In [1]:
import torch
import torchvision
import numpy as np
import cv2
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt

## Importing MNIST dataset and creating train and test loder

In [2]:
n_epochs = 5
batch_size_train = 64
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 10

random_seed = 1
torch.backends.cudnn.enabled = True
torch.manual_seed(random_seed)

train_loader = torch.utils.data.DataLoader(
  #torchvision.datasets.MNIST('/files/', train=True, download=True,
  torchvision.datasets.MNIST('.', train=True, download=True,

                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor()
                             ])),
  batch_size=batch_size_train,pin_memory = True, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  #torchvision.datasets.MNIST('/files/', train=False, download=True,
  torchvision.datasets.MNIST('.', train=False, download=True,

                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor()
                             ])),
  batch_size=batch_size_test, pin_memory = True,shuffle=True)

#torchvision.transforms.Normalize(                                 (0.1307,), (0.3081,))

In [None]:
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)
example_data.shape

fig = plt.figure()
for i in range(6):
    plt.subplot(2,3,i+1)
    plt.tight_layout()
    plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
    plt.title("Ground Truth: {}".format(example_targets[i]))
    plt.xticks([])
    plt.yticks([])


In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.mnist = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=3),
            nn.ReLU(),
            nn.Conv2d(8, 16, kernel_size=3),
            nn.BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(),
            nn.Dropout2d(),
            nn.Conv2d(16, 20, kernel_size=5),
            nn.BatchNorm2d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(),
            nn.Dropout2d(),
            nn.Flatten(),
            nn.Linear(320, 50),
            nn.BatchNorm1d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU(),
            nn.Linear(50, 10),
            nn.LogSoftmax()
        )

    def forward(self, x):
        x = self.mnist(x)
        return x


In [None]:
network_cpu = Net()
network = network_cpu.cuda()
optimizer = optim.SGD(network.parameters(), lr=learning_rate,
                      momentum=momentum)
train_losses = []
train_counter = []
test_losses = []
test_counter = [i*len(train_loader.dataset) for i in range(n_epochs + 1)]

In [None]:
def train(epoch):
    network.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = network(data.cuda())
        loss = F.nll_loss(output, target.cuda())
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            train_losses.append(loss.item())
            train_counter.append(
                (batch_idx*64) + ((epoch-1)*len(train_loader.dataset)))
            torch.save(network.state_dict(), 'model.pth')
            torch.save(optimizer.state_dict(), 'optimizer.pth')
def test():
    network.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = network(data.cuda())
            test_loss += F.nll_loss(output, target.cuda(), size_average=False).item()
            #test_loss += F.nll_loss(output, target, size_average=False).item()

            pred = output.data.max(1, keepdim=True)[1]
            #correct += pred.eq(target.data.view_as(pred)).sum()
            correct += pred.cpu().eq(target.data.view_as(pred.cpu())).sum()
            test_loss /= len(test_loader.dataset)
            test_losses.append(test_loss)
            print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
                test_loss, correct, len(test_loader.dataset),
                100. * correct / len(test_loader.dataset)))

In [None]:
#Network training
for epoch in range(1, n_epochs + 1):
    train(epoch)
    test()

In [None]:
saved_weights = torch.load('model.pth')

saved_model = Net()
saved_model = saved_model
saved_model = saved_model.cuda()

saved_model.load_state_dict(saved_weights)
saved_model.eval()
print(saved_model)

#input_image = cv2.imread('temp_folder/img_137.jpg',0)/255#example_data[2][0]
#input_tensor = torch.tensor(input_image[np.newaxis,np.newaxis,:,:],dtype=torch.float32)
#print(input_tensor)
#output = saved_model(input_tensor.cuda())
#print(output)

In [None]:
input_image = cv2.imread('temp_folder/img_137.jpg',0)/255#example_data[2][0]
input_tensor = torch.tensor(input_image[np.newaxis,np.newaxis,:,:],dtype=torch.float32)
#print(input_tensor)
output = saved_model(input_tensor.cuda())
print(output)

In [None]:
conv_layer_idx = [0,2,7]
relu_layer_idx = [1,5,10,15,16]
fc_layer_idx = [13,16]
batch_norm2d_idx = [3,8]
batch_norm1d_idx = [14]
max_pool_idx = [4,9]

In [None]:
def merge_batch_norm2d(conv_layer_idx, bn2d_idx):
    eps =1e-05
    conv_weight = saved_weights['mnist.'+str(conv_layer_idx)+'.weight']
    conv_bias = saved_weights['mnist.'+str(conv_layer_idx)+'.bias']
    bn2d_gamma = saved_weights['mnist.'+str(bn2d_idx)+'.weight']
    bn2d_beta = saved_weights['mnist.'+str(bn2d_idx)+'.bias']
    
    bn2d_mean = saved_weights['mnist.'+str(bn2d_idx)+'.running_mean']
    bn2d_var = saved_weights['mnist.'+str(bn2d_idx)+'.running_var']
    bn2d_var = torch.sqrt(bn2d_var + eps)
    print(conv_weight.shape)

    conv_weight = conv_weight.transpose(0,3)
    print(conv_weight.shape)
    #gamma = bn.weight
    #beta = bn.bias
    #mean = bn.running_mean
    #var = bn.running_var
    #eps =1e-05

    #var_sqrt = torch.sqrt(var + eps)

    #w = (self.weight * gamma.reshape(self.out_channels, 1, 1, 1)) / var_sqrt.reshape(self.out_channels, 1,1, 1)
    #b = ((self.bias - mean) * gamma) / var_sqrt + beta

    output_conv_weight = torch.mul(conv_weight, torch.div(bn2d_gamma, bn2d_var))
    #print(output_conv_weight.shape)
    output_conv_weight = output_conv_weight.transpose(0,3)
    conv_weight = conv_weight.transpose(0,3)
    output_conv_bias = torch.add(torch.mul((conv_bias-bn2d_mean), torch.div(bn2d_gamma,bn2d_var)),bn2d_beta)
    return output_conv_weight, output_conv_bias#, #output_conv_weight, 

def merge_batch_norm1d(fc_layer_idx, bn1d_idx):
    eps =1e-05

    fc_weight = saved_weights['mnist.'+str(fc_layer_idx)+'.weight']
    fc_bias = saved_weights['mnist.'+str(fc_layer_idx)+'.bias']
    bn1d_gamma = saved_weights['mnist.'+str(bn1d_idx)+'.weight']
    bn1d_beta = saved_weights['mnist.'+str(bn1d_idx)+'.bias']
    bn1d_mean = saved_weights['mnist.'+str(bn1d_idx)+'.running_mean']
    bn1d_var = saved_weights['mnist.'+str(bn1d_idx)+'.running_var']
    bn1d_var = torch.sqrt(bn1d_var + eps)
    fc_weight = fc_weight.transpose(0,1)
    output_fc_weight = torch.mul(fc_weight, torch.div(bn1d_gamma, bn1d_var))
    fc_weight = fc_weight.transpose(0,1)
    output_fc_weight = output_fc_weight.transpose(0,1)
    output_fc_bias = torch.add(torch.mul((fc_bias-bn1d_mean), torch.div(bn1d_gamma,bn1d_var)),bn1d_beta)
    return output_fc_weight, output_fc_bias

In [None]:
def output_list_fn(layers_list,data_loader,model):
    tensor_list = []
    output_dict = {}
    count =0
    for batch_idx, (data, _) in enumerate(data_loader):
        max_output_list =[]
        x = data.cpu()
        for idx,seq in enumerate(model):#.mnist):
            x = seq(x)
            if idx in layers_list:
                max_output_list.append(x) #intermediate_output#(torch.flatten(x)))
        #del data_cuda
        #torch.cuda.empty_cache()
        iter = 0
        if count == 0:
            tensor_list.append(data)
        else:
            tensor_list[iter] = torch.cat((tensor_list[iter],data),0)
        iter+=1   
        for idx in layers_list:
            if count == 0:
                tensor_list.append(max_output_list[iter-1])
            else:
                tensor_list[iter] = torch.cat((tensor_list[iter],max_output_list[iter-1]),0)
            iter+=1
        
        count +=1
    tensor_list[0] = torch.flatten(tensor_list[0]).detach().numpy()
    output_dict['input'] = tensor_list[0]
    iter =1
    for idx in layers_list:
        tensor_list[iter] = torch.flatten(tensor_list[iter]).detach().numpy()
        output_dict[str(idx)] = tensor_list[iter]
        iter+=1

    return output_dict

In [None]:
def output_list_fn_cuda(layers_list,data_loader,model):
    tensor_list = []
    output_dict = {}
    count =0
    for batch_idx, (data, _) in enumerate(data_loader):
        max_output_list =[]
        x = data.cuda()
        for idx,seq in enumerate(model):#.mnist):
            x = seq(x)
            if idx in layers_list:
                max_output_list.append(x.detach().cpu()) #intermediate_output#(torch.flatten(x)))
        #del data_cuda
        #torch.cuda.empty_cache()
        iter = 0
        if count == 0:
            tensor_list.append(data)
        else:
            tensor_list[iter] = torch.cat((tensor_list[iter],data.cpu()),0)
        iter+=1   
        for idx in layers_list:
            if count == 0:
                tensor_list.append(max_output_list[iter-1])
            else:
                tensor_list[iter] = torch.cat((tensor_list[iter],max_output_list[iter-1]),0)
            iter+=1
        
        count +=1
    tensor_list[0] = torch.flatten(tensor_list[0]).detach().numpy()
    output_dict['input'] = tensor_list[0]
    iter =1
    for idx in layers_list:
        tensor_list[iter] = torch.flatten(tensor_list[iter]).detach().numpy()
        output_dict[str(idx)] = tensor_list[iter]
        iter+=1
    
    
    return output_dict

    

In [None]:
relu_output_dict = output_list_fn_cuda(relu_layer_idx,train_loader,saved_model.mnist.cuda())

In [None]:
# print percentiles
relu_layers = [1,5,10,15,16]
percentile_99 = np.percentile(relu_output_dict['input'],99)
percentile_99_9 = np.percentile(relu_output_dict['input'],99.9)
percentile_99_99 = np.percentile(relu_output_dict['input'],99.99)
percentile_100 = np.percentile(relu_output_dict['input'],100)
print('99 percentile', percentile_99)
print('99.9 percentile', percentile_99_9)
print('99.99 percentile', percentile_99_99)
print('100 percentile', percentile_100)
percentile_dict = {}
percentile_dict['layer_input_99'] = percentile_99
percentile_dict['layer_input_99_9'] = percentile_99_9
percentile_dict['layer_input_99_99'] = percentile_99_99
percentile_dict['layer_input_100'] = percentile_100
for idx in relu_layers:
    percentile_99 = np.percentile(relu_output_dict[str(idx)],99)
    percentile_99_9 = np.percentile(relu_output_dict[str(idx)],99.9)
    percentile_99_99 = np.percentile(relu_output_dict[str(idx)],99.99)
    percentile_100 = np.percentile(relu_output_dict[str(idx)],100)
    percentile_dict['layer_'+str(idx)+'_99'] = percentile_99
    percentile_dict['layer_'+str(idx)+'_99_9'] = percentile_99_9
    percentile_dict['layer_'+str(idx)+'_99_99'] = percentile_99_99
    percentile_dict['layer_'+str(idx)+'_100'] = percentile_100

    print('99 percentile', percentile_99)
    print('99.9 percentile', percentile_99_9)
    print('99.99 percentile', percentile_99_99)

    print('100 percentile', percentile_100)

In [None]:
step_size = relu_output_dict[0].cpu()/number_of_bins
plt.scatter(np.arange(0,relu_output_dict[0].cpu(),step_size),relu_output_hist[0].cpu())
plt.show()
step_size = relu_output_dict[1].cpu()/number_of_bins
plt.scatter(np.arange(0,relu_output_dict[1].cpu(),step_size),relu_output_hist[1].cpu())
plt.show()
step_size = relu_output_dict[2].cpu()/number_of_bins
plt.scatter(np.arange(0,relu_output_dict[2].cpu(),step_size),relu_output_hist[2].cpu())
plt.show()
step_size = relu_output_dict[3].cpu()/number_of_bins
plt.scatter(np.arange(0,relu_output_dict[3].cpu(),step_size),relu_output_hist[3].cpu())
plt.show()
step_size = relu_output_dict[4].cpu()/number_of_bins
plt.scatter(np.arange(0,relu_output_dict[4].cpu(),step_size),relu_output_hist[4].cpu())
plt.show()
step_size = relu_output_dict[5].cpu()/number_of_bins
plt.scatter(np.arange(0,relu_output_dict[5].cpu(),step_size),relu_output_hist[5].cpu())
plt.show()

In [None]:
def normalize_weight_bias(weight,bias,l1_out,l2_out):
    weight = torch.mul(weight,l1_out/l2_out)
    bias = torch.div(bias,l2_out)
    return weight,bias

In [None]:
conv1_weight = saved_weights['mnist.0.weight']
conv1_bias = saved_weights['mnist.0.bias']

conv2_weight, conv2_bias = merge_batch_norm2d(2,3)
conv3_weight, conv3_bias = merge_batch_norm2d(7,8)
fc1_weight,fc1_bias = merge_batch_norm1d(13,14)
fc2_weight = saved_weights['mnist.16.weight']
fc2_bias = saved_weights['mnist.16.bias']

snn_conv1_weight,snn_conv1_bias = normalize_weight_bias(conv1_weight,conv1_bias,percentile_dict['layer_input_99_9'],percentile_dict['layer_1_99_9'])
snn_conv2_weight,snn_conv2_bias = normalize_weight_bias(conv2_weight,conv2_bias,percentile_dict['layer_1_99_9'],percentile_dict['layer_5_99_9'])
snn_conv3_weight,snn_conv3_bias = normalize_weight_bias(conv3_weight,conv3_bias,percentile_dict['layer_5_99_9'],percentile_dict['layer_10_99_9'])
snn_fc1_weight,snn_fc1_bias = normalize_weight_bias(fc1_weight,fc1_bias,percentile_dict['layer_10_99_9'],percentile_dict['layer_15_99_9'])
snn_fc2_weight,snn_fc2_bias = normalize_weight_bias(fc2_weight,fc2_bias,percentile_dict['layer_15_99_9'],percentile_dict['layer_16_99_9'])



#print(snn_conv1_weight)

In [None]:
conv_layer_idx = [0,2,7]
relu_layer_idx = [1,5,10,15,16]
fc_layer_idx = [13,16]
batch_norm2d_idx = [3,8]
batch_norm1d_idx = [14]
max_pool_idx = [4,9]


In [None]:
train_batch_size = 2000
test_batch_size = 1000
snn_train_loader = torch.utils.data.DataLoader(
  #torchvision.datasets.MNIST('/files/', train=True, download=True,
  torchvision.datasets.MNIST('.', train=True, download=True,

                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor()
                             ])),
  batch_size=train_batch_size,pin_memory = False, shuffle=True)

snn_test_loader = torch.utils.data.DataLoader(
  #torchvision.datasets.MNIST('/files/', train=False, download=True,
  torchvision.datasets.MNIST('.', train=False, download=True,

                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor()
                             ])),
  batch_size=test_batch_size, pin_memory = False,shuffle=True)


In [None]:
train_size =0
for batch_idx, (data, target) in enumerate(snn_train_loader):
    train_size += data.shape[0]
print(train_size)

test_size =0
for batch_idx, (data, target) in enumerate(snn_test_loader):
    test_size += data.shape[0]
print(test_size)
    

In [None]:
torch.cuda.empty_cache()

In [None]:
## Data gen: Generate data and analyze how many time steps on an average is needed for someone to confirm the output
#spikenet_fc_1 = spiking_fc(1,snn_fc1_weight,snn_fc1_bias,1).cuda()
#spikenet_fc_2 = spiking_fc(1,snn_fc2_weight,snn_fc2_bias,1).cuda()
#spike_frame = torch.zeros_like(input_tensor)
#spike_pot = torch.zeros_like(input_tensor)
sp1_shape = snn_conv1_weight.shape
spikenet_1 = nn.Conv2d(sp1_shape[0],sp1_shape[1],kernel_size=sp1_shape[2]).cuda()
spikenet_1.weight = nn.Parameter(snn_conv1_weight)
spikenet_1.bias = nn.Parameter(snn_conv1_bias)
sp2_shape = snn_conv2_weight.shape
spikenet_2 = nn.Conv2d(sp2_shape[0],sp2_shape[1],kernel_size=sp2_shape[2]).cuda()
spikenet_2.weight = nn.Parameter(snn_conv2_weight)
spikenet_2.bias = nn.Parameter(snn_conv2_bias)
sp3_shape = snn_conv3_weight.shape
spikenet_3 = nn.Conv2d(sp3_shape[0],sp3_shape[1],kernel_size=sp3_shape[2]).cuda()
spikenet_3.weight = nn.Parameter(snn_conv3_weight)
spikenet_3.bias = nn.Parameter(snn_conv3_bias)
sp_fc1_shape = snn_fc1_weight.shape
spikenet_fc_1 = nn.Linear(sp_fc1_shape[0],sp_fc1_shape[1]).cuda()
spikenet_fc_1.weight = nn.Parameter(snn_fc1_weight)
spikenet_fc_1.bias = nn.Parameter(snn_fc1_bias)
sp_fc2_shape = snn_fc2_weight.shape
spikenet_fc_2 = nn.Linear(sp_fc2_shape[0],sp_fc2_shape[1]).cuda()
spikenet_fc_2.weight = nn.Parameter(snn_fc2_weight)
spikenet_fc_2.bias = nn.Parameter(snn_fc2_bias)
max_pool2 = nn.MaxPool2d(2, stride=2).cuda()
max_pool3 = nn.MaxPool2d(2, stride=2).cuda()
number_of_timesteps = 500 #1000
train_output_expected = np.zeros((train_size,1))
train_output_array = np.zeros((train_size,number_of_timesteps))

In [None]:
accuracy_over_time = np.zeros(train_output_array.shape[1])
total_samples = train_output_array.shape[0]
from scipy import stats
for idx in range(train_output_array.shape[1]):
    part_of_output = train_output_array[:,0:idx+1]
    max_of_part = stats.mode(part_of_output,axis=1)[0]
    #print(max_of_part)
    percentage = np.sum(np.equal(max_of_part ,train_output_expected)*1.0)
    #print(percentage)
    accuracy_over_time[idx] = percentage*100/total_samples
print(accuracy_over_time)

In [None]:
## Running SNN on train data 

In [None]:
number_of_timesteps = 500
test_output_expected = np.zeros((test_size,1))
test_output_array = np.zeros((test_size,number_of_timesteps))

for batch_idx, (data, target) in enumerate(snn_train_loader):
    #x = data np.zeros((1000,1000))
    start = batch_idx*train_batch_size
    end = batch_idx*train_batch_size + data.shape[0]
    test_output_expected[start:end,0] = target.detach().numpy()
    data = data.cuda()
    #print(data.shape)
    spike_pot = torch.zeros_like(data).detach()cuda()
    sp1_out = torch.zeros((data.shape[0],8,26,26)).detach().cuda()
    sp2_out = torch.zeros((data.shape[0],16,24,24)).detach().cuda()
    sp3_out = torch.zeros((data.shape[0],20,8,8)).detach().cuda()
    sp_fc1_out = torch.zeros((data.shape[0],50)).detach().cuda()
    sp_fc2_out = torch.zeros((data.shape[0],10)).detach().cuda()

    
    for i in range(number_of_timesteps):
        print(i)
        first_frame = True if i ==0 else False
        #x = spike_frame_mnist.reset_by_subtraction(input_tensor,first_frame)
            
        #spike_pot = torch.add(spike_pot, data)
        spike_pot.add_(data)

        spike_frame = torch.gt(spike_pot,1.0)*1.0
        spike_pot.sub_(spike_frame)
        sp1_out.add_(spikenet_1(spike_frame))
        sp2_in = torch.gt(sp1_out,1.0)*1.0
        sp1_out.sub_(sp2_in)
        sp2_out.add_(spikenet_2(sp2_in))
        sp3_in = torch.gt(sp2_out,1.0)*1.0
        sp2_out.sub_(sp3_in)
        sp3_in_red = max_pool2(sp3_in)
        sp3_out.add_(spikenet_3(sp3_in_red))
        sp4_in = torch.gt(sp3_out,1.0)*1.0
        sp3_out.sub_(sp4_in)
        sp4_in_red = max_pool3(sp4_in)
        sp4_in_red_flat = torch.flatten(sp4_in_red,start_dim=1)
        sp_fc1_out.add_(spikenet_fc_1(sp4_in_red_flat))
        sp_fc2_in = torch.gt(sp_fc1_out,1.0)*1.0
        sp_fc1_out.sub_(sp_fc2_in)
        sp_fc2_out.add_(spikenet_fc_2(sp_fc2_in))
        sof_pot = F.softmax(sp_fc2_out)
        max_numbers = torch.argmax(sof_pot,dim=1)
        test_output_array[start:end,i] = max_numbers.cpu().detach().numpy()
        
        torch.cuda.empty_cache()
    torch.cuda.empty_cache()


In [None]:
## Running SNN on test data 

In [None]:
### RAM PLEASE LOOK AT THIS CODE

number_of_timesteps = 400

test_output_expected = np.zeros((test_size,1))
test_output_array = np.zeros((test_size,number_of_timesteps))
for batch_idx, (data, target) in enumerate(snn_test_loader):
    start = batch_idx*train_batch_size
    end = batch_idx*train_batch_size + data.shape[0]
    test_output_expected[start:end,0] = target.detach().numpy()
    data = data.cuda()
    spike_pot = torch.zeros_like(data).detach().cuda()
    sp1_np = np.zeros((data.shape[0],8,26,26))
    sp1_out = torch.tensor(sp1_np, requires_grad=False, dtype=torch.float32).cuda()
    sp2_np = np.zeros((data.shape[0],16,24,24))
    sp2_out = torch.tensor(sp2_np, requires_grad=False, dtype=torch.float32).cuda()
    sp3_np = np.zeros((data.shape[0],20,8,8))
    sp3_out = torch.tensor(sp3_np, requires_grad=False, dtype=torch.float32).cuda()
    sp_fc1_np = np.zeros((data.shape[0],50))
    sp_fc1_out = torch.tensor(sp_fc1_np, requires_grad=False, dtype=torch.float32).cuda()
    sp_fc2_np = np.zeros((data.shape[0],10))
    sp_fc2_out = torch.tensor(sp_fc2_np, requires_grad=False, dtype=torch.float32).cuda()
    
    for i in range(number_of_timesteps):
        print(i)
        first_frame = True if i ==0 else False
        spike_pot.add_(data)

        spike_frame = torch.gt(spike_pot,1.0)*1.0 #F.threshold_(spike_pot,1,0).sign() #
        spike_pot.sub_(spike_frame)
        sp1_out.add_(spikenet_1(spike_frame).detach())
        sp2_in = torch.gt(sp1_out,1.0)*1.0 #F.threshold_(sp1_out, 1, 0) #
        sp1_out.sub_(sp2_in)
        sp2_out.add_(spikenet_2(sp2_in).detach())
        sp3_in = torch.gt(sp2_out,1.0)*1.0 #F.threshold_(sp2_out, 1, 0).sign() # 
        sp2_out.sub_(sp3_in)
        sp3_in_red = max_pool2(sp3_in)
        sp3_out.add_(spikenet_3(sp3_in_red).detach())
        sp4_in = torch.gt(sp3_out,1.0)*1.0 #F.threshold_(sp3_out, 1, 0).sign() #
        sp3_out.sub_(sp4_in)
        sp4_in_red = max_pool3(sp4_in)
        sp4_in_red_flat = torch.flatten(sp4_in_red,start_dim=1)
        sp_fc1_out.add_(spikenet_fc_1(sp4_in_red_flat).detach())
        sp_fc2_in = torch.gt(sp_fc1_out,1.0)*1.0 #F.threshold_(sp_fc1_out, 1, 0).sign() #
        sp_fc1_out.sub_(sp_fc2_in)
        sp_fc2_out.add_(spikenet_fc_2(sp_fc2_in).detach())
        sof_pot = F.softmax(sp_fc2_out).detach()
        max_numbers = torch.argmax(sof_pot,dim=1)
        test_output_array[start:end,i] = max_numbers.cpu().detach().numpy()
        print(torch.cuda.memory_allocated(device=0))
        torch.cuda.empty_cache()


In [None]:
torch.cuda.empty_cache()

In [None]:
accuracy_over_time = np.zeros(test_output_array.shape[1])
total_samples = test_output_array.shape[0]
from scipy import stats
for idx in range(test_output_array.shape[1]):
    part_of_output = test_output_array[:,0:idx+1]
    max_of_part = stats.mode(part_of_output,axis=1)[0]
    #print(max_of_part)
    percentage = np.sum(np.equal(max_of_part ,test_output_expected)*1.0)
    #print(percentage)
    accuracy_over_time[idx] = percentage*100/total_samples
    

In [None]:
print(accuracy_over_time)

In [None]:
#print(output_array[199])
print(train_output_array.shape)
print(train_output_expected.shape)
print(test_output_array.shape)
print(test_output_expected.shape)

In [None]:
class compute_analysis():
    def __init__(self,inp_ch, out_ch, padding='valid', filter_size=(3,3)):
        #super.init(compute_analysis,self).__init__()
        self.inp_ch = inp_ch
        self.out_ch = out_ch
        self.filter_size = filter_size
        self.padding = padding
        self.total_number_of_conv_ops = 0
        self.single_conv_ops =0 
        self.shape = None
        self.output_shape = np.zeros(2)
        self.first_ifmap = 1
        self.ifmap_sparsity_list = [] #torch.zeros(1)
        self.number_of_zero_ifmaps = 0
        self.number_of_nonzero_addtions =0
        self.layer_dict ={}
        self.sparsity_tensor = None
        self.conv_counter = nn.Conv2d(inp_ch,out_ch,kernel_size=filter_size).cuda()
        self.conv_weight = torch.zeros((out_ch,inp_ch,*filter_size),dtype= torch.float32)+1.0
        self.conv_bias = torch.zeros(out_ch)

        self.conv_counter.weight = nn.Parameter(self.conv_weight.cuda())
        self.conv_counter.bias = nn.Parameter(self.conv_bias.cuda())
        
    ## This will measure sparsity over input images    
    def crude_conv_analysis(self,input_frame,first_frame):
        if first_frame == True:
            self.shape = input_frame.shape
            shape = self.shape
            number_of_inp_pix = shape[0]*shape[1]*shape[2]*shape[3]
            if self.padding == 'valid':
                self.output_shape[0] = int(shape[2]-self.filter_size[0]+1)
                self.output_shape[1] = int(shape[3]-self.filter_size[1]+1)
            else:
                self.output_shape[0] = int(shape[2])
                self.output_shape[1] = int(shape[3])
            
            self.total_number_of_conv_ops = shape[0]*(self.output_shape[0])*(self.output_shape[1])*self.out_ch*self.filter_size[0]*self.filter_size[1]*shape[1]#self.inp_ch
            self.single_conv_ops = (self.output_shape[0])*(self.output_shape[1])*self.out_ch*self.filter_size[0]*self.filter_size[1]*self.inp_ch
            number_of_mult_in_conv_per_op = self.filter_size[0]*self.filter_size[1]*self.inp_ch
            for i in range(shape[0]):
                for j in range(shape[1]):
                    for id1 in range(int(self.output_shape[0])):
                        for id2 in range(int(self.output_shape[1])):
                            if self.first_ifmap == 1:
                                number_of_on_fields = torch.sum(input_frame[i,j,id1:id1+self.filter_size[0],id2:id2+self.filter_size[1]])
                                self.number_of_nonzero_addtions += number_of_on_fields*self.out_ch
                                if number_of_on_fields ==0:
                                    self.number_of_zero_ifmaps+=1
                                #print('before',self.ifmap_sparsity_list,torch.div(number_of_on_fields,number_of_mult_in_conv_per_op))
                                self.ifmap_sparsity_list.append(torch.div(number_of_on_fields,number_of_mult_in_conv_per_op).detach().numpy().item(0))
                                #print('after',self.ifmap_sparsity_list)

                                #self.ifmap_sparsity_list.append(torch.div(number_of_on_fields,number_of_mult_in_conv_per_op))
                                #print(torch.div(number_of_on_fields,number_of_mult_in_conv_per_op))#self.ifmap_sparsity_list)
                                self.first_ifmap = 0
                            else:
                                number_of_on_fields = torch.sum(input_frame[i,j,id1:id1+self.filter_size[0],id2:id2+self.filter_size[1]])
                                self.number_of_nonzero_addtions += number_of_on_fields*self.out_ch
                                if number_of_on_fields ==0:
                                    self.number_of_zero_ifmaps+=1
                                #print(torch.div(number_of_on_fields,number_of_mult_in_conv_per_op))#self.ifmap_sparsity_list)
                                #print('before',self.ifmap_sparsity_list)

                                self.ifmap_sparsity_list.append(torch.div(number_of_on_fields,number_of_mult_in_conv_per_op).detach().numpy().item(0))
                                #print('after',self.ifmap_sparsity_list)

                                #self.ifmap_sparsity_list = torch.cat((self.ifmap_sparsity_list[0],torch.div(number_of_on_fields,number_of_mult_in_conv_per_op)),0)
        
        else:
            self.shape = input_frame.shape
            shape = self.shape
            number_of_inp_pix = shape[0]*shape[1]*shape[2]*shape[3]

            self.total_number_of_conv_ops += shape[0]*(self.output_shape[0])*(self.output_shape[1])*self.out_ch*self.filter_size[0]*self.filter_size[1]*shape[1]#self.inp_ch
            number_of_mult_in_conv_per_op = self.filter_size[0]*self.filter_size[1]*self.inp_ch
            for i in range(shape[0]):
                for j in range(shape[1]):
                    for id1 in range(int(self.output_shape[0])):
                        for id2 in range(int(self.output_shape[1])):
                            number_of_on_fields = torch.sum(input_frame[i,j,id1:id1+self.filter_size[0],id2:id2+self.filter_size[1]])
                            self.number_of_nonzero_addtions += number_of_on_fields*self.out_ch
                            if number_of_on_fields ==0:
                                self.number_of_zero_ifmaps+=1
                            #print(torch.div(number_of_on_fields,number_of_mult_in_conv_per_op))#self.ifmap_sparsity_list)
                            #self.ifmap_sparsity_list = torch.cat((self.ifmap_sparsity_list,torch.div(number_of_on_fields,number_of_mult_in_conv_per_op)),0)
                            self.ifmap_sparsity_list.append(torch.div(number_of_on_fields,number_of_mult_in_conv_per_op).detach().numpy().item(0))

                            #self.ifmap_sparsity_list = torch.cat((self.ifmap_sparsity_list[0],torch.div(number_of_on_fields,number_of_mult_in_conv_per_op)),0)
                            
        

            
        return 0
    def conv_analysis(self,input_frame,first_frame):
        if first_frame == True:
            self.shape = input_frame.shape
            shape = self.shape
            number_of_inp_pix = shape[0]*shape[1]*shape[2]*shape[3]
            if self.padding == 'valid':
                self.output_shape[0] = int(shape[2]-self.filter_size[0]+1)
                self.output_shape[1] = int(shape[3]-self.filter_size[1]+1)
            else:
                self.output_shape[0] = int(shape[2])
                self.output_shape[1] = int(shape[3])

            self.total_number_of_conv_ops = shape[0]*(self.output_shape[0])*(self.output_shape[1])*self.out_ch*self.filter_size[0]*self.filter_size[1]*shape[1]#self.inp_ch
            self.single_conv_ops = (self.output_shape[0])*(self.output_shape[1])*self.out_ch*self.filter_size[0]*self.filter_size[1]*self.inp_ch
            number_of_mult_in_conv_per_op = self.filter_size[0]*self.filter_size[1]*self.inp_ch*self.out_ch
            #self.hist_bins = np.arange(number_of_mult_in_conv_per_op+1)
            number_of_on_fields = self.conv_counter(input_frame).detach().cpu()
            self.sparsity_tensor= torch.histc(number_of_on_fields,bins = number_of_mult_in_conv_per_op+1,max = number_of_mult_in_conv_per_op,min=0)
            self.number_of_nonzero_addtions += torch.sum(number_of_on_fields).detach().item()
            self.number_of_zero_ifmaps += torch.sum(torch.eq(number_of_on_fields,0)*1.0).detach().item()
            #self.sparsity_tensor = number_of_on_fields
            #self.sparsity_tensor = torch.div(number_of_on_fields,number_of_mult_in_conv_per_op).detach()
        
        else:
            self.shape = input_frame.shape
            shape = self.shape
            number_of_inp_pix = shape[0]*shape[1]*shape[2]*shape[3]

            self.total_number_of_conv_ops += shape[0]*(self.output_shape[0])*(self.output_shape[1])*self.out_ch*self.filter_size[0]*self.filter_size[1]*shape[1]#self.inp_ch
            number_of_mult_in_conv_per_op = self.filter_size[0]*self.filter_size[1]*self.inp_ch*self.out_ch
            number_of_on_fields = self.conv_counter(input_frame).detach().cpu()
            self.number_of_nonzero_addtions += torch.sum(number_of_on_fields).detach().item()
            self.number_of_zero_ifmaps += torch.sum(torch.eq(number_of_on_fields,0)*1.0).detach().item()
            self.sparsity_tensor.add_(torch.histc(number_of_on_fields,bins = number_of_mult_in_conv_per_op+1,max = number_of_mult_in_conv_per_op,min=0))

            #self.sparsity_tensor = torch.cat((self.sparsity_tensor,number_of_on_fields),0)
            #self.sparsity_tensor = torch.cat((self.sparsity_tensor,torch.div(number_of_on_fields,number_of_mult_in_conv_per_op).detach()),0)

            
        return 0
    
    
    
    def summary(self):
        self.layer_dict  = {'total_pure_ann_ops':self.total_number_of_conv_ops, 'single_ann_ops':self.single_conv_ops, 'ifmap_sparsity_list':self.ifmap_sparsity_list,'number_of_zero_ifmaps':self.number_of_zero_ifmaps,'total_nonzero_ops':self.number_of_nonzero_addtions,'sparsity_tensor':torch.flatten(self.sparsity_tensor)}
        return self.layer_dict
        
        
    def fc_analysis(self,input_frame):
        return 0

        
        

In [None]:
number_of_timesteps = 200
conv1_analysis = compute_analysis(1, 8)
conv2_analysis = compute_analysis(8, 16)
conv3_analysis = compute_analysis(16, 20,(5,5))
test_output_expected = np.zeros((test_size,1))
test_output_array = np.zeros((test_size,number_of_timesteps))
for batch_idx, (data, target) in enumerate(snn_test_loader):
    start = batch_idx*train_batch_size
    end = batch_idx*train_batch_size + data.shape[0]
    data = data.cuda()
    
    spike_pot = torch.zeros_like(data).detach().cuda()
    sp1_np = np.zeros((data.shape[0],8,26,26))
    sp1_out = torch.tensor(sp1_np, requires_grad=False, dtype=torch.float32).cuda()
    sp2_np = np.zeros((data.shape[0],16,24,24))
    sp2_out = torch.tensor(sp2_np, requires_grad=False, dtype=torch.float32).cuda()
    sp3_np = np.zeros((data.shape[0],20,8,8))
    sp3_out = torch.tensor(sp3_np, requires_grad=False, dtype=torch.float32).cuda()
    sp_fc1_np = np.zeros((data.shape[0],50))
    sp_fc1_out = torch.tensor(sp_fc1_np, requires_grad=False, dtype=torch.float32).cuda()
    sp_fc2_np = np.zeros((data.shape[0],10))
    sp_fc2_out = torch.tensor(sp_fc2_np, requires_grad=False, dtype=torch.float32).cuda()
    
    for i in range(number_of_timesteps):
        #print(i)
        first_frame = True if i ==0 else False
        spike_pot.add_(data)

        spike_frame = torch.gt(spike_pot,1.0)*1.0 #F.threshold_(spike_pot,1,0).sign() #
        spike_pot.sub_(spike_frame)
        conv1_analysis.conv_analysis(spike_frame,first_frame)
        sp1_out.add_(spikenet_1(spike_frame).detach())
        sp2_in = torch.gt(sp1_out,1.0)*1.0 #F.threshold_(sp1_out, 1, 0) #
        sp1_out.sub_(sp2_in)
        conv2_analysis.conv_analysis(sp2_in,first_frame)

        sp2_out.add_(spikenet_2(sp2_in).detach())
        sp3_in = torch.gt(sp2_out,1.0)*1.0 #F.threshold_(sp2_out, 1, 0).sign() # 
        sp2_out.sub_(sp3_in)
        conv3_analysis.conv_analysis(sp3_in,first_frame)

        sp3_in_red = max_pool2(sp3_in)
        sp3_out.add_(spikenet_3(sp3_in_red).detach())
        sp4_in = torch.gt(sp3_out,1.0)*1.0 #F.threshold_(sp3_out, 1, 0).sign() #
        sp3_out.sub_(sp4_in)
        sp4_in_red = max_pool3(sp4_in)
        sp4_in_red_flat = torch.flatten(sp4_in_red,start_dim=1)
        sp_fc1_out.add_(spikenet_fc_1(sp4_in_red_flat).detach())
        sp_fc2_in = torch.gt(sp_fc1_out,1.0)*1.0 #F.threshold_(sp_fc1_out, 1, 0).sign() #
        sp_fc1_out.sub_(sp_fc2_in)
        sp_fc2_out.add_(spikenet_fc_2(sp_fc2_in).detach())
        sof_pot = F.softmax(sp_fc2_out).detach()
        max_numbers = torch.argmax(sof_pot,dim=1)
        print(torch.cuda.memory_allocated(device=0))
        torch.cuda.empty_cache()
    if(batch_idx==5):
        break
    else:
        print('batch_id', batch_idx)



print(conv1_analysis.summary())  
