In [36]:
from metavision_core.event_io import EventsIterator, load_events
import tonic
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import random
import statistics
from IPython.utils import io
from tqdm.notebook import tqdm
# import sdl2.ext

import models
import utils

In [37]:
# %matplotlib inline
# %matplotlib ipympl
# %matplotlib tk
plt.rc('animation', html='jshtml')
plt.rcParams["image.cmap"] = "bone"

In [38]:
in_data = tonic.datasets.NMNIST(save_to="../data", train=True) # (34, 34, 2) 3
# in_data = tonic.datasets.POKERDVS(save_to="../data",train=True) # (35, 35, 2) 3
# in_data = tonic.datasets.DVSGesture(save_to="../data",train=True) # (128, 128, 2) 2
# in_data = tonic.datasets.ASLDVS(save_to="../data", train=True) # (240, 180, 2) 3

in_data.sensor_size

(34, 34, 2)

In [39]:
noise = tonic.transforms.UniformNoise(
    sensor_size=in_data.sensor_size[:-1]+(1,),
    n=int(4172*2),
)
to_frame = tonic.transforms.ToFrame(in_data.sensor_size[:-1]+(1,), time_window=1000)
# to_frame = tonic.transforms.ToFrame(in_data.sensor_size[:-1]+(1,), n_time_bins=100)

In [40]:
# net = models.Conv_FastLIF(models.kernels.gaussian3, 0.95,1.5)#.cuda(0)
net = models.LIF()#.cuda(0)

In [41]:
list(net.parameters())

[Parameter containing:
 tensor(200., requires_grad=True),
 Parameter containing:
 tensor(100., requires_grad=True)]

In [42]:
optimizer = torch.optim.Adam(net.parameters(), lr=1)
for epoch in range(10):
    random.seed(0)
    ids = random.sample(range(len(in_data)), 10)
    losses = []
    for id in tqdm(ids):
        events, label = in_data[id]
        events = events.squeeze()
        events = np.array(list(filter(lambda ev: ev[3] == 0, events)))
        # print(events)

        true_frames = torch.tensor(to_frame(events)).float()  # .cuda(0)
        in_frames = torch.tensor(to_frame(noise(events))).float()  # .cuda(0)

        out_frames = net(in_frames)

        out_frames = utils.frame_merge(out_frames, 10).clamp_max(1).cpu()
        true_frames = utils.frame_merge(true_frames, 10).clamp_max(1).cpu()

        loss = nn.functional.l1_loss(out_frames, true_frames)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(loss.detach().item())
    print(statistics.mean(losses))




  0%|          | 0/10 [00:00<?, ?it/s]

0.027685041539371014


  0%|          | 0/10 [00:00<?, ?it/s]

0.027156365849077702


  0%|          | 0/10 [00:00<?, ?it/s]

0.026854193955659866


  0%|          | 0/10 [00:00<?, ?it/s]

0.02716929167509079


  0%|          | 0/10 [00:00<?, ?it/s]

0.02757656220346689


  0%|          | 0/10 [00:00<?, ?it/s]

0.027415085397660733


  0%|          | 0/10 [00:00<?, ?it/s]

0.02763900477439165


  0%|          | 0/10 [00:00<?, ?it/s]

0.027470071241259574


  0%|          | 0/10 [00:00<?, ?it/s]

0.02770850732922554


  0%|          | 0/10 [00:00<?, ?it/s]

0.027743407897651194


In [43]:
random.seed(0)
ids = random.sample(range(len(in_data)), 100)
noisy_loss = []
denoisy_loss = []
for id in tqdm(ids):
    events, label = in_data[id]
    events = events.squeeze()
    events = np.array(list(filter(lambda ev: ev[3] == 0, events)))
    # print(events)

    true_frames = torch.tensor(to_frame(events)).float()  # .cuda(0)
    in_frames = torch.tensor(to_frame(noise(events))).float()  # .cuda(0)

    out_frames = net(in_frames)

    in_frames = utils.frame_merge(in_frames, 10).clamp_max(1).cpu()
    out_frames = utils.frame_merge(out_frames, 10).clamp_max(1).cpu()
    true_frames = utils.frame_merge(true_frames, 10).clamp_max(1).cpu()
    noisy_loss.append(loss(in_frames, true_frames).detach().item())
    denoisy_loss.append(loss(out_frames, true_frames).item())

display(statistics.mean(noisy_loss))
display(statistics.mean(denoisy_loss))

  0%|          | 0/100 [00:00<?, ?it/s]

TypeError: 'Tensor' object is not callable

In [None]:
with io.capture_output() as captured:
    anim = tonic.utils.plot_animation(np.concatenate((in_frames,out_frames,true_frames),axis=3))
display(anim)

In [None]:
anim.save("../img/lif-temp.gif",dpi=300)