## Symmetric Quantization 

In [None]:
import torch
# Function to calculate scale in symmetric mode
def get_q_scale_symmetric(tensor, dtype=torch.int8):
    # Get the maximum absolute value in the tensor
    r_max = tensor.abs().max().item()
    print(r_max)
   
    # Get the maximum value for the dtype (int8)
    q_max = torch.iinfo(dtype).max
    print(q_max)
    
    # Calculate and return the scale
    return r_max / q_max

# Test the implementation on a 4x4 matrix
test_tensor = torch.randn((4, 4))
print(test_tensor)
scale = get_q_scale_symmetric(test_tensor)
print(f'Symmetric Scale: {scale}')

tensor([[ 0.3647,  1.1854, -0.4698, -0.3893],
        [-0.7885,  0.2649,  0.4186, -1.1576],
        [-1.4789,  0.0940,  0.1670,  0.3681],
        [ 0.6532,  1.5599, -2.4730, -1.7660]])
2.473046064376831
127
Symmetric Scale: 0.01947280365651048


In [6]:
def linear_q_symmetric(tensor, dtype=torch.int8):
    # Get the scale using the symmetric method
    scale = get_q_scale_symmetric(tensor)
    
    # Quantize the tensor
    quantized_tensor = tensor / scale
    quantized_tensor = torch.round(quantized_tensor).clamp(-128, 127).to(dtype)
    
    return quantized_tensor, scale

# Quantize the test tensor
quantized_tensor, scale = linear_q_symmetric(test_tensor)
print(f'Quantized Tensor (Symmetric):\n{quantized_tensor}')

2.473046064376831
127
Quantized Tensor (Symmetric):
tensor([[  19,   61,  -24,  -20],
        [ -40,   14,   21,  -59],
        [ -76,    5,    9,   19],
        [  34,   80, -127,  -91]], dtype=torch.int8)


## Asymmetric Quantization 

In [10]:
def get_q_scale_and_zero_point_asymmetric(tensor, dtype=torch.int8):
    # Get the min and max values in the tensor
    r_min = tensor.min().item()
    r_max = tensor.max().item()
    
    # Get the min and max values for the dtype (int8)
    q_min = torch.iinfo(dtype).min
    q_max = torch.iinfo(dtype).max
    
    # Calculate scale
    scale = (r_max - r_min) / (q_max - q_min)
    
    # Calculate zero point
    zero_point = q_min - (r_min / scale)
    zero_point = int(round(zero_point))
    
    return scale, zero_point

print(test_tensor)
# Calculate scale and zero point for asymmetric mode
scale, zero_point = get_q_scale_and_zero_point_asymmetric(test_tensor)
print(f'Asymmetric Scale: {scale}, Zero Point: {zero_point}')

tensor([[ 0.3647,  1.1854, -0.4698, -0.3893],
        [-0.7885,  0.2649,  0.4186, -1.1576],
        [-1.4789,  0.0940,  0.1670,  0.3681],
        [ 0.6532,  1.5599, -2.4730, -1.7660]])
Asymmetric Scale: 0.01581564557318594, Zero Point: 28


In [9]:
def linear_q_asymmetric(tensor, dtype=torch.int8):
    # Get scale and zero point using asymmetric method
    scale, zero_point = get_q_scale_and_zero_point_asymmetric(tensor)
    
    # Quantize the tensor
    quantized_tensor = (tensor / scale) + zero_point
    quantized_tensor = torch.round(quantized_tensor).clamp(-128, 127).to(dtype)
    
    return quantized_tensor, scale, zero_point

# Quantize the test tensor in asymmetric mode
quantized_tensor, scale, zero_point = linear_q_asymmetric(test_tensor)
print(f'Quantized Tensor (Asymmetric):\n{quantized_tensor}')

Quantized Tensor (Asymmetric):
tensor([[  51,  103,   -2,    3],
        [ -22,   45,   54,  -45],
        [ -66,   34,   39,   51],
        [  69,  127, -128,  -84]], dtype=torch.int8)
