In [9]:
import os

import torch
from torchvision import datasets, models, transforms
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from utils.data import get_data_loaders
from utils.train_eval import train

In [2]:
# Either cifar10 or flowers102
dataset = "cifar10"
train_loader, test_loader, num_classes = get_data_loaders(dataset=dataset, batch_size=64, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
print(f"Number of training batches: {len(train_loader)}")
print(f"Number of testing batches: {len(test_loader)}")

Number of training batches: 782
Number of testing batches: 157


In [4]:
# the network to test
model = models.resnet18(weights=None)

learning_rate = 1e-3

# Modify the last fully connected layer
fc_input = model.fc.in_features
model.fc = nn.Linear(fc_input, num_classes)

# print(model)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [5]:
num_params = sum(torch.numel(p) for p in model.parameters())
print(f"Number of parameters: {num_params}")

Number of parameters: 11181642


In [6]:
log_dir = f"./logs/{dataset}/baseline"  # Set the directory for storing the logs
writer = SummaryWriter(log_dir)

In [7]:
num_epochs = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = model.to(device)

# log every these many batches 
log_train_every = 150
log_test_every = 250

config = {
    "opt": optimizer,
    "crit": criterion,
    "log_freq_tr": log_train_every,
    "log_freq_test": log_test_every,
}

train(model, train_loader, test_loader, num_epochs, config, device, writer)

[Epoch: 1, Batch: 100] Loss: 1.861
[Epoch: 1, Batch: 200] Loss: 1.655
[Epoch: 1, Batch: 300] Loss: 1.571
[Epoch: 1, Batch: 400] Loss: 1.443
[Epoch: 1, Batch: 500] Loss: 1.411
[Epoch: 1, Batch: 600] Loss: 1.348
[Epoch: 1, Batch: 700] Loss: 1.334
Training finished.


In [10]:
ckpt_dir = os.path.join(log_dir, "checkpoints")
os.makedirs(ckpt_dir, exist_ok=True)
torch.save(model, f"{ckpt_dir}/{num_epochs}.pth")