# Custom Build an 8-bit Quantizer

W8A16LinearLayer
                    # 8-bit  # 16-bit         # optional
* w8_a16_forward -> weights, input,   scales, bias=None

* Cast the 8-bit weights to the same data type as the input, "casted weights", keeping the "casted weights" in the same range as before, [-128, 127]

* Next,
((𝑖𝑛𝑝𝑢𝑡𝑠⋅``casted weights'')∗𝑠𝑐𝑎𝑙𝑒)+𝑏𝑖𝑎𝑠

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

random_int8 = torch.randint(-128, 127, (32, 16)).to(torch.int8)
random_hs = torch.randn((1, 16), dtype=torch.bfloat16)
scales = torch.randn((1, 32), dtype=torch.bfloat16)
bias = torch.randn((1, 32), dtype=torch.bfloat16)

In [2]:
F.linear(random_hs, random_int8.to(random_hs.dtype))

tensor([[-478.0000,  280.0000, -362.0000, -414.0000,  121.5000, -262.0000,
           41.2500,   75.5000, -151.0000,  -78.0000,   39.5000, -117.0000,
         -164.0000, -132.0000, -344.0000, -209.0000,   38.7500,  316.0000,
          220.0000,   -6.1875,  252.0000,   72.0000,  -57.7500, -123.5000,
          -33.7500,    6.4062, -254.0000,   -9.1875,  187.0000,  -27.7500,
          -19.5000,  260.0000]], dtype=torch.bfloat16)

In [4]:
(F.linear(random_hs, random_int8.to(random_hs.dtype)) * scales) + bias  

tensor([[-732.0000, -243.0000, -124.5000,  328.0000,    3.7812,  -71.5000,
          -58.5000, -157.0000,  169.0000,  176.0000,    8.8750,   19.7500,
           89.0000, -119.0000,  145.0000,   32.0000,   42.7500, -540.0000,
           -7.5312,   -4.3125,   26.0000, -173.0000,   31.8750,  -33.2500,
          -45.2500,   11.1250,   54.2500,   -4.6562,  -65.5000,   34.5000,
           -9.8125, -410.0000]], dtype=torch.bfloat16)

In [5]:
def w8_a16_forward(weight, input, scales, bias = None):
    casted_weights = weight.to(input.dtype)
    output = F.linear(input, casted_weights) * scales 
    if bias is not None:
        output = output + bias 
    return output 

In [17]:
import torch
import torch.nn as nn

class W8A16LinearLayer(nn.Module):
    def __init__(self, in_features, out_features, bias = True, dtype=torch.float32):
        super().__init__()

        self.register_buffer(
            "int8_weights",
            torch.randint(
                -128, 127, (out_features, in_features), dtype=torch.int8
            )
        )

        self.register_buffer("scales",
                             torch.randn((out_features), dtype = dtype ))
        
        if bias: 
            self.register_buffer("bias", 
                                 torch.randn((1,out_features), dtype=dtype))
        else:
            self.bias = None 
    
    def forward(self, input):
        return w8_a16_forward(self.int8_weights, input, self.scales, self.bias)
    
    def quantize(self, weights):
        w_fp32 = weights.clone().to(torch.float32)

        scales = w_fp32.abs().max(dim=-1).values / 127
        scales = scales.to(weights.dtype)

        int8_weights = torch.round(weights/scales.unsqueeze(1)).to(torch.int8)

        self.int8_weights = int8_weights
        self.scales = scales
        


In [11]:
dummy_instance = W8A16LinearLayer(16,32)

In [12]:
print(dummy_instance.int8_weights.shape)
print(dummy_instance.scales.shape)

torch.Size([32, 16])
torch.Size([32])


In [14]:
module = W8A16LinearLayer(16,32)
dummy_hidden_states = torch.randn(1,6,16)

In [16]:
module(dummy_hidden_states).shape

torch.Size([1, 6, 32])

In [18]:
module = W8A16LinearLayer(4, 8)

In [19]:
print("Weights before:\n" , module.int8_weights)

Weights before:
 tensor([[-114,  -30,    5,   -6],
        [-126,  109,  111,  -76],
        [  74,   19,   48,   29],
        [  96,  -55,  -58,   48],
        [ -35,   46,   64,   96],
        [ -83,   87,   38,  -48],
        [ -20,   68,   45,   48],
        [  -2,  -16, -112,   83]], dtype=torch.int8)


In [20]:
random_matrix = torch.randn((4, 8), dtype=torch.bfloat16)

In [21]:
module.quantize(random_matrix)

In [22]:
print("Weights After:\n" , module.int8_weights)

Weights After:
 tensor([[  18,  127,   50,   70,   -7,  -75,  -82,    8],
        [ -89,  127,  -37,  -21,   53,  -76,   32,   31],
        [-127,   -4,  -47,   -6,  -64, -104,  -69,  -49],
        [ -98,    5, -127,  -49,  -36,  -12,   69,  -70]], dtype=torch.int8)


In [23]:
module.scales

tensor([0.0167, 0.0203, 0.0137, 0.0145], dtype=torch.bfloat16)

In [24]:
module.scales.shape

torch.Size([4])

In [25]:
module.int8_weights.shape

torch.Size([4, 8])

In [26]:
### dequantized weights
module.int8_weights * module.scales.unsqueeze(1)

tensor([[ 0.3008,  2.1250,  0.8359,  1.1719, -0.1172, -1.2578, -1.3750,  0.1338],
        [-1.8047,  2.5781, -0.7500, -0.4258,  1.0703, -1.5391,  0.6484,  0.6289],
        [-1.7422, -0.0549, -0.6445, -0.0825, -0.8789, -1.4297, -0.9492, -0.6719],
        [-1.4219,  0.0728, -1.8438, -0.7109, -0.5234, -0.1738,  1.0000, -1.0156]],
       dtype=torch.bfloat16)

In [27]:
### original weights
random_matrix

tensor([[ 0.2969,  2.1250,  0.8438,  1.1641, -0.1099, -1.2578, -1.3750,  0.1387],
        [-1.8047,  2.5781, -0.7422, -0.4180,  1.0781, -1.5547,  0.6406,  0.6211],
        [-1.7422, -0.0566, -0.6445, -0.0850, -0.8750, -1.4375, -0.9453, -0.6758],
        [-1.4219,  0.0679, -1.8438, -0.7148, -0.5312, -0.1787,  1.0000, -1.0156]],
       dtype=torch.bfloat16)

In [28]:
(random_matrix - module.int8_weights 
 * module.scales.unsqueeze(1)).abs().mean()

tensor(0.0041, dtype=torch.bfloat16)