# MNIST Model

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, Compose
import matplotlib.pyplot as plt
import numpy as np
from collections import namedtuple
import copy
import warnings
warnings.filterwarnings(action='ignore')

# Preparing Data

In [2]:
dataset_dir = '~/dataset'
train_data = datasets.MNIST(root=dataset_dir, train=True,  download=True, transform=ToTensor())
test_data  = datasets.MNIST(root=dataset_dir, train=False, download=True, transform=ToTensor())

batch_size = 64

# Number of trainData/validData/testData = 50000/10000/10000
train_dataloader = DataLoader(train_data, batch_size=batch_size)
test_dataloader  = DataLoader(test_data,  batch_size=batch_size)

In [3]:
def visualise(x, axs):
	x = x.view(-1).cpu().numpy()
	axs.hist(x)

# MNIST Model

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))

class NN(nn.Module):
	def __init__(self):
		super(NN, self).__init__()
		self.flatten = nn.Flatten()
		self.fc1 = nn.Linear(28*28, 16)
		self.fc2 = nn.Linear(16, 16)
		self.fc3 = nn.Linear(16, 10)
		
	def forward(self, x):
		x = self.flatten(x)
		x = F.relu(self.fc1(x))
		x = F.relu(self.fc2(x))
		logits = self.fc3(x)

		return logits

loss_func = nn.CrossEntropyLoss()

Using cuda device


# Training

In [5]:
def train(train_dataloader, model, loss_func, optimizer, epoch):
	model.train()
	max_batch_index = int(np.floor(len(train_data)/batch_size))
	for batch_index, (image, label) in enumerate(train_dataloader):
		image, label = image.to(device), label.to(device)
		pred = model(image)
		loss = loss_func(pred, label)

		optimizer.zero_grad()
		loss.backward()
		optimizer.step()
				
		if batch_index == max_batch_index:
			print(f'Epoch {epoch+1:<3d}: Loss: {loss.item():.2f}', end = '\t')

In [6]:
def test(test_dataloader, model, loss_func, epoch):
	model.eval()
	loss, correct = 0, 0
	with torch.no_grad():
		for image, label in test_dataloader:
			image, label = image.to(device), label.to(device)
			pred = model(image)
			loss += loss_func(pred, label).item()
			correct += (pred.argmax(1) == label).type(torch.int).sum().item()
	loss /= len(test_dataloader.dataset)
	correct_rate = 100 * correct / len(test_dataloader.dataset)
	print(f'Accuracy: {correct}/{len(test_data)} ({correct_rate:>.1f}%)')

In [7]:
def main():
	model = NN().to(device)
	optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
	epochs = 10
	for epoch in range(epochs):
		train(train_dataloader, model, loss_func, optimizer, epoch)
		test(test_dataloader, model, loss_func, epoch)
	print("Done!")
	return model

# model = main()

# Quantization of Network

## Quantization Functions

In [8]:
def calcScaleZeroPoint(min_val, max_val, num_bits=8, opt='asym'):
	qmin, qmax = 0., 2.**num_bits - 1.
	if opt == 'asym':
		scale = (max_val - min_val) / (qmax - qmin)
		initial_zero_point = -min_val / scale
		zero_point = int({initial_zero_point < qmin: qmin,  initial_zero_point > qmax: qmax, qmin <= initial_zero_point <= qmax: initial_zero_point}.get(True, False))
	if opt == 'sym':
		scale = max(abs(min_val), abs(max_val)) / qmax
		zero_point = 0
	return scale, zero_point

In [9]:
def quantizeTensor(input_tensor, min_val=None, max_val=None, num_bits=8, opt='asym'):
	if not min_val and not max_val:
		min_val, max_val = input_tensor.min(), input_tensor.max()
	
	qmin, qmax = 0., 2.**num_bits - 1.
	scale, zero_point = calcScaleZeroPoint(min_val, max_val, num_bits, opt)
	
	if opt == 'asym':
		quant_tensor = (input_tensor / scale + zero_point).clamp(qmin, qmax).round()
	if opt == 'sym':
		quant_tensor = (input_tensor / scale).clamp(qmax, qmax).round()
	
	qTuple = namedtuple('qTuple', ['tensor', 'scale', 'zero_point'])
	return qTuple(tensor=quant_tensor, scale=scale, zero_point=zero_point)

In [10]:
def dequantizeTensor(qTuple, opt='asym'):
	if opt == 'asym':
		dequant_tensor = qTuple.scale * (qTuple.tensor.float() - qTuple.zero_point)
	if opt == 'sym':
		dequant_tensor = qTuple.scale * (qTuple.tensor.float())	
	return dequant_tensor

In [11]:
def updateStats(actTensor, stats, layerName):
	# dim=0 : find min/max in each col
	# dim=1 : find min/max in each row
	maxValue = torch.max(actTensor, dim=1)[0]
	minValue = torch.min(actTensor, dim=1)[0]
	
	if layerName not in stats:
		stats[layerName] = {'max': maxValue.sum(), 'min': minValue.sum(), 'total': 1}
	else:
		stats[layerName]['max'] += maxValue.sum().item()
		stats[layerName]['min'] += maxValue.sum().item()
		stats[layerName]['total'] += 1
		
	weighting = 2.0 / (stats[layerName]['total']) + 1

	if 'ema_min' in stats[layerName]:
		stats[layerName]['ema_min'] = weighting*(minValue.mean().item()) + (1 - weighting) * stats[layerName]['ema_min']
	else:
		stats[layerName]['ema_min'] = weighting*(minValue.mean().item())

	if 'ema_max' in stats[layerName]:
		stats[layerName]['ema_max'] = weighting*(maxValue.mean().item()) + (1 - weighting) * stats[layerName]['ema_max']
	else: 
		stats[layerName]['ema_max'] = weighting*(maxValue.mean().item())

		stats[layerName]['min_val'] = stats[layerName]['min']/ stats[layerName]['total']
		stats[layerName]['max_val'] = stats[layerName]['max']/ stats[layerName]['total']

	return stats

In [12]:
# Reworked Forward Pass to access activation Stats through updateStats function
def gatherActivationStats(model, x, stats):
	x = model.flatten(x)
	stats = updateStats(x, stats, 'fc1')
	x = F.relu(model.fc1(x))
	
	stats = updateStats(x, stats, 'fc2')
	x = F.relu(model.fc2(x))

	stats = updateStats(x, stats, 'fc3')
	x = model.fc3(x)

	return stats

In [13]:
# Entry function to get stats of all functions.
def gatherStats(model, testDataLoader):
	model.eval()
	stats = {}
	with torch.no_grad():
		for img, lab in testDataLoader:
			img, lab = img.to(device), lab.to(device)
			stats = gatherActivationStats(model, img, stats)

	final_stats = {}
	for key, value in stats.items():
		final_stats[key] = { "max" : value["max"] / value["total"], "min" : value["min"] / value["total"], "ema_min": value["ema_min"], "ema_max": value["ema_max"] }
	return final_stats

In [14]:
class FakeQuantOp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, min_val=None, max_val=None, num_bits=8, opt='asym'):
		#def quantizeTensor(input_tensor, min_val=None, max_val=None, num_bits=4, opt='asym'):
        x = quantizeTensor(x, min_val=min_val, max_val=max_val, num_bits=num_bits, opt=opt)
        x = dequantizeTensor(x)
        return x

    @staticmethod
    def backward(ctx, grad_output):
        # straight through estimator
        return grad_output, None, None, None, None

In [26]:
def quantAwareTrainingForward(x, model, stats, act_quant=False, num_bits=8, opt='asym'):
	params = model.state_dict()
	
	x = model.flatten(x)
	
	model.fc1.weight.data = FakeQuantOp.apply(model.fc1.weight.data, None, None, num_bits, opt)
	model.fc1.bias.data = FakeQuantOp.apply(model.fc1.bias.data, None, None, num_bits, opt)

# 	x = F.relu(model.fc1(x))
	with torch.no_grad():
		stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'fc1')
	if act_quant:
		x = FakeQuantOp.apply(x, stats['fc1']['ema_min'], stats['fc1']['ema_max'], num_bits, opt)
	x = F.relu(model.fc1(x))
	
	model.fc2.weight.data = FakeQuantOp.apply(model.fc2.weight.data, None, None, num_bits, opt)
	model.fc2.bias.data = FakeQuantOp.apply(model.fc2.bias.data, None, None, num_bits, opt)
# 	x = F.relu(model.fc2(x))
	with torch.no_grad():
		stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'fc2')
	if act_quant:
		x = FakeQuantOp.apply(x, stats['fc2']['ema_min'], stats['fc2']['ema_max'], num_bits, opt)
	x = F.relu(model.fc2(x))

	model.fc3.weight.data = FakeQuantOp.apply(model.fc3.weight.data, None, None, num_bits, opt)
	model.fc3.bias.data = FakeQuantOp.apply(model.fc3.bias.data, None, None, num_bits, opt)
# 	x = model.fc3(x)
	with torch.no_grad():
		stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'fc3')
	if act_quant:
		x = FakeQuantOp.apply(x, stats['fc3']['ema_min'], stats['fc3']['ema_max'], num_bits, opt)
	x = model.fc3(x)
	
#	return F.log_softmax(x, dim=1), fc1_weight, fc2_weight, fc3_weight, stats
	return x, params, stats

In [27]:
def trainQAT(train_dataloader, model, loss_func, optimizer, epoch, stats, act_quant=False, num_bits=8, opt='asym'):
	model.train()
	max_batch_index = int(np.floor(len(train_data)/batch_size))
	for batch_index, (image, label) in enumerate(train_dataloader):
		image, label = image.to(device), label.to(device)
		optimizer.zero_grad()
		y = model(image)
		y, params, stats = quantAwareTrainingForward(image, model, stats, act_quant, num_bits, opt)
		
		# Recover FP32 to improve accuracy
		model.fc1.weight.data = params['fc1.weight']
		model.fc2.weight.data = params['fc2.weight']
		model.fc3.weight.data = params['fc3.weight']
		model.fc1.bias.data   = params['fc1.bias']
		model.fc2.bias.data   = params['fc2.bias']
		model.fc3.bias.data   = params['fc3.bias']

		loss = loss_func(y, label)
		loss.backward()
		optimizer.step()
		
		if batch_index == max_batch_index:
			print(f'Epoch {epoch+1:<3d}: Loss: {loss.item():.2f}', end = '\t')
			
	return stats

In [28]:
def testQAT(test_dataloader, model, epoch, loss_func, stats, act_quant=False, num_bits=8, opt='asym'):
	model.eval()
	loss, correct = 0, 0
	with torch.no_grad():
		for image, label in test_dataloader:
			image, label = image.to(device), label.to(device)
			y, _, _ = quantAwareTrainingForward(image, model, stats, act_quant, num_bits, opt)	
			loss += loss_func(y, label).item()
			correct += (y.argmax(1) == label).type(torch.int).sum().item()
	loss /= len(test_dataloader.dataset)
	correct_rate = 100 * correct / len(test_dataloader.dataset)
	print(f'Accuracy: {correct}/{len(test_data)} ({correct_rate:>.1f}%)')

In [33]:
def mainQAT():
	model = NN().to(device)
	optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
	epochs = 10
	num_bits = 16
	stats = {}
	opt = 'asym'
	for epoch in range(epochs):
		act_quant = True if epoch > 5 else False
		#def trainQAT(train_dataloader, model, loss_func, optimizer, epoch, stats, act_quant=False, num_bits=8, opt='asym'):
		stats = trainQAT(train_dataloader, model, loss_func, optimizer, epoch, stats, act_quant, num_bits, opt)
		testQAT(test_dataloader, model, epoch, loss_func, stats, act_quant, num_bits, opt)
	print("Done!")
	return model, stats

In [34]:
q_model, old_stats = mainQAT()

Epoch 1  : Loss: 2.12	Accuracy: 2971/10000 (29.7%)
Epoch 2  : Loss: 1.23	Accuracy: 6480/10000 (64.8%)
Epoch 3  : Loss: 0.56	Accuracy: 7841/10000 (78.4%)
Epoch 4  : Loss: 0.38	Accuracy: 8402/10000 (84.0%)
Epoch 5  : Loss: 0.31	Accuracy: 8633/10000 (86.3%)
Epoch 6  : Loss: 0.27	Accuracy: 8740/10000 (87.4%)
Epoch 7  : Loss: 0.35	Accuracy: 8196/10000 (82.0%)
Epoch 8  : Loss: 0.32	Accuracy: 8333/10000 (83.3%)
Epoch 9  : Loss: 0.53	Accuracy: 7996/10000 (80.0%)
Epoch 10 : Loss: 0.32	Accuracy: 8413/10000 (84.1%)
Done!


## Rework Forward pass of Linear and Conv Layers to support Quantisation

In [None]:
testQAT(test_dataloader, q_model, epoch=10, loss_func=loss_func, stats=old_stats, act_quant=True, num_bits=4, opt='asym')

In [None]:
weight = q_model.state_dict()['fc1.weight']
weight

In [None]:
aa = quantizeTensor(weight)
aa

In [None]:
dequantizeTensor(aa)

In [None]:
q_model.state_dict()

In [None]:
def quantizeLayer(qActTuple, layer, stat, numBits=4, opt='asym'):
	W, B = layer.weight.data, layer.bias.data
	
	qWeightTuple = quantizeTensor(W, numBits=numBits, opt=opt)
	qBiasTuple   = quantizeTensor(B, numBits=numBits, opt=opt)
	
	layer.weight.data = qWeightTuple.qTensor.float()
	layer.bias.data   = qBiasTuple.qTensor.float()
	
	nextScale, nextZeroPoint = calcScaleZeroPoint(minValue=stat['min'], maxValue=stat['max'], numBits=numBits, opt=opt)
	
	weightScale = qWeightTuple.scale
	weightZeroPoint = qWeightTuple.zeroPoint
	
	if opt == 'asym':
		layer.weight.data = ((qWeightTuple.scale * qActTuple.scale) / nextScale) * (layer.weight.data - qWeightTuple.zeroPoint)
		layer.bias.data = (qBiasTuple.scale / nextScale) * (layer.bias.data - qBiasTuple.zeroPoint)		
		oAct = layer(qActTuple.qTensor.float() - qActTuple.zeroPoint) + nextZeroPoint
	if opt == 'sym':
		layer.weight.data = ((qWeightTuple.scale * qActTuple.scale) / nextScale) * (layer.weight.data)
		layer.bias.data = (qBiasTuple.scale / nextScale) * (layer.bias.data)
		oAct = layer(qActTuple.qTensor.float())
		
	layer.weight.data, layer.bias.data = W, B
	
	return oAct.round(), nextScale, nextZeroPoint

In [None]:
def quantForward(x, model, stats, num_bits=8, opt='asym'):
	# Quantise before inputting into incoming layers
	x = quantizeTensor()
	if sym:
		x = quantize_tensor_sym(x, min_val=stats['conv1']['min'], max_val=stats['conv1']['max'], num_bits=num_bits)
	else:
		x = quantize_tensor(x, min_val=stats['conv1']['min'], max_val=stats['conv1']['max'], num_bits=num_bits)

	x, scale_next, zero_point_next = quantizeLayer(x.tensor, model.conv1, stats['conv2'], x.scale, x.zero_point, vis, axs, X=X, y=y+1, sym=sym, num_bits=num_bits)

	x = F.max_pool2d(x, 2, 2)



	x, scale_next, zero_point_next = quantizeLayer(x, model.conv2, stats['fc1'], scale_next, zero_point_next, vis, axs, X=X, y=y+3, sym=sym, num_bits=num_bits)

	x = F.max_pool2d(x, 2, 2)


	x = x.view(-1, 4*4*50)

	x, scale_next, zero_point_next = quantizeLayer(x, model.fc1, stats['fc2'], scale_next, zero_point_next, vis, axs, X=X+1, y=0, sym=sym, num_bits=num_bits)


	# Back to dequant for final layer
# 	if sym:
# 		x = dequantize_tensor_sym(QTensor(tensor=x, scale=scale_next, zero_point=zero_point_next))
# 	else:
# 		x = dequantize_tensor(QTensor(tensor=x, scale=scale_next, zero_point=zero_point_next))


	x = model.fc2(x)

	if vis:
		axs[X+1,3].set_xlabel('Unquantised Weights of fc2 layer')
		visualise(model.fc2.weight.data,axs[X+1,3])

		axs[X+1,2].set_xlabel('Output after fc2 but dequantised visualised below: ')
		visualise(x,axs[X+1,4])

	return F.log_softmax(x, dim=1)