# Mask Training Based on Lottery Ticket Paper

Pytorch Implementation of [**Deconstructing Lottery Tickets: Zeros, Signs, and the Supermask**](https://arxiv.org/abs/1905.01067) ( from UberAI )

## Prepare Dataset

### Load MNIST Dataset

In [None]:
from torchvision import datasets, transforms

train_set = datasets.MNIST('./data', train=True, download=True)
test_set = datasets.MNIST('./data', train=False, download=True)

train_set_x = train_set.data.numpy() / 255.0
test_set_x = test_set.data.numpy() / 255.0

train_set_y = train_set.targets.numpy()
test_set_y = test_set.targets.numpy()

print(train_set_y[0],train_set_y.shape)
print(train_set_x[0].shape,train_set_x.shape)

5 (60000,)
(28, 28) (60000, 28, 28)


### Define and Load A Sequential Dataset Class

In [None]:
from torch.utils.data import Sampler,SequentialSampler,BatchSampler,Dataset,DataLoader
import torch
import numpy as np

class SimpleDataset(Dataset):
    def __init__(self, x_values, y_values, out_shape):
        self.X = x_values
        self.y = y_values
        self.out_shape = out_shape

    def __len__(self):
        return (len(self.X))
        
    def __getitem__(self, index):
        return (torch.as_tensor(self.X[index].reshape(self.out_shape),dtype=torch.float32),
                torch.as_tensor(self.y[index],dtype=torch.long))
    

## Define Custom Classes

### Define Bernouli function with gradient

As the implementation of pytorch does not support derivation and gradient.

In [None]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR

class Bern(torch.autograd.Function):
    """
    Custom Bernouli function that supports gradients.
    The original Pytorch implementation of Bernouli function,
    does not support gradients.

    First-Order gradient of bernouli function with prbabilty p, is p.

    Inputs: Tensor of arbitrary shapes with bounded values in [0,1] interval
    Outputs: Randomly generated Tensor of only {0,1}, given Inputs as distributions.
    """
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return torch.bernoulli(input)

    @staticmethod
    def backward(ctx, grad_output):      
        pvals = ctx.saved_tensors
        return pvals[0] * grad_output


### Define a MaskedLinear layer

Which is a custom fully connected linear layer that its weights $W_f$ remain constant once initialized randomly.
A second weight matrix $W_m$ with the same shape as $W_f$ is used for generating a binary mask. This weight matrix can be trained through backpropagation. Each unit of $W_f$ may be passed through sigmoid function to generate the $p$ value of the $Bern(p)$ function.

In [None]:
class MaskedLinear(nn.Module):
    """
    Which is a custom fully connected linear layer that its weights $W_f$ 
    remain constant once initialized randomly.
    A second weight matrix $W_m$ with the same shape as $W_f$ is used for
    generating a binary mask. This weight matrix can be trained through
    backpropagation. Each unit of $W_f$ may be passed through sigmoid
    function to generate the $p$ value of the $Bern(p)$ function.
    """
    def __init__(self, in_features, out_features, device=None):
        super(MaskedLinear, self).__init__()
        self.device = device

        # Fully Connected Weights
        self.fcw = torch.randn((out_features,in_features),requires_grad=False,device=device)
        # Weights of Mask
        self.mask = nn.Parameter(torch.randn_like(self.fcw,requires_grad=True,device=device))        

    def forward(self, x):        
        # Generate probability of bernouli distributions
        s_m = torch.sigmoid(self.mask)
        # Generate a binary mask based on the distributions
        g_m = Bern.apply(s_m)
        # Keep weights where mask is 1 and set others to 0
        effective_weight = self.fcw * g_m            
        # Apply the effective weight on the input data
        lin = F.linear(x, effective_weight)

        return lin
        
    def __str__(self):        
        prod = torch.prod(*self.fcw.shape).item()
        return 'Mask Layer: \n FC Weights: {}, {}, MASK: {}'.format(self.fcw.sum(),torch.abs(self.fcw).sum(),self.mask.sum() / prod)


### Define a Custom Masked ANN

A simple fully connected masked network for our test purposes.

In [None]:
class MaskANN(nn.Module):
    def __init__(self, device):
        super(MaskANN, self).__init__()
        self.ml1 = MaskedLinear(784, 1200,device)        
        self.ml2 = MaskedLinear(1200, 1200,device)
        self.ml3 = MaskedLinear(1200,10,device)      

    def forward(self, x):        
        x = self.ml1(x)
        x = F.relu(x)        
        x = self.ml2(x)        
        x = F.relu(x)        
        x = self.ml3(x)
                
        return x

    def get_layers(self):
        return [self.ml1, self.ml2, self.ml3]

    def print_weights(self):
        print('FC 1: ', self.ml1.weight.sum().item(), torch.abs(self.ml1.weight).sum().item())
        print('FC 2: ', self.ml2.weight.sum().item(), torch.abs(self.ml2.weight).sum().item())
        print('FC 3: ', self.ml3.weight.sum().item(), torch.abs(self.ml3.weight).sum().item())

### Defintion of application functions: train, test, and main

In [None]:
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    loss_func = torch.nn.CrossEntropyLoss()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(data)                
        loss = loss_func(output, target)
        loss.backward()
        optimizer.step()                
        print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        epoch, batch_idx * len(data), len(train_loader.dataset),
        100. * batch_idx / len(train_loader), loss.item()))

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)

            output = model(data)

            test_loss += F.cross_entropy(output, target).item() 
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

mask_net = None

def main():

    seed = 0
    use_cuda = torch.cuda.is_available()
    torch.manual_seed(seed)
    device = torch.device("cuda" if use_cuda else "cpu")

    train_ds = SpikeDataset(train_set_x,train_set_y,(784,))
    test_ds = SpikeDataset(test_set_x,test_set_y,(784,))

    batch_size = 100
    epochs = 10
    learning_rate = 0.001
    momentum = 0.1

    train_loader = torch.utils.data.DataLoader(
        train_ds, batch_size=batch_size, shuffle=False,drop_last=True)

    test_loader = torch.utils.data.DataLoader(
        test_ds, batch_size=batch_size, shuffle=False,drop_last=True)

    global mask_net
    mask_net = MaskANN(device)
    
    #optimizer = optim.SGD(mask_net.parameters(), lr=learning_rate, momentum=momentum)    
    optimizer = optim.Adam(mask_net.parameters())
    
    #scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    for epoch in range(1, epochs + 1):
        train(mask_net, device, train_loader, optimizer, epoch)
        test(mask_net, device, test_loader)        

if __name__ == '__main__':
    main()

Parameter values of MaskNet:
wayt:  tensor(750996.1250, grad_fn=<SumBackward0>)

Test set: Average loss: 10.3730, Accuracy: 3942/10000 (39%)

wayt:  tensor(752205., grad_fn=<SumBackward0>)

Test set: Average loss: 6.6777, Accuracy: 5324/10000 (53%)

wayt:  tensor(753496., grad_fn=<SumBackward0>)

Test set: Average loss: 4.8698, Accuracy: 6226/10000 (62%)

wayt:  tensor(754822.0625, grad_fn=<SumBackward0>)

Test set: Average loss: 3.8178, Accuracy: 6768/10000 (68%)

wayt:  tensor(756206.2500, grad_fn=<SumBackward0>)

Test set: Average loss: 3.0681, Accuracy: 7172/10000 (72%)

wayt:  tensor(757582.5000, grad_fn=<SumBackward0>)

Test set: Average loss: 2.7339, Accuracy: 7462/10000 (75%)

wayt:  tensor(759038.6250, grad_fn=<SumBackward0>)

Test set: Average loss: 2.3422, Accuracy: 7587/10000 (76%)

wayt:  tensor(760453.0625, grad_fn=<SumBackward0>)

Test set: Average loss: 2.1755, Accuracy: 7732/10000 (77%)

wayt:  tensor(761976.5000, grad_fn=<SumBackward0>)

Test set: Average loss: 1.7397