# Module

In [1]:
import argparse
from tqdm import tqdm_notebook as tq
import tqdm
import os, time, math
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from collections import namedtuple
import matplotlib.pyplot as plt

torch.set_printoptions(precision=4, linewidth=50000)
import warnings
warnings.filterwarnings(action='ignore')

# User Define Variable

In [2]:
data_path = '~/dataset'

# Parser

In [3]:
parser = argparse.ArgumentParser(description='PyTorch for MNIST dataset')
parser.add_argument('--device', type=str, default='cuda', help='Device')
parser.add_argument('--shuffle', action='store_true', default=False, help='enables data shuffle')
parser.add_argument('--dataset', type=str, default='mnist', help='training dataset')
parser.add_argument('--data_path', type=str, default=data_path, help='path to MNIST')
parser.add_argument('--batch_size', type=int, default=64, help='batch size')
parser.add_argument('--epochs', type=int, default=10, help='number of epochs to train')
parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
parser.add_argument('--optimizer', type=str, default='adam', help='optimizer')
parser.add_argument('--loss_func', type=str, default='cel', help='optimizer')
parser.add_argument('--quant_opt', type=str, default='asym', help='Type of Quantization')
parser.add_argument('--num_bits', type=int, default=8, help='Number of Quantization Bits')
parser.add_argument('--act_quant', type=bool, default=False, help='Activation Quantization')
parser.add_argument('--disp', type=bool, default=False, help='Display Model Information')

args = parser.parse_args(args=[])

# Preparing Data

In [4]:
kwargs = {'num_workers': 1, 'pin_memory': True} if args.device == 'cuda' else {}
if args.dataset == 'mnist':
	train_loader = torch.utils.data.DataLoader(
		dataset=datasets.MNIST(
			root=args.data_path,
			train=True,
			download=True,
			transform=transforms.ToTensor()
		),
		batch_size=args.batch_size,
		shuffle=args.shuffle,
		**kwargs
	)

	test_loader = torch.utils.data.DataLoader(
		dataset=datasets.MNIST(
			root=args.data_path,
			train=False,
			download=True,
			transform=transforms.ToTensor()
		),
		batch_size=args.batch_size,
		shuffle=args.shuffle,
		**kwargs
	)

# Build Model

In [5]:
class MyReLU(torch.autograd.Function):
	@staticmethod
	def forward(ctx, input):
		ctx.save_for_backward(input)
		return input.clamp(min=0, max=2**15)
	@staticmethod
	def backward(ctx, grad_output):
		input, = ctx.saved_tensors
		grad_input = grad_output.clone()
		grad_input[input < 0] = 0
		return grad_input

class MLP(nn.Module):
	def __init__(self):
		super(MLP, self).__init__()
		self.flatten = nn.Flatten()
		self.fc1 = nn.Linear(28*28, 32)
		self.relu1 = nn.ReLU()
		self.fc2 = nn.Linear(32, 16)
		self.relu2 = nn.ReLU()
		self.fc3 = nn.Linear(16, 10)
		
	def forward(self, x):
		x = self.flatten(x)
		x = self.fc1(x)
		x = self.relu1(x)
		x = self.fc2(x)
		x = self.relu2(x)
		logits = self.fc3(x)
		return logits
	
class LENET5(nn.Module):
	def __init__(self):
		super(LENET5, self).__init__()
		#C1
		self.conv1 = nn.Conv2d(1, 6, 5)
		self.relu1 = nn.ReLU()
		#S2
		self.pool1 = nn.MaxPool2d(2)
		#C3
		self.conv2 = nn.Conv2d(6, 16, 5)
		self.relu2 = nn.ReLU()
		#S4
		self.pool2 = nn.MaxPool2d(2)
		#C5
		self.flatten = nn.Flatten()
		self.fc1 = nn.Linear(256, 120)
		self.relu3 = nn.ReLU()
		#F6
		self.fc2 = nn.Linear(120, 84)
		self.relu4 = nn.ReLU()
		#OUTPUT
		self.fc3 = nn.Linear(84, 10)
		self.relu5 = nn.ReLU()

	def forward(self, x):
		x = self.conv1(x)
		x = self.relu1(x)
		x = self.pool1(x)
		x = self.conv2(x)
		x = self.relu2(x)
		x = self.pool2(x)
		x = self.flatten(x)
		x = self.fc1(x)
		x = self.relu3(x)
		x = self.fc2(x)
		x = self.relu4(x)
		x = self.fc3(x)
		logits = self.relu5(x)
		return logits

In [6]:
def genOptimizer(model, args):
	if args.optimizer == 'sgd':
		optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
	if args.optimizer == 'adam':
		optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
	return optimizer

def genLossFunc(args):
	if args.loss_func == 'cel':
		loss_func = nn.CrossEntropyLoss()
	return loss_func


In [7]:
def accuracy(output, target, topk=(1,)):
	with torch.no_grad():
		maxk = max(topk)
		batch_size = target.size(0)

		_, pred = output.topk(maxk, 1, True, True)
		pred = pred.t()
		correct = pred.eq(target.view(1, -1).expand_as(pred))

		res = []
		for k in topk:
			correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
			res.append(correct_k.mul_(100.0 / batch_size))
		return res

In [8]:
def train(train_loader, model, epoch, args):
	model.train()
	loss_func = genLossFunc(args)
	optimizer = genOptimizer(model, args)
	max_batch_index = int(np.floor(len(train_loader.dataset)/args.batch_size))
	running_loss = 0
	for batch_index, (image, label) in enumerate(tq(train_loader, desc='Train', leave=False)):
		image, label = image.to(args.device), label.to(args.device)
		pred = model(image)
		loss = loss_func(pred, label)
		running_loss += loss.item()#*image.size(0)

		optimizer.zero_grad()
		loss.backward()
		optimizer.step()
				
	print(f'Epoch {epoch+1:<3d}: Avg. Loss: {running_loss/len(train_loader.dataset):.4f}', end = '\t')

In [9]:
def test(test_loader, model, epoch, args):
	model.eval()
	loss_func = genLossFunc(args)
	loss, correct = 0, 0
	with torch.no_grad():
# 		for batch_index, (image, label) in enumerate(tq(test_loader, desc='Test', leave=False)):
		for batch_index, (image, label) in enumerate(test_loader):
			image, label = image.to(args.device), label.to(args.device)
			pred = model(image)
			loss += loss_func(pred, label).item()#*image.size(0)
			correct += (pred.argmax(1) == label).type(torch.int).sum().item()
	loss /= len(test_loader.dataset)
	correct_rate = 100 * correct / len(test_loader.dataset)
	print(f'Accuracy: {correct}/{len(test_loader.dataset)} ({correct_rate:>.1f}%)')

In [10]:
def main(model):
	for epoch in range(args.epochs):
		train(train_loader, model, epoch, args)
		test(test_loader, model, epoch, args)
	print("Done!")
	return model

model = main(MLP().to(args.device))

Train:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 1  : Avg. Loss: 0.0084	Accuracy: 9086/10000 (90.9%)


Train:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 2  : Avg. Loss: 0.0043	Accuracy: 9263/10000 (92.6%)


Train:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 3  : Avg. Loss: 0.0035	Accuracy: 9364/10000 (93.6%)


Train:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 4  : Avg. Loss: 0.0029	Accuracy: 9444/10000 (94.4%)


Train:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 5  : Avg. Loss: 0.0025	Accuracy: 9494/10000 (94.9%)


Train:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 6  : Avg. Loss: 0.0023	Accuracy: 9534/10000 (95.3%)


Train:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 7  : Avg. Loss: 0.0020	Accuracy: 9549/10000 (95.5%)


Train:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 8  : Avg. Loss: 0.0018	Accuracy: 9569/10000 (95.7%)


Train:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 9  : Avg. Loss: 0.0017	Accuracy: 9577/10000 (95.8%)


Train:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 10 : Avg. Loss: 0.0016	Accuracy: 9597/10000 (96.0%)
Done!


In [11]:
def calcScaleZeroPoint(min_val, max_val, args):
	qmin, qmax = 0., 2.**args.num_bits - 1.
	if args.quant_opt == 'asym':
		scale = (max_val - min_val) / (qmax - qmin)
		if scale is not 0.:
			scale = 2**round(math.log(abs(scale), 2))
		initial_zero_point = -min_val / scale
		zero_point = int({initial_zero_point < qmin: qmin,  initial_zero_point > qmax: qmax, qmin <= initial_zero_point <= qmax: initial_zero_point}.get(True, False))
	if args.quant_opt == 'sym':
		scale = max(abs(min_val), abs(max_val)) / qmax
		zero_point = 0
	return scale, zero_point

def quantizeTensor(input_tensor, min_val, max_val, args):
	if not min_val and not max_val:
		min_val, max_val = input_tensor.min(), input_tensor.max()
	
	qmin, qmax = 0., 2.**args.num_bits - 1.
	scale, zero_point = calcScaleZeroPoint(min_val, max_val, args)
	
	if args.quant_opt == 'asym':
		quant_tensor = (input_tensor / scale + zero_point).clamp(qmin, qmax).round()
	if args.quant_opt == 'sym':
		quant_tensor = (input_tensor / scale).clamp(qmax, qmax).round()
	return 	dict(zip(['qTensor', 'scale', 'zero_point'], [quant_tensor, scale, zero_point]))

\begin{equation*}
r = S(q-Z) \\
here, r: real-value, S: scale, q: quantized-value, Z: zero-point
\end{equation*}

In [12]:
def dequantizeTensor(qDict, args):
	if args.quant_opt == 'asym':
		dequant_tensor = qDict['scale'] * (qDict['qTensor'].float() - qDict['zero_point'])
	if args.quant_opt == 'sym':
		dequant_tensor = qDict['scale'] * (qDict['qTensor'].float())	
	return dequant_tensor

In [13]:
def updateStats(actTensor, stats, layerName):
	# dim=0 : find min/max in each col
	# dim=1 : find min/max in each row
	maxValue = torch.max(actTensor, dim=1)[0]
	minValue = torch.min(actTensor, dim=1)[0]
	
	if layerName not in stats:
		stats[layerName] = {'max': maxValue.sum(), 'min': minValue.sum(), 'total': 1}
	else:
		stats[layerName]['max'] += maxValue.sum().item()
		stats[layerName]['min'] += maxValue.sum().item()
		stats[layerName]['total'] += 1
		
	weighting = 2.0 / (stats[layerName]['total']) + 1
	
	if 'ema_min' in stats[layerName]:
		stats[layerName]['ema_min'] = weighting*(minValue.mean().item()) + (1 - weighting) * stats[layerName]['ema_min']
	else:
		stats[layerName]['ema_min'] = weighting*(minValue.mean().item())

	if 'ema_max' in stats[layerName]:
		stats[layerName]['ema_max'] = weighting*(maxValue.mean().item()) + (1 - weighting) * stats[layerName]['ema_max']
	else: 
		stats[layerName]['ema_max'] = weighting*(maxValue.mean().item())
	return stats

# Reworked Forward Pass to access activation Stats through updateStats function
def gatherActivationStats(model, x, stats):
	x = model.flatten(x)
	stats = updateStats(x, stats, 'fc1')
	x = F.relu(model.fc1(x))
	
	stats = updateStats(x, stats, 'fc2')
	x = F.relu(model.fc2(x))

	stats = updateStats(x, stats, 'fc3')
# 	x = model.fc3(x)
	return stats

# Entry function to get stats of all functions.
def gatherStats(model, test_loader, args):
	model.eval()
	stats = {}
	with torch.no_grad():
		for image, label in test_loader:
			image, label = image.to(args.device), label.to(args.device)
			stats = gatherActivationStats(model, image, stats)

	final_stats = {}
	for key, value in stats.items():
		final_stats[key] = { "max" : value["max"] / value["total"], "min" : value["min"] / value["total"], "ema_min": value["ema_min"], "ema_max": value["ema_max"] }
	return final_stats

In [14]:
class FakeQuantOp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, min_val, max_val, args):
		#def quantizeTensor(input_tensor, min_val=None, max_val=None, num_bits=4, opt='asym'):
        x = quantizeTensor(x, min_val, max_val, args)
        x = dequantizeTensor(x, args)
        return x

    @staticmethod
    def backward(ctx, grad_output):
        # straight through estimator
        return grad_output, None, None, None

In [15]:
def quantAwareTrainingForward(x, model, stats, args):
	with torch.no_grad():
		params = model.state_dict()
# 		layer_names, _ = zip(*list(model.named_modules())[1:])

	x = model.flatten(x)
	
	model.fc1.weight.data = FakeQuantOp.apply(model.fc1.weight.data, None, None, args)
	model.fc1.bias.data = FakeQuantOp.apply(model.fc1.bias.data, None, None, args)

# 	x = F.relu(model.fc1(x))
	with torch.no_grad():
		stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'fc1')
	if args.act_quant:
		x = FakeQuantOp.apply(x, stats['fc1']['ema_min'], stats['fc1']['ema_max'], args)
# 	x = model.relu1(model.fc1(x))
	relu1 = MyReLU.apply
	x = relu1(model.fc1(x))
	
	model.fc2.weight.data = FakeQuantOp.apply(model.fc2.weight.data, None, None, args)
	model.fc2.bias.data = FakeQuantOp.apply(model.fc2.bias.data, None, None, args)
# 	x = F.relu(model.fc2(x))
	with torch.no_grad():
		stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'fc2')
	if args.act_quant:
		x = FakeQuantOp.apply(x, stats['fc2']['ema_min'], stats['fc2']['ema_max'], args)
# 	x = model.relu2(model.fc2(x))
	relu2 = MyReLU.apply
	x = relu2(model.fc2(x))
	
	model.fc3.weight.data = FakeQuantOp.apply(model.fc3.weight.data, None, None, args)
	model.fc3.bias.data = FakeQuantOp.apply(model.fc3.bias.data, None, None, args)
# 	x = model.fc3(x)
	with torch.no_grad():
		stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'fc3')
	if args.act_quant:
		x = FakeQuantOp.apply(x, stats['fc3']['ema_min'], stats['fc3']['ema_max'], args)
	x = model.fc3(x)	
	return x, params, stats

In [16]:
def trainQAT(train_loader, model, epoch, stats, args):
	model.train()
	loss_func = genLossFunc(args)
	optimizer = genOptimizer(model, args)
	max_batch_index = int(np.floor(len(train_loader.dataset)/args.batch_size))
	running_loss = 0 
	for batch_index, (image, label) in enumerate(tq(train_loader, desc='Train', leave=False)):
		image, label = image.to(args.device), label.to(args.device)
		optimizer.zero_grad()
		y = model(image)
		y, params, stats = quantAwareTrainingForward(image, model, stats, args)
		
		# Recover FP32 to improve accuracy
		# 		model.fc1.weight.data = params['fc1.weight']
		for param, _ in model.named_parameters():
			exec(f'model.{param}.data = params["{param}"]')

		loss = loss_func(y, label)		
		running_loss += loss.item()#*image.size(0)

		loss.backward()
		optimizer.step()	
		if batch_index == max_batch_index:
			print(f'[Epoch {epoch+1:>2d}] Training Loss: {running_loss/len(train_loader.dataset):.4f}', end = ' / ')
			
	return stats

In [17]:
def testQAT(test_loader, model, epoch, stats, args):
	model.eval()
	loss_func = genLossFunc(args)
	loss, correct = 0, 0
	with torch.no_grad():
		for batch_index, (image, label) in enumerate(test_loader):
			image, label = image.to(args.device), label.to(args.device)
			y, _, _ = quantAwareTrainingForward(image, model, stats, args)	
			loss += loss_func(y, label).item()#*image.size(0)
			correct += (y.argmax(1) == label).type(torch.int).sum().item()
	loss /= len(test_loader.dataset)
	correct_rate = 100 * correct / len(test_loader.dataset)
	print(f'Test Loss:  {loss:>.4f} / Accuracy: {correct}/{len(test_loader.dataset)} ({correct_rate:>.1f}%)')

In [None]:
def mainQAT(model):
	stats = {}
	for epoch in range(args.epochs):
		args.act_quant = True if epoch > 5 else False
		stats = trainQAT(train_loader, model, epoch, stats, args)
		testQAT(test_loader, model, epoch, stats, args)
	print("Done!")
	return model, stats

q_model, old_stats = mainQAT(MLP().to(args.device))
# q_model, old_stats = mainQAT(model)

Train:   0%|          | 0/938 [00:00<?, ?it/s]

[Epoch  1] Training Loss: 0.0142 / Test Loss:  0.0073 / Accuracy: 8666/10000 (86.7%)


Train:   0%|          | 0/938 [00:00<?, ?it/s]

[Epoch  2] Training Loss: 0.0068 / Test Loss:  0.0061 / Accuracy: 8868/10000 (88.7%)


Train:   0%|          | 0/938 [00:00<?, ?it/s]

In [None]:
testQAT(test_loader, q_model, epoch=10, stats=old_stats, args=args)

In [None]:
def quantizeLayer(layer, x, args):
	wDict = quantizeTensor(layer.weight.data, None, None, args)
	bDict = quantizeTensor(layer.bias.data, None, None, args)
	layer.weight.data = wDict['qTensor']
	layer.bias.data = bDict['qTensor']
	
	xDict = quantizeTensor(x, None, None, args)
	oActDict = quantizeTensor(layer(xDict['qTensor']), None, None, args)
	return oActDict

def quantForward(model, x, args):
	x = model.flatten(x)
	# Input
	actDict0 = quantizeTensor(x, None, None, args)
	# 1st FC Layer
	actDict1 = quantizeLayer(model.fc1, actDict0['qTensor'], args)
	act1 = model.relu1(actDict1['qTensor'])
	actDict2 = quantizeLayer(model.fc2, act1, args)
	act2 = model.relu2(actDict2['qTensor'])
	actDict3 = quantizeLayer(model.fc3, act2, args)
	return actDict0, actDict1, actDict2, actDict3 

In [None]:
def testQuant(model, test_loader, stats, quant, args):
	model.eval()
	loss_func = genLossFunc(args)
	loss, correct = 0, 0
	with torch.no_grad():
		for batch_index, (image, label) in enumerate(test_loader):
			image, label = image.to(args.device), label.to(args.device)
			if quant:
				actDict0, actDict1, actDict2, actDict3  = quantForward(model, image, args)
				y = actDict3['qTensor']
			else:
				y = model(image)
			loss += loss_func(y, label).item()#*image.size(0)
			correct += (y.argmax(1) == label).type(torch.int).sum().item()
	correct_rate = 100 * correct / len(test_loader.dataset)
	print(f'Accuracy: {correct}/{len(test_loader.dataset)} ({correct_rate:>.1f}%) Loss: {loss/len(test_loader.dataset):.2f}')
	return actDict0, actDict1, actDict2, actDict3


In [None]:
actDict0, actDict1, actDict2, actDict3 = testQuant(q_model, test_loader, old_stats, True, args)
# testQuant(MLP().to(args.device), test_loader, old_stats, True, args)

In [None]:
actDict0['qTensor'][15].view(28,28)

In [None]:
for actDict in [actDict0, actDict1, actDict2, actDict3]:
	print(actDict)

In [None]:
quantizeTensor(q_model.fc1(actDict0['qTensor']), None, None, args)