In [3]:
# imports

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from tqdm import tqdm

from utils.simple_classifier import SimpleCNN

In [None]:
# global params
BATCH_SIZE      = 32
LEARNING_RATE   = 1.e-3
NUM_EPOCHS      = 10

In [4]:
# augmentations
augmented_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=20),  # +-20
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# without augmentations
basic_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# controlled parameters here are the datasets, we provide different transforms as to compare the change
train_set_aug = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=augmented_transform)
train_set_no_aug = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=basic_transform)
test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=basic_transform)


Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [None]:
# train and eval the model
def train_model_augmentations(train_set):
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
    
    model = SimpleCNN()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    epochs = NUM_EPOCHS
    train_loss, train_acc = [], []
    
    for epoch in range(epochs):
        model.train()
        running_loss, correct = 0.0, 0
        
        for images, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()

        train_loss.append(running_loss / len(train_loader))
        train_acc.append(100 * correct / len(train_set))

        print(f'epoch {epoch+1}/{epochs}, loss: {running_loss/len(train_loader):.4f}, acc: {100 * correct / len(train_set):.2f}%')

    return train_loss, train_acc