In [1]:
import torch
import random
import numpy as np
import os

In [2]:
import modules

In [3]:
def seed_assign(seed):
    random.seed(seed)                          # Python random 시드 고정
    np.random.seed(seed)                       # NumPy 시드 고정
    torch.manual_seed(seed)                    # PyTorch CPU 시드 고정
    torch.cuda.manual_seed(seed)               # PyTorch GPU 시드 고정
    torch.cuda.manual_seed_all(seed)           # PyTorch 멀티 GPU 시드 고정
    torch.backends.cudnn.deterministic = True  # 연산의 결정론적 동작 보장
    # torch.backends.cudnn.benchmark = False     # 성능 최적화 비활성화 (결정론적 보장)


In [4]:
class SNN(torch.nn.Module):
    def __init__(self, v_decay, v_threshold, v_reset_mode, sg_width, surrogate, CLASS_NUM, in_channels, IMAGE_SIZE, time_step):
        super(SNN, self).__init__()
        self.TIME = time_step

        self.conv1 = torch.nn.Conv2d(in_channels=in_channels, out_channels=32, kernel_size=3, stride=1, padding=1)
        IMAGE_SIZE = (IMAGE_SIZE + 2 - 3 // 1) + 1
        self.lif1 = modules.neuron.LIF_layer(v_decay, v_threshold, v_reset_mode, sg_width, surrogate)
        self.fc1 = torch.nn.Linear(32 * IMAGE_SIZE * IMAGE_SIZE, CLASS_NUM)
        self.lif2 = modules.neuron.LIF_layer(v_decay, v_threshold, v_reset_mode, sg_width, surrogate)

    def forward(self, x):
        # SHAPE : [Batch, Time_step, Channel, H, W]
        # print(x.shape)
        x = torch.transpose(x, 0, 1)
        # print(x.shape)

        T, B, *spatial_dims = x.shape
        x = x.reshape(T * B, *spatial_dims)
        # print(x.shape)

        x = self.conv1(x)
        # print(x.shape)

        TB, *spatial_dims = x.shape
        x = x.reshape(self.TIME , TB // self.TIME, *spatial_dims)
        # print(x.shape)

        x = self.lif1(x)
        # print(x.shape)
        
        T, B, *spatial_dims = x.shape
        x = x.reshape(T * B, *spatial_dims)
        # print(x.shape)

        x = x.reshape(x.size(0), -1)
        x = self.fc1(x)
        # print(x.shape)

        TB, *spatial_dims = x.shape
        x = x.reshape(self.TIME , TB // self.TIME, *spatial_dims)
        # print(x.shape)

        x = self.lif2(x)
        # print(x.shape)

        x = x.sum(dim=0)
        # print(x.shape)
        
        return x

Hyperparameter

In [5]:
which_data = 'MNIST'
batch_size = 64
data_path = '/data2' # <-- 데이터셋 경로 nfs로 지정하면 안됩니다. 무슨 말인지 잘 모르겠으면 문의해주세요.
learning_rate = 0.001
time_step = 10
v_decay = 0.5
v_threshold = 0.5
v_reset_mode = 'soft_reset' # 'soft_reset' or 'hard_reset'
sg_width = 4.0 # surrogate gradient width
surrogate = 'sigmoid' # 'sigmoid' or 'rectangle' or 'rough_rectangle' or 'hard_sigmoid'

max_epoch = 5
gpu = '0' # 사용할 GPU 번호


In [6]:
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 
os.environ["CUDA_VISIBLE_DEVICES"]= gpu

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader, test_loader, CLASS_NUM, in_channels, IMAGE_SIZE = modules.data_loader.data_loader(which_data, data_path, batch_size)
net = SNN(v_decay=v_decay, v_threshold=v_threshold, v_reset_mode=v_reset_mode, sg_width=sg_width, surrogate=surrogate, CLASS_NUM=CLASS_NUM, in_channels=in_channels, IMAGE_SIZE=IMAGE_SIZE, time_step=time_step)
net = net.to(device)

criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
# optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9)

In [7]:
for epoch in range(max_epoch):
    net.train()
    correct_train = 0
    total_train = 0
    train_loss = 0.0
    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)
        print(images.shape, images)
        images = images.unsqueeze(1).repeat(1, time_step, 1, 1, 1)  # repeat 코딩. rate코딩 등 실험해봐도 좋다. # (batch, time_step, C, H, W)
        
        optimizer.zero_grad()
        outputs = net(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # Training accuracy
        _, predicted = torch.max(outputs.data, 1)
        total_train += labels.size(0)
        correct_train += (predicted == labels).sum().item()
        train_loss += loss.item()

    training_accuracy = 100 * correct_train / total_train
    avg_train_loss = train_loss / len(train_loader)
    print(f'Epoch-{epoch} training_loss: {avg_train_loss:.4f}, training_accuracy: {training_accuracy:.2f} %')

    # Validation
    net.eval()
    correct = 0
    total = 0
    val_loss = 0.0
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            
            images = images.unsqueeze(1).repeat(1, time_step, 1, 1, 1)  # repeat 코딩. rate코딩 등 실험해봐도 좋다. # (batch, time_step, C, H, W)

            outputs = net(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    validation_accuracy = 100 * correct / total
    avg_val_loss = val_loss / len(test_loader)
    print(f'Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {validation_accuracy:.2f} %')
    print()


torch.Size([64, 1, 28, 28]) tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        ...,


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
   