In [None]:
import sys
sys.path.append('..')
import numpy as np
import os
import argparse
import time
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision.transforms as trn
import torchvision.datasets as dset
import torch.nn.functional as F
from tqdm import tqdm_notebook, tqdm
from SVHN.models.allconv import AllConvNet
from SVHN.models.wrn import WideResNet
import utils.svhn_loader as svhn
from utils.validation_dataset import validation_split

In [None]:
args = {
        'calibration': '',
        'epochs': 20,
        'dataset':'svhn', 
        'learning_rate': 0.01,
        'batch_size': 128,
        'test_bs': 200,
        'model':'wrn', 
        'momentum': 0.9,
        'decay': 0.0005,
        'save':'./SVHN/results/baseline',  
        'load': '',
        'test': 'store_true',
        'layers':16,
        'widen_factor':4,
        'droprate':0.4,
        'ngpu': 1,
        'prefetch': 4,    
        }

In [4]:
# !wget http://ufldl.stanford.edu/housenumbers/extra 32x32.mat

In [None]:
root='../SVHN' 

state = {k: v for k, v in args.items()}

torch.manual_seed(1)
np.random.seed(1)

train_data = svhn.SVHN(root, split='train_and_extra',
                       transform=trn.ToTensor(), download=False)
test_data = svhn.SVHN(root, split='test',
                      transform=trn.ToTensor(), download=False)
num_classes = 10

calib_indicator = ''
if args['calibration']:
    train_data, val_data = validation_split(train_data, val_share=5000/604388.)
    calib_indicator = 'calib_' 

train_loader = torch.utils.data.DataLoader(
    train_data, batch_size=args['batch_size'], shuffle=True,
    num_workers=args['prefetch'], pin_memory=True)
test_loader = torch.utils.data.DataLoader(
    test_data, batch_size=args['test_bs'], shuffle=False,
    num_workers=args['prefetch'], pin_memory=True)

In [None]:
# Create model
if args['model'] == 'allconv':
    net = AllConvNet(num_classes)
else:
    net = WideResNet(args['layers'], num_classes, args['widen_factor'], dropRate=args['droprate'])

start_epoch = 0

# Restore model if desired
if args['load'] != '':
    for i in range(1000 - 1, -1, -1):
        model_name = os.path.join( args['load'], args['dataset'] + calib_indicator + '_' + args['model'] +
                                  '_baseline_epoch_' + str(i) + '.pt')
        if os.path.isfile(model_name):
            net.load_state_dict(torch.load(model_name))
            print('Model restored! Epoch:', i)
            start_epoch = i + 1
            break
    if start_epoch == 0:
        assert False, "could not resume"

if args['ngpu'] > 1:
    net = torch.nn.DataParallel(net, device_ids=list(range(args['ngpu'])))

if args['ngpu'] > 0:
    net.cuda()
    torch.cuda.manual_seed(1)

cudnn.benchmark = True  # fire on all cylinders

optimizer = torch.optim.SGD(
    net.parameters(), state['learning_rate'], momentum=state['momentum'],
    weight_decay=state['decay'], nesterov=True)


def cosine_annealing(step, total_steps, lr_max, lr_min):
    return lr_min + (lr_max - lr_min) * 0.5 * (
            1 + np.cos(step / total_steps * np.pi))


scheduler = torch.optim.lr_scheduler.LambdaLR(
    optimizer,
    lr_lambda=lambda step: cosine_annealing(
        step,
        args['epochs'] * len(train_loader),
        1,  # since lr_lambda computes multiplicative factor
        1e-6 / args['learning_rate']))


# /////////////// Training ///////////////

def train():
    net.train()  # enter train mode
    loss_avg = 0.0
    for data, target in train_loader:
        data, target = data.cuda(), target.long().cuda()  

        # forward
        x = net(data)

        # backward
        scheduler.step()
        optimizer.zero_grad()
        loss = F.cross_entropy(x, target)
        loss.backward()
        optimizer.step()

        # exponential moving average
        loss_avg = loss_avg * 0.8 + float(loss) * 0.2

    state['train_loss'] = loss_avg

# test function
def test():
    net.eval()
    loss_avg = 0.0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.cuda(), target.long().cuda()

            # forward
            output = net(data)
            loss = F.cross_entropy(output, target)

            # accuracy
            pred = output.data.max(1)[1]
            correct += pred.eq(target.data).sum().item()

            # test loss average
            loss_avg += float(loss.data)

    state['test_loss'] = loss_avg / len(test_loader)
    state['test_accuracy'] = correct / len(test_loader.dataset)


# if args['test']:
#     test()
#     print(state)
#     exit()

# Make save directory
if not os.path.exists(args['save']):
    os.makedirs(args['save'])
if not os.path.isdir(args['save']):
    raise Exception('%s is not a dir' % args['save'])

with open(os.path.join(args['save'], args['dataset'] + calib_indicator + '_' + args['model'] +
                                  '_baseline_training_results.csv'), 'w') as f:
    f.write('epoch,time(s),train_loss,test_loss,test_error(%)\n')

print('Beginning Training\n')

# Main loop
for epoch in range(start_epoch, args['epochs']):
    state['epoch'] = epoch

    begin_epoch = time.time()

    train()
    test()

    # Save model
    torch.save(net.state_dict(),
               os.path.join(args['save'], args['dataset'] + calib_indicator + '_' + args['model'] +
                            '_baseline_epoch_' + str(epoch) + '.pt'))
    # Let us not waste space and delete the previous model
    prev_path = os.path.join(args['save'], args['dataset'] + calib_indicator + '_' + args['model'] +
                             '_baseline_epoch_' + str(epoch - 1) + '.pt')
    if os.path.exists(prev_path): os.remove(prev_path)

    # Show results

    with open(os.path.join(args['save'], args['dataset'] + calib_indicator + '_' + args['model'] + 
                                      '_baseline_training_results_epoch_20.csv'), 'a') as f:
        f.write('%03d,%05d,%0.6f,%0.5f,%0.2f\n' % (
            (epoch + 1),
            time.time() - begin_epoch,
            state['train_loss'],
            state['test_loss'],
            100 - 100. * state['test_accuracy'],
        ))

    # # print state with rounded decimals
    # print({k: round(v, 4) if isinstance(v, float) else v for k, v in state.items()})

    print('Epoch {0:3d} | Time {1:5d} | Train Loss {2:.4f} | Test Loss {3:.3f} | Test Error {4:.2f}'.format(
        (epoch + 1),
        int(time.time() - begin_epoch),
        state['train_loss'],
        state['test_loss'],
        100 - 100. * state['test_accuracy'])
    )
