In [3]:
'''
Loosely based upon training code in https://github.com/xternalz/WideResNet-pytorch
'''

'\nLoosely based upon training code in https://github.com/xternalz/WideResNet-pytorch\n'

In [7]:
from mlp import MLP
from average_meter import AverageMeter
import time
from torch.autograd import Variable

In [None]:
def train(train_loader, model, criterion, optimizer, 
          performance_stats={}, epoch, total_epochs=-1,
          verbose=True, print_freq=10
          tensorboard_log_function=None,
          tensorboard_stats=['train_loss']):
    '''
    Trains for one epoch. 
    
    x, y are input and target
    y_hat is the predicted output
    
    performance_stats is a dictionary of name:function pairs
    where the function calculates some performance score from y and
    y_hat
    
    see the docs for the 'display_training_stats' function for
    info on verbose, print_freq, tensorboard_log_function, and
    tensorboard_stats
    '''
    
    base_stats = {'time' : AverageMeter(), 'train_loss' : AverageMeter()}
    other_stats = {name:AverageMeter() for name in performance_stats.keys()}
    stats = {**base_stats, **other_stats}
    
    # enter training mode
    model.train()
    
    # begin timing the epoch
    stopwatch = time.time()
    
    # iterate over the batches of the epoch
    for i, (x, y) in enumerate(train_loader):
        y = y.cuda(async=True)
        x = x.cuda()
        # wrap as Variables
        x_var = torch.autograd.Variable(x)
        y_var = torch.autograd.Variable(y)
        
        # forward pass
        y_hat = model(x_var)
        loss = criterion(y_hat, y_var)
        
        # track loss and performance stats
        stats['train_loss'].update(loss.data[0], x.size(0))
        for stat_name, stat_func in performance_stats:
            stats[stat_name].update(stat_func(y_hat.data, y), x.size(0))
        
        # track batch time
        stats['batch_time'].update(time.time() - stopwatch)
        stopwatch = time.time()
        
        # display progress
        display_training_stats('training', stats, i, len(train_loader), 
                               epoch, total_epochs, print_results=verbose, 
                               print_freq=print_freq, 
                               tensorboard_log_function=tensorboard_log_function,
                               tensorboard_stats=tensorboard_stats)

In [None]:
def validate(val_loader, model, criterion, epoch, 
          performance_stats={}, verbose=True, print_freq=10
          tensorboard_log_function=None):
    '''
    Evaluates the model on the validation set.
    
    x, y are input and target
    y_hat is the predicted output
    
    performance_stats is a dictionary of name:function pairs
    where the function calculates some performance score from y and
    y_hat
    '''
    
    batch_time = AverageMeter()
    losses = AvergeMeter()
    stats = {name:AverageMeter() for name in performance_stats.keys()}
    
    

In [None]:
def display_training_stats(phase, stats, batch, ttl_batches, 
                           epoch=1, ttl_epochs=1, print_results=True, 
                           print_freq=10, tensorboard_log_function=None,
                           tensorboard_stats=[]):
    '''
    Handles the logging of training and validation statistics to
    printed output and tensorboard.
    
    phase is a string, typically 'training' or 'validation'
    
    batch, ttl_batches, epoch, and ttl_epochs are integers indicating
    the current epoch & batch and the total number of epochs and batches
    
    if print_results is True, prints results every print_freq batches
    
    if tensorboard_log_function is provided, all stats whose names are
    passed in the tensorboard_stats list will be logged to tensorboard
    using the tensorboard_log_function
    '''
    
    if verbose and (i % print_freq == 0):
        msgs = ['{name} {meter.mean:.4f} ({meter.mean:.4f})'.format(
                            name=name, meter=meter)\
                            for name, meter in stats.items()]
        stats_report = '\t'.join(msgs)
        print('\{{phase}\} epoch: [{epoch}/{ttl_epochs}]'
              'batch [{batch}/{ttl_batches}]\n'.format(
                  phase=phase, epoch=epoch, ttl_epochs=ttl_epochs,
                  batch=batch, ttl_batches=ttl_batches) + stats_report)
    
    if tensorboard_log_function is not None:
        for name in tensorboard_stats:
            meter = stats[name]
            tensorboard_log_function(name, meter.mean, epoch)