In [4]:
import torch
import torchvision
import torchvision.transforms as transforms
from RL_utils import *
import lava.lib.dl.slayer as slayer
%load_ext autoreload
%autoreload 2

In [5]:
if torch.cuda.is_available():
    device = torch.device('cuda:0')
    #device = torch.device("cuda")     
else:
    device = torch.device("cpu")
    
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 128

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [6]:
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


class SNN_Net(torch.nn.Module):
    def __init__(self, time_steps, layers, layer_idx):
        super().__init__()
        self.time_steps = time_steps
        self.layers = layers
        self.layer_idx = layer_idx

        neuron_params = {
                'threshold'     : 1.25,
                'current_decay' : 1, # this must be 1 to use batchnorm
                'voltage_decay' : 0.03,
                'tau_grad'      : 1,
                'scale_grad'    : 1,
            }
        neuron_params_norm = {
                **neuron_params,
                # 'norm'    : slayer.neuron.norm.MeanOnlyBatchNorm,
            }
        self.generated_layers = [slayer.block.sigma_delta.Input(sdnn_params)]
        self.generated_layers += [self.layers[l_idx][idx] for l_idx, idx in enumerate(self.layer_idx)]
        self.generated_layers += [slayer.block.sigma_delta.Flatten()]
        self.generated_layers += [ 
                slayer.block.sigma_delta.Dense(sdnn_dense_params, 16 * 128, 128, weight_scale=2, weight_norm=True),
                slayer.block.sigma_delta.Output(sdnn_dense_params, 128, 10, weight_scale=2, weight_norm=True)
        ]
        
        self.blocks = torch.nn.ModuleList(self.generated_layers)

    def forward(self, x):
        x = slayer.utils.time.replicate(x, self.time_steps)
        for block in self.blocks:
            x = block(x)
        return x

    def export_hdf5(self, filename):
        # network export to hdf5 format
        h = h5py.File(filename, 'w')
        simulation = h.create_group('simulation')
        simulation['Ts'] = 1
        simulation['tSample'] = self.time_steps
        layer = h.create_group('layer')
        for i, b in enumerate(self.blocks):
            b.export_hdf5(layer.create_group(f'{i}'))

In [7]:
import torch.nn.functional as F
sdnn_params = { # sigma-delta neuron parameters
                'threshold'     : 0.1,    # delta unit threshold
                'tau_grad'      : 0.5,    # delta unit surrogate gradient relaxation parameter
                'scale_grad'    : 1,      # delta unit surrogate gradient scale parameter
                'requires_grad' : True,   # trainable threshold
                'shared_param'  : True,   # layer wise threshold
                'activation'    : F.relu, # activation function
            }

sdnn_cnn_params = { # conv layer has additional mean only batch norm
                **sdnn_params,                                 # copy all sdnn_params
                'norm' : slayer.neuron.norm.MeanOnlyBatchNorm, # mean only quantized batch normalizaton
            }

sdnn_dense_params = { # dense layers have additional dropout units enabled
                **sdnn_cnn_params,                        # copy all sdnn_cnn_params
                'dropout' : slayer.neuron.Dropout(p=0.0), # neuron dropout
            }


l1 = [slayer.block.sigma_delta.Conv(sdnn_cnn_params, 3, 32, 3, padding=1, stride=2, weight_scale=2, weight_norm=True),
      slayer.block.sigma_delta.Conv(sdnn_cnn_params, 3, 32, 5, padding=2, stride=2, weight_scale=2, weight_norm=True),
      slayer.block.sigma_delta.Conv(sdnn_cnn_params, 3, 32, 7, padding=3, stride=2, weight_scale=2, weight_norm=True)]

l2 = [slayer.block.sigma_delta.Conv(sdnn_cnn_params, 32, 32, 3, padding=1, stride=1, weight_scale=2, weight_norm=True),
      slayer.block.sigma_delta.Conv(sdnn_cnn_params, 32, 32, 5, padding=2, stride=1, weight_scale=2, weight_norm=True),
      slayer.block.sigma_delta.Conv(sdnn_cnn_params, 32, 32, 7, padding=3, stride=1, weight_scale=2, weight_norm=True)]

l3 = [slayer.block.sigma_delta.Conv(sdnn_cnn_params, 32, 64, 3, padding=1, stride=2, weight_scale=2, weight_norm=True),
      slayer.block.sigma_delta.Conv(sdnn_cnn_params, 32, 64, 5, padding=2, stride=2, weight_scale=2, weight_norm=True),
      slayer.block.sigma_delta.Conv(sdnn_cnn_params, 32, 64, 7, padding=3, stride=2, weight_scale=2, weight_norm=True)]

l4 = [slayer.block.sigma_delta.Conv(sdnn_cnn_params, 64, 64, 3, padding=1, stride=1, weight_scale=2, weight_norm=True),
      slayer.block.sigma_delta.Conv(sdnn_cnn_params, 64, 64, 5, padding=2, stride=1, weight_scale=2, weight_norm=True),
      slayer.block.sigma_delta.Conv(sdnn_cnn_params, 64, 64, 7, padding=3, stride=1, weight_scale=2, weight_norm=True)]

l5 = [slayer.block.sigma_delta.Conv(sdnn_cnn_params, 64, 128, 3, padding=1, stride=2, weight_scale=2, weight_norm=True),
      slayer.block.sigma_delta.Conv(sdnn_cnn_params, 64, 128, 5, padding=2, stride=2, weight_scale=2, weight_norm=True),
      slayer.block.sigma_delta.Conv(sdnn_cnn_params, 64, 128, 7, padding=3, stride=2, weight_scale=2, weight_norm=True)]
layers = [l1, l2, l3, l4, l5]

In [15]:
import torch.optim as optim
net = SNN_Net(time_steps=10, layers=layers, layer_idx=[0, 0, 0, 0, 0]).to(device)
criterion = slayer.loss.SpikeMax(moving_window=None, mode='softmax', reduction='sum')
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

In [17]:
#for name, param in net.named_parameters():
#    print(param.shape, name)
#res = MGM_SNN(net, criterion, optimizer, trainloader, num_batchs=10, num_params=50)
#print(res)

In [18]:
def test(model, testloader):
    correct = 0
    total = 0
    # since we're not training, we don't need to calculate the gradients for our outputs
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            # calculate outputs by running images through the network
            outputs = model(images)
            # the class with the highest energy is what we choose as prediction
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')

def train_snn(model, optimizer, criterion, trainloader, testloader, epochs=1):
    for epoch in range(epochs):  # loop over the dataset multiple times
        print(epoch)
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            print(i)

            # print statistics
            running_loss += loss.item()
            if i % 2000 == 1999:    # print every 2000 mini-batches
                print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
                running_loss = 0.0
                
    # test results
    test(model, testloader)

In [23]:
import numpy as np
for i in range(200):
    layer_idx = np.random.randint(3, size=5).tolist()
    print(layer_idx)
    net = SNN_Net(time_steps=10, layers=layers, layer_idx=layer_idx).to(device)
    
    criterion = slayer.loss.SpikeMax(moving_window=None, mode='probability', reduction='sum')
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    
    MGM_score = MGM_SNN(net, criterion, optimizer, trainloader, num_batchs=10, num_params=50)
    print(MGM_score)
    train_snn(net, optimizer, criterion, trainloader, testloader, epochs=1)

[2, 0, 2, 1, 2]
tensor(nan)
0
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115


KeyboardInterrupt: 

In [10]:
train_snn(net, optimizer, criterion, trainloader, testloader, epochs=2)

0
0
1
2
3
4
5
6
7
8
9
10
11
12
13


KeyboardInterrupt: 