In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

# 1. Dataset 정의
class MutationDataset(Dataset):
    def __init__(self, mutation_tensor, labels):
        self.mutation_tensor = mutation_tensor  # [N, 4384, 4]
        self.labels = labels  # [N]

    def __len__(self):
        return self.mutation_tensor.size(0)

    def __getitem__(self, idx):
        return self.mutation_tensor[idx], self.labels[idx]

# 2. 모델 & 하이퍼파라미터 세팅
model = MutationAttentionExternalVTransformed()
criterion = nn.CrossEntropyLoss()  # 다중 클래스 분류
optimizer = optim.Adam(model.parameters(), lr=1e-3)
epochs = 10
batch_size = 32

# 3. 외부 V 고정 준비
external_V = pretrained_glove_tensor.clone().detach()  # [4384, 128]
external_V.requires_grad = False  # 학습하지 않음

# 4. DataLoader 준비
train_dataset = MutationDataset(mutation_input_train, labels_train)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# 5. 학습 루프
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
external_V = external_V.to(device)

for epoch in range(epochs):
    model.train()
    total_loss = 0

    for mutation_input_batch, label_batch in train_loader:
        mutation_input_batch = mutation_input_batch.to(device)
        label_batch = label_batch.to(device)

        # Forward
        output = model(mutation_input_batch, external_V)  # [batch_size, 26]
        
        loss = criterion(output, label_batch)

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader):.4f}")
