In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

from torchvision.datasets import MNIST, CIFAR10, CIFAR100
from torchvision.transforms import Compose, ToTensor, Normalize
from torchvision import transforms

import copy
import types

In [None]:
torch.manual_seed(42)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
def snip_forward_conv2d(self, x):
        return F.conv2d(x, self.weight * self.weight_mask, self.bias,
                        self.stride, self.padding, self.dilation, self.groups)


def snip_forward_linear(self, x):
        return F.linear(x, self.weight * self.weight_mask, self.bias)

In [None]:
def SNIP_mask_add(net):
    # TODO: shuffle?

    # removed network deep copy and added mask parameters directly into our own network
    for layer in net.modules():
        if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
            layer.weight_mask = nn.Parameter(torch.ones_like(layer.weight))
            #nn.init.xavier_normal_(layer.weight)
            #layer.weight.requires_grad = False
            #print("abcd")
            #print(layer.weight_mask)

        # Override the forward methods:
        if isinstance(layer, nn.Conv2d):
            layer.forward = types.MethodType(snip_forward_conv2d, layer)

        if isinstance(layer, nn.Linear):
            layer.forward = types.MethodType(snip_forward_linear, layer)


def SNIP_mask_quantize(net, keep_ratio):
    # finding the top keep_ratio percentage of weights through out the network.

    grads_abs=[]
    for layer in net.modules():
        if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
            grads_abs.append(torch.abs(layer.weight_mask.grad))

    # Gather all scores in a single vector and normalise
    all_scores = torch.cat([torch.flatten(x) for x in grads_abs])
    norm_factor = torch.sum(all_scores)
    all_scores.div_(norm_factor)

    num_params_to_keep = int(len(all_scores) * keep_ratio)
    threshold, _ = torch.topk(all_scores, num_params_to_keep, sorted=True)
    acceptable_score = threshold[-1]

    keep_masks = []
    for g in grads_abs:
        keep_masks.append(((g / norm_factor) >= acceptable_score).float())
        
    
    return (keep_masks)



In [None]:
def apply_prune_mask(net, keep_masks):

    # Before I can zip() layers and pruning masks I need to make sure they match
    # one-to-one by removing all the irrelevant modules:
    prunable_layers = filter(
        lambda layer: isinstance(layer, nn.Conv2d) or isinstance(
            layer, nn.Linear), net.modules())

    for layer, keep_mask in zip(prunable_layers, keep_masks):
        assert (layer.weight.shape == keep_mask.shape)
        
        # mask[i] == 0 --> Prune parameter
        # mask[i] == 1 --> Keep parameter

        # Step 1: Set the masked weights to zero (NB the biases are ignored)
        # Step 2: Make sure their gradients remain zero
        layer.weight.data[keep_mask == 0.] = 0.


In [None]:
class VGG16(nn.Module):
   def __init__(self, config, num_classes=10):
        super().__init__()

        self.features = self.make_layers(config, batch_norm=True)

        self.classifier = nn.Sequential(
            nn.Linear(512, 512),  # 512 * 7 * 7 in the original VGG
            nn.ReLU(True),
            nn.BatchNorm1d(512),  # instead of dropout
            nn.Linear(512, 512),
            nn.ReLU(True),
            nn.BatchNorm1d(512),  # instead of dropout
            nn.Linear(512, num_classes),
        )

    @staticmethod
    def make_layers(config, batch_norm=False):  # TODO: BN yes or no?
        layers = []
        in_channels = 3
        for v in config:
            if v == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
                if batch_norm:
                    layers += [
                        conv2d,
                        nn.BatchNorm2d(v),
                        nn.ReLU(inplace=True)
                    ]
                else:
                    layers += [conv2d, nn.ReLU(inplace=True)]
                in_channels = v
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)  
        x = F.log_softmax(x, dim=1)
        return x

In [None]:
def network_init():
    
    
  net = VGG16([ 64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M',
        512, 512, 512, 'M'
    ])
  optimiser = optim.SGD( net.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)
  scheduler = optim.lr_scheduler.StepLR(optimiser, lr_decay_interval, gamma=0.1)


  train_transform = transforms.Compose([
      transforms.RandomCrop(32, padding=4),
      transforms.RandomHorizontalFlip(),
      transforms.ToTensor(),
      transforms.Normalize((0.4914, 0.4822, 0.4465),
                          (0.2023, 0.1994, 0.2010)),
  ])

  test_transform = transforms.Compose([
      transforms.RandomCrop(32, padding=4),
      transforms.ToTensor(),
      transforms.Normalize((0.4914, 0.4822, 0.4465),
                          (0.2023, 0.1994, 0.2010)),
  ])

  train_dataset = CIFAR10('_dataset', True, train_transform, download=True)
  test_dataset = CIFAR10('_dataset', False, test_transform, download=False)

  train_loader = DataLoader( train_dataset, batch_size, shuffle=True, num_workers=2, pin_memory=True)
  val_loader = DataLoader( test_dataset, batch_size, shuffle=False, num_workers=2, pin_memory=True)

  return net, optimiser, scheduler, train_loader, val_loader

In [None]:
def training(epoch, model, optimizer, scheduler, criterion, device, train_loader):
  model.train()
  avg_loss = 0.0
  av_loss=0.0
  total=0
  for batch_num, (feats, labels) in enumerate(train_loader):
      feats, labels = feats.to(device), labels.to(device)
      
      optimizer.zero_grad()

      outputs = model(feats)


      loss = criterion(outputs, labels.long())
      loss.backward()
      
      optimizer.step()
      
      avg_loss += loss.item()
      av_loss += loss.item() 
      total +=len(feats) 
      # if batch_num % 10 == 9:
      #     print('Epoch: {}\tBatch: {}\tAv-Loss: {:.4f}'.format(epoch+1, batch_num+1, av_loss/10))
      #     av_loss = 0.0

      torch.cuda.empty_cache()
      del feats
      del labels
      del loss

  del train_loader

  return avg_loss/total
  

In [None]:
def validate(epoch, model, criterion, device, data_loader):
    with torch.no_grad():
        model.eval()
        running_loss, accuracy,total  = 0.0, 0.0, 0

        
        for i, (X, Y) in enumerate(data_loader):
            
            X, Y = X.to(device), Y.to(device)
            output= model(X)
            loss = criterion(output, Y.long())

            _,pred_labels = torch.max(F.softmax(output, dim=1), 1)
            pred_labels = pred_labels.view(-1)
            
            accuracy += torch.sum(torch.eq(pred_labels, Y)).item()

            running_loss += loss.item()
            total += len(X)

            torch.cuda.empty_cache()
            
            del X
            del Y
        
        return running_loss/total, accuracy/total

In [None]:
batch_size = 128
lr = 0.1
weight_decay = 0.0005
epochs = 70
lr_decay_interval = 20

In [None]:
if __name__ == '__main__':

      net, optimiser, lr_scheduler, train_loader, val_loader = network_init()
      net = net.to(device)
      keep_ratio=0.05
      criterion = nn.CrossEntropyLoss()


      for epoch in range(epochs):
          SNIP_mask_add(net)

          train_loss = training(epoch, net, optimiser, lr_scheduler, criterion, device,train_loader)

          val_loss, val_acc = validate(epoch, net, criterion, device, val_loader)

          lr_scheduler.step()

          keep_masks = SNIP_mask_quantize(net, keep_ratio)  # TODO: shuffle?
          apply_prune_mask(net, keep_masks)

          
          print('Epoch: {} \t train-Loss: {:.4f}, \tval-Loss: {:.4f}, \tval-acc: {:.4f}'.format(epoch+1,  train_loss, val_loss, val_acc))
