# Custom quantiser layer: W8A16LinearLayer

In [2]:
# imports
import torch
import torch.nn as nn
import torch.nn.functional as F

In [9]:
random_int8 = torch.randint(-128, 127, (32, 16)).to(torch.int8)
random_hs = torch.randn((1, 16), dtype=torch.bfloat16)
scales = torch.randn((1,32), dtype=torch.bfloat16)
bias = torch.randn((1, 32), dtype=torch.bfloat16)

In [6]:
F.linear(random_hs, random_int8.to(random_hs.dtype))

tensor([[  89.0000, -510.0000,  110.0000,  -57.7500,  -61.5000,  175.0000,
          151.0000, -444.0000, -178.0000,  174.0000, -233.0000,  302.0000,
          612.0000,  -92.5000,  392.0000,  612.0000,  -31.8750,  -37.5000,
         -165.0000, -219.0000,  576.0000,  -93.0000,   92.5000,   52.0000,
          188.0000,  235.0000,    3.7969,   65.5000, -472.0000,   41.0000,
          -32.7500,   13.0625]], dtype=torch.bfloat16)

In [7]:
F.linear(random_hs, random_int8.to(random_hs.dtype)) * scales

tensor([[ 1.3400e+02,  1.6800e+02, -1.3100e+02,  1.0000e+01,  6.7500e+01,
         -6.0250e+01,  7.4219e-02,  2.4800e+02,  3.6800e+02, -2.3100e+02,
         -1.7700e+02, -1.6500e+02,  2.7400e+02, -1.8400e+02,  3.9800e+02,
         -9.8000e+02,  5.0750e+01, -4.5750e+01, -6.8000e+01,  1.8600e+02,
         -4.8600e+02, -1.4900e+02,  5.6250e+01, -1.0200e+02, -3.1600e+02,
          0.0000e+00,  7.0000e+00, -2.8250e+01, -3.0600e+02, -6.9500e+01,
         -3.1500e+01,  1.9750e+01]], dtype=torch.bfloat16)

In [11]:
(F.linear(random_hs, random_int8.to(random_hs.dtype)) * scales) + bias

tensor([[  28.0000, -191.0000,  -59.0000, -688.0000,  125.0000, -280.0000,
          400.0000, -235.0000,  -92.0000,   -3.1406,   38.2500,  -28.6250,
            8.4375,  227.0000,   11.5000,  104.5000,  204.0000,   73.0000,
          -39.2500,   62.5000,  -34.7500, -175.0000,  240.0000,  160.0000,
          -45.5000,  153.0000,   17.6250,  624.0000,   86.0000,   42.0000,
          -35.5000,  107.0000]], dtype=torch.bfloat16)

In [68]:
print("With bias:")
print(w8_a16_forward(input=random_hs, weight=random_int8, scales=scales, bias=None))

print("Without bias:")
print(w8_a16_forward(input=random_hs, weight=random_int8, scales=scales))

With bias:
tensor([[  27.3750, -191.0000,  -59.0000, -688.0000,  126.0000, -280.0000,
          402.0000, -234.0000,  -92.5000,   -3.8438,   37.2500,  -29.0000,
            8.8125,  227.0000,    9.5000,  103.5000,  203.0000,   72.5000,
          -38.2500,   63.5000,  -33.5000, -175.0000,  240.0000,  160.0000,
          -44.5000,  154.0000,   14.4375,  624.0000,   85.0000,   41.2500,
          -34.2500,  105.5000]], dtype=torch.bfloat16)
Without bias:
tensor([[  27.3750, -191.0000,  -59.0000, -688.0000,  126.0000, -280.0000,
          402.0000, -234.0000,  -92.5000,   -3.8438,   37.2500,  -29.0000,
            8.8125,  227.0000,    9.5000,  103.5000,  203.0000,   72.5000,
          -38.2500,   63.5000,  -33.5000, -175.0000,  240.0000,  160.0000,
          -44.5000,  154.0000,   14.4375,  624.0000,   85.0000,   41.2500,
          -34.2500,  105.5000]], dtype=torch.bfloat16)


In [None]:
class W8A16LinearLayer(nn.Module):
    def __init__(self, in_features, out_features, bias=True, dtype=torch.float32):
        super().__init__()
        self.register_buffer("int8_weights", torch.randint(-128, 127, (out_features, in_features), dtype=torch.int8))
        self.register_buffer("scales", torch.randn((out_features), dtype=dtype))

        if bias is not None:
            self.register_buffer("bias", torch.randn((1, out_features), dtype=dtype))
        else:
            self.bias = None

    def w8_a16_forward(weight, input, scales, bias=None):
        casted_weights = weight.to(input.dtype)
        output = F.linear(input, casted_weights) * scales
        if bias is not None:
            output = output + bias
        return output

    def quantise(self, weights):
        w_fp32 = weights.clone().to(torch.float32)
        scales = w_fp32.abs().max(dim=-1).values / 127
        scales = scales.to(w_fp32.dtype)

        int8_w = torch.round(w_fp32/scales.unsqueeze(1)).to(torch.int8)
        self.int8_weights = int8_w
        self.scales = scales

    def forward(self, input):
        return w8_a16_forward(input=input, weight=self.int8_weights, scales=self.scales, bias=self.bias)


In [77]:
module = W8A16LinearLayer(16, 32)
dummy_hidden_states = torch.randn(1, 6, 16)
dummy_hidden_states.shape
module(dummy_hidden_states).dtype


torch.float32