In [1]:
# Implements 1 level cache (No backward pass needed so not implemented; only accuracy determination required)
# Weight sharing not included

# Cache.approx just updates W, A values according to hits and misses and does not compute the result. Actual result computation
# is done at the end in parallel by doing updated_W * updated_A
# Current - LRU policy replacement. Future - May use LRU + count-used based policy

# Find out how hdw simulation is done in papers. They ofcourse don't actually build the hdw, they just simulate it. But, 
# they are still able to find the accuracy value by running the program over their hdw simulation

# To be added (and be performed in this order):
# Network Pruning
# Wt sharing
# Bit Masking (both weights and activations; may try retraining also for both) / Pytorch trained 8 bit int quantization
# May merge weight bit masking in 3rd step with 2nd step 

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import  torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import shutil
import copy
import pandas as pd
from sklearn.cluster import KMeans, MiniBatchKMeans
from statistics import mean
from collections  import OrderedDict
from collections  import namedtuple
import sys

device = torch.device('cuda')
SAVE_PATH = 'D://models//main_net.pth'

In [3]:
transform = {
    'train':transforms.Compose([
        transforms.RandomResizedCrop(224), 
        transforms.RandomHorizontalFlip(), 
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],[0.229,0.224,0.225])]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
}

In [4]:
data_dir = 'D:\\datasets\\ILSVRC2012_img_val - Retrain\\'
dataset = {x:datasets.ImageFolder(os.path.join(data_dir, x), transform[x]) for x in ['train', 'val']}

In [5]:
dataloader = {x:torch.utils.data.DataLoader(dataset[x], batch_size = 1, shuffle = False)
              for x in ['train', 'val']}

In [6]:
dataset_size = {x:len(dataset[x]) for x in ['train', 'val']}
class_names = dataset['train'].classes

In [7]:
class cache:
    global device
    
    def __init__(self):
        self.size = 1000
        self.W = torch.randn(self.size, device = device)
        self.A = torch.randn(self.size, device = device)
        self.time = torch.ones(self.size, device = device, dtype = torch.int32)
        self.epsilon = 1e-2
        self.hits = torch.tensor(0, device = device, dtype = torch.float32).view(1)
        self.misses = torch.tensor(0, device = device, dtype = torch.float32).view(1)
        
    def batch_wise_approx(self, orig_W, orig_A, num_send):
        
        if orig_W.shape != orig_A.shape:
            sys.exit("W-shape and A-shape unequal")
        
        shape = orig_W.shape
        num_elem = orig_W.numel()
        
        orig_W = orig_W.flatten()
        orig_A = orig_A.flatten()
        
        out_W = torch.zeros(orig_W.shape, dtype = torch.float32, device = device)
        out_A = torch.zeros(orig_A.shape, dtype = torch.float32, device = device)
        
        for i in range(int(num_elem / num_send) + 1):
            start = i * num_send
            end = min((i + 1) * num_send, num_elem)
            out_W[start:end], out_A[start:end], hits, misses = self.approx(orig_W[start:end], orig_A[start:end])
            tot = hits + misses
            print('\r{} / {} | Hits {} / {}, {:.2f}%'.format(i, int(num_elem / num_send), hits, tot,  hits * 100.0 / tot), end = '')
            
        return out_W.view(shape), out_A.view(shape)
        
    def approx(self, orig_W, orig_A):
        
        s1 = orig_W.shape
        s2 = orig_A.shape
        
        orig_W = orig_W.flatten()
        orig_A = orig_A.flatten()
        if orig_W.shape != orig_A.shape:
            sys.exit("W-shape and A-shape unequal")
        
        # Remove indices where W = 0 or A = 0
        
        ind1 = (orig_W == 0).nonzero() 
        ind2 = (orig_A == 0).nonzero()
        zero_inds = torch.unique(torch.cat((ind1, ind2))).view(-1,1)
        all_inds = torch.arange(orig_W.shape[0], device = device).view(1,-1)
        non_zero_inds = (((zero_inds - all_inds) == 0).sum(dim = 0) == 0).nonzero().view(-1) # set(all_inds) - set(zero_inds)
        W_ = orig_W[non_zero_inds]
        A_ = orig_A[non_zero_inds]
        
        W = W_
        A = A_
        
        while(True):
            
            #print('Size is: ', len(W), len(A))
            
            # Find hits, misses

            I = torch.cat((W.view(1,-1),A.view(1,-1)), dim = 0)
            S = torch.cat((self.W.view(1,-1), self.A.view(1,-1)), dim = 0)

            #x = I.repeat(S.shape[1],1,1)
            x = I.expand(S.shape[1],-1,-1)
            #y = S.t().view(S.shape[1],2,1).repeat(1,1,I.shape[1])
            y = S.t().view(S.shape[1],2,1).expand(-1,-1,I.shape[1])
            dist, LI = torch.abs(x - y).sum(dim = 1).min(dim = 0)

            misses = ~(dist < self.epsilon)
            t = misses.nonzero()
            
            if t.shape[0] != 0:
                max_lim = t[0]
            else:
                max_lim = W.shape[0]
            LI = LI[:max_lim]
            
            #print('LI', LI)
            
            if LI.shape[0] != 0: # Hits encountered
                
                # Update num hits
                self.hits += max_lim

                # Update W, A corresponding to hits
                W[:max_lim] = self.W[LI]
                A[:max_lim] = self.A[LI]

                # Update time corresponding to hits

                used = torch.unique(LI)
                x = LI.view(-1,1) == used.view(1,-1)
                x = x.type(torch.int32)
                x = x * torch.arange(start = 1, end = x.shape[0]+1, device = device).view(-1,1).expand_as(x)
                x = x.shape[0] - 1 - torch.argmax(x, dim = 0)

                self.time += LI.shape[0]
                self.time[used] = 0
                self.time[used] += x

#                 print('Time', self.time)
#                 print('W', W)
#                 print('A', A)
            
            if t.shape[0] != 0: # A miss encountered
                
                self.misses += 1
                
                # Update stored values in cache
                ind = torch.argmax(self.time)
                self.W[ind] = W[max_lim]
                self.A[ind] = A[max_lim]
                self.time += 1
                self.time[ind] = 0
                
#                 print('Miss at', max_lim, W[max_lim], A[max_lim])
#                 print('New cache W', self.W)
#                 print('New cache A', self.A)
#                 print('New time', self.time)
#                 print()
                
                if max_lim == W.shape[0] - 1:
                    break
                
                # Update W, A, prod for next cycle
                W = W[max_lim + 1:]
                A = A[max_lim + 1:]
                
            else:
                break
        
        orig_W[non_zero_inds] = W_
        orig_A[non_zero_inds] = A_
        
        hits = self.hits.item()
        misses = self.misses.item()
        
        self.hits = torch.tensor(0, device = device, dtype = torch.float32).view(1)
        self.misses = torch.tensor(0, device = device, dtype = torch.float32).view(1)
        
        return orig_W.view(s1), orig_A.view(s2), hits, misses
            

In [8]:
global_cache = cache()

In [9]:
# Testing

# global_cache.size = 3
# global_cache.epsilon = 3
# global_cache.W = torch.tensor([1.,2,3], device = device, dtype = torch.float32)
# global_cache.A = torch.tensor([7.,2,5], device = device, dtype = torch.float32)
# global_cache.time = torch.ones(global_cache.size, device = device, dtype = torch.int32)

# x = torch.tensor([1., 2, 3, 1, 2, 1, 1, 3, 0, 3, 2], device = device, dtype = torch.float32)
# y = torch.tensor([1., 2, 3, 4, 5, 6, 7, 8, 9, 10, 0], device = device, dtype = torch.float32)
# x, y = global_cache.approx(x, y)

# print()
# print(global_cache.hits, global_cache.misses)

In [10]:
x = torch.tensor([])
torch.cat((x, torch.zeros(10)))

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [11]:
class cached_conv(nn.Module):
    
    global device
    
    def __init__(self, wt_layer, cache):
        super(cached_conv, self).__init__()
        self.weight = wt_layer.weight
        self.bias = wt_layer.bias
        self.stride = wt_layer.stride
        self.padding = wt_layer.padding
        #self.dilation = wt_layer.dilation
        #self.groups = wt_layer.groups
        self.cache = cache
        
    def forward(self, x):
        
        print('Reached conv')
        
        A_prev = x
        W = self.weight
        b = self.bias
        stride = self.stride
        pad = self.padding
        cache = self.cache
        
        #return F.conv2d(x, W, bias = b, stride = stride, padding = pad)
        
        (m, n_C_prev, n_H_prev, n_W_prev) = A_prev.shape
        (n_C, n_C_prev, f, f) = W.shape
        
        # Compute the dimensions of the CONV output volume 
        n_H = int((n_H_prev + 2*pad[0] - f)/stride[0]) + 1
        n_W =int((n_W_prev + 2*pad[1] - f)/stride[1]) + 1

        y = F.unfold(A_prev, (f, f), padding = pad, stride = stride).transpose(2,1)
        #y = y.view(m, 1, y.shape[1],y.shape[2]).repeat((1,n_C,1,1))
        y = y.view(m, 1, y.shape[1],y.shape[2]).expand((-1,n_C,-1,-1))

        W = W.view(n_C, -1)
        #W = W.view(1,n_C, 1, W.shape[1]).repeat(m, 1, y.shape[2], 1)
        W = W.view(1,n_C, 1, W.shape[1]).expand(m, -1, y.shape[2], -1)
        
        W, y = cache.batch_wise_approx(W, y, 35000)
        
        Z = torch.sum(W * y, dim = 3).view(m, n_C, n_H, n_W)
        Z = Z + b.view(1,b.shape[0], 1, 1)
        
        #print(torch.sum(torch.abs(Z - F.conv2d(x, self.weight, bias = self.bias, stride = self.stride, padding = self.padding))))
        
        #sys.exit()
        
        return Z

In [12]:
class cached_fc(nn.Module):
    
    global device
    
    def __init__(self, wt_layer, cache):
        super(cached_fc, self).__init__()
        self.weight = wt_layer.weight
        self.bias = wt_layer.bias
        self.cache = cache
        
    def forward(self, x):
        
        #print('Reached fc')
               
        return F.linear(x, self.weight, bias = self.bias)
        
        A_prev = x
        W = self.weight
        b = self.bias
        cache = self.cache
        
        (m, n_prev) = A_prev.shape
        (n, n_prev) = W.shape

        A_prev = A_prev.view(m, 1, n_prev).expand(-1, n, -1)
        W = W.view(1, n, n_prev).expand(m, -1, -1)

        W, A_prev = cache.batch_wise_approx(W, A_prev, 30000)
        
        Z = (A_prev * W).sum(dim = 2).view(m, n)
        Z = Z + b.view(1, n)
        
        #print(torch.sum(torch.abs(Z - F.linear(x, self.weight, bias = self.bias))))

        return Z

In [13]:
class AlexNet(nn.Module):

    def __init__(self, init_state_dict, num_classes=1000):
        super(AlexNet, self).__init__()
        
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )
        
        self.load_state_dict(init_state_dict)
        
        self.init_cache_layers()
       
    def init_cache_layers(self):
        
        ind = -1
        global global_cache 
        q_list = []
        for layer in self.features:
            if isinstance(layer, nn.Conv2d):
                ind += 1
                q_list.append(cached_conv(layer, global_cache))
            else:
                q_list.append(layer)
        self.features = nn.Sequential(*q_list)
        
        ind = -1
        q_list = []
        for layer in self.classifier:
            if isinstance(layer, nn.Linear):
                ind += 1
                q_list.append(cached_fc(layer, global_cache))
            else:
                q_list.append(layer)
        self.classifier = nn.Sequential(*q_list)
        
    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

In [14]:
def check_accuracy(model, phase):
    
    global device
    
    model.to(device)
    model.eval()
#     if record_grad:
#         model.train()
#     else:
#         model.eval()

        
    done = 0
    acc = 0.0
    since = time.time()
    corrects = torch.tensor(0)
    total_loss = 0.0
    corrects = corrects.to(device)
    loss = 100.0
    
    for inputs, labels in dataloader[phase]:

        inputs = inputs.to(device)
        labels = labels.to(device)
        
        with torch.set_grad_enabled(False):
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            corrects += torch.sum(preds == labels)

        done += len(inputs)
        print('\r{}, {}, {:.2f}%, {:.2f}'.format(corrects.item(), done, corrects.item() * 100.0 / done, total_loss), end = '')
#         if done >= 1000:
#             break
                    
    acc = corrects.double() / done
    print('\n{} Acc: {:.4f} %'.format(phase, acc * 100))

    time_elapsed = time.time() - since
    print('Total time taken = {} seconds'.format(time_elapsed))

    return acc


In [15]:
alexnet = models.alexnet(pretrained=True)
model = AlexNet(init_state_dict=torch.load('D://models//undone_wt_shared_net.pth'))

In [16]:
model.classifier

Sequential(
  (0): Dropout(p=0.5, inplace=False)
  (1): cached_fc()
  (2): ReLU(inplace=True)
  (3): Dropout(p=0.5, inplace=False)
  (4): cached_fc()
  (5): ReLU(inplace=True)
  (6): cached_fc()
)

In [17]:
model.to(device)
torch.cuda.empty_cache()

check_accuracy(model, 'val')

Reached conv
30 / 2007 | Hits 19126.0 / 29795.0, 64.19%

RuntimeError: CUDA out of memory. Tried to allocate 398.00 MiB (GPU 0; 6.00 GiB total capacity; 4.38 GiB already allocated; 380.14 MiB free; 4.39 GiB reserved in total by PyTorch)

In [None]:
####################################################################################################################
####################################################################################################################
####################################################################################################################
####################################################################################################################
####################################################################################################################
####################################################################################################################
####################################################################################################################
####################################################################################################################
####################################################################################################################
####################################################################################################################