In [44]:
# Copyright (c) 2024 Byeonghyeon Kim 
# github site: https://github.com/bhkim003/ByeonghyeonKim
# email: bhkim003@snu.ac.kr
 
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
# the Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
 
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
 
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.


In [86]:
# 메인 셀

import sys
import torchvision
import os
import torch
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import torchvision
import torchvision.datasets
import torchvision.transforms as transforms

import numpy as np
import matplotlib.pyplot as plt

import time



class SYNAPSE_FC_METHOD(torch.autograd.Function):
    @staticmethod
    def forward(ctx, spike_one_time, spike_now, weight, bias):
        ctx.save_for_backward(spike_one_time, spike_now, weight, bias)
        return F.linear(spike_one_time, weight, bias=bias)

    @staticmethod
    def backward(ctx, grad_output_current):
        #############밑에부터 수정해라#######
        spike_one_time, spike_now, weight, bias = ctx.saved_tensors
        
        ## 이거 클론해야되는지 모르겠음!!!!
        grad_output_current_clone = grad_output_current.clone()

        grad_input_spike = grad_weight = grad_bias = None


        if ctx.needs_input_grad[0]:
            grad_input_spike = grad_output_current_clone @ weight
        if ctx.needs_input_grad[2]:
            grad_weight = grad_output_current_clone.t() @ spike_now
        if bias is not None and ctx.needs_input_grad[3]:
            grad_bias = grad_output_current_clone.sum(0)

        return grad_input_spike, None, grad_weight, grad_bias

     
class SYNAPSE_FC(nn.Module):
    def __init__(self, in_features, out_features, trace_const1=1, trace_const2=0.7):
        super(SYNAPSE_FC, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.trace_const1 = trace_const1
        self.trace_const2 = trace_const2

        # self.weight = torch.randn(self.out_features, self.in_features, requires_grad=True)
        # self.bias = torch.randn(self.out_features, requires_grad=True)
        self.weight = nn.Parameter(torch.randn(self.out_features, self.in_features))
        self.bias = nn.Parameter(torch.randn(self.out_features))

    def forward(self, spike):
        # spike: [Time, Batch, Features]   
        Time = spike.shape[0]
        Batch = spike.shape[1] 
        output_current = torch.zeros(Time, Batch, self.out_features)

        # spike_detach = spike.detach().clone()
        spike_detach = spike.detach()
        spike_past = torch.zeros_like(spike_detach[0])
        spike_now = torch.zeros_like(spike_detach[0])

        for t in range(Time):
            spike_now = self.trace_const1*spike_detach[t] + self.trace_const2*spike_past
            output_current[t]= SYNAPSE_FC_METHOD.apply(spike[t], spike_now, self.weight, self.bias) 
            spike_past = spike_now

        return output_current 


class SYNAPSE_CONV_METHOD(torch.autograd.Function):
    @staticmethod
    def forward(ctx, spike_one_time, spike_now, weight, bias, stride=1, padding=1):
        ctx.save_for_backward(spike_one_time, spike_now, weight, bias, torch.tensor([stride], requires_grad=False), torch.tensor([padding], requires_grad=False))
        return F.conv2d(spike_one_time, weight, bias=bias, stride=stride, padding=padding)

    @staticmethod
    def backward(ctx, grad_output_current):
        spike_one_time, spike_now, weight, bias, stride, padding = ctx.saved_tensors
        stride=stride.item()
        padding=padding.item()
        
        ## 이거 클론해야되는지 모르겠음!!!!
        grad_output_current_clone = grad_output_current.clone()

        grad_input_spike = grad_weight = grad_bias = None


        if ctx.needs_input_grad[0]:
            grad_input_spike = F.conv_transpose2d(grad_output_current_clone, weight, stride=stride, padding=padding)
        if ctx.needs_input_grad[2]:
            grad_weight = torch.nn.grad.conv2d_weight(spike_now, weight.shape, grad_output_current_clone,
                                                      stride=stride, padding=padding)
        if bias is not None and ctx.needs_input_grad[3]:
            grad_bias = grad_output_current_clone.sum((0, -1, -2))

        return grad_input_spike, None, grad_weight, grad_bias, None, None

     



class SYNAPSE_CONV(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, trace_const1=1, trace_const2=0.7):
        super(SYNAPSE_CONV, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.trace_const1 = trace_const1
        self.trace_const2 = trace_const2

        # self.weight = torch.randn(self.out_channels, self.in_channels, self.kernel_size, self.kernel_size, requires_grad=True)
        # self.bias = torch.randn(self.out_channels, requires_grad=True)
        self.weight = nn.Parameter(torch.randn(self.out_channels, self.in_channels, self.kernel_size, self.kernel_size))
        self.bias = nn.Parameter(torch.randn(self.out_channels))

    def forward(self, spike):
        # spike: [Time, Batch, Channel, Height, Width]   
        Time = spike.shape[0]
        Batch = spike.shape[1] 
        Channel = self.out_channels
        Height = (spike.shape[3] + self.padding*2 - self.kernel_size) // self.stride + 1
        Width = (spike.shape[4] + self.padding*2 - self.kernel_size) // self.stride + 1
        output_current = torch.zeros(Time, Batch, Channel, Height, Width)

        # spike_detach = spike.detach().clone()
        spike_detach = spike.detach()
        spike_past = torch.zeros_like(spike_detach[0])
        spike_now = torch.zeros_like(spike_detach[0])

        for t in range(Time):
            spike_now = self.trace_const1*spike_detach[t] + self.trace_const2*spike_past
            output_current[t]= SYNAPSE_CONV_METHOD.apply(spike[t], spike_now, self.weight, self.bias, self.stride, self.padding) 
            spike_past = spike_now

        return output_current 



class LIF_METHOD(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input_current_one_time, v_one_time, v_decay, v_threshold, v_reset, sg_width):
        v_one_time = v_one_time * v_decay + input_current_one_time # leak + pre-synaptic current integrate
        spike = (v_one_time >= v_threshold).float() #fire
        ctx.save_for_backward(v_one_time, torch.tensor([v_decay], requires_grad=False), 
                              torch.tensor([v_threshold], requires_grad=False), 
                              torch.tensor([v_reset], requires_grad=False), 
                              torch.tensor([sg_width], requires_grad=False)) # save before reset
        v_one_time = (v_one_time - spike * v_threshold).clamp_min(0) # reset
        return spike, v_one_time

    @staticmethod
    def backward(ctx, grad_output_spike, grad_output_v):
        v_one_time, v_decay, v_threshold, v_reset, sg_width = ctx.saved_tensors
        v_decay=v_decay.item()
        v_threshold=v_threshold.item()
        v_reset=v_reset.item()
        sg_width=sg_width.item()

        grad_input_current = grad_output_spike.clone()
        # grad_temp_v = grad_output_v.clone() # not used

        ################ select one of the following surrogate gradient functions ################
        #===========surrogate gradient function (rectangle)
        grad_input_current = grad_input_current * ((v_one_time - v_threshold).abs() < sg_width/2).float() / sg_width

        #===========surrogate gradient function (sigmoid)
        # sig = torch.sigmoid((v_one_time - v_threshold))
        # grad_input_current =  sig*(1-sig)*grad_input_current

        #===========surrogate gradient function (rough rectangle)
        # v_minus_th = (v_one_time - v_threshold)
        # grad_input_current[v_minus_th <= -.5] = 0
        # grad_input_current[v_minus_th > .5] = 0
        ###########################################################################################
        return grad_input_current, None, None, None, None, None

class LIF_layer(nn.Module):
    def __init__ (self, v_init = 0.0, v_decay = 0.8, v_threshold = 0.5, v_reset = 0.0, sg_width = 1):
        super(LIF_layer, self).__init__()
        self.v_init = v_init
        self.v_decay = v_decay
        self.v_threshold = v_threshold
        self.v_reset = v_reset
        self.sg_width = sg_width

    def forward(self, input_current):
        v = torch.full_like(input_current, fill_value = self.v_init, dtype = torch.float) # v (membrane potential) init
        post_spike = torch.full_like(input_current, fill_value = self.v_init, dtype = torch.float) # v (membrane potential) init
        # i와 v와 post_spike size는 여기서 다 같음: [Time, Batch, Channel, Height, Width] 

        Time = v.shape[0]
        for t in range(Time):
            # leaky하고 input_current 더하고 fire하고 reset까지 (backward직접처리)
            post_spike[t], v[t] = LIF_METHOD.apply(input_current[t], v[t], 
                                            self.v_decay, self.v_threshold, self.v_reset, self.sg_width) 

        return post_spike
    
    


torch.manual_seed(42)

# HEPER PARAMETER
TIME = 8
BATCH = 128
IMAGE_PIXEL_CHANNEL = 1
IMAGE_SIZE = 28
CLASS_NUM = 10

## SYNAPSE_CONV 레이어의 하이퍼파라미터
synapse_conv_in_channels = IMAGE_PIXEL_CHANNEL
# synapse_conv_out_channels = layer별 지정
# synapse_conv_kernel_size = layer별 지정
synapse_conv_stride = 1
synapse_conv_padding = 1
synapse_conv_trace_const1 = 1
synapse_conv_trace_const2 = 0.7

## LIF_layer 레이어의 하이퍼파라미터
lif_layer_v_init = 0.0
lif_layer_v_decay = 0.8
lif_layer_v_threshold = 1.2
lif_layer_v_reset = 0.0
lif_layer_sg_width = 1

## SYNAPSE_FC 레이어의 하이퍼파라미터
# synapse_fc_in_features = 마지막CONV_OUT_CHANNEL * H * W
synapse_fc_out_features = CLASS_NUM
synapse_fc_trace_const1 = 1
synapse_fc_trace_const2 = 0.7


class MY_SNN_MK1(nn.Module):
    def __init__(self):
        super(MY_SNN_MK1, self).__init__()

        in_channels = synapse_conv_in_channels
        out_channels = 64
        self.synapse_conv1 = SYNAPSE_CONV(in_channels=in_channels, 
                                          out_channels=out_channels, 
                                          kernel_size=3, 
                                          stride=synapse_conv_stride, 
                                          padding=synapse_conv_padding, 
                                          trace_const1=synapse_conv_trace_const1, 
                                          trace_const2=synapse_conv_trace_const2)
        


        in_channels = 64
        out_channels = 64
        self.synapse_conv2 = SYNAPSE_CONV(in_channels=in_channels, 
                                          out_channels=out_channels, 
                                          kernel_size=3, 
                                          stride=synapse_conv_stride, 
                                          padding=synapse_conv_padding, 
                                          trace_const1=synapse_conv_trace_const1, 
                                          trace_const2=synapse_conv_trace_const2)
        

        
        self.lif_layer = LIF_layer(v_init=lif_layer_v_init, 
                                   v_decay=lif_layer_v_decay, 
                                   v_threshold=lif_layer_v_threshold, 
                                   v_reset=lif_layer_v_reset, 
                                   sg_width=lif_layer_sg_width)
        


        self.synapse_FC = SYNAPSE_FC(in_features=64*28*28,  # 마지막CONV의 OUT_CHANNEL * H * W
                                      out_features=CLASS_NUM, 
                                      trace_const1=synapse_fc_trace_const1, 
                                      trace_const2=synapse_fc_trace_const2)
        

    def forward(self, spike_input):
        spike_input = self.synapse_conv1(spike_input)
        spike_input = self.lif_layer(spike_input)
                                     
        spike_input = self.synapse_conv2(spike_input)
        spike_input = self.lif_layer(spike_input)


        spike_input = spike_input.view(spike_input.size(0), spike_input.size(1), -1)
        spike_input = self.synapse_FC(spike_input)

        spike_input = spike_input.sum(axis=0)

        return spike_input




############################################################
####################### DATASET ############################
transform = transforms.Compose([transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5,),(0.5))])


trainset = torchvision.datasets.MNIST(root='/data2',
                                      train=True,
                                      download=True,
                                      transform=transform)

# 조금만 쓰기
# subset_indices = torch.randperm(len(trainset))[:1000]
# trainset = torch.utils.data.Subset(trainset, subset_indices)

testset = torchvision.datasets.MNIST(root='/data2',
                                     train=False,
                                     download=True,
                                     transform=transform)

train_loader = DataLoader(trainset,
                          batch_size =BATCH,
                          shuffle = True,
                          num_workers =2)
test_loader = DataLoader(testset,
                          batch_size =BATCH,
                          shuffle = False,
                          num_workers =2)
####################### DATASET END ############################
################################################################


os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  # Arrange GPU devices starting from 0
os.environ["CUDA_VISIBLE_DEVICES"]= "2,3"  # Set the GPU 2 to use

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

# 모델을 GPU로 이동합니다.
net = MY_SNN_MK1().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

for epoch in range(100):
    print('epoch', epoch)
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        print('\niter', i)
        iter_one_train_time_start = time.time()

        inputs, labels = data
        inputs = inputs.repeat(TIME, 1, IMAGE_PIXEL_CHANNEL, 1, 1)
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        print("Epoch: {}, Iter: {}, Loss: {}".format(epoch + 1, i + 1, running_loss / 100))
        running_loss = 0.0

        iter_one_train_time_end = time.time()
        elapsed_time = iter_one_train_time_end - iter_one_train_time_start  # 실행 시간 계산
        print(f"iter_one_train_time: {elapsed_time} seconds")


        correct = 0
        total = 0
        acc = 0
        if i % 100 == 9:
            iter_one_val_time_start = time.time()

            with torch.no_grad():
                how_many_val_image=0
                for data in test_loader:
                    how_many_val_image += 1
                    images, labels = data
                    images = images.repeat(TIME, 1, IMAGE_PIXEL_CHANNEL, 1, 1)
                    images = images.to(device)
                    labels = labels.to(device)
                    outputs = net(images)
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()
                    # if how_many_val_image > 10:
                    #     break
                print(f'\nvalidation acc\n: {100 * correct / total:.2f}%')


            iter_one_val_time_end = time.time()
            elapsed_time = iter_one_val_time_end - iter_one_val_time_start  # 실행 시간 계산
            print(f"iter_one_val_time: {elapsed_time} seconds")
            if acc < correct / total:
                acc = correct / total
                torch.save(net.state_dict(), "save_now_net_weights.pth")
                torch.save(net, "save_now_net.pth")




epoch 0

iter 0
Epoch: 1, Iter: 1, Loss: 18.902303466796877
iter_one_train_time: 8.941988468170166 seconds

iter 1
Epoch: 1, Iter: 2, Loss: 14.474342041015625
iter_one_train_time: 8.653029203414917 seconds

iter 2
Epoch: 1, Iter: 3, Loss: 15.340267333984375
iter_one_train_time: 10.006683111190796 seconds

iter 3
Epoch: 1, Iter: 4, Loss: 15.04623046875
iter_one_train_time: 7.384870529174805 seconds

iter 4
Epoch: 1, Iter: 5, Loss: 16.676324462890626
iter_one_train_time: 6.261949062347412 seconds

iter 5
Epoch: 1, Iter: 6, Loss: 18.211563720703126
iter_one_train_time: 7.029981374740601 seconds

iter 6
Epoch: 1, Iter: 7, Loss: 21.69876220703125
iter_one_train_time: 7.999080181121826 seconds

iter 7
Epoch: 1, Iter: 8, Loss: 27.79680908203125
iter_one_train_time: 9.965918779373169 seconds

iter 8
Epoch: 1, Iter: 9, Loss: 23.44623291015625
iter_one_train_time: 8.080379724502563 seconds

iter 9
Epoch: 1, Iter: 10, Loss: 18.593533935546876
iter_one_train_time: 7.797908544540405 seconds
15.66%


KeyboardInterrupt: 