<a href="https://colab.research.google.com/github/kihoon71/quantization_code/blob/main/quantization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Zero point Quantization

In [None]:
import numpy as np

def get_scale(r, d_type='int8'):
  r_min, r_max = np.min(r), np.max(r)
  d_min = np.iinfo(d_type).min
  d_max = np.iinfo(d_type).max

  scale = (d_max - d_min) / (r_max - r_min)

  return scale

def get_zeropoint(r, d_type='int8'):
  r_min, r_max = np.min(r), np.max(r)
  d_min = np.iinfo(d_type).min
  d_max = np.iinfo(d_type).max

  scale = get_scale(r, d_type)

  zero_point = -1 * np.round(r_min * scale) + d_min

  return zero_point


def zero_point_quantiztion(r, value,d_type='int8'):

  # get scale
  scale = get_scale(r, d_type)


  # get zero point
  zero_point = get_zeropoint(r, d_type)

  # quantize the value
  q_value = np.round(scale*value + zero_point)
  return q_value.astype(d_type)

def dequantize_value(r, value, d_type='int8'):
  # get scale
  scale = get_scale(r, d_type)

  # get zero point
  zero_point = get_zeropoint(r, d_type  )

  # dequantize the value
  r_value = (value - zero_point) / scale

  return r_value

r = np.linspace(-1.0 , 1.0)
quantized_value = zero_point_quantiztion(r, 0.5)

print(r)
print(np.min(r), np.max(r))
print('scale :', get_scale(r))
print('zero point :', get_zeropoint(r))
print(quantized_value)
print(dequantize_value(r, quantized_value))

[-1.         -0.95918367 -0.91836735 -0.87755102 -0.83673469 -0.79591837
 -0.75510204 -0.71428571 -0.67346939 -0.63265306 -0.59183673 -0.55102041
 -0.51020408 -0.46938776 -0.42857143 -0.3877551  -0.34693878 -0.30612245
 -0.26530612 -0.2244898  -0.18367347 -0.14285714 -0.10204082 -0.06122449
 -0.02040816  0.02040816  0.06122449  0.10204082  0.14285714  0.18367347
  0.2244898   0.26530612  0.30612245  0.34693878  0.3877551   0.42857143
  0.46938776  0.51020408  0.55102041  0.59183673  0.63265306  0.67346939
  0.71428571  0.75510204  0.79591837  0.83673469  0.87755102  0.91836735
  0.95918367  1.        ]
-1.0 1.0
scale : 127.5
zero point : 0.0
64
0.5019607843137255


# absmax qunatization

In [None]:
def get_absmax(r):
  return np.max(np.abs(r))

def get_range_data_type(d_type='int8'):
  return np.iinfo(d_type).max

def get_absmax_scale(r, dtype='int8'):
  absmax = get_absmax(r)
  scale = get_range_data_type(dtype) / absmax
  return scale

def absmax_quantization(r, value, dtype='int8'):
  absmax = get_absmax(r)
  scale = get_absmax_scale(r, dtype)
  quantized_value = np.round(value * scale)
  return quantized_value.astype(dtype)

def dequantize_value(r, value, dtype='int8'):
  absmax = get_absmax(r)
  scale = get_absmax_scale(r, dtype)
  r_value = value / scale
  return r_value


def matrix_absmax_quantization(matrix_q, dtype='int8'):
  abs_max = get_absmax(matrix_q)
  scale = get_range_data_type(dtype) / abs_max
  quantized_matrix = np.round(matrix_q * scale)
  return quantized_matrix.astype('int8')

def matrix_dequantize_value(matrix_q, quantized_matrix, dtype='int8'):
  abs_max = get_absmax(matrix_q)
  scale = get_absmax_scale(matrix_q, dtype)
  r_value_matrix = quantized_matrix / scale
  return r_value_matrix


vector_a = np.random.randn(8)
value = vector_a[3]
quantized = absmax_quantization(vector_a, value)
print('vector a :', vector_a)
print('4번째 원소 :', value)
print('vector a의 절대 최대 값 :', get_absmax(vector_a))
print('data의 범위값 :', get_range_data_type())
print('양자화 값 :', absmax_quantization(vector_a, value))
print('역 양자화 값', dequantize_value(vector_a,quantized))

vector a : [ 0.67229476  0.40746184 -0.76991607  0.53924919 -0.67433266  0.03183056
 -0.63584608  0.67643329]
4번째 원소 : 0.5392491912918173
vector a의 절대 최대 값 : 0.7699160744453164
data의 범위값 : 127
양자화 값 : 89
역 양자화 값 0.539547485241206


# Matrix quantization operation

$X_{f16}W_{f16} = C_{f16} ≈
\frac{1}{c_{x_{f16}}c_{w_{f16}}}C_{i32}= S_{f16} · C_{i32} $

$≈ S_{f16} · A_{i8}B_{i8} = S_{f16} · Q(A_{f16}) Q(B_{f16})$

아래는 quantization을 실행하는 코드에 대한 것이다.

In [None]:
np.random.seed(0)

# X_f16 * W_f16
input_matrix_x = np.random.randn(8, 8).astype('float16')
weight_matrix_w = np.random.randn(8, 8).astype('float16')

print('inut_x')
print(input_matrix_x, input_matrix_x.dtype)
print('weight_w')
print(weight_matrix_w, weight_matrix_w.dtype)

result_matrix_C = np.matmul(input_matrix_x, weight_matrix_w)
print('result_matrix_C')
print(result_matrix_C, result_matrix_C.dtype)


inut_x
[[ 1.764    0.4001   0.9785   2.24     1.867   -0.977    0.95    -0.1514 ]
 [-0.1032   0.4106   0.144    1.454    0.761    0.1217   0.4438   0.3337 ]
 [ 1.494   -0.2052   0.313   -0.854   -2.553    0.654    0.8643  -0.742  ]
 [ 2.27    -1.454    0.04575 -0.1871   1.533    1.47     0.1549   0.3782 ]
 [-0.8877  -1.98    -0.348    0.1564   1.23     1.202   -0.3872  -0.3022 ]
 [-1.049   -1.42    -1.706    1.951   -0.51    -0.438   -1.253    0.7773 ]
 [-1.614   -0.2128  -0.8955   0.387   -0.5107  -1.181   -0.02818  0.4282 ]
 [ 0.0665   0.3025  -0.6343  -0.3628  -0.6724  -0.3596  -0.813   -1.727  ]] float16
weight_w
[[ 0.1774  -0.4019  -1.63     0.463   -0.907    0.05194  0.729    0.129  ]
 [ 1.14    -1.234    0.4023  -0.6846  -0.8706  -0.5786  -0.3115   0.05615]
 [-1.165    0.901    0.4656  -1.536    1.488    1.8955   1.179   -0.1799 ]
 [-1.07     1.055   -0.403    1.223    0.2083   0.9766   0.3564   0.7065 ]
 [ 0.0105   1.786    0.127    0.402    1.883   -1.348   -1.2705   0.969  ]


In [None]:

#quantize X, W and mulitply S_f16
quantized_X = matrix_absmax_quantization(input_matrix_x)
quantized_W = matrix_absmax_quantization(weight_matrix_w)

print('==============quantized==============' )
print(quantized_X, quantized_X.dtype)
print(quantized_W, quantized_W.dtype)

print('==============dequantized==============' )
dequantized_x = matrix_dequantize_value(input_matrix_x, quantized_X)
dequantized_w = matrix_dequantize_value(weight_matrix_w, quantized_W)
print(dequantized_x, dequantized_x.dtype)
print(dequantized_w, dequantized_w.dtype)

print('==============origin-dequantized==============' )
print(input_matrix_x - dequantized_x)
print(weight_matrix_w - dequantized_w)

scale_x = get_absmax_scale(input_matrix_x)
scale_w = get_absmax_scale(weight_matrix_w)

print('scale_x :', scale_x)
print('scale_w :', scale_w)

S_f16 = 1 / (scale_x * scale_w)
print('S_f16 :', S_f16)

quantized_scaled_C = S_f16 * np.matmul(quantized_X, quantized_W)
print(quantized_scaled_C, quantized_scaled_C.dtype)


[[  88   20   49  111   93  -49   47   -8]
 [  -5   20    7   72   38    6   22   17]
 [  74  -10   16  -42 -127   33   43  -37]
 [ 113  -72    2   -9   76   73    8   19]
 [ -44  -98  -17    8   61   60  -19  -15]
 [ -52  -71  -85   97  -25  -22  -62   39]
 [ -80  -11  -45   19  -25  -59   -1   21]
 [   3   15  -32  -18  -33  -18  -40  -86]] int8
[[  12  -26 -107   30  -59    3   48    8]
 [  74  -81   26  -45  -57  -38  -20    4]
 [ -76   59   30 -100   97  124   77  -12]
 [ -70   69  -26   80   14   64   23   46]
 [   1  117    8   26  123  -88  -83   63]
 [ -77  127  -27  -49  126   97  122   59]
 [ -56  125  -18   52   62  -10   40   60]
 [  25  -72   20   87  -45  -10  -28  121]] int8
[[ 1.76882382  0.40200541  0.98491326  2.23113004  1.86932517 -0.98491326
   0.94471272 -0.16080217]
 [-0.10050135  0.40200541  0.14070189  1.44721949  0.76381029  0.12060162
   0.44220595  0.3417046 ]
 [ 1.48742003 -0.20100271  0.32160433 -0.84421137 -2.55273438  0.66330893
   0.86431164 -0.7437100

In [None]:
errors = result_matrix_C - quantized_scaled_C
print(errors)

[[-2.48220653e+00  5.43417846e+00 -2.80954737e+00  3.82062008e+00
   2.59688754e+00 -2.23383422e-01 -4.62686524e-01  3.21402825e+00]
 [-1.66386260e+00  3.31445267e+00 -1.55923230e-01  2.27202229e+00
   2.12198558e+00  4.81597226e-01 -1.42634661e-01  2.91101719e+00]
 [-1.22715471e+00 -1.76809355e+00 -3.06294614e+00 -2.49804688e+00
  -3.10966713e+00  4.37706035e+00  6.50655176e+00 -2.91986054e+00]
 [-2.79064671e+00  6.21481286e+00 -4.54333766e+00  1.88016744e+00
   4.80865554e+00  8.75480783e-01  2.82810955e+00  3.72431725e+00]
 [-3.38119739e+00  6.75756671e+00  6.81290315e-02  5.57943958e-01
   6.52885073e+00  8.28436324e-01  2.29799312e-01  1.34300378e+00]
 [ 3.48612669e-03 -2.29230771e+00  2.68643765e-01  5.62217049e+00
  -3.48680471e+00 -4.57742377e-01 -2.88556568e+00  8.33519489e-01]
 [ 1.65304363e+00 -3.23966511e+00  2.51002731e+00  2.48751209e+00
  -3.14959102e+00 -2.40869516e+00 -3.75296853e+00 -5.68171171e-01]
 [ 1.98582020e+00 -2.93711519e+00 -3.93927346e-01 -2.61598764e+00
  -

위와 같은 경우에 에러가 매우 큼을 알 수 있었다. 다만 중간의 결과값을 int32로 저장하고 다시 Scaling factor로 계산해주면 비슷한 근사치로 값이 나오게 됨을 알 수 있었다.

In [None]:
intermediate = np.matmul(quantized_X.astype('int32'), quantized_W.astype('int32'))
intermediate = intermediate * S_f16
print(intermediate)

[[-2.43722583  5.45976724 -2.84537811  3.77917639  2.61869519 -0.19777085
  -0.443831    3.1132761 ]
 [-1.65567726  3.26798642 -0.16024668  2.25360344  2.11334915  0.47643397
  -0.14240731  2.88905378]
 [-1.26997797 -1.75225587 -3.08405646 -2.51965598 -3.06467922  4.37218136
   6.56396218 -2.92903856]
 [-2.77186763  6.20871284 -4.55088255  1.90019954  4.82646992  0.8827408
   2.82876905  3.70689622]
 [-3.35810595  6.79525873  0.10826647  0.54840657  6.52213196  0.85751964
   0.19684812  1.30904002]
 [-0.02829692 -2.30558364  0.28789038  5.68768046 -3.45037851 -0.45951734
  -2.94380217  0.89442866]
 [ 1.66551967 -3.20739411  2.54087867  2.50089389 -3.15818208 -2.4326122
  -3.81362481 -0.56747623]
 [ 1.93187978 -2.88444016 -0.36939781 -2.57225133 -2.86813867 -1.00361785
  -0.51764905 -5.01562866]]


In [None]:
errors_with_intermediate = result_matrix_C - intermediate
print(errors_with_intermediate)

[[-4.12897998e-02  1.17025825e-03  2.50656057e-02  4.11361100e-02
  -1.50768879e-03  1.28343253e-02  9.74897150e-03  6.44582767e-02]
 [-1.03383639e-02  7.40420488e-03  1.55526903e-03 -1.14159353e-02
  -3.97414754e-03  9.16172997e-03  1.48438376e-02 -2.33503298e-03]
 [ 3.26732823e-02 -3.58300659e-02  7.88458002e-03  2.16091018e-02
  -3.88364071e-02 -3.23376111e-02 -2.88059328e-02 -6.50831699e-03]
 [-3.47729904e-02 -5.58783544e-03  2.35387980e-02 -9.57453915e-03
   5.56132957e-03  9.34904133e-03 -6.50342026e-03  2.35725302e-02]
 [ 4.59032652e-03 -1.40087275e-02 -1.06102178e-02  6.76921305e-03
  -6.50695658e-03 -3.76954179e-02 -6.41843569e-03  4.43654330e-03]
 [ 3.48612669e-03 -8.86948216e-03 -4.63132317e-02 -4.70554594e-02
  -2.22777407e-02  1.46931188e-02  2.77865422e-02 -3.26122535e-02]
 [-4.80604790e-04 -1.13558928e-02 -9.62866509e-03 -3.21438901e-02
  -2.10485956e-05  3.22215723e-02  2.65154366e-02 -1.69964261e-02]
 [ 1.73389660e-02 -2.37629699e-02 -2.31100079e-04 -1.75924153e-02
  -