In [1]:
import argparse
import tqdm
from tqdm import tqdm_notebook as tq
import os, time, math, copy, random
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from collections import namedtuple

torch.set_printoptions(precision=8, linewidth=50000)
import warnings
warnings.filterwarnings(action='ignore')

In [2]:
BLACK	= '\033[30m'
RED		= '\033[31m'
GREEN	= '\033[32m'
YELLOW	= '\033[33m'
BLUE	= '\033[34m'
MAGENTA	= '\033[35m'
CYAN	= '\033[36m'
RESET	= '\033[0m'
SEL		= '\033[7m'

In [3]:
parser = argparse.ArgumentParser(description='fixed_mac')
parser.add_argument('--device', type=str, default='cpu', help='Device')
parser.add_argument('--full_bits', type=int, default=8, help='Number of Quantization Bits') # full bits
parser.add_argument('--frac_bits', type=int, default=3, help='Number of Quantization Bits') # fraction bits
parser.add_argument('--bBW', type=int, default=4, help='Number of bitwidth')
args = parser.parse_args(args=[])

In [4]:
class	fxp:
	def	__init__(self, bIn, iBWF):
		self.iFullBW	= len(bIn)
		self.iIntgBW	= self.iFullBW - iBWF
		self.bSign		= bIn[0]
		self.bIntg		= bIn[:self.iIntgBW]
		self.bFrac		= bIn[self.iIntgBW:]
		self.fFull		= 0
		try:
			for idx, bit in enumerate(bIn):
				if	idx == 0:
					self.fFull = self.fFull + int(bit,2) * -pow(2, self.iIntgBW - 1)
				else:
					self.fFull = self.fFull + int(bit,2) * pow(2, self.iIntgBW - 1 - idx)
		except:
			print(bIn)
		self.dispFull	= RED + self.bIntg + BLUE + self.bFrac + RESET
		return

In [5]:
class	flp2fix:
	def	__init__(self, fIn, iBW, iBWF):
		self.fMin		= - 2 ** (iBW - iBWF - 1)
		self.fMax		= (2 ** (iBW-1) - 1) * (2 ** -iBWF)
		self.fResol		= 2 ** -iBWF
		if fIn < self.fMin or fIn > self.fMax:
			print(f'({fIn}): Out of input range ({self.fMax}/{self.fMin}) during flp -> fix converting ')
		self.iBW		= iBW
		self.iBWI		= iBW - iBWF # integer
		self.iBWF		= iBWF # fraction

		self.iFLP2INT	= abs(int(fIn * 2 ** iBWF))
		if fIn < 0:
			self.iFLP2INT = 2 ** (iBW-1) - self.iFLP2INT

		if fIn >= 0:
			self.bFull = bin(self.iFLP2INT)[2:].rjust(iBW, '0')
		else:
			self.bFull = '1'+bin(self.iFLP2INT)[2:].rjust(iBW-1, '0')
			if len(self.bFull) > iBW:
				self.bFull = '0' * iBW

		self.cssFxp		= fxp(self.bFull, self.iBWF)
		self.bSign		= self.cssFxp.bSign
		self.bIntg		= self.cssFxp.bIntg
		self.bFrac		= self.cssFxp.bFrac
		self.fFull		= self.cssFxp.fFull
		return

In [6]:
iIN = random.uniform(-(2**(args.bBW)),0)
iA = flp2fix(5, args.full_bits, args.frac_bits).bFull
ibA = flp2fix(iIN, args.full_bits, args.frac_bits).bFull
print(iIN)
print(iA, ibA)
iB = fxp('010110010', args.frac_bits).fFull
print(iB)

-10.57596997113258
00101000 10101100
22.25


## Integer -> Tow's complement Binary

In [7]:
def int2bin(iIn, iBW): # iBW : bit length = 16
    iBW		= iBW + 1
    if iIn >= 0:
	    bOut	= bin(iIn).replace('0b','').rjust(iBW, '0')
    else:
	    bOut	= bin(iIn & (pow(2,iBW)-1)).replace('0b','').rjust(iBW, '1')
    return bOut[1:] # magnitude : go into the comparator and compare with LFSR's output

In [8]:
int2bin(-5,4)

'1011'

## XOR

In [9]:
def XOR(iA, iB):
	if iA != iB:
		oOUT = '1'
	else:
		oOUT = '0'
	return oOUT

## LFSR

In [10]:
def LFSR(seed, flag, taps):
    sr, xor = seed, 0
    for t in taps:
        xor += int(sr[t-1])
    if xor%2 == 0.0:
        xor = 0
    else:
        xor = 1
    sr, xor = str(xor) + sr[:-1], 0
    if sr == flag[0]:
        return len(seed)*'0'
    else:
        return sr

## Comparator

In [11]:
def COMP(iA, l_lfsr):
    for comp in range(0, len(l_lfsr)): #lfsrlist : 16개의 elements -> len = 16
        oA = '0'
        if abs(iA) > int(l_lfsr[comp],2):
            oA = '1'
        elif abs(iA) < int(l_lfsr[comp],2):
            oA = '0'
            
    return oA

## SNG_Conventional

In [12]:
def SNG(bBW, iA): # bBW = 8

    cA = int2bin(iA, bBW)
    s_cA = cA[0]    
    oAlist = []
    l_lfsr = []
    ran = random.randint(1, pow(2,bBW-1)-1)
    ranZP = bBW-1-len(bin(ran).replace('0b',''))
    num = ranZP*'0'+bin(ran).replace('0b','')
    
    for lfsr in range (0, (2**(bBW-1))):
        if not l_lfsr:
            l_lfsr.append(num)
        else:
            l_lfsr.append(LFSR(l_lfsr[lfsr-1], l_lfsr, (7,6)))
            
        oAlist.append(COMP(iA, l_lfsr))
        
    Onum = oAlist.count('1')
    oAlist.insert(0, s_cA)
    sA = ''.join(oAlist)
#    print(l_lfsr)
        
    return sA

In [13]:
len(SNG(8,25))

129

## Inter-Blocks: Output Revision(OUR) Scheme

In [14]:
def pos(SN):
    return SN[0] == '0'
def neg(SN):
    return SN[0] == '1'
def counter(x):
    return x.count('1')

In [15]:
def gen_acclist(bBW, IN):
    A_list=[]
    acc_A=0
    acc_Alist=[]
    for i in range(0,2**(bBW-1)):
        A = 0
        for k in range(len(IN)):
            if IN[k][i] == '1':
                A += 1
        A_list.append(A)
        acc_A += A_list[i]
        acc_Alist.append(acc_A)     
    
    return acc_Alist

In [16]:
def gen_So(bBW, diff):
    So_list=[]
    Ao_list=[]
    global So
    for o in range(2**(bBW-1)):
        if o == 0:
            Ao = 0 # first A_o is fixed
            if diff[o] > 0:
                So = 1
            else:
                So = 0
            Ao_list.append(Ao)
            So_list.append(So)
        else:
            Ao += So_list[o-1]
            Ao_list.append(Ao)
            if diff[o] > Ao_list[o]:
                So = 1
            elif diff[o] <= Ao_list[o]:
                So = 0
            So_list.append(So)
    
    return So_list

In [17]:
def block(IN):
    BLOCK1=[]
    BLOCK2=[]
    BLOCK3=[]
    BLOCK4=[]
    BLOCK5=[]
    BLOCK6=[]
    BLOCK7=[]
    BLOCK8=[]
    for i in range(len(IN)):
        block1 = IN[i][1:17]
        block2 = IN[i][17:33]
        block3 = IN[i][33:49]
        block4 = IN[i][49:65]
        block5 = IN[i][65:81]
        block6 = IN[i][81:97]
        block7 = IN[i][97:113]
        block8 = IN[i][113:]
        BLOCK1.append(block1)
        BLOCK2.append(block2)
        BLOCK3.append(block3)
        BLOCK4.append(block4)
        BLOCK5.append(block5)
        BLOCK6.append(block6)
        BLOCK7.append(block7)
        BLOCK8.append(block8)
    return BLOCK1, BLOCK2, BLOCK3, BLOCK4, BLOCK5, BLOCK6, BLOCK7, BLOCK8

In [69]:
def OUR(cnt, bBW): # bBW = 8
    global diff1, diff2, diff3, diff4, diff5, diff6, diff7, diff8, sign
    gen_SN = []
#    output = []
    for i in range(0, 2**cnt):
        iA = random.randint(-64,63)
        SN = SNG(bBW,iA) 
        gen_SN.append(SN)
  
    # sorting random input bit-stream(positive/negative)
    pos_IN = list(filter(pos, gen_SN))
    neg_IN = list(filter(neg, gen_SN))

    # Block devision
    pBLOCK1, pBLOCK2, pBLOCK3, pBLOCK4, pBLOCK5, pBLOCK6, pBLOCK7, pBLOCK8 = block(pos_IN)
    nBLOCK1, nBLOCK2, nBLOCK3, nBLOCK4, nBLOCK5, nBLOCK6, nBLOCK7, nBLOCK8 = block(neg_IN)

    # generate list of number of accumulated 1s
    Ap1_list = gen_acclist(bBW-3, pBLOCK1)
    An1_list = gen_acclist(bBW-3, nBLOCK1)
    Ap2_list = gen_acclist(bBW-3, pBLOCK2)
    An2_list = gen_acclist(bBW-3, nBLOCK2)
    Ap3_list = gen_acclist(bBW-3, pBLOCK3)
    An3_list = gen_acclist(bBW-3, nBLOCK3)
    Ap4_list = gen_acclist(bBW-3, pBLOCK4)
    An4_list = gen_acclist(bBW-3, nBLOCK4)
    Ap5_list = gen_acclist(bBW-3, pBLOCK5)
    An5_list = gen_acclist(bBW-3, nBLOCK5)
    Ap6_list = gen_acclist(bBW-3, pBLOCK6)
    An6_list = gen_acclist(bBW-3, nBLOCK6)
    Ap7_list = gen_acclist(bBW-3, pBLOCK7)
    An7_list = gen_acclist(bBW-3, nBLOCK7)
    Ap8_list = gen_acclist(bBW-3, pBLOCK8)
    An8_list = gen_acclist(bBW-3, nBLOCK8)

    # determine sign of output
    if Ap1_list[-1]+Ap2_list[-1]+Ap3_list[-1]+Ap4_list[-1] > An1_list[-1]+An2_list[-1]+An3_list[-1]+An4_list[-1]:
        diff1 = [x-y for x,y in zip(Ap1_list, An1_list)]
        diff2 = [x-y for x,y in zip(Ap2_list, An2_list)]
        diff3 = [x-y for x,y in zip(Ap3_list, An3_list)]
        diff4 = [x-y for x,y in zip(Ap4_list, An4_list)]
        diff5 = [x-y for x,y in zip(Ap5_list, An5_list)]
        diff6 = [x-y for x,y in zip(Ap6_list, An6_list)]
        diff7 = [x-y for x,y in zip(Ap7_list, An7_list)]
        diff8 = [x-y for x,y in zip(Ap8_list, An8_list)]
        sign  = ['0']        
    elif Ap1_list[-1]+Ap2_list[-1]+Ap3_list[-1]+Ap4_list[-1] < An1_list[-1]+An2_list[-1]+An3_list[-1]+An4_list[-1]:
        diff1 = [x-y for x,y in zip(An1_list, Ap1_list)]
        diff2 = [x-y for x,y in zip(An2_list, Ap2_list)]
        diff3 = [x-y for x,y in zip(An3_list, Ap3_list)]
        diff4 = [x-y for x,y in zip(An4_list, Ap4_list)]
        diff5 = [x-y for x,y in zip(An5_list, Ap5_list)]
        diff6 = [x-y for x,y in zip(An6_list, Ap6_list)]
        diff7 = [x-y for x,y in zip(An7_list, Ap7_list)]
        diff8 = [x-y for x,y in zip(An8_list, Ap8_list)]
        sign  = ['1']
    
    So1 = gen_So(bBW-3, diff1)
    So2 = gen_So(bBW-3, diff2)
    So3 = gen_So(bBW-3, diff3)
    So4 = gen_So(bBW-3, diff4)
    So5 = gen_So(bBW-3, diff5)
    So6 = gen_So(bBW-3, diff6)
    So7 = gen_So(bBW-3, diff7)
    So8 = gen_So(bBW-3, diff8)
    So_list = So1 + So2 + So3 + So4 + So5 + So6 + So7 + So8

    result = list(map(str, So_list))
    tempout = ''.join(result)

    p = abs((Ap1_list[-1]+Ap2_list[-1]+Ap3_list[-1]+Ap4_list[-1]+Ap5_list[-1]+Ap6_list[-1]+Ap7_list[-1]+Ap8_list[-1])-(An1_list[-1]+An2_list[-1]+An3_list[-1]+An4_list[-1]+An5_list[-1]+An6_list[-1]+An7_list[-1]+An8_list[-1]))
    q = counter(tempout)

    p_list=[]
    q_list=[counter(tempout)]
    out=[]
    for k in range(2**(bBW-1)):
        p_list.append(p)
        if q < p:
            q += 1
            out.append('1')
        elif q > p:
            q -= 1
            out.append('0')
        else:
            q = p
            out.append(tempout[k])
        q_list.append(q)
    sout = sign + out
    output = ''.join(sout)
    
    print("===========================================")
    print(" Index : %s"%(i+1))
    print("-------------------------------------------")
    print("real output: {0}/{1}".format((Ap1_list[-1]+Ap2_list[-1]+Ap3_list[-1]+Ap4_list[-1]+Ap5_list[-1]+Ap6_list[-1]+Ap7_list[-1]+Ap8_list[-1])-(An1_list[-1]+An2_list[-1]+An3_list[-1]+An4_list[-1]+An5_list[-1]+An6_list[-1]+An7_list[-1]+An8_list[-1]),128))
    print("cal output : {0}/{1}".format(counter(output[1:]),128))
    print('positive input = {0}'.format(pos_IN), len(pos_IN[0]))
    print('negative input = {0}'.format(neg_IN), len(neg_IN[0]))
    print('pBLOCK1 = {0}'.format(pBLOCK1))
    print('nBLOCK1 = {0}'.format(nBLOCK1))
    print(Ap1_list, An1_list)
    print('pBLOCK2 = {0}'.format(pBLOCK2))
    print('nBLOCK2 = {0}'.format(nBLOCK2))
    print(Ap2_list, An2_list)
    print('pBLOCK3 = {0}'.format(pBLOCK3))
    print('nBLOCK3 = {0}'.format(nBLOCK3))
    print(Ap3_list, An3_list)
    print('pBLOCK4 = {0}'.format(pBLOCK4))
    print('nBLOCK4 = {0}'.format(nBLOCK4))
    print(Ap4_list, An4_list)
    print('pBLOCK5 = {0}'.format(pBLOCK5))
    print('nBLOCK5 = {0}'.format(nBLOCK5))
    print(Ap5_list, An5_list)
    print('pBLOCK6 = {0}'.format(pBLOCK6))
    print('nBLOCK6 = {0}'.format(nBLOCK6))
    print(Ap6_list, An6_list)
    print('pBLOCK7 = {0}'.format(pBLOCK7))
    print('nBLOCK7 = {0}'.format(nBLOCK7))
    print(Ap7_list, An7_list)
    print('pBLOCK8 = {0}'.format(pBLOCK8))
    print('nBLOCK8 = {0}'.format(nBLOCK8))
    print(Ap8_list, An8_list)
    print(So1, So2, So3, So4)
    print(tempout)
    print(p_list)
    print(q_list)
    print(output)
    print("===========================================")

    return output, counter(output[1:]), p

In [76]:
OUR(3,8)

 Index : 8
-------------------------------------------
real output: -61/128
cal output : 61/128
positive input = ['010100000000111110111110011110101110000110111010011000101011000000011110001110110110010010100100001001110010110100010001100110101', '001100100101001000000001100101101000000001001101010100000000111110111110011110101110000010111010011000001011000000011110000110111', '000000000000000000000000000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000001'] 129
negative input = ['100000000000000000000000000000000000000000000000000000000000000000000000000100000000000000000000000000000000000000000000000000001', '100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000010000010000000000000000000000000001', '100011111011111001111010111000011011101001100010101100000101111000111011011001001010010000100111001011010001000110011010101000001', '11101100100001000000000001100101101000000001000101010100000000111

('100011111111110001110010011000000001111101010001100111110101001011111001111101000001100000000011100011110001001011000011101000000',
 61,
 61)

In [20]:
n = 0
k = 0
mse = 0
for i in range(1000):
    result, p = OUR(4,8)
    if p > 128:
        #print("intra-Block module result is: overflow!")
        k += 1
        #continue
    
    else:
        print("===========================================")
        print(" Index : %s"%(i+1))
        print("-------------------------------------------")
        print("intra-Block module result is: {0}/128".format(counter(result[1:])))
        print("real accumulation result is : {0}/128".format(p))
        print("===========================================")
        Error = abs(counter(result[1:])/128-p/128)*100
        if Error <= 5:
            n += 1
           
        #mse += pow((counter(result[1:])-p)/128,2)
        
#MSE = mse/(100-k)
#print(mse)
print("128-bit's MSE is: {0}".format(MSE))
print("Accuracy is: {0}/{1}, {2}%".format(n, 1000-k, n/(1000-k)*100))

 Index : 16
-------------------------------------------
real output: -47/128
cal output : 94/128
positive input = ['000000000001100000100000000000000000000000000100000000000000000000000000000000000011110011110000110000010000000001000000000000001', '000000000000011000011000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001', '000000000000001100000100000000001000100000000000000111110111110001110100110000010011010001000000001000000001110000110110110000001', '011101001100010101100000001111000111011011001001010010000000111001011010001000110011010101000000001111101111100111101011100001101', '001000101000000000000111110111110001110101110000010111010001000000001000000001110000110110110000000000000000000110000010100000001', '001100100101001000000011100101101000000001001101010100000000111110111110011110101110000010111010011000001011000000011110001110111', '00000000001111000111000011000001000000000100000000000000000000000011000001000000000000

ValueError: too many values to unpack (expected 2)