In [None]:
import sys
sys.path.append("../../")

import torch
import torch.nn as nn

from torchvision import transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
from tqdm import tqdm
from utils.plots import plot_losses

In [None]:
# these are the standard pre-computed values
cifar10_mean = (0.4914, 0.4822, 0.4465)
cifar10_std  = (0.2023, 0.1994, 0.2010)

train_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize(mean=cifar10_mean, std=cifar10_std)
])

val_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=cifar10_mean, std=cifar10_std)
])

train_dataset = CIFAR10(
    root="../../assets/cifar10", 
    train=True, 
    download=True, 
    transform=train_transforms
)
val_dataset = CIFAR10(
    root="../../assets/cifar10", 
    train=False, 
    download=True, 
    transform=val_transforms
)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

In [None]:
model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(32*32*3, 128),
    nn.ReLU(),
    nn.Linear(128, 10),
)

In [None]:
def train_one_epoch(model, dataloader, optimizer, loss_fn, device):
    model.train()
    running_loss = 0.0

    for images, labels in dataloader:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)

    return running_loss / len(dataloader.dataset)

In [None]:
def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0.0
    total_correct = 0

    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            total_loss += loss.item() * images.size(0)
            preds = outputs.argmax(dim=1)
            total_correct += (preds == labels).sum().item()

    avg_loss = total_loss / len(dataloader.dataset)
    accuracy = total_correct / len(dataloader.dataset)

    return avg_loss, accuracy

In [None]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
loss_fn = nn.CrossEntropyLoss()

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
model.to(device)

num_epochs = 50

In [None]:
train_losses = []
val_losses = []

for epoch in tqdm(range(num_epochs)):
    train_loss = train_one_epoch(
        model, train_loader, optimizer, loss_fn, device
    )

    val_loss, val_acc = evaluate(
        model, val_loader, loss_fn, device
    )

    train_losses.append(train_loss)
    val_losses.append(val_loss)

Expect a final accuracy of **~47%**

In [None]:
print(f"Final Validation Accuracy: {val_acc*100:.2f}%")

In [None]:
plot_losses([train_losses, val_losses], ["Train Loss", "Validation Loss"])