In [1]:
# Copyright (c) 2024 Byeonghyeon Kim 
# github site: https://github.com/bhkim003/ByeonghyeonKim
 
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
# the Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
 
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
 
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.


In [None]:

import sys
import torchvision
import os
import torch
import torch.nn as nn


os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  # Arrange GPU devices starting from 0
os.environ["CUDA_VISIBLE_DEVICES"]= "1,2" 



class LIF_neuron(nn.Module):
    def __init__ (self, v_decay = 0.8, v_threshold = 0.5, v_init = 0.0, v_reset = 0.0):
        super(LIF_neuron, self).__init__()
        self.v_decay = v_decay
        self.v_threshold = v_threshold
        self.v_init = v_init
        self.v_reset = v_reset
        self.v = v_init
         
    




class LIFNeuron(nn.Module):
    def __init__(self, tau_mem=20.0, tau_syn=5.0, v_reset=0.0, v_threshold=1.0):
        super(LIFNeuron, self).__init__()

        # Parameters
        self.tau_mem = tau_mem
        self.tau_syn = tau_syn
        self.v_reset = v_reset
        self.v_threshold = v_threshold

        # State variables
        self.v = v_reset
        self.i_syn = 0.0

    def forward(self, i_in):
        dt = 1.0  # time step

        # Update synaptic current
        self.i_syn += (-self.i_syn + i_in) / self.tau_syn * dt

        # Update membrane potential
        dv = (-self.v + self.i_syn) / self.tau_mem * dt
        self.v += dv

        # Check for spike
        spike = (self.v >= self.v_threshold).float()
        
        # Reset
        self.v = torch.where(self.v >= self.v_threshold, self.v_reset, self.v)

        return spike

In [1]:
class LIFNeuronFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, i_in, v, i_syn, tau_mem, tau_syn, v_reset, v_threshold):
        dt = 1.0  # time step

        # Update synaptic current
        i_syn_new = i_syn + (-i_syn + i_in) / tau_syn * dt

        # Update membrane potential
        dv = (-v + i_syn_new) / tau_mem * dt
        v_new = v + dv

        # Check for spike
        spike = (v_new >= v_threshold).float()

        # Reset
        v_new = torch.where(v_new >= v_threshold, v_reset, v_new)

        ctx.save_for_backward(i_in, v, i_syn, tau_mem, tau_syn, v_reset, v_threshold, spike)

        return spike, v_new, i_syn_new

    @staticmethod
    def backward(ctx, grad_output):
        i_in, v, i_syn, tau_mem, tau_syn, v_reset, v_threshold, spike = ctx.saved_tensors

        # Compute gradients (this is just an example, you should replace this with the correct gradients)
        grad_i_in = grad_v = grad_i_syn = grad_tau_mem = grad_tau_syn = grad_v_reset = grad_v_threshold = None

        if ctx.needs_input_grad[0]:
            grad_i_in = grad_output.clone()

        return grad_i_in, grad_v, grad_i_syn, grad_tau_mem, grad_tau_syn, grad_v_reset, grad_v_threshold

class LIFNeuron(nn.Module):
    def __init__(self, tau_mem=20.0, tau_syn=5.0, v_reset=0.0, v_threshold=1.0):
        super(LIFNeuron, self).__init__()

        # Parameters
        self.tau_mem = tau_mem
        self.tau_syn = tau_syn
        self.v_reset = v_reset
        self.v_threshold = v_threshold

        # State variables
        self.v = nn.Parameter(torch.tensor(v_reset))
        self.i_syn = nn.Parameter(torch.tensor(0.0))

    def forward(self, i_in):
        return LIFNeuronFunction.apply(i_in, self.v, self.i_syn, self.tau_mem, self.tau_syn, self.v_reset, self.v_threshold)

NameError: name 'torch' is not defined