# Quantizing Weights & Activations for Inference

In [1]:
import torch
import torch.nn as nn
from helper import linear_quantization_symm, get_q_scale_symmetric

In [2]:
def quantized_linear_W8A32_without_bias(input,q_w,s_w,z_w):
    assert input.dtype == torch.float32
    assert q_w.dtype == torch.int8
    dequantized_w = s_w*(q_w.to(torch.int32)-z_w)
    output = torch.nn.functional.linear(input,dequantized_w)
    return output

In [3]:
input = torch.tensor([1, 2, 3], dtype=torch.float32)

In [4]:
weight = torch.tensor([[-2,   -1.13, 0.42],
                       [-1.51, 0.25, 1.62],
                       [0.23,  1.35, 2.15]])

In [5]:
q_w, s_w  = linear_quantization_symm(weight)

In [6]:
print ("Quantized Weights =",q_w)
print ("Scale =",s_w)

Quantized Weights = tensor([[-118,  -67,   25],
        [ -89,   15,   96],
        [  14,   80,  127]], dtype=torch.int8)
Scale = 0.016929134609192376


In [7]:
output_w_q = quantized_linear_W8A32_without_bias(input,q_w,s_w,0)

In [8]:
output_wo_q = torch.nn.functional.linear(input,weight)

In [9]:
print ("W8A32 Output:")
print (output_w_q)

W8A32 Output:
tensor([-2.9965,  3.8768,  9.3957])


In [10]:
print ("Output Without Quantization:")
print (output_wo_q)

Output Without Quantization:
tensor([-3.0000,  3.8500,  9.3800])


# Custom Build an 8-Bit Quantizer
We create a class `W8A16LinearLayer`, which will be responsible for quantizing the model.

In [11]:
random_int8 = torch.randint(-128,127, (32,16)).to(dtype = torch.int8)
random_hs = torch.randn((1, 16), dtype=torch.bfloat16)
scales = torch.randn((1, 32), dtype=torch.bfloat16)
bias = torch.randn((1, 32), dtype=torch.bfloat16)

In [12]:
def w8_a16_forward(input, weight, scale, bias = None):
    weight = weight.to(dtype = input.dtype)
    if (bias == None):
        return torch.nn.functional.linear(input,weight)*scale
    return torch.nn.functional.linear(input,weight)*scale + bias

In [13]:
print("With bias:\n\n", w8_a16_forward(random_hs, random_int8, scales, bias))

print("\nWithout bias:\n\n", w8_a16_forward(random_hs, random_int8, scales))

With bias:

 tensor([[ -412.0000,   308.0000,    95.5000,    51.7500,   680.0000,  -124.5000,
          1168.0000,   -43.0000,   165.0000,   324.0000,  -474.0000,  -122.5000,
          -288.0000,   944.0000,   608.0000,   258.0000,   180.0000,  -121.0000,
         -1344.0000,   -33.7500,    20.1250,  -139.0000,   280.0000,  -203.0000,
          -219.0000,  -183.0000,   580.0000,  -123.0000,  -202.0000,    58.0000,
           160.0000,    -6.5312]], dtype=torch.bfloat16)

Without bias:

 tensor([[ -414.0000,   308.0000,    95.5000,    51.7500,   680.0000,  -124.5000,
          1168.0000,   -42.0000,   164.0000,   324.0000,  -474.0000,  -121.0000,
          -288.0000,   944.0000,   608.0000,   258.0000,   180.0000,  -121.5000,
         -1344.0000,   -34.0000,    21.3750,  -141.0000,   280.0000,  -203.0000,
          -220.0000,  -183.0000,   580.0000,  -124.0000,  -200.0000,    59.0000,
           161.0000,    -5.6875]], dtype=torch.bfloat16)


In [14]:
class W8A16LinearLayer(nn.Module):
    def __init__(self,input_features,output_features, bias = True, dtype = torch.float32):
        super().__init__()

        self.register_buffer('int8_weights', torch.randint(-128,127, (output_features,input_features)).to(torch.int8))
        self.register_buffer('scales', torch.randn((output_features),dtype= dtype))
        if bias:
            self.register_buffer('bias', torch.randn((1,output_features), dtype = dtype))
        else:
            self.bias = None


    def quantize(self,weights):
        r_max = weights.clone().abs().max(dim=-1).values
        q_max = torch.iinfo(torch.int8).max
        scales = r_max/q_max
        scales = scales.to(weights.dtype)
        self.scales = r_max/q_max
        self.int8_weights = torch.round(weights/self.scales.unsqueeze(1)).to(torch.int8)


    def forward(self, input):
        return w8_a16_forward(input, self.int8_weights, self.scales, self.bias)



In [15]:
module = W8A16LinearLayer(4, 8)

In [16]:
print (module.int8_weights)

tensor([[  71,  -78,   -4,   10],
        [ 101,  -21,  -42,   30],
        [-100,   61,  -76,  -44],
        [  82,   20,   90,  -45],
        [ -86,   47,   97,  -35],
        [  14,   35,  -56,    9],
        [ -47,  -31,  -51,  -43],
        [ -59,  -94,  -93, -107]], dtype=torch.int8)


In [17]:
# random_matrix = torch.randn((4, 8), dtype=torch.bfloat16)
random_matrix = torch.tensor([[ 0.4668,  0.3750, -0.2969,  0.9180,  1.2578,  0.1089,  0.0000,  1.0703],
        [-0.9805,  0.0649, -0.6484, -0.1465, -1.4219,  0.0615,  0.0000,  0.0000],
        [-0.5039, -1.9219,  0.0713,  1.0625,  0.0347, -1.3203, -0.1699,  1.4922],
        [ 0.0403, -0.7891, -0.5391,  0.9141,  1.4531, -1.0859, -0.3398, -0.4336]],
       dtype=torch.bfloat16)

In [18]:
module.quantize(random_matrix)
print (module.int8_weights)

tensor([[  47,   38,  -30,   93,  127,   11,    0,  108],
        [ -88,    6,  -58,  -13, -128,    6,    0,    0],
        [ -33, -127,    5,   70,    2,  -87,  -11,   98],
        [   4,  -69,  -47,   80, -128,  -95,  -30,  -38]], dtype=torch.int8)


In [19]:
module.scales

tensor([0.0099, 0.0112, 0.0151, 0.0114], dtype=torch.bfloat16)

In [20]:
print (module.int8_weights.shape)
print (module.scales.shape)

torch.Size([4, 8])
torch.Size([4])


In [21]:
dequantized_wt = module.scales.unsqueeze(1) *  module.int8_weights
print (dequantized_wt)

tensor([[ 0.4648,  0.3750, -0.2969,  0.9180,  1.2578,  0.1089,  0.0000,  1.0703],
        [-0.9844,  0.0669, -0.6484, -0.1455, -1.4297,  0.0669,  0.0000,  0.0000],
        [-0.5000, -1.9219,  0.0757,  1.0625,  0.0303, -1.3203, -0.1660,  1.4844],
        [ 0.0457, -0.7891, -0.5352,  0.9141, -1.4609, -1.0859, -0.3418, -0.4336]],
       dtype=torch.bfloat16)


In [22]:
print (random_matrix)

tensor([[ 0.4668,  0.3750, -0.2969,  0.9180,  1.2578,  0.1089,  0.0000,  1.0703],
        [-0.9805,  0.0649, -0.6484, -0.1465, -1.4219,  0.0615,  0.0000,  0.0000],
        [-0.5039, -1.9219,  0.0713,  1.0625,  0.0347, -1.3203, -0.1699,  1.4922],
        [ 0.0403, -0.7891, -0.5391,  0.9141,  1.4531, -1.0859, -0.3398, -0.4336]],
       dtype=torch.bfloat16)


# Quantization Pipeline
- Replace all of the `torch.nn.Linear` layers with the `W8A16LinearLayer` layer.
- Call `quantize` on the linear layers using the original weights.

## 1. Linear Layer Replacement

In [23]:
def replace_linear_with_target(module,target_class, excluded_module):
    for name,child in module.named_children():
        if (isinstance(child, nn.Linear) and not any([x==name for x in excluded_module])):
            old_bias = child.bias
            new_module = target_class(child.in_features,
                                      child.out_features,
                                      old_bias is not None,
                                      child.weight.dtype)
            setattr(module, name, new_module)
            if old_bias != None:
                getattr(module,name).bias = old_bias
        else:
            replace_linear_with_target(child,target_class,excluded_module)

In [24]:
class Test_Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.emb = torch.nn.Embedding(3,1)
        self.l1 = nn.Linear(2,3)
        self.l2 = nn.Linear(14,4, bias= False)
        self.lm_head = nn.Linear(2, 1, bias=False)

t1 = Test_Model()

In [25]:
replace_linear_with_target(t1,W8A16LinearLayer,[])
print (t1)

Test_Model(
  (emb): Embedding(3, 1)
  (l1): W8A16LinearLayer()
  (l2): W8A16LinearLayer()
  (lm_head): W8A16LinearLayer()
)


## 2. Layer Replacement + Quantization

In [26]:
def replace_linear_with_target_and_quantize(module,target_class, excluded_module):
    for name,child in module.named_children():
        if (isinstance(child, nn.Linear) and not any([x==name for x in excluded_module])):
            old_bias = child.bias
            old_wt = child.weight
            new_module = target_class(child.in_features,
                                      child.out_features,
                                      old_bias is not None,
                                      child.weight.dtype)
            new_module.quantize(old_wt)
            setattr(module, name, new_module)

            # print("Before Wt: ", old_wt)
            # print("After Wt: ", getattr(module,name).int8_weights, '\n')

            if old_bias != None:
                getattr(module,name).bias = old_bias

        else:
            replace_linear_with_target(child,target_class,excluded_module)

In [27]:
class Test_Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.emb = torch.nn.Embedding(4,4)
        self.l1 = nn.Linear(4,4)
        self.l2 = nn.Linear(4,4, bias= False)
        self.lm_head = nn.Linear(4,4, bias=False)

t1 = Test_Model()

In [28]:
print (t1)

Test_Model(
  (emb): Embedding(4, 4)
  (l1): Linear(in_features=4, out_features=4, bias=True)
  (l2): Linear(in_features=4, out_features=4, bias=False)
  (lm_head): Linear(in_features=4, out_features=4, bias=False)
)


In [29]:
replace_linear_with_target_and_quantize(t1,W8A16LinearLayer, ['lm_head'])
print (t1)

Test_Model(
  (emb): Embedding(4, 4)
  (l1): W8A16LinearLayer()
  (l2): W8A16LinearLayer()
  (lm_head): Linear(in_features=4, out_features=4, bias=False)
)


We can observe that our function is able to replace the required nn.Linear layers in the model t1