In [None]:
from dataclasses import dataclass, asdict

import torch, wandb
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from utils.config import Config, ModelType
from utils.model.classification import SimpleRNN, SimpleLSTM, SimpleGRU, LearnableDelayRNN
from tqdm.auto import tqdm

# torch.autograd.set_detect_anomaly(True)

config = Config(
    model_type=ModelType.DelayedRNN,
    max_delay=40,
    max_think_steps=100,
    seed=None,
    batch_size=32,
    input_size=1,
    # seq_length=784,
    seq_min=5,
    seq_max=20,
    hidden_size=256,
    num_classes=10,
    learning_rate=0.01,
    epochs=100,
    device=torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
)

run = wandb.init(entity="CIDA", project="PSMNIST_RNN", name=f"{config.model_type.name}", config=asdict(config))
run.__enter__()

print(f"Using device: {config.device}")

[34m[1mwandb[0m: [wandb.login()] Loaded credentials for https://api.wandb.ai from /home1/paul6598/.netrc.
[34m[1mwandb[0m: Currently logged in as: [33mjoonhuk6598[0m ([33mjoonhuk6598-university-of-seoul[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Using device: cpu


In [None]:
train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor(), download=True)

train_loader = DataLoader(dataset=train_dataset, batch_size=config.batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=config.batch_size, shuffle=False)

# 3. 고정된 순열(Permutation) 생성
# 모든 배치와 에폭에서 동일한 순서로 섞어야 PSMNIST가 성립됩니다.
if config.seed is not None:
    torch.manual_seed(config.seed) # 재현성을 위해 시드 고정
perm_order = torch.randperm(config.seq_length).to(config.device)

match config.model_type:
    case ModelType.RNN:
        model = SimpleRNN(config.input_size, config.hidden_size, config.num_classes, config=config).to(config.device)
    case ModelType.LSTM:
        model = SimpleLSTM(config.input_size, config.hidden_size, config.num_classes, config=config).to(config.device)
    case ModelType.GRU:
        model = SimpleGRU(config.input_size, config.hidden_size, config.num_classes, config=config).to(config.device)
    case ModelType.DelayedRNN:
        model = LearnableDelayRNN(config.input_size, config.hidden_size, config.num_classes, max_delay=config.max_delay, config=config).to(config.device)
    case _:
        raise ValueError(f"Unknown model type: {config.model_type}")
    
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)

# 5. 학습 루프
for epoch in tqdm(range(config.epochs), desc="Epochs"):
    model.train()
    for i, (images, labels) in tqdm(enumerate(train_loader), total=len(train_loader), desc="Batches", leave=False):
        # 이미지 변형: (Batch, 1, 28, 28) -> (Batch, 784, 1)
        images = images.view(-1, config.seq_length, config.input_size).to(config.device)
        labels = labels.to(config.device)
        
        # *** 중요: 여기서 픽셀 순서를 섞습니다 (PSMNIST 핵심) ***
        images = images[:, perm_order, :]
        
        # 순전파
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # 역전파 및 최적화
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        wandb.log({"Loss/Train": loss.item(),
                   "Accuracy/Train": (outputs.argmax(dim=1) == labels).float().mean().item()})
        if (i+1) % 300 == 0:
            print(f'Epoch [{epoch+1}/{config.epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')

    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.view(-1, config.seq_length, config.input_size).to(config.device)
            
            # 테스트셋에도 동일한 순열 적용
            images = images[:, perm_order, :]
            labels = labels.to(config.device)
            
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        wandb.log({"Accuracy/Validation": 100 * correct / total})
        print(f'Validation Accuracy after Epoch {epoch+1}: {100 * correct / total:.2f}%')
# 6. 평가 루프
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in tqdm(test_loader, desc="Testing"):
        images = images.view(-1, config.seq_length, config.input_size).to(config.device)
        
        # 테스트셋에도 동일한 순열 적용
        images = images[:, perm_order, :]
        labels = labels.to(config.device)
        
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        wandb.log({"Accuracy/Test": 100 * correct / total})

    print(f'Test Accuracy of the RNN on the 10000 test images (PSMNIST): {100 * correct / total:.2f}%')

KeyboardInterrupt: 

Error in callback <bound method _WandbInit._post_run_cell_hook of <wandb.sdk.wandb_init._WandbInit object at 0x7f05cd1f1940>> (for post_run_cell), with arguments args (<ExecutionResult object at 7f0782378c30, execution_count=2 error_before_exec=None error_in_exec= info=<ExecutionInfo object at 7f0782378d60, raw_cell="train_dataset = datasets.MNIST(root='./data', trai.." transformed_cell="train_dataset = datasets.MNIST(root='./data', trai.." store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell://ssh-remote%2Bgate1_remote/home1/paul6598/delayed-rnn/psmnist.ipynb#W1sdnNjb2RlLXJlbW90ZQ%3D%3D> result=None>,),kwargs {}:


ConnectionResetError: Connection lost

In [None]:
run.__exit__(None, None, None)