In [None]:
# Implements 1 level cache (No backward pass needed so not implemented; only accuracy determination required)
# Weight sharing not included

# Cache.approx just updates W, A values according to hits and misses and does not compute the result. Actual result computation
# is done at the end in parallel by doing updated_W * updated_A
# Current - LRU policy replacement. Future - May use LRU + count-used based policy

# Find out how hdw simulation is done in papers. They ofcourse don't actually build the hdw, they just simulate it. But, 
# they are still able to find the accuracy value by running the program over their hdw simulation

# To be added (and be performed in this order):
# Network Pruning
# Wt sharing
# Bit Masking (both weights and activations; may try retraining also for both) / Pytorch trained 8 bit int quantization
# May merge weight bit masking in 3rd step with 2nd step 

# Find out exactly how 1000 caches will be used in hdw. They all have the same data ? If miss occurs in one cache, access 
# to all caches will be stopped ? Using the parallelism of 1000 caches to process misses in parallel can really speed up
# the simulation. The sequential processing of misses is the current bottleneck

# Speed Up methods:
# Prune W and A in initial step to be able to increase batch size (Currently, very difficult to vectorize. Basically, hard
# to get equal size search ranges for each unique weight in the new weight-activation pairs to prune cache_W. Search ranges
# can be made equal for each unique weight in a sequential manner)
# Use multiple caches which are dependant / independent
# May do distance calculation between new weight-activation pairs and cache pairs in batches to increase the batch size fed to
# the cache

# First: Make updating time corresponding to hits consume less memory and run faster. (Start sampling LI from the end in
# small batches and break from the loop when all unique_LI have been found)
# Currently, it is the bottleneck for increasing batch size. So even if say 16 caches are used in parallel, only the same no 
# of hits as are being processed currently could be processed in parallel  

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import  torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import shutil
import copy
from sklearn.cluster import KMeans, MiniBatchKMeans
from statistics import mean
from collections  import OrderedDict
from collections  import namedtuple
import sys

device = torch.device('cpu')

In [None]:
transform = {
    'train':transforms.Compose([
        transforms.RandomResizedCrop(224), 
        transforms.RandomHorizontalFlip(), 
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],[0.229,0.224,0.225])]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
}

In [None]:
data_dir = '../datasets/ILSVRC2012_img_val - Retrain/'
dataset = {x:datasets.ImageFolder(os.path.join(data_dir, x), transform[x]) for x in ['train', 'val']}

In [None]:
dataloader = {x:torch.utils.data.DataLoader(dataset[x], batch_size = 1, shuffle = False)
              for x in ['train', 'val']}

In [None]:
dataset_size = {x:len(dataset[x]) for x in ['train', 'val']}
class_names = dataset['train'].classes

In [None]:
class cache:
    global device
    
    def __init__(self, num, size, epsilon):
        self.size = size
        self.num = num
        self.W = torch.randn((self.num, self.size), device = device)
        self.A = torch.randn((self.num, self.size), device = device)
        self.time = torch.ones((self.num, self.size), device = device, dtype = torch.int32)
        self.epsilon = epsilon
        self.hits = torch.zeros((self.num,), device = device, dtype = torch.int64)
        self.misses = torch.zeros((self.num,), device = device, dtype = torch.int64)
        
    def batch_wise_approx(self, orig_W, orig_A):
        
        s1 = orig_W.shape
        s2 = orig_A.shape

        orig_W = orig_W.flatten()
        orig_A = orig_A.flatten()
        if orig_W.shape != orig_A.shape:
            sys.exit("W-shape and A-shape unequal")

        # Remove indices where W = 0 or A = 0
        
        non_zero_mask = ((orig_W != 0) & (orig_A != 0))
        W = orig_W[non_zero_mask]
        A = orig_A[non_zero_mask]
        
        # Split into batches
        
        out_W = torch.zeros(W.shape, dtype = torch.float32, device = device)
        out_A = torch.zeros(A.shape, dtype = torch.float32, device = device)
        num_elem = W.shape[0]
        period = 500
        
        upper_lim = int(int(num_elem / self.num) / period)

        for i in range(int(num_elem / self.num) + 1):

            start = i * self.num
            end = min((i + 1) * self.num, num_elem)
            if end == start:
                break
            
            out_W[start:end], out_A[start:end] = self.approx(W[start:end], A[start:end])
            
            if i % period == 0:
                
                hits = torch.sum(self.hits)
                misses = torch.sum(self.misses)
                tot = hits + misses
                
                self.hits = torch.zeros((self.num,), device = device, dtype = torch.int64)
                self.misses = torch.zeros((self.num,), device = device, dtype = torch.int64)
                
                print('\r{} / {} | Hits {} / {}, {:.2f}%'.format(int(i / period), upper_lim, hits, tot,  hits * 100.0 / tot), end = '', flush = True)
                

        orig_W[non_zero_mask] = out_W
        orig_A[non_zero_mask] = out_A
    
        return orig_W.view(s1), orig_A.view(s2)
          
    def approx(self, W, A):
                
        N = W.shape[0]
        
        if (N > self.num):
            sys.exit("Invalid number of W / A provided")
        
        # W and A are flat
        # They have no zero elements
        
        if (W.shape != (N,)) or (A.shape != (N,)):
            sys.exit("W and A have incorrect shape")
            
        my_W = self.W[:N]
        my_A = self.A[:N]
        my_time = self.time[:N]
        my_hits = self.hits[:N]
        my_misses = self.misses[:N]
        
        # Find nearest cache elements and their distances
        
        I = torch.cat([W.view(-1, 1), A.view(-1, 1)], dim = 1).unsqueeze(2)
        S = torch.cat([my_W.unsqueeze(1), my_A.unsqueeze(1)], dim = 1)
        dist, LI = torch.abs(I - S).sum(dim = 1).min(dim = 1)
        
        # Find hits, misses
        
        hits = (dist < self.epsilon)
        misses = ~hits
        num_hits = torch.sum(hits).item()
        num_misses = N - num_hits
        
        # Hit Processing
        
        if num_hits > 0:
            my_hits[hits] += 1
            
            # Update time
            my_time[hits] += 1
            rows = torch.arange(0, N, device = device)[hits].long()
            my_time[rows, LI[hits]] = 0
            
            # Update W, A
            W[hits] = my_W[rows, LI[hits]]
            A[hits] = my_A[rows, LI[hits]]
            
        # Miss Processing
        
        if num_misses > 0:
            my_misses[misses] += 1
            
            # Update cache_store
            rows = torch.arange(0, N, device = device)[misses].long()
            _, rem_inds = torch.max(my_time[misses], dim = 1)
            my_W[rows, rem_inds] = W[misses]
            my_A[rows, rem_inds] = A[misses]
            
            # Update time
            my_time[misses] += 1
            my_time[rows, rem_inds] = 0
            
        return W, A
        

In [None]:
global_cache = cache(16, 200000, 1e-4)

In [None]:
# # Testing

# W = torch.tensor([1,2,9,3], dtype = torch.float32).flatten()
# A = torch.tensor([2,6,0,1], dtype = torch.float32).flatten()
# cache = cache(4, 2, 4)
# cache.W[0] = torch.tensor([4, 10])
# cache.A[0] = torch.tensor([2, 5])
# cache.W[1] = torch.tensor([4, 10])
# cache.A[1] = torch.tensor([2, 5])
# cache.W[2] = torch.tensor([4, 10])
# cache.A[2] = torch.tensor([2, 5])
# cache.W[3] = torch.tensor([4, 10])
# cache.A[3] = torch.tensor([2, 5])

# W_, A_ = cache.approx(W, A)
# # W_ tensor([4., 2., 9., 4.])
# # A_ tensor([2., 6., 0., 2.])
# # cache.W tensor([[ 4., 10.],
# #                 [ 4.,  2.],
# #                 [ 4.,  9.],
# #                 [ 4., 10.]])
# # cache.A tensor([[2., 5.],
# #                 [2., 6.],
# #                 [2., 0.],
# #                 [2., 5.]])
# # cache.time tensor([[0, 2],
# #                   [2, 0],
# #                   [2, 0],
# #                   [0, 2]], dtype=torch.int32)
# # cache.hits tensor([1, 0, 0, 1])
# # cache.misses tensor([0, 1, 1, 0])

In [None]:
class cached_conv(nn.Module):
    
    global device
    
    def __init__(self, wt_layer, cache):
        super(cached_conv, self).__init__()
        self.weight = wt_layer.weight
        self.bias = wt_layer.bias
        self.stride = wt_layer.stride
        self.padding = wt_layer.padding
        #self.dilation = wt_layer.dilation
        #self.groups = wt_layer.groups
        self.cache = cache
        
    def forward(self, x):
        
        print('\nReached conv')
        
        A_prev = x
        W = self.weight
        b = self.bias
        stride = self.stride
        pad = self.padding
        cache = self.cache
        
        #return F.conv2d(x, W, bias = b, stride = stride, padding = pad)
        
        (m, n_C_prev, n_H_prev, n_W_prev) = A_prev.shape
        (n_C, n_C_prev, f, f) = W.shape
        
        # Compute the dimensions of the CONV output volume 
        n_H = int((n_H_prev + 2*pad[0] - f)/stride[0]) + 1
        n_W =int((n_W_prev + 2*pad[1] - f)/stride[1]) + 1

        y = F.unfold(A_prev, (f, f), padding = pad, stride = stride).transpose(2,1)
        #y = y.view(m, 1, y.shape[1],y.shape[2]).repeat((1,n_C,1,1))
        y = y.view(m, 1, y.shape[1],y.shape[2]).expand((-1,n_C,-1,-1))

        W = W.view(n_C, -1)
        #W = W.view(1,n_C, 1, W.shape[1]).repeat(m, 1, y.shape[2], 1)
        W = W.view(1,n_C, 1, W.shape[1]).expand(m, -1, y.shape[2], -1)
        
        W, y = cache.batch_wise_approx(W, y)
        
        Z = torch.sum(W * y, dim = 3).view(m, n_C, n_H, n_W)
        Z = Z + b.view(1,b.shape[0], 1, 1)
        
        #print(torch.sum(torch.abs(Z - F.conv2d(x, self.weight, bias = self.bias, stride = self.stride, padding = self.padding))))
        
        #sys.exit()
        
        return Z

In [None]:
class cached_fc(nn.Module):
    
    global device
    
    def __init__(self, wt_layer, cache):
        super(cached_fc, self).__init__()
        self.weight = wt_layer.weight
        self.bias = wt_layer.bias
        self.cache = cache
        
    def forward(self, x):
        
        print('\nReached fc')
               
        #return F.linear(x, self.weight, bias = self.bias)
        
        A_prev = x
        W = self.weight
        b = self.bias
        cache = self.cache
        
        (m, n_prev) = A_prev.shape
        (n, n_prev) = W.shape

        A_prev = A_prev.view(m, 1, n_prev).expand(-1, n, -1)
        W = W.view(1, n, n_prev).expand(m, -1, -1)

        W, A_prev = cache.batch_wise_approx(W, A_prev)
        
        Z = (A_prev * W).sum(dim = 2).view(m, n)
        Z = Z + b.view(1, n)
        
        #print(torch.sum(torch.abs(Z - F.linear(x, self.weight, bias = self.bias))))

        return Z

In [None]:
class AlexNet(nn.Module):

    def __init__(self, init_state_dict, num_classes=1000):
        super(AlexNet, self).__init__()
        
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )
        
        self.load_state_dict(init_state_dict)
        
        self.init_cache_layers()
       
    def init_cache_layers(self):
        
        ind = -1
        global global_cache 
        q_list = []
        for layer in self.features:
            if isinstance(layer, nn.Conv2d):
                ind += 1
                q_list.append(cached_conv(layer, global_cache))
            else:
                q_list.append(layer)
        self.features = nn.Sequential(*q_list)
        
        ind = -1
        q_list = []
        for layer in self.classifier:
            if isinstance(layer, nn.Linear):
                ind += 1
                q_list.append(cached_fc(layer, global_cache))
            else:
                q_list.append(layer)
        self.classifier = nn.Sequential(*q_list)
        
    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

In [None]:
def check_accuracy(model, phase):
    
    global device
    
    model.to(device)
    model.eval()
#     if record_grad:
#         model.train()
#     else:
#         model.eval()

        
    done = 0
    acc = 0.0
    since = time.time()
    corrects = torch.tensor(0)
    total_loss = 0.0
    corrects = corrects.to(device)
    loss = 100.0
    
    for inputs, labels in dataloader[phase]:

        inputs = inputs.to(device)
        labels = labels.to(device)
        
        with torch.set_grad_enabled(False):
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            corrects += torch.sum(preds == labels)

        done += len(inputs)
        print('\r{}, {}, {:.2f}%, {:.2f}'.format(corrects.item(), done, corrects.item() * 100.0 / done, total_loss), end = '')
#         if done >= 1000:
#             break
                    
    acc = corrects.double() / done
    print('\n{} Acc: {:.4f} %'.format(phase, acc * 100))

    time_elapsed = time.time() - since
    print('Total time taken = {} seconds'.format(time_elapsed))

    return acc


In [None]:
alexnet = models.alexnet(pretrained=True)
model = AlexNet(init_state_dict=alexnet.state_dict())

In [None]:
model.classifier

In [None]:
model.to(device)
torch.cuda.empty_cache()

check_accuracy(model, 'val')

In [None]:
####################################################################################################################
####################################################################################################################
####################################################################################################################
####################################################################################################################
####################################################################################################################
####################################################################################################################
####################################################################################################################
####################################################################################################################
####################################################################################################################
####################################################################################################################