## Custom 8-bit quantizer

We will be implementing a per-channel quantization method to quantize the model in 8-bit.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
# Init random variables
random_int8 = torch.randint(-128, 127, (32, 16)).to(torch.int8)
random_hs = torch.randn((1, 16), dtype=torch.bfloat16)
scales = torch.randn((1, 32), dtype=torch.bfloat16)
bias = torch.randn((1, 32), dtype=torch.bfloat16)

In [10]:
def w8_a16_forward(weight, input, scales, bias=None):

    casted_weights = weight.to(input.dtype)
    output = F.linear(input, casted_weights) * scales

    if bias is not None:
        output = output + bias

    return output

In [8]:
class W8A16LinearLayer(nn.Module):

  def __init__(self, in_features, out_features, bias = True, dtype = torch.float32):
    super().__init__()
    self.register_buffer(
            "int8_weights",
            torch.randint(
                -128, 127, (out_features, in_features), dtype=torch.int8
            )
        )
    self.register_buffer("scales",
                             torch.randn((out_features), dtype=dtype))

    if bias:
        self.register_buffer("bias",
                              torch.randn((1, out_features),
                                          dtype=dtype))
    else:
        self.bias = None

  def quantize(self, weights):
        w_fp32 = weights.clone().to(torch.float32)

        scales = w_fp32.abs().max(dim=-1).values / 127
        scales = scales.to(weights.dtype)

        int8_weights = torch.round(weights
                        /scales.unsqueeze(1)).to(torch.int8)

        self.int8_weights = int8_weights
        self.scales = scales

  def forward(self, input):
      return w8_a16_forward(self.int8_weights,
                            input, self.scales, self.bias)


In [9]:
def replace_linear_with_target(module,
                               target_class, module_name_to_exclude):
    for name, child in module.named_children():
        if isinstance(child, nn.Linear) and not \
          any([x == name for x in module_name_to_exclude]):
            old_bias = child.bias

            new_module = target_class(child.in_features,
                                      child.out_features,
                                      old_bias is not None,
                                      child.weight.dtype)
            setattr(module, name, new_module)
            if old_bias is not None:
              getattr(module, name).bias = old_bias
        else:
            # Recursively call the function for nested modules
            replace_linear_with_target(
                child, target_class, module_name_to_exclude)

W8A16LinearLayer()