In [None]:
import torch
import torch.nn as nn
import quartz
import matplotlib.pyplot as plt
import sinabs.layers as sl
import numpy as np
np.set_printoptions(suppress=True)

In [None]:
inputs = torch.rand((100000))
t_max = 4
q = 1/ t_max

In [None]:
np.histogram(q * np.round(inputs/q), bins=t_max+1)


In [None]:
q_inputs = (inputs * (t_max)).round() / (t_max)
np.histogram(q_inputs, bins=t_max)
# plt.hist(q_inputs);

In [None]:
q_inputs

In [None]:
np.bincount(np.digitize(inputs, bins=[0.25, 0.5, 0.75, 1]))

In [None]:
t_max = 5
batch_size = 1

class Net(nn.Module):
    def __init__(
        self, 
        t_max: int,
        batch_size: int,
    ):
        super().__init__()

        self.flatten_time = sl.FlattenTime()
        self.unflatten_time = sl.UnflattenTime(batch_size=batch_size)
        self.weight1 = nn.Linear(1, 1, bias=False)
        # self.weight1 = nn.Linear(10, 10, bias=False)
        # self.weight1 = nn.Conv2d(1, 1, kernel_size=3, bias=False)
        self.layer1 = quartz.IFSqueeze(t_max=t_max, record_v_mem=True, rectification=False, batch_size=batch_size)

    def forward(self, data: torch.Tensor):
        data = self.flatten_time(data)
        data = self.weight1(data)
        data = self.layer1(data)
        return self.unflatten_time(data)

    def __len__(self):
        return 1

    def reset_states(self):
        self.layer1.reset_states()

net = Net(t_max=t_max, batch_size=batch_size)
net.weight1.weight.data *= 0
net.weight1.weight.data -= 0.5

In [None]:
static_data = torch.ones(1, 1, 3, 3)
static_data = torch.ones(1, 1, 1, 1) #* 0.5
# static_data = torch.ones(1, 10)
input_data = quartz.utils.encode_inputs(static_data, t_max=t_max)
n_time_steps = input_data.shape[1]

In [None]:
net.reset_states()
output = net(input_data)

In [None]:
# input_data

In [None]:
fig, ax1 = plt.subplots(1, 1, figsize=(6, 4))
ax1.plot(net.layer1.v_mem_recorded[0,:,0,0,0].flatten().detach().numpy(), label='v_mem', drawstyle='steps-mid')
ax1.plot(range(n_time_steps), [t_max]*n_time_steps, label='spike threshold')
ax1.plot(t_max*output[0,:,0,0,0].flatten().detach().numpy(), label='output', drawstyle='steps-mid')
ax1.legend()

In [None]:
quartz.decode_outputs(output, t_max)

In [None]:
net.layer1.v_mem_recorded[0,999,0,0,0]

In [None]:
n_time_steps

In [None]:
static_data = torch.ones(1, 1, 2, 2)
static_data = torch.ones(2, 5, 3, 3) *0.25
input_data = quartz.utils.encode_inputs(static_data, t_max=t_max, n_layers=len(net))
n_time_steps = input_data.shape[1]

layer = nn.Conv2d(5, 1, kernel_size=3, bias=False)
layer2 = quartz.Lift(layer)

In [None]:
output = layer(input_data.flatten(0, 1)).unflatten(0, (-1, n_time_steps))
output2 = layer2(input_data)

In [None]:
(output == output2).all()

In [None]:
for i in range(t_max):
    print(output[:,i])

In [None]:
for i in range(t_max):
    print(output2[:,i])

In [None]:
import torch
import matplotlib.pyplot as plt
# np.set_printoptions(True)

In [None]:
q_inputs