In [1]:
#Expand notebook to take full screen width
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

### Layered Training

In [2]:
import sys
sys.path.insert(0,'../../src/')

In [3]:
#execution example: python layer_retrain.py --layer 1 --alpha 2500 --beta 10 --tau 1e-6 --mixtures 4
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as dsets
from torch.nn.modules import Module
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
import copy
import pickle
import argparse

model_dir = "./models/"
import model_archs
from utils_plot import show_sws_weights, show_weights, print_dims, prune_plot, draw_sws_graphs, joint_plot
from utils_model import test_accuracy, train_epoch, retrain_sws_epoch, model_prune, get_weight_penalty, layer_accuracy
from utils_misc import trueAfterN, logsumexp
from utils_sws import GaussianMixturePrior, special_flatten, KL, compute_responsibilies, merger, sws_prune
from mnist_loader import search_train_data, search_retrain_data, search_validation_data, train_data, test_data, batch_size

def retrain_layer(model_retrain, model_orig, data_loader, test_data_full, test_labels_full, alpha, beta, tau, mixtures, model_dir):
    weight_loader = model_retrain.state_dict()
    for layer in model_retrain.state_dict():
        weight_loader[layer] = model_orig.state_dict()[layer]
    model_retrain.load_state_dict(weight_loader)

    exp_name = "{}_a{}_b{}_r{}_t{}_m{}_kdT{}_{}".format(model_retrain.name, alpha, beta, 50, tau, int(mixtures), int(temp), data_size)
    gmp = GaussianMixturePrior(mixtures, [x for x in model_retrain.parameters()], 0.99, ab = (alpha, beta), scaling = False)
    gmp.print_batch = False

    print ("Model Name: {}".format(model_retrain.name))
    criterion = nn.MSELoss()
    opt = torch.optim.Adam([
        {'params': model_retrain.parameters(), 'lr': 1e-4},
        {'params': [gmp.means], 'lr': 1e-4},
        {'params': [gmp.gammas, gmp.rhos], 'lr': 3e-3}])#log precisions and mixing proportions

    
    for epoch in range(50):
        model_retrain, loss = retrain_sws_epoch(model_retrain, gmp, opt, criterion, data_loader, tau)

        if (trueAfterN(epoch, 10)):
            print('Epoch: {}. Loss: {:.2f}'.format(epoch+1, float(loss.data)))
            layer_accuracy(model_retrain, gmp, model_orig, test_data_full, test_labels_full)
            
    if(model_save_dir!=""):
        torch.save(model_retrain, model_save_dir + '/mnist_retrain_{}.m'.format(exp_name))
        with open(model_save_dir + '/mnist_retrain_{}_gmp.p'.format(exp_name),'wb') as f:
            pickle.dump(gmp, f)
            
    return model_retrain, gmp

In [4]:
from utils_misc import  model_load_dir
from extract_targets import get_targets

###
alpha = 100
beta = 10
tau = 1e-6
layer = 2
mixtures = 6
temp = 6

data_size = 'search'

###
test_data_full =  Variable(test_data(fetch = "data")).cuda()
test_labels_full =  Variable(test_data(fetch = "labels")).cuda()
val_data_full =  Variable(search_validation_data(fetch = "data")).cuda()
val_labels_full =  Variable(search_validation_data(fetch = "labels")).cuda()

model_name = "SWSModel"
model_file = 'mnist_{}_{}_{}'.format(model_name, 100, data_size)
model_orig = torch.load(model_load_dir + model_file + '.m').cuda()
target_dir = model_file.replace("search", "full")
 

In [5]:
x_start = 0
x_end = 60000
if (data_size == "search"):
    x_start = 40000
    x_end = 50000

In [7]:
layer = 4
if (layer == 1):
    layer_model = model_archs.SWSModelConv1().cuda()
    input = Variable(train_data(fetch = "data")[x_start:x_end]).cuda()
    output = get_targets(target_dir, temp, ["conv1.out"])["conv1.out"][x_start:x_end]
if (layer == 2):
    layer_model = model_archs.SWSModelConv2().cuda()
    input = nn.ReLU()(get_targets(target_dir, temp, ["conv1.out"])["conv1.out"][x_start:x_end])
    output = get_targets(target_dir, temp, ["conv2.out"])["conv2.out"][x_start:x_end]
if (layer == 3):
    layer_model = model_archs.SWSModelFC1().cuda()
    input = nn.ReLU()(get_targets(target_dir, temp, ["conv2.out"])["conv2.out"][x_start:x_end])
    output = get_targets(target_dir, temp, ["fc1.out"])["fc1.out"][x_start:x_end]
if (layer == 4):
    layer_model = model_archs.SWSModelFC2().cuda()
    input = nn.ReLU()(get_targets(target_dir, temp, ["fc1.out"])["fc1.out"][x_start:x_end])
    output = get_targets(target_dir, temp, ["fc2.out"])["fc2.out"][x_start:x_end]

print (type(input), type(output))

<class 'torch.autograd.variable.Variable'> <class 'torch.autograd.variable.Variable'>


In [21]:
dataset = torch.utils.data.TensorDataset(input.data, output.data)
loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)

model, gmp = retrain_layer(layer_model, model_orig, loader, test_data_full, test_labels_full, alpha, beta, tau, mixtures, model_dir + model_file)

0-component Mean: 2500.0 Variance: 1250.0
Non-zero component Mean: 10.0 Variance: 1.0
Model Name: SWSModelConv2
Epoch: 10. Loss: 11.15
Original: 98.71% - Retrain: 98.58% - Prune: 13.15%
Epoch: 20. Loss: 6.32
Original: 98.71% - Retrain: 98.35% - Prune: 9.85%
Epoch: 30. Loss: 3.75
Original: 98.71% - Retrain: 98.16% - Prune: 9.74%
Epoch: 40. Loss: 1.86
Original: 98.71% - Retrain: 97.91% - Prune: 9.74%
Epoch: 50. Loss: 1.07
Original: 98.71% - Retrain: 97.72% - Prune: 9.74%


NameError: name 'model_save_dir' is not defined