In [78]:
from torch.quantization import quantize_dynamic
from torch import nn
import torch
import copy
import numpy as np
import math
import genesys_quantized_linear

In [56]:
class PrintLayer(nn.Module):
    def __init__(self):
        super(PrintLayer, self).__init__()
    
    def forward(self, x):
        # Do your print / debug stuff here
        print(x)
        return x

In [57]:
# From FBGEMM
# https://github.com/pytorch/FBGEMM/blob/8f1b8777745d412c10d254284a72d76357ac287a/include/fbgemm/QuantUtils.h#L45
def clamp(src, precision=8, signed=False):
    min_num = -(1 << (precision - 1)) if signed else 0
    max_num = ((1 << (precision - 1)) - 1) if signed else (1 << precision) - 1;
    
    return np.minimum(np.maximum(src, min_num), max_num)

In [58]:
# From FBGEMM
# https://github.com/pytorch/FBGEMM/blob/8f1b8777745d412c10d254284a72d76357ac287a/include/fbgemm/QuantUtils.h#L62
def quantize(src, zero_point, scale, precision=8, signed=False):
    inv_scale = 1.0/scale
    transformed_val = src * inv_scale;
    transformed_val = zero_point + np.round(transformed_val)
    result = clamp(transformed_val, precision, signed)
    return result

In [59]:
# From FBGEMM
# https://github.com/pytorch/FBGEMM/blob/8f1b8777745d412c10d254284a72d76357ac287a/include/fbgemm/QuantUtils.h#L147
def dequantize(src, zero_point, scale):
    result = scale * (src - zero_point)
    return result

In [60]:
# Origial Model:
model = nn.Sequential(
    nn.Linear(32,1)
)

In [61]:
from torch.quantization.observer import MovingAverageMinMaxObserver

# Pytorch Quantization
model.eval()
m_pytorch = copy.deepcopy(model)
m_pytorch.eval()
#backend = "x86"

"""Insert stubs"""
m_pytorch_q = nn.Sequential(torch.quantization.QuantStub(), 
                  *m_pytorch, 
                  torch.quantization.DeQuantStub())

"""Prepare"""
# qconfig = torch.quantization.get_default_qconfig(backend)
# print(qconfig)

# qconfig = torch.quantization.QConfig(
#   activation=MovingAverageMinMaxObserver.with_args(qscheme=torch.per_tensor_symmetric, dtype=torch.quint8),
#   weight=MovingAverageMinMaxObserver.with_args(qscheme=torch.per_tensor_symmetric, dtype=torch.qint8)
# )

qconfig = torch.quantization.QConfig(
  activation=MovingAverageMinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8),
  weight=MovingAverageMinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.qint8)
)

m_pytorch_q.qconfig = qconfig
#torch.per_tensor_symmetric
torch.quantization.prepare(m_pytorch_q, inplace=True)

with torch.inference_mode():
  for _ in range(1000):
    x = torch.rand(10, 32)
    m_pytorch_q(x)
    
"""Convert"""
torch.quantization.convert(m_pytorch_q, inplace=True)

Sequential(
  (0): Quantize(scale=tensor([0.0039]), zero_point=tensor([0]), dtype=torch.quint8)
  (1): QuantizedLinear(in_features=32, out_features=1, scale=0.0029710179660469294, zero_point=255, qscheme=torch.per_tensor_affine)
  (2): DeQuantize()
)

In [62]:
print(m_pytorch_q[1].weight())
weight_pytorch = torch.int_repr(m_pytorch_q[1].weight()).numpy()
print(weight_pytorch)

tensor([[-0.0714, -0.0013, -0.1288, -0.1416,  0.1288,  0.0089, -0.0115, -0.0026,
          0.0319, -0.1377, -0.0319, -0.1046,  0.1492,  0.0306, -0.0880, -0.0969,
          0.0000, -0.1301,  0.0013,  0.0727,  0.0421,  0.0000,  0.1416, -0.1250,
         -0.0281, -0.1760,  0.1454, -0.1569, -0.1046,  0.0587, -0.1467,  0.0791]],
       size=(1, 32), dtype=torch.qint8,
       quantization_scheme=torch.per_tensor_affine, scale=0.0012753951596096158,
       zero_point=10)
[[ -46    9  -91 -101  111   17    1    8   35  -98  -15  -72  127   34
   -59  -66   10  -92   11   67   43   10  121  -88  -12 -128  124 -113
   -72   56 -105   72]]


In [63]:
# Custome Quantization
weight = model[0].weight.detach()
weight_scale = m_pytorch_q[1].weight().q_scale()
weight_zero_point = m_pytorch_q[1].weight().q_zero_point()
dequantize_weight_pytorch = dequantize(weight_pytorch, scale=weight_scale, zero_point=weight_zero_point)
# print("Dequantzed pytorch weight:")
# print(dequantize_weight_pytorch)
print("Weight Scale: ", weight_scale)
print("Weight Zero Ploint: ", weight_zero_point)
print("Original weight:")
print(weight)
weight_q = quantize(src=weight, scale=weight_scale, zero_point=weight_zero_point, signed=True).numpy()
print("Quantzed weight:")
print(weight_q)
print("Dequantzed weight:")
weight_dequant = dequantize(weight_q, scale=weight_scale, zero_point=weight_zero_point)
print(weight_dequant)

Weight Scale:  0.0012753951596096158
Weight Zero Ploint:  10
Original weight:
tensor([[-0.0719, -0.0007, -0.1291, -0.1409,  0.1288,  0.0089, -0.0111, -0.0029,
          0.0321, -0.1373, -0.0313, -0.1050,  0.1499,  0.0312, -0.0876, -0.0971,
         -0.0003, -0.1303,  0.0011,  0.0723,  0.0425,  0.0003,  0.1416, -0.1251,
         -0.0284, -0.1754,  0.1456, -0.1567, -0.1041,  0.0588, -0.1464,  0.0795]])
Quantzed weight:
[[ -46.    9.  -91. -101.  111.   17.    1.    8.   35.  -98.  -15.  -72.
   127.   34.  -59.  -66.   10.  -92.   11.   67.   43.   10.  121.  -88.
   -12. -128.  124. -113.  -72.   56. -105.   72.]]
Dequantzed weight:
[[-0.07142213 -0.0012754  -0.1288149  -0.14156887  0.1288149   0.00892777
  -0.01147856 -0.00255079  0.03188488 -0.13774268 -0.03188488 -0.10458241
   0.14922123  0.03060948 -0.08800226 -0.09693003  0.         -0.13009031
   0.0012754   0.07269753  0.04208804  0.          0.14156887 -0.12498873
  -0.02805869 -0.17600453  0.14539506 -0.1568736  -0.10458241  0

In [64]:
# Check Weight Quantization
diff = weight_pytorch - weight_q
print(diff)
print("Quantization Sucess!") if (abs(np.max(diff)) <= 1) else print("Quantization failed!")

[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0.]]
Quantization Sucess!


In [65]:
# From https://github.com/google/gemmlowp/blob/master/doc/quantization_example.cc#L210
def quantizeMultiplierSmallerThanOne(real_multiplier):
    right_shift = 0
    while real_multiplier < 0.5:
        real_multiplier *= 2.0
        right_shift+=1
        
    quantized_multiplier = np.round(real_multiplier * (1 << 31))
    quantized_multiplier = np.int64(quantized_multiplier)
    assert quantized_multiplier <= (1<<31)
    quantized_multiplier = np.int32(quantized_multiplier)
    
    if quantized_multiplier == (1 << 31):
        quantized_multiplier /= 2
        right_shift-=1
    
    assert right_shift >= 0
    assert quantized_multiplier <= np.iinfo(quantized_multiplier.dtype).max
    return quantized_multiplier, right_shift

In [66]:
# Inference Test
input_float = torch.rand(1,32)

In [67]:
#######################
# Custome output
#######################
# 1. Quantize Input:
input_scale = m_pytorch_q[0].scale.numpy()[0]
input_zero_point = m_pytorch_q[0].zero_point.numpy()[0]
input_q = quantize(src=input_float, scale=input_scale, zero_point=input_zero_point, signed=False).numpy()
print("Input Scale: ", input_scale)
print("Input Zero Point: ", input_zero_point)
print("Input Float:")
print(input_float)
print("Input Quant:")
print(input_q)
print("Input Dequant:")
input_dequant = dequantize(src=input_q, scale=input_scale, zero_point=input_zero_point)
print(input_dequant)

Input Scale:  0.003910441
Input Zero Point:  0
Input Float:
tensor([[0.4966, 0.6472, 0.6732, 0.6343, 0.6171, 0.3961, 0.9619, 0.0564, 0.7982,
         0.0521, 0.4522, 0.6521, 0.6271, 0.1004, 0.0307, 0.6566, 0.2897, 0.7411,
         0.3003, 0.3013, 0.9382, 0.9840, 0.8076, 0.0075, 0.3455, 0.1321, 0.9654,
         0.8554, 0.3008, 0.2634, 0.9033, 0.2811]])
Input Quant:
[[127. 166. 172. 162. 158. 101. 246.  14. 204.  13. 116. 167. 160.  26.
    8. 168.  74. 190.  77.  77. 240. 252. 207.   2.  88.  34. 247. 219.
   77.  67. 231.  72.]]
Input Dequant:
[[0.496626   0.6491332  0.67259586 0.63349146 0.61784965 0.39495453
  0.9619685  0.05474617 0.79772997 0.05083573 0.45361114 0.6530436
  0.62567055 0.10167146 0.03128353 0.65695405 0.28937262 0.74298376
  0.30110395 0.30110395 0.9385058  0.98543113 0.8094613  0.00782088
  0.3441188  0.13295498 0.9658789  0.85638654 0.30110395 0.26199955
  0.90331185 0.28155175]]


In [68]:
#######################
# pytorch output
output_qt = model(input_float).detach().numpy()
output_q_pt = m_pytorch_q(input_float)
print("Reference Full Precision Output: ", output_qt)
print("Reference Quantization Output: ", output_q_pt)
#######################

output_scale = m_pytorch_q[1].scale
output_zero_point = m_pytorch_q[1].zero_point
print("Output Scale: ", output_scale)
print("Output Zero Point: ", output_zero_point)
output_quantized = quantize(src=output_qt, scale=output_scale, zero_point=output_zero_point, signed=False)
print("Quantized Float Output Output: ", output_quantized)

Reference Full Precision Output:  [[-0.36835927]]
Reference Quantization Output:  tensor([[-0.3714]])
Output Scale:  0.0029710179660469294
Output Zero Point:  255
Quantized Float Output Output:  [[131.]]


In [69]:
# 2. Gemm
print("Original Gemm:")
gemm = input_float @ weight.reshape(32,1)
print(gemm)                                
gemm_quantized = quantize(src=gemm, scale=output_scale, zero_point=output_zero_point, precision=32, signed=False)
print("Original Gemm Quantized:")
print(gemm_quantized)
# Gemm original equation 1:
q1=input_q[0]-input_zero_point
q2=weight_q[0]-weight_zero_point
S=input_scale * weight_scale
print(q1 @ q2.transpose() * S)

# Gemm original equation 2:
sum_gemm = 0 
for idx, x in enumerate(input_q[0]):
    sum_gemm += (x-input_zero_point) * (weight_q[0][idx]-weight_zero_point)
print(sum_gemm*S)

# Gemm dequant
gemm_dequant = input_dequant @ weight_dequant.transpose()
print(gemm_dequant)

# Integer Gemm 
gemm_out_q = (input_q @ weight_q.transpose()).astype(np.int32)

Original Gemm:
tensor([[-0.2468]])
Original Gemm Quantized:
tensor([[172.]], dtype=torch.float64)
-0.24954741794820284
-0.24954741794820284
[[-0.24954742]]


In [70]:
act_times_w_scale = input_scale * weight_scale
M = (act_times_w_scale / output_scale)

In [71]:
# # 3. Quantize Bias
bias_float = model[0].bias.detach()
bias_q = quantize(src=bias_float, scale=act_times_w_scale, zero_point=0, precision = 32, signed=False).numpy()
print("Bias float: ")
print(bias_float)
print("Bias quantized: ")
print(bias_q)

Bias float: 
tensor([-0.1215])
Bias quantized: 
[0.]


In [72]:
from fxpmath import Fxp

# 3. M decompose
_, right_shift = quantizeMultiplierSmallerThanOne(M)
M0 = M*(2**right_shift)
M0_fxp = Fxp(M0, signed=True, n_word=64, n_frac=32)

N = input_float.shape[1]
#print(N)
NZ1Z2 = N * input_zero_point * weight_zero_point
a1 = np.sum(input_q, axis=1)
a2 = np.sum(weight_q, axis=1)
Z1a2 = int(input_zero_point * a2[0])
Z2a1 = int((weight_zero_point * a1)[0])
# print(NZ1Z2)
# print(Z1a2)
# print(Z2a1)

In [73]:
print("Before ", gemm_out_q)
for r in range(gemm_out_q.shape[0]):
    gemm_out_q[r,:] = gemm_out_q[r,:] - Z1a2
for c in range(gemm_out_q.shape[1]):
    gemm_out_q[:,c] = gemm_out_q[:,c] - Z2a1
print("After ", gemm_out_q)
gemm_out_q = gemm_out_q + bias_q

Before  [[-8416]]
After  [[-50036]]


In [74]:
# M0_fxp.info()
# print(gemm_out_q*M0)
# out_test = M0_fxp * integer
# out_test.info()

In [75]:
# M0 here
fxp_part = (M0_fxp*(NZ1Z2+gemm_out_q)) >> right_shift
gemm_final_out_q = output_zero_point + fxp_part.astype(np.int32)
print(fxp_part)

#gemm_final_out_q = gemm_final_out_q + bias_q.numpy()

gemm_final_out = dequantize(src=gemm_final_out_q, scale=output_scale, zero_point=output_zero_point)
print("Quantized Gemm")
print(gemm_final_out)

# # Bias Add:
output_q = gemm_final_out_q

# # Final Dequantize
output = dequantize(src=output_q, scale=output_scale, zero_point=output_zero_point)
print("Final Output:")
print(output)

# Original Fucntion
# q3 = np.round(gemm_out_q * M + output_zero_point + bias_q).numpy().astype(np.int32)
# output_test = dequantize(src=q3, scale=output_scale, zero_point=output_zero_point)
# print(q3)
# print(output_test)

[[-83.99391077]]
Quantized Gemm
[[-0.24956551]]
Final Output:
[[-0.24956551]]
