# Module

In [1]:
import argparse
from tqdm import tqdm_notebook as tq
import tqdm
import os, time, math, copy
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from collections import namedtuple
import matplotlib.pyplot as plt

torch.set_printoptions(precision=8, linewidth=50000)
import warnings
warnings.filterwarnings(action='ignore')

# Print Colors

In [2]:
BLACK	= '\033[30m'
RED		= '\033[31m'
GREEN	= '\033[32m'
YELLOW	= '\033[33m'
BLUE	= '\033[34m'
MAGENTA	= '\033[35m'
CYAN	= '\033[36m'
RESET	= '\033[0m'
SEL		= '\033[7m'

In [3]:
class	fxp:
	def	__init__(self, bIn, iBWF):
		self.iFullBW	= len(bIn)
		self.iIntgBW	= self.iFullBW - iBWF
		self.bSign		= bIn[0]
		self.bIntg		= bIn[:self.iIntgBW]
		self.bFrac		= bIn[self.iIntgBW:]
		self.fFull		= 0
		try:
			for idx, bit in enumerate(bIn):
				if	idx == 0:
					self.fFull = self.fFull + int(bit,2) * -pow(2, self.iIntgBW - 1)
				else:
					self.fFull = self.fFull + int(bit,2) * pow(2, self.iIntgBW - 1 - idx)
		except:
			print(bIn)
		self.dispFull	= RED + self.bIntg + BLUE + self.bFrac + RESET
		return

In [4]:
class	flp2fix:
	def	__init__(self, fIn, iBW, iBWF):
		self.fMin		= fxp('1' + (iBW-1)*'0', iBWF).fFull
		self.fMax		= fxp('0' + (iBW-1)*'1', iBWF).fFull
		self.fResol		= 2 ** -iBWF
		if fIn < self.fMin or fIn > self.fMax:
			print('Out of input range during flp -> fix converting ')
		
		self.iBW		= iBW
		self.iBWI		= iBW - iBWF
		self.iBWF		= iBWF

		self.iFLP2INT	= abs(int(fIn * 2 ** iBWF))
		if fIn < 0:
			self.iFLP2INT = 2 ** (iBW-1) - self.iFLP2INT

		if fIn >= 0:
			self.bFull = bin(self.iFLP2INT)[2:].rjust(iBW, '0')
		else:
			self.bFull = '1'+bin(self.iFLP2INT)[2:].rjust(iBW-1, '0')
			if len(self.bFull) > iBW:
				self.bFull = '0' * iBW

		self.cssFxp		= fxp(self.bFull, self.iBWF)
		self.bSign		= self.cssFxp.bSign
		self.bIntg		= self.cssFxp.bIntg
		self.bFrac		= self.cssFxp.bFrac
		self.fFull		= self.cssFxp.fFull
		return

In [5]:
def	flp2fixTensor(fIn, iBW, iBWF):
	fMin = - 2 ** (iBW - iBWF - 1)
	fMax = (2 ** (iBW-1) - 1) * (2 ** -iBWF)
	fTensor = fIn * (2 ** iBWF)
	fTensor = fTensor.round() * (2 ** -iBWF)
	return fTensor

# User Define Variable

In [6]:
data_path = '~/dataset'

# Parser

In [8]:
parser = argparse.ArgumentParser(description='PyTorch for MNIST dataset')
parser.add_argument('--device', type=str, default='cpu', help='Device')
parser.add_argument('--shuffle', action='store_true', default=False, help='enables data shuffle')
parser.add_argument('--dataset', type=str, default='mnist', help='training dataset')
parser.add_argument('--data_path', type=str, default=data_path, help='path to MNIST')
parser.add_argument('--batch_size', type=int, default=64, help='batch size')
parser.add_argument('--epochs', type=int, default=10, help='number of epochs to train')
parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
parser.add_argument('--optimizer', type=str, default='adam', help='optimizer')
parser.add_argument('--loss_func', type=str, default='cel', help='optimizer')
parser.add_argument('--quant_opt', type=str, default='asym', help='Type of Quantization')
parser.add_argument('--full_bits', type=int, default=8, help='Number of Quantization Bits')
parser.add_argument('--frac_bits', type=int, default=4, help='Number of Quantization Bits')
parser.add_argument('--act_quant', type=bool, default=False, help='Activation Quantization')
parser.add_argument('--disp', type=bool, default=False, help='Display Model Information')

args = parser.parse_args(args=[])

# Preparing Data

In [9]:
kwargs = {'num_workers': 1, 'pin_memory': True} if args.device == 'cuda' else {}
if args.dataset == 'mnist':
	train_loader = torch.utils.data.DataLoader(
		dataset=datasets.MNIST(
			root=args.data_path,
			train=True,
			download=True,
			transform=transforms.ToTensor()
		),
		batch_size=args.batch_size,
		shuffle=args.shuffle,
		**kwargs
	)

	test_loader = torch.utils.data.DataLoader(
		dataset=datasets.MNIST(
			root=args.data_path,
			train=False,
			download=True,
			transform=transforms.ToTensor()
		),
		batch_size=args.batch_size,
		shuffle=args.shuffle,
		**kwargs
	)

# Build Model

In [10]:
class MLP(nn.Module):
	def __init__(self):
		super(MLP, self).__init__()
		self.flatten = nn.Flatten()
		self.fc1 = nn.Linear(28*28, 32)
		self.relu1 = nn.ReLU()
		self.fc2 = nn.Linear(32, 16)
		self.relu2 = nn.ReLU()
		self.fc3 = nn.Linear(16, 10)
		
	def forward(self, x):
		x = self.flatten(x)
		x = self.fc1(x)
		x = self.relu1(x)
		x = self.fc2(x)
		x = self.relu2(x)
		logits = self.fc3(x)
		return logits
	

In [11]:
def genOptimizer(model, args):
	if args.optimizer == 'sgd':
		optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
	if args.optimizer == 'adam':
		optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
	return optimizer

def genLossFunc(args):
	if args.loss_func == 'cel':
		loss_func = nn.CrossEntropyLoss()
	return loss_func


In [12]:
def train(train_loader, model, epoch, args):
	model.train()
	loss_func = genLossFunc(args)
	optimizer = genOptimizer(model, args)
	max_batch_index = int(np.floor(len(train_loader.dataset)/args.batch_size))
	running_loss = 0
	for batch_index, (image, label) in enumerate(tq(train_loader, desc='Train', leave=False)):
		image, label = image.to(args.device), label.to(args.device)
		pred = model(image)
		loss = loss_func(pred, label)
		running_loss += loss.item()#*image.size(0)

		optimizer.zero_grad()
		loss.backward()
		optimizer.step()
				
	print(f'Epoch {epoch+1:<3d}: Avg. Loss: {running_loss/len(train_loader.dataset):.4f}', end = '\t')

In [13]:
def test(test_loader, model, epoch, args):
	model.eval()
	with torch.no_grad():
		loss_func = genLossFunc(args)
		loss, correct = 0, 0
# 		for batch_index, (image, label) in enumerate(tq(test_loader, desc='Test', leave=False)):
		for batch_index, (image, label) in enumerate(test_loader):
			image, label = image.to(args.device), label.to(args.device)
			pred = model(image)
			loss += loss_func(pred, label).item()#*image.size(0)
			correct += (pred.argmax(1) == label).type(torch.int).sum().item()
	loss /= len(test_loader.dataset)
	correct_rate = 100 * correct / len(test_loader.dataset)
	print(f'Accuracy: {correct}/{len(test_loader.dataset)} ({correct_rate:>.1f}%)')

In [14]:
def main(model):
	for epoch in range(args.epochs):
		train(train_loader, model, epoch, args)
		test(test_loader, model, epoch, args)
	print("Done!")
	return model

model = main(MLP().to(args.device))

Train:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 1  : Avg. Loss: 0.0079	Accuracy: 9192/10000 (91.9%)


Train:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 2  : Avg. Loss: 0.0037	Accuracy: 9371/10000 (93.7%)


Train:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 3  : Avg. Loss: 0.0030	Accuracy: 9463/10000 (94.6%)


Train:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 4  : Avg. Loss: 0.0025	Accuracy: 9501/10000 (95.0%)


Train:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 5  : Avg. Loss: 0.0022	Accuracy: 9544/10000 (95.4%)


Train:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 6  : Avg. Loss: 0.0020	Accuracy: 9562/10000 (95.6%)


Train:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 7  : Avg. Loss: 0.0018	Accuracy: 9575/10000 (95.8%)


Train:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 8  : Avg. Loss: 0.0017	Accuracy: 9584/10000 (95.8%)


Train:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 9  : Avg. Loss: 0.0016	Accuracy: 9597/10000 (96.0%)


Train:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 10 : Avg. Loss: 0.0015	Accuracy: 9610/10000 (96.1%)
Done!


In [15]:
def model2fix(model, full_width=args.full_bits, frac_width=args.frac_bits):
	for name, _ in model.named_parameters():
		exec(f'model.{name}.data = flp2fixTensor(model.{name}.data, {full_width}, {frac_width})')
	return model

def in2fix(images, full_width=args.full_bits, frac_width=args.frac_bits):
	dim_images = images.size()
	images = images.view(-1)
	for idx_image, image in enumerate(images):
		temp_css = flp2fix(image, full_width, frac_width)
		images[idx_image] = torch.tensor(temp_css.fFull)
		del temp_css
	return images.view(dim_images)

In [16]:
def quantFixForward(model, x, args):
	cmodel = copy.deepcopy(model).to(args.device)
	with torch.no_grad():
		x = cmodel.flatten(x)
		act0 = flp2fixTensor(x, args.full_bits, args.frac_bits)
		
		act1 = model.relu1(cmodel.fc1(act0))
		act1 = flp2fixTensor(act1, args.full_bits, args.frac_bits)

		act2 = cmodel.relu2(cmodel.fc2(act1))
		act2 = flp2fixTensor(act2, args.full_bits, args.frac_bits)
		
		act3 = cmodel.fc3(act2)
		act3 = flp2fixTensor(act3, args.full_bits, args.frac_bits)
	return cmodel, act0, act1, act2, act3 

In [17]:
def testQuant(model, test_loader, args):
	qmodel = copy.deepcopy(model).to(args.device)
	qmodel = model2fix(qmodel, args.full_bits, args.frac_bits)
	qmodel.eval()

	with torch.no_grad():
		loss_func = genLossFunc(args)
		loss, correct = 0, 0
		for batch_index, (image, label) in enumerate(test_loader):
			image, label = image.to(args.device), label.to(args.device)
			qmodel, act0, act1, act2, act3  = quantFixForward(qmodel, image, args)
			y = act3
			loss += loss_func(y, label).item()#*image.size(0)
			correct += (y.argmax(1) == label).type(torch.int).sum().item()
	correct_rate = 100 * correct / len(test_loader.dataset)
	print(f'Accuracy: {correct}/{len(test_loader.dataset)} ({correct_rate:>.1f}%) Loss: {loss/len(test_loader.dataset):.2f}')
	return qmodel, act0, act1, act2, act3

In [18]:
qmodel, act0, act1, act2, act3 = testQuant(model, test_loader, args)

Accuracy: 9597/10000 (96.0%) Loss: 0.00


In [19]:
act1

tensor([[ 0.00000000,  6.93750000,  0.00000000,  1.50000000,  0.00000000,  2.37500000,  5.93750000,  0.00000000,  2.31250000,  0.00000000,  3.87500000,  1.00000000,  2.62500000,  3.12500000,  0.00000000,  1.12500000,  0.00000000,  1.00000000,  0.00000000,  4.43750000,  2.43750000,  1.00000000,  2.93750000,  0.81250000,  1.75000000,  4.31250000,  2.43750000,  4.25000000,  4.31250000,  2.06250000,  0.00000000,  0.00000000],
        [ 0.00000000,  4.62500000,  0.00000000,  0.68750000,  0.31250000,  0.00000000,  0.00000000,  4.75000000,  0.00000000,  4.56250000,  7.06250000,  5.50000000,  0.00000000,  5.00000000,  3.18750000,  0.00000000,  1.50000000,  3.43750000,  0.00000000,  2.50000000,  6.18750000,  0.00000000,  8.25000000,  0.00000000,  3.56250000,  5.50000000,  6.12500000,  2.12500000,  0.00000000,  5.75000000,  0.43750000,  0.00000000],
        [ 0.00000000, 11.06250000,  0.00000000,  4.18750000,  0.00000000,  0.00000000,  2.31250000,  1.75000000,  2.93750000,  2.43750000,  6.875000