In [None]:
# For Python 3.9.18
#accelerate==0.26.1
#seaborn==0.13.1
#torch==2.1.1
#transformers==4.35.0

# Linear Quantization I: Asymmetric mode
Quantize and De-quantize a Tensor

In [None]:
import torch

## Quantization with Random Scale and Zero Point
- Implement Linear Quantization for when the "scale" and the "zero point" are known/randomly selected.

In [None]:
def linear_q_with_scale_and_zero_point(tensor, scale, zero_point, dtype = torch.int8):

    scaled_and_shifted_tensor = tensor / scale + zero_point

    rounded_tensor = torch.round(scaled_and_shifted_tensor)

    q_min = torch.iinfo(dtype).min
    q_max = torch.iinfo(dtype).max

    q_tensor = rounded_tensor.clamp(q_min,q_max).to(dtype)
    
    return q_tensor

In [None]:
### a dummy tensor to test the implementation
test_tensor=torch.tensor(
    [[191.6, -13.5, 728.6],
     [92.14, 295.5,  -184],
     [0,     684.6, 245.5]]
)

In [None]:
### these are random values for "scale" and "zero_point" to test the implementation
scale = 3.5
zero_point = -70

In [None]:
quantized_tensor = linear_q_with_scale_and_zero_point(test_tensor, scale, zero_point)
quantized_tensor

## Dequantization with Random Scale and Zero Point
- Now, Dequantize the tensor to see how precise the quantization is.

In [None]:
dequantized_tensor = scale * (quantized_tensor.float() - zero_point)
dequantized_tensor

In [None]:
### without casting to float
scale * (quantized_tensor - zero_point)

In [None]:
def linear_dequantization(quantized_tensor, scale, zero_point):
    return scale * (quantized_tensor.float() - zero_point)

In [None]:
dequantized_tensor = linear_dequantization(quantized_tensor, scale, zero_point)
dequantized_tensor

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

def plot_matrix(tensor, ax, title, vmin=0, vmax=1, cmap=None):
    """
    Plot a heatmap of tensors using seaborn
    """
    sns.heatmap(tensor.cpu().numpy(), ax=ax, vmin=vmin, vmax=vmax, cmap=cmap, annot=True, fmt=".2f", cbar=False)
    ax.set_title(title)
    ax.set_yticklabels([])
    ax.set_xticklabels([])


def plot_quantization_errors(original_tensor, quantized_tensor, dequantized_tensor, dtype = torch.int8, n_bits = 8):
    """
    A method that plots 4 matrices, the original tensor, the quantized tensor
    the de-quantized tensor and the error tensor.
    """
    # Get a figure of 4 plots
    fig, axes = plt.subplots(1, 4, figsize=(15, 4))

    # Plot the first matrix
    plot_matrix(original_tensor, axes[0], 'Original Tensor', cmap=ListedColormap(['white']))

    # Get the quantization range and plot the quantized tensor
    q_min, q_max = torch.iinfo(dtype).min, torch.iinfo(dtype).max
    plot_matrix(quantized_tensor, axes[1], f'{n_bits}-bit Linear Quantized Tensor', vmin=q_min, vmax=q_max, cmap='coolwarm')

    # Plot the de-quantized tensors
    plot_matrix(dequantized_tensor, axes[2], 'Dequantized Tensor', cmap='coolwarm')

    # Get the quantization errors
    q_error_tensor = abs(original_tensor - dequantized_tensor)
    plot_matrix(q_error_tensor, axes[3], 'Quantization Error Tensor', cmap=ListedColormap(['white']))

    fig.tight_layout()
    plt.show()

In [None]:
plot_quantization_errors(test_tensor, quantized_tensor, dequantized_tensor)

In [None]:
(dequantized_tensor - test_tensor).square().mean()

# Linear Quantization I: Get the Scale and Zero Point

In [None]:
### a dummy tensor to test the implementation
test_tensor=torch.tensor(
    [[191.6, -13.5, 728.6],
     [92.14, 295.5,  -184],
     [0,     684.6, 245.5]]
)

## Finding Scale and Zero Point for Quantization

In [None]:
q_min = torch.iinfo(torch.int8).min
q_max = torch.iinfo(torch.int8).max
print(q_min, "\n", q_max)

In [None]:
r_min = test_tensor.min().item()
r_max = test_tensor.max().item()
print(r_min, "\n", r_max)

In [None]:
scale = (r_max - r_min) / (q_max - q_min)
scale

In [None]:
zero_point = q_min - (r_min / scale)
print(zero_point)
zero_point = int(round(zero_point))
print(zero_point)

In [None]:
def get_q_scale_and_zero_point(tensor, dtype=torch.int8):
    
    q_min, q_max = torch.iinfo(dtype).min, torch.iinfo(dtype).max
    r_min, r_max = tensor.min().item(), tensor.max().item()

    scale = (r_max - r_min) / (q_max - q_min)

    zero_point = q_min - (r_min / scale)

    # clip the zero_point to fall in [quantized_min, quantized_max]
    if zero_point < q_min:
        zero_point = q_min
    elif zero_point > q_max:
        zero_point = q_max
    else:
        # round and cast to int
        zero_point = int(round(zero_point))
    
    return scale, zero_point

In [None]:
new_scale, new_zero_point = get_q_scale_and_zero_point(test_tensor)
print(new_scale, "\n", new_zero_point)

## Quantization and Dequantization with Calculated Scale and Zero Point

In [None]:
quantized_tensor = linear_q_with_scale_and_zero_point(test_tensor, new_scale, new_zero_point)
dequantized_tensor = linear_dequantization(quantized_tensor, new_scale, new_zero_point)

In [None]:
plot_quantization_errors(test_tensor, quantized_tensor, dequantized_tensor)

In [None]:
(dequantized_tensor-test_tensor).square().mean()

## Everything together

In [None]:
def linear_quantization(tensor, dtype=torch.int8):
    scale, zero_point = get_q_scale_and_zero_point(tensor, 
                                                   dtype=dtype)
    
    quantized_tensor = linear_q_with_scale_and_zero_point(tensor,
                                                          scale, 
                                                          zero_point, 
                                                          dtype=dtype)
    
    return quantized_tensor, scale , zero_point

In [None]:
r_tensor = torch.randn((4, 4))
r_tensor

In [None]:
quantized_tensor, scale, zero_point = linear_quantization(r_tensor)
print(quantized_tensor, "\n", scale, "\n", zero_point)

In [None]:
dequantized_tensor = linear_dequantization(quantized_tensor,
                                           scale, zero_point)
dequantized_tensor

In [None]:
plot_quantization_errors(r_tensor, quantized_tensor, dequantized_tensor)

In [None]:
(dequantized_tensor-r_tensor).square().mean()

# Linear Quantization: Symmetric Mode

In asymmetric the r_min is mapped to q_min and r_max is mapped to q_max 

However in symmetric quantization we map -r_max to -q_max and r_max to q_max. r_max is max(abs(tensor)). Since the floating and quantized ranges are symmetric therefore zero is always zero. Hence this type of quantization has only scale.

In [None]:
def get_q_scale_symmetric(tensor, dtype=torch.int8):
    r_max = tensor.abs().max().item()
    q_max = torch.iinfo(dtype).max

    # return the scale
    return r_max/q_max

In [None]:
### test the implementation on a 4x4 matrix
test_tensor = torch.randn((4, 4))
test_tensor

In [None]:
get_q_scale_symmetric(test_tensor)

In [None]:
def linear_q_symmetric(tensor, dtype=torch.int8):
    scale = get_q_scale_symmetric(tensor)
    
    quantized_tensor = linear_q_with_scale_and_zero_point(tensor,
                                                     scale=scale,
                   # in symmetric quantization zero point is = 0    
                                                    zero_point=0,
                                                      dtype=dtype)
    
    return quantized_tensor, scale

In [None]:
quantized_tensor, scale = linear_q_symmetric(test_tensor)

In [None]:
dequantized_tensor = linear_dequantization(quantized_tensor,scale,0)

In [None]:
plot_quantization_errors(
    test_tensor, quantized_tensor, dequantized_tensor)

In [None]:
def quantization_error(tensor, dequantized_tensor):
    return (dequantized_tensor - tensor).abs().square().mean()

In [None]:
print(f"""Quantization Error : \
{quantization_error(test_tensor, dequantized_tensor)}""")

## Granularity of quantization
All above was Per Tensor quantization i.e. the entire tensor will have one scale and zero point.

Other types of granulairty are,
- Per channel: where each axis is quantized separately i.e. has it's own scale and zero point
- Per Group: Where a group of numbers are quantized separately i.e. has it's own scale and zero point

# Linear Quantization II: Per Channel Quantization

In [None]:
def linear_q_symmetric_per_channel(r_tensor,dim,dtype=torch.int8):
    output_dim = r_tensor.shape[dim]
    scale = torch.zeros(output_dim)
    for index in range(output_dim):
        sub_tensor = r_tensor.select(dim,index)
        # print(sub_tensor)
        scale[index] = get_q_scale_symmetric(sub_tensor, dtype=dtype)
    # reshape the scale
    scale_shape = [1] * r_tensor.dim()
    scale_shape[dim] = -1
    scale = scale.view(scale_shape)
    quantized_tensor = linear_q_with_scale_and_zero_point(
        r_tensor, scale=scale, zero_point=0, dtype=dtype)
    return quantized_tensor, scale

In [None]:
test_tensor=torch.tensor(
    [[191.6, -13.5, 728.6],
     [92.14, 295.5,  -184],
     [0,     684.6, 245.5]]
)

In [None]:
### along the rows (dim = 0)
quantized_tensor_0, scale_0 = linear_q_symmetric_per_channel(
    test_tensor, dim=0)

### along the columns (dim = 1)
quantized_tensor_1, scale_1 = linear_q_symmetric_per_channel(
    test_tensor, dim=1)

In [None]:
dequantized_tensor_0 = linear_dequantization(
    quantized_tensor_0, scale_0, 0)

plot_quantization_errors(
    test_tensor, quantized_tensor_0, dequantized_tensor_0)

In [None]:
print(f"""Quantization Error : \
{quantization_error(test_tensor, dequantized_tensor_0)}""")

In [None]:
dequantized_tensor_1 = linear_dequantization(
    quantized_tensor_1, scale_1, 0)

plot_quantization_errors(
    test_tensor, quantized_tensor_1, dequantized_tensor_1, n_bits=8)

print(f"""Quantization Error : \
{quantization_error(test_tensor, dequantized_tensor_1)}""")

# Linear Quantization II: Per Group Quantization

In [None]:
def linear_q_symmetric_per_group(tensor, group_size,
                                 dtype=torch.int8):
    
    t_shape = tensor.shape
    assert t_shape[1] % group_size == 0  # Only rows are grouped and quantized
    assert tensor.dim() == 2
    
    tensor = tensor.view(-1, group_size)
    
    quantized_tensor, scale = linear_q_symmetric_per_channel(
                                tensor, dim=0, dtype=dtype)
    
    quantized_tensor = quantized_tensor.view(t_shape)
    
    return quantized_tensor, scale

In [None]:
def linear_dequantization_per_group(quantized_tensor, scale, 
                                    group_size):
    
    q_shape = quantized_tensor.shape
    quantized_tensor = quantized_tensor.view(-1, group_size)
    
    dequantized_tensor = linear_dequantization(quantized_tensor, 
                                               scale, 0)
    
    dequantized_tensor = dequantized_tensor.view(q_shape)
    
    return dequantized_tensor

In [None]:
test_tensor = torch.rand((6, 6))

In [None]:
group_size = 3

In [None]:
quantized_tensor, scale = linear_q_symmetric_per_group(
    test_tensor, group_size=group_size)

dequantized_tensor = linear_dequantization_per_group(
    quantized_tensor, scale, group_size=group_size)

plot_quantization_errors(
    test_tensor, quantized_tensor, dequantized_tensor)

In [None]:
print(f"""Quantization Error : \
{quantization_error(test_tensor, dequantized_tensor)}""")

We can either quantize the weights or we can quantize both weights and activations.

- When we quantize weights only then we need to dequantize before performing the multiplication with activation.
- We can also quanitze the weights and activations. Quantizing both results in integer multiplication and addition which is not supported by all hardwares.

In [None]:
def quantized_linear_W8A32_without_bias(input, q_w, s_w, z_w):
    assert input.dtype == torch.float32
    assert q_w.dtype == torch.int8

    dequantized_weight = q_w.to(torch.float32) * s_w + z_w
    output = torch.nn.functional.linear(input, dequantized_weight)
    
    return output

In [None]:
input = torch.tensor([1, 2, 3], dtype=torch.float32)
weight = torch.tensor([[-2,   -1.13, 0.42],
                       [-1.51, 0.25, 1.62],
                       [0.23,  1.35, 2.15]])

In [None]:
q_w, s_w  = linear_q_symmetric(weight)

In [None]:
output = quantized_linear_W8A32_without_bias(input,
                                             q_w,
                                             s_w,
                                             0)

In [None]:
print(f"This is the W8A32 output: {output}")

In [None]:
fp32_output = torch.nn.functional.linear(input, weight)
print(f"This is the output if we don't quantize: {fp32_output}")

# Custom Build an 8-Bit Quantizer

For inference only

In [None]:
import torch.nn as nn
import torch.nn.functional as F

In [None]:
random_int8 = torch.randint(-128, 127, (32, 16)).to(torch.int8)
random_hs = torch.randn((1, 16), dtype=torch.bfloat16)
scales = torch.randn((1, 32), dtype=torch.bfloat16)
bias = torch.randn((1, 32), dtype=torch.bfloat16)

In [None]:
def w8_a16_forward(weight, input, scales, bias=None):
    
    casted_weights = weight.to(input.dtype)
    output = F.linear(input, casted_weights) * scales
    
    if bias is not None:
        output = output + bias
      
    return output

In [None]:
print("With bias:\n\n", 
      w8_a16_forward(random_int8, random_hs, scales, bias))

print("\nWithout bias:\n\n", 
      w8_a16_forward(random_int8, random_hs, scales))

In [None]:
class W8A16LinearLayer(nn.Module):
    def __init__(self, in_features, out_features, 
                 bias=True, dtype=torch.float32):
        super().__init__()
        
        
        self.register_buffer(
            "int8_weights",
            torch.randint(
                -128, 127, (out_features, in_features), dtype=torch.int8
            )
        )
        
        self.register_buffer("scales", 
                             torch.randn((out_features), dtype=dtype))
        
        if bias:
            self.register_buffer("bias", 
                                 torch.randn((1, out_features), 
                                             dtype=dtype))
        
        else:
            self.bias = None

    def quantize(self, weights):
        w_fp32 = weights.clone().to(torch.float32)

        scales = w_fp32.abs().max(dim=-1).values / 127
        scales = scales.to(weights.dtype)

        int8_weights = torch.round(weights
                        /scales.unsqueeze(1)).to(torch.int8)

        self.int8_weights = int8_weights
        self.scales = scales
    
    def forward(self, input):
        return w8_a16_forward(self.int8_weights, 
                              input, self.scales, self.bias)

In [None]:
module = W8A16LinearLayer(4, 8)
print("Weights before:\n" , module.int8_weights)

In [None]:
random_matrix = torch.randn((4, 8), dtype=torch.bfloat16)
module.quantize(random_matrix)
print("Weights After:\n" , module.int8_weights)

In [None]:
print(module.scales)
print(module.scales.shape)
print(module.int8_weights.shape)

In [None]:
(random_matrix - module.int8_weights 
 * module.scales.unsqueeze(1)).abs().mean()

# Replace PyTorch layers with Quantized Layers

In [None]:
def replace_linear_with_target_and_quantize(module, 
                               target_class, module_name_to_exclude):
    for name, child in module.named_children():
        if isinstance(child, nn.Linear) and not \
        any([x == name for x in module_name_to_exclude]):
            old_bias = child.bias
            old_weight = child.weight

            new_module = target_class(child.in_features, 
                                      child.out_features, 
                                      old_bias is not None, 
                                      child.weight.dtype)
            setattr(module, name, new_module)

            getattr(module, name).quantize(old_weight)
            
            if old_bias is not None:
              getattr(module, name).bias = old_bias
        else:
            # Recursively call the function for nested modules
            replace_linear_with_target_and_quantize(child, 
                     target_class, module_name_to_exclude)

In [None]:
class DummyModel(torch.nn.Module):
  def __init__(self):
    super().__init__()
    self.emb = torch.nn.Embedding(1, 1)
    # Try with bias
    self.linear_1 = nn.Linear(1, 1)
    # Try without bias
    self.linear_2 = nn.Linear(1, 1, bias=False)
    # Lm prediction head
    self.lm_head = nn.Linear(1, 1, bias=False)

In [None]:
model_1 = DummyModel()

In [None]:
replace_linear_with_target_and_quantize(model_1, W8A16LinearLayer, ["lm_head"])
print(model_1)

# Packing 2-bit Weights

In [None]:
# Example Tensor: [1, 0, 3, 2]
    # 1 0 3 2 - 01 00 11 10

    # Starting point of packed int8 Tensor
    # [0000 0000]

    ##### First Iteration Start:
    # packed int8 Tensor State: [0000 0000]
    # 1 = 0000 0001
    # 0000 0001
    # No left shifts in the First Iteration
    # After bit-wise OR operation between 0000 0000 and 0000 0001:
    # packed int8 Tensor State: 0000 0001
    ##### First Iteration End

    ##### Second Iteration Start:
    # packed int8 Tensor State: [0000 0001]
    # 0 = 0000 0000
    # 0000 0000
    # 2 left shifts:
    # [0000 0000] (1 shift)-> 0000 0000 (2 shift)-> 0000 0000
    # After bit-wise OR operation between 0000 0001 and 0000 0000:
    # packed int8 Tensor State: 0000 0001
    ##### Second Iteration End

    ##### Third Iteration Start:
    # packed int8 Tensor State: [0000 0001]
    # 3 = 0000 0011
    # 0000 0011
    # 4 left shifts:
    # [0000 0011] (1 shift)-> 0000 0110 (2 shift)-> 0000 1100
    # 0000 1100 (3 shift)-> 0001 1000 (4 shift)-> 0011 0000
    # After bit-wise OR operation between 0000 0001 and 0011 0000:
    # packed int8 Tensor State: 0011 0001
    ##### Third Iteration End

    ##### Fourth Iteration Start:
    # packed int8 Tensor State: [0011 0001]
    # 2 = 0000 0010
    # 0000 0010
    # 6 left shifts:
    # [0000 0010] (1 shift)-> 0000 0100 (2 shift)-> 0000 1000
    # 0000 1000 (3 shift)-> 0001 0000 (4 shift)-> 0010 0000
    # 0010 0000 (5 shift)-> 0100 0000 (6 shift)-> 1000 0000
    # After bit-wise OR operation between 0011 0001 and 1000 0000:
    # packed int8 Tensor State: 1011 0001
    ##### Fourth Iteration End

    # Final packed int8 Tensor State: [1011 0001]

In [None]:
def pack_weights(uint8tensor, bits):
    if uint8tensor.shape[0] * bits % 8 != 0:
        raise ValueError(f"The input shape needs to be a mutiple \
        of {8 / bits} - got {uint8tensor.shape[0]}")

    num_values = uint8tensor.shape[0] * bits // 8

    num_steps = 8 // bits

    unpacked_idx = 0

    packed_tensor = torch.zeros((num_values), dtype=torch.uint8)

    # 1 0 3 2 - 01 00 11 10

    # [0000 0000] -> 0000 0001

    # 0000 0001

    # 0000 0000 - 0000 0000

    # 0000 0011 - 0011 0000 - 0011 0001

    # 1011 0001
    
    for i in range(num_values):
        for j in range(num_steps):
            packed_tensor[i] |= uint8tensor[unpacked_idx] << (bits * j)
            unpacked_idx += 1
    return packed_tensor

In [None]:
unpacked_tensor = torch.tensor([1, 0, 3, 2], dtype=torch.uint8)

In [None]:
pack_weights(unpacked_tensor, 2)

In [None]:
unpacked_tensor = torch.tensor([1, 0, 3, 2, 3, 3, 3, 3], dtype=torch.uint8)

In [None]:
pack_weights(unpacked_tensor, 2)

In [None]:
def unpack_weights(uint8tensor, bits):
    num_values = uint8tensor.shape[0] * 8 // bits

    num_steps = 8 // bits

    unpacked_tensor = torch.zeros((num_values), dtype=torch.uint8)

    unpacked_idx = 0

    # 1 0 3 2 - 01 00 11 10

    # [00000000 00000000 00000000 00000000]
    # [10110001 00101100 00001011 00000010]
    # [00000001 00000000 00000011 00000010]

    # 10110001
    # 00000011
    
    # 00000001

    # 1: [10110001]
    # 2: [00101100]
    # 3: [00001011]

    mask = 2 ** bits - 1

    for i in range(uint8tensor.shape[0]):
        for j in range(num_steps):
            unpacked_tensor[unpacked_idx] |= uint8tensor[i] >> (bits * j)
            unpacked_idx += 1

    unpacked_tensor &= mask
    return unpacked_tensor

In [None]:
unpacked_tensor = torch.tensor([177, 255], dtype=torch.uint8)

In [None]:
# Answer should be: torch.tensor([1, 0, 3, 2, 3, 3, 3, 3]
unpack_weights(unpacked_tensor, 2)