## 1. Asymmetric linear quantization

In [7]:
import torch

def linear_q_with_scale_and_zero_point(tensor, scale, zero_point, dtype = torch.int8):
    scaled_and_shifted_tensor = 1 / scale * tensor + zero_point
    rounded_tensor = torch.round(scaled_and_shifted_tensor)

    min_value = torch.iinfo(dtype).min
    max_value = torch.iinfo(dtype).max
    clipped_tensor = rounded_tensor.clamp(min_value, max_value)
    
    quantized_tensor = clipped_tensor.to(dtype=dtype)

    return quantized_tensor

In [3]:
test_tensor=torch.tensor(
    [[191.6, -13.5, 728.6],
     [92.14, 295.5,  -184],
     [0,     684.6, 245.5]]
)

Using random value for `scale` and `zero_point`:

In [5]:
### these are random values for "scale" and "zero_point"
### to test the implementation
scale = 3.5
zero_point = -70

In [8]:
quantized_tensor = linear_q_with_scale_and_zero_point(test_tensor, scale, zero_point)

In [9]:
quantized_tensor

tensor([[ -15,  -74,  127],
        [ -44,   14, -123],
        [ -70,  126,    0]], dtype=torch.int8)

In [15]:
# dequantization
dequantized_tensor = scale * (quantized_tensor.float() - zero_point)
dequantized_tensor

tensor([[ 192.5000,  -14.0000,  689.5000],
        [  91.0000,  294.0000, -185.5000],
        [   0.0000,  686.0000,  245.0000]])

In [16]:
# dequantization error
mse = (test_tensor - dequantized_tensor).square().mean()
mse

tensor(170.8753)

Finding `scale` and `zero_point`:

In [21]:
def get_q_scale_and_zero_point(tensor, dtype=torch.int8):
    rmin = tensor.min().item()
    rmax = tensor.max().item()
    qmin = torch.iinfo(dtype).min
    qmax = torch.iinfo(dtype).max

    scale = (rmax - rmin) / (qmax - qmin)
    zero_point = qmin - rmin / scale

    # clipping zero_point
    if zero_point > qmax:
        zero_point = qmax
    elif zero_point < qmin:
        zero_point = qmin
    else:
        # round and cast to int8
        zero_point = int(round(zero_point))

    return scale, zero_point

In [22]:
new_scale, new_zero_point = get_q_scale_and_zero_point(test_tensor)
print(f"new scale: {new_scale}")
print(f"new zero point: {new_zero_point}")

new scale: 3.578823433670343
new zero point: -77


## 2. Symmetric linear quantization