# Module

In [1]:
import argparse
from tqdm import tqdm_notebook as tq
import tqdm
import os, time, math, copy
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from collections import namedtuple
import matplotlib.pyplot as plt
import random
import time
import datetime

torch.set_printoptions(precision=8, linewidth=50000)
import warnings
warnings.filterwarnings(action='ignore')

  warn(f"Failed to load image Python extension: {e}")


# Print Colors

In [2]:
BLACK	= '\033[30m'
RED		= '\033[31m'
GREEN	= '\033[32m'
YELLOW	= '\033[33m'
BLUE	= '\033[34m'
MAGENTA	= '\033[35m'
CYAN	= '\033[36m'
RESET	= '\033[0m'
SEL		= '\033[7m'

In [3]:
def int2bin(iIn,iBW):
    iBW = iBW + 1
    if iIn >= 0:
        bOut = bin(iIn).replace('0b','').rjust(iBW,'0')
    else :
        bOut = bin(iIn & (pow(2,iBW)-1)).replace('0b','').rjust(iBW,'1')
    return bOut[1:]

In [4]:
def XOR(iA,iB):
    if iA != iB :
        iOut = '1'
    else : 
        iOut = '0'
    return iOut

In [5]:
def snum(a):
    if a >= 0 :
        return '0'
    else :
        return '1'

In [6]:
def binInv(bIn):
    bOut = bin(int(bIn,2)^(pow(2,len(bIn))-1)).replace('0b','').rjust(len(bIn),'0')
    return bOut

In [7]:
class fxp:
    def __init__(self, bIn, iBWF):
        self.iFullBW = len(bIn)
        self.iIntgBW = self.iFullBW - iBWF
        self.bSign = bIn[0]
        self.bIntg = bIn[:self.iIntgBW]
        self.bFrac = bIn[self.iIntgBW:]
        self.fFull = 0
        try:
            for idx, bit in enumerate(bIn):
                if idx == 0:
                    self.fFull = self.fFull + int(bit,2) * -pow(2, self.iIntgBW - 1)
                else:
                    self.fFull = self.fFull + int(bit,2) * pow(2, self.iIntgBW - 1 - idx)
        except:
            print(bIn)
        self.dispFull = self.bIntg +"."+ self.bFrac 
        return

In [8]:
class flp2fix:
    def __init__(self, fIn, iBW, iBWF):
        self.fMin = - 2 ** (iBW - iBWF - 1)
        self.fMax = (2 ** (iBW-1) - 1) * (2 ** -iBWF)
        self.fResol = 2 ** -iBWF
        #if fIn < self.fMin or fIn > self.fMax:
            #print(f'({fIn}): Out of input range ({self.fMax}/{self.fMin}) during flp -> fix converting ')
        self.iBW = iBW
        self.iBWI = iBW - iBWF
        self.iBWF = iBWF

        self.iFLP2INT = abs(int(fIn * 2 ** iBWF))
        if fIn < 0:
            self.iFLP2INT = 2 ** (iBW-1) - self.iFLP2INT

        if fIn >= 0:
            self.bFull = bin(self.iFLP2INT)[2:].rjust(iBW, '0')
        else:
            self.bFull = '1'+bin(self.iFLP2INT)[2:].rjust(iBW-1, '0')
            if len(self.bFull) > iBW:
                self.bFull = '0' * iBW

        self.cssFxp = fxp(self.bFull, self.iBWF)
        self.bSign = self.cssFxp.bSign
        self.bIntg = self.cssFxp.bIntg
        self.bFrac = self.cssFxp.bFrac
        self.fFull = self.cssFxp.fFull
        return

In [9]:
def flp2fixTensor(fIn, iBW, iBWF):
    fMin = - 2 ** (iBW - iBWF - 1)
    fMax = (2 ** (iBW-1) - 1) * (2 ** -iBWF)
    fList = []
    for aTensor in fIn.view(-1):
        fList.append(flp2fix(aTensor, iBW, iBWF).fFull)
    return torch.tensor(fList).view(fIn.size())

# User Define Variables

In [10]:
data_path = '~/dataset'

In [11]:
parser = argparse.ArgumentParser(description='PyTorch for MNIST dataset')
parser.add_argument('--device', type=str, default='cpu', help='Device')
parser.add_argument('--shuffle', action='store_true', default=False, help='enables data shuffle')
parser.add_argument('--dataset', type=str, default='mnist', help='training dataset')
parser.add_argument('--data_path', type=str, default=data_path, help='path to MNIST')
parser.add_argument('--batch_size', type=int, default=64, help='batch size')
parser.add_argument('--epochs', type=int, default=10, help='number of epochs to train')
parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
parser.add_argument('--optimizer', type=str, default='adam', help='optimizer')
parser.add_argument('--loss_func', type=str, default='cel', help='optimizer')
parser.add_argument('--quant_opt', type=str, default='asym', help='Type of Quantization')
parser.add_argument('--full_bits', type=int, default=16, help='Number of Quantization Bits')
parser.add_argument('--frac_bits', type=int, default=8, help='Number of Quantization Bits')
#parser.add_argument('--pretrained', type=bool, default=True, help='Pretrained Model')
parser.add_argument('--act_quant', type=bool, default=False, help='Activation Quantization')
parser.add_argument('--disp', type=bool, default=False, help='Display Model Information')
parser.add_argument('--bBW',type=int,default=6,help='bit number')
args = parser.parse_args(args=[])

# Preparing Data

In [12]:
kwargs = {'num_workers': 1, 'pin_memory': True} if args.device == 'cuda' else {}
if args.dataset == 'mnist':
	train_loader = torch.utils.data.DataLoader(
		dataset=datasets.MNIST(
			root=args.data_path,
			train=True,
			download=True,
			transform=transforms.ToTensor()
		),
		batch_size=args.batch_size,
		shuffle=args.shuffle,
		**kwargs
	)

	test_loader = torch.utils.data.DataLoader(
		dataset=datasets.MNIST(
			root=args.data_path,
			train=False,
			download=True,
			transform=transforms.ToTensor()
		),
		batch_size=args.batch_size,
		shuffle=args.shuffle,
		**kwargs
	)

# Build Model

In [13]:
class MLP(nn.Module):
	def __init__(self):
		super(MLP, self).__init__()
		self.flatten = nn.Flatten()
		self.fc1 = nn.Linear(28*28, 32)
		self.relu1 = nn.ReLU()
		self.fc2 = nn.Linear(32, 64)
		self.relu2 = nn.ReLU()
		self.fc3 = nn.Linear(64, 10)
		
	def forward(self, x):
		x = self.flatten(x)
		x = self.fc1(x)
		x = self.relu1(x)
		x = self.fc2(x)
		x = self.relu2(x)
		logits = self.fc3(x)
		return logits

In [14]:
def genOptimizer(model, args):
	if args.optimizer == 'sgd':
		optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
	if args.optimizer == 'adam':
		optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
	return optimizer

def genLossFunc(args):
	if args.loss_func == 'cel':
		loss_func = nn.CrossEntropyLoss()
	return loss_func

In [15]:
def train(train_loader, model, epoch, args):
	model.train()
	loss_func = genLossFunc(args)
	optimizer = genOptimizer(model, args)
	max_batch_index = int(np.floor(len(train_loader.dataset)/args.batch_size))
	running_loss = 0
	for batch_index, (image, label) in enumerate(tq(train_loader, desc='Train', leave=False)):
		image, label = image.to(args.device), label.to(args.device)
		pred = model(image)
		loss = loss_func(pred, label)
		running_loss += loss.item()#*image.size(0)

		optimizer.zero_grad()
		loss.backward()
		optimizer.step()
				
	print(f'Epoch {epoch+1:<3d}: Avg. Loss: {running_loss/len(train_loader.dataset):.4f}', end = '\t')

In [16]:
def test(test_loader, model, args):
	model.eval()
	with torch.no_grad():
		loss_func = genLossFunc(args)
		loss, correct = 0, 0
# 		for batch_index, (image, label) in enumerate(tq(test_loader, desc='Test', leave=False)):
		for batch_index, (image, label) in enumerate(test_loader):
			image, label = image.to(args.device), label.to(args.device)
			pred = model(image)
			loss += loss_func(pred, label).item()#*image.size(0)
			correct += (pred.argmax(1) == label).type(torch.int).sum().item()
	loss /= len(test_loader.dataset)
	correct_rate = 100 * correct / len(test_loader.dataset)
	print(f'Accuracy: {correct}/{len(test_loader.dataset)} ({correct_rate:>.1f}%)')

In [17]:
def main(model):
	for epoch in range(args.epochs):
		train(train_loader, model, epoch, args)
		test(test_loader, model, args)
	print("Done!")
	return model

In [18]:
#if args.pretrained:
#    if os.path.isfile('preTrainedModel.pth'):
#        model = MLP().to(args.device)
#        model.load_state_dict(torch.load('preTrainedModel.pth'))
#        test(test_loader, model, args)
#    else:
#        model = main(MLP().to(args.device))
#        torch.save(model.state_dict(), 'preTrainedModel.pth')
#else:
model = main(MLP().to(args.device))
torch.save(model.state_dict(), 'preTrainedModel.pth')

Train:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 1  : Avg. Loss: 0.0071	Accuracy: 9220/10000 (92.2%)


Train:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 2  : Avg. Loss: 0.0035	Accuracy: 9377/10000 (93.8%)


Train:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 3  : Avg. Loss: 0.0027	Accuracy: 9466/10000 (94.7%)


Train:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 4  : Avg. Loss: 0.0022	Accuracy: 9535/10000 (95.3%)


Train:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 5  : Avg. Loss: 0.0019	Accuracy: 9593/10000 (95.9%)


Train:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 6  : Avg. Loss: 0.0016	Accuracy: 9586/10000 (95.9%)


Train:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 7  : Avg. Loss: 0.0014	Accuracy: 9579/10000 (95.8%)


Train:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 8  : Avg. Loss: 0.0013	Accuracy: 9540/10000 (95.4%)


Train:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 9  : Avg. Loss: 0.0012	Accuracy: 9518/10000 (95.2%)


Train:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 10 : Avg. Loss: 0.0011	Accuracy: 9569/10000 (95.7%)
Done!


In [19]:
print(model.fc1.weight)

Parameter containing:
tensor([[-0.03084081,  0.01904124,  0.02467639,  ..., -0.01760078, -0.03035460, -0.00424182],
        [-0.03064621,  0.00843286, -0.00851725,  ...,  0.02726848, -0.03423236, -0.02163524],
        [-0.02528932,  0.02233844, -0.02095168,  ...,  0.01256039,  0.02028259,  0.01839883],
        ...,
        [-0.00507159,  0.01260033, -0.03568311,  ...,  0.01650016, -0.01334844,  0.01835725],
        [ 0.01415202,  0.02092912, -0.02527218,  ...,  0.01994696,  0.01392094, -0.02142347],
        [-0.00415314, -0.02397374, -0.03201337,  ...,  0.01135057,  0.00744670,  0.01904514]], requires_grad=True)


# SNG

In [20]:
def Comp(a,lfsr,snum):
    for com in range(0,len(a)):
        oA = '0'
        if a[com]!=lfsr[com]:
            if(int(a[com]) > int(lfsr[com])):
                oA = '1'
            break
    return XOR(oA,snum)

In [21]:
def perm(a):
    al = len(a)
    blist = []
    for i in range(al) :
        #print(al-i-1)
        blist.append(a[al-i-1])
    
    b = "".join(blist)
    
    return b

In [22]:
class LFSR6:
    def Random(self):
        self.b0 = eval(f'str(random.randint(0,1))')
        self.b1 = eval(f'str(random.randint(0,1))')
        self.b2 = eval(f'str(random.randint(0,1))')
        self.b3 = eval(f'str(random.randint(0,1))')
        self.b4 = eval(f'str(random.randint(0,1))')
        self.b5 = eval(f'str(random.randint(0,1))')
        
        return self.b0 + self.b1 + self.b2 + self.b3 + self.b4 + self.b5
    
    def Normal(self,stream):
        self.b0 = XOR(int(stream[4]),int(stream[5]))
        self.b1 = stream[0]
        self.b2 = stream[1]
        self.b3 = stream[2]
        self.b4 = stream[3]
        self.b5 = stream[4]
        
        return self.b0 + self.b1 + self.b2 + self.b3 + self.b4 + self.b5
    
    def Allzero(self):
        self.b0 = '0'
        self.b1 = '0'
        self.b2 = '0'
        self.b3 = '0'
        self.b4 = '0'
        self.b5 = '0'
        
        return self.b0 + self.b1 + self.b2 + self.b3 + self.b4 + self.b5

In [23]:
def LFSRlist6():
    lfsr = LFSR6()
    lfsrlist = []
    for k in range(2**(args.bBW)-1): #lfsr number generating
        if k == 0:
            lfsrlist.append(lfsr.Random())
        else :
            lfsrlist.append(lfsr.Normal(lfsrlist[k-1]))
        if (k == 2**(args.bBW)-2):
            lfsrlist.append(lfsr.Allzero())
    
    if (args.bBW) != args.frac_bits :
        if args.bBW < args.frac_bits :
            for i in range(len(lfsrlist)):
                lfsrlist[i] = lfsrlist[i] + (args.frac_bits-args.bBW)*'0'
        else :
            print("it can't work")
            return 0
    
    return lfsrlist

In [95]:
def SNG(iIN,lfsr):

    sNUM = snum(iIN)
    
    bIN = flp2fix(iIN,args.full_bits,args.frac_bits).bFull
    bFRAC = bIN[-(args.frac_bits):]
    if sNUM == 1 :
        bFRAC = bin(int(binInv(bFRAC),2)+1).replace('0b','').rjust(args.bBW,'0')
    oAlist = []
    
    for k in range(2**(args.bBW)): #lfsr number generating
        lNUM = lfsr[k]
        a = Comp(bFRAC,lNUM,sNUM)
        oAlist.append(a) #comparator of input a
    
    oAlist.insert(0,sNUM)
    sA = "".join(oAlist)
    if bIN == '0'*args.full_bits :
        return "0"*((2**(args.bBW))+1)
    else :
        return sA

In [112]:
def SNG_P(iIN,lfsr):
    sNUM = snum(iIN)
    
    bIN = flp2fix(iIN,args.full_bits,args.frac_bits).bFull
    bFRAC = bIN[-(args.frac_bits):]
    if sNUM == 1 :
        bFRAC = bin(int(binInv(bFRAC),2)+1).replace('0b','').rjust(args.bBW,'0')
    oAlist = []
    
    for k in range(2**(args.bBW)): #lfsr number generating
        if (args.bBW == args.frac_bits) :    
            lNUM = perm(lfsr[k])
        elif (args.bBW < args.frac_bits) :
            lNUM = perm(lfsr[k][:args.bBW])+(args.frac_bits-args.bBW)*"0"
        a = Comp(bFRAC,lNUM,sNUM)
        oAlist.append(a) #comparator of input a
    
    oAlist.insert(0,sNUM)
    sA = "".join(oAlist)
    if bIN == '0'*args.full_bits :
        return "0"*((2**(args.bBW))+1)
    else :
        return sA

In [113]:
def SNGnumpy(fIn,lfsr):
    start = time.time()
    sList = []
    for aNumpy in fIn.view(-1):
        sList.append(SNG(float(aNumpy),lfsr))
    end = time.time()
    sec = (end-start)
    result_list = str(datetime.timedelta(seconds=sec)).split(".")
    print(f'SNGnumpy : {result_list[0]}')                 
    return np.array(sList).reshape(fIn.size())

In [114]:
def SNGpnumpy(fIn,lfsr):
    start = time.time()
    sList = []
    for aNumpy in fIn.view(-1):
        sList.append(SNG_P(float(aNumpy),lfsr))
    end = time.time()
    sec = (end-start)
    result_list = str(datetime.timedelta(seconds=sec)).split(".")
    print(f'SNGpnumpy : {result_list[0]}')                   
    return  np.array(sList).reshape(fIn.size())

In [115]:
def CountOne(nIn):
    nlist = []
    for num in nIn.reshape(-1):
        n = 0
        for a in num:
            if a == '1' :
                n += 1
        if a[0] == '1' :
            nlist.append(n-1)
        else :
            nlist.append(n)
    return torch.tensor(nlist).view(nIn.shape)

In [116]:
def defSign(nIn):
    nlist = []
    for num in nIn.reshape(-1):
        if num[0] == '1' :
            nlist.append(-1)
        else :
            nlist.append(1)
    return torch.tensor(nlist).view(nIn.shape)

In [117]:
def mul(a,b):
    al = len(a)
    bl = len(b)
    
    outlist = []
    
    if al != bl :
        print("length of string is different")
        return 0
    
    outlist.append(XOR(a[0],b[0]))
    
    for i in range(al-1) :
        outlist.append(str(int(a[i+1]) & int(b[i+1])))
    
    #print(outlist)
    out = "".join(outlist)
    return out    

In [118]:
def defSign1(nIn):
    if nIn[0] == '1' :
        return -1
    else :
        return 1

In [119]:
def CountOne1(nIn):
    n = 0
    for num in nIn:
        if num == '1' :
            n += 1
    if nIn[0] == '1' :
        return n-1
    else :
        return n

In [120]:
def S2None(sIn,SF):
    print(sIn)
    s = defSign1(sIn)
    o = (CountOne1(sIn)/(2**args.bBW))*SF*s
    end = time.time()
    return o

In [121]:
def mulNumpy(aIn,bIn,aSF,wSF):
    start = time.time()
    mList = []
    for i in range(aIn.shape[0]):
        for j in range(bIn.T.shape[1]):
            sum = 0
            for k in range(aIn.shape[1]):
                sum += S2None(mul((aIn[i][k].astype(str)),(bIn.T)[k][j].astype(str)),aSF*wSF)
            mList.append(sum)
    end = time.time()
    sec = (end-start)
    result_list = str(datetime.timedelta(seconds=sec)).split(".")
    print(f'mulNumpy : {result_list[0]}')
    return torch.tensor(mList).view(aIn.shape[0],bIn.T.shape[1])

## Find max, min 

In [128]:
def findMaxMin(data):
    start = time.time()
    
    max = torch.max(data)
    min = torch.min(data)
    SF=torch.max(abs(max),abs(min)).item()
    
    end = time.time()
    sec = (end-start)
    result_list = str(datetime.timedelta(seconds=sec)).split(".")
    print(f'findMaxMin : {result_list[0]}')
    
    return SF

## Checking time

## Fixed model

In [129]:
def model2fix(model, args):
	for name, _ in model.named_parameters():
		exec(f'model.{name}.data = flp2fixTensor(model.{name}.data, {args.full_bits}, {args.frac_bits})')
	return model

In [130]:
def N2S2N(model, iX, iW, iB, args):
    lfsr = LFSRlist6()
    
    xSF = findMaxMin(iX)
    wSF = findMaxMin(iW)

    x = SNGnumpy(iX/xSF,lfsr)
    w = SNGpnumpy(iW/wSF,lfsr)
    
    sout = mulNumpy(x,w,xSF,wSF) + iB
    
    return sout

In [131]:
def quantFixForward(model, x, args):
    cmodel = copy.deepcopy(model).to(args.device)
    
    with torch.no_grad():
        i0 = cmodel.flatten(x)
        act0 = N2S2N(cmodel,i0,cmodel.fc1.weight,cmodel.fc1.bias,args)
        print("act0 successed")
        
        i1 = model.relu1(act0)
        act1 = N2S2N(cmodel,i1,cmodel.fc2.weight,cmodel.fc2.bias,args)
        print("act1 successed")
        
        i2 = cmodel.relu2(act1)
        act2 = N2S2N(cmodel,i2,cmodel.fc3.weight,cmodel.fc3.bias,args)
        print("act2 successed")
        
        act3 = flp2fixTensor(act2, args.full_bits, args.frac_bits)
        print("act3 successed")
    return cmodel, act0, act1, act2, act3 

In [132]:
def testQuant(model, test_loader, args):
    
    start = time.time()
    
    qmodel = copy.deepcopy(model).to(args.device)
    qmodel = model2fix(qmodel, args)
    qmodel.eval()
    
    with torch.no_grad():
        loss_func = genLossFunc(args)
        loss, correct = 0, 0
        for batch_index, (image, label) in enumerate(tq(test_loader,desc='Test',leave=False)):
            start = time.time()
            image, label = image.to(args.device), label.to(args.device)
            qmodel, act0, act1, act2, act3  = quantFixForward(qmodel, image, args)
            y = act3
            loss += loss_func(y, label).item()#*image.size(0)
            correct += (y.argmax(1) == label).type(torch.int).sum().item()
            end = time.time()
            sec = end - start
            result_list = str(datetime.timedelta(seconds=sec)).split(".")
            print(f'image {batch_index} time  : {result_list[0]}')
    correct_rate = 100 * correct / len(test_loader.dataset)
    print(f'Accuracy: {correct}/{len(test_loader.dataset)} ({correct_rate:>.1f}%) Loss: {loss/len(test_loader.dataset):.2f}')
    
    end = time.time()
    sec = (end-start)
    result_list = str(datetime.timedelta(seconds=sec)).split(".")
    print(f'Total time is : {result_list[0]}')
    return qmodel, act0, act1, act2, act3

In [None]:
qmodel, act0, act1, act2, act3 = testQuant(model, test_loader, args)

Test:   0%|          | 0/157 [00:00<?, ?it/s]

findMaxMin : 0:00:00
findMaxMin : 0:00:00
SNGnumpy : 0:00:02
SNGpnumpy : 0:00:03
10000000000000000000000000000000000000000000000000000000000000000
00000000000000000000000000000000000000000000000000000000000000000
00000000000000000000000000000000000000000000000000000000000000000
00000000000000000000000000000000000000000000000000000000000000000
10000000000000000000000000000000000000000000000000000000000000000
00000000000000000000000000000000000000000000000000000000000000000
00000000000000000000000000000000000000000000000000000000000000000
00000000000000000000000000000000000000000000000000000000000000000
00000000000000000000000000000000000000000000000000000000000000000
10000000000000000000000000000000000000000000000000000000000000000
00000000000000000000000000000000000000000000000000000000000000000
00000000000000000000000000000000000000000000000000000000000000000
10000000000000000000000000000000000000000000000000000000000000000
1000000000000000000000000000000000000000000000000000000000000