In [12]:
import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import Dataset, DataLoader

import torchvision
import torchvision.datasets
import torchvision.transforms as transforms

import numpy as np
import matplotlib.pyplot as plt

import time

from snntorch import spikegen
import matplotlib.pyplot as plt
import snntorch.spikeplot as splt
from IPython.display import HTML

from tqdm import tqdm


from modules.data_loader import *
from modules.network import *
from modules.neuron import *
from modules.synapse import *

In [13]:
class SYNAPSE_CONV_gra_test(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(SYNAPSE_CONV_gra_test, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.weight = nn.Parameter(torch.randn(self.out_channels, self.in_channels, self.kernel_size, self.kernel_size))
        self.bias = nn.Parameter(torch.randn(self.out_channels))

    def forward(self, spike):
        return SYNAPSE_CONV_METHOD_gra_test.apply(spike, self.weight, self.bias, self.stride, self.padding)

class SYNAPSE_CONV_METHOD_gra_test(torch.autograd.Function):
    @staticmethod
    def forward(ctx, spike_one_time, weight, bias, stride=1, padding=1):
        ctx.save_for_backward(spike_one_time, weight, bias, torch.tensor([stride], requires_grad=False), torch.tensor([padding], requires_grad=False))
        return F.conv2d(spike_one_time, weight, bias=bias, stride=stride, padding=padding)

    @staticmethod
    def backward(ctx, grad_output_current):
        spike_one_time, weight, bias, stride, padding = ctx.saved_tensors
        stride=stride.item()
        padding=padding.item()
        
        ## 이거 클론해야되는지 모르겠음!!!!
        grad_output_current_clone = grad_output_current.clone()


        grad_input_spike = grad_weight = grad_bias = None

        if ctx.needs_input_grad[0]:
            grad_input_spike = F.conv_transpose2d(grad_output_current_clone, weight, stride=stride, padding=padding)
        if ctx.needs_input_grad[1]:
            grad_weight = torch.nn.grad.conv2d_weight(spike_one_time, weight.shape, grad_output_current_clone,
                                                    stride=stride, padding=padding)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output_current_clone.sum((0, -1, -2))

        # print('grad_input_spike_conv', grad_input_spike)
        # print('grad_weight_conv', grad_weight)
        # print('grad_bias_conv', grad_bias)
        # print('grad_input_spike_conv', ctx.needs_input_grad[0])
        # print('grad_weight_conv', ctx.needs_input_grad[2])
        # print('grad_bias_conv', ctx.needs_input_grad[3])

        return grad_input_spike, grad_weight, grad_bias, None, None
   

In [14]:
batch = 3
in_channels = 2
out_channels = 4

image_size = 5

input_tensor = torch.randn(batch, in_channels, image_size, image_size, requires_grad=True)


# Define custom convolution layer
custom_conv = SYNAPSE_CONV_gra_test(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
custom_conv_output = custom_conv(input_tensor)

# Define standard convolution layer with the same weights and biases as custom layer
standard_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
standard_conv.weight = nn.Parameter(custom_conv.weight.data.clone())
standard_conv.bias = nn.Parameter(custom_conv.bias.data.clone())
standard_conv_output = standard_conv(input_tensor)

# Compare forward outputs
print("Forward outputs are equal: ", torch.allclose(custom_conv_output, standard_conv_output))

# Compute gradients
grad_output = torch.randn_like(custom_conv_output)
custom_conv_output.backward(grad_output, retain_graph=True)
standard_conv_output.backward(grad_output)

# Compare gradients w.r.t. input
print("Input gradients are equal: ", torch.allclose(input_tensor.grad, input_tensor.grad))

# Compare gradients w.r.t. weights
print("Weight gradients are equal: ", torch.allclose(custom_conv.weight.grad, standard_conv.weight.grad))

# Compare gradients w.r.t. biases
print("Bias gradients are equal: ", torch.allclose(custom_conv.bias.grad, standard_conv.bias.grad))

Forward outputs are equal:  True
grad_input_spike_conv tensor([[[[  1.6675,   3.1113,  -1.3269,   3.1651,   0.6786],
          [  3.2753,  -1.0269,  -7.5661,   0.4283,  -5.3309],
          [ -1.1095,  -2.7449,  -7.3839,   7.9657,   3.9452],
          [ -1.4781,  -0.8452,   9.1830,   9.7885,   7.1655],
          [  5.9332,   3.7079,   0.2336,  -9.2079,  -9.6860]],

         [[ -4.8036,   1.2488,   4.3674,   2.8444,   0.7935],
          [ -3.7342,   2.2655,   5.7872,   4.2463,   3.9889],
          [ -1.9559,   4.3582,   2.1653,  -1.3253,   0.8523],
          [ -3.8095,   8.5844,   2.3080,   0.7438,  -3.2798],
          [ -0.0207,  -5.9985,  -4.2232,   1.0674,   0.8361]]],


        [[[  2.9352,   2.1296,   5.0763,   1.8191,  -4.6384],
          [  8.9888,   5.7801,   0.6527,   1.3835,   0.7977],
          [ -7.4659,  -4.2613,  -4.6242,   3.5618,   4.2628],
          [ -1.9347,   1.2162,  -4.2999,   3.9684,   4.6007],
          [  0.2886,  -1.1876,   1.1517,  -5.7773,   3.8654]],

       

In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F
class CustomTwoLayerConvNet(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(CustomTwoLayerConvNet, self).__init__()
        self.layer1 = SYNAPSE_CONV_gra_test(in_channels, hidden_channels, kernel_size, stride, padding)
        self.layer2 = SYNAPSE_CONV_gra_test(hidden_channels, out_channels, kernel_size, stride, padding)
    
    def forward(self, x):
        x = self.layer1(x)
        x = F.relu(x)  # Adding a non-linearity for completeness
        x = self.layer2(x)
        return x

class StandardTwoLayerConvNet(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(StandardTwoLayerConvNet, self).__init__()
        self.layer1 = nn.Conv2d(in_channels, hidden_channels, kernel_size, stride, padding)
        self.layer2 = nn.Conv2d(hidden_channels, out_channels, kernel_size, stride, padding)
    
    def forward(self, x):
        x = self.layer1(x)
        x = F.relu(x)  # Adding a non-linearity for completeness
        x = self.layer2(x)
        return x

# Parameters
batch = 1
in_channels = 5
hidden_channels = 5
out_channels = 5
image_size = 5


# Define input tensor
input_tensor = torch.randn(batch, in_channels, image_size, image_size, requires_grad=True)

# Define custom two-layer convolution model
custom_model = CustomTwoLayerConvNet(in_channels, hidden_channels, out_channels)
custom_model_output = custom_model(input_tensor)

# Define standard two-layer convolution model with the same weights and biases as custom model
standard_model = StandardTwoLayerConvNet(in_channels, hidden_channels, out_channels)
standard_model.layer1.weight = nn.Parameter(custom_model.layer1.weight.data.clone())
standard_model.layer1.bias = nn.Parameter(custom_model.layer1.bias.data.clone())
standard_model.layer2.weight = nn.Parameter(custom_model.layer2.weight.data.clone())
standard_model.layer2.bias = nn.Parameter(custom_model.layer2.bias.data.clone())
standard_model_output = standard_model(input_tensor)

# Compare forward outputs
print("Forward outputs are equal: ", torch.allclose(custom_model_output, standard_model_output))

# Compute gradients
grad_output = torch.randn_like(custom_model_output)
custom_model_output.backward(grad_output, retain_graph=True)
standard_model_output.backward(grad_output)

input_tensor.grad= torch.round(input_tensor.grad * 1e5) / 1e5
custom_model.layer1.weight.grad= torch.round(custom_model.layer1.weight.grad * 1e5) / 1e5
standard_model.layer1.weight.grad= torch.round(standard_model.layer1.weight.grad * 1e5) / 1e5
custom_model.layer1.bias.grad= torch.round(custom_model.layer1.bias.grad * 1e5) / 1e5
standard_model.layer1.bias.grad= torch.round(standard_model.layer1.bias.grad * 1e5) / 1e5
custom_model.layer2.weight.grad= torch.round(custom_model.layer2.weight.grad * 1e5) / 1e5
standard_model.layer2.weight.grad= torch.round(standard_model.layer2.weight.grad * 1e5) / 1e5
custom_model.layer2.bias.grad= torch.round(custom_model.layer2.bias.grad * 1e5) / 1e5
standard_model.layer2.bias.grad= torch.round(standard_model.layer2.bias.grad * 1e5) / 1e5










# Compare gradients w.r.t. input
print("Input gradients are equal: ", torch.allclose(input_tensor.grad, input_tensor.grad))

# Compare gradients w.r.t. first layer weights
print("First layer weight gradients are equal: ", torch.allclose(custom_model.layer1.weight.grad, standard_model.layer1.weight.grad))

# Compare gradients w.r.t. first layer biases
print("First layer bias gradients are equal: ", torch.allclose(custom_model.layer1.bias.grad, standard_model.layer1.bias.grad))

# Compare gradients w.r.t. second layer weights
print("Second layer weight gradients are equal: ", torch.allclose(custom_model.layer2.weight.grad, standard_model.layer2.weight.grad))

# Compare gradients w.r.t. second layer biases
print("Second layer bias gradients are equal: ", torch.allclose(custom_model.layer2.bias.grad, standard_model.layer2.bias.grad))




# print('custom_model.layer1.weight.grad', custom_model.layer1.weight.grad.size())
# print('standard_model.layer1.weight.grad', standard_model.layer1.weight.grad.size())
# print('custom_model.layer1.weight.grad', custom_model.layer1.weight.grad)
# print('standard_model.layer1.weight.grad', standard_model.layer1.weight.grad)


Forward outputs are equal:  True
grad_input_spike_conv tensor([[[[  4.4154,   2.2984,  -9.4555,   4.0858,   1.1753],
          [ -4.3042,   2.3617,   7.4733,  -3.1823,  -4.4650],
          [  3.8756,  -2.2621, -14.6387,   4.7701,  -0.4026],
          [  0.1012,   5.2098,  13.5921,  -0.8115,  -3.6333],
          [ -2.9049,   1.3135,  -8.9027,   5.9925,  -0.5756]],

         [[  4.5611,   4.0176,   3.1602,   0.2289,  -0.3995],
          [  1.0245,  -9.4194,   1.5867,   2.1504,  -4.5593],
          [  7.4433,   4.2246,   5.7605,  -6.3857,   1.0267],
          [ -5.0355,   0.8542,  10.2554,   0.0686,  -5.2112],
          [  0.6433,  -2.8248,   0.1386,   2.9333,   1.3182]],

         [[  0.2378,   1.1505,   2.6447,  -1.8690,  -0.2379],
          [ -8.6084,   5.2001, -13.9032,   3.7629,   7.4997],
          [ -4.0535,  -8.6637,   2.0248,  -3.9737,  -0.1870],
          [ -1.2925,   2.1714, -13.4965,   3.9888,  -0.6797],
          [ -0.9666,   6.8602,  -4.8485,   2.0156,  -1.0743]],

         

In [16]:
print("First layer weight gradients are equal: ", torch.allclose(custom_model.layer1.weight.grad, standard_model.layer1.weight.grad))

First layer weight gradients are equal:  True


In [17]:
custom_model.layer1.weight.grad= torch.round(custom_model.layer1.weight.grad * 1e5) / 1e5
custom_model.layer1.weight.grad

tensor([[[[ 1.5139e+01,  1.5467e+01,  2.4720e+01],
          [-2.8639e+00, -7.0500e+01, -1.4494e+01],
          [ 5.1309e+00,  2.2252e+01, -1.0698e+01]],

         [[ 1.0914e+01,  1.7602e+01,  1.1687e+01],
          [-2.8848e+00, -1.0948e+01, -7.0184e+00],
          [-3.8666e+00,  2.5753e+01,  2.8845e+01]],

         [[ 1.3137e+01, -1.0906e+01, -1.1161e+01],
          [-3.3867e+01,  3.5889e+00, -4.2976e+00],
          [-4.8495e-01, -2.5969e+01, -5.7526e+00]],

         [[-3.2585e+01, -9.0445e+00, -1.9721e+01],
          [ 1.4435e+01, -2.3112e+00,  6.6558e+00],
          [-1.2920e+00,  2.4451e+01,  1.6336e+01]],

         [[ 9.2033e+00, -1.4991e+01, -2.5255e+01],
          [-3.4645e+01,  3.6436e+01,  2.6582e+01],
          [ 7.6796e+00, -6.8395e+00, -1.6259e+01]]],


        [[[ 6.9662e+00,  1.8094e+00,  7.5254e+00],
          [-2.1284e+01,  6.2344e+00,  2.1621e+01],
          [ 9.9756e+00, -3.0127e+00, -3.8798e+01]],

         [[ 3.5498e+00,  4.6208e+00,  1.2358e+01],
          [-1.625

In [18]:
standard_model.layer1.weight.grad = torch.round(standard_model.layer1.weight.grad * 1e5) / 1e5
standard_model.layer1.weight.grad

tensor([[[[ 1.5139e+01,  1.5467e+01,  2.4720e+01],
          [-2.8639e+00, -7.0500e+01, -1.4494e+01],
          [ 5.1309e+00,  2.2252e+01, -1.0698e+01]],

         [[ 1.0914e+01,  1.7602e+01,  1.1687e+01],
          [-2.8848e+00, -1.0948e+01, -7.0184e+00],
          [-3.8666e+00,  2.5753e+01,  2.8845e+01]],

         [[ 1.3137e+01, -1.0906e+01, -1.1161e+01],
          [-3.3867e+01,  3.5889e+00, -4.2976e+00],
          [-4.8495e-01, -2.5969e+01, -5.7526e+00]],

         [[-3.2585e+01, -9.0445e+00, -1.9721e+01],
          [ 1.4435e+01, -2.3112e+00,  6.6558e+00],
          [-1.2920e+00,  2.4451e+01,  1.6336e+01]],

         [[ 9.2033e+00, -1.4991e+01, -2.5255e+01],
          [-3.4645e+01,  3.6436e+01,  2.6582e+01],
          [ 7.6796e+00, -6.8395e+00, -1.6259e+01]]],


        [[[ 6.9662e+00,  1.8094e+00,  7.5254e+00],
          [-2.1284e+01,  6.2344e+00,  2.1621e+01],
          [ 9.9756e+00, -3.0127e+00, -3.8798e+01]],

         [[ 3.5498e+00,  4.6208e+00,  1.2358e+01],
          [-1.625

In [19]:
custom_model.layer1.weight.grad[0][0][0][0] 

tensor(15.1393)

In [20]:
standard_model.layer1.weight.grad[0][0][0][0]

tensor(15.1393)

In [21]:
custom_model.layer1.weight.grad == standard_model.layer1.weight.grad

tensor([[[[True, True, True],
          [True, True, True],
          [True, True, True]],

         [[True, True, True],
          [True, True, True],
          [True, True, True]],

         [[True, True, True],
          [True, True, True],
          [True, True, True]],

         [[True, True, True],
          [True, True, True],
          [True, True, True]],

         [[True, True, True],
          [True, True, True],
          [True, True, True]]],


        [[[True, True, True],
          [True, True, True],
          [True, True, True]],

         [[True, True, True],
          [True, True, True],
          [True, True, True]],

         [[True, True, True],
          [True, True, True],
          [True, True, True]],

         [[True, True, True],
          [True, True, True],
          [True, True, True]],

         [[True, True, True],
          [True, True, True],
          [True, True, True]]],


        [[[True, True, True],
          [True, True, True],
          [True,