# L4-A - Building your own Quantizer: Custom Build an 8-Bit Quantizer

In this lesson, you will learn how to compress any model in 8-bit precision.

## Step 1: class `W8A16LinearLayer`

- Build the target class, `W8A16LinearLayer()`, that will be responsible for quantizing your model.

### 1.1 - `w8_a16_forward` Function

-
```Python
W8A16LinearLayer
                    # 8-bit  # 16-bit         # optional
* w8_a16_forward -> weights, input,   scales, bias=None
                    
```
- Cast the 8-bit `weights` to the same data type as the `input`, "casted weights",
- keeping the "casted weights" in the same range as before, [-128, 127]
- Next, $$(({inputs} \cdot \text{``casted weights''}) * {scale}) + {bias}$$ 

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
random_int8 = torch.randint(low=-128, high=127, size=(32, 16)).to(torch.int8)
random_hs = torch.randn(size=(1, 16), dtype=torch.bfloat16)
scales = torch.randn(size=(1, 32), dtype=torch.bfloat16)
bias = torch.randn(size=(1, 32), dtype=torch.bfloat16)

In [3]:
F.linear(input=random_hs, weight=random_int8.to(random_hs.dtype))

tensor([[ 219.0000, -282.0000, -148.0000,  -28.2500,   17.0000,  129.0000,
          -96.0000,  282.0000, -216.0000, -568.0000,  -53.0000, -270.0000,
         -334.0000,  -48.0000, -390.0000,  488.0000,   77.0000, -170.0000,
          -36.2500, -450.0000, -294.0000,  338.0000, -270.0000, -470.0000,
         -256.0000,   75.5000,  354.0000, -568.0000, -157.0000,   61.5000,
         -320.0000,  -49.5000]], dtype=torch.bfloat16)

In [4]:
F.linear(input=random_hs, weight=random_int8.to(random_hs.dtype)) * scales

tensor([[-157.0000, -131.0000,  -69.0000,   35.7500,  -18.0000,  134.0000,
           12.1875,  -94.0000,  204.0000, -440.0000,   31.6250,  -76.5000,
          -28.2500,   72.5000,  564.0000, -356.0000, -100.0000,   10.6875,
            2.8125, -272.0000, -255.0000,   82.0000, -398.0000,  352.0000,
           57.5000,   -4.6875,  -79.0000, -330.0000,  127.5000,  118.0000,
           91.0000,   -7.6875]], dtype=torch.bfloat16)

In [5]:
F.linear(input=random_hs, weight=random_int8.to(random_hs.dtype)) * scales + bias

tensor([[-158.0000, -131.0000,  -69.0000,   35.7500,  -16.6250,  134.0000,
           12.1250,  -94.0000,  204.0000, -442.0000,   31.8750,  -78.5000,
          -26.0000,   72.0000,  564.0000, -356.0000, -101.0000,   10.7500,
            4.3750, -272.0000, -256.0000,   81.5000, -398.0000,  352.0000,
           58.2500,   -6.7500,  -79.5000, -330.0000,  128.0000,  118.5000,
           90.5000,   -7.8750]], dtype=torch.bfloat16)

- Implement all this as a function, `w8_a16_forward`

In [6]:
torch.finfo(torch.bfloat16)

finfo(resolution=0.01, min=-3.38953e+38, max=3.38953e+38, eps=0.0078125, smallest_normal=1.17549e-38, tiny=1.17549e-38, dtype=bfloat16)

In [7]:
def w8_a16_forward(weight, input, scales, bias=None):
    casted_weights = weight.to(input.dtype)
    output = F.linear(input=input, weight=casted_weights) * scales

    if bias is not None:
        output += bias
    
    return output

In [8]:
print(f"With bias: {w8_a16_forward(weight=random_int8, input=random_hs, scales=scales, bias=bias)}")

print(f"Without bias: {w8_a16_forward(weight=random_int8, input=random_hs, scales=scales)}")

With bias: tensor([[-158.0000, -131.0000,  -69.0000,   35.7500,  -16.6250,  134.0000,
           12.1250,  -94.0000,  204.0000, -442.0000,   31.8750,  -78.5000,
          -26.0000,   72.0000,  564.0000, -356.0000, -101.0000,   10.7500,
            4.3750, -272.0000, -256.0000,   81.5000, -398.0000,  352.0000,
           58.2500,   -6.7500,  -79.5000, -330.0000,  128.0000,  118.5000,
           90.5000,   -7.8750]], dtype=torch.bfloat16)
Without bias: tensor([[-157.0000, -131.0000,  -69.0000,   35.7500,  -18.0000,  134.0000,
           12.1875,  -94.0000,  204.0000, -440.0000,   31.6250,  -76.5000,
          -28.2500,   72.5000,  564.0000, -356.0000, -100.0000,   10.6875,
            2.8125, -272.0000, -255.0000,   82.0000, -398.0000,  352.0000,
           57.5000,   -4.6875,  -79.0000, -330.0000,  127.5000,  118.0000,
           91.0000,   -7.6875]], dtype=torch.bfloat16)


### 1.2 - `init` Function of class `W8A16LinearLayer`

- This is how the `init` is of [PyTorch Linear layer](https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear):
```Python
def __init__(self, in_features, out_features, bias=True,
             device=None, dtype=None)

```

In [9]:
# running this will result in an error
class W8A16LinearLayer:
    def __init__(self, in_features, out_features, bias=True, dtype=torch.float32):
        super().__init__()

        self.int8_weights = nn.Parameter(data=torch.Tensor([0, 1]).to(dtype=torch.int8))

try:
    W8A16LinearLayer(in_features=1, out_features=1)
except Exception as error:
    print("\033[91m", type(error).__name__, ": ", error, "\033[0m")

[91m RuntimeError :  Only Tensors of floating point and complex dtype can require gradients [0m


Store the weights as buffer

In [10]:
class W8A16LinearLayer(nn.Module):
    def __init__(self, in_features, out_features, bias=True, dtype=torch.float32):
        super().__init__()

        self.register_buffer("int8_weights",
                              torch.randint(low=-128, high=127, size=(out_features, in_features), dtype=torch.int8)
                              )
        
        self.register_buffer("scales",
                             torch.randn((1, out_features), dtype=dtype)
                             )
        
        if bias:
            self.register_buffer("bias",
                                 torch.randn((1, out_features), dtype=dtype)
                                 )
        else:
            self.bias = None


- Test your implementation

In [11]:
dummy_instance = W8A16LinearLayer(in_features=16, out_features=32)

In [12]:
print(dummy_instance.int8_weights.shape)
print(dummy_instance.scales.shape)

torch.Size([32, 16])
torch.Size([1, 32])


### 1.3 - `forward` Function of class `W8A16LinearLayer`

- Use the `w8_a16_forward` defined earlier (Step 1.1) to define the `forward` function.

In [13]:
class W8A16LinearLayer(nn.Module):
    def __init__(self, in_features, out_features, bias=True, dtype=torch.float32):
        super().__init__()

        self.register_buffer("int8_weights",
                              torch.randint(low=-128, high=127, size=(out_features, in_features), dtype=torch.int8)
                              )
        
        self.register_buffer("scales",
                             torch.randn((1, out_features), dtype=dtype)
                             )
        
        if bias:
            self.register_buffer("bias",
                                 torch.randn((1, out_features), dtype=dtype)
                                 )
        else:
            self.bias = None

    def forward(self, input):
        return w8_a16_forward(weight=self.int8_weights, input=input, scales=self.scales, bias=self.bias)

In [14]:
module = W8A16LinearLayer(in_features=16, out_features=32)
dummy_hidden_states = torch.randn(1,6,16)

In [15]:
module(dummy_hidden_states).shape

torch.Size([1, 6, 32])

In [16]:
module(dummy_hidden_states).dtype

torch.float32

KA: Check that output has same data type as the input

In [17]:
module = W8A16LinearLayer(in_features=16, out_features=32, dtype=torch.bfloat16)
dummy_hidden_states = torch.randn(1,6,16, dtype=torch.bfloat16)

In [18]:
module(dummy_hidden_states).dtype

torch.bfloat16

### 1.4 - `quantize` Function of class `W8A16LinearLayer`

- `quantize` function will dynamically quantize half-precision weights into `torch.int8`

In [19]:
class W8A16LinearLayer(nn.Module):
    def __init__(self, in_features, out_features, bias=True, dtype=torch.float32):
        super().__init__()

        self.register_buffer("int8_weights",
                              torch.randint(low=-128, high=127, size=(out_features, in_features), dtype=torch.int8)
                              )
        
        self.register_buffer("scales",
                             torch.randn((1, out_features), dtype=dtype)
                             )
        
        if bias:
            self.register_buffer("bias",
                                 torch.randn((1, out_features), dtype=dtype)
                                 )
        else:
            self.bias = None
    
    def quantize(self, weights):
        w_fp32 = weights.clone().to(torch.float32)

        scales = w_fp32.abs().max(dim=-1).values / 127
        scales = scales.to(weights.dtype)

        int8_weights = torch.round(weights/scales.unsqueeze(1)).to(torch.int8)

        self.int8_weights = int8_weights
        self.scales = scales

    def forward(self, input):
        return w8_a16_forward(weight=self.int8_weights, input=input, scales=self.scales, bias=self.bias)

In [20]:
module = W8A16LinearLayer(in_features=4, out_features=8)

In [21]:
print(f"Weights before:\n", module.int8_weights)

Weights before:
 tensor([[ -13,  -71,  -29,  108],
        [ -23,  -54,   22,  -64],
        [-116,  -81,  -32,   85],
        [ -79,  -48,  -75,  126],
        [  26,  -52,  -16,  -70],
        [  44,  112,  -25,   -8],
        [ -76,  -34,  -51,   81],
        [ -37,   73,    6,   22]], dtype=torch.int8)


In [22]:
random_matrix = torch.randn((4, 8), dtype=torch.bfloat16)

In [23]:
random_matrix

tensor([[ 0.8750,  0.1396, -0.3438,  0.4395, -0.9570, -0.6875,  0.5117, -0.3145],
        [-0.1953,  0.7031,  0.8945, -1.6797, -1.0078,  2.0781,  0.6562,  1.8125],
        [ 0.4648,  0.1904, -1.5781, -0.9609,  1.3281,  0.6211,  0.4414, -0.5508],
        [-1.7734,  0.6953,  0.4824, -0.8672,  0.3320, -0.1797, -0.0286, -0.9570]],
       dtype=torch.bfloat16)

In [24]:
module.quantize(weights=random_matrix)

In [25]:
print(f"Weights after:\n", module.int8_weights)

Weights after:
 tensor([[ 116,   18,  -46,   58, -127,  -91,   68,  -42],
        [ -12,   43,   55, -102,  -62,  127,   40,  111],
        [  37,   15, -126,  -77,  106,   50,   36,  -44],
        [-127,   50,   34,  -62,   24,  -13,   -2,  -68]], dtype=torch.int8)


In [26]:
module.scales

tensor([0.0075, 0.0164, 0.0125, 0.0140], dtype=torch.bfloat16)

In [27]:
module.scales.shape

torch.Size([4])

In [28]:
module.int8_weights.shape

torch.Size([4, 8])

In [29]:
### dequantized weights
module.int8_weights * module.scales.unsqueeze(1)

tensor([[ 0.8750,  0.1357, -0.3477,  0.4375, -0.9570, -0.6875,  0.5117, -0.3164],
        [-0.1963,  0.7031,  0.8984, -1.6719, -1.0156,  2.0781,  0.6562,  1.8125],
        [ 0.4609,  0.1865, -1.5703, -0.9570,  1.3203,  0.6211,  0.4492, -0.5469],
        [-1.7734,  0.6992,  0.4746, -0.8672,  0.3359, -0.1816, -0.0280, -0.9492]],
       dtype=torch.bfloat16)

In [30]:
### original weights
random_matrix

tensor([[ 0.8750,  0.1396, -0.3438,  0.4395, -0.9570, -0.6875,  0.5117, -0.3145],
        [-0.1953,  0.7031,  0.8945, -1.6797, -1.0078,  2.0781,  0.6562,  1.8125],
        [ 0.4648,  0.1904, -1.5781, -0.9609,  1.3281,  0.6211,  0.4414, -0.5508],
        [-1.7734,  0.6953,  0.4824, -0.8672,  0.3320, -0.1797, -0.0286, -0.9570]],
       dtype=torch.bfloat16)

Compute the average quantization error

In [31]:
(random_matrix - module.int8_weights * module.scales.unsqueeze(1)).abs().mean()

tensor(0.0030, dtype=torch.bfloat16)