In [1]:
# Here I load the pre-trained Classical Nets after converting them to Finite Gaussian Nets
# 1 - load classical network
# 2 - convert to fgn
# 3 - check mnist performance
# 4 - lower sigma until performance hit (no retraining)
# 6 - save results

In [2]:
from __future__ import print_function

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

import numpy as np
from scipy import stats

import torch_helper_lib as th

import json
import pickle

In [4]:
# random seeds
# torch.manual_seed(1665)
# np.random.seed(3266)

# torch.backends.cudnn.deterministic = True
# torch.cuda.manual_seed_all(999)

In [5]:
# Define what device we are using
print("CUDA Available: ",torch.cuda.is_available())
use_cuda = False
device = torch.device("cuda" if (use_cuda and torch.cuda.is_available()) else "cpu")
print("Using device:", device)

CUDA Available:  True
Using device: cpu


In [6]:
import matplotlib as mpl
# set this 'backend' when using jupyter; do this before importing pyplot
mpl.use('nbagg')
import matplotlib.pyplot as plt

In [7]:
# MNIST dataset and dataloader declaration
# transforms does both the conversion from 0-255 to 0-1
# and normalizes by the precomputed mean and std

batch_size = 64

mnist_train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../MNIST-dataset', train=True, download=False, 
                   transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))
            ])), 
        batch_size=batch_size, shuffle=True)

mnist_test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../MNIST-dataset', train=False, download=False, 
                   transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))
            ])), 
        batch_size=batch_size, shuffle=False)

In [8]:
# show an example
print("Shape:", mnist_train_loader.dataset.train_data.shape)
x = mnist_train_loader.dataset.train_data[1]
print("type:", type(x.numpy()))
print("shape:", x.shape)
# print("sample:", x.numpy())

y = mnist_train_loader.dataset.train_labels[1]
print("Label:",y.numpy())
print("type:", type(y))

plt.imshow(x, cmap=plt.cm.get_cmap('Greys'))
plt.colorbar()
plt.show()

Shape: torch.Size([60000, 28, 28])
type: <type 'numpy.ndarray'>
shape: torch.Size([28, 28])
Label: 0
type: <class 'torch.Tensor'>


<IPython.core.display.Javascript object>

In [9]:
# list of sizes to try
network_sizes = [[], [8], [128], [1024], [8,8], [128,128], [1024,1024], [8,8,8], [128,128,128], [1024,1024,1024]]

In [10]:
# pre-trained classic models 
model_dir = "./saved_models/Classic_MNIST_Nets"
model_id = "./saved_models/Classic_MNIST_Nets/model_[8]"
exp_num = "_exp_0"
full_model = "_full.pth"
state_dict = "_state_dict.pth"

full_path = model_id+exp_num+full_model
dict_path = model_id+exp_num+state_dict
print(dict_path)

./saved_models/Classic_MNIST_Nets/model_[8]_exp_0_state_dict.pth


In [11]:
save_path= "./saved_models/Retrained_FG_Nets//"

In [12]:
# functions for test
# nll loss function
def classic_nll_loss_func(model, output, target):
    return F.nll_loss(output, target)
# number of correct pred function for classic net
def classic_pred_func(output, target):
    output = output
    pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
    correct = pred.eq(target.long().view_as(pred)).sum().item()
    return correct

# nll loss function
def fgn_nll_loss_func(model, output, target):
    # split output into pred and likelihoods
    output, likelihood = output
    return F.nll_loss(output, target)    
# number of correct pred function for fgnet
def fgn_pred_func(output, target):
    output,_ = output
    pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
    correct = pred.eq(target.long().view_as(pred)).sum().item()
    return correct

In [13]:
# helper functions for manipulating the sigma of fgnets

def half_sigma(model):
    # given a model, divides the sigma of each FGN layer by 2
    for p in model.modules():
        if isinstance(p, th.FGN_layer):
            p.sigs = torch.nn.Parameter(p.sigs/2.0)
            
def double_sigma(model):
    # given a model, doubles the sigma of each FGN layer by 2
    for p in model.modules():
        if isinstance(p, th.FGN_layer):
            p.sigs = torch.nn.Parameter(p.sigs*2.0)

In [None]:
# for each classical network
for s in network_sizes:
    print("Working on size:", s)
    # for each exp 
    for exp_num in range(12):
        print("Working on exp:", exp_num)
        
        # create classic model
        classic_model = th.Classic_MNIST_Net(hidden_l_nums=s)
        #state dict
        # dict load
        classic_model_path = model_dir+'/model_'+str(s)+'_exp_'+str(exp_num)+'_state_dict.pth'
        print(classic_model_path)
        classic_model.to(device)
        classic_model.load_state_dict(torch.load(classic_model_path))
        print("Model loaded")
        
        # classic model perf
        classic_test_res = th.test(classic_model, device, mnist_test_loader, loss_func=classic_nll_loss_func, verbose=True, pred_func=classic_pred_func)

        # create model to be converted
        fgn_model = th.Feedforward_FGN_net(28*28,10,s).to(device)
        
        # convert
        th.convert_Classic2FGN(classic_model=classic_model, fgn_model=fgn_model)
        print("Model converted")
        
        # perf (check that same or close to above)
        fgn_test_res1 = th.test(fgn_model, device, mnist_test_loader, loss_func=fgn_nll_loss_func, verbose=True, pred_func=fgn_pred_func)
        if fgn_test_res1['test_accuracy'] < 0.99*classic_test_res['test_accuracy']:
            print("Error during conversion")
        else:
            print("Conversion successful")
        
        # half sigma until performance hit
        current_test_res = fgn_test_res1
        while current_test_res['test_accuracy'] > 0.95*fgn_test_res1['test_accuracy']:
            print("Halfing sigma")
            half_sigma(fgn_model)
            current_test_res = th.test(fgn_model, device, mnist_test_loader, loss_func=fgn_nll_loss_func, verbose=True, pred_func=fgn_pred_func)
        # set back 
        print("Doubling Sigma")
        double_sigma(fgn_model)
        current_test_res = th.test(fgn_model, device, mnist_test_loader, loss_func=fgn_nll_loss_func, verbose=True, pred_func=fgn_pred_func)
        # save model
        model_name = "fgn_model_{}_exp_{}".format(str(s), str(exp_num))
        print("Saving converted model {} in {}".format(model_name, save_path))
        # save model entirely
        torch.save(fgn_model, save_path+model_name+"_full.pth")
        
        # save model weights
        torch.save(fgn_model.state_dict(), save_path+model_name+"_state_dict.pth")
               
        print()

Working on size: []
Working on exp: 0
./saved_models/Classic_MNIST_Nets/model_[]_exp_0_state_dict.pth
Model loaded
Test set - Average loss: 0.0044, Accuracy: 9198/10000 (92%)
Model converted
Test set - Average loss: 0.0044, Accuracy: 9198/10000 (92%)
Conversion successful
Halfing sigma
Test set - Average loss: 0.0044, Accuracy: 9198/10000 (92%)
Halfing sigma
Test set - Average loss: 0.0044, Accuracy: 9198/10000 (92%)
Halfing sigma
Test set - Average loss: 0.0044, Accuracy: 9198/10000 (92%)
Halfing sigma
Test set - Average loss: 0.0047, Accuracy: 9198/10000 (92%)
Halfing sigma
Test set - Average loss: 0.0123, Accuracy: 9204/10000 (92%)
Halfing sigma
Test set - Average loss: 0.0347, Accuracy: 9192/10000 (92%)
Halfing sigma
Test set - Average loss: 0.0362, Accuracy: 3686/10000 (37%)
Doubling Sigma
Test set - Average loss: 0.0347, Accuracy: 9192/10000 (92%)
Saving converted model fgn_model_[]_exp_0 in ./saved_models/Retrained_FG_Nets//

Working on exp: 1
./saved_models/Classic_MNIST_Nets/m

Test set - Average loss: 0.0347, Accuracy: 9193/10000 (92%)
Saving converted model fgn_model_[]_exp_8 in ./saved_models/Retrained_FG_Nets//

Working on exp: 9
./saved_models/Classic_MNIST_Nets/model_[]_exp_9_state_dict.pth
Model loaded
Test set - Average loss: 0.0044, Accuracy: 9207/10000 (92%)
Model converted
Test set - Average loss: 0.0044, Accuracy: 9207/10000 (92%)
Conversion successful
Halfing sigma
Test set - Average loss: 0.0044, Accuracy: 9207/10000 (92%)
Halfing sigma
Test set - Average loss: 0.0044, Accuracy: 9206/10000 (92%)
Halfing sigma
Test set - Average loss: 0.0043, Accuracy: 9208/10000 (92%)
Halfing sigma
Test set - Average loss: 0.0046, Accuracy: 9207/10000 (92%)
Halfing sigma
Test set - Average loss: 0.0123, Accuracy: 9212/10000 (92%)
Halfing sigma
Test set - Average loss: 0.0347, Accuracy: 9204/10000 (92%)
Halfing sigma
Test set - Average loss: 0.0362, Accuracy: 3518/10000 (35%)
Doubling Sigma
Test set - Average loss: 0.0347, Accuracy: 9204/10000 (92%)
Saving conver

Test set - Average loss: 0.0059, Accuracy: 8925/10000 (89%)
Conversion successful
Halfing sigma
Test set - Average loss: 0.0074, Accuracy: 8920/10000 (89%)
Halfing sigma
Test set - Average loss: 0.0223, Accuracy: 8888/10000 (89%)
Halfing sigma
Test set - Average loss: 0.0360, Accuracy: 8328/10000 (83%)
Doubling Sigma
Test set - Average loss: 0.0223, Accuracy: 8888/10000 (89%)
Saving converted model fgn_model_[8]_exp_8 in ./saved_models/Retrained_FG_Nets//

Working on exp: 9
./saved_models/Classic_MNIST_Nets/model_[8]_exp_9_state_dict.pth
Model loaded
Test set - Average loss: 0.0055, Accuracy: 9032/10000 (90%)
Model converted
Test set - Average loss: 0.0057, Accuracy: 9035/10000 (90%)
Conversion successful
Halfing sigma
Test set - Average loss: 0.0073, Accuracy: 9024/10000 (90%)
Halfing sigma
Test set - Average loss: 0.0224, Accuracy: 8979/10000 (90%)
Halfing sigma
Test set - Average loss: 0.0360, Accuracy: 8078/10000 (81%)
Doubling Sigma
Test set - Average loss: 0.0224, Accuracy: 8979/

Test set - Average loss: 0.0340, Accuracy: 3661/10000 (37%)
Doubling Sigma
Test set - Average loss: 0.0250, Accuracy: 9393/10000 (94%)
Saving converted model fgn_model_[128]_exp_6 in ./saved_models/Retrained_FG_Nets//

Working on exp: 7
./saved_models/Classic_MNIST_Nets/model_[128]_exp_7_state_dict.pth
Model loaded
Test set - Average loss: 0.0020, Accuracy: 9621/10000 (96%)
Model converted
Test set - Average loss: 0.0020, Accuracy: 9621/10000 (96%)
Conversion successful
Halfing sigma
Test set - Average loss: 0.0020, Accuracy: 9620/10000 (96%)
Halfing sigma
Test set - Average loss: 0.0021, Accuracy: 9619/10000 (96%)
Halfing sigma
Test set - Average loss: 0.0027, Accuracy: 9614/10000 (96%)
Halfing sigma
Test set - Average loss: 0.0104, Accuracy: 9590/10000 (96%)
Halfing sigma
Test set - Average loss: 0.0251, Accuracy: 9362/10000 (94%)
Halfing sigma
Test set - Average loss: 0.0340, Accuracy: 3505/10000 (35%)
Doubling Sigma
Test set - Average loss: 0.0251, Accuracy: 9362/10000 (94%)
Saving

Test set - Average loss: 0.0025, Accuracy: 9578/10000 (96%)
Halfing sigma
Test set - Average loss: 0.0090, Accuracy: 9240/10000 (92%)
Halfing sigma
Test set - Average loss: 0.0342, Accuracy: 4937/10000 (49%)
Doubling Sigma
Test set - Average loss: 0.0090, Accuracy: 9240/10000 (92%)
Saving converted model fgn_model_[1024]_exp_4 in ./saved_models/Retrained_FG_Nets//

Working on exp: 5
./saved_models/Classic_MNIST_Nets/model_[1024]_exp_5_state_dict.pth
Model loaded
Test set - Average loss: 0.0020, Accuracy: 9649/10000 (96%)
Model converted
Test set - Average loss: 0.0020, Accuracy: 9649/10000 (96%)
Conversion successful
Halfing sigma
Test set - Average loss: 0.0020, Accuracy: 9649/10000 (96%)
Halfing sigma
Test set - Average loss: 0.0020, Accuracy: 9647/10000 (96%)
Halfing sigma
Test set - Average loss: 0.0021, Accuracy: 9639/10000 (96%)
Halfing sigma
Test set - Average loss: 0.0025, Accuracy: 9591/10000 (96%)
Halfing sigma
Test set - Average loss: 0.0090, Accuracy: 9199/10000 (92%)
Halfi

Test set - Average loss: 0.0065, Accuracy: 8885/10000 (89%)
Model converted
Test set - Average loss: 0.0067, Accuracy: 8893/10000 (89%)
Conversion successful
Halfing sigma
Test set - Average loss: 0.0083, Accuracy: 8881/10000 (89%)
Halfing sigma
Test set - Average loss: 0.0188, Accuracy: 8799/10000 (88%)
Halfing sigma
Test set - Average loss: 0.0368, Accuracy: 1591/10000 (16%)
Doubling Sigma
Test set - Average loss: 0.0188, Accuracy: 8799/10000 (88%)
Saving converted model fgn_model_[8, 8]_exp_3 in ./saved_models/Retrained_FG_Nets//

Working on exp: 4
./saved_models/Classic_MNIST_Nets/model_[8, 8]_exp_4_state_dict.pth
Model loaded
Test set - Average loss: 0.0058, Accuracy: 9003/10000 (90%)
Model converted
Test set - Average loss: 0.0061, Accuracy: 9002/10000 (90%)
Conversion successful
Halfing sigma
Test set - Average loss: 0.0077, Accuracy: 9002/10000 (90%)
Halfing sigma
Test set - Average loss: 0.0184, Accuracy: 8950/10000 (90%)
Halfing sigma
Test set - Average loss: 0.0363, Accuracy

Test set - Average loss: 0.0017, Accuracy: 9682/10000 (97%)
Model converted
Test set - Average loss: 0.0017, Accuracy: 9681/10000 (97%)
Conversion successful
Halfing sigma
Test set - Average loss: 0.0017, Accuracy: 9680/10000 (97%)
Halfing sigma
Test set - Average loss: 0.0017, Accuracy: 9678/10000 (97%)
Halfing sigma
Test set - Average loss: 0.0022, Accuracy: 9668/10000 (97%)
Halfing sigma
Test set - Average loss: 0.0075, Accuracy: 9647/10000 (96%)
Halfing sigma
Test set - Average loss: 0.0218, Accuracy: 9369/10000 (94%)
Halfing sigma
Test set - Average loss: 0.0338, Accuracy: 2684/10000 (27%)
Doubling Sigma
Test set - Average loss: 0.0218, Accuracy: 9369/10000 (94%)
Saving converted model fgn_model_[128, 128]_exp_3 in ./saved_models/Retrained_FG_Nets//

Working on exp: 4
./saved_models/Classic_MNIST_Nets/model_[128, 128]_exp_4_state_dict.pth
Model loaded
Test set - Average loss: 0.0017, Accuracy: 9684/10000 (97%)
Model converted
Test set - Average loss: 0.0017, Accuracy: 9683/10000 (

Test set - Average loss: 0.0084, Accuracy: 9084/10000 (91%)
Doubling Sigma
Test set - Average loss: 0.0024, Accuracy: 9593/10000 (96%)
Saving converted model fgn_model_[1024, 1024]_exp_0 in ./saved_models/Retrained_FG_Nets//

Working on exp: 1
./saved_models/Classic_MNIST_Nets/model_[1024, 1024]_exp_1_state_dict.pth
Model loaded
Test set - Average loss: 0.0017, Accuracy: 9670/10000 (97%)
Model converted
Test set - Average loss: 0.0017, Accuracy: 9669/10000 (97%)
Conversion successful
Halfing sigma
Test set - Average loss: 0.0017, Accuracy: 9668/10000 (97%)
Halfing sigma
Test set - Average loss: 0.0018, Accuracy: 9663/10000 (97%)
Halfing sigma
Test set - Average loss: 0.0018, Accuracy: 9655/10000 (97%)
Halfing sigma
Test set - Average loss: 0.0023, Accuracy: 9586/10000 (96%)
Halfing sigma
Test set - Average loss: 0.0083, Accuracy: 8934/10000 (89%)
Doubling Sigma
Test set - Average loss: 0.0023, Accuracy: 9586/10000 (96%)
Saving converted model fgn_model_[1024, 1024]_exp_1 in ./saved_mod