In [44]:
# Copyright (c) 2024 Byeonghyeon Kim 
# github site: https://github.com/bhkim003/ByeonghyeonKim
# email: bhkim003@snu.ac.kr
 
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
# the Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
 
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
 
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.


In [9]:
# 메인 셀

import sys
import torchvision
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

# GPU   selection
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  # Arrange GPU devices starting from 0
os.environ["CUDA_VISIBLE_DEVICES"]= "1,2" 


class SYNAPSE_FC_METHOD(torch.autograd.Function):
    @staticmethod
    def forward(ctx, spike_one_time, spike_now, weight, bias):
        ctx.save_for_backward(spike_one_time, spike_now, weight, bias)
        return F.linear(spike_one_time, weight, bias=bias)

    @staticmethod
    def backward(ctx, grad_output_current):
        #############밑에부터 수정해라#######
        spike_one_time, spike_now, weight, bias, stride, padding = ctx.saved_tensors
        
        ## 이거 클론해야되는지 모르겠음!!!!
        grad_output_current_clone = grad_output_current.clone()

        grad_input_spike = grad_weight = grad_bias = None


        if ctx.needs_input_grad[0]:
            grad_input_spike = F.conv_transpose2d(grad_output_current, weight, stride=stride, padding=padding)
        if ctx.needs_input_grad[2]:
            # grad_output_padded = F.pad(grad_output_current, (padding, padding, padding, padding))
            # grad_weight = F.conv2d(spike_now, grad_output_padded)
            grad_weight = F.conv2d(spike_now, grad_output_current, stride=stride, padding=padding)
        if bias is not None and ctx.needs_input_grad[3]:
            grad_bias = grad_output_current_clone.sum(0).squeeze(0)

        return grad_input_spike, None, grad_weight, grad_bias, None, None

     
class SYNAPSE_FC(nn.Module):
    def __init__(self, in_features, out_features, trace_const1=1, trace_const2=0.7):
        super(SYNAPSE_FC, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.trace_const1 = trace_const1
        self.trace_const2 = trace_const2

        self.weight = torch.randn(self.out_features, self.in_features, requires_grad=True)
        self.bias = torch.randn(self.out_features, requires_grad=True)
        
    def forward(self, spike):
        # spike: [Time, Batch, Features]   
        Time = spike.shape[0]
        Batch = spike.shape[1] 
        output_current = torch.zeros(Time, Batch, self.out_features)

        # spike_detach = spike.detach().clone()
        spike_detach = spike.detach()
        spike_past = torch.zeros_like(spike_detach[0])
        spike_now = torch.zeros_like(spike_detach[0])

        for t in range(Time):
            spike_now = self.trace_const1*spike_detach[t] + self.trace_const2*spike_past
            output_current[t]= SYNAPSE_FC_METHOD.apply(spike[t], spike_now, self.weight, self.bias) 
            spike_past = spike_now

        return output_current 


class SYNAPSE_CONV_METHOD(torch.autograd.Function):
    @staticmethod
    def forward(ctx, spike_one_time, spike_now, weight, bias, stride=1, padding=1):
        ctx.save_for_backward(spike_one_time, spike_now, weight, bias, stride, padding)
        return F.conv2d(spike_one_time, weight, bias=bias, stride=stride, padding=padding)

    @staticmethod
    def backward(ctx, grad_output_current):
        spike_one_time, spike_now, weight, bias, stride, padding = ctx.saved_tensors
        
        ## 이거 클론해야되는지 모르겠음!!!!
        grad_output_current_clone = grad_output_current.clone()

        grad_input_spike = grad_weight = grad_bias = None


        if ctx.needs_input_grad[0]:
            grad_input_spike = F.conv_transpose2d(grad_output_current, weight, stride=stride, padding=padding)
        if ctx.needs_input_grad[2]:
            # grad_output_padded = F.pad(grad_output_current, (padding, padding, padding, padding))
            # grad_weight = F.conv2d(spike_now, grad_output_padded)
            grad_weight = F.conv2d(spike_now, grad_output_current, stride=stride, padding=padding)
        if bias is not None and ctx.needs_input_grad[3]:
            grad_bias = grad_output_current_clone.sum(0).squeeze(0)

        return grad_input_spike, None, grad_weight, grad_bias, None, None

     



class SYNAPSE_CONV(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, trace_const1=1, trace_const2=0.7):
        super(SYNAPSE_CONV, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.trace_const1 = trace_const1
        self.trace_const2 = trace_const2

        # self.conv = nn.Conv2d(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.padding)
        self.weight = torch.randn(self.out_channels, self.in_channels, self.kernel_size, self.kernel_size, requires_grad=True)
        self.bias = torch.randn(self.out_channels, requires_grad=True)

    def forward(self, spike):
        # spike: [Time, Batch, Channel, Height, Width]   
        Time = spike.shape[0]
        Batch = spike.shape[1] 
        Channel = self.out_channels
        Height = (spike.shape[3] + self.padding*2 - self.kernel_size)/self.stride + 1
        Width = (spike.shape[4] + self.padding*2 - self.kernel_size)/self.stride + 1
        output_current = torch.zeros(Time, Batch, Channel, Height, Width)

        # spike_detach = spike.detach().clone()
        spike_detach = spike.detach()
        spike_past = torch.zeros_like(spike_detach[0])
        spike_now = torch.zeros_like(spike_detach[0])

        for t in range(Time):
            spike_now = self.trace_const1*spike_detach[t] + self.trace_const2*spike_past
            output_current[t]= SYNAPSE_CONV_METHOD.apply(spike[t], spike_now, self.weight, self.bias, self.stride, self.padding) 
            spike_past = spike_now

        return output_current 



class LIF(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input_current_one_time, v_one_time, v_decay, v_threshold, v_reset, sg_width):
        v_one_time = v_one_time * v_decay + input_current_one_time # leak + pre-synaptic current integrate
        spike = (v_one_time >= v_threshold).float() #fire
        ctx.save_for_backward(v_one_time, v_decay, v_threshold, v_reset, sg_width) # save before reset
        v_one_time = (v_one_time - spike * v_threshold).clamp_min(0) # reset
        return spike, v_one_time

    @staticmethod
    def backward(ctx, grad_output_spike, grad_output_v):
        v_one_time, v_decay, v_threshold, v_reset, sg_width = ctx.saved_tensors
        grad_input_current = grad_output_spike.clone()
        # grad_temp_v = grad_output_v.clone() # not used

        ################ select one of the following surrogate gradient functions ################
        #===========surrogate gradient function (rectangle)
        grad_input_current = grad_input_current * ((v_one_time - v_threshold).abs() < sg_width/2).float() / sg_width

        #===========surrogate gradient function (sigmoid)
        # sig = torch.sigmoid((v_one_time - v_threshold))
        # grad_input_current =  sig*(1-sig)*grad_input_current

        #===========surrogate gradient function (rough rectangle)
        # v_minus_th = (v_one_time - v_threshold)
        # grad_input_current[v_minus_th <= -.5] = 0
        # grad_input_current[v_minus_th > .5] = 0
        ###########################################################################################
        return grad_input_current, None, None, None, None, None

class LIF_layer(nn.Module):
    def __init__ (self, v_init = 0.0, v_decay = 0.8, v_threshold = 0.5, v_reset = 0.0, sg_width = 1):
        super(LIF_layer, self).__init__()
        self.v_init = v_init
        self.v_decay = v_decay
        self.v_threshold = v_threshold
        self.v_reset = v_reset
        self.sg_width = sg_width

    def forward(self, input_current):
        v = torch.full_like(input_current, fill_value = self.v_init, dtype = torch.float) # v (membrane potential) init
        post_spike = torch.zeros_like(input_current, fill_value = self.v_init, dtype = torch.float) # v (membrane potential) init
        # i와 v와 post_spike size는 여기서 다 같음: [Time, Batch, Channel, Height, Width] 

        Time = v.shape[0]
        for t in range(Time):
            # leaky하고 input_current 더하고 fire하고 reset까지 (backward직접처리)
            post_spike[t], v[t] = LIF.apply(input_current[t], v[t], 
                                            self.v_decay, self.v_threshold, self.v_reset, self.sg_width) 

        return post_spike
    
    




In [93]:
#커스텀 컨볼루션 레이어 되는지 확인하는 셀


import torch
import torch.nn as nn
import torch.nn.functional as F

class MyConv2d(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, weight, bias=None, stride=1, padding=1, dilation=1, groups=1):
        # Save input and weight for backward pass
        ctx.save_for_backward(input, weight, bias)
        ctx.stride = stride
        ctx.padding = padding
        ctx.dilation = dilation
        ctx.groups = groups

        # Perform forward pass
        output = F.conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        # Retrieve tensors from the forward pass
        input, weight, bias = ctx.saved_tensors
        stride = ctx.stride
        padding = ctx.padding
        dilation = ctx.dilation
        groups = ctx.groups

        # Compute gradients w.r.t input and weight
        grad_input = grad_weight = grad_bias = None


        if ctx.needs_input_grad[0]:
            grad_input = F.conv_transpose2d(grad_output, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
        if ctx.needs_input_grad[1]:
            # grad_output_padded = F.pad(grad_output, (padding, padding, padding, padding))
            # grad_weight = F.conv2d(input, grad_output_padded)
            grad_weight = F.conv2d(input, grad_output, stride=stride, padding=padding, dilation=dilation, groups=groups)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0).squeeze(0)

        # if ctx.needs_input_grad[0]:
        #     grad_input = F.conv_transpose2d(grad_output, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
        #     grad_input = torch.round(grad_input).int()
        # if ctx.needs_input_grad[1]:
        #     grad_output_padded = F.pad(grad_output, (padding, padding, padding, padding))
        #     grad_weight = F.conv2d(input, grad_output_padded)
        #     grad_weight = torch.round(grad_weight).int()
        # if bias is not None and ctx.needs_input_grad[2]:
        #     grad_bias = grad_output.sum(0).squeeze(0)
        #     grad_bias = torch.round(grad_bias).int()



        return grad_input, grad_weight, grad_bias, None, None, None, None
    
    
# Usage
my_conv = MyConv2d.apply
input = torch.randn(1, 1, 5, 5, requires_grad=True)
input2 = input.clone().detach().requires_grad_(True)
weight = torch.randn(1, 1, 3, 3, requires_grad=True)
weight2 = weight.clone().detach().requires_grad_(True)
bias = torch.randn(1, requires_grad=True)
bias2 = bias.clone().detach().requires_grad_(True)
output = my_conv(input, weight, bias,1,1,1,1)
output.sum().backward()
print(input.grad)
print(weight.grad)
print(bias.grad)

# Define the convolutional layer
conv = nn.Conv2d(1, 1, 3, 1,1,1,1)
conv.weight = nn.Parameter(weight2)
conv.bias = nn.Parameter(bias2)

# Forward pass
output_conv = conv(input2)

# Backward pass
output_conv.sum().backward()

print(input2.grad)
print(conv.weight.grad)
print(conv.bias.grad)

print(torch.allclose(input.grad, input2.grad))
print(torch.allclose(weight.grad, conv.weight.grad))
print(torch.allclose(bias.grad, conv.bias.grad))


tensor([[[[2.8091, 5.3851, 5.3851, 5.3851, 2.5255],
          [1.7492, 4.8344, 4.8344, 4.8344, 2.7727],
          [1.7492, 4.8344, 4.8344, 4.8344, 2.7727],
          [1.7492, 4.8344, 4.8344, 4.8344, 2.7727],
          [0.9901, 2.9766, 2.9766, 2.9766, 2.6546]]]])
tensor([[[[ 4.7602,  4.6082,  3.2449],
          [ 1.5601,  1.5482, -0.2261],
          [-0.5064, -0.3129, -1.3671]]]])
tensor([25.])
tensor([[[[2.8091, 5.3851, 5.3851, 5.3851, 2.5255],
          [1.7492, 4.8344, 4.8344, 4.8344, 2.7727],
          [1.7492, 4.8344, 4.8344, 4.8344, 2.7727],
          [1.7492, 4.8344, 4.8344, 4.8344, 2.7727],
          [0.9901, 2.9766, 2.9766, 2.9766, 2.6546]]]])
tensor([[[[ 4.7602,  4.6082,  3.2449],
          [ 1.5601,  1.5482, -0.2261],
          [-0.5064, -0.3129, -1.3671]]]])
tensor([25.])
True
True
True


In [None]:
# https://bo-10000.tistory.com/181
# detach 관련 포스팅

In [52]:
import torch.nn as nn
conv_layer = nn.Conv2d(in_channels=5, out_channels=64, kernel_size=3, stride=1, padding=1)

print(conv_layer.weight.shape)
print(conv_layer.weight)
# out_channels, in_channels, kernel_size, kernel_size

# input은 N, C-in, H, Wimport torch

# 배치 크기가 10인 입력 텐서 생성
input_tensor = torch.randn(10, 5, 32, 32)

# Conv2d 레이어를 통해 입력 텐서 전달
output_tensor = conv_layer(input_tensor)

print(output_tensor.shape)
# 출력: torch.Size([10, 64, 32, 32])






import torch
a = torch.tensor([1,1])


b=torch.rand(2,3)
b = (b>0.5).float()
b


print(conv_layer.bias.shape)
print(conv_layer.weight.shape)


torch.Size([64, 5, 3, 3])
Parameter containing:
tensor([[[[-0.0611, -0.0643, -0.1282],
          [ 0.0746,  0.0501,  0.0935],
          [-0.0972, -0.0330,  0.1036]],

         [[ 0.1026, -0.1350, -0.0846],
          [ 0.0761,  0.0362,  0.0318],
          [ 0.1138, -0.1256,  0.1194]],

         [[-0.0611,  0.0265,  0.1006],
          [-0.0847,  0.1399,  0.1136],
          [ 0.0560,  0.0165, -0.0669]],

         [[-0.0692,  0.1255,  0.0081],
          [-0.0638,  0.0958, -0.0263],
          [ 0.1478,  0.0063, -0.0636]],

         [[ 0.0913, -0.0478,  0.1305],
          [-0.0689,  0.0300,  0.1242],
          [-0.0663,  0.1127,  0.0968]]],


        [[[ 0.0884, -0.1476, -0.1015],
          [ 0.1254, -0.0437, -0.0883],
          [-0.0084, -0.1422, -0.0645]],

         [[-0.0779, -0.0813,  0.0893],
          [-0.0158, -0.0726, -0.1080],
          [-0.1322, -0.1105, -0.0963]],

         [[-0.0633,  0.0936,  0.0793],
          [ 0.0351, -0.1105,  0.0954],
          [ 0.1288,  0.0475,  0.1034]],