In [1]:
# Implements 1 level cache (No backward pass needed so not implemented; only accuracy determination required)
# Weight sharing not included

# Cache.approx just updates W, A values according to hits and misses and does not compute the result. Actual result computation
# is done at the end in parallel by doing updated_W * updated_A
# Current - LRU policy replacement. Future - May use LRU + count-used based policy

# Find out how hdw simulation is done in papers. They ofcourse don't actually build the hdw, they just simulate it. But, 
# they are still able to find the accuracy value by running the program over their hdw simulation

# To be added (and be performed in this order):
# Network Pruning
# Wt sharing
# Bit Masking (both weights and activations; may try retraining also for both) / Pytorch trained 8 bit int quantization
# May merge weight bit masking in 3rd step with 2nd step 

# Find out exactly how 1000 caches will be used in hdw. They all have the same data ? If miss occurs in one cache, access 
# to all caches will be stopped ? Using the parallelism of 1000 caches to process misses in parallel can really speed up
# the simulation. The sequential processing of misses is the current bottleneck

# Speed Up methods:
# Prune W and A in initial step to be able to increase batch size (Currently, very difficult to vectorize. Basically, hard
# to get equal size search ranges for each unique weight in the new weight-activation pairs to prune cache_W. Search ranges
# can be made equal for each unique weight in a sequential manner)
# Use multiple caches which are dependant / independent
# May do distance calculation between new weight-activation pairs and cache pairs in batches to increase the batch size fed to
# the cache

# First: Make updating time corresponding to hits consume less memory and run faster. (Start sampling LI from the end in
# small batches and break from the loop when all unique_LI have been found)
# Currently, it is the bottleneck for increasing batch size. So even if say 16 caches are used in parallel, only the same no 
# of hits as are being processed currently could be processed in parallel  

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import  torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import shutil
import copy
import pandas as pd
from sklearn.cluster import KMeans, MiniBatchKMeans
from statistics import mean
from collections  import OrderedDict
from collections  import namedtuple
import sys
import Cache_Definition
import torch.multiprocessing as mp

device = torch.device('cpu')
SAVE_PATH = 'D://models//main_net.pth'

In [3]:
transform = {
    'train':transforms.Compose([
        transforms.RandomResizedCrop(224), 
        transforms.RandomHorizontalFlip(), 
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],[0.229,0.224,0.225])]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
}

In [4]:
data_dir = 'D:\\datasets\\ILSVRC2012_img_val - Retrain\\'
dataset = {x:datasets.ImageFolder(os.path.join(data_dir, x), transform[x]) for x in ['train', 'val']}

In [5]:
dataloader = {x:torch.utils.data.DataLoader(dataset[x], batch_size = 1, shuffle = False)
              for x in ['train', 'val']}

In [6]:
dataset_size = {x:len(dataset[x]) for x in ['train', 'val']}
class_names = dataset['train'].classes

In [7]:
num_units = 12
global_caches = [Cache_Definition.cache(device=device) for i in range(num_units)]
process_set = mp.Pool(processes = num_units)

In [8]:
def batch_wise_approx(orig_W, orig_A, num_send):
    
    global num_units
    global global_caches
    global process_set
    global device
        
    if orig_W.shape != orig_A.shape:
        sys.exit("W-shape and A-shape unequal")

    shape = orig_W.shape
    num_elem = orig_W.numel()

    orig_W = orig_W.flatten()
    orig_A = orig_A.flatten()

    out_W = torch.zeros(orig_W.shape, dtype = torch.float32, device = device)
    out_A = torch.zeros(orig_A.shape, dtype = torch.float32, device = device)
    
    hits = 0
    misses = 0
    tot = 1

    for i in range(int(num_elem / num_send) + 1):
        
        print('\r{} / {} | Hits {} / {}, {:.2f}%'.format(i, int(num_elem / num_send), hits, tot,  hits * 100.0 / tot), end = '', flush = True)
        
        start = i * num_send
        end = min((i + 1) * num_send, num_elem)
        if end == start:
            break
        
        num_vals = end - start
        q = int(num_vals / num_units)
        r = num_vals - q * num_units
        
        # Split W and A into chunks for the different processes
        split_vals = []
        beg = start
        k = 0
        for j in range(r):
            split_vals.append((global_caches[k], orig_W[beg : beg + q + 1], orig_A[beg : beg + q + 1]))
            beg += q + 1
            k += 1
        while(beg < end):
            split_vals.append((global_caches[k], orig_W[beg : beg + q], orig_A[beg : beg + q]))
            beg += q
            k += 1
        
        output = process_set.starmap(Cache_Definition.approx, split_vals, chunksize = 1)

        # Gather together the distributed output from different processes
        out_W_list = []
        out_A_list = []
        hits = 0
        misses = 0
        for out in output:
            out_W_list.append(out[0])
            out_A_list.append(out[1])
            hits += out[2]
            misses += out[3]
        out_W[start:end] = torch.cat(out_W_list)
        out_A[start:end] = torch.cat(out_A_list)
        tot = hits + misses
        #print('\r{} / {} | Hits {} / {}, {:.2f}%'.format(i, int(num_elem / num_send), hits, tot,  hits * 100.0 / tot), end = '', flush = True)
        
    return out_W.view(shape), out_A.view(shape)

In [9]:
# # Testing

# global_cache.size = 3
# global_cache.epsilon = 3
# global_cache.W = torch.tensor([1.,2,3], device = device, dtype = torch.float32)
# global_cache.A = torch.tensor([7.,2,5], device = device, dtype = torch.float32)
# global_cache.time = torch.ones(global_cache.size, device = device, dtype = torch.int32)

# x = torch.tensor([1., 2, 3, 1, 2, 1, 1, 3, 0, 3, 2], device = device, dtype = torch.float32)
# y = torch.tensor([1., 2, 3, 4, 5, 6, 7, 8, 9, 10, 0], device = device, dtype = torch.float32)
# x, y, hits, misses = global_cache.approx(x, y)

# print(x)
# print(y)
# print(hits, misses)

# # Expected Answer
# # tensor([2., 2., 2., 1., 1., 1., 1., 3., 0., 3., 2.], device='cuda:0')
# # tensor([2., 2., 2., 4., 4., 7., 7., 8., 9., 8., 0.], device='cuda:0')
# # 7.0 2.0

In [10]:
# # Testing

# global_cache.size = 1
# global_cache.epsilon = 0.1
# global_cache.W = torch.tensor([1.], device = device, dtype = torch.float32)
# global_cache.A = torch.tensor([7.], device = device, dtype = torch.float32)
# global_cache.time = torch.ones(global_cache.size, device = device, dtype = torch.int32)

# x = torch.tensor([1., 2, 3, 1, 2, 1, 1, 3, 0, 3, 2], device = device, dtype = torch.float32)
# y = torch.tensor([1., 2, 3, 4, 5, 6, 7, 8, 9, 10, 0], device = device, dtype = torch.float32)
# x, y, hits, misses = global_cache.approx(x, y)

# print(x)
# print(y)
# print(hits, misses)

# # Expected Answer
# # tensor([1., 2., 3., 1., 2., 1., 1., 3., 0., 3., 2.], device='cuda:0')
# # tensor([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.,  0.],
# #        device='cuda:0')
# # 0 9

In [11]:
class cached_conv(nn.Module):
    
    global device
    
    def __init__(self, wt_layer):
        super(cached_conv, self).__init__()
        self.weight = wt_layer.weight
        self.bias = wt_layer.bias
        self.stride = wt_layer.stride
        self.padding = wt_layer.padding
        #self.dilation = wt_layer.dilation
        #self.groups = wt_layer.groups
        
    def forward(self, x):
        
        print('Reached conv')
        
        A_prev = x
        W = self.weight
        b = self.bias
        stride = self.stride
        pad = self.padding
        
        #return F.conv2d(x, W, bias = b, stride = stride, padding = pad)
        
        (m, n_C_prev, n_H_prev, n_W_prev) = A_prev.shape
        (n_C, n_C_prev, f, f) = W.shape
        
        # Compute the dimensions of the CONV output volume 
        n_H = int((n_H_prev + 2*pad[0] - f)/stride[0]) + 1
        n_W =int((n_W_prev + 2*pad[1] - f)/stride[1]) + 1

        y = F.unfold(A_prev, (f, f), padding = pad, stride = stride).transpose(2,1)
        #y = y.view(m, 1, y.shape[1],y.shape[2]).repeat((1,n_C,1,1))
        y = y.view(m, 1, y.shape[1],y.shape[2]).expand((-1,n_C,-1,-1))

        W = W.view(n_C, -1)
        #W = W.view(1,n_C, 1, W.shape[1]).repeat(m, 1, y.shape[2], 1)
        W = W.view(1,n_C, 1, W.shape[1]).expand(m, -1, y.shape[2], -1)
        
        W, y = batch_wise_approx(W, y, 12000)
        
        Z = torch.sum(W * y, dim = 3).view(m, n_C, n_H, n_W)
        Z = Z + b.view(1,b.shape[0], 1, 1)
        
        #print(torch.sum(torch.abs(Z - F.conv2d(x, self.weight, bias = self.bias, stride = self.stride, padding = self.padding))))
        
        #sys.exit()
        
        return Z

In [12]:
class cached_fc(nn.Module):
    
    global device
    
    def __init__(self, wt_layer):
        super(cached_fc, self).__init__()
        self.weight = wt_layer.weight
        self.bias = wt_layer.bias
        
    def forward(self, x):
        
        #print('Reached fc')
               
        return F.linear(x, self.weight, bias = self.bias)
        
        A_prev = x
        W = self.weight
        b = self.bias
        
        (m, n_prev) = A_prev.shape
        (n, n_prev) = W.shape

        A_prev = A_prev.view(m, 1, n_prev).expand(-1, n, -1)
        W = W.view(1, n, n_prev).expand(m, -1, -1)

        W, A_prev = batch_wise_approx(W, A_prev, 12000)
        
        Z = (A_prev * W).sum(dim = 2).view(m, n)
        Z = Z + b.view(1, n)
        
        #print(torch.sum(torch.abs(Z - F.linear(x, self.weight, bias = self.bias))))

        return Z

In [13]:
class AlexNet(nn.Module):

    def __init__(self, init_state_dict, num_classes=1000):
        super(AlexNet, self).__init__()
        
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )
        
        self.load_state_dict(init_state_dict)
        
        self.init_cache_layers()
       
    def init_cache_layers(self):
        
        ind = -1
        global global_cache 
        q_list = []
        for layer in self.features:
            if isinstance(layer, nn.Conv2d):
                ind += 1
                q_list.append(cached_conv(layer))
            else:
                q_list.append(layer)
        self.features = nn.Sequential(*q_list)
        
        ind = -1
        q_list = []
        for layer in self.classifier:
            if isinstance(layer, nn.Linear):
                ind += 1
                q_list.append(cached_fc(layer))
            else:
                q_list.append(layer)
        self.classifier = nn.Sequential(*q_list)
        
    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

In [14]:
def check_accuracy(model, phase):
    
    global device
    
    model.to(device)
    model.eval()
#     if record_grad:
#         model.train()
#     else:
#         model.eval()

        
    done = 0
    acc = 0.0
    since = time.time()
    corrects = torch.tensor(0)
    total_loss = 0.0
    corrects = corrects.to(device)
    loss = 100.0
    
    for inputs, labels in dataloader[phase]:

        inputs = inputs.to(device)
        labels = labels.to(device)
        
        with torch.set_grad_enabled(False):
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            corrects += torch.sum(preds == labels)

        done += len(inputs)
        print('\r{}, {}, {:.2f}%, {:.2f}'.format(corrects.item(), done, corrects.item() * 100.0 / done, total_loss), end = '')
#         if done >= 1000:
#             break
                    
    acc = corrects.double() / done
    print('\n{} Acc: {:.4f} %'.format(phase, acc * 100))

    time_elapsed = time.time() - since
    print('Total time taken = {} seconds'.format(time_elapsed))

    return acc


In [15]:
alexnet = models.alexnet(pretrained=True)
model = AlexNet(init_state_dict=torch.load('D://models//undone_wt_shared_net.pth'))

In [16]:
model.classifier

Sequential(
  (0): Dropout(p=0.5, inplace=False)
  (1): cached_fc()
  (2): ReLU(inplace=True)
  (3): Dropout(p=0.5, inplace=False)
  (4): cached_fc()
  (5): ReLU(inplace=True)
  (6): cached_fc()
)

In [None]:
model.to(device)
torch.cuda.empty_cache()

check_accuracy(model, 'val')

Reached conv
5856 / 5856 | Hits 4312 / 10248, 42.08%Reached conv
10328 / 18662 | Hits 3777 / 5559, 67.94%

In [None]:
####################################################################################################################
####################################################################################################################
####################################################################################################################
####################################################################################################################
####################################################################################################################
####################################################################################################################
####################################################################################################################
####################################################################################################################
####################################################################################################################
####################################################################################################################