In [2]:
import torch

In [4]:
def linear_q_with_scale_and_zero_point(
    r_tensor, scale, zero_point, dtype=torch.int8):
    """
    Performs simple linear quantization given
    the scale and zero-point.
    """

    # scale tensor and add the zero point
    scaled_and_shifted_tensor = r_tensor / scale + zero_point

    # round the tensor
    rounded_tensor = torch.round(scaled_and_shifted_tensor)

    # we need to clamp to the min/max value of the specified dtype
    q_min, q_max = torch.iinfo(dtype).min, torch.iinfo(dtype).max
    q_tensor = rounded_tensor.clamp(q_min, q_max).to(dtype)
    return q_tensor


def get_q_scale_symmetric(tensor, dtype=torch.int8):
    r_max = tensor.abs().max().item()
    q_max = torch.iinfo(dtype).max

    # return the scale
    return r_max/q_max


def linear_q_symmetric(tensor, dtype=torch.int8):
    scale = get_q_scale_symmetric(tensor)

    quantized_tensor = linear_q_with_scale_and_zero_point(tensor,
                                                     scale=scale,
                   # in symmetric quantization zero point is = 0
                                                    zero_point=0,
                                                      dtype=dtype)

    return quantized_tensor, scale

## Linear Quantization: Inference
- W8A32 means weights in 8-bits and activations in 32-bits.
- For simplicity, the linear layer will be without bias

In [5]:
def quantized_linear_W8A32_without_bias(input, q_w, s_w, z_w):
    assert input.dtype == torch.float32
    assert q_w.dtype == torch.int8

    dequantized_weight = q_w.to(torch.float32) * s_w + z_w
    output = torch.nn.functional.linear(input, dequantized_weight)

    return output

In [6]:
input = torch.tensor([1, 2, 3], dtype=torch.float32)

In [7]:
weight = torch.tensor([[-2,   -1.13, 0.42],
                       [-1.51, 0.25, 1.62],
                       [0.23,  1.35, 2.15]])

In [9]:
q_w, s_w  = linear_q_symmetric(weight)

In [10]:
q_w

tensor([[-118,  -67,   25],
        [ -89,   15,   96],
        [  14,   80,  127]], dtype=torch.int8)

In [11]:
s_w

0.016929134609192376

In [12]:
output = quantized_linear_W8A32_without_bias(input,
                                             q_w,
                                             s_w,
                                             0)

In [13]:
print(f"This is the W8A32 output: {output}")

This is the W8A32 output: tensor([-2.9965,  3.8768,  9.3957])


In [15]:
fp32_output = torch.nn.functional.linear(input, weight)

In [17]:
print(f"This is the output if we don't quantize: {fp32_output}")


This is the output if we don't quantize: tensor([-3.0000,  3.8500,  9.3800])
