In [1]:
# Copyright (c) 2024 Byeonghyeon Kim 
# github site: https://github.com/bhkim003/ByeonghyeonKim
# email: bhkim003@snu.ac.kr
 
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
# the Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
 
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
 
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.


In [9]:
# 중간중간에 .float()라고 해놓은거 지워야 되나
#   나중에 한번 떼고 실험해보자
############################################

import sys
import torchvision
import os
import torch
import torch.nn as nn


# GPU   selection
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  # Arrange GPU devices starting from 0
os.environ["CUDA_VISIBLE_DEVICES"]= "1,2" 

class SYNAPSE_UPDATE(torch.autograd.Function):
    @staticmethod
    def forward(ctx, spike_one_time, spike_now, this_layer):
        ctx.save_for_backward(spike_one_time, spike_now, this_layer)
        return this_layer(spike_one_time)

    @staticmethod
    def backward(ctx, grad_output_current):
        spike_one_time, spike_now, this_layer = ctx.saved_tensors

        # ConvTranspose2d 레이어 생성
        conv_transpose = nn.ConvTranspose2d(
            in_channels=this_layer.out_channels,
            out_channels=this_layer.in_channels,
            kernel_size=this_layer.kernel_size,
            stride=this_layer.stride,
            padding=this_layer.padding,
            bias=False
        )

        # 원래 Conv2d 레이어의 가중치를 사용
        conv_transpose.weight.data = this_layer.weight.data

        # ConvTranspose2d를 사용하여 입력에 대한 그래디언트 계산
        grad_input_spike = conv_transpose(grad_output_current)

        return grad_input_spike, None, None

     



class SYNAPSE_CONV(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, trace_const1=1, trace_const2=0.7):
        super(SYNAPSE_CONV, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.trace_const1 = trace_const1
        self.trace_const2 = trace_const2

        self.conv = nn.Conv2d(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.padding)

    def forward(self, spike):
        # spike: [Time, Batch, Channel, Height, Width]   
        Time = spike.shape[0]
        Batch = spike.shape[1] 
        Channel = self.out_channels
        Height = (spike.shape[3] + self.padding*2 - self.kernel_size)/self.stride + 1
        Width = (spike.shape[4] + self.padding*2 - self.kernel_size)/self.stride + 1
        output_current = torch.zeros(Time, Batch, Channel, Height, Width)

        # spike_detach = spike.detach().clone()
        spike_detach = spike.detach()
        spike_past = torch.zeros_like(spike_detach[0])
        spike_now = torch.zeros_like(spike_detach[0])

        for t in range(Time):
            spike_now = self.trace_const1*spike_detach[t] + self.trace_const2*spike_past
            output_current[t]= SYNAPSE_UPDATE.apply(spike[t], spike_now, self.conv) 
            spike_past = spike_now

        return output_current 



class LIF(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input_current_one_time, v_one_time, v_decay, v_threshold, v_reset, sg_width):
        v_one_time = v_one_time * v_decay + input_current_one_time # leak + pre-synaptic current integrate
        spike = (v_one_time >= v_threshold).float() #fire
        ctx.save_for_backward(v_one_time, v_decay, v_threshold, v_reset, sg_width) # save before reset
        v_one_time = (v_one_time - spike * v_threshold).clamp_min(0) # reset
        return spike, v_one_time

    @staticmethod
    def backward(ctx, grad_output_spike, grad_output_v):
        v_one_time, v_decay, v_threshold, v_reset, sg_width = ctx.saved_tensors
        grad_input_current = grad_output_spike.clone()
        # grad_temp_v = grad_output_v.clone() # not used

        ################ select one of the following surrogate gradient functions ################
        #===========surrogate gradient function (rectangle)
        grad_input_current = grad_input_current * ((v_one_time - v_threshold).abs() < sg_width/2).float() / sg_width

        #===========surrogate gradient function (sigmoid)
        # sig = torch.sigmoid((v_one_time - v_threshold))
        # grad_input_current =  sig*(1-sig)*grad_input_current

        #===========surrogate gradient function (rough rectangle)
        # v_minus_th = (v_one_time - v_threshold)
        # grad_input_current[v_minus_th <= -.5] = 0
        # grad_input_current[v_minus_th > .5] = 0
        ###########################################################################################
        return grad_input_current, None, None, None, None, None

class LIF_layer(nn.Module):
    def __init__ (self, v_init = 0.0, v_decay = 0.8, v_threshold = 0.5, v_reset = 0.0, sg_width = 1):
        super(LIF_layer, self).__init__()
        self.v_init = v_init
        self.v_decay = v_decay
        self.v_threshold = v_threshold
        self.v_reset = v_reset
        self.sg_width = sg_width

    def forward(self, input_current):
        v = torch.full_like(input_current, fill_value = self.v_init, dtype = torch.float) # v (membrane potential) init
        post_spike = torch.zeros_like(input_current, fill_value = self.v_init, dtype = torch.float) # v (membrane potential) init
        # i와 v와 post_spike size는 여기서 다 같음: [Time, Batch, Channel, Height, Width] 

        Time = v.shape[0]
        for t in range(Time):
            # leaky하고 input_current 더하고 fire하고 reset까지 (backward직접처리)
            post_spike[t], v[t] = LIF.apply(input_current[t], v[t], 
                                            self.v_decay, self.v_threshold, self.v_reset, self.sg_width) 
  
        return post_spike
    
    




In [None]:
# https://bo-10000.tistory.com/181
# detach 관련 포스팅

In [1]:
# copilot이 준 autograd conv 코드라는데 별로 쓸 데 없을 거 같음

import torch
from torch.autograd import Function
import torch.nn.functional as F

class Conv2dFunction(Function):
    @staticmethod
    def forward(ctx, input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
        # Save arguments to context to use in backward
        ctx.save_for_backward(input, weight, bias)
        ctx.stride = stride
        ctx.padding = padding
        ctx.dilation = dilation
        ctx.groups = groups

        # Perform forward convolution
        output = F.conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        # Load saved tensors
        input, weight, bias = ctx.saved_tensors

        # Calculate gradients
        grad_input = grad_weight = grad_bias = None
        if ctx.needs_input_grad[0]:
            grad_input = F.conv_transpose2d(grad_output, weight, stride=ctx.stride, padding=ctx.padding, dilation=ctx.dilation, groups=ctx.groups)
        if ctx.needs_input_grad[1]:
            grad_weight = F.conv_transpose2d(input, grad_output, stride=ctx.stride, padding=ctx.padding, dilation=ctx.dilation, groups=ctx.groups)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0).squeeze(0)

        return grad_input, grad_weight, grad_bias, None, None, None, None


In [16]:
# 입력, 가중치, 편향 텐서 생성
input = torch.randn(1, 1, 3, 3, requires_grad=True)
weight = torch.randn(1, 1, 3, 3, requires_grad=True)
bias = torch.randn(1, requires_grad=True)

# Conv2dFunction 사용
output = Conv2dFunction.apply(input, weight, bias)

# 그래디언트 계산
output2 = output.backward(torch.randn(1, 1, 1, 1))

print(output)
print(output2)

tensor([[[[-4.1372]]]], grad_fn=<Conv2dFunctionBackward>)
None


In [15]:
import torch.nn as nn
conv_layer = nn.Conv2d(in_channels=5, out_channels=64, kernel_size=3, stride=1, padding=1)

print(conv_layer.weight.shape)
# out_channels, in_channels, kernel_size, kernel_size

# input은 N, C-in, H, Wimport torch

# 배치 크기가 10인 입력 텐서 생성
input_tensor = torch.randn(10, 5, 32, 32)

# Conv2d 레이어를 통해 입력 텐서 전달
output_tensor = conv_layer(input_tensor)

print(output_tensor.shape)
# 출력: torch.Size([10, 64, 32, 32])






import torch
a = torch.tensor([1,1])


b=torch.rand(2,3)
b = (b>0.5).float()
b





torch.Size([64, 5, 3, 3])
torch.Size([10, 64, 32, 32])
