In [None]:
import torch
import torch.nn as nn
import quartz
import matplotlib.pyplot as plt
import sinabs.layers as sl

In [None]:
t_max = 16
batch_size = 1

class Net(nn.Module):
    def __init__(
        self, 
        t_max: int,
        batch_size: int,
    ):
        super().__init__()

        self.flatten_time = sl.FlattenTime()
        self.unflatten_time = sl.UnflattenTime(batch_size=batch_size)
        self.weight1 = nn.Linear(3, 1, bias=False)
        # self.weight1 = nn.Conv2d(1, 2, kernel_size=2, bias=False)
        self.layer1 = quartz.IFSqueeze(t_max=t_max, index=0, record_v_mem=True, rectification=True, batch_size=batch_size)

    def forward(self, data: torch.Tensor):
        data = self.flatten_time(data)
        data = self.weight1(data)
        data = self.layer1(data)
        return self.unflatten_time(data)

    def __len__(self):
        return 1

    def reset_states(self):
        self.layer1.reset_states()

net = Net(t_max=t_max, batch_size=batch_size)
net.weight1.weight.data *= 0
net.weight1.weight.data += 0.5

In [None]:
static_data = torch.ones(1, 1, 2, 2)
static_data = torch.ones(1, 1, 1, 3)
input_data = quartz.utils.encode_inputs(static_data, t_max=t_max, n_layers=len(net))
n_time_steps = input_data.shape[1]

In [None]:
net.reset_states()
output = net(input_data)

In [None]:
fig, ax1 = plt.subplots(1, 1, figsize=(6, 4))
ax1.plot(net.layer1.v_mem_recorded.flatten().detach().numpy(), label='v_mem')
ax1.plot(range(n_time_steps), [t_max]*n_time_steps, label='spike threshold')
ax1.plot(t_max*output.flatten().detach().numpy(), label='output')
ax1.legend()

In [None]:
quartz.decode_outputs(output, t_max)

In [None]:
torch.where(output > 0)

In [None]:
n_time_steps