In [None]:
from dataclasses import dataclass, asdict

import torch, wandb
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from utils.config import Config, ModelType
from utils.model.seq2seq import SimpleRNN, SimpleLSTM, LearnableDelayRNN
from utils.data import SwapDataset, collate_fn
from tqdm.auto import tqdm

# torch.autograd.set_detect_anomaly(True)

config = Config()
config.model_type = ModelType.LEARNABLE_DELAY_RNN
config.input_size = 11
config.num_classes = 10
config.max_delay = 20
config.seq_min = 5         # 최소 시퀀스 길이
config.seq_max = 20        # 최대 시퀀스 길이
config.device = torch.device("cuda:4")

run = wandb.init(project="QSWAP_RNN", name="LearnableDelayRNN_QSWAP", config=asdict(config))
run.__enter__()

print(f"Using device: {config.device}")

[34m[1mwandb[0m: [wandb.login()] Loaded credentials for https://api.wandb.ai from /root/.netrc.
[34m[1mwandb[0m: Currently logged in as: [33msizzflair97[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Using device: cuda:4


In [2]:
DATASET_SIZE = 5000 # 한 Epoch에 사용할 데이터 수
train_dataset = SwapDataset(size=DATASET_SIZE, k=config.input_size-1, min_len=config.seq_min, max_len=config.seq_max)
test_dataset = SwapDataset(size=DATASET_SIZE//10, k=config.input_size-1, min_len=config.seq_min, max_len=config.seq_max)

# DataLoader 생성 (collate_fn 등록 필수!)
train_loader = DataLoader(
    train_dataset, 
    batch_size=config.batch_size,       # 확인용으로 작은 배치
    shuffle=True,       # 학습 시 셔플 추천
    collate_fn=collate_fn # 우리가 만든 패딩 함수 적용
)
test_loader = DataLoader(
    test_dataset, 
    batch_size=config.batch_size, 
    shuffle=False, 
    collate_fn=collate_fn
)

In [None]:
torch.manual_seed(42) # 재현성을 위해 시드 고정

match config.model_type:
    case ModelType.SIMPLE_RNN:
        model = SimpleRNN(config.input_size, config.hidden_size, config.num_classes, config=config).to(config.device)
    case ModelType.SIMPLE_LSTM:
        model = SimpleLSTM(config.input_size, config.hidden_size, config.num_classes, config=config).to(config.device)
    case ModelType.LEARNABLE_DELAY_RNN:
        model = LearnableDelayRNN(config.batch_size, config.input_size, config.hidden_size, config.num_classes, max_delay=config.max_delay, config=config).to(config.device)
    case _:
        raise ValueError(f"Unknown model type: {config.model_type}")
    
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)

# 5. 학습 루프
for epoch in tqdm(range(config.epochs), desc="Epochs"):
    total_loss = 0
    model.train()
    for i, (inputs, targets, lengths) in tqdm(enumerate(train_loader), total=len(train_loader), desc="Batches", leave=False):
        inputs, targets = inputs.to(config.device), targets.to(config.device)
        
        optimizer.zero_grad()
        
        # 모델 Forward
        outputs = model(inputs, config.seq_max)
        
        # Loss
        # outputs: [Batch, Max_Len, K] -> Flatten
        # targets: [Batch, Max_Len] -> Flatten (-1은 ignore_index 처리됨)
        loss = criterion(outputs.reshape(-1, config.seq_max), targets.reshape(-1))
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        wandb.log({"Loss/Train": loss.item(),
                   "Accuracy/Train": (outputs.argmax(dim=2) == targets).float().mean().item()})
        if (i+1) % 300 == 0:
            print(f'Epoch [{epoch+1}/{config.epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')

    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for inputs, targets, lengths in test_loader:
            inputs, targets = inputs.to(config.device), targets.to(config.device)
            
            outputs = model(inputs, lengths)
            _, predicted = torch.max(outputs.data, 2)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
        
        wandb.log({"Accuracy/Validation": 100 * correct / total})
        print(f'Validation Accuracy after Epoch {epoch+1}: {100 * correct / total:.2f}%')
            
# 6. 평가 루프
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for inputs, targets, lengths in tqdm(test_loader, desc="Testing"):
        inputs, targets = inputs.to(config.device), targets.to(config.device)
        
        outputs = model(inputs, lengths)
        _, predicted = torch.max(outputs.data, 2)
        total += targets.size(0)
        correct += (predicted == targets).sum().item()
        wandb.log({"Accuracy/Test": 100 * correct / total})

    print(f'Test Accuracy of the RNN on the 10000 test images (PSMNIST): {100 * correct / total:.2f}%')

Epochs:   0%|          | 0/10 [00:00<?, ?it/s]

Batches:   0%|          | 0/79 [00:00<?, ?it/s]

ValueError: Expected input batch_size (736) to match target batch_size (1280).

In [4]:
outputs.shape, targets.shape

(torch.Size([64, 23, 10]), torch.Size([64, 20]))

In [None]:
run.__exit__(None, None, None)