In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from collections import defaultdict
import pickle as pkl
import datetime
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from models.resnet import ResNet18

def get_device():
    return 'cuda' if torch.cuda.is_available() else 'cpu'

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,)*3, (0.5,)*3)
])

batch_size = 8

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
    download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
    shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
    download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
    shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
    'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [2]:
def train_model(model, trainloader, epochs=10, lr=1e-3, momentum=.9):
    """
    Train a model on CIFAR-10 using the PyTorch Module API and prints model 
    accuracies during training.
    
    Inputs:
    - model: A PyTorch Module giving the model to train.
    - optimizer: An Optimizer object we will use to train the model
    - epochs: (Optional) A Python integer giving the number of epochs to train for
    
    Returns: Lists of validation accuracies at the end of each epoch.
    """
    device = get_device()
    model = model.to(device)  # move the model parameters to CPU/GPU
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
    for e in range(epochs):
        for t, (x, y) in enumerate(trainloader):
            model.train()  # put model to training mode
            x = x.to(device)  # move to device, e.g. GPU
            y = y.to(device)

            scores = model(x)
            loss = criterion(scores, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

def eval_model(model, testloader):
    device = get_device()

    correct_pred = defaultdict(int)
    total_pred = defaultdict(int)

    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predictions = torch.max(outputs, 1)

            for label, prediction in zip(labels, predictions):
                if label == prediction:
                    correct_pred[classes[label]] += 1
                total_pred[classes[label]] += 1

    for classname, correct_count in correct_pred.items():
        accuracy = 100 * float(correct_count)/total_pred[classname]
        print(f'Accuracy for class: {classname:5s} is accuracy {accuracy:.1f}%')

def train_eval(model, trainloader, epochs=10, lr=1e-3, momentum=.9):
    train_model(model, trainloader, epochs=epochs, lr=lr, momentum=momentum)
    eval_model(model, testloader)

def save_model(model, model_name):
    date_str = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
    with open(f'/media/3tb/chet/saved_models/{model_name}_{date_str}.pkl', 'wb') as f:
        pkl.dump(model, f)

In [3]:
def sequential_head():
    return nn.Sequential(
        nn.Conv2d(3, 6, 5),
        nn.ReLU(),
        nn.MaxPool2d(2, 2),
        nn.Conv2d(6, 16, 5),
        nn.ReLU(),
        nn.MaxPool2d(2,2),
        nn.Flatten(),
        nn.Linear(16 * 5 * 5, 120),
        nn.ReLU(),
        nn.Linear(120, 84),
        nn.ReLU(),
        nn.Linear(84, 10)
    )

### Default Head ###

In [4]:
seq_net = sequential_head()
train_eval(seq_net, trainloader, epochs=10)


Accuracy for class: ship  is accuracy 70.3%
Accuracy for class: plane is accuracy 72.2%
Accuracy for class: frog  is accuracy 81.2%
Accuracy for class: cat   is accuracy 31.6%
Accuracy for class: car   is accuracy 78.6%
Accuracy for class: truck is accuracy 72.0%
Accuracy for class: dog   is accuracy 59.0%
Accuracy for class: horse is accuracy 68.8%
Accuracy for class: bird  is accuracy 50.9%
Accuracy for class: deer  is accuracy 50.9%


In [5]:
save_model(seq_net, 'sequential_cifar')

### ResNet18 ###

In [6]:
resnet = ResNet18()
train_eval(resnet, trainloader, epochs=10)

Accuracy for class: cat   is accuracy 68.7%
Accuracy for class: ship  is accuracy 89.0%
Accuracy for class: plane is accuracy 86.8%
Accuracy for class: frog  is accuracy 86.4%
Accuracy for class: car   is accuracy 92.9%
Accuracy for class: truck is accuracy 87.5%
Accuracy for class: dog   is accuracy 74.6%
Accuracy for class: horse is accuracy 85.2%
Accuracy for class: deer  is accuracy 80.5%
Accuracy for class: bird  is accuracy 68.2%


In [7]:
save_model(resnet, 'resnet_cifar')

### Timing ###

In [8]:
%timeit -n 5 -r 2 train_model(seq_net, trainloader, epochs=1)

17.5 s ± 55.3 ms per loop (mean ± std. dev. of 2 runs, 5 loops each)


In [9]:
%timeit -n 5 -r 2 train_model(resnet, trainloader, epochs=1)

1min 24s ± 807 ms per loop (mean ± std. dev. of 2 runs, 5 loops each)


### Smaller Datasets ###

In [10]:
class_idxs = np.vstack([np.nonzero(np.array(trainset.targets) == label)[0] for label in range(4)])

In [11]:
def cifar_subset(dataset_size):
    return torch.utils.data.Subset(trainset, class_idxs[:,:dataset_size//4].flatten())

In [12]:
trainset100 = cifar_subset(100)
trainloader100 = torch.utils.data.DataLoader(trainset100, batch_size=batch_size, shuffle=True, num_workers=2)

trainset1000 = cifar_subset(1000)
trainloader1000 = torch.utils.data.DataLoader(trainset1000, batch_size=batch_size, shuffle=True, num_workers=2)

trainset10000 = cifar_subset(10000)
trainloader10000 = torch.utils.data.DataLoader(trainset10000, batch_size=batch_size, shuffle=True, num_workers=2)

In [13]:
seq_net_100 = sequential_head()
train_eval(seq_net_100, trainloader100, epochs=10)
save_model(seq_net_100, 'sequential_cifar_100')

Accuracy for class: plane is accuracy 97.5%
Accuracy for class: car   is accuracy 11.5%
Accuracy for class: bird  is accuracy 0.3%


In [14]:
seq_net_1000 = sequential_head()
train_eval(seq_net_1000, trainloader1000, epochs=10)
save_model(seq_net_1000, 'sequential_cifar_1000')

Accuracy for class: cat   is accuracy 61.6%
Accuracy for class: plane is accuracy 60.7%
Accuracy for class: bird  is accuracy 46.8%
Accuracy for class: car   is accuracy 69.5%


In [15]:
seq_net_10000 = sequential_head()
train_eval(seq_net_10000, trainloader10000, epochs=10)
save_model(seq_net_10000, 'sequential_cifar_10000')

Accuracy for class: cat   is accuracy 72.3%
Accuracy for class: plane is accuracy 78.2%
Accuracy for class: car   is accuracy 85.9%
Accuracy for class: bird  is accuracy 71.0%
