Pytorch Alexnet Example
=======================

This is a complete example of training an alexnet on pytorch, fully within notebook, and using nothing but widely-used library functions.

Warning: this notebook download a full large-scale dataset (places365).  That is too large to do in a practical way on Google Colab, so you need to host this notebook on your own server.

In [None]:
import torch, torchvision, os

def train_alexnet_places(num_steps=100000):
    print("Making alexnet...")
    alexnet = make_untrained_alexnet_places()
    alexnet.train()
    print("Loading datasets...")
    train_loader, val_loader = get_train_and_val_data_loaders()
    print("Training classifier...")
    checkpointer = make_checkpointing_function(val_loader, checkpoint_dir='checkpoints')
    train_classifier(alexnet, train_loader,
                     max_iter=num_steps,
                     momentum=0.9,
                     init_lr=2e-2,
                     weight_decay=5e-4,
                     monitor=checkpointer)
    return alexnet

# Untrained Alexnet
-----------------

This function creates an untrained alexnet, with randomized parameters.

In [None]:
from torch import nn
from collections import OrderedDict
def make_untrained_alexnet_places():
    # channel widths
    w = [3, 96, 256, 384, 384, 256, 4096, 4096, 365]
    # Alexnet splits channels into groups
    groups = [1, 2, 1, 2, 2]
    model = nn.Sequential(OrderedDict([
        ('conv1', nn.Conv2d(w[0], w[1], kernel_size=11,
            stride=4,
            groups=groups[0], bias=True)),
        ('relu1', nn.ReLU(inplace=True)),
        ('pool1', nn.MaxPool2d(kernel_size=3, stride=2)),
        ('conv2', nn.Conv2d(w[1], w[2], kernel_size=5, padding=2,
            groups=groups[1], bias=True)),
        ('relu2', nn.ReLU(inplace=True)),
        ('pool2', nn.MaxPool2d(kernel_size=3, stride=2)),
        ('conv3', nn.Conv2d(w[2], w[3], kernel_size=3, padding=1,
            groups=groups[2], bias=True)),
        ('relu3', nn.ReLU(inplace=True)),
        ('conv4', nn.Conv2d(w[3], w[4], kernel_size=3, padding=1,
            groups=groups[3], bias=True)),
        ('relu4', nn.ReLU(inplace=True)),
        ('conv5', nn.Conv2d(w[4], w[5], kernel_size=3, padding=1,
            groups=groups[4], bias=True)),
        ('relu5', nn.ReLU(inplace=True)),
        ('pool5', nn.MaxPool2d(kernel_size=3, stride=2)),
        ('flatten', nn.Flatten()),
        ('fc6', nn.Linear(w[5] * 6 * 6, w[6], bias=True)),
        ('relu6', nn.ReLU(inplace=True)),
        ('dropout6', nn.Dropout()),
        ('fc7', nn.Linear(w[6], w[7], bias=True)),
        ('relu7', nn.ReLU(inplace=True)),
        ('dropout7', nn.Dropout()),
        ('fc8', nn.Linear(w[7], w[8]))
    ]))
    # Setup the initial parameters randomly
    for n, p in model.named_parameters():
        if 'bias' in n:
            torch.nn.init.zeros_(p)
        else:
            torch.nn.init.kaiming_normal_(p, nonlinearity='relu')
    model.cuda()
    model.train()
    return model

We can call the function to make a network, and then list all the network's trainable parameters.

In [None]:
a = make_untrained_alexnet_places()
for n, p in a.named_parameters():
    print(n, tuple(p.shape))

And we can save the uninitialized neural network.

In [None]:
torch.save(a.state_dict(), 'checkpoints/uninitialized_alexnet.pth')

Main Training Loop
------------------

This is a generic training loop for a classifier.

In [None]:
def train_classifier(model, train_data_loader, max_iter,
                     momentum=0.9, init_lr=2e-2, weight_decay=5e-4,
                     monitor=None):
    if monitor is not None:
        monitor(model, 0, 0.0, 0.0, 0)
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=init_lr, momentum=momentum, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, init_lr, max_iter)
    iter_num = 0
    while iter_num < max_iter:
        for t_input, t_target in train_data_loader:
            # Copy data into the gpu
            input_var, target_var = [d.cuda() for d in [t_input, t_target]]
            # Evaluate model
            output = model(input_var)
            loss = torch.nn.functional.cross_entropy(output, target_var)
            # Perform one step of SGD
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step() # Learning rate schedule
            # Check training set accuracy
            _, pred = output.max(1)
            batch_size = len(t_input)
            accuracy = target_var.detach().eq(pred).float().sum().item() / batch_size
            # Advance, and print out some stats
            iter_num += 1
            if monitor is not None:
                monitor(model, iter_num, loss, accuracy, batch_size)
            if iter_num >= max_iter:
                break

Data set
--------

This is the definition of the places data set used for training.
If we do not have the files, we download them.  And then we make a
DataSet object that defines how to resize, crop, and normalize the images.

The DataLoader objects wrap the dataset in a multithreaded streaming
object that batches the image data and loads it quickly.

In [None]:
def get_places_data_set(split, crop_size=227, download=True):
    dirname = f'datasets/places/{split}'
    nfs_source = '/data/vision/torralba/datasets/places/files'
    web_source = 'https://dissect.csail.mit.edu/datasets/'
    if not os.path.exists(dirname) and download:
        if os.path.exists(nfs_source):
            os.symlink(nfs_source, 'datasets/places')
        else:
            os.makedirs(dirname, exist_ok=True)
            torchvision.datasets.utils.download_and_extract_archive(
                'web_sources' +
                'places_%s.zip' % split,
                'datasets',
                md5=dict(val='593bbc21590cf7c396faac2e600cd30c',
                         train='d1db6ad3fc1d69b94da325ac08886a01')[split])
    if split == 'train':
        cropping_rule = [
            torchvision.transforms.RandomCrop(227),
            torchvision.transforms.RandomHorizontalFlip() ]
    else:
        cropping_rule = [torchvision.transforms.CenterCrop(crop_size)]
    places_transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize(256)
        ] + cropping_rule + [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    return torchvision.datasets.ImageFolder(
        dirname, transform=places_transform)

def get_train_and_val_data_loaders():
    return [
        torch.utils.data.DataLoader(
            get_places_data_set(split),
            batch_size=256, shuffle=(split == 'train'),
            num_workers=48, pin_memory=True)
        for split in ['train', 'val']
    ]

Generic Evaluation and Checkpointing Utilities
----------------------------------------------

 * **measure_val_accuracy_and_loss** evaluates the model on the holdout set and reports its performance.
 * **save_model_iteration** saves the current model parameters in a pytorch file.
 * **make_training_monitor** makes a callback function for periodically evaluating and saving a model during training.
 * **AverageMeter** tracks averages (e.g., average accuracy, average loss).

In [None]:
def measure_val_accuracy_and_loss(model, val_data_loader):
    '''
    Evaluates the model (in inference mode) on holdout data.
    '''
    model.eval()
    val_loss, val_acc = AverageMeter(), AverageMeter()
    for input, target in val_data_loader:
        input_var, target_var = [d.cuda() for d in [input, target]]
        with torch.no_grad():
            output = model(input_var)
            loss = torch.nn.functional.cross_entropy(output, target_var)
            _, pred = output.max(1)
            accuracy = (target_var.eq(pred)
                    ).data.float().sum().item() / input.size(0)
        val_acc.update(accuracy, input.size(0))
        val_loss.update(loss.data.item(), input.size(0))
    return val_acc, val_loss

def save_model_iteration(model, iter_num, checkpoint_dir):
    '''
    Saves the current parameters of the model to a file.
    '''
    torch.save(model.state_dict(), os.path.join(checkpoint_dir, f'iter_{iter_num}.pth'))
        
def make_checkpointing_function(val_data_loader, checkpoint_dir=None, checkpoint_freq=100):
    '''
    Makes a callback to monitor training and make checkpoints.
    '''
    avg_train_accuracy, avg_train_loss = AverageMeter(), AverageMeter()
    def monitor(model, iter_num, loss, accuracy, batch_size):
        avg_train_accuracy.update(accuracy, batch_size)
        avg_train_loss.update(loss, batch_size)
        if iter_num % checkpoint_freq == 0:
            val_accuracy, val_loss = measure_val_accuracy_and_loss(model, val_data_loader)
            if checkpoint_dir is not None:
                save_model_iteration(model, iter_num, checkpoint_dir)
            print(f'Iter {iter_num}, ' + 
                  f'train acc {avg_train_accuracy.avg:.3g} loss {avg_train_loss.avg:.3g}, ' +
                  f'val acc {val_accuracy.avg:.3g}, loss {val_loss.avg:.3g}')
            model.train()
    return monitor            
            
class AverageMeter(object):
    '''
    To keep running averages.
    '''
    def __init__(self):
        self.reset()
    def reset(self):
        self.val = 0.
        self.avg = 0.
        self.sum = 0.
        self.count = 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        if self.count:
            self.avg = self.sum / self.count

Now do the work
---------------

Try loading alexnet from a checkpoint.  If we have not yet saved a checkpoint snapshot with the number of iterations we want, then train it. 

In [None]:
num_iterations = 100
try:
    a = make_untrained_alexnet_places()
    a.load_state_dict(torch.load(f'checkpoints/iter_{num_iterations}.pth'))
except:
    a = train_alexnet_places(num_iterations)

Now view one image - reverse the dataset normalization to get a nice image.

In [None]:
dsv = get_places_data_set('val')
im, label = dsv[5000]
im = im.cuda()
# Reverse the normalization
unnormalized = (im.cpu().permute(1, 2, 0)
    * torch.tensor([0.229, 0.224, 0.225])
    + torch.tensor([0.485, 0.456, 0.406]))

from matplotlib import pyplot as plt
plt.imshow(unnormalized)
plt.axis('off')
plt.show()

Finally, run the network on the function and print the prediction.

Note the network expexts to work in batches, so `im[None]` forms an image batch of size one.

In [None]:
a.eval()
output = a(im[None])
pred = output.max(1)[1][0]

print('prediction: ', dsv.classes[pred])
print('groundtruth: ', dsv.classes[label])