In [16]:
import torch
import torch.nn as nn

In [26]:
class QuantizedLinearLayer(nn.Module):
    """Quantized version of nn.Linear"""

    def __init__(
        self,
        input_dim,
        output_dim,
        weight,
        weight_scale,
        weight_zero_point,
        bias,
        bias_scale,
        bias_zero_point,
    ):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.weight = nn.parameter.Buffer(weight)
        self.bias = nn.parameter.Buffer(bias)
        self.weight_scale = nn.parameter.Buffer(weight_scale)
        self.weight_zero_point = nn.parameter.Buffer(weight_zero_point)
        self.bias_scale = nn.parameter.Buffer(bias_scale)
        self.bias_zero_point = nn.parameter.Buffer(bias_zero_point)

    def forward(self, x):
        # x.shape = (batch_size, input_features)

        # dequantize params
        weight = (self.weight.float() - self.weight_zero_point) / self.weight_scale
        bias = (self.bias.float() - self.bias_zero_point) / self.bias_scale

        # compute
        return x @ weight.T + bias

In [1]:
def quantize_linear(linear_layer):
    """Quantizes a linear layer and returns the quantized weights and biases together 
    with the scale and zero point"""
    # quantize linear layer to unsigned 8-bit integers
    weight = linear_layer.weight

    # compute min and max
    min_val = weight.min()
    max_val = weight.max()

    # extend interval to include zero
    if min_val > 0:
        min_val = 0

    if max_val < 0:
        max_val = 0

    # compute scale
    weight_scale = 255 / (max_val - min_val)
    # compute zero point
    weight_zero_point = (-min_val * weight_scale).round().clamp(0, 255).to(torch.uint8)

    # quantize weight
    weight_quantized = (
        (weight * weight_scale + weight_zero_point)
        .round()
        .clamp(0, 255)
        .to(torch.uint8)
    )

    # same for bias
    bias = linear_layer.bias

    min_val = bias.min()
    max_val = bias.max()

    if min_val > 0:
        min_val = 0

    if max_val < 0:
        max_val = 0

    bias_scale = 255 / (max_val - min_val)
    bias_zero_point = (-min_val * bias_scale).round().clamp(0, 255).to(torch.uint8)

    bias_quantized = (
        (bias * bias_scale + bias_zero_point).round().clamp(0, 255).to(torch.uint8)
    )

    return (
        weight_quantized,
        weight_scale,
        weight_zero_point,
        bias_quantized,
        bias_scale,
        bias_zero_point,
    )

In [None]:
def quantize_model(model, exclude_layers):
    """It quantizes the model by quantizing all the linear layers in the model.
    Args:
        model: the model to quantize
        exclude_layers: list of layers to exclude from quantization

    Returns:
        the quantized model"""

    # quantize model
    for name, layer in model.named_children():
        if name in exclude_layers:
            continue
        if isinstance(layer, nn.Linear):
            # quantize layer
            (
                weight,
                weight_scale,
                weight_zero_point,
                bias,
                bias_scale,
                bias_zero_point,
            ) = quantize_linear(layer)
            # replace layer with quantized version
            setattr(
                model,
                name,
                QuantizedLinearLayer(
                    layer.in_features,
                    layer.out_features,
                    weight,
                    weight_scale,
                    weight_zero_point,
                    bias,
                    bias_scale,
                    bias_zero_point,
                ),
            )
        else:
            # recursively quantize children
            quantize_model(layer, exclude_layers)

    return model

# Esempio di utilizzo

In [19]:
linear_layer = nn.Linear(10, 20)

x = torch.randn(5, 10)

(
    weight_quantized,
    weight_scale,
    weight_zero_point,
    bias_quantized,
    bias_scale,
    bias_zero_point,
) = quantize_linear(linear_layer)

quantized_linear_layer = QuantizedLinearLayer(
    10,
    20,
    weight_quantized,
    weight_scale,
    weight_zero_point,
    bias_quantized,
    bias_scale,
    bias_zero_point,
)

In [20]:
linear_layer.bias

Parameter containing:
tensor([ 0.0088, -0.0333, -0.1156, -0.1462, -0.1421,  0.2478, -0.1945,  0.0979,
        -0.0684, -0.2976, -0.1623,  0.0833, -0.1441,  0.2712, -0.1024, -0.2685,
         0.2813, -0.1863,  0.0755,  0.2845], requires_grad=True)

In [21]:
(bias_quantized - bias_zero_point.float()) / bias_scale

tensor([ 0.0091, -0.0342, -0.1164, -0.1461, -0.1415,  0.2488, -0.1940,  0.0982,
        -0.0685, -0.2968, -0.1621,  0.0822, -0.1438,  0.2717, -0.1027, -0.2694,
         0.2808, -0.1872,  0.0753,  0.2853], grad_fn=<DivBackward0>)

In [22]:
weight_quantized, weight_scale, weight_zero_point, bias_quantized, bias_scale, bias_zero_point

(tensor([[115,  79, 189, 213,  28,  92, 211,  46, 105, 150],
         [ 38, 183, 170, 142, 132,  73,  87,  26, 233, 104],
         [ 35, 241,  56, 230, 196,  53, 194, 103, 134, 122],
         [ 45, 144, 142,  71, 133, 136, 204, 204,  19, 100],
         [149,  14,   0,  71,  30, 227, 237, 165,  50,  15],
         [ 73,  60, 214,  47,  46,  23, 181, 217, 119,   4],
         [100, 142,  21, 217, 106,  17, 237, 178, 185, 186],
         [124,  60,  83,  53,  84, 121, 168, 205, 223,  54],
         [129, 130,  55, 112,  34,  72,  91,   6,  72,  64],
         [187,  74,  99,  55,  62, 236,  28, 207, 255, 177],
         [140, 191, 100,  27,  55,  22,  46, 240,  56, 232],
         [211,  42,  30,  86, 216, 111,  93, 114,  97, 137],
         [ 64, 169,  40,  83,  48,  68,  98,  57, 188, 121],
         [ 98, 180,  69, 225, 139, 254,  93,  94, 175,  44],
         [140,  27,  29, 228, 171,  44, 230, 139, 124, 196],
         [117, 175,  64, 225,   2,  52,  18,  40,  92, 225],
         [ 32, 244,  99,

In [23]:
linear_layer(x)

tensor([[ 0.0768,  0.6342,  0.3130, -1.0972, -1.2296, -0.5090,  0.7722,  0.1783,
         -0.2836,  0.4136,  0.0305, -0.0108,  0.4759,  0.3487,  0.4402,  0.6275,
          1.8768, -0.2668, -0.5348,  0.1860],
        [ 0.1502,  0.2328,  0.3824, -0.9654, -0.7395, -0.4140,  0.1421, -0.2863,
          0.1670, -0.3156, -0.4549, -0.1251, -0.1044,  1.1238, -0.0629,  0.6653,
          1.0077, -0.3789, -0.1617, -0.1031],
        [ 0.1273, -0.5554, -0.1577, -0.4666,  1.1988,  0.4176,  0.4678,  0.7854,
          0.6538, -0.4186, -0.4085,  0.9378,  0.1510,  0.0620,  0.7029, -0.5298,
         -0.6442, -0.9324,  0.1821,  0.8202],
        [-0.4484, -0.3160, -0.6315, -0.0603,  0.2521,  0.6745, -1.0236,  0.1898,
          0.2037, -0.1817, -0.0967,  0.3500, -0.4088,  0.5202, -0.7298, -0.5025,
         -0.5360,  0.0209,  0.3130,  0.4785],
        [-0.6428, -0.3609, -0.1791, -0.2906, -0.6202, -0.4382, -0.1434, -0.3393,
          0.2912, -0.0483,  1.3269,  0.8876,  0.2011, -0.3006,  0.0754,  0.7666,
      

In [12]:
torch.tensor(1.0) - torch.tensor(255, dtype=torch.uint8)

tensor(-254.)