Pytorch Alexnet Example
=======================

This is a complete example of training an alexnet on pytorch, fully within notebook, and using nothing but widely-used library functions.

In [None]:
import torch, torchvision, os

def train_alexnet_places():
    alexnet = make_untrained_alexnet_places()
    alexnet.train()
    train_loader, val_loader = get_train_and_val_data_loaders()
    checkpointer = make_checkpointing_function(val_loader, checkpoint_dir='checkpoints')
    train_classifier(alexnet, train_loader,
                     max_iter=100000,
                     momentum=0.9,
                     init_lr=2e-2,
                     weight_decay=5e-4,
                     monitor=checkpointer)
    return alexnet

Untrained Alexnet
-----------------

This function creates an untrained alexnet, with randomized parameters.

In [None]:
def make_untrained_alexnet_places():
    model = torchvision.models.alexnet(num_classes=365)
    # Setup the initial parameters randomly
    for n, p in model.named_parameters():
        if 'bias' in n:
            torch.nn.init.zeros_(p)
        else:
            torch.nn.init.kaiming_normal_(p, nonlinearity='relu')
    return model

Main Training Loop
------------------

This is a generic training loop for a classifier.

In [None]:
def train_classifier(model, train_data_loader, max_iter,
                     momentum=0.9, init_lr=2e-2, weight_decay=5e-4,
                     monitor=None):
    if monitor is not None:
        monitor(model, 0, 0.0, 0.0, 0)
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=init_lr, momentum=momentum, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, init_lr, max_iter - 1)
    iter_num = 0
    while iter_num < max_iter:
        for t_input, t_target in train_data_loader:
            # Copy data into the gpu
            input_var, target_var = [d.cuda() for d in [t_input, t_target]]
            # Evaluate model
            output = model(input_var)
            loss = torch.nn.cross_entropy_loss(output, target_var)
            # Perform one step of SGD
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step() # Learning rate schedule
            # Check training set accuracy
            _, pred = output.max(1)
            batch_size = len(t_input)
            accuracy = target_var.detach().eq(pred).float().sum().item() / batch_size
            # Advance, and print out some stats
            iter_num += 1
            if monitor is not None:
                monitor(model, iter_num, loss, accuracy, batch_size)
            if iter_num >= max_iter:
                break

Data set
--------

This is the definition of the places data set used for training.
If we do not have the files, we download them.  And then we make a
DataSet object that defines how to resize, crop, and normalize the images.

The DataLoader objects wrap the dataset in a multithreaded streaming
object that batches the image data and loads it quickly.

In [None]:
def get_places_data_set(split, crop_size=227, download=True):
    dirname = f'datasets/places/{split}'
    if not os.path.exists(dirname) and download:
        os.makedirs(dirname, exist_ok=True)
        torchvision.datasets.utils.download_and_extract_archive(
            'https://dissect.csail.mit.edu/datasets/' +
            'places_%s.zip' % split,
            'datasets',
            md5=dict(val='593bbc21590cf7c396faac2e600cd30c',
                     train='d1db6ad3fc1d69b94da325ac08886a01')[split])
    if split == 'train':
        cropping_rule = [
            torchvision.transforms.RandomCrop(227),
            torchvision.transforms.RandomHorizontalFlip() ]
    else:
        cropping_rule = [torchvision.transforms.CenterCrop(crop_size)]
    places_transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize(256)
        ] + cropping_rule + [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    return torchvision.datasets.ImageFolder(
        dirname, transform=places_transform)

def get_train_and_val_data_loaders():
    return [
        torch.utils.data.DataLoader(
            get_places_data_set(split),
            batch_size=256, shuffle=(split == 'train'),
            num_workers=48, pin_memory=True)
        for split in ['train', 'val']
    ]

Generic Evaluation and Checkpointing Utilities
----------------------------------------------

 * **measure_val_accuracy_and_loss** evaluates the model on the holdout set and reports its performance.
 * **save_model_iteration** saves the current model parameters in a pytorch file.
 * **make_training_monitor** makes a callback function for periodically evaluating and saving a model during training.
 * **AverageMeter** tracks averages (e.g., average accuracy, average loss).

In [None]:
def measure_val_accuracy_and_loss(model, val_data_loader):
    '''
    Evaluates the model (in inference mode) on holdout data.
    '''
    model.eval()
    val_loss, val_acc = AverageMeter(), AverageMeter()
    for input, target in pbar(val_loader):
        input_var, target_var = [d.cuda() for d in [input, target]]
        with torch.no_grad():
            output = model(input_var)
            loss = criterion(output, target_var)
            _, pred = output.max(1)
            accuracy = (target_var.eq(pred)
                    ).data.float().sum().item() / input.size(0)
        val_loss.update(loss.data.item(), input.size(0))
        val_acc.update(accuracy, input.size(0))
    return val_loss, val_acc

def save_model_iteration(model, iter_num, checkpoint_dir):
    '''
    Saves the current parameters of the model to a file.
    '''
    torch.save(model.state_dict(), os.path.join(checkpoint_dir, f'iter_{iter_num}.pth'))
        
def make_checkpointing_function(val_data_loader, checkpoint_dir=None, checkpoint_freq=1000):
    '''
    Makes a callback to monitor training and make checkpoints.
    '''
    avg_train_accuracy, avg_train_loss = AverageMeter(), AverageMeter()
    def monitor(model, iter_num, loss, accuracy, batch_size):
        avg_train_accuracy.update(accuracy, batch_size)
        avg_train_loss.update(loss, batch_size)
        if iter_num % checkpoint_freq == 0:
            val_accuracy, val_loss = measure_val_accuracy_and_loss(model, val_data_loader)
            if checkpoint_dir is not None:
                save_model_iteration(model, iter_num, checkpoint_dir)
            print(f'Iter {iter_num}, ' + 
                  f'train acc {avg_train_accuracy.avg} loss {avg_train_loss.avg}, ' +
                  f'val acc {val_accuracy.avg}, loss {val_loss.avg}')
            model.train()
    return monitor            
            
class AverageMeter(object):
    '''
    To keep running averages.
    '''
    def __init__(self):
        self.reset()
    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

Now do the work
---------------

In [None]:
train_alexnet_places()