In [1]:
import os
if os.getcwd().split('/')[-1] == "notebooks":
    os.chdir('..')

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules import Module
import torch.optim as optim

from src.model import teacherNet, studentNet
from src.sws import GaussianMixturePrior
from src.train import train
import copy
from tqdm.notebook import tqdm
from src.kd import extract_logits, kd_ce_loss
batch_size = 256

In [3]:
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose(
    [transforms.ToTensor()])

trainset = torchvision.datasets.MNIST(root='./data/MNIST/', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=False, num_workers=2)

testset = torchvision.datasets.MNIST(root='./data/MNIST/', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

In [4]:
class LeNet_300_100(nn.Module):
    
    def __init__(self):
        super(LeNet_300_100, self).__init__()
        
        self.name = 'LeNet-300-100'
        
        self.fc1 = nn.Linear(28*28, 300)
        self.fc2 = nn.Linear(300, 100)
        self.fc3 = nn.Linear(100, 10)
        
    def forward(self, x):
        x = x.view(-1, 28*28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
import seaborn as sns
import matplotlib.pyplot as plt
from src.train import evaluate
import numpy as np 
from src.sws import sws_prune_copy

%matplotlib inline

def show_sws_weights_log(model, means=0, precisions=0, epoch=-1, accuracy=-1, savefile = ""):
	"""
	show model weight histogram with mean and precisions
	"""
	weights = np.array([], dtype=np.float32)
	for layer in model.state_dict():
		weights = np.hstack( (weights, model.state_dict()[layer].view(-1).cpu().numpy()) )
		
	plt.clf()
	plt.figure(figsize=(20, 3))

	#2-Logplot
	sns.distplot(weights, kde=False, color="g",bins=200,norm_hist=True, hist_kws={'log':True})
	#plot mean and precision
	if not (means==0 or precisions==0):
		plt.axvline(0, linewidth = 1)
		std_dev0 = np.sqrt(1/np.exp(precisions[0]))
		plt.axvspan(xmin=-std_dev0, xmax=std_dev0, alpha=0.3)

		for mean, precision in zip(means, precisions[1:]):
			plt.axvline(mean, linewidth = 1)
			std_dev = np.sqrt(1/np.exp(precision))
			plt.axvspan(xmin=mean - std_dev, xmax=mean + std_dev, alpha=0.1)
	plt.xlabel("Weight Value")
	plt.ylabel("Density")
	plt.xlim([-1.2, 1.2])
	plt.ylim([1e-3, 1e2])
	
	if savefile!="":
		plt.savefig("./figs/{}_{}.png".format(savefile, epoch+1), bbox_inches='tight')
		plt.close()
	else:
		plt.show()
        
def get_sparsity(model_prune):
	sp_zeroes = 0
	sp_elem = 0
	for layer in model_prune.state_dict():
		sp_zeroes += float((model_prune.state_dict()[layer].view(-1) == 0).sum())
		sp_elem += float(model_prune.state_dict()[layer].view(-1).numel())
	sp = sp_zeroes/sp_elem * 100.0
	return sp


def joint_plot(model, model_orig, gmp, epoch, retraining_epochs, acc, pruned_acc, sparsity, savefile = ""):
	"""
	joint distribution plot weights before and after sws retraining
	"""
	weights_T = np.array([], dtype=np.float32)
	for layer in model.state_dict():
		weights_T = np.hstack( (weights_T, model.state_dict()[layer].view(-1).cpu().numpy()) )

	weights_0 = np.array([], dtype=np.float32)
	for layer in model_orig.state_dict():
		weights_0 = np.hstack( (weights_0, model_orig.state_dict()[layer].view(-1).cpu().numpy()) )

	#get mean, stddev
	mu_T = np.concatenate([np.zeros(1), gmp.means.clone().data.cpu().numpy()])
	std_T = np.sqrt(1/np.exp(gmp.gammas.clone().data.cpu().numpy()))

	x0 = -1.2
	x1 = 1.2
	I = np.random.permutation(len(weights_0))
	f = sns.jointplot(weights_0[I], weights_T[I], size=8, kind="scatter", color=sns.color_palette()[0], stat_func=None, edgecolor='w',
					  marker='o', joint_kws={"s": 8}, marginal_kws=dict(bins=1000), ratio=4)
	f.ax_joint.hlines(mu_T, x0, x1, lw=0.5)

	for k in range(len(mu_T)):
		if k == 0:
			f.ax_joint.fill_between(np.linspace(x0, x1, 10), mu_T[k] - 2 * std_T[k], mu_T[k] + 2 * std_T[k],
									color='g', alpha=0.1)
		else:
			f.ax_joint.fill_between(np.linspace(x0, x1, 10), mu_T[k] - 2 * std_T[k], mu_T[k] + 2 * std_T[k],
									color=sns.color_palette()[0], alpha=0.1)
	
	plt.title("Epoch: {}/{}\nTest accuracy: {:.2f}%\nPrune Accuracy: {:.2f}%\nSparsity: {:.2f}%"
              .format(epoch+1, retraining_epochs, acc, pruned_acc, sparsity))
	f.ax_marg_y.set_xscale("log")
	f.set_axis_labels("Pretrained", "Retrained")
	f.ax_marg_x.set_xlim(-1.2, 1.2)
	f.ax_marg_y.set_ylim(-1.2, 1.2)
	if savefile!="":
		plt.savefig("./figs/jp_{}_{}.png".format(savefile, epoch+1), bbox_inches='tight')
		plt.close()
	else:
		plt.show()

def evaluate(model, testloader):
    correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for (images, labels) in testloader:
            if torch.cuda.is_available():
                images = images.cuda()
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted.cpu() == labels).sum().item()
    model.train()
    return 100. * correct / total

 


def retrain(model, gmp, dataloader, testloader, optimizer, criterion, epochs=10, writer=None, scheduler=None):
    running_loss = 0.0
    model.train()
    model_orig = copy.deepcopy(model)
    for epoch in tqdm(range(epochs)):
        if scheduler:
            scheduler.step()
        running_loss = 0
        for i, data in enumerate(dataloader):
            # data = (inputs, targets, teacher_scores(optional))
            if torch.cuda.is_available():
                data = tuple([x.cuda() for x in data])

            optimizer.zero_grad()
            outputs = model(data[0].float())
            
            loss = criterion(outputs, *data[1:])
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
                    
            
        acc = evaluate(model, testloader)
        pruned_model = sws_prune_copy(model, gmp, 'l2')
        pruned_acc = evaluate(pruned_model, testloader)
#         resp_pruned = sws_prune_copy(model, gmp, 'p')
#         resp_pruned_acc = evaluate(resp_pruned, testloader)
        sparsity = get_sparsity(pruned_model)
#         resp_sparsity = get_sparsity(resp_pruned)
        
        
        joint_plot(model, model_orig, gmp, epoch, epochs, acc, pruned_acc, sparsity, savefile = "lenet")
        show_sws_weights_log(model = model, 
                             means = list(gmp.means.data.clone().cpu()), 
                             precisions = list(gmp.gammas.data.clone().cpu()),
                            epoch=epoch,
                            accuracy=acc,
                            savefile="lenet")
        print("Epoch {} accuracy = {:.2f}% pruned_accuracy = {:.2f}% sparsity = {:.2f}%".format(epoch + 1, 
                                                                               acc, 
                                                                               pruned_acc,
                                                                              sparsity,))
        pruned_model.cpu()
        del pruned_model
        if writer:
            writer.add_scalar('accuracy', acc, epoch)
            writer.add_scalar('training loss', running_loss/len(dataloader), epoch)
        
        
        running_loss = 0.0
    model.eval()

### Normal Training

In [5]:
# torch.manual_seed(0)
# model = LeNet_300_100().cuda()

# criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(params=model.parameters(), lr=5e-3)

# train(model, trainloader, testloader, optimizer, criterion, epochs=30, writer=None)
# torch.save(model.state_dict(), "./models/mnist_lenet.pt")

In [6]:
orig_model = LeNet_300_100().cuda()
orig_model.load_state_dict(torch.load("./models/mnist_lenet.pt"))

### Retrain with Soft-weight Sharing

In [7]:
# logits = extract_logits(orig_model, trainloader).cpu()
# kdtrain = torch.utils.data.TensorDataset(trainset.train_data, trainset.train_labels, logits)
# kdloader = torch.utils.data.DataLoader(kdtrain, batch_size=batch_size, shuffle=False, num_workers=2)

# opt = torch.optim.Adam(params=model.parameters(), lr=5e-4)

# class kd_mse_loss:
    
#     def __init__(self, temperature, alpha, criterion=nn.CrossEntropyLoss()):
#         self.temperature = temperature
#         self.alpha = alpha
#         self.criterion = criterion
#     '''
#     Calculate the mse loss between logits_S and logits_T
#     :param logits_S: Tensor of shape (batch_size, length, num_labels) or (batch_size, num_labels)
#     :param logits_T: Tensor of shape (batch_size, length, num_labels) or (batch_size, num_labels)
#     :param temperature: A float or a tensor of shape (batch_size, length) or (batch_size,)
#     '''
    
#     def __call__(self, logits_S, labels, logits_T):
#         if isinstance(self.temperature, torch.Tensor) and self.temperature.dim() > 0:
#             self.temperature = self.temperature.unsqueeze(-1)
#         beta_logits_T = logits_T / self.temperature
#         beta_logits_S = logits_S / self.temperature
#         kd_loss = F.mse_loss(beta_logits_S, beta_logits_T)
#         label_loss = self.criterion(logits_S, labels)
#         loss = (1.-self.alpha) * label_loss + self.alpha * kd_loss
#         print(kd_loss)
#         return loss

# train(model, kdloader, testloader, opt, kd_ce_loss(temperature=4, alpha=0.8), epochs=30, writer=None)

In [None]:
torch.manual_seed(0)
model = LeNet_300_100().cuda()
model.load_state_dict(torch.load("./models/mnist_lenet.pt"))
logits = extract_logits(model, trainloader).cpu()
kdtrain = torch.utils.data.TensorDataset(trainset.train_data, trainset.train_labels, logits)
kdloader = torch.utils.data.DataLoader(kdtrain, batch_size=batch_size, shuffle=False, num_workers=2)


def get_ab(mean, var):
	beta = mean/var
	alpha = mean * beta
	return (alpha, beta)

class SWSLoss:
    
    def __init__(self, criterion, gmp, tau):
        self.criterion = criterion
        self.gmp = gmp
        self.tau = tau
        
    def __call__(self, outputs, labels):
        acc_loss = self.criterion(outputs, labels)
        gmp_loss = self.gmp.call()
#         self.tau = self.tau*0.95
        return acc_loss + self.tau * gmp_loss

n_mixtures = 16
zero_mixing_proportion = 0.99
gmp = GaussianMixturePrior(nb_components = n_mixtures, 
                           network_weights = [x for x in model.parameters()], 
                           pi_zero = zero_mixing_proportion, 
                           zero_ab = get_ab(2500, 1250),  #(10000, 100)
                           ab = get_ab(100, 10), #(1000, 10)
                           means = [],
                           scaling = False)

optimizable_params = [
    {'params': model.parameters(), 'lr': 2e-4}, #(5e-3)
    {'params': [gmp.means], 'lr': 5e-5}, #(5e-6)
    {'params': [gmp.gammas, gmp.rhos], 'lr': 3e-3} #(5e-4)
]

opt = torch.optim.Adam(optimizable_params)
criterion = SWSLoss(nn.CrossEntropyLoss(), gmp, tau=1e-6) #(1e-6)
retrain(model, gmp, trainloader, testloader, opt, criterion, epochs=100, writer=None)

0-component Mean: 2500.0 Variance: 1250.0
Non-zero component Mean: 100.0 Variance: 10.0


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

Layer Loss: 289535.000
Layer Loss: 251.706
Layer Loss: 34082.000
Layer Loss: 150.885
Layer Loss: 811.361
Layer Loss: 22.138
0-neglogprop Loss: -13828.171
Remaining-neglogprop Loss: -39147.242




Epoch 1 accuracy = 98.11% pruned_accuracy = 97.50% sparsity = 26.86%
Epoch 2 accuracy = 98.13% pruned_accuracy = 97.60% sparsity = 33.43%
Epoch 3 accuracy = 98.17% pruned_accuracy = 97.61% sparsity = 37.32%
Epoch 4 accuracy = 98.20% pruned_accuracy = 97.66% sparsity = 40.27%
Epoch 5 accuracy = 98.20% pruned_accuracy = 97.76% sparsity = 43.08%
Epoch 6 accuracy = 98.25% pruned_accuracy = 97.70% sparsity = 45.61%
Epoch 7 accuracy = 98.28% pruned_accuracy = 97.79% sparsity = 47.94%
Epoch 8 accuracy = 98.28% pruned_accuracy = 97.78% sparsity = 50.01%
Epoch 9 accuracy = 98.29% pruned_accuracy = 97.84% sparsity = 51.97%
Epoch 10 accuracy = 98.34% pruned_accuracy = 97.87% sparsity = 53.69%
Epoch 11 accuracy = 98.30% pruned_accuracy = 97.93% sparsity = 55.30%
Epoch 12 accuracy = 98.31% pruned_accuracy = 98.00% sparsity = 56.83%
Epoch 13 accuracy = 98.30% pruned_accuracy = 98.05% sparsity = 58.34%
Epoch 14 accuracy = 98.30% pruned_accuracy = 97.99% sparsity = 59.79%
Epoch 15 accuracy = 98.26% pr

In [None]:
pruned_model = sws_prune_copy(model, gmp, 'l2')
show_sws_weights_log(model = pruned_model, 
                             means = list(gmp.means.data.clone().cpu()), 
                             precisions = list(gmp.gammas.data.clone().cpu()),
                            epoch=30,
                            accuracy=evaluate(pruned_model, testloader),
                            savefile="lenet")

get_sparsity(pruned_model)

import imageio
images = []
filenames = ["./figs/lenet_{}.png".format(x) for x in range(1,101)]
for filename in filenames:
    images.append(imageio.imread(filename))
imageio.mimsave('./figs/lenet_sws_weights.gif', images)

images = []
filenames = ["./figs/jp_lenet_{}.png".format(x) for x in range(1,101)]
for filename in filenames:
    images.append(imageio.imread(filename))
imageio.mimsave('./figs/lenet_jp.gif', images)

![JP](../figs/lenet_jp.gif "SWS")

![JP](../figs/lenet_sws_weights.gif "SWS")
