In [None]:
%matplotlib inline
import os
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn, optim
import torch.nn.functional as F
from torchvision import datasets
import torchvision.transforms as T

In [None]:
##  NETWORK  ##
class ConvNormRelu(nn.Module):
    def __init__(self, c_in, c_out):
        super().__init__()
        self.conv = nn.Conv2d(c_in, c_out, kernel_size=3, padding=1)
        self.norm = nn.BatchNorm2d(c_out)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        return self.relu(self.norm(self.conv(x)))
        

class VGG(nn.Module):
    def __init__(self, cfg=None):
        super().__init__()
        self.conv1 = ConvNormRelu(3, 64)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv2 = ConvNormRelu(64, 128)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv3 = ConvNormRelu(128, 256)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv4 = ConvNormRelu(256, 512)
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.avg_pool = nn.AvgPool2d(kernel_size=1, stride=2)
        self.classifier = nn.Linear(512, 10)
    
    def forward(self, x):
        h = self.pool1(self.conv1(x))
        h = self.pool2(self.conv2(h))
        h = self.pool3(self.conv3(h))
        h = self.pool4(self.conv4(h))
        h = self.avg_pool(h)
        h = h.view(h.size(0), -1)
        return self.classifier(h)      
        

In [None]:
##  DATA  ##

DATA_DIR = os.environ['HOME'] + '/.Data'

transforms = T.Compose([
    T.ToTensor(),
    T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# Datasets & loaders.
trainset = datasets.CIFAR10(root=DATA_DIR, train=True, download=True, transform=transforms)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=16, shuffle=True)

testset = datasets.CIFAR10(root=DATA_DIR, train=False, download=True, transform=transforms)
testloader = torch.utils.data.DataLoader(testset, batch_size=16, shuffle=False)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
##  MODEL  ##
DEVICE = 'cuda'
net = VGG()
net.to(DEVICE)

criterion = nn.CrossEntropyLoss()
opt = optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-3)
scheduler = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=200)

In [None]:
##  TRAINING  ##
best_acc = 0

def train(epoch):
    print(f"{epoch = }")
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for i, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
        opt.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        opt.step()
        
        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        
        #print(f"{i = }\nLoss: {train_loss / (i+1): .3f} | Acc: {100 * correct / total: .3f}")
    


@torch.no_grad()
def test(epoch, chkpt_dir='/home/evan/Checkpoints'):
    global best_acc
    net.eval()
    test_loss = 0
    correct = total = 0
    for i, (inputs, targets) in enumerate(testloader):
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        test_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    print(f"Test Loss: {test_loss / (i+1): .3f} | Acc: {100 * correct / total: .3f}")
    
    # Save checkpoint.
    acc = 100 * correct / total
    if acc > best_acc:
        print("Saving...")
        state = dict(
            net=net.state_dict(),
            acc=acc,
            epoch=epoch
        )
        if not os.path.isdir(chkpt_dir):
            os.mkdir(chkpt_dir)
        torch.save(state, chkpt_dir + '/chkpt.pth')
        best_acc = acc

In [None]:
for epoch in range(0, 20):
    train(epoch)
    test(epoch)
    scheduler.step()