In [None]:
%reload_ext autoreload
%autoreload 2
import numpy as np
import os
import sys
sys.path.append('..')
import pickle
import argparse
import time
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision.transforms as trn
import torchvision.datasets as dset
import torch.nn.functional as F
from tqdm import tqdm_notebook, tqdm
from models import resnet
from utils.tinyimages_80mn_loader import TinyImages
from utils.validation_dataset import validation_split
import utils.svhn_loader as svhn

In [None]:
args = {
    'dataset':'svhn',  
    'model': 'ResNet34',
    'calibration': '',
    'epochs': 5,  
    'learning_rate': 0.001,
    'batch_size': 128,
    'oe_batch_size': 256,
    'test_bs': 200,
    'momentum': 0.9,
    'decay': 0.0005,  # Weight decay (L2 penalty)
    'save': './Mahalanobis_Experiments/results/Mahal_OECC_tune',
    'test': 'store_true',
    'ngpu': 2,
    'prefetch': 4,
    'lambda_1': 0.07,  ## To push the known classes to higher prediction probabilities
    'lambda_2': 0.03  ## To push the outliers towards uniform
}

In [None]:
cd ..

In [None]:
root='SVHN'

state = {k: v for k, v in args.items()}

torch.manual_seed(1)
np.random.seed(1)


train_data_in = svhn.SVHN(root, split='train',
                          transform=trn.Compose([trn.ToTensor(),trn.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]), download=False)
test_data = svhn.SVHN(root, split='test',
                       transform=trn.Compose([trn.ToTensor(),trn.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]), download=False)
num_classes = 10

calib_indicator = ''
if args['calibration']:
    train_data_in, val_data = validation_split(train_data_in, val_share=5000/604388.)
    calib_indicator = 'calib_'

tiny_images = TinyImages(transform=trn.Compose(
    [trn.ToTensor(), trn.ToPILImage(),
     trn.RandomHorizontalFlip(), trn.ToTensor(),trn.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]))


train_loader_in = torch.utils.data.DataLoader(
    train_data_in,
    batch_size=args['batch_size'], shuffle=True,
    num_workers=args['prefetch'], pin_memory=True)

train_loader_out = torch.utils.data.DataLoader(
    tiny_images,
    batch_size=args['oe_batch_size'], shuffle=False,
    num_workers=args['prefetch'], pin_memory=True)

test_loader = torch.utils.data.DataLoader(
    test_data,
    batch_size=args['batch_size'], shuffle=False,
    num_workers=args['prefetch'], pin_memory=True)

In [None]:
net = resnet.ResNet34(num_classes) 
net.load_state_dict(torch.load('./Mahalanobis_Experiments/pre_trained/resnet_svhn.pth')) 

net = torch.nn.DataParallel(net,[0,1]).cuda()

device = torch.device('cuda:0')    
cudnn.benchmark = True  # fire on all cylinders

# Optimizer
optimizer = torch.optim.SGD(
    net.parameters(), state['learning_rate'], momentum=state['momentum'],
    weight_decay=state['decay'], nesterov=True)

# Learning Rate
def cosine_annealing(step, total_steps, lr_max, lr_min):
    return lr_min + (lr_max - lr_min) * 0.5 * (
            1 + np.cos(step / total_steps * np.pi))


scheduler = torch.optim.lr_scheduler.LambdaLR(
    optimizer,
    lr_lambda=lambda step: cosine_annealing(
        step,
        args['epochs'] * len(train_loader_in),
        1,  # since lr_lambda computes multiplicative factor
        1e-6 / args['learning_rate']))

In [None]:
# /////////////// Training ///////////////
def train():
    net.train()  # enter train mode
    loss_avg = 0.0

    # start at a random point of the outlier dataset; this induces more randomness without obliterating locality
    train_loader_out.dataset.offset = np.random.randint(len(train_loader_out.dataset))
    for in_set, out_set in zip(train_loader_in, train_loader_out):
        data = torch.cat((in_set[0], out_set[0]), 0)
        target = in_set[1]

        data, target = data.to(device), target.long().to(device)
        
        # forward
        x = net(data)

        # backward
        scheduler.step()
        optimizer.zero_grad()

        loss = F.cross_entropy(x[:len(in_set[0])], target)
        
        ## OECC Loss Function
        A_tr = 0.967 # Training accuracy of the baseline 
        sm = torch.nn.Softmax(dim=1) # Create a Softmax 
        probabilities = sm(x) # Get the probabilites for both In and Outlier Images
        max_probs, _ = torch.max(probabilities, dim=1) # Take the maximum probabilities produced by softmax
        prob_diff_in = max_probs[:len(in_set[0])] - A_tr  
        loss += args['lambda_1'] * torch.sum(prob_diff_in**2) ## 1st Regularization term
        prob_diff_out = probabilities[len(in_set[0]):][:] - (1/num_classes)
        loss += args['lambda_2'] * torch.sum(torch.abs(prob_diff_out)) ## 2nd Regularization term
                
        loss.backward()
        optimizer.step()
        
        # exponential moving average
        loss_avg = loss_avg * 0.8 + float(loss) * 0.2
    
    state['train_loss'] = loss_avg

# test function
def test():
    net.eval()
    loss_avg = 0.0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.long().to(device)

            # forward
            output = net(data)
            loss = F.cross_entropy(output, target)

            # accuracy
            pred = output.data.max(1)[1]
            correct += pred.eq(target.data).sum().item()

            # test loss average
            loss_avg += float(loss.data)            
            
    state['test_loss'] = loss_avg / len(test_loader)
    state['test_accuracy'] = correct / len(test_loader.dataset)


# Make save directory
if not os.path.exists(args['save']):
    os.makedirs(args['save'])
if not os.path.isdir(args['save']):
    raise Exception('%s is not a dir' % args['save'])

with open(os.path.join(args['save'], args['dataset'] + calib_indicator + '_' + args['model'] +
                                  '_Mahal_OECC_tune_training_results.csv'), 'w') as f:
    f.write('epoch,time(s),train_loss,test_loss,test_error(%)\n')

print('Beginning Training\n')


# Main loop
for epoch in range(0, args['epochs']):
    state['epoch'] = epoch

    begin_epoch = time.time()

    train()
    test()

    # Save model
    torch.save(net.state_dict(),
               os.path.join(args['save'], args['dataset'] + calib_indicator + '_' + args['model'] +
                            '_Mahal_OECC_tune_epoch_' + str(epoch) + '.pth'))
    # Let us not waste space and delete the previous model
    prev_path = os.path.join(args['save'], args['dataset'] + calib_indicator + '_' + args['model'] +
                             '_Mahal_OECC_tune_epoch_' + str(epoch - 1) + '.pth')
    if os.path.exists(prev_path): os.remove(prev_path)

    # Show results

    with open(os.path.join(args['save'], args['dataset'] + calib_indicator + '_' + args['model'] +
                                      '_Mahal_OECC_tune_training_results.csv'), 'a') as f:
        f.write('%03d,%05d,%0.6f,%0.5f,%0.2f\n' % (
            (epoch + 1),
            time.time() - begin_epoch,
            state['train_loss'],
            state['test_loss'],
            100 - 100. * state['test_accuracy'],
        ))

    # # print state with rounded decimals
    # print({k: round(v, 4) if isinstance(v, float) else v for k, v in state.items()})

    print('Epoch {0:3d} | Time {1:5d} | Train Loss {2:.4f} | Test Loss {3:.3f} | Test Error {4:.2f}'.format(
        (epoch + 1),
        int(time.time() - begin_epoch),
        state['train_loss'],
        state['test_loss'],
        100 - 100. * state['test_accuracy'])
    )