In [1]:
# This file is for learning both pytorch and variational dropout sparsifies deep neural network
import math
import torch
import time
import numpy as np
import matplotlib.pyplot as plt

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

#from local_logger import Logger
from torch.nn import Parameter
from torchvision import datasets, transforms

import imageio
import seaborn as sns
import cv2
CONV2D=True

In [3]:
#torch.cuda.is_available()

True

In [2]:
class LinearSVDO(nn.Module):
    def __init__(self, in_features, out_features, threshold, bias=True):
        super(LinearSVDO, self).__init__()
        self.in_features = in_features  # the input dimension
        self.out_features = out_features # the output dimension 
        self.threshold = threshold # the threshold that determines whether a neuron is set to be zero 
        
        self.W = Parameter(torch.Tensor(out_features, in_features)) # so I guess it's output channel, inputchannel
        self.log_sigma = Parameter(torch.Tensor(out_features, in_features))
        self.bias = Parameter(torch.Tensor(1, out_features)) # output channel, input channel ?
        
        print("The weight matrix", self.W.shape)
        print("The log sigma", self.log_sigma.shape)
        print("The bias term", self.bias.shape)
        self.reset_parameters()
    
    def reset_parameters(self):
        """This function is used to initialize the parameters"""
        self.bias.data.zero_()
        self.W.data.normal_(0, 0.02)  # the weight matrix follows a normal distribution with mean 0.0 and sigma 0.02
        self.log_sigma.data.fill_(-5)  # the logsigma is initialized to be -5
        
    def forward(self, x):
        self.log_alpha = self.log_sigma * 2.0 - 2.0 * torch.log(1e-16 + torch.abs(self.W))  # self.log_alpha is actually the dropout rate for each element
        self.log_alpha = torch.clamp(self.log_alpha, -10, 10)  # this is equivalent to tf.clip_by_value
        
        if self.training:
            lrt_mean = F.linear(x, self.W) + self.bias
            lrt_std = torch.sqrt(F.linear(x * x, torch.exp(self.log_sigma * 2.0)) + 1e-8)
            eps = lrt_std.data.new(lrt_std.size()).normal_()
            return lrt_mean + lrt_std * eps
    
        return F.linear(x, self.W * (self.log_alpha < 3).float()) + self.bias
    
    def kl_reg(self):
        """This function returns the KL divergence"""
        k1, k2, k3 = torch.Tensor([0.63576]).cuda(), torch.Tensor([1.8732]).cuda(), torch.Tensor([1.48695]).cuda()
        kl = k1 * torch.sigmoid(k2 + k3 * self.log_alpha) - 0.5 * torch.log1p(torch.exp(-self.log_alpha))
        a = -torch.sum(kl)
        return a        

In [3]:
class Conv2DSVDO(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
                 dilation=1, groups=1, ard_init=-10, threshold=3):
#         super(Conv2DSVDO, self).__init__(in_channels, out_channels, kernel_size, stride, 
#                                          padding, dilation, groups)
#        self.bias=None
        super(Conv2DSVDO, self).__init__()
        self.threshold = threshold
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.ard_init = ard_init  #  I think this parameter is for cliping
        self.stride = stride
        self.padding=padding
        self.dilation=dilation
        self.groups=groups
        
        kh, kw = kernel_size, kernel_size
        self.W = Parameter(torch.Tensor(out_channels, in_channels, kh, kw))
        self.bias = Parameter(torch.Tensor(out_channels))
        self.log_sigma = Parameter(torch.Tensor(out_channels, in_channels, kh, kw))
        
        print("The weight matrix", self.W.shape)
        print("The bias shape", self.bias.shape)
        print("The logsigma shape", self.log_sigma.shape)
        
        self.reset_parameters()
    
    def reset_parameters(self):
        self.bias.data.zero_()
        self.W.data.normal_(0, 0.02)
        self.log_sigma.data.fill_(-5)
        
    def forward(self, x):
        eps = 1e-8
        self.log_alpha = self.log_sigma * 2.0 - 2.0 * torch.log(torch.abs(self.W) + 1e-16)
        self.log_alpha = torch.clamp(self.log_alpha, -10.0, 10.0)
        
        if self.training:
            conv_mu = F.conv2d(x, self.W, self.bias, self.stride, self.padding, 
                               self.dilation, self.groups)
            conv_std = torch.sqrt(F.conv2d(x * x, torch.exp(self.log_sigma * 2.0)) + eps)
            noise = torch.normal(torch.zeros_like(conv_mu), torch.ones_like(conv_std))
            conv = conv_mu + conv_std * noise
            return conv
        
        
        return F.conv2d(x, self.W * (self.log_alpha < self.threshold).float(), self.bias,
                        self.stride, self.padding, self.dilation, self.groups)
    def kl_reg(self):
        k1, k2, k3 = torch.Tensor([0.63576]).cuda(), torch.Tensor([1.8732]).cuda(), torch.Tensor([1.48695]).cuda()
        kl = k1 * torch.sigmoid(k2 + k3 * self.log_alpha) - 0.5 * torch.log1p(torch.exp(-self.log_alpha))
        a = -torch.sum(kl)
        return a            

In [4]:
class Net(nn.Module):
    def __init__(self, threshold):
        super(Net, self).__init__()
        self.fc1 = LinearSVDO(28*28, 300, threshold)
        self.fc2 = LinearSVDO(300, 100, threshold)
        self.fc3 = LinearSVDO(100, 10, threshold)
        self.threshold = threshold
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.log_softmax(self.fc3(x), dim=1)
        return x

In [5]:
class Conv2DNet(nn.Module):
    def __init__(self, threshold):
        super(Conv2DNet, self).__init__()
        self.threshold = threshold
        self.conv0 = Conv2DSVDO(1, 20, 5)
        self.conv1 = Conv2DSVDO(20, 50, 5)
        
        self.fc1 = LinearSVDO(50*4*4, 300, threshold)
        self.fc2 = LinearSVDO(300, 10, threshold)
        
    def forward(self, x):
        x = F.relu(self.conv0(x))
        x = F.max_pool2d(x, 2)
#        print("----first conv layer", x.shape)
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
#        print("----second conv layer", x.shape)
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
#        print("----first fc layer", x.shape)
        x = F.relu(self.fc2(x))
#        print("----second fc layer", x.shape)
        return F.log_softmax(x, dim=1)

In [6]:
# This is for loading the dataset
def get_mnist(batch_size):
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
    train_loader = torch.utils.data.DataLoader(datasets.MNIST('./data', train=True, download=True, 
                                                             transform=transform), batch_size=batch_size,
                                              shuffle=True, num_workers=4)
    test_loader = torch.utils.data.DataLoader(datasets.MNIST('./data', train=False, download=True, 
                                                            transform=transform), batch_size=batch_size,
                                             shuffle=True, num_workers=4)
    return train_loader, test_loader

In [7]:
class SGVLB(nn.Module):
    def __init__(self, net, train_size):
        super(SGVLB, self).__init__()
        self.train_size = train_size
        self.net = net
        
    def forward(self, input, target, kl_weight=1.0):
        assert not target.requires_grad
        kl = 0.0
        for module in self.net.children():
#            print(module)
            if hasattr(module, 'kl_reg'):
                kl = kl + module.kl_reg()
        return F.cross_entropy(input, target) * self.train_size + kl_weight * kl

In [8]:
if CONV2D is True:
    model = Conv2DNet(threshold=3).cuda()
else:
    model = Net(threshold=3).cuda()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50, 60, 70, 80], 
                                                gamma=0.2)
train_loader, test_loader = get_mnist(batch_size=100)
sgvlb = SGVLB(model, len(train_loader.dataset)).cuda()

The weight matrix torch.Size([20, 1, 5, 5])
The bias shape torch.Size([20])
The logsigma shape torch.Size([20, 1, 5, 5])
The weight matrix torch.Size([50, 20, 5, 5])
The bias shape torch.Size([50])
The logsigma shape torch.Size([50, 20, 5, 5])
The weight matrix torch.Size([300, 800])
The log sigma torch.Size([300, 800])
The bias term torch.Size([1, 300])
The weight matrix torch.Size([10, 300])
The log sigma torch.Size([10, 300])
The bias term torch.Size([1, 10])


In [17]:
# model = Net(threshold=3).cuda()
# optimizer = optim.Adam(model.parameters(), lr=1e-3)
# scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50, 60, 70, 80], 
#                                                 gamma=0.2)
# # fmt = {"tr_loss": '3.1e', "te_loss": "3.1e", "sp_0": ".3f", "sp_1": ".3f", "lr":"3.1e", "kl":".2f"}
# # logger = Logger("sparse_vd", fmt=fmt)

# train_loader, test_loader = get_mnist(batch_size=100)
# sgvlb = SGVLB(model, len(train_loader.dataset)).cuda()

The weight matrix torch.Size([300, 784])
The log sigma torch.Size([300, 784])
The bias term torch.Size([1, 300])
The weight matrix torch.Size([100, 300])
The log sigma torch.Size([100, 300])
The bias term torch.Size([1, 100])
The weight matrix torch.Size([10, 100])
The log sigma torch.Size([10, 100])
The bias term torch.Size([1, 10])


In [15]:
w_use_conv0 = model.conv0.W.cpu().detach().numpy()
w_use_conv1 = model.conv1.W.cpu().detach().numpy()
w_use_fc0 = model.fc1.W.cpu().detach().numpy()
w_use_fc1 = model.fc2.W.cpu().detach().numpy()
group = [w_use_conv0, w_use_conv1, w_use_fc0, w_use_fc1]
[print(v.shape) for v in group]


(20, 1, 5, 5)
(50, 20, 5, 5)
(300, 800)
(10, 300)


[None, None, None, None]

In [9]:
kl_weight = 0.02
epochs = 60
W_group = [[] for i in range(4)]
val_accu = []
for epoch in range(1, epochs + 1):
    start = time.time()
    model.train()
    train_loss, train_acc = 0.0, 0.0
    kl_weight = min(kl_weight+0.02, 1)
    for batch_idx, (data, target) in enumerate(train_loader):
        if CONV2D is True:
            data = data.view(-1, 1, 28, 28)
        else:
            data = data.view(-1, 28*28)
        data = data.cuda()
        target = target.cuda()
        
        optimizer.zero_grad()
        
        output = model(data)
        pred = output.data.max(1)[1]
        loss = sgvlb(output, target, kl_weight)
        loss.backward()

        optimizer.step()
        
        train_loss += loss
        train_acc += np.sum(pred.cpu().numpy() == target.cpu().data.numpy())
    model.eval()
    test_loss, test_acc = 0.0, 0.0
    for batch_idx, (data, target) in enumerate(test_loader):
        if CONV2D is True:
            data = data.view(-1, 1, 28, 28)
        else:
            data = data.view(-1, 28*28)
        data = data.cuda()
        target = target.cuda()
        output = model(data)
        test_loss += float(sgvlb(output, target, kl_weight))
        pred = output.data.max(1)[1]
        test_acc += np.sum(pred.cpu().numpy() == target.cpu().data.numpy())
    if epoch % 2 == 0 or epoch == epochs:
        for c_iter, c in enumerate(model.children()):
            W_group[c_iter].append(c.log_alpha.cpu().data.numpy())
        val_accu.append(test_acc / len(test_loader.dataset) * 100)
    end = time.time()
    
    loss_vec = ["epoch", epoch, "trlos", train_loss/len(train_loader.dataset), 
                "tracc", train_acc/len(train_loader.dataset) * 100,
                "valos", test_loss/ len(test_loader.dataset), 
                "valacc", test_acc/ len(test_loader.dataset) * 100]
    for i, c in enumerate(model.children()):
        if hasattr(c, 'kl_reg'):
            loss_vec.append((c.log_alpha.cpu().data.numpy() > model.threshold).mean())
    loss_vec.append("time")
    loss_vec.append(end-start)
    
    if CONV2D is True:
        print("{:s}{: d} {:s}{: .3f} {:s}{: .3f} {:s}{: .3f} {:s}{: .3f} {: .2f}{: .2f}{: .2f}{: .2f} {:s}{: .3f}".format(*loss_vec))          
    else:
        print("{:s}{: d} {:s}{: .3f} {:s}{: .3f} {:s}{: .3f} {:s}{: .3f} {: .2f}{: .2f}{: .2f} {:s}{: .3f}".format(*loss_vec))          
    
    scheduler.step()



epoch 1 trlos 227.264 tracc 87.232 valos 18.966 valacc 98.070  0.01 0.12 0.41 0.09 time 13.762
epoch 2 trlos-15.424 tracc 98.207 valos-38.794 valacc 98.700  0.03 0.26 0.58 0.13 time 14.473
epoch 3 trlos-65.177 tracc 98.480 valos-80.467 valacc 98.810  0.03 0.37 0.69 0.17 time 15.731
epoch 4 trlos-109.354 tracc 98.818 valos-121.362 valacc 98.980  0.06 0.47 0.75 0.21 time 15.914
epoch 5 trlos-146.817 tracc 98.788 valos-150.556 valacc 98.820  0.06 0.51 0.76 0.24 time 15.171
epoch 6 trlos-184.531 tracc 98.928 valos-186.870 valacc 98.900  0.07 0.55 0.79 0.27 time 15.766
epoch 7 trlos-220.764 tracc 98.940 valos-226.872 valacc 99.070  0.14 0.62 0.84 0.33 time 14.965
epoch 8 trlos-256.723 tracc 99.005 valos-258.985 valacc 98.840  0.14 0.64 0.85 0.34 time 12.393
epoch 9 trlos-291.590 tracc 99.010 valos-293.901 valacc 99.090  0.21 0.65 0.86 0.33 time 15.960
epoch 10 trlos-326.124 tracc 99.058 valos-331.300 valacc 99.130  0.24 0.68 0.89 0.38 time 15.837
epoch 11 trlos-361.324 tracc 99.045 valos-36

In [10]:
all_w, kep_w = 0, 0
for c in model.children():
    kep_w += (c.log_alpha.cpu().data.numpy() < model.threshold).sum()
    all_w += c.log_alpha.cpu().data.numpy().size
print("the kept weight ratio is %.3f" % (kep_w / all_w * 100))


the kept weight ratio is 1.753


In [20]:
def generate_animation(save_name, feature, rate, test_accu):
    save_dir = '/project/bo/exp_data/'
    with imageio.get_writer(save_dir + '/%s.gif' % save_name, mode='I', fps=5) as writer:
        for iterr, single_feature in enumerate(feature):
            fig = plt.figure(figsize=(4,2.5))
            ax = fig.add_subplot(111)
            sns.heatmap(abs(single_feature), cmap='Reds')
            ax.set_xticks([])
            ax.set_yticks([])
            ax.set_title("epoch %d sparsity %.2f test accuracy %.2f" % (iterr * 2, rate[iterr] * 100, test_accu[iterr]), fontsize=6)
            plt.savefig(save_dir + '/im.png', pad_inches=0, bbox_inches='tight', dpi=300)
            plt.close()
            im = cv2.imread(save_dir + '/im.png')[:, :, ::-1]
            writer.append_data(im)

In [22]:
def create_canvas(feature):
    """This function is used to create canvas
    Args: 
        feature [out_channel, in_channel, kh, kw]
    """
    nx, ny, fh, fw = np.shape(feature)
    x_values = np.linspace(-3, 3, nx)
    y_values = np.linspace(-3, 3, ny)
    canvas = np.empty((fw * nx, fh * ny))
    
    for i, yi in enumerate(x_values):
        f_sub = feature[i]
        for j, xj in enumerate(y_values):
            f_use = f_sub[j]
#             f_use[:, -1] = 10.0
#             f_use[:, 0] = 10.0
#             f_use[-1, :] = 10.0
#             f_use[0, :] = 10.0
            canvas[(nx - i - 1) * fh:(nx - i) * fh,
                   j * fw:(j + 1) * fw] = f_use        
    return canvas

In [23]:
for iterr, save_feature in enumerate(W_group):
    num_dim = np.shape(save_feature[0])
    num_element_tot = np.prod(num_dim)
    rate = [np.sum((v >= model.threshold).astype('int32')) / num_element_tot for v in save_feature]
    save_feature = [v * (v < model.threshold).astype('int32') for v in save_feature]
    if len(num_dim) == 4:
        save_name = "feature_conv_%d" % iterr
        save_feature = [create_canvas(v) for v in save_feature]
    elif len(num_dim) == 2:
        save_name = "feature_fc_%d" % iterr
    generate_animation(save_name, save_feature, rate, val_accu)

In [33]:
for iterr, save_feature in enumerate(W_group):
    save_name = "feature_anim_%d" % iterr
    num_element_tot = np.prod(np.shape(save_feature[0]))
    rate = [np.sum((v >= model.threshold).astype('int32')) / num_element_tot for v in save_feature]
    save_feature = [v * (v < model.threshold).astype('int32') for v in save_feature]
    generate_animation(save_name, save_feature, rate)