<a href="https://colab.research.google.com/github/kihoon71/quantization_code/blob/main/mixed_precision_decomposition.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import numpy as np

class VectorWiseQuantization:
  def __init__(self, X, W):
    self.X = X
    self.W = W

    # scaling factor vectors
    self.C_x = self.get_abs_max(self.X, axis=1) # by row
    self.C_w = self.get_abs_max(self.W, axis=0) # by column

    #quantized_x, quantized_w
    self.q_x = self.absmax_quantization_x()
    self.q_w = self.absmax_quantization_w()

    #quantized_matrix
    self.quantized_matrix_multiplication = self.quantized_matrix_multiplication(self.q_x, self.q_w)

    #dequantized_matrix
    self.dequantized_matrix = self.dequantization()

  def get_abs_max(self, matrix, axis=0):
    return np.max(np.abs(matrix), axis=axis)

  def get_range_data_type(self, d_type='int8'):
    return float(np.iinfo(d_type).max)

  def get_absmax_scale(self, absmax, dtype='int8'):
    scale = self.get_range_data_type(dtype) / absmax
    return scale.astype('float16')

  def absmax_quantization_x(self):
    scale = self.get_absmax_scale(self.C_x)
    quantized_x = np.round(self.X * scale[:, np.newaxis] )
    return quantized_x.astype('int8')

  def absmax_quantization_w(self):
    scale = self.get_absmax_scale(self.C_w)
    quantized_w = np.round(self.W * scale[np.newaxis, :])
    return quantized_w.astype('int8')

  def quantized_matrix_multiplication(self, x, w):
    ## if we do not type-cast before the matmul overflow issue will come out.
    x_32 = x.astype('int32')
    w_32 = w.astype('int32')
    result = np.dot(x_32, w_32)
    return result

  def dequantization(self):
    outer_product = np.outer(self.C_x, self.C_w)
    matrix_ = self.quantized_matrix_multiplication * outer_product
    matrix_ = matrix_ / (self.get_range_data_type() ** 2)

    return matrix_

In [46]:
class MixedPrecisionDecomposition:
  def __init__(self, X, W):
    self.X = X
    self.W = W

    #outlier matrices
    self.outlier_indices = self.get_outlier_indices(self.X, axis=0)
    self.outlier_x = self.X[:, self.outlier_indices]
    self.outlier_w = self.W[self.outlier_indices, :]

    #reuslt of matmul as float16 with outlier matrices
    self.outlier_result = self.dot_outliers()

    #non-outlier matrices
    self.non_outlier_indices = np.setdiff1d(np.arange(self.X.shape[0]), self.outlier_indices)
    self.non_outlier_x = self.X[:, self.non_outlier_indices]
    self.non_outlier_w = self.W[self.non_outlier_indices, :]

    #quantization matmul matrix
    self.quantized_matrix_multiplication = VectorWiseQuantization(self.non_outlier_x, self.non_outlier_w)

    self.dequantized_matrix = self.quantized_matrix_multiplication.dequantized_matrix

    #dequantization matrix + f16 outlier matmul matrix
    self.mixed_precision_matrix = self.combine()

  def get_outlier_indices(self, matrix, axis=0, threshold=6):

    outlier_indices = np.any(np.abs(matrix) > threshold, axis=axis)
    return np.where(outlier_indices)[0]

  def dot_outliers(self):
    x_f16 = self.outlier_x.astype('float16')
    w_f16 = self.outlier_w.astype('float16')
    result = np.dot(x_f16, w_f16)
    return result

  def combine(self):
    return self.dequantized_matrix + self.outlier_result


In [47]:
import numpy as np

np.random.seed(0)

def get_random_matrix(rows, cols, mean=0, std=1):
  # 원하는 범위 설정
  lower_bound = -10
  upper_bound = 10

  random_matrix = np.random.normal(mean, std, size=(rows, cols))

  random_matrix = np.clip(random_matrix, lower_bound, upper_bound)

  return random_matrix


random_matrix_x = get_random_matrix(5,5, mean=0, std=3)
random_matrix_w = get_random_matrix(5,5, mean=0, std=1)


# 결과 출력
print("랜덤 정규분포 행렬:")
print(random_matrix_x)
print(random_matrix_w)


랜덤 정규분포 행렬:
[[ 5.29215704  1.20047163  2.93621395  6.7226796   5.60267397]
 [-2.93183364  2.85026525 -0.45407162 -0.30965656  1.23179551]
 [ 0.43213071  4.36282052  2.28311318  0.36502505  1.3315897 ]
 [ 1.00102298  4.48223722 -0.61547479  0.9392031  -2.56228722]
 [-7.65896945  1.96085579  2.5933086  -2.22649506  6.80926387]]
[[-1.45436567  0.04575852 -0.18718385  1.53277921  1.46935877]
 [ 0.15494743  0.37816252 -0.88778575 -1.98079647 -0.34791215]
 [ 0.15634897  1.23029068  1.20237985 -0.38732682 -0.30230275]
 [-1.04855297 -1.42001794 -1.70627019  1.9507754  -0.50965218]
 [-0.4380743  -1.25279536  0.77749036 -1.61389785 -0.21274028]]


In [48]:
decompostion = MixedPrecisionDecomposition(random_matrix_x, random_matrix_w)
print(decompostion.outlier_x,'\n', decompostion.outlier_w)

[[ 5.29215704  6.7226796   5.60267397]
 [-2.93183364 -0.30965656  1.23179551]
 [ 0.43213071  0.36502505  1.3315897 ]
 [ 1.00102298  0.9392031  -2.56228722]
 [-7.65896945 -2.22649506  6.80926387]] 
 [[-1.45436567  0.04575852 -0.18718385  1.53277921  1.46935877]
 [-1.04855297 -1.42001794 -1.70627019  1.9507754  -0.50965218]
 [-0.4380743  -1.25279536  0.77749036 -1.61389785 -0.21274028]]


In [35]:
print(decompostion.non_outlier_x)
print(decompostion.non_outlier_w)

[[ 1.20047163  2.93621395]
 [ 2.85026525 -0.45407162]
 [ 4.36282052  2.28311318]
 [ 4.48223722 -0.61547479]
 [ 1.96085579  2.5933086 ]]
[[ 0.15494743  0.37816252 -0.88778575 -1.98079647 -0.34791215]
 [ 0.15634897  1.23029068  1.20237985 -0.38732682 -0.30230275]]


In [49]:
print(decompostion.dequantized_matrix)
print(decompostion.outlier_result)

[[ 0.64556128  4.06660588  2.46051996 -3.52626182 -1.30307252]
 [ 0.37194819  0.52461764 -3.07629584 -5.47077534 -0.85638168]
 [ 1.0312403   4.4377278  -1.15654716 -9.52592347 -2.201107  ]
 [ 0.60146815  0.95525987 -4.71037827 -8.64445336 -1.37862471]
 [ 0.70953812  3.9311347   1.37357548 -4.89413226 -1.46348169]]
[[-17.2    -16.33    -8.1     12.19     3.16  ]
 [  4.047   -1.237    2.033   -7.086   -4.414 ]
 [ -1.595   -2.168    0.3318  -0.7754   0.1656]
 [ -1.318    1.923   -3.781    7.504    1.538 ]
 [ 10.49    -5.72    10.52   -27.08   -11.57  ]]


In [50]:
print(decompostion.mixed_precision_matrix)

[[-16.55756372 -12.26151912  -5.64104254   8.66123818   1.85708373]
 [  4.41882319  -0.71268704  -1.04309271 -12.55671284  -5.27044418]
 [ -0.56348626   2.26975905  -0.82476005 -10.3013141   -2.03545758]
 [ -0.71689122   2.87811143  -8.49162827  -1.14054711   0.15946122]
 [ 11.20172562  -1.7876153   11.89701298 -31.97225726 -13.03379419]]


In [51]:
print(np.dot(random_matrix_x, random_matrix_w))

[[-16.55512064 -12.25679827  -5.64060626   8.66883843   1.85265038]
 [  4.41967919  -0.71840279  -1.04153266 -12.5558374   -5.26652566]
 [ -0.5615892    2.29197021  -0.79650171 -10.30077733  -2.04243318]
 [ -0.71990321   2.85994283  -8.501364    -1.13821841   0.16393162]
 [ 11.19986743  -1.78736878  11.90409671 -31.96087144 -13.03381044]]
