In [1]:
import torch
import snntorch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.set_per_process_memory_fraction(0.85, device=0)



import numpy as np

import tonic
from tonic import datasets, transforms
import torchvision as tv
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split
from collections import namedtuple

State = namedtuple("State", "obs labels")



In [2]:
class _SHD2Raster():
    """ 
    Tool for rastering SHD samples into frames. Packs bits along the temporal axis for memory efficiency. This means
        that the used will have to apply jnp.unpackbits(events, axis=<time axis>) prior to feeding the data to the network.
    """

    def __init__(self, encoding_dim, sample_T = 100):
        self.encoding_dim = encoding_dim
        self.sample_T = sample_T
        
    def __call__(self, events):
        # tensor has dimensions (time_steps, encoding_dim)
        tensor = np.zeros((events["t"].max()+1, self.encoding_dim), dtype=int)
        np.add.at(tensor, (events["t"], events["x"]), 1)
        #return tensor[:self.sample_T,:]
        tensor = tensor[:self.sample_T,:]
        tensor = np.minimum(tensor, 1)
        #tensor = np.packbits(tensor, axis=0) pytorch does not have an unpack feature.
        return tensor

In [3]:
sample_T = 64
shd_timestep = 1e-6
shd_channels = 700
net_channels = 128
net_dt = 1/sample_T
batch_size = 256

obs_shape = tuple([net_channels,])
act_shape = tuple([20,])

transform = transforms.Compose([
    transforms.Downsample(
        time_factor=shd_timestep / net_dt,
        spatial_factor=net_channels / shd_channels
    ),
    _SHD2Raster(net_channels, sample_T=sample_T)
])

train_dataset = datasets.SHD("./data", train=True, transform=transform)
test_dataset = datasets.SHD("./data", train=False, transform=transform)



In [4]:
train_dl = iter(DataLoader(train_dataset, batch_size=len(train_dataset),
                          collate_fn=tonic.collation.PadTensors(batch_first=True), drop_last=True, shuffle=False))
        
x_train, y_train = next(train_dl)
x_train, y_train = x_train.to(torch.uint8), y_train.to(torch.uint8)
x_train, y_train = x_train.to(device), y_train.to(device)

In [5]:
def shuffle(dataset):
    x, y = dataset

    cutoff = y.shape[0] % batch_size

    indices = torch.randperm(y.shape[0])[:-cutoff]
    obs, labels = x[indices], y[indices]


    obs = torch.reshape(obs, (-1, batch_size) + obs.shape[1:])
    labels = torch.reshape(labels, (-1, batch_size)) # should make batch size a global

    return State(obs=obs, labels=labels)

In [6]:
test_dl = iter(DataLoader(test_dataset, batch_size=len(test_dataset),
                          collate_fn=tonic.collation.PadTensors(batch_first=True), drop_last=True, shuffle=False))
        
x_test, y_test = next(test_dl)
x_test, y_test = x_test.to(torch.uint8), y_test.to(torch.uint8)
x_test, y_test = x_test.to(device), y_test.to(device)
x_test, y_test = shuffle((x_test, y_test))

In [7]:
num_hidden = 64
# Define Network
class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()

        # Initialize layers
        self.fc1 = torch.nn.Linear(128, num_hidden)
        self.lif1 = snntorch.Leaky(beta=torch.ones(num_hidden)*0.5, learn_beta=True)
        self.fc2 = torch.nn.Linear(num_hidden, num_hidden)
        self.lif2 = snntorch.Leaky(beta=torch.ones(num_hidden)*0.5, learn_beta=True)
        self.fc3 = torch.nn.Linear(num_hidden, 20)
        self.lif3 = snntorch.Leaky(beta=torch.ones(20)*0.5, learn_beta=True, threshold=10e6)

    def forward(self, x):

        x = x.float() # [batch, time, channel]
        
        x = x.permute(1,0,2) # [time, batch, channel]
        # Initialize hidden states at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()
        
        V = []

        # need to fix since data is not time leading axis...
        for i, step in enumerate(x):
            cur1 = self.fc1(step)
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            cur3 = self.fc3(spk2)
            spk3, mem3 = self.lif3(cur3, mem3)
            
            V.append(mem3)

        
        return torch.stack(V, axis=0).permute(1,0,2)
        
# Load the network onto CUDA if available
net = Net().to(device)

In [8]:
loss = torch.nn.CrossEntropyLoss(label_smoothing=0.3)
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4)
acc = lambda predictions, targets : (torch.argmax(predictions, axis=-1) == targets).sum().item() / len(targets)

In [78]:
num_epochs = 300
loss_hist = []
counter = 0


# Outer training loop
for epoch in range(num_epochs):    
    
    train_batch = shuffle((x_train, y_train))
    train_data, targets = train_batch
    
    
    # Minibatch training loop
    for data, targets in zip(train_data, targets):

        # forward pass
        net.train()
        out_V = net(data)
        # initialize the loss & sum over time
        loss_val = loss(torch.sum(out_V, axis=-2), targets)

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

    # Store loss history for future plotting
    loss_hist.append(loss_val.item())


# Test set
with torch.no_grad():
    denominator = y_test[0]
    test_acc = 0
    batch_acc = []
    for test_data, test_targets in zip(x_test, y_test):
        net.eval()
        # Test set forward pass
        out_V = net(test_data)
        # Test set loss
        batch_acc.append( acc(torch.sum(out_V, axis=-2), test_targets) )
    
    test_acc = np.mean(batch_acc)


KeyboardInterrupt: 

In [10]:
test_acc

0.767578125

### Using NIR to load a network from Spyx in snnTorch:

The following code is boiler plate until snnTorch merges support for importing from NIR.

In [105]:
import snntorch as snn
import numpy as np
import torch
import nir
import typing


# TODO: implement this?
class ImportedNetwork(torch.nn.Module):
    """Wrapper for a snnTorch network. NOTE: not working atm."""
    def __init__(self, module_list):
        super().__init__()
        self.module_list = module_list

    def forward(self, x):
        for module in self.module_list:
            x = module(x)
        return x


def create_snntorch_network(module_list):
    return torch.nn.Sequential(*module_list)


def _lif_to_snntorch_module(
        lif: typing.Union[nir.LIF, nir.CubaLIF]
) -> torch.nn.Module:
    """Parse a LIF node into snnTorch."""
    if isinstance(lif, nir.LIF):
        assert np.alltrue(lif.v_leak == 0), 'v_leak not supported'
        assert np.allclose(lif.r , 1. - 1. / lif.tau), 'r not supported'
        assert np.unique(lif.v_threshold).size == 1, 'v_threshold must be same for all neurons'
        threshold = lif.v_threshold[0]
        mod = snn.Leaky(
            beta=1. - 1. / lif.tau,
            threshold=threshold,
            reset_mechanism='subtract',
            init_hidden=True,
        )
        return mod

    elif isinstance(lif, nir.LI):
        assert np.alltrue(lif.v_leak == 0), 'v_leak not supported'
        assert np.allclose(lif.r , 1. - 1. / lif.tau), 'r not supported'
        mod = snn.Leaky(
            beta=1. - 1. / lif.tau,
            reset_mechanism='none',
            init_hidden=True,
            output=True,
        )
        return mod

    elif isinstance(lif, nir.CubaLIF):
        assert np.alltrue(lif.v_leak == 0), 'v_leak not supported'
        assert np.alltrue(lif.r == 1. - 1. / lif.tau_mem), 'r not supported'  # NOTE: is this right?
        assert np.unique(lif.v_threshold).size == 1, 'v_threshold must be same for all neurons'
        threshold = lif.v_threshold[0]
        mod = snn.RSynaptic(
            alpha=1. - 1. / lif.tau_syn,
            beta=1. - 1. / lif.tau_mem,
            threshold=threshold,
            all_to_all=True,
            reset_mechanism='zero',
            linear_features=lif.tau_mem.shape[0] if len(lif.tau_mem.shape) == 1 else None,
            init_hidden=True,
        )
        return mod

    else:
        raise ValueError('called _lif_to_snntorch_module on non-LIF node')


def _to_snntorch_module(node: nir.NIRNode) -> torch.nn.Module:
    """Convert a NIR node to a snnTorch module.

    Supported NIR nodes: Affine.
    """
    if isinstance(node, (nir.LIF, nir.CubaLIF, nir.LI)):
        return _lif_to_snntorch_module(node)

    elif isinstance(node, (nir.Affine, nir.Linear)):
        if len(node.weight.shape) != 2:
            raise NotImplementedError('only 2D weight matrices are supported')
        
        linear = torch.nn.Linear(node.weight.shape[1], node.weight.shape[0], bias=False)
        linear.weight.data = torch.Tensor(node.weight.T)
        
        return linear

    else:
        raise NotImplementedError(f'node type {type(node).__name__} not supported')


def _rnn_subgraph_to_snntorch_module(
        lif: typing.Union[nir.LIF, nir.CubaLIF], w_rec: typing.Union[nir.Affine, nir.Linear]
) -> torch.nn.Module:
    """Parse an RNN subgraph consisting of a LIF node and a recurrent weight matrix into snnTorch.

    NOTE: for now always set it as a recurrent linear layer (not RecurrentOneToOne)
    """
    assert isinstance(lif, (nir.LIF, nir.CubaLIF)), 'only LIF or CubaLIF nodes supported as RNNs'
    mod = _lif_to_snntorch_module(lif)
    mod.recurrent.weight.data = torch.Tensor(w_rec.weight)
    if isinstance(w_rec, nir.Linear):
        mod.recurrent.register_parameter('bias', None)
        mod.recurrent.reset_parameters()
    else:
        mod.recurrent.bias.data = torch.Tensor(w_rec.bias)
    return mod


def _get_next_node_key(node_key: str, graph: nir.ir.NIRGraph):
    """Get the next node key in the NIR graph."""
    possible_next_node_keys = [edge[1] for edge in graph.edges if edge[0] == node_key]
    # possible_next_node_keys += [edge[1] + '.input' for edge in graph.edges if edge[0] == node_key]
    assert len(possible_next_node_keys) <= 1, 'branching networks are not supported'
    if len(possible_next_node_keys) == 0:
        return None
    else:
        return possible_next_node_keys[0]


def from_nir(graph: nir.ir.NIRGraph) -> torch.nn.Module:
    """Convert NIR graph to snnTorch module.

    :param graph: a saved snnTorch model as a parameter dictionary
    :type graph: nir.ir.NIRGraph

    :return: snnTorch module
    :rtype: torch.nn.Module
    """
    node_key = 'input'
    visited_node_keys = [node_key]
    module_list = []

    while _get_next_node_key(node_key, graph) is not None:
        node_key = _get_next_node_key(node_key, graph)

        assert node_key not in visited_node_keys, 'cyclic NIR graphs not supported'

        if node_key == 'output':
            visited_node_keys.append(node_key)
            continue

        if node_key in graph.nodes:
            visited_node_keys.append(node_key)
            node = graph.nodes[node_key]
            print(f'simple node {node_key}: {type(node).__name__}')
            module = _to_snntorch_module(node)
        else:
            # check if it's a nested node
            print(f'potential subgraph node: {node_key}')
            sub_node_keys = [n for n in graph.nodes if n.startswith(f'{node_key}.')]
            assert len(sub_node_keys) > 0, f'no nodes found for subgraph {node_key}'

            # parse subgraph
            # NOTE: for now only looking for RNN subgraphs
            rnn_sub_node_keys = [f'{node_key}.{n}' for n in ['input', 'output', 'lif', 'w_rec']]
            if set(sub_node_keys) != set(rnn_sub_node_keys):
                raise NotImplementedError('only RNN subgraphs are supported')
            print('found RNN subgraph')
            module = _rnn_subgraph_to_snntorch_module(
                graph.nodes[f'{node_key}.lif'], graph.nodes[f'{node_key}.w_rec']
            )
            for nk in sub_node_keys:
                visited_node_keys.append(nk)

        module_list.append(module)

    if len(visited_node_keys) != len(graph.nodes):
        print(graph.nodes.keys(), visited_node_keys)
        raise ValueError('not all nodes visited')

    return create_snntorch_network(module_list)

In [106]:
G = nir.read("./spyx_shd.nir")

dict_keys(['LI', 'LIF', 'LIF_1', 'input', 'linear', 'linear_1', 'linear_2', 'output'])

In [110]:
net2 = from_nir(G).to(device)

simple node linear: Linear
simple node LIF: LIF
simple node linear_1: Linear
simple node LIF_1: LIF
simple node linear_2: Linear
simple node LI: LI


  return _lif_to_snntorch_module(node)


In [121]:
from snntorch import utils

def forward_pass(network, data):
  v_rec = []
  utils.reset(net)  # resets hidden states for all LIF neurons in net

  transposed_data = torch.permute(data, (1,0,2))

  for step in transposed_data:  # data.size(0) = number of time steps
      spk_out, v = network(step)
      v_rec.append(v)
  
  return torch.stack(v_rec)

In [125]:
# Test set
with torch.no_grad():
    denominator = y_test[0]
    test_acc = 0
    batch_acc = []
    for test_data, test_targets in zip(x_test, y_test):
        net2.zero_grad()
        # Test set forward pass
        out_V = forward_pass(net2, test_data.to(torch.float32))
        # Test set loss
        batch_acc.append( acc(torch.sum(out_V, axis=0), test_targets) )
    
    test_acc = np.mean(batch_acc)

test_acc

0.7431640625

As we can see, it gets about the same accuracy as it did in Spyx.

In [128]:
batch_acc

[0.71875,
 0.76171875,
 0.73828125,
 0.7109375,
 0.73046875,
 0.7578125,
 0.75,
 0.77734375]