# Module

In [1]:
import argparse
from tqdm import tqdm_notebook as tq
import tqdm
import os, time, math, copy
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from collections import namedtuple
import matplotlib.pyplot as plt

torch.set_printoptions(precision=8, linewidth=50000)
import warnings
warnings.filterwarnings(action='ignore')

# User Define Variable

In [2]:
data_path = '~/dataset'

# Parser

In [3]:
parser = argparse.ArgumentParser(description='PyTorch for MNIST dataset')
parser.add_argument('--device', type=str, default='cuda', help='Device')
parser.add_argument('--shuffle', action='store_true', default=False, help='enables data shuffle')
parser.add_argument('--dataset', type=str, default='mnist', help='training dataset')
parser.add_argument('--data_path', type=str, default=data_path, help='path to MNIST')
parser.add_argument('--batch_size', type=int, default=64, help='batch size')
parser.add_argument('--epochs', type=int, default=10, help='number of epochs to train')
parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
parser.add_argument('--optimizer', type=str, default='adam', help='optimizer')
parser.add_argument('--loss_func', type=str, default='cel', help='optimizer')
parser.add_argument('--quant_opt', type=str, default='asym', help='Type of Quantization')
parser.add_argument('--num_bits', type=int, default=8, help='Number of Quantization Bits')
parser.add_argument('--act_quant', type=bool, default=False, help='Activation Quantization')
parser.add_argument('--disp', type=bool, default=False, help='Display Model Information')

args = parser.parse_args(args=[])

# Preparing Data

In [4]:
kwargs = {'num_workers': 1, 'pin_memory': True} if args.device == 'cuda' else {}
if args.dataset == 'mnist':
	train_loader = torch.utils.data.DataLoader(
		dataset=datasets.MNIST(
			root=args.data_path,
			train=True,
			download=True,
			transform=transforms.ToTensor()
		),
		batch_size=args.batch_size,
		shuffle=args.shuffle,
		**kwargs
	)

	test_loader = torch.utils.data.DataLoader(
		dataset=datasets.MNIST(
			root=args.data_path,
			train=False,
			download=True,
			transform=transforms.ToTensor()
		),
		batch_size=args.batch_size,
		shuffle=args.shuffle,
		**kwargs
	)

# Build Model

In [5]:
class MLP(nn.Module):
	def __init__(self):
		super(MLP, self).__init__()
		self.flatten = nn.Flatten()
		self.fc1 = nn.Linear(28*28, 32)
		self.relu1 = nn.ReLU()
		self.fc2 = nn.Linear(32, 16)
		self.relu2 = nn.ReLU()
		self.fc3 = nn.Linear(16, 10)
		
	def forward(self, x):
		x = self.flatten(x)
		x = self.fc1(x)
		x = self.relu1(x)
		x = self.fc2(x)
		x = self.relu2(x)
		logits = self.fc3(x)
		return logits


In [6]:
def genOptimizer(model, args):
	if args.optimizer == 'sgd':
		optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
	if args.optimizer == 'adam':
		optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
	return optimizer

def genLossFunc(args):
	if args.loss_func == 'cel':
		loss_func = nn.CrossEntropyLoss()
	return loss_func

In [7]:
def accuracy(output, target, topk=(1,)):
	with torch.no_grad():
		maxk = max(topk)
		batch_size = target.size(0)

		_, pred = output.topk(maxk, 1, True, True)
		pred = pred.t()
		correct = pred.eq(target.view(1, -1).expand_as(pred))

		res = []
		for k in topk:
			correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
			res.append(correct_k.mul_(100.0 / batch_size))
		return res

In [8]:
def train(train_loader, model, epoch, args):
	model.train()
	loss_func = genLossFunc(args)
	optimizer = genOptimizer(model, args)
	max_batch_index = int(np.floor(len(train_loader.dataset)/args.batch_size))
	running_loss = 0
	for batch_index, (image, label) in enumerate(tq(train_loader, desc='Train', leave=False)):
		image, label = image.to(args.device), label.to(args.device)
		pred = model(image)
		loss = loss_func(pred, label)
		running_loss += loss.item()#*image.size(0)

		optimizer.zero_grad()
		loss.backward()
		optimizer.step()
				
	print(f'Epoch {epoch+1:<3d}: Avg. Loss: {running_loss/len(train_loader.dataset):.4f}', end = '\t')

In [9]:
def test(test_loader, model, epoch, args):
	model.eval()
	with torch.no_grad():
		loss_func = genLossFunc(args)
		loss, correct = 0, 0
# 		for batch_index, (image, label) in enumerate(tq(test_loader, desc='Test', leave=False)):
		for batch_index, (image, label) in enumerate(test_loader):
			image, label = image.to(args.device), label.to(args.device)
			pred = model(image)
			loss += loss_func(pred, label).item()#*image.size(0)
			correct += (pred.argmax(1) == label).type(torch.int).sum().item()
	loss /= len(test_loader.dataset)
	correct_rate = 100 * correct / len(test_loader.dataset)
	print(f'Accuracy: {correct}/{len(test_loader.dataset)} ({correct_rate:>.1f}%)')

In [10]:
def main(model):
	for epoch in range(args.epochs):
		train(train_loader, model, epoch, args)
		test(test_loader, model, epoch, args)
	print("Done!")
	return model

model = main(MLP().to(args.device))

Train:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 1  : Avg. Loss: 0.0085	Accuracy: 9176/10000 (91.8%)


Train:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 2  : Avg. Loss: 0.0039	Accuracy: 9333/10000 (93.3%)


Train:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 3  : Avg. Loss: 0.0031	Accuracy: 9428/10000 (94.3%)


Train:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 4  : Avg. Loss: 0.0026	Accuracy: 9494/10000 (94.9%)


Train:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 5  : Avg. Loss: 0.0023	Accuracy: 9533/10000 (95.3%)


Train:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 6  : Avg. Loss: 0.0021	Accuracy: 9555/10000 (95.5%)


Train:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 7  : Avg. Loss: 0.0019	Accuracy: 9579/10000 (95.8%)


Train:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 8  : Avg. Loss: 0.0017	Accuracy: 9594/10000 (95.9%)


Train:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 9  : Avg. Loss: 0.0016	Accuracy: 9621/10000 (96.2%)


Train:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 10 : Avg. Loss: 0.0015	Accuracy: 9624/10000 (96.2%)
Done!


In [11]:
def calcScaleZeroPoint(min_val, max_val, args):
	qmin, qmax = 0., 2.**args.num_bits - 1.
	if args.quant_opt == 'asym':
		scale = (max_val - min_val) / (qmax - qmin)
		if scale is not 0.:
			scale = 2**torch.round(torch.log2(torch.tensor(abs(scale))))
		initial_zero_point = -min_val / scale
		if initial_zero_point < qmin:
			zero_point = int(qmin)
		elif initial_zero_point <= qmax:
			zero_point = int(initial_zero_point)
		else:
			zero_point = int(qmax)
#		zero_point = int({initial_zero_point < qmin: qmin,  initial_zero_point > qmax: qmax, qmin <= initial_zero_point <= qmax: initial_zero_point}.get(True, False))
	if args.quant_opt == 'sym':
		scale = max(abs(min_val), abs(max_val)) / qmax
		zero_point = 0
	return scale, zero_point

In [12]:
def quantizeTensor(input_tensor, min_val, max_val, args):
	if not min_val and not max_val:
		min_val, max_val = input_tensor.min(), input_tensor.max()
	
	qmin, qmax = 0., 2.**args.num_bits - 1.
	scale, zero_point = calcScaleZeroPoint(min_val, max_val, args)
	
	if args.quant_opt == 'asym':
		quant_tensor = (input_tensor / scale + zero_point).clamp(qmin, qmax).round()
	if args.quant_opt == 'sym':
		quant_tensor = (input_tensor / scale).clamp(qmax, qmax).round()
	return 	dict(zip(['qTensor', 'scale', 'zero_point'], [quant_tensor, scale, zero_point]))

\begin{equation*}
r = S(q-Z) \\
here, r: real-value, S: scale, q: quantized-value, Z: zero-point
\end{equation*}

In [13]:
def dequantizeTensor(qDict, args):
	if args.quant_opt == 'asym':
		dequant_tensor = qDict['scale'] * (qDict['qTensor'].float() - qDict['zero_point'])
	if args.quant_opt == 'sym':
		dequant_tensor = qDict['scale'] * (qDict['qTensor'].float())
	return dequant_tensor

In [14]:
def updateStats(actTensor, stats, layerName):
	# dim=0 : find min/max in each col
	# dim=1 : find min/max in each row
	maxValue = torch.max(actTensor, dim=1)[0]
	minValue = torch.min(actTensor, dim=1)[0]
	
	if layerName not in stats:
		stats[layerName] = {'max': maxValue.sum(), 'min': minValue.sum(), 'total': 1}
	else:
		stats[layerName]['max'] += maxValue.sum().item()
		stats[layerName]['min'] += maxValue.sum().item()
		stats[layerName]['total'] += 1
		
	weighting = 2.0 / (stats[layerName]['total']) + 1
	
	if 'ema_min' in stats[layerName]:
		stats[layerName]['ema_min'] = weighting*(minValue.mean().item()) + (1 - weighting) * stats[layerName]['ema_min']
	else:
		stats[layerName]['ema_min'] = weighting*(minValue.mean().item())

	if 'ema_max' in stats[layerName]:
		stats[layerName]['ema_max'] = weighting*(maxValue.mean().item()) + (1 - weighting) * stats[layerName]['ema_max']
	else: 
		stats[layerName]['ema_max'] = weighting*(maxValue.mean().item())
	return stats

# Reworked Forward Pass to access activation Stats through updateStats function
def gatherActivationStats(model, x, stats):
	x = model.flatten(x)
	stats = updateStats(x, stats, 'fc1')
	x = F.relu(model.fc1(x))
	
	stats = updateStats(x, stats, 'fc2')
	x = F.relu(model.fc2(x))

	stats = updateStats(x, stats, 'fc3')
# 	x = model.fc3(x)
	return stats

# Entry function to get stats of all functions.
def gatherStats(model, test_loader, args):
	model.eval()
	stats = {}
	with torch.no_grad():
		for image, label in test_loader:
			image, label = image.to(args.device), label.to(args.device)
			stats = gatherActivationStats(model, image, stats)

	final_stats = {}
	for key, value in stats.items():
		final_stats[key] = { "max" : value["max"] / value["total"], "min" : value["min"] / value["total"], "ema_min": value["ema_min"], "ema_max": value["ema_max"] }
	return final_stats

In [15]:
class FakeQuantOp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, min_val, max_val, args):
		#def quantizeTensor(input_tensor, min_val=None, max_val=None, num_bits=4, opt='asym'):
        x = quantizeTensor(x, min_val, max_val, args)
        x = dequantizeTensor(x, args)
        return x

    @staticmethod
    def backward(ctx, grad_output):
        # straight through estimator
        return grad_output, None, None, None

In [16]:
def quantAwareTrainingForward(x, model, stats, args):
	with torch.no_grad():
		params = model.state_dict()
# 		layer_names, _ = zip(*list(model.named_modules())[1:])

	x = model.flatten(x)
	
	model.fc1.weight.data = FakeQuantOp.apply(model.fc1.weight.data, None, None, args)
	model.fc1.bias.data = FakeQuantOp.apply(model.fc1.bias.data, None, None, args)

	with torch.no_grad():
		stats = updateStats(x, stats, 'fc1')
	if args.act_quant:
# 		x = FakeQuantOp.apply(x, stats['fc1']['ema_min'], stats['fc1']['ema_max'], args)
		x = FakeQuantOp.apply(x, None, None, args)

	x = model.relu1(model.fc1(x))
	
	model.fc2.weight.data = FakeQuantOp.apply(model.fc2.weight.data, None, None, args)
	model.fc2.bias.data = FakeQuantOp.apply(model.fc2.bias.data, None, None, args)

	with torch.no_grad():
		stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'fc2')
	if args.act_quant:
# 		x = FakeQuantOp.apply(x, stats['fc2']['ema_min'], stats['fc2']['ema_max'], args)
		x = FakeQuantOp.apply(x, None, None, args)

	x = model.relu2(model.fc2(x))
	
	model.fc3.weight.data = FakeQuantOp.apply(model.fc3.weight.data, None, None, args)
	model.fc3.bias.data = FakeQuantOp.apply(model.fc3.bias.data, None, None, args)

	with torch.no_grad():
		stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'fc3')
	if args.act_quant:
# 		x = FakeQuantOp.apply(x, stats['fc3']['ema_min'], stats['fc3']['ema_max'], args)
		x = FakeQuantOp.apply(x, None, None, args)
	x = model.fc3(x)
	return x, params, stats

In [17]:
def trainQAT(train_loader, model, epoch, stats, args):
	model.train()
	loss_func = genLossFunc(args)
	optimizer = genOptimizer(model, args)
	max_batch_index = int(np.floor(len(train_loader.dataset)/args.batch_size))
	running_loss = 0 
	for batch_index, (image, label) in enumerate(tq(train_loader, desc='Train', leave=False)):
		image, label = image.to(args.device), label.to(args.device)
		optimizer.zero_grad()
		y = model(image)
		y, params, stats = quantAwareTrainingForward(image, model, stats, args)
		
		# Recover FP32 to improve accuracy
		# 		model.fc1.weight.data = params['fc1.weight']
		for param, _ in model.named_parameters():
			exec(f'model.{param}.data = params["{param}"]')

		loss = loss_func(y, label)		
		running_loss += loss.item()#*image.size(0)

		loss.backward()
		optimizer.step()	
		if batch_index == max_batch_index:
			print(f'[Epoch {epoch+1:>2d}] Training Loss: {running_loss/len(train_loader.dataset):.4f}', end = ' / ')
			
	return stats

In [18]:
def testQAT(test_loader, model, epoch, stats, args):
	cmodel = copy.deepcopy(model).to(args.device)
	cmodel.eval()
	with torch.no_grad():
		loss_func = genLossFunc(args)
		loss, correct = 0, 0
		for batch_index, (image, label) in enumerate(test_loader):
			image, label = image.to(args.device), label.to(args.device)
			y, _, _ = quantAwareTrainingForward(image, cmodel, stats, args)	
			loss += loss_func(y, label).item()#*image.size(0)
			correct += (y.argmax(1) == label).type(torch.int).sum().item()
	loss /= len(test_loader.dataset)
	correct_rate = 100 * correct / len(test_loader.dataset)
	print(f'Test Loss:  {loss:>.4f} / Accuracy: {correct}/{len(test_loader.dataset)} ({correct_rate:>.1f}%)')

In [19]:
def mainQAT(model):
	stats = {}
	for epoch in range(args.epochs):
		args.act_quant = True if epoch > 5 else False
		stats = trainQAT(train_loader, model, epoch, stats, args)
		testQAT(test_loader, model, epoch, stats, args)
	print("Done!")
	return model, stats

q_model, old_stats = mainQAT(MLP().to(args.device))
# q_model, old_stats = mainQAT(model)

Train:   0%|          | 0/938 [00:00<?, ?it/s]

[Epoch  1] Training Loss: 0.0092 / Test Loss:  0.0045 / Accuracy: 9163/10000 (91.6%)


Train:   0%|          | 0/938 [00:00<?, ?it/s]

[Epoch  2] Training Loss: 0.0041 / Test Loss:  0.0035 / Accuracy: 9323/10000 (93.2%)


Train:   0%|          | 0/938 [00:00<?, ?it/s]

[Epoch  3] Training Loss: 0.0033 / Test Loss:  0.0030 / Accuracy: 9437/10000 (94.4%)


Train:   0%|          | 0/938 [00:00<?, ?it/s]

[Epoch  4] Training Loss: 0.0029 / Test Loss:  0.0026 / Accuracy: 9501/10000 (95.0%)


Train:   0%|          | 0/938 [00:00<?, ?it/s]

[Epoch  5] Training Loss: 0.0024 / Test Loss:  0.0024 / Accuracy: 9560/10000 (95.6%)


Train:   0%|          | 0/938 [00:00<?, ?it/s]

[Epoch  6] Training Loss: 0.0022 / Test Loss:  0.0023 / Accuracy: 9586/10000 (95.9%)


Train:   0%|          | 0/938 [00:00<?, ?it/s]

[Epoch  7] Training Loss: 0.0020 / Test Loss:  0.0022 / Accuracy: 9595/10000 (96.0%)


Train:   0%|          | 0/938 [00:00<?, ?it/s]

[Epoch  8] Training Loss: 0.0018 / Test Loss:  0.0022 / Accuracy: 9604/10000 (96.0%)


Train:   0%|          | 0/938 [00:00<?, ?it/s]

[Epoch  9] Training Loss: 0.0017 / Test Loss:  0.0022 / Accuracy: 9609/10000 (96.1%)


Train:   0%|          | 0/938 [00:00<?, ?it/s]

[Epoch 10] Training Loss: 0.0016 / Test Loss:  0.0021 / Accuracy: 9615/10000 (96.2%)
Done!


In [20]:
testQAT(test_loader, q_model, epoch=10, stats=old_stats, args=args)

Test Loss:  0.0021 / Accuracy: 9615/10000 (96.2%)


In [21]:
def quantizeLayerForward(layer, x, args):
	c_layer = copy.deepcopy(layer).to(args.device)
	with torch.no_grad():
		y_wo_bias = c_layer(x) - c_layer.bias.data

		yDict = quantizeTensor(y_wo_bias, None, None, args)
		wDict = quantizeTensor(c_layer.weight.data, None, None, args)
		bDict = quantizeTensor(c_layer.bias.data, None, None, args)

		c_layer.weight.data = wDict['qTensor']
		c_layer.bias.data = bDict['qTensor']

		# 64 x # of output nodes
		xDict = quantizeTensor(x, None, None, args)

		# Int 
		z1 = wDict['zero_point']
		z2 = xDict['zero_point']
		z3 = yDict['zero_point']

		a1 = torch.sum(wDict['qTensor'], dim=1)
		a2 = torch.sum(xDict['qTensor'], dim=1)

		q1 = wDict['qTensor']
		q2 = xDict['qTensor']

		m = wDict['scale'] * xDict['scale'] / yDict['scale']
		n = len(x)

		q1_q2_mul = torch.matmul(q1, q2.t())

		a1 = a1.view(a1.size(0), 1) * torch.ones(q1_q2_mul.size()).to(args.device)
		a2 = a2.view(1, a2.size(0)) * torch.ones(q1_q2_mul.size()).to(args.device)

		q_y = z3 + m * ( (n * z1 * z2) - (z1 * a2) - (z2 * a1) + q1_q2_mul)
	
	return dict(zip(['qTensor', 'scale', 'zero_point'], [q_y.round().t(), yDict['scale'], yDict['zero_point']]))

In [22]:
def quantForward(model, x, args):
	flp_act0 = model.flatten(x)
	flp_act1 = model.relu1(model.fc1(flp_act0))
	flp_act2 = model.relu2(model.fc2(flp_act1))
	flp_act3 = model.fc3(flp_act2)
	# Input	
	qDict0 = quantizeTensor(flp_act0, None, None, args)
	flp_act0 = dequantizeTensor(qDict0, args)
	# 1st FC Layer
	qDict1 = quantizeLayerForward(model.fc1, flp_act0, args)
	flp_act1 = model.relu1(dequantizeTensor(qDict1, args) + model.fc1.bias)

	qDict2 = quantizeLayerForward(model.fc2, flp_act1, args)
	flp_act2 = model.relu2(dequantizeTensor(qDict2, args) + model.fc2.bias)

	qDict3 = quantizeLayerForward(model.fc3, flp_act2, args)
	return model, qDict1, qDict2, qDict3

In [23]:
def testQuant(model, test_loader, stats,  args):
	model.eval()
	with torch.no_grad():
		loss_func = genLossFunc(args)
		loss, correct = 0, 0
		for batch_index, (image, label) in enumerate(test_loader):
			image, label = image.to(args.device), label.to(args.device)
			o_model, qDict1, qDict2, qDict3  = quantForward(model, image, args)
			y = dequantizeTensor(qDict3, args) + model.fc3.bias.data
			loss += loss_func(y, label).item()#*image.size(0)
			correct += (y.argmax(1) == label).type(torch.int).sum().item()
	correct_rate = 100 * correct / len(test_loader.dataset)
	print(f'Accuracy: {correct}/{len(test_loader.dataset)} ({correct_rate:>.1f}%) Loss: {loss/len(test_loader.dataset):.4f}')
	return o_model, qDict1, qDict2, qDict3

In [24]:
o_model, q_Dict1, q_Dict2, q_Dict3 = testQuant(q_model, test_loader, old_stats, args)

Accuracy: 9602/10000 (96.0%) Loss: 0.0022
