In [35]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import Compose, ToTensor, Normalize
from torchvision import transforms
import numpy as np
import os
from sklearn.cluster import KMeans
import copy

In [30]:
class QuantizeNetwork(object):
  def __init__(self, verbose = True):
    self.model = None
    self.num_cluster = None
    self.verbose = verbose
    self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

  def quantize_network(self,model,num_cluster):
    self.model = copy.deepcopy(model)
    self.num_cluster = num_cluster
    self._k_means_quantization(self.model.features._modules.items())
    self._k_means_quantization(self.model.classifier._modules.items())
    self.model = torch.quantization.quantize_dynamic( self.model, {torch.nn.BatchNorm2d,torch.nn.Conv2d},  dtype=torch.qint8) 
    return self.model

  def _k_means_quantization(self,modules):
    for layer, (name, module) in enumerate(modules):
      if not isinstance(module,nn.ReLU) and not isinstance(module,nn.MaxPool2d):
        weight = module.weight.data.cpu().numpy()
        org_shape =  module.weight.shape
        flatten_weights = weight.flatten()
        old_unique_weights = np.unique(flatten_weights)
        space = np.linspace(np.min(flatten_weights), np.max(flatten_weights), num=2**self.num_cluster)
        kclusters = KMeans(n_clusters=len(space), init=space.reshape(-1,1), n_init=1, precompute_distances=True, algorithm="full")
        kclusters.fit(weight.reshape(-1,1))
        new_weight = kclusters.cluster_centers_[kclusters.labels_].reshape(-1)
        new_unique_weights = np.unique(new_weight)
        module.weight.data = torch.from_numpy(new_weight.reshape(org_shape)).to(self.device)
        if self.verbose:
          print('layer_names [',module,'] -> unique weights count old:',len(old_unique_weights),', new:',len(new_unique_weights))

# Testing the model

In [3]:
from google.colab import drive
drive.mount('/content/gdrive')
checkpoint_loc = '/content/gdrive/MyDrive/11785/project/'

Mounted at /content/gdrive


In [4]:
import torch.optim as optim
from torch.utils.data import DataLoader

from torchvision.datasets import MNIST, CIFAR10
from torchvision.transforms import Compose, ToTensor, Normalize
from torchvision import transforms

In [5]:
torch.manual_seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

INIT_LR = 0.1
WEIGHT_DECAY_RATE = 0.0005
EPOCHS = 70
lr_decay_interval = 10
batch_size = 128

In [6]:
VGG_CONFIGS = {
    # M for MaxPool, Number for channels
    'D': [
        64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M',
        512, 512, 512, 'M'
    ],
}


class VGG_SNIP(nn.Module):
    """
    This is a base class to generate three VGG variants used in SNIP paper:
        1. VGG-C (16 layers)
        2. VGG-D (16 layers)
        3. VGG-like

    Some of the differences:
        * Reduced size of FC layers to 512
        * Adjusted flattening to match CIFAR-10 shapes
        * Replaced dropout layers with BatchNorm
    """

    def __init__(self, config, num_classes=10):
        super().__init__()

        self.features = self.make_layers(VGG_CONFIGS[config], batch_norm=True)

        self.classifier = nn.Sequential(
            nn.Linear(512, 512),  # 512 * 7 * 7 in the original VGG
            nn.ReLU(True),
            nn.BatchNorm1d(512),  # instead of dropout
            nn.Linear(512, 512),
            nn.ReLU(True),
            nn.BatchNorm1d(512),  # instead of dropout
            nn.Linear(512, num_classes),
        )

    @staticmethod
    def make_layers(config, batch_norm=False):  # TODO: BN yes or no?
        layers = []
        in_channels = 3
        for v in config:
            if v == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
                if batch_norm:
                    layers += [
                        conv2d,
                        nn.BatchNorm2d(v),
                        nn.ReLU(inplace=True)
                    ]
                else:
                    layers += [conv2d, nn.ReLU(inplace=True)]
                in_channels = v
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)  
        x = F.log_softmax(x, dim=1)
        return x

In [7]:
def get_cifar10_dataloaders(train_batch_size, test_batch_size):

    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    train_dataset = CIFAR10('_dataset', True, train_transform, download=True)
    test_dataset = CIFAR10('_dataset', False, test_transform, download=False)

    train_loader = DataLoader(
        train_dataset,
        train_batch_size,
        shuffle=True,
        num_workers=2,
        pin_memory=True)
    test_loader = DataLoader(
        test_dataset,
        test_batch_size,
        shuffle=False,
        num_workers=2,
        pin_memory=True)

    return train_loader, test_loader

In [8]:
def cifar10_experiment():
    
    BATCH_SIZE = 128
    LR_DECAY_INTERVAL = 20
    
    net = VGG_SNIP('D').to(device)
    # net = 
    optimiser = optim.SGD(
        net.parameters(),
        lr=INIT_LR,
        momentum=0.9,
        weight_decay=WEIGHT_DECAY_RATE)
    lr_scheduler = optim.lr_scheduler.StepLR(
        optimiser, LR_DECAY_INTERVAL, gamma=0.1)
    
    train_loader, val_loader = get_cifar10_dataloaders(BATCH_SIZE,
                                                       BATCH_SIZE)  # TODO

    return net, optimiser, lr_scheduler, train_loader, val_loader

In [10]:
initial_net, optimiser, lr_scheduler, train_loader, val_loader = cifar10_experiment()
initial_net = initial_net.to(device)
torch.save(initial_net,'/content/init.pt')
initial_net

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to _dataset/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=0.0, max=170498071.0), HTML(value='')))


Extracting _dataset/cifar-10-python.tar.gz to _dataset


VGG_SNIP(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU(inplace=True)
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (14): Conv2d(128

In [32]:
# load after pruned network and perform quantization.
after_pruning_net = torch.load('/content/gdrive/MyDrive/11785/project/check/after_pruning.ptmodel')
after_pruning_net = after_pruning_net.cuda()

In [13]:
import time

def validate(epoch, model, criterion, device, data_loader):
    start_time = time.time()
    with torch.no_grad():
        model.eval()
        running_loss, accuracy,total  = 0.0, 0.0, 0

        
        for i, (X, Y) in enumerate(data_loader):
            
            X, Y = X.to(device), Y.to(device)
            output= model(X)
            loss = criterion(output, Y.long())

            _,pred_labels = torch.max(F.softmax(output, dim=1), 1)
            pred_labels = pred_labels.view(-1)
            
            accuracy += torch.sum(torch.eq(pred_labels, Y)).item()

            running_loss += loss.item()
            total += len(X)

            torch.cuda.empty_cache()
            
            del X
            del Y
        
        return running_loss/total, accuracy/total, (time.time() - start_time)

In [33]:
 criterion = nn.CrossEntropyLoss()
 val_loss, val_acc,time_taken = validate(0, after_pruning_net, criterion, device, val_loader)
 print(val_loss, ' ', val_acc,' ',time_taken)

0.0022592162996530533   0.9193   2.32389760017395


Testing using quantization new code

In [36]:
q_net = QuantizeNetwork()
q_network = q_net.quantize_network(after_pruning_net,5)

layer_names [ Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ] -> unique weights count old: 1590 , new: 32
layer_names [ BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ] -> unique weights count old: 64 , new: 32
layer_names [ Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ] -> unique weights count old: 26883 , new: 32
layer_names [ BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ] -> unique weights count old: 64 , new: 32
layer_names [ Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ] -> unique weights count old: 47853 , new: 32
layer_names [ BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ] -> unique weights count old: 128 , new: 32
layer_names [ Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ] -> unique weights count old: 76178 , new: 32
layer_names [ BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_

In [37]:
 criterion = nn.CrossEntropyLoss()
 val_loss, val_acc,time_taken = validate(0, q_network, criterion, device, val_loader)
 print(val_loss, ' ', val_acc,' ',time_taken)

0.0022042215384542943   0.9184   2.4405198097229004
