# Sparse Variational Dropout

note: typos in images, at a few places $\alpha_{i,j}w_{i,j}$ should be $\alpha_{i,j}\mu_{i,j}$

![alt text](https://ars-ashuha.github.io/images/ss_vd1.png)
![alt text](https://ars-ashuha.github.io/images/ss_vd2.png)

- Variational Dropout Sparsifies Deep Neural Networks https://arxiv.org/abs/1701.05369
- Cheating link https://github.com/ars-ashuha/sparse-vd-pytorch/blob/master/svdo-solution.ipynb

# Install 

In [None]:
# Logger
# if you have problems with this import
# check that you are working with python3
# and downloaded logger.py file to the folder with this notebook
from logger import Logger

# Implementation

In [None]:
import torch
import numpy as np

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from logger import Logger
from torch.nn import Parameter
from torchvision import datasets, transforms

In [None]:
# Load a dataset
def get_mnist(batch_size):
    trsnform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../day2_vae/mnist', train=True, download=False,
        transform=trsnform), batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../day2_vae/mnist', train=False, download=False,
        transform=trsnform), batch_size=batch_size, shuffle=True)

    return train_loader, test_loader

In [None]:
class LinearSVDO(nn.Module):
    def __init__(self, in_features, out_features, threshold, bias=True):
        super(LinearSVDO, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.threshold = threshold

        self.W = Parameter(torch.Tensor(out_features, in_features))
        self.log_sigma = Parameter(torch.Tensor(out_features, in_features))
        self.bias = Parameter(torch.Tensor(1, out_features))
        
        self.reset_parameters()

    def reset_parameters(self):
        self.bias.data.zero_()
        self.W.data.normal_(0, 0.02)
        self.log_sigma.data.fill_(-5)        
        
    def forward(self, x): 
        if self.training:
            
            # Compute activation's mean e.g x.dot(W) + b
            lrt_mean = F.linear(x, self.W, bias=self.bias) 
            # Compute activation's var e.g sqrt((x*x).dot(sigma * sigma) + 1e-8)
            lrt_std = torch.sqrt(F.linear(x*x, torch.exp(2*self.log_sigma)) + 1e-8) 
            
            eps = torch.randn_like(lrt_mean)
            return lrt_mean + lrt_std * eps
        
        self.log_alpha =  2 * (self.log_sigma - torch.log(1e-16 + torch.abs(self.W)))
        self.log_alpha = torch.clamp(self.log_alpha, -10, 10) # Clip log alpha to be in [-10, 10] for numerical stability 
        W = self.W * (self.log_alpha < 3).float() # Prune out redundant wights e.g. W * mask(log_alpha < 3) 
        return F.linear(x, W) + self.bias
        
    def kl_reg(self):
        self.log_alpha =  2 * (self.log_sigma - torch.log(1e-16 + torch.abs(self.W)))
        self.log_alpha = torch.clamp(self.log_alpha, -10, 10) 
        k1, k2, k3 = torch.Tensor([0.63576]), torch.Tensor([1.8732]), torch.Tensor([1.48695])
        KL = k1*torch.sigmoid(k2 + k3*self.log_alpha) - 0.5 * torch.log1p(torch.exp(-self.log_alpha))
        return -KL.sum()


In [None]:
# Define a simple 2 layer Network
class Net(nn.Module):
    def __init__(self, threshold):
        super(Net, self).__init__()
        self.fc1 = LinearSVDO(28*28, 300, threshold)
        self.fc2 = LinearSVDO(300,  10, threshold)
        self.threshold = threshold

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.log_softmax(self.fc2(x), dim=1)
        return x

In [None]:
# Define a new Loss Function -- SGVLB 
class SGVLB(nn.Module):
    def __init__(self, net, train_size):
        super(SGVLB, self).__init__()
        self.train_size = train_size
        self.net = net

    def forward(self, input, target, kl_weight=1.0):
        assert not target.requires_grad
        kl = torch.Tensor([0.0])
        for module in self.net.children():
            if hasattr(module, 'kl_reg'):
                kl = kl + module.kl_reg()
                
        # Compute Stochastic Gradient Variational Lower Bound
        SGVLB = F.cross_entropy(input, target) * self.train_size + kl_weight * kl
        return SGVLB

In [None]:
model = Net(threshold=3)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50,60,70,80], gamma=0.2)

fmt = {'tr_los': '3.1e', 'te_loss': '3.1e', 'sp_0': '.3f', 'sp_1': '.3f', 'lr': '3.1e', 'kl': '.2f'}
logger = Logger('sparse_vd', fmt=fmt)

train_loader, test_loader = get_mnist(batch_size=100)
sgvlb = SGVLB(model, len(train_loader.dataset))

In [None]:
kl_weight = 0.02
epochs = 100

for epoch in range(1, epochs + 1):
    scheduler.step()
    model.train()
    train_loss, train_acc = 0, 0 
    kl_weight = min(kl_weight+0.02, 1)
    logger.add_scalar(epoch, 'kl', kl_weight)
    logger.add_scalar(epoch, 'lr', scheduler.get_lr()[0])
    for batch_idx, (data, target) in enumerate(train_loader):
        data = data.view(-1, 28*28)
        optimizer.zero_grad()
        
        output = model(data)
        pred = output.data.max(1)[1] 
        loss = sgvlb(output, target, kl_weight)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        train_acc += np.sum(pred.numpy() == target.data.numpy())

    logger.add_scalar(epoch, 'tr_los', train_loss / len(train_loader.dataset))
    logger.add_scalar(epoch, 'tr_acc', train_acc / len(train_loader.dataset) * 100)
    
    
    model.eval()
    test_loss, test_acc = 0, 0
    for batch_idx, (data, target) in enumerate(test_loader):
        data = data.view(-1, 28*28)
        output = model(data)
        test_loss += float(sgvlb(output, target, kl_weight))
        pred = output.data.max(1)[1] 
        test_acc += np.sum(pred.numpy() == target.data.numpy())
        
    logger.add_scalar(epoch, 'te_loss', test_loss / len(test_loader.dataset))
    logger.add_scalar(epoch, 'te_acc', test_acc / len(test_loader.dataset) * 100)
    
    for i, c in enumerate(model.children()):
        if hasattr(c, 'kl_reg'):
            logger.add_scalar(epoch, 'sp_%s' % i, (c.log_alpha.data.numpy() > model.threshold).mean())
            
    logger.iter_info()

In [None]:
all_w, kep_w = 0, 0

for c in model.children():
    kep_w += (c.log_alpha.data.numpy() < model.threshold).sum()
    all_w += c.log_alpha.data.numpy().size

print('kept weight ratio = ', all_w/kep_w)

    # Good result should be like 
    #   epoch    kl       lr    tr_los    tr_acc    te_loss    te_acc    sp_0    sp_1
    #  -------  ----  -------  --------  --------  ---------  --------  ------  ------
    #      100     1  1.6e-06  -1.4e+03      98.0   -1.4e+03      98.3   0.969   0.760
    # keept weight ratio = 30.109973454683352

# Visualization

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

from matplotlib import rcParams
rcParams['figure.figsize'] = 16, 3
rcParams['figure.dpi'] = 300


log_alpha = (model.fc1.log_alpha.detach().numpy() < 3).astype(np.float)
W = model.fc1.W.detach().numpy()

plt.imshow(log_alpha * W, cmap='hot', interpolation=None)
plt.colorbar()

In [None]:
s = 0
from matplotlib import rcParams
rcParams['figure.figsize'] = 8, 5

z = np.zeros((28*15, 28*15))

for i in range(15):
    for j in range(15):
        s += 1
        z[i*28:(i+1)*28, j*28:(j+1)*28] =  np.abs((log_alpha * W)[s].reshape(28, 28))
        
plt.imshow(z, cmap='hot_r')
plt.colorbar()
plt.axis('off')

# Compression with Sparse Matrixes

In [None]:
import scipy
import numpy as np
from scipy.sparse import csc_matrix, csc_matrix, coo_matrix, dok_matrix

row, col, data = [], [], []
M = list(model.children())[0].W.data.numpy()
LA = list(model.children())[0].log_alpha.data.numpy()

for i in range(300):
    for j in range(28*28):
        if LA[i, j] < 3:
            row += [i]
            col += [j]
            data += [M[i, j]]

Mcsr = csc_matrix((data, (row, col)), shape=(300, 28*28))
Mcsc = csc_matrix((data, (row, col)), shape=(300, 28*28))
Mcoo = coo_matrix((data, (row, col)), shape=(300, 28*28))

In [None]:
np.savez_compressed('M_w', M)
scipy.sparse.save_npz('Mcsr_w', Mcsr)
scipy.sparse.save_npz('Mcsc_w', Mcsc)
scipy.sparse.save_npz('Mcoo_w', Mcoo)

In [None]:
ls -lah | grep _w