In [1]:
# Implements weight sharing

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import  torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import shutil
import copy
import pandas as pd
from sklearn.cluster import KMeans, MiniBatchKMeans
from statistics import mean
from collections  import OrderedDict
import sys

device = torch.device('cuda')
SAVE_PATH = 'D://models//wt_shared_net.pth'

In [3]:
transform = {
    'train':transforms.Compose([
        transforms.RandomResizedCrop(224), 
        transforms.RandomHorizontalFlip(), 
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],[0.229,0.224,0.225])]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
}

In [4]:
data_dir = 'D:\\datasets\\ILSVRC2012_img_val - Retrain\\'
dataset = {x:datasets.ImageFolder(os.path.join(data_dir, x), transform[x]) for x in ['train', 'val']}

In [5]:
dataloader = {x:torch.utils.data.DataLoader(dataset[x], batch_size = 512, shuffle = False, num_workers = 6, pin_memory=True)
              for x in ['train', 'val']}

In [6]:
batch_size_actk = 24
dataloader_actk = torch.utils.data.DataLoader(dataset['train'], batch_size = batch_size_actk, shuffle = True, num_workers = 6, pin_memory=True)

In [7]:
dataset_size = {x:len(dataset[x]) for x in ['train', 'val']}
class_names = dataset['train'].classes

In [8]:
class Quantize(nn.Module):
    
    global device
    
    def __init__(self, layer_type, wt_layer, num_vals, quick = False):
        super(Quantize, self).__init__()
        
        if layer_type != 'conv' and layer_type != 'fc':
            sys.exit("Invalid layer type given")
        
        if quick == False:
            self.mask = ~(wt_layer.weight == 0).to(device)
            self.wt_shape = wt_layer.weight.shape
            
            mat = wt_layer.weight[self.mask]
            flat_mat = mat.to('cpu').view(-1, 1).detach()
            #kmeans = KMeans(n_clusters=num_vals, n_jobs=12)
            kmeans = MiniBatchKMeans(n_clusters=num_vals, batch_size=1000000)
            kmeans.fit(flat_mat)
            
            self.centroids = nn.Parameter(torch.from_numpy(kmeans.cluster_centers_).to(device).requires_grad_(True))
            self.labels = nn.Parameter(torch.from_numpy(kmeans.labels_).to(device).view(mat.shape).type(torch.long), requires_grad=False)
        else:
            self.mask = ~(wt_layer.weight == 0).to(device)
            mat = wt_layer.weight[self.mask]
            self.wt_shape = wt_layer.weight.shape
            self.centroids = nn.Parameter(torch.zeros((num_vals,1), device = device).requires_grad_(True))
            self.labels = nn.Parameter(torch.zeros(mat.shape, dtype = torch.long, device = device), requires_grad=False)
            
        self.type = layer_type    
        self.num_reps = num_vals
        self.bias = wt_layer.bias.to(device).requires_grad_(True)
        if layer_type == 'conv':
            self.stride = wt_layer.stride
            self.padding = wt_layer.padding
            self.dilation = wt_layer.dilation
            self.groups = wt_layer.groups
        
    def forward(self, x):
        vals = torch.squeeze(self.centroids[self.labels], dim = -1).type(torch.float32)
        wt = torch.zeros(self.wt_shape, dtype = torch.float32).to(device)
        wt[self.mask] = vals
        
        if self.type == 'conv':
            return F.conv2d(x, wt, bias = self.bias, stride = self.stride, padding = self.padding, dilation = self.dilation, groups = self.groups)
        else:
            return F.linear(x, wt, bias = self.bias)

In [9]:
class AlexNet(nn.Module):

    def __init__(self, init_model, quant_nums_w, quick = False, num_classes=1000):
        super(AlexNet, self).__init__()
        
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )
        
        if init_model != None:
            self.load_state_dict(copy.deepcopy(init_model.state_dict()))
        
        if quant_nums_w != None:
            self.init_wt_quantizers(quant_nums_w, quick = quick)
       
    def init_wt_quantizers(self, quant_nums, quick):
        
        ind = -1
        
        q_list = []
        for layer in self.features:
            if isinstance(layer, nn.Conv2d):
                ind += 1
                q_list.append(Quantize('conv', layer, quant_nums[ind], quick))
                print('Done', ind)
            else:
                q_list.append(layer)
        self.features = nn.Sequential(*q_list)
        
        q_list = []
        for layer in self.classifier:
            if isinstance(layer, nn.Linear):
                ind += 1
                q_list.append(Quantize('fc', layer, quant_nums[ind], quick))
                print('Done', ind)
            else:
                q_list.append(layer)
        self.classifier = nn.Sequential(*q_list)
        
    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

In [10]:
def check_accuracy(model, phase, record_grad, criterion = None, optimizer = None, I = None):
    
    global device
    
    model.to(device)
    model.eval()
#     if record_grad:
#         model.train()
#     else:
#         model.eval()

        
    done = 0
    acc = 0.0
    since = time.time()
    corrects = torch.tensor(0)
    total_loss = 0.0
    corrects = corrects.to(device)
    loss = 100.0
    
    if I == None:
        for inputs, labels in dataloader[phase]:

            inputs = inputs.to(device)
            labels = labels.to(device)

            if record_grad:
                with torch.set_grad_enabled(True):
                    optimizer.zero_grad()
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    corrects += torch.sum(preds == labels)
                    loss = criterion(outputs, labels)
                    loss.backward()
                    optimizer.step()

                total_loss += loss.item() * inputs.size(0)

            else:
                with torch.set_grad_enabled(False):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    corrects += torch.sum(preds == labels)
                    
            done += len(inputs)
            print('\r{}, {}, {:.2f}%, {:.2f}'.format(corrects.item(), done, corrects.item() * 100.0 / done, total_loss), end = '')
#             if done >= 64:
#                 break
                    
    else:
            
        inputs, labels = I

        inputs = inputs.to(device)
        labels = labels.to(device)

        if record_grad:
            with torch.set_grad_enabled(True):
                optimizer.zero_grad()
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                corrects += torch.sum(preds == labels)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                #print(model.features[0].centroids.grad)
            total_loss += loss.item() * inputs.size(0)

        else:
            with torch.set_grad_enabled(False):
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                corrects += torch.sum(preds == labels)
                
        done += len(inputs)
        print('\r{}, {}, {:.2f}%, {:.2f}'.format(corrects.item(), done, corrects.item() * 100.0 / done, total_loss), end = '')

    acc = corrects.double() / done
    print('\n{} Acc: {:.4f} %'.format(phase, acc * 100))

    time_elapsed = time.time() - since
    print('Total time taken = {} seconds'.format(time_elapsed))

    if record_grad:
        return acc, total_loss
    else:
        return acc


In [11]:
def train(model, criterion, optimizer, scheduler = None, num_epochs = 25, I = None, do_baseline = True):
    
    global device
    
    print('          ', end = '\r')
    acc = {'train':0.0, 'val':0.0}
    best_acc = 0.0
    
    if do_baseline:
        acc['val'] = check_accuracy(model, phase = 'val', record_grad = False)
        acc['train'] = check_accuracy(model, phase = 'train', record_grad = False)
        print('.......... Baseline Evaluation Done ..............')
        best_acc = acc['val']
    
    since = time.time()
    torch.save(model.state_dict(), SAVE_PATH)
    
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        
        for phase in ['train', 'val']:
            if phase == 'val':
                epoch_acc = check_accuracy(model, phase, record_grad = False, I = I)
                if epoch_acc > best_acc:
                    print('Saving')
                    best_acc = epoch_acc
                    torch.save(model.state_dict(), SAVE_PATH)
            else:
                epoch_acc, epoch_loss = check_accuracy(model, phase, criterion=criterion, optimizer=optimizer, record_grad = True, I = I)
                
        print()
        
        if scheduler != None:
            scheduler.step()
        
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))
    
    model.load_state_dict(torch.load(SAVE_PATH))
    return model

In [12]:
device = torch.device('cuda')
alexnet = models.alexnet(pretrained=True)
alexnet.load_state_dict(torch.load('D://models//undone_pruned_net.pth'))

<All keys matched successfully>

In [13]:
# model = alexnet
# model.to(device)
# check_accuracy(model, 'val', record_grad = False)

In [14]:
ORIG_PATH = 'D://models//NNA_quants.pth'

# model = AlexNet(init_model=alexnet, quant_nums_w = [32]*8, quick = False)
# torch.save(model.state_dict(), ORIG_PATH)

model = AlexNet(init_model=alexnet, quant_nums_w = [32]*8, quick = True)
# model.load_state_dict(torch.load(ORIG_PATH))
model.load_state_dict(torch.load(SAVE_PATH))

Done 0
Done 1
Done 2
Done 3
Done 4
Done 5
Done 6
Done 7


<All keys matched successfully>

In [15]:
params = list(filter(lambda p: p.requires_grad, model.parameters()))
print(len(params), params)

8 [Parameter containing:
tensor([[ 0.0474],
        [-0.1279],
        [ 0.2761],
        [-0.3138],
        [-0.0402],
        [ 0.1549],
        [ 0.6490],
        [-0.2235],
        [-0.4541],
        [ 0.0802],
        [ 0.3952],
        [-0.0605],
        [ 0.0207],
        [ 0.2302],
        [-0.1868],
        [-0.6518],
        [-0.0209],
        [ 0.1002],
        [ 0.1255],
        [-0.0826],
        [-0.3815],
        [ 0.5585],
        [ 0.3306],
        [-0.2637],
        [-0.1040],
        [ 0.0620],
        [ 0.1894],
        [-0.1548],
        [ 0.0339],
        [ 0.4757],
        [ 0.8451],
        [-0.5259]], device='cuda:0', requires_grad=True), Parameter containing:
tensor([[-0.0397],
        [ 0.0343],
        [-0.2697],
        [ 0.2221],
        [-0.1401],
        [ 0.0945],
        [ 0.0174],
        [-0.0168],
        [-0.0845],
        [ 2.1282],
        [ 0.4799],
        [ 0.0580],
        [ 0.1466],
        [-0.0313],
        [-0.3692],
        [-0.0598],
  

In [None]:
model.to(device)
torch.cuda.empty_cache()

criterion = nn.CrossEntropyLoss() 
params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = optim.SGD(params, lr = 1e-8, momentum = 0.9)
#exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size = 5, gamma = 0.5)
exp_lr_scheduler = None
#I = next(iter(dataloader['train']))
I = None

model = train(model, criterion, optimizer, exp_lr_scheduler, num_epochs = 100, I = I, do_baseline = True)

5397, 10000, 53.97%, 0.00
val Acc: 53.9700 %
Total time taken = 34.36496663093567 seconds
12901, 27648, 46.66%, 0.00

In [None]:
####################################################################################################################
####################################################################################################################
####################################################################################################################
####################################################################################################################
####################################################################################################################
####################################################################################################################
####################################################################################################################
####################################################################################################################
####################################################################################################################
####################################################################################################################