In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

from torchvision.datasets import CIFAR10, CIFAR100
from torchvision.transforms import Compose, ToTensor, Normalize
from torchvision import transforms

import copy
import types
import time

In [2]:
torch.manual_seed(42)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [3]:
def forward_new(self, x):
  return F.conv2d(x, self.weight * self.w_mask, self.bias,\
                         self.stride, self.padding, self.dilation, self.groups) if isinstance(self, nn.Conv2d)\
                         else F.linear(x, self.weight * self.w_mask, self.bias)

In [4]:
def layer_mask_gen(model, keep_ratio):
    layer_num=0
    masks = []
    for layer in model.modules():
      
      if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
        absolute_gradient =torch.abs(layer.w_mask.grad)  

        value = absolute_gradient.reshape(-1, )
        #value = torch.hstack(value)
        sum_of_values = value.sum()
        final_val = value/sum_of_values

        req_params = (keep_ratio[layer_num] * len(final_val) )
        req_params = int(req_params)
        top_K = torch.topk(final_val, req_params, sorted=True)[0]

        masks.append(absolute_gradient/sum_of_values >= top_K[len(top_K)-1])
        layer_num +=1           

    return masks


In [5]:
def mask_gen(model, keep_ratio):
    absolute_gradients= []
    absolute_gradients =[torch.abs(layer.w_mask.grad) for layer in model.modules() if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear)] 

    value = [each_lay_grad.reshape(-1, ) for each_lay_grad in absolute_gradients]
    value = torch.hstack(value)
    sum_of_values = value.sum()
    final_val = value/sum_of_values

    req_params = (len(final_val) * keep_ratio)
    req_params = int(req_params)
    top_K = torch.topk(final_val, req_params, sorted=True)[0]
    
    
    masks = []
    masks = [layer_gradient/sum_of_values >= top_K[len(top_K)-1] for layer_gradient in absolute_gradients]

    return masks

In [6]:
def fb_training(model_fb, keep_ratio, train_dataloader, generate_mask, device):
    
    X, Y = next(iter(train_dataloader))
    X = X.to(device)
    Y = Y.to(device)

    
    model_fb = copy.deepcopy(model_fb)
    

    for layer in model_fb.modules():
      #print("current_layer _before", layer)
      if isinstance(layer, nn.Conv2d) or isinstance(layer,nn.Linear):
        layer.w_mask = nn.Parameter(torch.ones(layer.weight.shape).to(device)) 
        nn.init.xavier_normal_(layer.weight)
        layer.weight.requires_grad = False
        layer.forward = types.MethodType(forward_new, layer)
     

    model_fb.zero_grad()
    out = model_fb.forward(X)
    loss = F.nll_loss(out, Y)
    loss.backward()

    return generate_mask(model_fb, keep_ratio)

    

In [7]:
def freeze(gradients):
  return gradients*mask

def activate_hook(mask):
  return freeze(grads)

In [8]:
def mask_app(model, keep_masks):

    layer_prun=[]
    i=0
    for layer in model.modules():
      if layer == nn.Conv2d or layer ==nn.Linear:
        layer_prun.append(layer) 
        assert layer.weight.shape == keep_masks[i].shape
        layer.weight.data[keep_masks[i].shape==0.] =0.
        layer.weight.register_hook(activate_hook(keep_masks[i]))
        i+=1


In [9]:
class VGG16(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()

       

        self.convol = nn.Sequential(
                  nn.Conv2d(3, 64, kernel_size= 3, padding=1), 
                  nn.BatchNorm2d(64),
                  nn.ReLU(inplace=True),
                  nn.Conv2d(64, 64, kernel_size=3, padding=1), 
                  nn.BatchNorm2d(64),
                  nn.ReLU(inplace=True),
                  nn.MaxPool2d(kernel_size=2, stride=2),
                  nn.Conv2d(64, 128, kernel_size = 3, padding=1), 
                  nn.BatchNorm2d(128),
                  nn.ReLU(inplace=True),
                  nn.Conv2d(128, 128, kernel_size = 3, padding=1), 
                  nn.BatchNorm2d(128),
                  nn.ReLU(inplace=True),
                  nn.MaxPool2d(kernel_size=2, stride=2),
                  nn.Conv2d(128, 256, kernel_size = 3, padding=1), 
                  nn.BatchNorm2d(256),
                  nn.ReLU(inplace=True),
                  nn.Conv2d(256, 256, kernel_size=3, padding = 1), 
                  nn.BatchNorm2d(256),
                  nn.ReLU(inplace=True),
                  nn.Conv2d(256, 256, kernel_size = 3, padding= 1), 
                  nn.BatchNorm2d(256),
                  nn.ReLU(inplace=True),
                  nn.MaxPool2d(kernel_size=2, stride=2),
                  nn.Conv2d(256, 512, kernel_size = 3,padding=1), 
                  nn.BatchNorm2d(512),
                  nn.ReLU(inplace=True),
                  nn.Conv2d(512, 512, kernel_size = 3, padding = 1), 
                  nn.BatchNorm2d(512),
                  nn.ReLU(inplace=True),
                  nn.Conv2d(512, 512, kernel_size=3 , padding = 1), 
                  nn.BatchNorm2d(512),
                  nn.ReLU(inplace=True),
                  nn.MaxPool2d(kernel_size=2, stride=2),
                  nn.Conv2d(512, 512, kernel_size = 3, padding = 1), 
                  nn.BatchNorm2d(512),
                  nn.ReLU(inplace=True),
                  nn.Conv2d(512, 512, kernel_size = 3, padding = 1), 
                  nn.BatchNorm2d(512),
                  nn.ReLU(inplace=True),
                  nn.Conv2d(512, 512, kernel_size = 3, padding = 1), 
                  nn.BatchNorm2d(512),
                  nn.ReLU(inplace=True),
                  nn.MaxPool2d(kernel_size=2, stride=2),
                  
              )
        
        self.Linear = nn.Sequential(
                  nn.Linear(512, 512),  
                  nn.ReLU(True),
                  nn.BatchNorm1d(512),  
                  nn.Linear(512, 512),
                  nn.ReLU(True),
                  nn.BatchNorm1d(512),  
                  nn.Linear(512, num_classes),)


    def forward(self, x):
        x = self.convol(x)
        x = x.reshape(x.shape[0], -1)
        x = self.Linear(x)  
        x = F.log_softmax(x, dim=1)
        return x

In [10]:
class SimpleResidualBlock(nn.Module):
    def __init__(self, in_channel_size, out_channel_size, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channel_size, out_channel_size, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel_size)
        self.relu1 = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(out_channel_size, out_channel_size, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel_size)

        if stride == 1:
            self.shortcut = nn.Identity()
        else:
            self.shortcut = nn.Conv2d(in_channel_size, out_channel_size, kernel_size=1, stride=stride, bias=False)
        self.bn_shortcut= nn.BatchNorm2d(out_channel_size)
        self.relu_shortcut = nn.ReLU(inplace=True)
 

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)
        
        out = self.conv2(out)
        out = self.bn2(out)

        shortcut = self.shortcut(x)
        shortcut= self.bn_shortcut(shortcut)
        
        out = self.relu_shortcut(out + shortcut)
        
        return out

    def init_weights(self, m):
        if isinstance(m, nn.Conv2d):
            nn.init.xavier_normal_(m.weight, 1.732)
            if m.bias is not None:
                nn.init.zeros_(m.bias)

In [11]:
class ResNet34(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        
        self.layers = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False),
            
            SimpleResidualBlock(64, 64, 1),
            
            SimpleResidualBlock(64, 64, 1),
            
            SimpleResidualBlock(64, 64, 1),
            
            SimpleResidualBlock(64, 128, 2),
            
            SimpleResidualBlock(128, 128, 1),

            SimpleResidualBlock(128, 128, 1),
            
            SimpleResidualBlock(128, 128, 1),

            SimpleResidualBlock(128, 256, 2),

            SimpleResidualBlock(256, 256, 1),

            SimpleResidualBlock(256, 256, 1),

            SimpleResidualBlock(256, 256, 1),

            SimpleResidualBlock(256, 256, 1),

            SimpleResidualBlock(256, 256, 1),

            SimpleResidualBlock(256, 512, 2),

            SimpleResidualBlock(512, 512, 1),
            
            SimpleResidualBlock(512, 512, 1),
            
            nn.AdaptiveAvgPool2d((1, 1)), 
            
            nn.Flatten(),
        )
        self.relu = nn.ReLU(inplace=True)
        
        self.linear_output = nn.Linear(512,num_classes) 

        
               
    def forward(self, x):
        embedding = self.layers(x) 
        output = self.linear_output(self.relu(embedding))
     
        return output      


    def init_weights(self, m):
        if isinstance(m, nn.Conv2d) or isinstance(m,nn.Linear):
            nn.init.xavier_normal_(m.weight, 1.732)
            if m.bias is not None:
                nn.init.zeros_(m.bias)

In [12]:
def training(epoch, model, optimizer, scheduler, criterion, train_loader):
  model.train()
  avg_loss = 0.0
  av_loss=0.0
  total=0
  for batch_num, (feats, labels) in enumerate(train_loader):
      feats, labels = feats.to(device), labels.to(device)
      
      optimizer.zero_grad()

      outputs = model(feats)


      loss = criterion(outputs, labels.long())
      loss.backward()
      
      optimizer.step()
      
      avg_loss += loss.item()
      av_loss += loss.item() 
      total +=len(feats) 

      torch.cuda.empty_cache()
      del feats
      del labels
      del loss

  del train_loader

  return avg_loss/total
  

In [13]:
def validate(epoch, model, criterion, data_loader):
    with torch.no_grad():
        model.eval()
        running_loss, accuracy,total  = 0.0, 0.0, 0

        
        for i, (X, Y) in enumerate(data_loader):
            
            X, Y = X.to(device), Y.to(device)
            output= model(X)
            loss = criterion(output, Y.long())

            _,pred_labels = torch.max(F.softmax(output, dim=1), 1)
            pred_labels = pred_labels.view(-1)
            
            accuracy += torch.sum(torch.eq(pred_labels, Y)).item()

            running_loss += loss.item()
            total += len(X)

            torch.cuda.empty_cache()
            
            del X
            del Y
        
        return running_loss/total, accuracy/total

In [14]:
CUDA_LAUNCH_BLOCKING=1

In [15]:
if __name__ == '__main__':

      model_dict = {"VGG16": VGG16, "ResNet34": ResNet34}
      dataset_dict = {"CIFAR10":CIFAR10, "CIFAR100":CIFAR100}
      ####arguments#############
      model = ResNet34
      dataset=CIFAR100
      lr_rate = 0.1
      step_size = 20
      batch_size = 128
      weight_decay = 0.0005
      epochs = 70
      layer_wise = False
      ####arguments############

      if layer_wise == False:
        keep_ratio = 0.05
      else:
        if model == VGG16:
          layer_mult = 16
        elif model_dict["ResNet34"] == model:
          layer_mult = 37
        elif model == ResNet50:
          layer_mult = 53
        keep_ratio = [0.05]* layer_mult
      
    
      if dataset == dataset_dict["CIFAR10"]:
        num_classes =10
      elif dataset== dataset_dict["CIFAR100"] :
        num_classes = 100

      generate_mask= mask_gen if layer_wise == False else layer_mask_gen # gloabl pruning or layer_wise pruning
      print(num_classes)
      

      net = model(num_classes=num_classes) #num_classes for CIFAR10 = 10 for CIFAR100= 100
      #print(net)
      optimiser = optim.SGD( net.parameters(), lr=0.1, momentum=0.9, weight_decay=weight_decay)
      scheduler = optim.lr_scheduler.StepLR(optimiser, step_size=20, gamma=0.1)

      transform_train = transforms.Compose([
      transforms.RandomCrop(32, padding=4),
      transforms.RandomHorizontalFlip(),
      transforms.ToTensor(),
      transforms.Normalize((0.4914, 0.4822, 0.4465),
                        (0.2023, 0.1994, 0.2010)),
      ])

      transform_test = transforms.Compose([
      transforms.RandomCrop(32, padding=4),
      transforms.ToTensor(),
      transforms.Normalize((0.4914, 0.4822, 0.4465),
                        (0.2023, 0.1994, 0.2010)),
      ])

      train_dataset = dataset('_dataset', True, transform_train, download=True)
      test_dataset = dataset('_dataset', False, transform_test, download=False)

      train_loader = DataLoader( train_dataset, batch_size, shuffle=True, num_workers=2, pin_memory=True)
      val_loader = DataLoader( test_dataset, batch_size, shuffle=False, num_workers=2, pin_memory=True)

      net = net.to(device)
      
      #masks = fb_training(net, keep_ratio, train_loader, generate_mask,device) 
      #mask_app(net, masks)
      
      
      
      criterion = nn.CrossEntropyLoss()


      for epoch in range(epochs):

          train_loss = training(epoch, net, optimiser, scheduler, criterion,train_loader)

          start_time = time.time()
          val_loss, val_acc = validate(epoch, net, criterion, val_loader)
          end_time = time.time()
          

          scheduler.step()


          print('Epoch: {} \t train-Loss: {:.4f}, \tval-Loss: {:.4f}, \tval-acc: {:.4f} \tinference_time: {:.4f}:'.format(epoch+1,  train_loss, val_loss, val_acc, end_time - start_time))


100
Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to _dataset/cifar-100-python.tar.gz


HBox(children=(FloatProgress(value=0.0, max=169001437.0), HTML(value='')))


Extracting _dataset/cifar-100-python.tar.gz to _dataset
Epoch: 1 	 train-Loss: 0.0330, 	val-Loss: 0.0295, 	val-acc: 0.1238 	inference_time: 1.8495:
Epoch: 2 	 train-Loss: 0.0276, 	val-Loss: 0.0268, 	val-acc: 0.1818 	inference_time: 1.6747:
Epoch: 3 	 train-Loss: 0.0255, 	val-Loss: 0.0255, 	val-acc: 0.2066 	inference_time: 1.7483:
Epoch: 4 	 train-Loss: 0.0240, 	val-Loss: 0.0252, 	val-acc: 0.2391 	inference_time: 1.8480:
Epoch: 5 	 train-Loss: 0.0230, 	val-Loss: 0.0237, 	val-acc: 0.2602 	inference_time: 1.8013:
Epoch: 6 	 train-Loss: 0.0220, 	val-Loss: 0.0236, 	val-acc: 0.2572 	inference_time: 1.7557:
Epoch: 7 	 train-Loss: 0.0214, 	val-Loss: 0.0220, 	val-acc: 0.2900 	inference_time: 1.9088:
Epoch: 8 	 train-Loss: 0.0208, 	val-Loss: 0.0219, 	val-acc: 0.2961 	inference_time: 1.8389:
Epoch: 9 	 train-Loss: 0.0214, 	val-Loss: 0.0310, 	val-acc: 0.2800 	inference_time: 1.7865:
Epoch: 10 	 train-Loss: 0.0203, 	val-Loss: 0.0209, 	val-acc: 0.3203 	inference_time: 1.8490:
Epoch: 11 	 train-Loss