In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from torch.autograd import Variable
from torchvision import datasets, transforms
import itertools

from tqdm import tqdm_notebook
from torch.utils.data.sampler import SubsetRandomSampler
import math
import numpy as np

In [None]:
def build_dataset(dataset='MNIST', dataset_dir='./data', batch_size=100,total_training_size=100):
    dataset_ = {
        'MNIST': datasets.MNIST,
        'CIFAR10': datasets.CIFAR10
    }[dataset]
    
    transform = {
        'MNIST': transforms.ToTensor(),
        'CIFAR10': transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])
    }[dataset]
    
    train_dataset = dataset_(root=dataset_dir,
                             train=True,
                             transform=transform,
                             download=True)
    
    
    
    
    indices = list(range(len(train_dataset)))
    train_idx= indices[:total_training_size]

    
    train_sampler = SubsetRandomSampler(train_idx)

    train_loader = data.DataLoader(dataset=train_dataset,
                                  batch_size=batch_size,
                                  sampler=train_sampler)
    #print(len(train_loader))
    
    test_dataset = dataset_(root=dataset_dir,
                             train=False,
                             transform=transform,
                             download=True)

    test_loader = data.DataLoader(dataset=test_dataset,
                                  batch_size=batch_size,
                                  shuffle=False)
    
    return train_loader, test_loader

In [None]:
class GaussianDropout(nn.Module):
    def __init__(self, alpha=1.0):
        super(GaussianDropout, self).__init__()
        self.alpha = torch.Tensor([alpha])
        
    def forward(self, x):
        """
        Sample noise   e ~ N(1, alpha)
        Multiply noise h = h_ * e
        """
        if self.train():
            # N(1, alpha)
            epsilon = torch.randn(x.size()) * self.alpha + 1

            epsilon = Variable(epsilon)
            if x.is_cuda:
                epsilon = epsilon.cuda()

            return x * epsilon
        else:
            return x

In [None]:
class VariationalDropout(nn.Module):
    def __init__(self, alpha=1.0, dim=None):
        super(VariationalDropout, self).__init__()
        self.dim = dim
        self.max_alpha = alpha
        # Initial alpha
        log_alpha = (torch.ones(dim) * alpha).log()
        self.log_alpha = nn.Parameter(log_alpha)
        
    def kl(self):
        c1 = 1.16145124
        c2 = -1.50204118
        c3 = 0.58629921
        alpha = self.log_alpha.exp()
        
        negative_kl = 0.5 * self.log_alpha + c1 * alpha + c2 * alpha**2 + c3 * alpha**3
        
        kl = -negative_kl
        
        return kl.mean()
    
    def forward(self, x):
        """
        Sample noise   e ~ N(1, alpha)
        Multiply noise h = h_ * e
        """
        if self.train():
            # N(0,1)
            epsilon = Variable(torch.randn(x.size()))
            if x.is_cuda:
                epsilon = epsilon.cuda()

            # Clip alpha
            self.log_alpha.data = torch.clamp(self.log_alpha.data, max=self.max_alpha)
            alpha = self.log_alpha.exp()

            # N(1, alpha)
            epsilon = epsilon * alpha

            return x * epsilon
        else:
            return x

In [None]:
def dropout(p=None, dim=None, method='standard'):
    if method == 'standard':
        return nn.Dropout(p)
    elif method == 'gaussian':
        return GaussianDropout(p/(1-p))
    elif method == 'variational':
        return VariationalDropout(p/(1-p), dim)

In [None]:
class Net(nn.Module):
    def __init__(self,
                 image_dim=28*28,
                 dropout_method='standard',dropout_rate=0.2):
        super(Net, self).__init__()
        """3-Layer Fully-connected NN"""
        

        self.net = nn.Sequential(
            nn.Linear(image_dim, 200),
            dropout(dropout_rate, 200, dropout_method),
            nn.ReLU(),
            nn.Linear(200, 200),
            dropout(dropout_rate, 200, dropout_method),
            nn.ReLU(),
            nn.Linear(200, 10)        
        )
        
    def kl(self):
        kl = 0
        for name, module in self.net.named_modules():
            if isinstance(module, VariationalDropout):
                kl += module.kl().sum()
        return kl
        
            
    def forward(self, x):
        return self.net(x)

In [None]:
class Solver(object):
    def __init__(self, dropout_method='standard', dataset='MNIST', n_epochs=50, lr=0.005,dropout_rate=0.2,batch_size=100,total_training_size=100):
        self.n_epochs = n_epochs
        
        self.batch_size = batch_size
        
        self.train_loader, self.test_loader = build_dataset(dataset, './data',batch_size=batch_size,total_training_size=total_training_size)
        
        self.image_dim = {'MNIST': 28*28, 'CIFAR10': 3*32*32}[dataset]
        
        self.dropout_method = dropout_method
        
        self.dropout_rate = dropout_rate
        
        self.total_training_size = total_training_size
        
        self.net = Net(
            image_dim=self.image_dim,
            dropout_method=dropout_method,dropout_rate=dropout_rate).cuda()
        self.loss_fn = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(self.net.parameters(), lr=lr)
                
    def train(self):
        self.net.train()
        
        for epoch_i in tqdm_notebook(range(self.n_epochs)):
            epoch_i += 1
            epoch_loss = 0
            epoch_kl = 0
            counter=0
            for images, labels in self.train_loader:
              
                counter = counter + 1
                images = Variable(images).view(-1, self.image_dim).cuda()
                labels = Variable(labels).cuda()
                logits = self.net(images)
                
                loss = self.loss_fn(logits, labels)
                
                if self.dropout_method == 'variational':
                    kl = self.net.kl()
                    total_loss = loss + kl / 10
                else:
                    total_loss = loss

                self.optimizer.zero_grad()
                total_loss.backward()

                self.optimizer.step()
                
                epoch_loss += float(loss.data)

                if self.dropout_method == 'variational':
                    epoch_kl += float(kl.data)

            if not self.dropout_method == 'variational':                                
                epoch_loss /= len(images)

                print("Epoch =%d | loss=%.4f"%(epoch_i,epoch_loss))
            else:
                epoch_loss /= len(images)
                epoch_kl /= len(images)

                print("Epoch =%d | loss=%.4f,kl:%.4f "%(epoch_i,epoch_loss,epoch_kl))
            
    def evaluate(self):
        total = 0
        correct = 0
        self.net.eval()
        for images, labels in self.test_loader:
            images = Variable(images).view(-1, self.image_dim).cuda()

            logits = self.net(images)
            
            _, predicted = torch.max(logits.data, 1)
            
            total += labels.size(0)
            correct += (predicted.cpu() == labels).sum()
                
        #print(f'Accuracy: {100 * correct / total:.2f}%')
        print("Accuracy : %.2f"%(100 * correct / total))
        return 100 * correct / total

In [None]:
batch_sizesDict = {100,150,200,250,300,350,400,450,500,550,600,650,700,750,800,850,900,950,1000}
results = []
#for bat in batch_sizesDict:
for bat in range(100,1050, 50):
    standard_solver = Solver(dropout_method='standard',n_epochs=50,dropout_rate=0.25,batch_size=bat,total_training_size=50000)

    standard_solver.train()
    result = standard_solver.evaluate()
    results.append((bat,result))

In [None]:
print("Printing Standart Dropout Results:")
print(results)

In [None]:
results_variational= []
#for bat in batch_sizesDict:
for bat in range(100,150, 50):  
    variational_solver = Solver('variational',n_epochs=50,batch_size=bat,total_training_size=50000)

    variational_solver.train()
    result = variational_solver.evaluate()
    results_variational.append((bat,result))

In [None]:
print("Printing Variational Dropout Results:")
print(results_variational)

In [None]:
result_pruned_weights = []
for bat in range(100,150, 50): 
    variational_solver2 = Solver('variational',n_epochs=50,batch_size=bat,total_training_size=50000)
    variational_solver2.train()

    result = variational_solver2.evaluate()
    #print(result)
    ##find mean and variance

    alpha_arr = []
    for alp in variational_solver2.net.net[1].log_alpha:
        alpha_arr.append(alp.data.exp())
    mean = np.mean(alpha_arr)
    var = np.var(alpha_arr)
    std  = np.std(alpha_arr)
    #print(var)
    #print(std)

    index=0
    number_of_elements = 0
    for i in variational_solver2.net.net[0].state_dict()["weight"]:
        index=index+1
        index_j = 0
        for j in i:
            index_j = index_j+1
            
            alpha = variational_solver2.net.net[1].log_alpha[index-1].data.exp()
            weight = j
            sigma_sequare = (weight*weight)*alpha
            if alpha[0]>(mean+std):
                #variational_solver2.net.net[0].state_dict()["weight"][index-1][index_j-1]=0
                number_of_elements=number_of_elements+1
    index=0
    for i in variational_solver2.net.net[3].state_dict()["weight"]:
        index=index+1
        index_j = 0
        for j in i:
            index_j = index_j+1
            
            alpha = variational_solver2.net.net[4].log_alpha[index-1].data.exp()
            weight = j
            sigma_sequare = (weight*weight)*alpha
            if alpha[0]>(mean+std):
                #variational_solver2.net.net[3].state_dict()["weight"][index-1][index_j-1]=0
                number_of_elements=number_of_elements+1        
        
    #result = variational_solver2.evaluate()
    percentage = number_of_elements/1568
    print("Pruned weights = %d percentage of pruned=%f"%(number_of_elements,percentage))
    result_pruned_weights.append((bat,number_of_elements,result))

In [None]:
print("Printing number of weights suitable for pruning\nbatch size, prunable weigh,accuracy")
print(result_pruned_weights)

In [None]:
result_pruned_weights_2 = []
for bat in range(100,150, 50): 
    variational_solver2 = Solver('variational',n_epochs=50,batch_size=bat,total_training_size=50000)
    variational_solver2.train()

    result = variational_solver2.evaluate()
    #print(result)
    ##find mean and variance

    alpha_arr = []
    for alp in variational_solver2.net.net[1].log_alpha:
        alpha_arr.append(alp.data.exp())
    mean = np.mean(alpha_arr)
    var = np.var(alpha_arr)
    std  = np.std(alpha_arr)
    #print(var)
    #print(std)

    index=0
    number_of_elements = 0
    for i in variational_solver2.net.net[0].state_dict()["weight"]:
        index=index+1
        index_j = 0
        for j in i:
            index_j = index_j+1
            
            alpha = variational_solver2.net.net[1].log_alpha[index-1].data.exp()
            weight = j
            sigma_sequare = (weight*weight)*alpha
            if alpha[0]>(mean+std):
                variational_solver2.net.net[0].state_dict()["weight"][index-1][index_j-1]=0
                number_of_elements=number_of_elements+1
    index=0
    for i in variational_solver2.net.net[3].state_dict()["weight"]:
        index=index+1
        index_j = 0
        for j in i:
            index_j = index_j+1
            
            alpha = variational_solver2.net.net[4].log_alpha[index-1].data.exp()
            weight = j
            sigma_sequare = (weight*weight)*alpha
            if alpha[0]>(mean+std):
                variational_solver2.net.net[3].state_dict()["weight"][index-1][index_j-1]=0
                number_of_elements=number_of_elements+1        
        
    result_after_prune = variational_solver2.evaluate()
    percentage = number_of_elements/1568
    #print("Pruned weights = %d percentage of pruned=%f"%(number_of_elements,percentage))
    result_pruned_weights_2.append((bat,number_of_elements,result,result_after_prune))

In [None]:
print("Printing results after pruned weights \nbatch size, pruned weights, accuracy before prune, accuracy after prune")
print(result_pruned_weights_2)