In [4]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from typing import Callable
from enum import Enum
import tonic
from torch import optim

import tonic.transforms as tonic_transforms
import torchvision.transforms as transforms

from torch.utils.data import DataLoader, Subset, TensorDataset


  from .autonotebook import tqdm as notebook_tqdm


In [5]:
class Sampler(torch.nn.Module):
    def __init__(self, window_size) -> None:
        super(Sampler, self).__init__()
        self.window_size = window_size

# class LearnableSampler(Sampler):
#     def __init__(self, window_size: int) -> None:
#         super(LearnableSampler, self).__init__(window_size=window_size)
#         self.linear = torch.nn.Linear(1, self.window_size)

#     def forward(self, x: torch.Tensor) -> torch.Tensor:
#         x = x.unsqueeze(-1)
#         x = self.linear(x)
#         return x

class HU(torch.nn.Module):
    def __init__(self, window_size: int, non_linear: torch.nn.Module = None) -> None:
        super(HU, self).__init__()
        self.window_size = window_size
        self.window_set = None
        self.window_conv = None
        self.sampler = None
        self.non_linear = non_linear
        self.precision_convert = None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.window_set is not None:
            x = self.window_set(x)
        if self.window_conv is not None:
            x = self.window_conv(x)
        if self.sampler is not None:
            x = self.sampler(x)
        if self.non_linear is not None:
            x = self.non_linear(x)
        if self.precision_convert is not None:
            x = self.precision_convert(x)
        return x


class PrecisionConvert(torch.nn.Module):
    def __init__(self, converter: Callable[[torch.Tensor], torch.Tensor]) -> None:
        super(PrecisionConvert, self).__init__()
        self.converter = converter

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.converter(x)
        return x

class A2SPrecisionConvert(PrecisionConvert):
    def __init__(self, converter: Callable[[torch.Tensor], torch.Tensor]) -> None:
        super(A2SPrecisionConvert, self).__init__(converter=converter)


class A2SHU(HU):
    def __init__(self, window_size: int, converter: Callable[[torch.Tensor], torch.Tensor],
                 non_linear: torch.nn.Module = None) -> None:
        super(A2SHU, self).__init__(window_size, non_linear)
        self.precision_convert = A2SPrecisionConvert(converter=converter)

    def check(self):
        assert (self.window_set is None and self.window_conv is None and
                self.sampler is not None and self.precision_convert is not None)

class A2SLearnableCoding(A2SHU):
    def __init__(self, window_size: int, converter: Callable[[torch.Tensor], torch.Tensor],
                 non_linear: torch.nn.Module = None) -> None:
        super(A2SLearnableCoding, self).__init__(
            window_size, converter, non_linear)
        self.sampler = LearnableSampler(window_size=self.window_size)
        self.check()

class HardUpdateAfterSpike(torch.nn.Module):
    def __init__(self, value: float) -> None:
        super(HardUpdateAfterSpike, self).__init__()
        self.value = value

    def forward(self, x: torch.Tensor, spike: torch.Tensor) -> torch.Tensor:
        out = spike * self.value + (1 - spike) * x
        return out

class Rectangle(torch.autograd.Function):
    window_size = 1

    @staticmethod
    def forward(ctx, v3: torch.Tensor, v_th) -> torch.Tensor:
        ctx.save_for_backward(v3, torch.as_tensor(v_th, device=v3.device))
        out = (v3 > v_th).float()
        return out

    @staticmethod
    def backward(ctx, grad_output):
        v3, v_th = ctx.saved_tensors
        mask = torch.abs(v3 - v_th) < Rectangle.window_size / 2
        return grad_output * mask.float() * 1 / Rectangle.window_size, None

    @staticmethod
    def symbolic(g: torch._C.Graph, input: torch._C.Value, v_th0: float) -> torch._C.Value:
        return g.op("snn::RectangleFire", input, v_th0_f=v_th0)


class Accumulate(torch.nn.Module):
    def __init__(self, v_init) -> None:
        super(Accumulate, self).__init__()
        self.v_init = v_init

    def forward(self, u_in, v=None) -> torch.Tensor:
        if v is None:
            v = torch.full_like(u_in, self.v_init)
        return u_in + v

class FireWithConstantThreshold(torch.nn.Module):
    def __init__(self, surrogate_function, v_th) -> None:
        super(FireWithConstantThreshold, self).__init__()
        self.surrogate_function = surrogate_function
        self.v_th = v_th

    def forward(self, v) -> torch.Tensor:
        spike = self.surrogate_function.apply(v, self.v_th)
        return spike

class IF(torch.nn.Module):
    '''Integrate-and-Fire
    '''
    def __init__(self, v_th, v_reset, v_init=None, window_size=1):
        super(IF, self).__init__()
        self.reset = HardUpdateAfterSpike(value=v_reset)
        self.accumulate = Accumulate(
            v_init=self.reset.value if v_init is None else v_init)
        Rectangle.window_size = window_size
        self.fire = FireWithConstantThreshold(
            surrogate_function=Rectangle, v_th=v_th)

    def forward(self, u_in: torch.Tensor, v: torch.Tensor = None):
        print("[IF] input u_in shape:", u_in.shape)
        if v is not None:
            print("[IF] input v shape:", v.shape)

        v_update = self.accumulate(u_in, v)
        print("[IF] v_update shape:", v_update.shape)

        spike = self.fire(v_update)
        print("[IF] spike shape:", spike.shape)

        v = self.reset(v_update, spike)
        print("[IF] new v shape:", v.shape)

        return spike, v


class Leaky(torch.nn.Module):
    def __init__(self, alpha, beta, adpt_en=True):
        super(Leaky, self).__init__()
        self.alpha = alpha
        self.beta = beta
        assert alpha <= 1
        self.adpt_en = adpt_en

    def forward(self, x: torch.Tensor):
        if self.adpt_en:
            out = self.alpha * x + self.beta
        else:
            out = x + self.beta
        return out

class LIF(torch.nn.Module):
    '''Leaky-Integrate-and-Fire
    Args:
        if_node.reset.value = v_reset
        if_node.accumulate.v_init = v_init
        if_node.fire.v_th = v_th
        if_node.fire.surrogate_function: Rectangle
        v_leaky.alpha = v_leaky_alpha
        v_leaky.beta = v_leaky_beta
        v_leaky.adpt_en = v_leaky_adpt_en
        window_size: Rectangle, default = 1
    '''
    def __init__(self, v_th, v_leaky_alpha, v_leaky_beta, v_reset=0, v_leaky_adpt_en=False, v_init=None, window_size=1):
        super(LIF, self).__init__()
        self.if_node = IF(v_th=v_th, v_reset=v_reset,
                          v_init=v_init, window_size=window_size)
        self.v_leaky = Leaky(alpha=v_leaky_alpha,
                             beta=v_leaky_beta, adpt_en=v_leaky_adpt_en)

    def forward(self, u_in: torch.Tensor, v=None):
        spike, v = self.if_node(u_in, v)
        v = self.v_leaky(v)
        return spike, v


# class OutputRateCoding(torch.nn.Module):
#     def __init__(self, dim=0) -> None:
#         super().__init__()
#         self.dim = dim

#     def forward(self, x: torch.Tensor):
#         x = torch.stack(x, dim=0)
#         return x.mean(dim=self.dim)

class OutputRateCoding(nn.Module):
    def __init__(self, dim=0):
        super().__init__()
        self.dim = dim
    def forward(self, x: torch.Tensor):
        # x shape = (T,B,10)
        return x.mean(dim=self.dim)  # => (B,10)

# class LearnableSampler(Sampler):
#     def __init__(self, window_size: int) -> None:
#         super(LearnableSampler, self).__init__(window_size=window_size)
#         self.linear = torch.nn.Linear(1, self.window_size)

#     def forward(self, x: torch.Tensor) -> torch.Tensor:
#         x = x.unsqueeze(-1)
#         x = self.linear(x)
#         return x


class LearnableSampler(Sampler):
    def __init__(self, window_size: int) -> None:
        super().__init__(window_size=window_size)
        # Suppose we want to keep output dimension = 576
        self.linear = torch.nn.Linear(576, 576)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x is already (T, B, 576). Let's flatten T*B dimension to feed linear
        T, B, F = x.shape    # F=576
        x = x.view(T*B, F)   # => (T*B, 576)
        x = self.linear(x)   # => (T*B, 576)
        x = x.view(T, B, F)  # => (T, B, 576)
        return x

class InputMode(Enum):
    STATIC = 'static'
    SEQUENTIAL = 'sequential'


class Model(torch.nn.Module):
    def __init__(self, time_interval: int, mode: InputMode) -> None:
        super(Model, self).__init__()
        self.time_interval = time_interval
        self.mode = mode

    def multi_step_forward(self, x, *args):
        outputs = []
        if self.mode == InputMode.STATIC:
            for i in range(self.time_interval):
                output, *args = self.forward(x, *args)
                outputs.append(output)
        elif self.mode == InputMode.SEQUENTIAL:
            for i in range(self.time_interval):
                output, *args = self.forward(x[i], *args)
                outputs.append(output)
        else:
            raise ValueError('Unsupported input mode')
        return outputs



class A2SModel(torch.nn.Module):
    def __init__(self, T) -> None:
        super().__init__()
        self.T = T
        self.ann: torch.nn.Module = None
        self.a2shu: A2SHU = None
        self.snn: Model = None
        self.encode: torch.nn.Module = None

    def reshape(self, x: torch.Tensor):
        return x

    def forward(self, x, *args):
        x = self.ann(x)
        x = self.a2shu(x)  # [N, C, H, W] -> [N, C, H, W, T]
        x = self.reshape(x)  # [N, C, H, W, T] -> [T, ...]
        x = self.snn.multi_step_forward(x, *args)
        return self.encode(x)

In [6]:
class HybridLayer(nn.Module):
    def __init__(self, in_features, out_features, use_conv=False):
        super().__init__()
        self.use_conv = use_conv
        
        if use_conv:
            self.ann_layer = nn.Conv2d(in_features, out_features, kernel_size=3, padding=1)
        else:
            self.ann_layer = nn.Linear(in_features, out_features)
        
        self.snn_layer = LIF(v_th=0.1, v_leaky_alpha=0.5, v_leaky_beta=0, v_reset=0)

    def forward(self, x, v=None):
        ann_out = self.ann_layer(x)
        spike_out, v = self.snn_layer(ann_out, v)  # SNN processes ANN output
        return spike_out, v
