In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

from torchvision.datasets import MNIST, CIFAR10
from torchvision.transforms import Compose, ToTensor, Normalize
from torchvision import transforms


import copy
import types

In [2]:
torch.manual_seed(42)

<torch._C.Generator at 0x7fcd509147d0>

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [4]:
def snip_forward_conv2d(self, x):
        return F.conv2d(x, self.weight * self.weight_mask, self.bias,
                        self.stride, self.padding, self.dilation, self.groups)


def snip_forward_linear(self, x):
        return F.linear(x, self.weight * self.weight_mask, self.bias)

In [5]:
def SNIP_mask_add(net):
    # TODO: shuffle?

    #net = copy.deepcopy(net)

    #grads_abs = []

    for layer in net.modules():
        #print("This is layer", layer)
        if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
            layer.weight_mask = nn.Parameter(torch.ones_like(layer.weight))
            #nn.init.xavier_normal_(layer.weight)
            #layer.weight.requires_grad = False
            #print("abcd")
            #print(layer.weight_mask)

        # Override the forward methods:
        if isinstance(layer, nn.Conv2d):
            layer.forward = types.MethodType(snip_forward_conv2d, layer)

        if isinstance(layer, nn.Linear):
            layer.forward = types.MethodType(snip_forward_linear, layer)

def SNIP_mask_quantize(net, keep_ratio):
    #grads_abs=[]
    layer_num=0
    keep_masks = []
    for layer in net.modules():
        if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
            grads_abs = torch.abs(layer.weight_mask.grad)

            # Gather all scores in a single vector and normalise
            all_scores = torch.flatten(grads_abs) 
            norm_factor = torch.sum(all_scores)
            all_scores.div_(norm_factor)

            num_params_to_keep = int(len(all_scores) * keep_ratio[layer_num])
            threshold, _ = torch.topk(all_scores, num_params_to_keep, sorted=True)
            acceptable_score = threshold[-1]
            

            #for g in grads_abs:
            #g = grad_abs
            keep_masks.append(((grads_abs / norm_factor) >= acceptable_score).float())
            #print(grads_abs.shape, all_scores.shape, layer_num, num_params_to_keep, acceptable_score, norm_factor)
            layer_num +=1
            #print(torch.sum(torch.cat([torch.flatten(x == 1) for x in keep_masks])))

    return (keep_masks)

In [6]:
def apply_prune_mask(net, keep_masks):

    # Before I can zip() layers and pruning masks I need to make sure they match
    # one-to-one by removing all the irrelevant modules:
    prunable_layers = filter(
        lambda layer: isinstance(layer, nn.Conv2d) or isinstance(
            layer, nn.Linear), net.modules())

    for layer, keep_mask in zip(prunable_layers, keep_masks):
        assert (layer.weight.shape == keep_mask.shape)

        # mask[i] == 0 --> Prune parameter
        # mask[i] == 1 --> Keep parameter

        # Step 1: Set the masked weights to zero (NB the biases are ignored)
        # Step 2: Make sure their gradients remain zero
        layer.weight.data[keep_mask == 0.] = 0.


In [7]:
class BasicBlock(nn.Module):

    def __init__(self, in_channels, out_channels, stride=1, downsample=None, base_width=1, padding=1, batch_norm=None):
        super(BasicBlock, self).__init__()
        if batch_norm is None:
            bn_layer = nn.Batchnorm2D
        else:
            bn_layer = batch_norm

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=padding, bias=False)
        self.bn1 = bn_layer(out_channels)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, bias=False)
        self.bn2 = bn_layer(out_channels)

        self.downsample = downsample
        if self.downsample is None:
            self.shortcut = nn.Identity()
        else:
            self.shortcut = self.downsample

        self.stride = stride

        self.relu2 = nn.ReLU(inplace=True)

    def forward(self, x):

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += self.shortcut(x)
        out = self.relu2(out)
        return out


class ResNet34(nn.Module):

    def __init__(self, layers, num_classes=10, zero_init_residual=False, base_width=64, batch_norm=None):
        # def make_layer(self, block, planes, blocks, stride= 1, dilate = False):
        super(ResNet34, self).__init__()
        block = BasicBlock
        if batch_norm is None:
            bn_layer = nn.BatchNorm2d
        self.bn_layer = bn_layer

        self.conv_out_channels = 64
        self.in_channels = self.conv_out_channels
        self.base_width = base_width
        self.conv1 = nn.Conv2d(3, self.conv_out_channels, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = bn_layer(self.conv_out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self.container(block, 64, layers[0])
        self.layer2 = self.container(block, 128, layers[1], stride=2, dilate=False)
        self.layer3 = self.container(block, 256, layers[2], stride=2, dilate=False)
        self.layer4 = self.container(block, 512, layers[3], stride=2, dilate=False)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                #nn.init.xavier_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                nn.init.xavier_normal_(m.weight)
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def container(self, block, in_channels, num_basicblocks, stride=1, dilate=False):
        bn_layer = self.bn_layer
        downsample = None
        if stride != 1:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channels, in_channels, kernel_size=1, stride=stride, bias=False),
                bn_layer(in_channels),
            )

        layers = []
        layers.append(
            block(self.in_channels, in_channels, stride, downsample, self.base_width, padding=1, batch_norm=bn_layer))
        self.in_channels = in_channels
        for basic_blocks in range(1, num_basicblocks):
            layers.append(block(self.in_channels, in_channels, base_width=self.base_width, batch_norm=bn_layer))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

In [9]:
def network_init():
    
    
  net = ResNet34([3, 4, 6, 3])
  optimiser = optim.SGD( net.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)
  scheduler = optim.lr_scheduler.StepLR(optimiser, lr_decay_interval, gamma=0.1)


  train_transform = transforms.Compose([
      transforms.RandomCrop(32, padding=4),
      transforms.RandomHorizontalFlip(),
      transforms.ToTensor(),
      transforms.Normalize((0.4914, 0.4822, 0.4465),
                          (0.2023, 0.1994, 0.2010)),
  ])

  test_transform = transforms.Compose([
      transforms.RandomCrop(32, padding=4),
      transforms.ToTensor(),
      transforms.Normalize((0.4914, 0.4822, 0.4465),
                          (0.2023, 0.1994, 0.2010)),
  ])

  train_dataset = CIFAR10('_dataset', True, train_transform, download=True)
  test_dataset = CIFAR10('_dataset', False, test_transform, download=False)

  train_loader = DataLoader( train_dataset, batch_size, shuffle=True, num_workers=2, pin_memory=True)
  val_loader = DataLoader( test_dataset, batch_size, shuffle=False, num_workers=2, pin_memory=True)

  return net, optimiser, scheduler, train_loader, val_loader

In [10]:
def training(epoch, model, optimizer, scheduler, criterion, device, train_loader):
  model.train()
  avg_loss = 0.0
  av_loss=0.0
  total=0
  for batch_num, (feats, labels) in enumerate(train_loader):
      feats, labels = feats.to(device), labels.to(device)
      
      optimizer.zero_grad()

      outputs = model(feats)


      loss = criterion(outputs, labels.long())
      loss.backward()
      
      optimizer.step()
      
      avg_loss += loss.item()
      av_loss += loss.item() 
      total +=len(feats) 
      # if batch_num % 10 == 9:
      #     print('Epoch: {}\tBatch: {}\tAv-Loss: {:.4f}'.format(epoch+1, batch_num+1, av_loss/10))
      #     av_loss = 0.0

      torch.cuda.empty_cache()
      del feats
      del labels
      del loss

  del train_loader

  return avg_loss/total
  

In [11]:
def validate(epoch, model, criterion, device, data_loader):
    with torch.no_grad():
        model.eval()
        running_loss, accuracy,total  = 0.0, 0.0, 0

        
        for i, (X, Y) in enumerate(data_loader):
            
            X, Y = X.to(device), Y.to(device)
            output= model(X)
            loss = criterion(output, Y.long())

            _,pred_labels = torch.max(F.softmax(output, dim=1), 1)
            pred_labels = pred_labels.view(-1)
            
            accuracy += torch.sum(torch.eq(pred_labels, Y)).item()

            running_loss += loss.item()
            total += len(X)

            torch.cuda.empty_cache()
            
            del X
            del Y
        
        return running_loss/total, accuracy/total

In [12]:
batch_size = 128
lr = 0.1
weight_decay = 0.0005
epochs = 70
lr_decay_interval = 20

In [13]:
# lr_scheduler = optim.lr_scheduler.StepLR(
#         optimiser, 20, gamma=0.1)

In [None]:
if __name__ == '__main__':

      net, optimiser, lr_scheduler, train_loader, val_loader = network_init()
      net = net.to(device)
      # Pre-training pruning using SKIP
      #keep_masks = SNIP(net, 0.01, train_loader, device)  # TODO: shuffle?
      #apply_prune_mask(net, keep_masks)
      keep_ratio = [1] * 37
      keep_ratio[5:36] = [0.005]*32
      keep_ratio[:4]=[0.05]*5
      
      criterion = nn.CrossEntropyLoss()


      for epoch in range(epochs):
          #print("This is model", sum(1 for _ in net.parameters()))
          SNIP_mask_add(net)

          train_loss = training(epoch, net, optimiser, lr_scheduler, criterion, device,train_loader)

          val_loss, val_acc = validate(epoch, net, criterion, device, val_loader)

          lr_scheduler.step()

          keep_masks = SNIP_mask_quantize(net, keep_ratio)  # TODO: shuffle?
          apply_prune_mask(net, keep_masks)

          #print(net.layer[0], net.layer[0].weight.shape)
          print('Epoch: {} \t train-Loss: {:.4f}, \tval-Loss: {:.4f}, \tval-acc: {:.4f}'.format(epoch+1,  train_loss, val_loss, val_acc))
