In [1]:
import torch, torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from cnn import CNNVarTime

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = CNNVarTime(in_ch=2, num_classes=11).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)  # L2 via weight_decay
criterion = nn.CrossEntropyLoss()

In [5]:
def train_epoch(model, loader, optimizer, criterion, device):
    model.train()
    running = 0.0
    for xb, yb in tqdm(loader, leave=False):
        xb, yb = xb.to(device), yb.to(device)
        logits = model(xb)
        loss = criterion(logits, yb)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running += loss.item() * xb.size(0)
    return running / len(loader.dataset)

@torch.no_grad()
def eval_epoch(model, loader, criterion, device):
    model.eval()
    total_loss, correct = 0.0, 0
    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        logits = model(xb)
        loss = criterion(logits, yb)
        total_loss += loss.item() * xb.size(0)
        correct += (logits.argmax(1) == yb).sum().item()
    return total_loss / len(loader.dataset), correct / len(loader.dataset)

In [6]:
# cell 3: run training
EPOCHS = 15
best_acc = 0.0
for epoch in range(1, EPOCHS+1):
    tr_loss = train_epoch(model, train_loader, optimizer, criterion, device)
    va_loss, va_acc = eval_epoch(model, val_loader, criterion, device)
    print(f"epoch {epoch:02d} | train {tr_loss:.4f} | val {va_loss:.4f} | acc {va_acc:.3f}")
    if va_acc > best_acc:
        best_acc = va_acc
        torch.save({"model": model.state_dict()}, "irmas_cnn.pt")
        print("  saved irmas_cnn.pt")

NameError: name 'train_loader' is not defined

In [None]:
# # load pretrained
# base = CNNIRMASVarTime(in_ch=2, num_classes=11)
# base.load_state_dict(torch.load("irmas_cnn.pt")["model"])

# # new head for 4 classes
# ft = CNNIRMASVarTime(in_ch=2, num_classes=4)
# ft.load_state_dict({k:v for k,v in base.state_dict().items() if not k.startswith("fc.")}, strict=False)