In [None]:
import torch
from models import get_model
from utils.train_utils import (
    get_mnist,
    get_cifar10,
    train_epoch,
    eval_epoch,
    save_checkpoint,
)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Training for LeNet with MNIST Dataset

In [None]:
model = get_model("lenet").to(device)
train_loader, test_loader = get_mnist()
opt = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(10):
    loss, acc = train_epoch(model, train_loader, opt, device)
    val = eval_epoch(model, test_loader, device)
    print(f"[LeNet] Epoch {epoch+1}: acc={acc:.3f}, val={val:.3f}")

save_checkpoint(model, "checkpoints/lenet_mnist.pth")

Training for ResNet10 with Cifar10 Dataset

In [None]:
model = get_model("resnet10").to(device)
train_loader, test_loader = get_cifar10()
opt = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(50):
    loss, acc = train_epoch(model, train_loader, opt, device)
    val = eval_epoch(model, test_loader, device)
    print(f"[ResNet10] Epoch {epoch+1}: acc={acc:.3f}, val={val:.3f}")

save_checkpoint(model, "checkpoints/resnet10_cifar10.pth")

Training for ResNet18 with Cifar10 Dataset

In [None]:
model = get_model("resnet18").to(device)
train_loader, test_loader = get_cifar10()
opt = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(50):
    loss, acc = train_epoch(model, train_loader, opt, device)
    val = eval_epoch(model, test_loader, device)
    print(f"[ResNet18] Epoch {epoch+1}: acc={acc:.3f}, val={val:.3f}")

save_checkpoint(model, "checkpoints/resnet18_cifar10.pth")