<a href="https://colab.research.google.com/github/nisha1729/lottery-ticket/blob/master/Conv2_iterative_pruning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
pip install inferno-pytorch

In [0]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from inferno.extensions.layers.reshape import Flatten
from builtins import range
from math import sqrt, ceil
  
  
def weights_init(m):
    if type(m) == nn.Linear:
        m.weight.data.normal_(0.0, 1e-3)
        m.bias.data.fill_(0.)

def update_lr(optimizer, lr):
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

#--------------------------------
# Device configuration
#--------------------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device: %s'%device)

#--------------------------------
# Hyper-parameters
#--------------------------------
input_size = 3
num_classes = 10
hidden_size = [64,128,256,512]
fc_size = 256
num_epochs = 1
batch_size = 200
learning_rate = 2e-3
learning_rate_decay = 0.95
reg=0.001
num_training= 49000
num_validation =1000
norm_layer = None
prune_percent = 20
prune_iter = 50


#  Initialising the initial mask with all ones
mask_layerwise = {};


#-------------------------------------------------
# Load the CIFAR-10 dataset
#-------------------------------------------------
data_aug_transforms = []

norm_transform = transforms.Compose(data_aug_transforms+[transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                     ])
test_transform = transforms.Compose([transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                     ])
cifar_dataset = torchvision.datasets.CIFAR10(root='datasets/',
                                           train=True,
                                           transform=norm_transform,
                                           download=True)

test_dataset = torchvision.datasets.CIFAR10(root='datasets/',
                                          train=False,
                                          transform=test_transform
                                          )
#-------------------------------------------------
# Prepare the training and validation splits
#-------------------------------------------------
mask = list(range(num_training))
train_dataset = torch.utils.data.Subset(cifar_dataset, mask)
mask = list(range(num_training, num_training + num_validation))
val_dataset = torch.utils.data.Subset(cifar_dataset, mask)

#-------------------------------------------------
# Data loader
#-------------------------------------------------
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)

val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                           batch_size=batch_size,
                                           shuffle=False)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size,
                                          shuffle=False)


#-------------------------------------------------
# Convolutional neural network
#-------------------------------------------------
class ConvNet(nn.Module):
    def __init__(self, input_size, hidden_layers, num_classes, norm_layer=None):
        super(ConvNet, self).__init__()
        #######################################################################################################
        # The Conv-2, Conv-4, and Conv-6 architectures are variants of the VGG (Simonyan & Zisserman,
        # 2014) network architecture scaled down for the CIFAR10 (Krizhevsky & Hinton, 2009) dataset. Like
        # VGG, the networks consist of a series of modules. Each module has two layers of 3x3 convolutional
        # filters followed by a maxpool layer with stride 2. After all of the modules are two fully-connected
        # layers of size 256 followed by an output layer of size 10; in VGG, the fully-connected layers are of
        # size 4096 and the output layer is of size 1000. Like VGG, the first module has 64 convolutions in
        # each layer, the second has 128, the third has 256, etc. The Conv-2, Conv-4, and Conv-6 architectures
        # have 1, 2, and 3 modules, respectively.
        #######################################################################################################
        layers = []
        
        # 1 module for Conv-2
        layers.append(nn.Conv2d(input_size, 64 , kernel_size = 3, stride = 1, padding = 1))
#         layers.append(nn.BatchNorm2d(64))
        layers.append(nn.ReLU())
        layers.append(nn.Conv2d(64, 64 , kernel_size = 3, stride = 1, padding = 1))
#         layers.append(nn.BatchNorm2d(64))
        layers.append(nn.ReLU())
        layers.append(nn.MaxPool2d(kernel_size = 2, stride = 2, padding = 0))
        
        # Fully Connected Layers
        layers.append(Flatten())
        
        layers.append(nn.Linear(16384, fc_size))
        layers.append(nn.Linear(fc_size, fc_size))
        
        # Output Layer
        layers.append(nn.Linear(fc_size, num_classes))
        
        self.layers = nn.Sequential(*layers)
        

    def forward(self, x):
        out = self.layers(x)
        for name,param in model.named_parameters():
          if 'weight' in name:
            new_param = param*mask_layerwise[name].float()
#             if '0' in name:
#               print("forward prun")
#               print(new_param)
#               break
            model.state_dict()[name].data.copy_(new_param)
        return out
      
def PrintModelSize(model, disp=True):
    model_sz = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print("Number of trainable parameters = ", model_sz)
    return model_sz

  
def VisualizeFilter(model):
    w = model.layers[0].weight
    w_grid = torchvision.utils.make_grid(w,8, normalize = True, scale_each = True)
    w_grid = w_grid.permute(2, 1, 0)
    plt.imshow(w_grid.detach().cpu().numpy())    
    return w_grid;
  
def prune(prune_percent, model, mask_layerwise):
    for name, param in model.named_parameters():
        if 'weight' in name:
            sort_array = torch.sort(torch.abs(torch.masked_select(param, mask_layerwise[name]))).values
            thres_index = int(prune_percent*len(sort_array)/100)
            threshold = sort_array[thres_index]
            new_mask = torch.where(torch.abs(param).cpu() <= threshold.cpu(), torch.zeros(mask_layerwise[name].shape).byte().cpu(), mask_layerwise[name].cpu()).byte().cuda()
            mask_layerwise[name] = new_mask        # Updating mask
#             print(mask_layerwise[name])
#     for name, param in model.named_parameters():
      
#       print(param)
#       break

  
model = ConvNet(input_size, hidden_size, num_classes, norm_layer=norm_layer).to(device)
model.apply(weights_init)

# Initialising mask with all ones
for name, param in model.named_parameters():
    mask_layerwise[name] = (torch.ones(param.size()).byte().to(device))

# Print the model
print(model)


# Print model size
PrintModelSize(model)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=reg)
loss_history = []

# Train the model
lr = learning_rate
max_val_acc = 0
train_accuracy = []
val_accuracy = []
total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
      
        # Move tensors to the configured device
        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(images)
        
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_history.append(loss.item())
        if (i+1) % 100 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
                   .format(epoch+1, num_epochs, i+1, total_step, loss.item()))
    
    # TRAINING
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in train_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        train_accuracy.append(100*correct/total)
    
    # Code to update the lr
    lr *= learning_rate_decay
    update_lr(optimizer, lr)
    model.eval()
    
    # VALIDATION
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        val_accuracy.append(100*correct/total)

        print('Validataion accuracy is: {} %'.format(100 * correct / total))
        
        # Saving best model
        if(correct/total>max_val_acc):
          print("Saving the model...")
          torch.save(model.state_dict(), 'model_early.ckpt')
          max_val_acc = correct/total

    model.train()

# plt.plot(loss_history)

# TESTING

# Test the model before pruning
model.eval()

# Load the best model
best_model = torch.load("model_early.ckpt")
model.load_state_dict(best_model)

with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        if total == 1000:
            break

    print('Accuracy of the network before pruning on the {} test images: {} %'.format(total, 100 * correct / total))

print("*********Pruning***********")
prune_accuracy = []
for iter in range(prune_iter):

  # Pruning the network
  prune(prune_percent, model, mask_layerwise)

  # Print first param tensor
#   for param in model.parameters():
#     print(param)
#     break
for parma in model.paranter
  # Test the model after pruning
  with torch.no_grad():

      correct = 0
      total = 0
      for images, labels in test_loader:
          images = images.to(device)
          labels = labels.to(device)
          outputs = model(images)
          _, predicted = torch.max(outputs.data, 1)
          total += labels.size(0)
          correct += (predicted == labels).sum().item()
          if total == 1000:
              break
      acc = 100 * correct / total
      print('Accuracy of the network after pruning on the {} test images: {} %'.format(total, acc))
      print("Iteration no: %d, Percent = %d", iter, prune_percent)

  prune_accuracy.append(acc)
plt.plot(prune_accuracy, label = "Validation")
# plt.plot(train_accuracy, label = "Train")
plt.show()


# Save the model checkpoint
torch.save(model.state_dict(), 'model.ckpt')


Using device: cuda
Files already downloaded and verified
ConvNet(
  (layers): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Flatten()
    (6): Linear(in_features=16384, out_features=256, bias=True)
    (7): Linear(in_features=256, out_features=256, bias=True)
    (8): Linear(in_features=256, out_features=10, bias=True)
  )
)
Number of trainable parameters =  4301642
Epoch [1/1], Step [100/245], Loss: 1.4773
Epoch [1/1], Step [200/245], Loss: 1.1658
Validataion accuracy is: 54.7 %
Saving the model...
Accuracy of the network before pruning on the 1000 test images: 55.8 %
*********Pruning***********
Accuracy of the network after pruning on the 1000 test images: 55.4 %
Iteration no: %d, Percent = %d 0 20
Accuracy of the network after pruning on the 1000 

IndexError: ignored

In [0]:
  
def prune(prune_percent, model, mask_layerwise):
    for name, param in model.named_parameters():
        if 'weight' in name:
            sort_array = torch.sort(torch.abs(torch.masked_select(param, mask_layerwise[name]))).values
            thres_index = int(prune_percent*len(sort_array)/100)
            threshold = sort_array[thres_index]
            new_mask = torch.where(torch.abs(param).cpu() <= threshold.cpu(), torch.zeros(mask_layerwise[name].shape).byte().cpu(), mask_layerwise[name].cpu()).byte().cuda()
            mask_layerwise[name] = new_mask        # Updating mask
    for name, param in model.named_parameters():
      print(mask_layerwise[name])
      break

print("*********Pruning***********")
prune_accuracy = []

for iter in range(prune_iter):

  # Pruning the network
  prune(prune_percent, model, mask_layerwise)

  # Print first param tensor
  for param in model.parameters():
    print(param)
    break

  # Test the model after pruning
  with torch.no_grad():

      correct = 0
      total = 0
      for images, labels in test_loader:
          images = images.to(device)
          labels = labels.to(device)
          outputs = model(images)
          _, predicted = torch.max(outputs.data, 1)
          total += labels.size(0)
          correct += (predicted == labels).sum().item()
          if total == 1000:
              break
      acc = 100 * correct / total
      print('Accuracy of the network after pruning on the {} test images: {} %'.format(total, acc))
      print("Iteration no: %d, Percent = %d", iter, prune_percent)

  prune_accuracy.append(acc)
plt.plot(prune_accuracy, label = "Validation")
# plt.plot(train_accuracy, label = "Train")
plt.show()


# Save the model checkpoint
torch.save(model.state_dict(), 'model.ckpt')


IndentationError: ignored

In [0]:
print(param.size())

torch.Size([64, 3, 3, 3])


In [0]:
# current_mask = torch.ones(param.size()).to(device)
percent = 10
# for param in model.named_parameters():
print("Before pruning")
print(param)
mask = torch.ones(param.size()).byte().to(device)
print(type(param))
print(type(mask))
print(type(torch.zeros(mask.shape)))
sort_array = torch.sort(torch.abs(torch.masked_select(param, mask))).values
print("Sort aray = ", sort_array)
thres_index = int(percent*len(sort_array)/100)
print("Thres index = ", thres_index)
threshold = sort_array[thres_index]
# new_mask = param.le(threshold).float()
mask = mask.cpu()
print(torch.zeros(mask.shape).byte().cpu().is_cuda)
print(mask.is_cuda)
new_mask = torch.where(torch.abs(param).cpu() <= threshold.cpu(), torch.zeros(mask.shape).byte().cpu(), mask.cpu()).float().cuda()
param_new = param*new_mask
print("Threshold = ", threshold)
print(new_mask)
print("After pruning")
print(param_new)


In [0]:
print(mask_layerwise)

In [0]:
print(model)

In [0]:
for param in model.named_parameters():
#   print(name)
  print(param)
  

In [0]:
\#  Working Code

#  Initialising the initial mask with all ones
mask_layerwise = {};
for name, param in model.named_parameters():
    mask_layerwise[name] = (torch.ones(param.size()).byte().to(device))


# Pruning the weights 
for name, param in model.named_parameters():
  if 'weight' in name:
    sort_array = torch.sort(torch.abs(torch.masked_select(param, mask_layerwise[name]))).values
    thres_index = int(percent*len(sort_array)/100)
    threshold = sort_array[thres_index]
    new_mask = param.ge(threshold).float()
    new_param = param*new_mask
    mask_layerwise[name] = new_mask        # Updating mask
    param = new_param                      # Updating param


In [0]:
for param in model.parameters():
  current_mask = torch.ones(param.size()).to(device)
  param, new_mask = prune(prune_percent, param, current_mask)

# Test the model after pruning
model.eval()

# TODO: Early Stopping

with torch.no_grad():
    
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        if total == 1000:
            break

    print('Accuracy of the network on the {} test images: {} %'.format(total, 100 * correct / total))

plt.plot(val_accuracy, label = "Validation")
plt.plot(train_accuracy, label = "Train")
plt.show()


In [0]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

x = torch.tensor([[1,2,3,4],[2,4,1,3],[5,8,2,4],[5,3,6,1],[6,4,2,6]]).float()
percent = 20
print("Original x")
print(x)
print(torch.masked_select(x, mask))
# current_mask = x.ge(-100)
mask = torch.ones(x.size()).byte()
x_sort_tensor = torch.sort(x)
x_sort_array = torch.sort(torch.abs(torch.masked_select(x, mask))).values
print(x_sort_array)
thres_index = int(percent*len(x_sort_array)/100)
threshold = x_sort_array[thres_index]
print(thres_index)
print(type(x))
print(type(mask))
print(type(torch.zeros(mask.shape)))
new_mask = torch.where(torch.abs(x) <= threshold, torch.zeros(mask.shape).byte(), mask).float()
# new_mask = x.ge(threshold).float()
print("Threshold = ", threshold)
print("New Mask")
print(new_mask)
x_new = x*new_mask
print("Pruned x")
print(x_new)

In [0]:
from torchvision import models
from torchsummary import summary
summary(model, (3, 32, 32))

In [0]:
print(model)