In [1]:
# importing necessary modules

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms


import numpy as np

#### Downloading the dataset

In [2]:
train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=64, shuffle=True)

test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=False, transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=1000, shuffle=True)

#### Basic fully-connected MNIST architecture with helper functions

In [3]:
# helper function for model evaluation
def eval(model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data, target
            output = F.log_softmax(model(data), dim=1)
            test_loss += F.nll_loss(output, target, reduction='sum').item()
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [4]:
class MnistNet(nn.Module):
    
    def __init__(self):
        super(MnistNet,self).__init__()
        
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 10)
        
    def forward(self,x):
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [5]:
# training baseline model
from tqdm import trange

net = MnistNet()
optimizer = optim.Adam(net.parameters(), lr=1e-3)

num_epoch = 10
for epoch in trange(num_epoch):
    for data, target in train_loader:
        optimizer.zero_grad()
        data, target = data, target
        output = F.log_softmax(net(data), dim=1)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()

100%|██████████| 10/10 [01:41<00:00, 10.23s/it]


In [6]:
eval(net, test_loader)


Test set: Average loss: 0.0999, Accuracy: 9773/10000 (98%)


In [7]:
torch.save(net.state_dict(), 'baseline_mnist.model')

## SVD

#### SVD, no finetuning

In [8]:
net_full_svd = MnistNet()
net_full_svd.load_state_dict(torch.load('baseline_mnist.model'))

k = 25
for i, name in enumerate(['fc1.weight']):
    weight = net_full_svd.state_dict()[name].numpy()
    u, s, vT = np.linalg.svd(weight, full_matrices=False)
    s[k:] = 0
    net_full_svd.state_dict()[name].copy_(torch.tensor(u @ np.diag(s) @ vT))

In [9]:
eval(net_full_svd, test_loader)

num_params = 784 * 256 
new_num_params = 784 * k + k * k + k * 256
print('Compression rate: x{:.4f}'.format(1. * num_params / new_num_params))


Test set: Average loss: 0.1509, Accuracy: 9628/10000 (96%)
Compression rate: x7.5382


#### SVD, finetuning last

In [10]:
net_svd = MnistNet()
net_svd.load_state_dict(torch.load('baseline_mnist.model'))

k = 15
for i, name in enumerate(['fc1.weight']):
    weight = net_svd.state_dict()[name].numpy()
    u, s, vT = np.linalg.svd(weight, full_matrices=False)
    s[k:] = 0
    net_svd.state_dict()[name].copy_(torch.tensor(u @ np.diag(s) @ vT))

In [11]:
num_epoch = 5
optimizer = optim.Adam(list(net_svd.parameters())[2:], lr=1e-3)

for epoch in trange(num_epoch):
    for data, target in train_loader:
        optimizer.zero_grad()
        data, target = data, target
        output = F.log_softmax(net_svd(data), dim=1)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()

100%|██████████| 5/5 [00:40<00:00,  8.04s/it]


In [12]:
eval(net_svd, test_loader)

num_params = 784 * 256
new_num_params = 784 * k + k * k + k * 256
print('Compression rate: x{:.4f}'.format(1. * num_params / new_num_params))


Test set: Average loss: 0.1244, Accuracy: 9645/10000 (96%)
Compression rate: x12.6827


## Sparse Variational Dropout

#### Linear variational layer

In [13]:
class varLinear(nn.Module):
    
    def __init__(self, shape, prune_val):
        super(varLinear, self).__init__()

        self.weight = nn.Parameter((0.02) ** 0.5 * torch.randn(shape[1], shape[0]))
        self.logstd = nn.Parameter(-5.0 * torch.ones(shape[1], shape[0]))
        self.bias = nn.Parameter(torch.zeros(1, shape[1]))
        
        self.prune_val = prune_val
        self.training = True
        
    def forward(self, x):
        self.log_alpha = self.logstd * 2.0 - 2.0 * torch.log(1e-16 + torch.abs(self.weight))
        self.log_alpha = torch.clamp(self.log_alpha, -10, 10)
        
        if self.training:
            lrt_mean = F.linear(x, self.weight)  + self.bias
            lrt_std = torch.sqrt(F.linear(x * x, torch.exp(self.logstd * 2.0) + 1e-8))
            eps = torch.randn_like(lrt_std)
            return lrt_mean + lrt_std * eps
        
        pruned = self.weight * (self.log_alpha < self.prune_val).float()
        return F.linear(x, pruned) + self.bias
        
    def kl(self):
        # KL divergence approximation (Molchanov et al.)
        k1, k2, k3 = torch.Tensor([0.63576]), torch.Tensor([1.8732]), torch.Tensor([1.48695])
        kl = k1 * torch.sigmoid(k2 + k3 * self.log_alpha) - 0.5 * torch.log1p(torch.exp(-self.log_alpha))
        kl = - kl.sum()
        return kl

#### MLP with linear variational layer

In [14]:
class MnistNetVar(nn.Module):
    def __init__(self, prune_val):
        super(MnistNetVar, self).__init__()
        self.fc1 = varLinear((784, 256), prune_val)
        self.fc2 = varLinear((256,  10), prune_val)
        self.prune_val = prune_val

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [15]:
varnet = MnistNetVar(prune_val=3.0)

In [16]:
optimizer = optim.Adam(varnet.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50,60,70,80], gamma=0.2)

In [17]:
kl_weight = 0.07
epochs = 100

for epoch in range(1, epochs + 1):
    scheduler.step()
    varnet.train()
    kl_weight = min(kl_weight+0.07, 1) # warming-up kl
    
    for batch_idx, (input, target) in enumerate(train_loader):
        optimizer.zero_grad()
        
        output = F.log_softmax(varnet(input), dim=1)
        loss = F.nll_loss(output, target) * len(train_loader.dataset)
        
        for module in varnet.children():
            loss += module.kl() * kl_weight
                
        loss.backward()
        optimizer.step()

In [18]:
eval(varnet, test_loader)

num_params = 784 * 256 + 256 * 10
new_num_params = 0
shapes = [784 * 256, 256 * 10]
for i, c in enumerate(varnet.children()):
    new_num_params += (c.log_alpha.data.numpy() < varnet.prune_val).mean() * shapes[i]

print('Compression rate: x{:.4f}'.format(1. * num_params / new_num_params))


Test set: Average loss: 0.0584, Accuracy: 9841/10000 (98%)
Compression rate: x23.3933
