In [1]:
import torch
import random
import numpy as np
import os

In [2]:
import modules

In [3]:
def seed_assign(seed):
    random.seed(seed)                          # Python random 시드 고정
    np.random.seed(seed)                       # NumPy 시드 고정
    torch.manual_seed(seed)                    # PyTorch CPU 시드 고정
    torch.cuda.manual_seed(seed)               # PyTorch GPU 시드 고정
    torch.cuda.manual_seed_all(seed)           # PyTorch 멀티 GPU 시드 고정
    torch.backends.cudnn.deterministic = True  # 연산의 결정론적 동작 보장
    # torch.backends.cudnn.benchmark = False     # 성능 최적화 비활성화 (결정론적 보장)


In [4]:
class SNN(torch.nn.Module):
    def __init__(self, v_decay, v_threshold, v_reset_mode, sg_width, surrogate, CLASS_NUM, in_channels, IMAGE_SIZE, time_step):
        super(SNN, self).__init__()
        self.TIME = time_step

        self.conv1 = torch.nn.Conv2d(in_channels=in_channels, out_channels=32, kernel_size=3, stride=1, padding=1)
        IMAGE_SIZE = (IMAGE_SIZE + 2 - 3 // 1) + 1

        self.lif1 = modules.neuron.LIF_layer(v_decay, v_threshold, v_reset_mode, sg_width, surrogate)

        self.fc1 = torch.nn.Linear(32 * IMAGE_SIZE * IMAGE_SIZE, CLASS_NUM)

        self.lif2 = modules.neuron.LIF_layer(v_decay, v_threshold, v_reset_mode, sg_width, surrogate)

    def forward(self, x):
        # SHAPE : [Batch, Time_step, Channel, H, W]

        x = torch.transpose(x, 0, 1)
        # SHAPE : [Time_step, Batch, Channel, H, W]

        T, B, *spatial_dims = x.shape
        x = x.reshape(T * B, *spatial_dims)
        # SHAPE : [Time_step * Batch, Channel, H, W]

        x = self.conv1(x)
        # SHAPE : [Time_step * Batch, Channel, H, W]

        TB, *spatial_dims = x.shape
        x = x.reshape(self.TIME , TB // self.TIME, *spatial_dims)
        # SHAPE : [Time_step, Batch, Channel, H, W]

        x = self.lif1(x)
        # SHAPE : [Time_step, Batch, Channel, H, W]
        
        T, B, *spatial_dims = x.shape
        x = x.reshape(T * B, *spatial_dims)
        # SHAPE : [Time_step * Batch, Channel, H, W]

        x = x.reshape(x.size(0), -1)
        x = self.fc1(x)
        # SHAPE : [Time_step * Batch, CLASS_NUM]

        TB, *spatial_dims = x.shape
        x = x.reshape(self.TIME , TB // self.TIME, *spatial_dims)
        # SHAPE : [Time_step, Batch, CLASS_NUM]

        x = self.lif2(x)
        # SHAPE : [Time_step, Batch, CLASS_NUM]

        x = x.sum(dim=0)
        # SHAPE : [Batch, CLASS_NUM]

        return x

In [5]:
def snn_system(seed,
                which_data,
                batch_size,
                data_path,
                learning_rate,
                time_step,
                v_decay,
                v_threshold,
                v_reset_mode,
                sg_width,
                surrogate,
                max_epoch,
                gpu):
    seed_assign(seed)
    os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 
    os.environ["CUDA_VISIBLE_DEVICES"]= gpu

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train_loader, test_loader, CLASS_NUM, in_channels, IMAGE_SIZE = modules.data_loader.data_loader(which_data, data_path, batch_size)
    net = SNN(v_decay=v_decay, v_threshold=v_threshold, v_reset_mode=v_reset_mode, sg_width=sg_width, surrogate=surrogate, CLASS_NUM=CLASS_NUM, in_channels=in_channels, IMAGE_SIZE=IMAGE_SIZE, time_step=time_step)
    net = net.to(device)

    criterion = torch.nn.CrossEntropyLoss().to(device)

    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
    # optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9)


    for epoch in range(max_epoch):
        print(f'Epoch-{epoch}')
        net.train()
        correct_train = 0
        total_train = 0
        train_loss = 0.0
        for images, labels in train_loader:
            images = images.to(device)
            labels = labels.to(device)
            images = images.unsqueeze(1).repeat(1, time_step, 1, 1, 1)  # repeat 코딩. rate코딩 등을 실험해봐도 좋다. # (batch, time_step, C, H, W)
            
            optimizer.zero_grad()
            outputs = net(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # Training accuracy
            _, predicted = torch.max(outputs.data, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()
            train_loss += loss.item()

        training_accuracy = 100 * correct_train / total_train
        avg_train_loss = train_loss / len(train_loader)
        print(f'Training_loss: {avg_train_loss:.4f}, Training_accuracy: {training_accuracy:.2f} %')

        # Validation
        net.eval()
        correct = 0
        total = 0
        val_loss = 0.0
        with torch.no_grad():
            for images, labels in test_loader:
                images = images.to(device)
                labels = labels.to(device)
                
                images = images.unsqueeze(1).repeat(1, time_step, 1, 1, 1)  # repeat 코딩. rate코딩 등 실험해봐도 좋다. # (batch, time_step, C, H, W)

                outputs = net(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()

                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        validation_accuracy = 100 * correct / total
        avg_val_loss = val_loss / len(test_loader)
        print(f'Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {validation_accuracy:.2f} %')
        print()


In [6]:
seed = 42
which_data = 'MNIST'
batch_size = 64
data_path = '/data2' # <-- 데이터셋 경로 nfs로 지정하면 안됩니다. 무슨 말인지 잘 모르겠으면 문의해주세요.
learning_rate = 0.001
time_step = 10
v_decay = 0.5
v_threshold = 0.5
v_reset_mode = 'soft_reset' # 'soft_reset' or 'hard_reset'
sg_width = 4.0 # surrogate gradient width
surrogate = 'sigmoid' # 'sigmoid' or 'rectangle' or 'rough_rectangle' or 'hard_sigmoid'
max_epoch = 10
gpu = '0' # 사용할 GPU 번호

snn_system(seed=seed,
            which_data=which_data,
            batch_size=batch_size,
            data_path=data_path,
            learning_rate=learning_rate,
            time_step=time_step,
            v_decay=v_decay,
            v_threshold=v_threshold,
            v_reset_mode=v_reset_mode,
            sg_width=sg_width,
            surrogate=surrogate,
            max_epoch=max_epoch,
            gpu=gpu)

Epoch-0
Training_loss: 4.2778, Training_accuracy: 23.14 %
Validation Loss: 3.5227, Validation Accuracy: 29.75 %

Epoch-1
Training_loss: 1.1334, Training_accuracy: 73.37 %
Validation Loss: 0.3618, Validation Accuracy: 87.36 %

Epoch-2
Training_loss: 0.2866, Training_accuracy: 90.08 %
Validation Loss: 0.1287, Validation Accuracy: 96.33 %

Epoch-3
Training_loss: 0.1039, Training_accuracy: 97.22 %
Validation Loss: 0.0979, Validation Accuracy: 97.19 %

Epoch-4
Training_loss: 0.0791, Training_accuracy: 97.85 %
Validation Loss: 0.0988, Validation Accuracy: 97.21 %

Epoch-5
Training_loss: 0.0580, Training_accuracy: 98.28 %
Validation Loss: 0.0925, Validation Accuracy: 97.12 %

Epoch-6
Training_loss: 0.0480, Training_accuracy: 98.58 %
Validation Loss: 0.0950, Validation Accuracy: 97.45 %

Epoch-7
Training_loss: 0.0436, Training_accuracy: 98.72 %
Validation Loss: 0.1216, Validation Accuracy: 96.42 %

Epoch-8
Training_loss: 0.0366, Training_accuracy: 98.87 %
Validation Loss: 0.0956, Validation Ac