In [1]:
#Expand notebook to take full screen width
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

### Layered Training

In [2]:
import sys
sys.path.insert(0,'../../src/')

In [3]:
#execution example: python layer_retrain.py --layer 1 --alpha 2500 --beta 10 --tau 1e-6 --mixtures 4
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as dsets
from torch.nn.modules import Module
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
import copy
import pickle
import argparse

model_dir = "./models/"
import model_archs
from utils_plot import show_sws_weights, show_weights, print_dims, prune_plot, draw_sws_graphs, joint_plot
from utils_model import test_accuracy, train_epoch, retrain_sws_epoch, model_prune, get_weight_penalty, layer_accuracy
from utils_misc import trueAfterN, logsumexp
from utils_sws import GaussianMixturePrior, special_flatten, KL, compute_responsibilies, merger, sws_prune
from mnist_loader import search_train_data, search_retrain_data, search_validation_data, train_data, test_data, batch_size

def retrain_layer(model_retrain, model_orig, data_loader, test_data_full, test_labels_full, alpha, beta, tau, mixtures, model_dir):
    weight_loader = model_retrain.state_dict()
    for layer in model_retrain.state_dict():
        weight_loader[layer] = model_orig.state_dict()[layer]
    model_retrain.load_state_dict(weight_loader)

    exp_name = "{}_a{}_b{}_r{}_t{}_m{}_kdT{}_{}".format(model_retrain.name, alpha, beta, 50, tau, int(mixtures), int(temp), data_size)
    gmp = GaussianMixturePrior(mixtures, [x for x in model_retrain.parameters()], 0.99, ab = (alpha, beta), scaling = False)
    gmp.print_batch = False

    print ("Model Name: {}".format(model_retrain.name))
    criterion = nn.MSELoss()
    opt = torch.optim.Adam([
        {'params': model_retrain.parameters(), 'lr': 1e-4},
        {'params': [gmp.means], 'lr': 1e-4},
        {'params': [gmp.gammas, gmp.rhos], 'lr': 3e-3}])#log precisions and mixing proportions

    
    for epoch in range(50):
        model_retrain, loss = retrain_sws_epoch(model_retrain, gmp, opt, criterion, data_loader, tau)

        if (trueAfterN(epoch, 10)):
            print('Epoch: {}. Loss: {:.2f}'.format(epoch+1, float(loss.data)))
            layer_accuracy(model_retrain, gmp, model_orig, test_data_full, test_labels_full)
            
    if(model_save_dir!=""):
        torch.save(model_retrain, model_save_dir + '/mnist_retrain_{}.m'.format(exp_name))
        with open(model_save_dir + '/mnist_retrain_{}_gmp.p'.format(exp_name),'wb') as f:
            pickle.dump(gmp, f)
            
    return model_retrain, gmp

In [4]:
from utils_misc import  model_load_dir
from extract_targets import get_targets

###
alpha = 100
beta = 10
tau = 1e-6
layer = 2
mixtures = 6
temp = 6

data_size = 'search'

###
test_data_full =  Variable(test_data(fetch = "data")).cuda()
test_labels_full =  Variable(test_data(fetch = "labels")).cuda()
val_data_full =  Variable(search_validation_data(fetch = "data")).cuda()
val_labels_full =  Variable(search_validation_data(fetch = "labels")).cuda()

model_name = "SWSModel"
model_file = 'mnist_{}_{}_{}'.format(model_name, 100, data_size)
model_orig = torch.load(model_load_dir + model_file + '.m').cuda()
target_dir = model_file.replace("search", "full")
 

In [5]:
x_start = 0
x_end = 60000
if (data_size == "search"):
    x_start = 40000
    x_end = 50000

In [7]:
layer = 4
if (layer == 1):
    layer_model = model_archs.SWSModelConv1().cuda()
    input = Variable(train_data(fetch = "data")[x_start:x_end]).cuda()
    output = get_targets(target_dir, temp, ["conv1.out"])["conv1.out"][x_start:x_end]
if (layer == 2):
    layer_model = model_archs.SWSModelConv2().cuda()
    input = nn.ReLU()(get_targets(target_dir, temp, ["conv1.out"])["conv1.out"][x_start:x_end])
    output = get_targets(target_dir, temp, ["conv2.out"])["conv2.out"][x_start:x_end]
if (layer == 3):
    layer_model = model_archs.SWSModelFC1().cuda()
    input = nn.ReLU()(get_targets(target_dir, temp, ["conv2.out"])["conv2.out"][x_start:x_end])
    output = get_targets(target_dir, temp, ["fc1.out"])["fc1.out"][x_start:x_end]
if (layer == 4):
    layer_model = model_archs.SWSModelFC2().cuda()
    input = nn.ReLU()(get_targets(target_dir, temp, ["fc1.out"])["fc1.out"][x_start:x_end])
    output = get_targets(target_dir, temp, ["fc2.out"])["fc2.out"][x_start:x_end]

print (type(input), type(output))

<class 'torch.autograd.variable.Variable'> <class 'torch.autograd.variable.Variable'>


In [21]:
dataset = torch.utils.data.TensorDataset(input.data, output.data)
loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)

model, gmp = retrain_layer(layer_model, model_orig, loader, test_data_full, test_labels_full, alpha, beta, tau, mixtures, model_dir + model_file)

0-component Mean: 2500.0 Variance: 1250.0
Non-zero component Mean: 10.0 Variance: 1.0
Model Name: SWSModelConv2
Epoch: 10. Loss: 11.15
Original: 98.71% - Retrain: 98.58% - Prune: 13.15%
Epoch: 20. Loss: 6.32
Original: 98.71% - Retrain: 98.35% - Prune: 9.85%
Epoch: 30. Loss: 3.75
Original: 98.71% - Retrain: 98.16% - Prune: 9.74%
Epoch: 40. Loss: 1.86
Original: 98.71% - Retrain: 97.91% - Prune: 9.74%
Epoch: 50. Loss: 1.07
Original: 98.71% - Retrain: 97.72% - Prune: 9.74%


NameError: name 'model_save_dir' is not defined

### Layered Results

In [14]:
import sys
sys.path.insert(0,'../../src/')
import os
import argparse
from retrain_model import retrain_model
savedir = os.getcwd() + "/models/"
import copy

import pickle
from mnist_loader import train_data
from utils_sws import sws_prune, compressed_model
from utils_model import test_accuracy, layer_accuracy, sws_replace
import torch
from torch.autograd import Variable
from mnist_loader import search_train_data, search_retrain_data, search_validation_data, train_data, test_data, batch_size
from utils_misc import model_load_dir

if __name__=="__main__":
    test_data_full =  Variable(test_data(fetch = "data")).cuda()
    test_labels_full =  Variable(test_data(fetch = "labels")).cuda()
    val_data_full =  Variable(search_validation_data(fetch = "data")).cuda()
    val_labels_full =  Variable(search_validation_data(fetch = "labels")).cuda()
    #parser = argparse.ArgumentParser()
    #parser.add_argument('--start', dest = "start", help="Start Search", required=True, type=(int))
    #parser.add_argument('--end', dest = "end", help="End Search", required=True, type=(int))
    #args = parser.parse_args()
    #start = int(args.start)
    #end = int(args.end)
    start = 100
    end = 105
    
    with open("../sobol_search.p", "rb") as handle:
        params = pickle.load(handle)
    for i in range (start,end):
        print ("exp:{} mean: {}, var: {}, tau: {}, temp: {}, mixtures: {}".format(i, params['mean'][i], params['var'][i], params['tau'][i], float(params['temp'][i]), int(params['mixtures'][i])))
        mean = float(params['mean'][i])
        var = float(params['var'][i])
        beta = mean/var
        alpha = mean * beta
        
        model_name = "SWSModel"
        model_file = 'mnist_{}_{}_{}'.format(model_name, 100, "search")
        model_orig = torch.load(model_load_dir + model_file + '.m').cuda()

        conv1_exp_name = "SWSModelConv1_a{}_b{}_r{}_t{}_m{}_kdT{}_{}".format(alpha, beta, 50, float(params['tau'][i]), int(params['mixtures'][i]), int(params['temp'][i]), 'search')
        conv1_model_file = "./models/mnist_SWSModel_100_searchmnist_retrain_{}".format(conv1_exp_name)
        conv2_exp_name = "SWSModelConv2_a{}_b{}_r{}_t{}_m{}_kdT{}_{}".format(alpha, beta, 50, float(params['tau'][i]), int(params['mixtures'][i]), int(params['temp'][i]), 'search')
        conv2_model_file = "./models/mnist_SWSModel_100_searchmnist_retrain_{}".format(conv2_exp_name)
        fc1_exp_name = "SWSModelFC1_a{}_b{}_r{}_t{}_m{}_kdT{}_{}".format(alpha, beta, 50, float(params['tau'][i]), int(params['mixtures'][i]), int(params['temp'][i]), 'search')
        fc1_model_file = "./models/mnist_SWSModel_100_searchmnist_retrain_{}".format(fc1_exp_name)
        fc2_exp_name = "SWSModelFC2_a{}_b{}_r{}_t{}_m{}_kdT{}_{}".format(alpha, beta, 50, float(params['tau'][i]), int(params['mixtures'][i]), int(params['temp'][i]), 'search')
        fc2_model_file = "./models/mnist_SWSModel_100_searchmnist_retrain_{}".format(fc2_exp_name)
        
        conv1_model = torch.load("{}.m".format(conv1_model_file)).cuda()
        with open("{}_gmp.p".format(conv1_model_file), "rb") as handle:
            conv1_gmp = pickle.load(handle)
        conv2_model = torch.load("{}.m".format(conv2_model_file)).cuda()
        with open("{}_gmp.p".format(conv2_model_file), "rb") as handle:
            conv2_gmp = pickle.load(handle)
        fc1_model = torch.load("{}.m".format(fc1_model_file)).cuda()
        with open("{}_gmp.p".format(fc1_model_file), "rb") as handle:
            fc1_gmp = pickle.load(handle)
        fc2_model = torch.load("{}.m".format(fc2_model_file)).cuda()
        with open("{}_gmp.p".format(fc2_model_file), "rb") as handle:
            fc2_gmp = pickle.load(handle)
            
        conv1_res = layer_accuracy(conv1_model, conv1_gmp, model_orig, val_data_full, val_labels_full)
        conv2_res = layer_accuracy(conv2_model, conv2_gmp, model_orig, val_data_full, val_labels_full)
        fc1_res = layer_accuracy(fc1_model, fc1_gmp, model_orig, val_data_full, val_labels_full)
        fc2_res = layer_accuracy(fc2_model, fc2_gmp, model_orig, val_data_full, val_labels_full)
        
        pruned_model = sws_replace(model_orig, sws_prune(conv1_model, conv1_gmp), sws_prune(conv2_model, conv2_gmp), sws_prune(fc1_model, fc1_gmp), sws_prune(fc2_model, fc2_gmp))
        test_acc = test_accuracy(test_data_full, test_labels_full, pruned_model)[0]
        val_acc = test_accuracy(val_data_full, val_labels_full, pruned_model)[0]
        print ("test: {}, val: {}".format(test_acc, val_acc))
        #cm = compressed_model(pruned_model.state_dict(), [conv1_gmp, conv2_gmp, fc1_gmp, fc2_gmp])
        #cr = cm.get_cr(6)[0]
        #print ("CR: {}".format(cr))
        #sp = (cm.binned_weights == 0).sum() / float(cm.binned_weights.size) * 100.0
        #print ("SP: {}".format(sp))
        if not os.path.exists("results.csv"):
            with open("results.csv", "w") as out_csv:
                out_csv.write("Exp, Mean, Var, Tau, Temp, Mixtures, conv1_val, conv1_sp, conv2_val, conv2_sp, fc1_val, fc1_sp, fc2_val, fc_2sp, Test Acc, Val Acc\n")
                out_csv.write(", ".join([str(x) for x in [i, params['mean'][i], params['var'][i], params['tau'][i], int(params['temp'][i]), int(params['mixtures'][i]), 
                                           conv1_res[1], conv1_res[2], conv2_res[1], conv1_res[2], fc1_res[1], fc1_res[2], fc2_res[1], fc2_res[2], test_acc, val_acc]]) + "\n")
        else:
            with open("results.csv", "a") as out_csv:
                out_csv.write(", ".join([str(x) for x in [i, params['mean'][i], params['var'][i], params['tau'][i], int(params['temp'][i]), int(params['mixtures'][i]), 
                                           conv1_res[1], conv1_res[2], conv2_res[1], conv1_res[2], fc1_res[1], fc1_res[2], fc2_res[1], fc2_res[2], test_acc, val_acc]]) + "\n")

#mnist_retrain_SWSModel_a1.3365233866589001e-06_b0.0004151108285503527_r50_t1.6436915121153314e-07_m6_kdT20_search.m
#mnist_retrain_SWSModel_a86791.37374683257_b8.524404751815132_r50_t2.5311641601652367e-05_m9.0_kdT7_search.m

exp:100 mean: 0.0032196784442513784, var: 0.8956708445545833, tau: 1.7742852308703055e-08, temp: 2.0, mixtures: 11
Original: 98.73% - Retrain: 25.40% - Prune: 32.89% - Sparsity: 0.00%
Original: 98.73% - Retrain: 97.61% - Prune: 10.09% - Sparsity: 100.00%
Original: 98.73% - Retrain: 95.68% - Prune: 10.09% - Sparsity: 100.00%
Original: 98.73% - Retrain: 98.14% - Prune: 94.31% - Sparsity: 0.00%
test: 9.8, val: 9.91
100, 0.0032196784442513784, 0.8956708445545833, 1.7742852308703055e-08, 2, 11, 32.89, 0.0, 10.09, 0.0, 10.09, 100.0, 94.31, 0.0, 9.8, 9.91

exp:101 mean: 0.0032196784442513784, var: 0.8956708445545833, tau: 1.7742852308703055e-08, temp: 10.0, mixtures: 8
Original: 98.73% - Retrain: 25.38% - Prune: 44.58% - Sparsity: 8.00%
Original: 98.73% - Retrain: 97.60% - Prune: 10.09% - Sparsity: 100.00%
Original: 98.73% - Retrain: 95.66% - Prune: 10.09% - Sparsity: 100.00%
Original: 98.73% - Retrain: 80.39% - Prune: 79.85% - Sparsity: 100.00%
test: 9.8, val: 9.91
101, 0.0032196784442513784

In [13]:
conv2_res

(97.61, 10.09, 100.0)