In [1]:
# Copyright (c) 2024 Byeonghyeon Kim 
# github site: https://github.com/bhkim003/ByeonghyeonKim
# email: bhkim003@snu.ac.kr
 
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
# the Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
 
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
 
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.


In [9]:
# 중간중간에 .float()라고 해놓은거 지워야 되나
# 나중에 한번 떼고 실험해보자
############################################

import sys
import torchvision
import os
import torch
import torch.nn as nn


# GPU selection
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  # Arrange GPU devices starting from 0
os.environ["CUDA_VISIBLE_DEVICES"]= "1,2" 

class SYNAPSE_UPDATE(torch.autograd.Function):
    @staticmethod
    def forward():
        print('adsf')


    @staticmethod
    def backward():
        print('adsf')

class SYNAPSE_CONV(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, 
                 trace_const1, trace_const2):
        super(SYNAPSE_CONV, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.trace_const1 = trace_const1
        self.trace_const2 = trace_const2

        self.conv = nn.Conv2d(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.padding)

    def forward(self, pre_spike):
        # pre_spike: [Time, Batch, Channel, Height, Width]   
        Time = pre_spike.shape[0]
        Batch = pre_spike.shape[1] 
        Channel = self.out_channels
        Height = (pre_spike.shape[3] + self.padding*2 - self.kernel_size)/self.stride + 1
        Width = (pre_spike.shape[4] + self.padding*2 - self.kernel_size)/self.stride + 1
        post_spike = torch.zeros(Time, Batch, Channel, Height, Width)

        # pre_spike_detach = pre_spike.detach().clone()
        pre_spike_detach = pre_spike.detach()
        pre_spike_past = torch.zeros_like(pre_spike_detach[0])
        pre_spike_now = torch.zeros_like(pre_spike_detach[0])

        for t in range(Time):
            pre_spike_now = self.trace_const1*pre_spike_detach[t] + self.trace_const2*pre_spike_past
            post_spike[t]= SYNAPSE_UPDATE.apply(pre_spike[t], pre_spike_now, self.conv) 
            pre_spike_past = pre_spike_now

        return post_spike



class LIF(torch.autograd.Function):
    @staticmethod
    def forward(ctx, i, v, v_decay, v_threshold, v_reset, sg_width):
        v = v * v_decay + i # leak + pre-synaptic current integrate
        spike = (v >= v_threshold).float() #fire
        ctx.save_for_backward(v, v_decay, v_threshold, v_reset, sg_width) # save before reset
        v = (v - spike * v_threshold).clamp_min(0) # reset
        return spike, v

    @staticmethod
    def backward(ctx, grad_output_spike, grad_output_v):
        v, v_decay, v_threshold, v_reset, sg_width = ctx.saved_tensors
        grad_input_current = grad_output_spike.clone()
        # grad_temp_v = grad_output_v.clone() # not used

        ################ select one of the following surrogate gradient functions ################
        #===========surrogate gradient function (rectangle)
        grad_input_current = grad_input_current * ((v - v_threshold).abs() < sg_width/2).float() / sg_width

        #===========surrogate gradient function (sigmoid)
        # sig = torch.sigmoid((v - v_threshold))
        # grad_input_current =  sig*(1-sig)*grad_input_current

        #===========surrogate gradient function (rough rectangle)
        # v_minus_th = (v - v_threshold)
        # grad_input_current[v_minus_th <= -.5] = 0
        # grad_input_current[v_minus_th > .5] = 0
        ###########################################################################################
        return grad_input_current, None, None, None, None, None

class LIF_layer(nn.Module):
    def __init__ (self, v_init = 0.0, v_decay = 0.8, v_threshold = 0.5, v_reset = 0.0, sg_width = 1):
        super(LIF_layer, self).__init__()
        self.v_init = v_init
        self.v_decay = v_decay
        self.v_threshold = v_threshold
        self.v_reset = v_reset
        self.sg_width = sg_width

    def forward(self, i):
        v = torch.full_like(i, fill_value = self.v_init, dtype = torch.float) # v (membrane potential) init
        post_spike = torch.zeros_like(i, fill_value = self.v_init, dtype = torch.float) # v (membrane potential) init
        # i와 v와 post_spike size는 여기서 다 같음: [Time, Batch, Channel, Height, Width] 

        Time = v.shape[0]
        for t in range(Time):
            # leaky하고 i (current) 더하고 fire하고 reset까지 (backward직접처리)
            post_spike[t], v[t] = LIF.apply(i[t], v[t], 
                                            self.v_decay, self.v_threshold, self.v_reset, self.sg_width) 

        return post_spike
    
    




In [None]:
# https://bo-10000.tistory.com/181
# detach 관련 포스팅

In [7]:
import torch
a = torch.tensor([1,1])


b=torch.rand(2,3)
b = (b>0.5).float()
b

tensor([[1., 1., 1.],
        [1., 0., 1.]])