In [1]:
#Expand notebook to take full screen width
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:95% !important; }</style>"))

#Jupyter magic to notify when a cell finishes execution with %%notify command -- does not work with Jupyterlab
import jupyternotify
ip = get_ipython()
ip.register_magics(jupyternotify.JupyterNotifyMagics)

###
import sys
sys.path.insert(0,'../src/')

%load_ext autoreload
%autoreload 2

<IPython.core.display.Javascript object>

In [7]:
#execution example: python retrain.py --model SWSModel --alpha 2500 --beta 10 --tau 1e-6 --mixtures 8 --temp 10
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as dsets
from torch.nn.modules import Module
from torch.autograd import Variable
import numpy as np

import model_archs
from utils_plot import show_sws_weights, show_weights, print_dims, prune_plot, draw_sws_graphs, joint_plot
from utils_model import test_accuracy, train_epoch, retrain_sws_epoch, model_prune, get_weight_penalty
from utils_misc import trueAfterN, logsumexp, root_dir, model_load_dir
from utils_sws import GaussianMixturePrior, special_flatten, KL, compute_responsibilies, merger, sws_prune, sws_prune_l2
from mnist_loader import search_train_data, search_retrain_data, search_validation_data, train_data, test_data, batch_size
import copy
import pickle
import argparse
retraining_epochs = 50

def retrain_model(alpha, beta, tau, temp, mixtures, model_name, data_size, lr, model_save_dir = "", scaling = False):
    if(data_size == 'search'):
        train_dataset = search_retrain_data
        val_data_full = Variable(search_validation_data(fetch='data')).cuda()
        val_labels_full = Variable(search_validation_data(fetch='labels')).cuda()
        (x_start, x_end) = (40000, 50000)
    if(data_size == 'full'):
        train_dataset = train_data
        (x_start, x_end) = (0, 60000)
    test_data_full = Variable(test_data(fetch='data')).cuda()
    test_labels_full = Variable(test_data(fetch='labels')).cuda()
        
    model_file = 'mnist_{}_{}_{}'.format(model_name, 100, data_size)
    model = torch.load(model_load_dir + model_file + '.m').cuda()
        
    if temp == 0:
        criterion = nn.CrossEntropyLoss()
        loader = torch.utils.data.DataLoader(dataset=train_dataset(), batch_size=batch_size, shuffle=True)
        temp_mult = 1
    else:
        criterion = nn.MSELoss()
        output = torch.load("{}{}_targets/{}.out.m".format(model_load_dir, model_file.replace("search", "full"), "fc2"))[x_start:x_end]#.data / float(temp)
        output = (nn.Softmax(dim=1)(output/temp)).data
        dataset = torch.utils.data.TensorDataset(train_dataset(fetch='data'), output)
        loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)
        #new_model = model_archs.SWSModelKD().cuda()
        #new_model.load_state_dict(model.state_dict())
        #del model
        #model = new_model
        temp_mult = temp ** 2

    exp_name = "{}_a{}_b{}_r{}_t{}_m{}_kdT{}_{}".format(model.name, alpha, beta, retraining_epochs, tau, int(mixtures), int(temp), data_size)
    gmp = GaussianMixturePrior(mixtures, [x for x in model.parameters()], 0.99, ab = (alpha, beta), scaling = scaling)
    gmp.print_batch = False

    opt = torch.optim.Adam([
        {'params': model.parameters(), 'lr': lr[0]},
        {'params': [gmp.means], 'lr': lr[1]},
        {'params': [gmp.gammas, gmp.rhos], 'lr': lr[2]}])#log precisions and mixing proportions

    for epoch in range(retraining_epochs):
        model, loss = retrain_sws_epoch(model, gmp, opt, criterion, loader, tau, temp_mult)

        if (trueAfterN(epoch, 25)):
            ###
            #show_sws_weights(model = model, means = list(gmp.means.data.clone().cpu()), precisions = list(gmp.gammas.data.clone().cpu()), epoch = epoch)
            test_acc = test_accuracy(test_data_full, test_labels_full, model)
            print('Epoch: {}. Test Accuracy: {:.2f}'.format(epoch+1, test_acc[0]))
    if(model_save_dir!=""):
        torch.save(model, model_save_dir + '/mnist_retrain_{}.m'.format(exp_name))
        with open(model_save_dir + '/mnist_retrain_{}_gmp.p'.format(exp_name),'wb') as f:
            pickle.dump(gmp, f)
    
    test_accuracy_pre = float((test_accuracy(test_data_full, test_labels_full, model)[0]))
    val_accuracy_pre = 0 if (data_size != 'search') else float((test_accuracy(val_data_full, val_labels_full, model)[0]))
    
    model_prune = copy.deepcopy(model)
    model_prune.load_state_dict(sws_prune_l2(model_prune, gmp))
    prune_acc = (test_accuracy(test_data_full, test_labels_full, model_prune))
    test_accuracy_prune = float((test_accuracy(test_data_full, test_labels_full, model_prune)[0]))
    val_accuracy = 0 if (data_size != 'search') else float((test_accuracy(val_data_full, val_labels_full, model_prune)[0]))
    sparsity = (special_flatten(model_prune.state_dict())==0).sum()/(special_flatten(model_prune.state_dict())>0).numel() * 100
    print('Retrain Test: {:.2f}, Retrain Validation: {:.2f}, Prune Test: {:.2f}, Prune Validation: {:.2f}, Prune Sparsity: {:.2f}'
          .format(test_accuracy_pre, val_accuracy_pre, test_accuracy_prune, val_accuracy, sparsity))
    
        
    return model, gmp

### Measuring Variance in Results

In [12]:
with open("../search/sobol_search.p", "rb") as handle:
    params = pickle.load(handle)
i = 240
print ("exp:{} mean: {}, var: {}, tau: {}, temp: {}, mixtures: {}".format(i, params['mean'][i], params['var'][i], params['tau'][i], params['temp'][i], int(params['mixtures'][i])))
mean = float(params['mean'][i])
var = float(params['var'][i])
beta = mean/var
alpha = mean * beta

exp:240 mean: 0.10181517217181825, var: 0.1034304555503225, tau: 6.479230606842127e-07, temp: 19.0, mixtures: 9


In [9]:
for j in range(0,10):
    _, _ = retrain_model(alpha, beta, float(params['tau'][i]), params['temp'][i], int(params['mixtures'][i]), 'SWSModel', 'search', (5e-4, 1e-4, 3e-3), "", False)

0-component Mean: 2500.0 Variance: 1250.0
Non-zero component Mean: 0.10181517217181825 Variance: 0.1034304555503225
Epoch: 25. Test Accuracy: 98.63
Epoch: 50. Test Accuracy: 98.70
Retrain Test: 98.70, Retrain Validation: 98.52, Prune Test: 98.66, Prune Validation: 98.41, Prune Sparsity: 70.09
0-component Mean: 2500.0 Variance: 1250.0
Non-zero component Mean: 0.10181517217181825 Variance: 0.1034304555503225
Epoch: 25. Test Accuracy: 98.74
Epoch: 50. Test Accuracy: 98.69
Retrain Test: 98.69, Retrain Validation: 98.67, Prune Test: 98.44, Prune Validation: 98.37, Prune Sparsity: 71.17
0-component Mean: 2500.0 Variance: 1250.0
Non-zero component Mean: 0.10181517217181825 Variance: 0.1034304555503225
Epoch: 25. Test Accuracy: 98.80
Epoch: 50. Test Accuracy: 98.70
Retrain Test: 98.70, Retrain Validation: 98.57, Prune Test: 98.64, Prune Validation: 98.44, Prune Sparsity: 68.92
0-component Mean: 2500.0 Variance: 1250.0
Non-zero component Mean: 0.10181517217181825 Variance: 0.1034304555503225
Ep

### Running over full dataset

#### Search 10K Parameters

In [14]:
model, gmp = retrain_model(alpha, beta, float(params['tau'][i]), params['temp'][i], int(params['mixtures'][i]), 'SWSModel', 'search', (5e-4, 1e-4, 3e-3), "", False)

0-component Mean: 2500.0 Variance: 1250.0
Non-zero component Mean: 0.10181517217181825 Variance: 0.1034304555503225
Epoch: 25. Test Accuracy: 98.63
Epoch: 50. Test Accuracy: 98.67
Retrain Test: 98.67, Retrain Validation: 98.74, Prune Test: 98.63, Prune Validation: 98.56, Prune Sparsity: 70.45


#### Full 60K Parameters

In [15]:
model, gmp = retrain_model(alpha, beta, float(params['tau'][i]), params['temp'][i], int(params['mixtures'][i]), 'SWSModel', 'full', (5e-4, 1e-4, 3e-3), "", False)

0-component Mean: 2500.0 Variance: 1250.0
Non-zero component Mean: 0.10181517217181825 Variance: 0.1034304555503225
Epoch: 25. Test Accuracy: 98.89
Epoch: 50. Test Accuracy: 98.94
Retrain Test: 98.94, Retrain Validation: 0.00, Prune Test: 98.75, Prune Validation: 0.00, Prune Sparsity: 79.67


#### Search 10K Parameters w/ Scaling

In [16]:
model, gmp = retrain_model(alpha, beta, float(params['tau'][i]), params['temp'][i], int(params['mixtures'][i]), 'SWSModel', 'search', (5e-4, 1e-4, 3e-3), "", True)

0-component Mean: 2500.0 Variance: 1250.0
Non-zero component Mean: 0.10181517217181825 Variance: 0.1034304555503225
Epoch: 25. Test Accuracy: 98.74
Epoch: 50. Test Accuracy: 98.55
Retrain Test: 98.55, Retrain Validation: 98.50, Prune Test: 98.52, Prune Validation: 98.41, Prune Sparsity: 71.14


#### Full 60K Parameters w/ Scaling

In [17]:
model, gmp = retrain_model(alpha, beta, float(params['tau'][i]), params['temp'][i], int(params['mixtures'][i]), 'SWSModel', 'full', (5e-4, 1e-4, 3e-3), "", True)

0-component Mean: 2500.0 Variance: 1250.0
Non-zero component Mean: 0.10181517217181825 Variance: 0.1034304555503225
Epoch: 25. Test Accuracy: 98.79
Epoch: 50. Test Accuracy: 99.00
Retrain Test: 99.00, Retrain Validation: 0.00, Prune Test: 98.51, Prune Validation: 0.00, Prune Sparsity: 78.08
