In [1]:
from spikingjelly.activation_based.neuron import ParametricLIFNode, LIFNode, BaseNode
from jaxtyping import Float, Int
from torch import Tensor

import torch
import matplotlib.pyplot as plt
import numpy as np
import spikingjelly.activation_based as snn
from spikingjelly.activation_based import surrogate, neuron, functional
from typing import Callable, Any, overload
from tqdm.auto import tqdm
from sklearn.model_selection import train_test_split

from utils.datasets import generate_lp_dataset, encode_temporal

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

np.random.seed(42)
rng = torch.manual_seed(42)

In [3]:
class TransposeLayer(torch.nn.Module):
    dims: tuple[int, int]

    def __init__(self, dims: tuple[int, int]) -> None:
        super().__init__()
        self.dims = dims

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return input.transpose(*self.dims)

class PrintLayer(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        print(input.shape)
        return input

In [21]:
class CCN(torch.nn.Module):
    def __init__(self,
                 vector_dim:int,
                 cc_acc:int,
                 feature_dims:list[int]):
        """
        Cross-correlation network initialization.
        
        :param self: 설명
        :param input_dim: ...
        :param n_acc: 설명
        """
        super(CCN, self).__init__()
        self.vector_dim = vector_dim
        self.cc_acc = cc_acc
        self.feature_dims = feature_dims
        
        self.model = torch.nn.Sequential(
            torch.nn.Unflatten(2, (2, self.vector_dim)), # T,N,2D -> T,N,2,D
            TransposeLayer((2,3)),
            torch.nn.Linear(2, self.cc_acc, bias=False),
            ParametricLIFNode(v_reset=None, surrogate_function=surrogate.Sigmoid()),
            torch.nn.Flatten(start_dim=2)
        )
        # Use ParametricLIF only in first layer
        feature_dims = [self.cc_acc * self.vector_dim] + self.feature_dims
        for in_dim, out_dim in zip(feature_dims[:-1], feature_dims[1:]):
            self.model.append(torch.nn.Linear(in_dim, out_dim))
            self.model.append(LIFNode(v_reset=None, surrogate_function=surrogate.Sigmoid()))
        

    def forward(self, x:Float[Tensor, "T N 2D"], return_rate:bool=True, reset:bool=True):
        """
        Compute the correlation between two input tensors.
        
        :param self: 설명
        :param x: 설명
        :type x: torch.Tensor
        """
        if reset:
            for layer in self.model:
                if isinstance(layer, BaseNode):
                    layer.reset()
        x = self.model(x)
        # print(x[:,0,...])
        if return_rate:
            x = x.mean(dim=0)
        return x


In [22]:
NUM_SAMPLES = 10000  # 총 1000 개의 샘플 생성
VECTOR_DIM = 10      # 각 벡터는 10차원
MAX_VAL = 10.0
TIME_STEPS = 20     # SNN을 20 타임스텝 동안 실행

model = CCN(vector_dim = VECTOR_DIM, cc_acc=5, feature_dims=[3, 1]).to(device)
model

CCN(
  (model): Sequential(
    (0): Unflatten(dim=2, unflattened_size=(2, 10))
    (1): TransposeLayer()
    (2): Linear(in_features=2, out_features=5, bias=False)
    (3): ParametricLIFNode(
      v_threshold=1.0, v_reset=None, detach_reset=False, step_mode=s, backend=torch, tau=2.0
      (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
    )
    (4): Flatten(start_dim=2, end_dim=-1)
    (5): Linear(in_features=50, out_features=3, bias=True)
    (6): LIFNode(
      v_threshold=1.0, v_reset=None, detach_reset=False, step_mode=s, backend=torch, tau=2.0
      (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
    )
    (7): Linear(in_features=3, out_features=1, bias=True)
    (8): LIFNode(
      v_threshold=1.0, v_reset=None, detach_reset=False, step_mode=s, backend=torch, tau=2.0
      (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
    )
  )
)

In [23]:
X_data, y_data = generate_lp_dataset(NUM_SAMPLES, VECTOR_DIM, p=1.0, max_val=MAX_VAL)
X_data = torch.FloatTensor(encode_temporal(X_data, TIME_STEPS)) # TN(2D)
y_data = torch.FloatTensor(y_data) / y_data.max() # ND
dataset = torch.utils.data.TensorDataset(X_data.transpose(1, 0), y_data)  # TN(2D) -> NT(2D)

train_test_split:Callable[[np.ndarray, float, int], tuple[np.ndarray, np.ndarray]] = train_test_split
train_set, test_set = train_test_split(dataset, test_size=0.2, random_state=42)

train_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size=32,
    shuffle=True,
    drop_last=False
    )
test_loader = torch.utils.data.DataLoader(
    test_set,
    batch_size=32,
    shuffle=False,
    drop_last=False
)

criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss = torch.tensor(float("inf"))

pbar = tqdm(range(20))
for epoch in pbar:
    model.train()
    for batch in tqdm(train_loader, leave=False):
        inputs:Float[Tensor, "N T 2D"]; targets:Float[Tensor, "N D"]
        inputs, targets = batch
        inputs = inputs.to(device); targets = targets.to(device)
        # 모델 학습 코드 추가
        out:Float[Tensor, "N D"] = model(inputs.transpose(1, 0)) # NT(2D)->TND->model->ND
        loss = criterion(out, targets)
        
        optimizer.zero_grad()
        loss.backward(retain_graph=True)
        optimizer.step()
        pbar.set_postfix({"loss": loss.item()})
    
    with torch.no_grad():
        model.eval()
        for batch in tqdm(test_loader, leave=False):
            inputs, targets = batch
            inputs = inputs.to(device); targets = targets.to(device)
            # 모델 학습 코드 추가
            out = model(inputs.transpose(1, 0)) # NT(2D)->TND->model->ND
            
            loss = criterion(out, targets)
            pbar.set_postfix({"loss": loss.item(), "pred": out[0].item(), "target": targets[0].item()})
        
        

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/250 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/250 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/250 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/250 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/250 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/250 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/250 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/250 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/250 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/250 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/250 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/250 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/250 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/250 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/250 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/250 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/250 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/250 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/250 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/250 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

In [24]:
sample = dataset[0:10]
inputs, targets = sample
inputs = inputs.to(device); targets = targets.to(device)
out = model(inputs.transpose(1, 0))

In [25]:
torch.hstack((out, targets))

tensor([[0.0000, 0.2851],
        [0.0000, 0.5489],
        [0.0000, 0.2824],
        [0.0000, 0.6193],
        [0.0000, 0.3226],
        [0.0000, 0.5060],
        [0.0000, 0.4093],
        [0.0000, 0.5624],
        [0.0000, 0.6035],
        [0.0000, 0.5745]], device='cuda:0')