In [1]:
# Copyright (c) 2024 Byeonghyeon Kim 
# github site: https://github.com/bhkim003/ByeonghyeonKim
# email: bhkim003@snu.ac.kr
 
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
# the Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
 
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
 
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.


In [46]:
import torch
from torchvision import datasets, transforms

# MNIST 데이터 로드 및 전처리
def load_mnist():
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
    trainset = datasets.MNIST('/data2/mnist/.pytorch/MNIST_data/', download=True, train=True, transform=transform)
    testset = datasets.MNIST('/data2/mnist/.pytorch/MNIST_data/', download=True, train=False, transform=transform)
    return trainset, testset

# LIF 뉴런 모델
class LIFNeuron:
    def __init__(self, tau=10, threshold=1, reset_potential=0, dt=1):
        self.tau = tau
        self.threshold = threshold
        self.reset_potential = reset_potential
        self.dt = dt
        self.membrane_potential = reset_potential
        self.spike = 0
        self.last_spike_time = 0

    def reset(self):
        self.membrane_potential = self.reset_potential
        self.spike = 0

    def integrate(self, inputs):
        self.membrane_potential *= torch.exp(-torch.tensor(self.dt) / torch.tensor(self.tau))  # 실수를 텐서로 변환
        total_input = torch.sum(inputs)
        self.membrane_potential += total_input
        if self.membrane_potential >= self.threshold:
            self.spike = 1
            self.last_spike_time = 0
            self.membrane_potential = self.reset_potential
        else:
            self.spike = 0
            self.last_spike_time += 1

# SNN 모델
class SNN:
    def __init__(self, num_neurons):
        self.num_neurons = num_neurons
        self.neurons = [LIFNeuron() for _ in range(num_neurons)]
        self.weights = torch.rand(num_neurons, num_neurons)

    def simulate(self, inputs):
        output_spikes = []
        for neuron in self.neurons:
            neuron.reset()

        for i, neuron in enumerate(self.neurons):
            neuron.integrate(inputs * self.weights[i])
            output_spikes.append(neuron.spike)
        
        return output_spikes

    def train_stdp(self, input_spikes, output_spikes):
        pass  # STDP 훈련 알고리즘 구현

# 훈련 함수
def train_snn(snn, x_train, y_train, epochs=1):
    for epoch in range(epochs):
        print("Epoch:", epoch+1)
        for i in range(10):
            print(i,len(x_train))
            inputs = x_train[i].view(-1)
            label = y_train[i]
            output_spikes = snn.simulate(inputs)
            snn.train_stdp(inputs, output_spikes)
            # 이후에 다른 훈련 알고리즘 적용 가능

# 테스트 함수
def test_snn(snn, x_test, y_test):
    correct = 0
    for i in range(len(x_test)):
        inputs = x_test[i].view(-1)  # inputs 텐서를 평탄화
        label = y_test[i]
        output_spikes = snn.simulate(inputs)
        predicted_label = torch.argmax(torch.tensor(output_spikes))  # 리스트를 텐서로 변환
        if predicted_label == label:
            correct += 1
    accuracy = correct / len(x_test)
    print("Test Accuracy:", accuracy)

# MNIST 데이터 로드
trainset, testset = load_mnist()

# SNN 모델 초기화
num_neurons = 784  # MNIST 이미지 크기에 맞는 입력 뉴런 수
snn = SNN(num_neurons)

# SNN 훈련
train_snn(snn, trainset.data, trainset.targets)

# SNN 테스트
test_snn(snn, testset.data, testset.targets)

Epoch: 1
0 60000
1 60000
2 60000
3 60000
4 60000
5 60000
6 60000
7 60000
8 60000
9 60000


KeyboardInterrupt: 

In [28]:
import numpy as np

# 파라미터 설정
num_inputs = 10
num_neurons = 5
threshold = 2.5

# 가중치와 입력 초기화
weights = np.random.rand(num_inputs, num_neurons)
inputs = np.random.rand(num_inputs)

# 뉴런의 출력 계산
outputs = np.dot(inputs, weights)

# 임계값을 초과하는 뉴런에 대해 스파이크 발생
spikes = outputs > threshold

print("Spikes:")
print(spikes)

Spikes:
[False False False False False]
