<a href="https://colab.research.google.com/github/jeshraghian/VariationAware/blob/master/ECTC_printout.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [38]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
from collections import namedtuple
import pandas as pd
from pandas import DataFrame
import shutil

## Quantization Functions

In [39]:
from collections import namedtuple
import torch
import torch.nn as nn

QTensor = namedtuple('QTensor', ['tensor', 'scale', 'zero_point'])

def calcScaleZeroPoint(min_val, max_val,num_bits=8):
  # Calc Scale and zero point of next 
  qmin = 0.
  qmax = 2.**num_bits - 1.

  scale = (max_val - min_val) / (qmax - qmin)

  initial_zero_point = qmin - min_val / scale
  
  zero_point = 0
  if initial_zero_point < qmin:
      zero_point = qmin
  elif initial_zero_point > qmax:
      zero_point = qmax
  else:
      zero_point = initial_zero_point

  zero_point = int(zero_point)

  return scale, zero_point

def calcScaleZeroPointSym(min_val, max_val,num_bits=8):
  
  # Calc Scale 
  max_val = max(abs(min_val), abs(max_val))
  qmin = 0.
  qmax = 2.**(num_bits-1) - 1.

  scale = max_val / qmax

  return scale, 0

def quantize_tensor(x, num_bits=8, min_val=None, max_val=None):
    
    if not min_val and not max_val: 
      min_val, max_val = x.min(), x.max()

    qmin = 0.
    qmax = 2.**num_bits - 1.

    scale, zero_point = calcScaleZeroPoint(min_val, max_val, num_bits)
    q_x = zero_point + x / scale
    q_x.clamp_(qmin, qmax).round_()
    q_x = q_x.round().byte()
    
    return QTensor(tensor=q_x, scale=scale, zero_point=zero_point)

def dequantize_tensor(q_x):
    return q_x.scale * (q_x.tensor.float() - q_x.zero_point)

def quantize_tensor_sym(x, num_bits=8, min_val=None, max_val=None):
    
    if not min_val and not max_val: 
      min_val, max_val = x.min(), x.max()

    max_val = max(abs(min_val), abs(max_val))
    qmin = 0.
    qmax = 2.**(num_bits-1) - 1.

    scale = max_val / qmax   

    q_x = x/scale

    q_x.clamp_(-qmax, qmax).round_()
    q_x = q_x.round()
    return QTensor(tensor=q_x, scale=scale, zero_point=0)

def dequantize_tensor_sym(q_x):
    return q_x.scale * (q_x.tensor.float())

In [40]:
# quantization aware training

class FakeQuantOp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, num_bits=8, min_val=None, max_val=None):
        x = quantize_tensor(x,num_bits=num_bits, min_val=min_val, max_val=max_val)
        x = dequantize_tensor(x)
        return x

    @staticmethod
    def backward(ctx, grad_output):
        # straight through estimator
        return grad_output, None, None, None

In [41]:
# Get Min and max of x tensor, and stores it
def updateStats(x, stats, key):
  max_val, _ = torch.max(x, dim=1)
  min_val, _ = torch.min(x, dim=1)

  # add ema calculation

  if key not in stats:
    stats[key] = {"max": max_val.sum(), "min": min_val.sum(), "total": 1}
  else:
    stats[key]['max'] += max_val.sum().item()
    stats[key]['min'] += min_val.sum().item()
    stats[key]['total'] += 1
  
  weighting = 2.0 / (stats[key]['total']) + 1

  if 'ema_min' in stats[key]:
    stats[key]['ema_min'] = weighting*(min_val.mean().item()) + (1- weighting) * stats[key]['ema_min']
  else:
    stats[key]['ema_min'] = weighting*(min_val.mean().item())

  if 'ema_max' in stats[key]:
    stats[key]['ema_max'] = weighting*(max_val.mean().item()) + (1- weighting) * stats[key]['ema_max']
  else: 
    stats[key]['ema_max'] = weighting*(max_val.mean().item())

  stats[key]['min_val'] = stats[key]['min']/ stats[key]['total']
  stats[key]['max_val'] = stats[key]['max']/ stats[key]['total']
  
  return stats

# Reworked Forward Pass to access activation Stats through updateStats function
def gatherActivationStats(model, x, stats):

  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'conv1')
  
  x = F.relu(model.conv1(x))

  x = F.max_pool2d(x, 2, 2)
  
  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'conv2')
  
  x = F.relu(model.conv2(x))

  x = F.max_pool2d(x, 2, 2)

  x = x.view(-1, 4*4*50)
  
  stats = updateStats(x, stats, 'fc1')

  x = F.relu(model.fc1(x))
  
  stats = updateStats(x, stats, 'fc2')

  x = model.fc2(x)

  return stats

# Entry function to get stats of all functions.
def gatherStats(model, test_loader):
    device = 'cuda'
    
    model.eval()
    test_loss = 0
    correct = 0
    stats = {}
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            stats = gatherActivationStats(model, data, stats)
    
    final_stats = {}
    for key, value in stats.items():
      final_stats[key] = { "max" : value["max"] / value["total"], "min" : value["min"] / value["total"], "ema_min": value["ema_min"], "ema_max": value["ema_max"] }
    return final_stats

## Forward pass

In [42]:
def quantAwareTrainingForward(model, x, stats, device, vis=False, axs=None, sym=False, num_bits=8, act_quant=False, noise=0):

  # set noise to 0.02 for 2% variation, 0.1 for 10% etc.
  conv1weight = model.conv1.weight.data
  model.conv1.weight.data = FakeQuantOp.apply(model.conv1.weight.data, num_bits)

  #print("The size of x is: {}".format(x.size()))
  #print("The size of conv1 is: {}".format(model.conv1.weight.data.size()))
  a = x.size()[0]
  b = model.conv1.weight.data.size()[0]

  if noise:
    memristor_dist1 = torch.normal(1, noise, size = [int(a), int(b), 24, 24]).to(device)
    x = F.relu(model.conv1(x)*memristor_dist1)
  else:
    x = F.relu(model.conv1(x))

  with torch.no_grad():
    stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'conv1')

  if act_quant:
    x = FakeQuantOp.apply(x, num_bits, stats['conv1']['ema_min'], stats['conv1']['ema_max'])

  x = F.max_pool2d(x, 2, 2)

  conv2weight = model.conv2.weight.data
  model.conv2.weight.data = FakeQuantOp.apply(model.conv2.weight.data, num_bits)

  #print("The size of x is: {}".format(x.size()))
  #print("The size of conv2 is: {}".format(model.conv1.weight.data.size()))

  a = x.size()[0]
  b = model.conv2.weight.data.size()[0]

  if noise:
    memristor_dist2 = torch.normal(1, noise, size = [int(a), int(b), 8, 8]).to(device)
    x = F.relu(model.conv2(x)*memristor_dist2)

  else:
    x = F.relu(model.conv2(x))


  with torch.no_grad():
    stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'conv2')
    
  if act_quant:
    x = FakeQuantOp.apply(x, num_bits, stats['conv2']['ema_min'], stats['conv2']['ema_max'])


  x = F.max_pool2d(x, 2, 2)
  x = x.view(-1, 4*4*50)

  fc1weight = model.fc1.weight.data
  model.fc1.weight.data = FakeQuantOp.apply(model.fc1.weight.data, num_bits)
  #print("fc1weight size is: {}".format(fc1weight.size()))

  #print("The size of x is: {}".format(x.size()))
  #print("The size of conv2 is: {}".format(model.fc1.weight.data.size()))

  a = x.size()[0]
  b = model.fc1.weight.data.size()[0]

  #print("The size of a at fc1 is: {}".format(a)) #should be 64
  #print("The size of b at fc1 is: {}".format(b)) # should be 500
  
  if noise:
    memristor_dist3 = torch.normal(1, noise, size = [int(a), int(b)]).to(device)
    x = F.relu(model.fc1(x)*memristor_dist3)
  else:
    x = F.relu(model.fc1(x))
  #print("x after fc1 is: {}".format(x.size()))

  with torch.no_grad():
    stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'fc1')

  if act_quant:
    x = FakeQuantOp.apply(x, num_bits, stats['fc1']['ema_min'], stats['fc1']['ema_max'])

  a = x.size()[0]
  b = model.fc2.weight.data.size()[0]

  if noise:
    memristor_dist4 = torch.normal(1, noise, size = [int(a), int(b)]).to(device)
    x = model.fc2(x)*memristor_dist4
  else:
    x = model.fc2(x)
  #print("x after fc2 is: {}".format(x.size()))
  
  with torch.no_grad():
    stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'fc2')

  return F.log_softmax(x, dim=1), conv1weight, conv2weight, fc1weight, stats

In [43]:
# Define Network
# redefine the nn.Module with a random gaussian term 

class NetMem(nn.Module):
    def __init__(self, mnist=True):
      
        super(NetMem, self).__init__()
        if mnist:
          num_channels = 1
        else:
          num_channels = 3
          
        self.conv1 = nn.Conv2d(num_channels, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        if mnist:
          self.fc1 = nn.Linear(4*4*50, 500)
          self.flatten_shape = 4*4*50
        else:
          self.fc1 = nn.Linear(1250, 500)
          self.flatten_shape = 1250

        self.fc2 = nn.Linear(500, 10)
        
    def forward(self, x, vis=False, axs=None):
        X = 0
        y = 0

        if vis:
          axs[X,y].set_xlabel('Entry into network, input distribution visualised below: ')
          visualise(x, axs[X,y])

          axs[X,y+1].set_xlabel("Visualising weights of conv 1 layer: ")
          visualise(self.conv1.weight.data, axs[X,y+1])

        # memristor_dist1 = torch.normal(1, 0.02, size = self.conv1.weight.data.size()) 
        x = F.relu(self.conv1(x))

        if vis:
          axs[X,y+2].set_xlabel('Output after conv1 visualised below: ')
          visualise(x,axs[X,y+2])

          axs[X,y+3].set_xlabel("Visualising weights of conv 2 layer: ")
          visualise(self.conv2.weight.data, axs[X,y+3])

        x = F.max_pool2d(x, 2, 2)

        # memristor_dist2 = torch.normal(1, 0.02, size = self.conv2.weight.data.size()) 
        x = F.relu(self.conv2(x))

        if vis:
          axs[X,y+4].set_xlabel('Output after conv2 visualised below: ')
          visualise(x,axs[X,y+4])

          axs[X+1,y].set_xlabel("Visualising weights of fc 1 layer: ")
          visualise(self.fc1.weight.data, axs[X+1,y])

        x = F.max_pool2d(x, 2, 2)  
        x = x.view(-1, self.flatten_shape)

        # memristor_dist3 = torch.normal(1, 0.02, size = self.fc1.weight.data.size()) 
        x = F.relu(self.fc1(x))

        if vis:
          axs[X+1,y+1].set_xlabel('Output after fc1 visualised below: ')
          visualise(x,axs[X+1,y+1])

          axs[X+1,y+2].set_xlabel("Visualising weights of fc 2 layer: ")
          visualise(self.fc2.weight.data, axs[X+1,y+2])

        # memristor_dist4 = torch.normal(1, 0.02, size = self.fc2.weight.data.size()) 
        x = self.fc2(x)

        if vis:
          axs[X+1,y+3].set_xlabel('Output after fc2 visualised below: ')
          visualise(x,axs[X+1,y+3])

        return F.log_softmax(x, dim=1)
           

In [44]:
def trainQuantAware(args, model, device, train_loader, optimizer, epoch, stats, act_quant=False, num_bits=8, noise=0):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output, conv1weight, conv2weight, fc1weight, stats = quantAwareTrainingForward(model, data, stats, device, num_bits=num_bits, act_quant=act_quant, noise=noise)

        model.conv1.weight.data = conv1weight
        model.conv2.weight.data = conv2weight
        model.fc1.weight.data = fc1weight

        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()

        if batch_idx % args["log_interval"] == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
    return stats, loss.item()

def testQuantAware(args, model, device, test_loader, stats, act_quant, num_bits=8, noise=0):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output, conv1weight, conv2weight, fc1weight, _ = quantAwareTrainingForward(model, data, stats, device, num_bits=num_bits, act_quant=act_quant, noise=noise)
            
            model.conv1.weight.data = conv1weight
            model.conv2.weight.data = conv2weight
            model.fc1.weight.data = fc1weight

            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    test_acc = 100. * correct / len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        test_acc))

    return test_loss, test_acc

def mainQuantAware(mnist=True):
 
    batch_size = 64
    test_batch_size = 64
    epochs = 20
    lr = 0.01
    momentum = 0.5
    seed = 1
    log_interval = 500
    save_model = False
    no_cuda = False
    
    use_cuda = not no_cuda and torch.cuda.is_available()
    torch.manual_seed(seed)
    device = torch.device("cuda" if use_cuda else "cpu")
    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

    if mnist:
      train_loader = torch.utils.data.DataLoader(
          datasets.MNIST('../data', train=True, download=True,
                        transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.1307,), (0.3081,))
                        ])),
          batch_size=batch_size, shuffle=True, **kwargs)
      
      test_loader = torch.utils.data.DataLoader(
          datasets.MNIST('../data', train=False, transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.1307,), (0.3081,))
                        ])),
          batch_size=test_batch_size, shuffle=True, **kwargs)
    else:
      transform = transforms.Compose(
          [transforms.ToTensor(),
          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

      trainset = datasets.CIFAR10(root='./dataCifar', train=True,
                                              download=True, transform=transform)
      train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                                shuffle=True, num_workers=2)

      testset = datasets.CIFAR10(root='./dataCifar', train=False,
                                            download=True, transform=transform)
      test_loader = torch.utils.data.DataLoader(testset, batch_size=test_batch_size,
                                              shuffle=False, num_workers=2)
          
  
    model = NetMem(mnist=mnist).to(device)
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
    args = {}
    args["log_interval"] = log_interval
    num_trials = 5
    stats = {}
    df = pd.DataFrame(columns=['trial', 'variation', 'num_bits', 'lr', 'epoch', 'train_set_loss', 'test_set_loss', 'test_set_accuracy'])

    for num_bits in [10]:
      for noise in [0.05, 0.1, 0.2, 0.5]:
        for i in range(num_trials):
          print(f"num_bits: {num_bits}, noise: {noise}, trial: {i}")
          for epoch in range(1, epochs + 1):
              if epoch <= 5:
                lr = 0.01
              if epoch > 5:
                act_quant = True 
                lr = 0.001
              else:
                act_quant = False

              stats, train_loss = trainQuantAware(args, model, device, train_loader, optimizer, epoch, stats, act_quant, num_bits=num_bits, noise=noise)
              test_loss, test_acc = testQuantAware(args, model, device, test_loader, stats, act_quant, num_bits=num_bits, noise=noise)

              # if (save_model):
              #     torch.save(model.state_dict(),"mnist_cnn.pt")
                
              df = df.append(
                  {'trial': i, 'variation': noise, 'num_bits': num_bits, 'lr': lr, 'epoch': epoch, 'train_set_loss': train_loss, 'test_set_loss': test_loss, 'test_set_accuracy': test_acc}, ignore_index=True)
              df.to_csv('ECTC2.csv', index=False)
              # if SAVE_GOOGLE_COLAB:
              #   shutil.copy("ECTC.csv", "/content/ECTC.csv")

    return model, stats

model, old_stats = mainQuantAware()

num_bits: 10, noise: 0.05, trial: 0

Test set: Average loss: 2.2954, Accuracy: 1135/10000 (11%)


Test set: Average loss: 2.2840, Accuracy: 1539/10000 (15%)


Test set: Average loss: 2.2642, Accuracy: 1569/10000 (16%)


Test set: Average loss: 2.2427, Accuracy: 1620/10000 (16%)


Test set: Average loss: 2.2246, Accuracy: 1628/10000 (16%)


Test set: Average loss: 2.2175, Accuracy: 1817/10000 (18%)


Test set: Average loss: 2.1687, Accuracy: 1961/10000 (20%)


Test set: Average loss: 2.1079, Accuracy: 2099/10000 (21%)


Test set: Average loss: 2.0704, Accuracy: 2079/10000 (21%)


Test set: Average loss: 2.0549, Accuracy: 2119/10000 (21%)



ZeroDivisionError: ignored