In [1]:
import torch

### a dummy tensor to test the implementation
test_tensor=torch.tensor(
    [[191.6, -13.5, 728.6],
     [92.14, 295.5,  -184],
     [0,     684.6, 245.5]]
)

  device: torch.device = torch.device(torch._C._get_default_device()),  # torch.device('cpu'),


#### Finding `Scale` and `Zero Point` for Quantization

In [3]:
q_min = torch.iinfo(torch.int8).min
q_max = torch.iinfo(torch.int8).max

In [4]:
q_max #for int8

127

In [5]:
q_min

-128

In [6]:
r_min = test_tensor.min().item()
r_max = test_tensor.max().item()

In [7]:
r_min

-184.0

In [8]:
r_max

728.5999755859375

In [12]:
#scale
scale = (r_max - r_min) / (q_max - q_min)
#zero
zero_point = q_min - (r_min / scale)
zero_point = int(round(zero_point))

In [13]:
scale

3.578823433670343

In [14]:
zero_point

-77

In [15]:
def get_q_scale_and_zero_point(tensor, dtype=torch.int8):
    
    q_min, q_max = torch.iinfo(dtype).min, torch.iinfo(dtype).max
    r_min, r_max = tensor.min().item(), tensor.max().item()

    scale = (r_max - r_min) / (q_max - q_min)

    zero_point = q_min - (r_min / scale)

    # clip the zero_point to fall in [quantized_min, quantized_max]
    if zero_point < q_min:
        zero_point = q_min
    elif zero_point > q_max:
        zero_point = q_max
    else:
        # round and cast to int
        zero_point = int(round(zero_point))
    
    return scale, zero_point

In [16]:
new_scale, new_zero_point = get_q_scale_and_zero_point(test_tensor)

In [17]:
new_scale

3.578823433670343

In [18]:
new_zero_point

-77

#### Quantization and Dequantization with Calculated `Scale` and `Zero Point`

- Use the calculated `scale` and `zero_point` with the functions `linear_q_with_scale_and_zero_point` and `linear_dequantization`.

In [19]:
def linear_quantization(tensor, scale, zero_point, dtype = torch.int8):
    scaled_and_shifted_tensor = tensor/scale + zero_point
    rounded_tensor = torch.round(scaled_and_shifted_tensor)

    q_min = torch.iinfo(dtype).min
    q_max = torch.iinfo(dtype).max

    q_tensor = rounded_tensor.clamp(q_min,q_max).to(dtype)
    
    return q_tensor

quantized_tensor = linear_quantization(test_tensor, new_scale, new_zero_point)

In [20]:
def linear_dequantization(quantized_tensor, scale, zero_point):
    return scale * (quantized_tensor.float() - zero_point)

dequantized_tensor = linear_dequantization(quantized_tensor,new_scale, new_zero_point)

In [21]:
(dequantized_tensor-test_tensor).square().mean() #mse

tensor(1.5730)

#### Put Everything Together: Your Own Linear Quantizer

- Now, put everything togther to make your own Linear Quantizer.

In [26]:
def linear_quantization(tensor, dtype=torch.int8):
    scale, zero_point = get_q_scale_and_zero_point(tensor, 
                                                   dtype=dtype)
    
    quantized_tensor = linear_quantization(tensor,scale,zero_point,dtype=dtype)
    
    return quantized_tensor, scale , zero_point

In [27]:
r_tensor = torch.randn((4, 4))

In [28]:
r_tensor

tensor([[-0.6067, -0.8505, -0.6025, -0.2585],
        [-0.1563,  0.7334,  1.7707,  0.7963],
        [ 0.2709, -1.5073, -0.3815, -0.8398],
        [ 0.1023, -1.0382,  1.1344, -0.5180]])