In [44]:
# Copyright (c) 2024 Byeonghyeon Kim 
# github site: https://github.com/bhkim003/ByeonghyeonKim
# email: bhkim003@snu.ac.kr
 
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
# the Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
 
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
 
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.


In [9]:
# 중간중간에 .float()라고 해놓은거 지워야 되나
#   나중에 한번 떼고 실험해보자
############################################

import sys
import torchvision
import os
import torch
import torch.nn as nn

# GPU   selection
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  # Arrange GPU devices starting from 0
os.environ["CUDA_VISIBLE_DEVICES"]= "1,2" 





class SYNAPSE_CONV_METHOD(torch.autograd.Function):
    @staticmethod
    def forward(ctx, spike_one_time, spike_now, weight, bias, in_channels, out_channels, kernel_size, stride, padding):
        this_layer = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        this_layer.weight.data = weight.data
        this_layer.bias.data = bias.data    

        ctx.save_for_backward(spike_one_time, spike_now, weight, bias, in_channels, out_channels, kernel_size, stride, padding)

        return this_layer(spike_one_time)

    @staticmethod
    def backward(ctx, grad_output_current):
        spike_one_time, spike_now, weight, bias, in_channels, out_channels, kernel_size, stride, padding = ctx.saved_tensors

        conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        
        # ConvTranspose2d 레이어 생성
        conv_transpose = nn.ConvTranspose2d(
            in_channels=out_channels,
            out_channels=in_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            bias=True
        )

        # 원래 Conv2d 레이어의 가중치를 사용
        conv_transpose.weight.data = this_layer.weight.data

        # ConvTranspose2d를 사용하여 입력에 대한 그래디언트 계산
        grad_input_spike = conv_transpose(grad_output_current)

        return grad_input_spike, None, grad_weight, grad_bias, None, None, None, None, None

     



class SYNAPSE_CONV(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, trace_const1=1, trace_const2=0.7):
        super(SYNAPSE_CONV, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.trace_const1 = trace_const1
        self.trace_const2 = trace_const2

        self.conv = nn.Conv2d(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.padding)
        self.weight = self.conv.weight
        self.bias = self.conv.bias

    def forward(self, spike):
        # spike: [Time, Batch, Channel, Height, Width]   
        Time = spike.shape[0]
        Batch = spike.shape[1] 
        Channel = self.out_channels
        Height = (spike.shape[3] + self.padding*2 - self.kernel_size)/self.stride + 1
        Width = (spike.shape[4] + self.padding*2 - self.kernel_size)/self.stride + 1
        output_current = torch.zeros(Time, Batch, Channel, Height, Width)

        # spike_detach = spike.detach().clone()
        spike_detach = spike.detach()
        spike_past = torch.zeros_like(spike_detach[0])
        spike_now = torch.zeros_like(spike_detach[0])

        for t in range(Time):
            spike_now = self.trace_const1*spike_detach[t] + self.trace_const2*spike_past
            output_current[t]= SYNAPSE_CONV_METHOD.apply(spike[t], spike_now, self.conv.weight, self.bias,
                                             self.in_channels, self.out_channels, self.kernel_size, self.stride, self.padding) 
            spike_past = spike_now

        return output_current 



class LIF(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input_current_one_time, v_one_time, v_decay, v_threshold, v_reset, sg_width):
        v_one_time = v_one_time * v_decay + input_current_one_time # leak + pre-synaptic current integrate
        spike = (v_one_time >= v_threshold).float() #fire
        ctx.save_for_backward(v_one_time, v_decay, v_threshold, v_reset, sg_width) # save before reset
        v_one_time = (v_one_time - spike * v_threshold).clamp_min(0) # reset
        return spike, v_one_time

    @staticmethod
    def backward(ctx, grad_output_spike, grad_output_v):
        v_one_time, v_decay, v_threshold, v_reset, sg_width = ctx.saved_tensors
        grad_input_current = grad_output_spike.clone()
        # grad_temp_v = grad_output_v.clone() # not used

        ################ select one of the following surrogate gradient functions ################
        #===========surrogate gradient function (rectangle)
        grad_input_current = grad_input_current * ((v_one_time - v_threshold).abs() < sg_width/2).float() / sg_width

        #===========surrogate gradient function (sigmoid)
        # sig = torch.sigmoid((v_one_time - v_threshold))
        # grad_input_current =  sig*(1-sig)*grad_input_current

        #===========surrogate gradient function (rough rectangle)
        # v_minus_th = (v_one_time - v_threshold)
        # grad_input_current[v_minus_th <= -.5] = 0
        # grad_input_current[v_minus_th > .5] = 0
        ###########################################################################################
        return grad_input_current, None, None, None, None, None

class LIF_layer(nn.Module):
    def __init__ (self, v_init = 0.0, v_decay = 0.8, v_threshold = 0.5, v_reset = 0.0, sg_width = 1):
        super(LIF_layer, self).__init__()
        self.v_init = v_init
        self.v_decay = v_decay
        self.v_threshold = v_threshold
        self.v_reset = v_reset
        self.sg_width = sg_width

    def forward(self, input_current):
        v = torch.full_like(input_current, fill_value = self.v_init, dtype = torch.float) # v (membrane potential) init
        post_spike = torch.zeros_like(input_current, fill_value = self.v_init, dtype = torch.float) # v (membrane potential) init
        # i와 v와 post_spike size는 여기서 다 같음: [Time, Batch, Channel, Height, Width] 

        Time = v.shape[0]
        for t in range(Time):
            # leaky하고 input_current 더하고 fire하고 reset까지 (backward직접처리)
            post_spike[t], v[t] = LIF.apply(input_current[t], v[t], 
                                            self.v_decay, self.v_threshold, self.v_reset, self.sg_width) 

        return post_spike
    
    




In [45]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MyConv2d(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
        # Save input and weight for backward pass
        ctx.save_for_backward(input, weight, bias)
        ctx.stride = stride
        ctx.padding = padding
        ctx.dilation = dilation
        ctx.groups = groups

        # Perform forward pass
        output = F.conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        # Retrieve tensors from the forward pass
        input, weight, bias = ctx.saved_tensors
        stride = ctx.stride
        padding = ctx.padding
        dilation = ctx.dilation
        groups = ctx.groups

        # Compute gradients w.r.t input and weight
        grad_input = grad_weight = grad_bias = None

        if ctx.needs_input_grad[0]:
            grad_input = F.conv_transpose2d(grad_output, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
            grad_input = torch.round(grad_input).int()
        if ctx.needs_input_grad[1]:
            grad_output_padded = F.pad(grad_output, (padding, padding, padding, padding))
            grad_weight = F.conv2d(input, grad_output_padded)
            grad_weight = torch.round(grad_weight).int()
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0).squeeze(0)
            grad_bias = torch.round(grad_bias).int()

        return grad_input, grad_weight, grad_bias, None, None, None, None
    
    
# Usage
my_conv = MyConv2d.apply
input = torch.randn(1, 1, 5, 5, requires_grad=True)
weight = torch.randn(1, 1, 3, 3, requires_grad=True)
output = my_conv(input, weight)
output.sum().backward()
print(input.grad)
print(weight.grad)

tensor([[[[ 1.,  1.,  1.,  0.,  0.],
          [ 1.,  2.,  3.,  2.,  1.],
          [ 0.,  1.,  2.,  2.,  1.],
          [-1.,  0.,  2.,  2.,  1.],
          [-1., -1., -1.,  0.,  0.]]]])
tensor([[[[ 2.,  5.,  2.],
          [ 0.,  2.,  0.],
          [-2.,  1.,  0.]]]])


In [47]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MyConv2d(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, weight, bias=None, stride=1, padding=1, dilation=1, groups=1):
        # Save input and weight for backward pass
        ctx.save_for_backward(input, weight, bias)
        ctx.stride = stride
        ctx.padding = padding
        ctx.dilation = dilation
        ctx.groups = groups

        # Perform forward pass
        output = F.conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        # Retrieve tensors from the forward pass
        input, weight, bias = ctx.saved_tensors
        stride = ctx.stride
        padding = ctx.padding
        dilation = ctx.dilation
        groups = ctx.groups

        # Compute gradients w.r.t input and weight
        grad_input = grad_weight = grad_bias = None

        if ctx.needs_input_grad[0]:
            grad_input = F.conv_transpose2d(grad_output, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
        if ctx.needs_input_grad[1]:
            grad_output_padded = F.pad(grad_output, (padding, padding, padding, padding))
            grad_weight = F.conv2d(input, grad_output_padded)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0).squeeze(0)

        return grad_input, grad_weight, grad_bias, None, None, None, None
# Usage
my_conv = MyConv2d.apply
input = torch.randn(1, 1, 5, 5, requires_grad=True)
weight = torch.randn(1, 1, 3, 3, requires_grad=True)
output = my_conv(input, weight)
output.sum().backward()
print(input.grad)
print(weight.grad)

tensor([[[[ 1.4130,  1.2828,  1.9121,  0.4991,  0.6293],
          [ 3.9457,  3.2379,  2.6662, -1.2795, -0.5717],
          [ 4.0311,  3.2764,  3.0860, -0.9451, -0.1904],
          [ 2.6182,  1.9936,  1.1739, -1.4442, -0.8197],
          [ 0.0855,  0.0386,  0.4198,  0.3344,  0.3813]]]])
tensor([[[[-3.3715, -5.6048, -3.5409],
          [ 1.1237, -2.3749,  0.7058],
          [ 0.0332,  0.6295,  4.6315]]]])


In [51]:
# MyConv2d 사용
my_conv = MyConv2d.apply
input_my_conv = torch.randn(1, 1, 5, 5, requires_grad=True)
weight_my_conv = torch.randn(1, 1, 3, 3, requires_grad=True)
output_my_conv = my_conv(input_my_conv, weight_my_conv)
output_my_conv.sum().backward()

# nn.Conv2d 사용
conv = nn.Conv2d(1, 1, 3, stride=1, padding=0, dilation=1, groups=1)
input_conv = input_my_conv.clone().detach().requires_grad_(True)
weight_conv = weight_my_conv.clone().detach().requires_grad_(True)
conv.weight = nn.Parameter(weight_conv)
output_conv = conv(input_conv)
output_conv.sum().backward()

# 그래디언트 비교
print(torch.allclose(input_my_conv.grad, input_conv.grad))  # True면 입력에 대한 그래디언트가 동일
print(weight_conv.grad is not None)  # True면 가중치에 대한 그래디언트가 계산됨
print(torch.allclose(weight_my_conv.grad, weight_conv.grad))  # True면 가중치에 대한 그래디언트가 동일

True
False


TypeError: allclose(): argument 'other' (position 2) must be Tensor, not NoneType

In [None]:
# https://bo-10000.tistory.com/181
# detach 관련 포스팅

In [1]:
# copilot이 준 autograd conv 코드라는데 별로 쓸 데 없을 거 같음

import torch
from torch.autograd import Function
import torch.nn.functional as F

class Conv2dFunction(Function):
    @staticmethod
    def forward(ctx, input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
        # Save arguments to context to use in backward
        ctx.save_for_backward(input, weight, bias)
        ctx.stride = stride
        ctx.padding = padding
        ctx.dilation = dilation
        ctx.groups = groups

        # Perform forward convolution
        output = F.conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        # Load saved tensors
        input, weight, bias = ctx.saved_tensors

        # Calculate gradients
        grad_input = grad_weight = grad_bias = None
        if ctx.needs_input_grad[0]:
            grad_input = F.conv_transpose2d(grad_output, weight, stride=ctx.stride, padding=ctx.padding, dilation=ctx.dilation, groups=ctx.groups)
        if ctx.needs_input_grad[1]:
            grad_weight = F.conv_transpose2d(input, grad_output, stride=ctx.stride, padding=ctx.padding, dilation=ctx.dilation, groups=ctx.groups)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0).squeeze(0)

        return grad_input, grad_weight, grad_bias, None, None, None, None


In [16]:
# 입력, 가중치, 편향 텐서 생성
input = torch.randn(1, 1, 3, 3, requires_grad=True)
weight = torch.randn(1, 1, 3, 3, requires_grad=True)
bias = torch.randn(1, requires_grad=True)

# Conv2dFunction 사용
output = Conv2dFunction.apply(input, weight, bias)

# 그래디언트 계산
output2 = output.backward(torch.randn(1, 1, 1, 1))

print(output)
print(output2)

tensor([[[[-4.1372]]]], grad_fn=<Conv2dFunctionBackward>)
None


In [29]:
import torch.nn as nn
conv_layer = nn.Conv2d(in_channels=5, out_channels=64, kernel_size=3, stride=1, padding=1)

print(conv_layer.weight.shape)
# out_channels, in_channels, kernel_size, kernel_size

# input은 N, C-in, H, Wimport torch

# 배치 크기가 10인 입력 텐서 생성
input_tensor = torch.randn(10, 5, 32, 32)

# Conv2d 레이어를 통해 입력 텐서 전달
output_tensor = conv_layer(input_tensor)

print(output_tensor.shape)
# 출력: torch.Size([10, 64, 32, 32])






import torch
a = torch.tensor([1,1])


b=torch.rand(2,3)
b = (b>0.5).float()
b


print(conv_layer.bias.shape)
print(conv_layer.weight.shape)


torch.Size([64, 5, 3, 3])
torch.Size([10, 64, 32, 32])
torch.Size([64])
torch.Size([64, 5, 3, 3])


In [38]:
print(conv_layer.weight.data)
conv_layer.weight.data = conv_layer.weight.data + 1
# print(conv_layer.weight.data)


tensor([[[[2.0601, 2.0517, 1.9235],
          [1.9142, 1.8625, 2.0826],
          [1.9622, 2.1115, 1.9114]],

         [[2.0156, 2.0409, 1.9783],
          [1.9721, 2.0103, 1.8834],
          [1.9347, 2.0165, 1.9584]],

         [[1.9579, 1.9664, 2.0507],
          [1.8645, 2.0277, 2.0766],
          [1.8842, 2.0763, 2.1095]],

         [[2.1363, 1.8818, 2.1197],
          [2.1128, 1.8667, 1.9653],
          [2.1313, 2.1334, 1.9699]],

         [[1.8600, 1.9308, 2.0743],
          [2.0305, 2.0491, 2.1310],
          [1.8754, 1.9908, 1.8914]]],


        [[[1.9892, 1.8944, 1.9313],
          [1.9487, 2.0942, 2.0314],
          [1.9013, 2.0375, 2.1096]],

         [[1.9986, 1.9797, 1.9915],
          [2.1212, 2.0582, 2.1400],
          [1.9639, 2.0424, 1.9789]],

         [[2.0781, 1.9984, 2.1082],
          [1.9477, 1.8938, 1.9783],
          [1.8922, 1.9297, 2.0864]],

         [[2.1104, 2.1118, 2.0031],
          [2.0683, 2.1434, 2.0573],
          [1.8828, 2.1452, 1.9578]],

        