In [1]:
# Implements weight sharing

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import  torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import shutil
import copy
from sklearn.cluster import KMeans, MiniBatchKMeans
from statistics import mean
from collections  import OrderedDict
import sys

device = torch.device('cuda')
SAVE_PATH = '../models/wt_shared_net.pth'

In [3]:
transform = {
    'train':transforms.Compose([
        transforms.RandomResizedCrop(224), 
        transforms.RandomHorizontalFlip(), 
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],[0.229,0.224,0.225])]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
}

In [4]:
data_dir = '../datasets/ILSVRC2012_img_val - Retrain/'
dataset = {x:datasets.ImageFolder(os.path.join(data_dir, x), transform[x]) for x in ['train', 'val']}

In [5]:
dataloader = {x:torch.utils.data.DataLoader(dataset[x], batch_size = 512, shuffle = False, num_workers = 16, pin_memory=True)
              for x in ['train', 'val']}

In [6]:
batch_size_actk = 24
dataloader_actk = torch.utils.data.DataLoader(dataset['train'], batch_size = batch_size_actk, shuffle = True, num_workers = 6, pin_memory=True)

In [7]:
dataset_size = {x:len(dataset[x]) for x in ['train', 'val']}
class_names = dataset['train'].classes

In [8]:
class Quantize(nn.Module):
    
    global device
    
    def __init__(self, layer_type, wt_layer, num_vals, quick = False):
        super(Quantize, self).__init__()
        
        if layer_type != 'conv' and layer_type != 'fc':
            sys.exit("Invalid layer type given")
        
        if quick == False:
            self.mask = ~(wt_layer.weight == 0).to(device)
            self.wt_shape = wt_layer.weight.shape
            
            mat = wt_layer.weight[self.mask]
            flat_mat = mat.to('cpu').view(-1, 1).detach()
            #kmeans = KMeans(n_clusters=num_vals, n_jobs=12)
            kmeans = MiniBatchKMeans(n_clusters=num_vals, batch_size=5000000)
            kmeans.fit(flat_mat)
            
            self.centroids = nn.Parameter(torch.from_numpy(kmeans.cluster_centers_).to(device).requires_grad_(True))
            self.labels = nn.Parameter(torch.from_numpy(kmeans.labels_).to(device).view(mat.shape).type(torch.long), requires_grad=False)
        else:
            self.mask = ~(wt_layer.weight == 0).to(device)
            mat = wt_layer.weight[self.mask]
            self.wt_shape = wt_layer.weight.shape
            self.centroids = nn.Parameter(torch.zeros((num_vals,1), device = device).requires_grad_(True))
            self.labels = nn.Parameter(torch.zeros(mat.shape, dtype = torch.long, device = device), requires_grad=False)
            
        self.type = layer_type    
        self.num_reps = num_vals
        self.bias = wt_layer.bias.to(device).requires_grad_(True)
        if layer_type == 'conv':
            self.stride = wt_layer.stride
            self.padding = wt_layer.padding
            self.dilation = wt_layer.dilation
            self.groups = wt_layer.groups
        
    def forward(self, x):
        vals = torch.squeeze(self.centroids[self.labels], dim = -1).type(torch.float32)
        wt = torch.zeros(self.wt_shape, dtype = torch.float32).to(device)
        wt[self.mask] = vals
        
        if self.type == 'conv':
            return F.conv2d(x, wt, bias = self.bias, stride = self.stride, padding = self.padding, dilation = self.dilation, groups = self.groups)
        else:
            return F.linear(x, wt, bias = self.bias)

In [9]:
class AlexNet(nn.Module):

    def __init__(self, init_model, quant_nums_w, quick = False, num_classes=1000):
        super(AlexNet, self).__init__()
        
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )
        
        if init_model != None:
            self.load_state_dict(copy.deepcopy(init_model.state_dict()))
        
        if quant_nums_w != None:
            self.init_wt_quantizers(quant_nums_w, quick = quick)
       
    def init_wt_quantizers(self, quant_nums, quick):
        
        ind = -1
        
        q_list = []
        for layer in self.features:
            if isinstance(layer, nn.Conv2d):
                ind += 1
                q_list.append(Quantize('conv', layer, quant_nums[ind], quick))
                print('Done', ind)
            else:
                q_list.append(layer)
        self.features = nn.Sequential(*q_list)
        
        q_list = []
        for layer in self.classifier:
            if isinstance(layer, nn.Linear):
                ind += 1
                q_list.append(Quantize('fc', layer, quant_nums[ind], quick))
                print('Done', ind)
            else:
                q_list.append(layer)
        self.classifier = nn.Sequential(*q_list)
        
    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x
    
    def un_quantize(self):
        q_list = []
        for layer in self.features:
            if isinstance(layer, Quantize):
                vals = torch.squeeze(layer.centroids[layer.labels], dim = -1).type(torch.float32)
                wt = torch.zeros(layer.wt_shape, dtype = torch.float32).to(device)
                wt[layer.mask] = vals
                my_layer = nn.Conv2d(wt.shape[1], wt.shape[0], kernel_size = wt.shape[2], stride = layer.stride, padding = layer.padding, dilation = layer.dilation, groups = layer.groups)
                my_layer.weight.data.copy_(wt)
                my_layer.bias.data.copy_(layer.bias)
                q_list.append(my_layer)
            else:
                q_list.append(layer)
        self.features = nn.Sequential(*q_list)
        
        q_list = []
        for layer in self.classifier:
            if isinstance(layer, Quantize):
                vals = torch.squeeze(layer.centroids[layer.labels], dim = -1).type(torch.float32)
                wt = torch.zeros(layer.wt_shape, dtype = torch.float32).to(device)
                wt[layer.mask] = vals
                my_layer = nn.Linear(wt.shape[1], wt.shape[0])
                my_layer.weight.data.copy_(wt)
                my_layer.bias.data.copy_(layer.bias)
                q_list.append(my_layer)
            else:
                q_list.append(layer)
        self.classifier = nn.Sequential(*q_list)
        

In [10]:
def check_accuracy(model, phase, record_grad, criterion = None, optimizer = None, I = None):
    
    global device
    
    model.to(device)
    model.eval()
#     if record_grad:
#         model.train()
#     else:
#         model.eval()

        
    done = 0
    acc = 0.0
    since = time.time()
    corrects = torch.tensor(0)
    total_loss = 0.0
    corrects = corrects.to(device)
    loss = 100.0
    
    if I == None:
        for inputs, labels in dataloader[phase]:

            inputs = inputs.to(device)
            labels = labels.to(device)

            if record_grad:
                with torch.set_grad_enabled(True):
                    optimizer.zero_grad()
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    corrects += torch.sum(preds == labels)
                    loss = criterion(outputs, labels)
                    loss.backward()
                    optimizer.step()

                total_loss += loss.item() * inputs.size(0)

            else:
                with torch.set_grad_enabled(False):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    corrects += torch.sum(preds == labels)
                    
            done += len(inputs)
            print('\r{}, {}, {:.2f}%, {:.2f}'.format(corrects.item(), done, corrects.item() * 100.0 / done, total_loss), end = '')
#             if done >= 64:
#                 break
                    
    else:
            
        inputs, labels = I

        inputs = inputs.to(device)
        labels = labels.to(device)

        if record_grad:
            with torch.set_grad_enabled(True):
                optimizer.zero_grad()
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                corrects += torch.sum(preds == labels)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                #print(model.features[0].centroids.grad)
            total_loss += loss.item() * inputs.size(0)

        else:
            with torch.set_grad_enabled(False):
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                corrects += torch.sum(preds == labels)
                
        done += len(inputs)
        print('\r{}, {}, {:.2f}%, {:.2f}'.format(corrects.item(), done, corrects.item() * 100.0 / done, total_loss), end = '')

    acc = corrects.double() / done
    print('\n{} Acc: {:.4f} %'.format(phase, acc * 100))

    time_elapsed = time.time() - since
    print('Total time taken = {} seconds'.format(time_elapsed))

    if record_grad:
        return acc, total_loss
    else:
        return acc


In [11]:
def train_limited(model, criterion, optimizer, scheduler = None, num_epochs = 25, I = None, do_baseline = True):
    
    global device
    global SAVE_PATH
    
    print('          ', end = '\r')
    acc = {'train':0.0, 'val':0.0}
    best_acc = 0.0
    
    if do_baseline:
        acc['train'] = check_accuracy(model, phase = 'train', record_grad = False)
        print('.......... Baseline Evaluation Done ..............')
        best_acc = acc['train']
    
    since = time.time()
    torch.save(model.state_dict(), SAVE_PATH)
    
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        
        for phase in ['train', 'val']:
            if phase == 'val':
                epoch_acc = check_accuracy(model, 'train', record_grad = False, I = I)
                if epoch_acc > best_acc:
                    print('Saving')
                    best_acc = epoch_acc
                    torch.save(model.state_dict(), SAVE_PATH)
            else:
                epoch_acc, epoch_loss = check_accuracy(model, 'train', criterion=criterion, optimizer=optimizer, record_grad = True, I = I)
                
        print()
        
        if scheduler != None:
            scheduler.step()
        
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))
    
    model.load_state_dict(torch.load(SAVE_PATH))
    return model

In [12]:
def train(model, criterion, optimizer, scheduler = None, num_epochs = 25, I = None, do_baseline = True):
    
    global device
    global SAVE_PATH
    
    print('          ', end = '\r')
    acc = {'train':0.0, 'val':0.0}
    best_acc = 0.0
    
    if do_baseline:
        acc['val'] = check_accuracy(model, phase = 'val', record_grad = False)
        acc['train'] = check_accuracy(model, phase = 'train', record_grad = False)
        print('.......... Baseline Evaluation Done ..............')
        best_acc = acc['val']
    
    since = time.time()
    torch.save(model.state_dict(), SAVE_PATH)
    
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        
        for phase in ['train', 'val']:
            if phase == 'val':
                epoch_acc = check_accuracy(model, phase, record_grad = False, I = I)
                if epoch_acc > best_acc:
                    print('Saving')
                    best_acc = epoch_acc
                    torch.save(model.state_dict(), SAVE_PATH)
            else:
                epoch_acc, epoch_loss = check_accuracy(model, phase, criterion=criterion, optimizer=optimizer, record_grad = True, I = I)
                
        print()
        
        if scheduler != None:
            scheduler.step()
        
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))
    
    model.load_state_dict(torch.load(SAVE_PATH))
    return model

In [14]:
device = torch.device('cuda')
model = models.alexnet(pretrained=True)
check_accuracy(model, 'train', record_grad = False)

19271, 40000, 48.18%, 0.00
train Acc: 48.1775 %
Total time taken = 24.76977825164795 seconds


tensor(0.4818, device='cuda:0', dtype=torch.float64)

In [15]:
model.load_state_dict(torch.load('../models/undone_pruned_net.pth'))
model.to(device)
check_accuracy(model, 'train', record_grad = False)

19079, 40000, 47.70%, 0.00
train Acc: 47.6975 %
Total time taken = 24.577978372573853 seconds


tensor(0.4770, device='cuda:0', dtype=torch.float64)

In [16]:
ORIG_PATH = '../models/NNA_quants.pth'

# model = AlexNet(init_model=alexnet, quant_nums_w = [32]*8, quick = False)
# torch.save(model.state_dict(), ORIG_PATH)

model = AlexNet(init_model=model, quant_nums_w = [32]*8, quick = True)
#model.load_state_dict(torch.load(ORIG_PATH))
model.load_state_dict(torch.load(SAVE_PATH))

Done 0
Done 1
Done 2
Done 3
Done 4
Done 5
Done 6
Done 7


<All keys matched successfully>

In [17]:
params = list(filter(lambda p: p.requires_grad, model.parameters()))
print(len(params), params)

8 [Parameter containing:
tensor([[-0.0304],
        [ 0.1038],
        [-0.1847],
        [ 0.3390],
        [ 0.0443],
        [-0.3106],
        [ 0.1667],
        [-0.1163],
        [-0.6406],
        [ 0.4367],
        [-0.4331],
        [-0.0665],
        [ 0.0150],
        [ 0.2522],
        [ 0.6203],
        [-0.2628],
        [ 0.0616],
        [ 0.1319],
        [-0.0158],
        [ 0.2086],
        [-0.0890],
        [ 0.8414],
        [-0.1482],
        [ 0.0288],
        [-0.3729],
        [-0.2239],
        [-0.0470],
        [ 0.5087],
        [-0.5059],
        [ 0.0815],
        [ 0.2964],
        [ 0.3867]], device='cuda:0', requires_grad=True), Parameter containing:
tensor([[-0.0291],
        [ 0.0409],
        [-0.0799],
        [ 0.2527],
        [-0.2078],
        [ 0.0216],
        [ 0.0996],
        [-0.4167],
        [-0.0501],
        [ 2.1189],
        [-0.1180],
        [-0.0126],
        [ 0.0795],
        [ 0.6553],
        [ 0.1579],
        [-0.2571],
  

In [18]:
model.to(device)
torch.cuda.empty_cache()

criterion = nn.CrossEntropyLoss() 
params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = optim.SGD(params, lr = 1e-7, momentum = 0.9)
#exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size = 5, gamma = 0.5)
exp_lr_scheduler = None
#I = next(iter(dataloader['train']))
I = None

model = train_limited(model, criterion, optimizer, exp_lr_scheduler, num_epochs = 100, I = I, do_baseline = True)

18714, 40000, 46.78%, 0.00
train Acc: 46.7850 %
Total time taken = 27.968446969985962 seconds
.......... Baseline Evaluation Done ..............
Epoch 0/99
----------
18796, 40000, 46.99%, 97743.12
train Acc: 46.9900 %
Total time taken = 92.58258867263794 seconds
18791, 40000, 46.98%, 0.00
train Acc: 46.9775 %
Total time taken = 28.590837240219116 seconds
Saving

Epoch 1/99
----------
18826, 40000, 47.06%, 97957.89
train Acc: 47.0650 %
Total time taken = 91.80765104293823 seconds
18711, 40000, 46.78%, 0.00
train Acc: 46.7775 %
Total time taken = 27.66166114807129 seconds

Epoch 2/99
----------
18777, 40000, 46.94%, 98287.92
train Acc: 46.9425 %
Total time taken = 92.3066155910492 seconds
18825, 40000, 47.06%, 0.00
train Acc: 47.0625 %
Total time taken = 28.00094509124756 seconds
Saving

Epoch 3/99
----------
18772, 40000, 46.93%, 98172.43
train Acc: 46.9300 %
Total time taken = 91.87693190574646 seconds
18777, 40000, 46.94%, 0.00
train Acc: 46.9425 %
Total time taken = 27.8856966495513

KeyboardInterrupt: 

In [24]:
model.load_state_dict(torch.load(SAVE_PATH))
model.to(device)
check_accuracy(model, 'train', record_grad = False)

18786, 40000, 46.97%, 0.00
train Acc: 46.9650 %
Total time taken = 28.752476453781128 seconds


tensor(0.4697, device='cuda:0', dtype=torch.float64)

In [25]:
model.un_quantize()

In [26]:
model.features

Sequential(
  (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
  (1): ReLU(inplace=True)
  (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (4): ReLU(inplace=True)
  (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (7): ReLU(inplace=True)
  (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (9): ReLU(inplace=True)
  (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace=True)
  (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
)

In [27]:
model.classifier

Sequential(
  (0): Dropout(p=0.5, inplace=False)
  (1): Linear(in_features=9216, out_features=4096, bias=True)
  (2): ReLU(inplace=True)
  (3): Dropout(p=0.5, inplace=False)
  (4): Linear(in_features=4096, out_features=4096, bias=True)
  (5): ReLU(inplace=True)
  (6): Linear(in_features=4096, out_features=1000, bias=True)
)

In [29]:
model.to(device)
check_accuracy(model, 'train', record_grad = False)

18741, 40000, 46.85%, 0.00
train Acc: 46.8525 %
Total time taken = 24.960099935531616 seconds


tensor(0.4685, device='cuda:0', dtype=torch.float64)

In [30]:
torch.save(model.state_dict(), '../models/undone_wt_shared_net.pth')

In [None]:
####################################################################################################################
####################################################################################################################
####################################################################################################################
####################################################################################################################
####################################################################################################################
####################################################################################################################
####################################################################################################################
####################################################################################################################
####################################################################################################################
####################################################################################################################