In [1]:
import torch
import torch.nn as nn
import snntorch as snn
from snntorch import surrogate
from snntorch import utils

# Hyperparameters from Appendix A.1
ALPHA = 2.0         # Slope parameter for surrogate gradient [cite: 560]
V_TH = 1.0          # Firing threshold [cite: 560]
V_RESET = 0.0       # Reset potential [cite: 560]
LEARN_BETA = True   # PLIF: Learnable membrane time constant 

# Define the Surrogate Gradient Function (ArcTan)
spike_grad = surrogate.atan(alpha=ALPHA)

class SEWBlock(nn.Module):
    """
    Spike-Element-Wise (SEW) Residual Block with ADD operation.
    Ref: Section 3.3, Eq (9), and Fig 1(c)[cite: 139].
    """
    def __init__(self, channels):
        super().__init__()
        # The Residual Path F(x): Conv -> BN -> SN -> Conv -> BN -> SN
        # Note: 7B-Net maintains 32 channels throughout.
        
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(channels)
        self.sn1 = snn.Leaky(beta=0.5, threshold=V_TH, reset_mechanism='zero', 
                             spike_grad=spike_grad, init_hidden=True, learn_beta=LEARN_BETA)
        
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(channels)
        self.sn2 = snn.Leaky(beta=0.5, threshold=V_TH, reset_mechanism='zero', 
                             spike_grad=spike_grad, init_hidden=True, learn_beta=LEARN_BETA)

    def forward(self, x):
        # x: Input Spikes S^l[t]
        
        # --- Residual Mapping A^l[t] = SN(F(x)) ---
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.sn1(out) 
        
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.sn2(out) # This is A^l[t]
        
        # --- Element-Wise Function g(A, S) = ADD ---
        # "SEW ResNet can relieve infinite outputs... output no larger than k+1" 
        return out + x 

class Net7B(nn.Module):
    """
    7B-Net Architecture for DVS Gesture.
    Structure: c32k3s1-BN-PLIF-{SEW Block-MPk2s2}*7-FC11 
    """
    def __init__(self):
        super().__init__()
        
        # 1. Stem: c32k3s1-BN-PLIF
        # Input channels = 2 (DVS Gesture often uses 2 channels: on/off events)
        self.conv1 = nn.Conv2d(2, 32, kernel_size=3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(32)
        # Initial PLIF neuron
        self.sn1 = snn.Leaky(beta=0.5, threshold=V_TH, reset_mechanism='zero', 
                             spike_grad=spike_grad, init_hidden=True, learn_beta=LEARN_BETA)
        
        # 2. Body: {SEW Block - MPk2s2} * 7
        # The paper repeats this structure 7 times.
        self.layers = nn.ModuleList()
        for _ in range(7):
            self.layers.append(SEWBlock(32))
            self.layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
            
        # 3. Head: FC11
        # DVS Gesture inputs are 128x128. After 7 MaxPools (reduction 2^7 = 128),
        # spatial dim is 1x1. Channels = 32. Flatten -> 32 inputs.
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(32, 11) # 11 Classes for DVS Gesture
        
        # Output Integration Layer (Readout)
        # We use a non-resetting Leaky layer to accumulate potential (votes) over time.
        self.final_integ = snn.Leaky(beta=0.5, threshold=1e9, reset_mechanism='none', 
                                     spike_grad=spike_grad, init_hidden=True, 
                                     learn_beta=LEARN_BETA, output=True)

    def forward(self, x):
        # x shape: (Time_Steps, Batch_Size, Channels, Height, Width)
        
        # Reset hidden states (membranes) at start of batch
        utils.reset(self.conv1) # Just in case
        utils.reset(self.sn1)
        for layer in self.layers:
            utils.reset(layer)
        utils.reset(self.final_integ)
        
        spk_rec = []
        mem_rec = []
        
        # Loop over time dimension
        for step in range(x.size(0)):
            t_input = x[step]
            
            # Stem
            out = self.conv1(t_input)
            out = self.bn1(out)
            out = self.sn1(out)
            
            # 7B Layers (SEW + MaxPool repeats)
            for layer in self.layers:
                out = layer(out)
            
            # Classification Head
            out = self.flatten(out)
            out = self.fc(out)
            spk, mem = self.final_integ(out)
            
            spk_rec.append(spk)
            mem_rec.append(mem)
        
        # Return stack of membrane potentials for CrossEntropyLoss
        return torch.stack(mem_rec)

# --- Example Usage ---
if __name__ == "__main__":
    # Create the model
    model = Net7B()
    
    # Input: 16 Time Steps, Batch Size 4, 2 Channels, 128x128 Resolution
    # Note: T=16 is used in the paper for DVS Gesture 
    dummy_input = torch.randn(16, 4, 2, 128, 128)
    
    # Forward Pass
    output_mem = model(dummy_input)
    
    print(f"Model Parameters: {sum(p.numel() for p in model.parameters())}")
    print(f"Output Shape: {output_mem.shape}") # Should be [16, 4, 11]

Model Parameters: 130939
Output Shape: torch.Size([16, 4, 11])
