In [1]:
import torch
import torch.nn as nn

In [None]:
def quantize_tensor_symmetric(tensor: torch.tensor, n_bits=8):
    q_max = 2 ** (n_bits - 1) - 1
    q_min = - q_max

    max_val = tensor.abs().max()
    min_val = tensor.abs().min()

    scale = (max_val - min_val) / (q_max - q_min)

    if scale == 0.0:
        scale = 1.0

    quantized = torch.clamp(
        tensor / scale, q_min, q_max
    ).to(torch.int8)

    return quantized, scale

def dequantize_tensor_symmetric(quantized: torch.tensor, scale: float):
    return quantized.float() * scale

def quantize_tensor_asymmetric(tensor: torch.tensor, n_bits=8):
    q_max = 2 ** (n_bits - 1) - 1
    q_min = - q_max

    max_val = tensor.max()
    min_val = tensor.abs().min()

    scale = (max_val - min_val) / (q_max - q_min)

    if scale == 0.0:
        scale = 1.0

    zero_point = q_min - torch.round(min_val / scale)
    zero_point = torch.clamp(zero_point, min_val, max_val).to(torch.int8)

    quantized = torch.clamp(
        tensor / scale + zero_point, 
        q_min, q_max
    ).to(torch.int8)

    return quantized, scale, zero_point

def dequantize_tensor_asymmetric(quantized: torch.tensor, scale: float, zero_point: int):
    return (quantized - zero_point).float() * scale

In [2]:
def per_channel_quantization(tensor: torch.tensor, channel_dim: int, n_bits: int = 8):
    q_max = 2 ** (n_bits - 1) - 1
    q_min = - q_max

    tranposed_tensor = tensor.transpose(0, channel_dim) if channel_dim != 0 else tensor

    quantized_list = []
    scales = []
    zero_points = []

    for c in range(tranposed_tensor.shape[0]):
        channel_data = tranposed_tensor[c,:]
        max_val = channel_data.max()
        min_val = channel_data.min()

        scale = (max_val - min_val) / (q_max - q_min)

        if scale == 0.0:
            scale = 1.0

        zero_point = q_min - torch.round(min_val / scale)
        zero_point = torch.clamp(zero_point, q_min, q_max).to(torch.int8)

        quantized_channel_data = torch.clamp(
            channel_data / scale + zero_point, q_min, q_max
        ).to(torch.int8)

        quantized_list.append(quantized_channel_data)
        scales.append(scale)
        zero_points.append(zero_point)

    quantized_tensor = torch.concat(quantized_list, dim=0)
    quantized_tensor = quantized_tensor.transpose(0, channel_dim) if channel_dim != 0 else quantized_tensor

    scales = torch.tensor(scales)
    zero_points = torch.tensor(zero_points)

    return quantized_tensor, scales, zero_points

def dequantize_per_channel(quantized_tensor: torch.tensor, scales: torch.tensor, zero_points: torch.tensor, channel_dim: int):
    quantized_tensor = quantized_tensor.transpose(0, channel_dim) if channel_dim != 0 else quantized_tensor
    tensors = []

    for c in range(quantized_tensor.shape[0]):
        quantized_data = quantized_tensor[c,:]
        data = (quantized_data - zero_points[c]).float() * scales[c]

        tensors.append(data)

    tensors = torch.concat(tensors, dim=0)
    tensors = tensors.transpose(0, channel_dim) if channel_dim != 0 else tensors

    return tensors

In [3]:
class QuantizedLinearLayer(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super().__init__()

        self.in_features = in_features
        self.out_features = out_features

        self.linear = nn.Linear(in_features, out_features, bias=True)

        self.register_buffer('weight_quantized', torch.zeros(out_features, in_features, dtype=torch.int8))
        self.register_buffer('weight_scale', torch.zeros(out_features))
        self.register_buffer('weight_zero_point', torch.zeros(out_features))

        self.register_buffer('activation_scale', torch.tensor(1.0))
        self.register_buffer('activation_zero_point', torch.tensor(0.0))

        self.quantized = False

    def quantize_weights(self):
        weight = self.linear.weight.data
        quantized_weight, scale, zero_point = per_channel_quantization(weight, channel_dim=0)

        self.weight_quantized = quantized_weight
        self.weight_scale = scale
        self.weight_zero_point = zero_point

        self.quantized = True

    def set_activation_scale(self, scale, zero_point):
        self.activation_scale = scale
        self.activation_zero_point = zero_point

    def forward(self, x):
        if not self.quantized:
            return self.linear(x)
        
        quantized_x = torch.clamp(
            x / self.activation_scale + self.activation_zero_point, 
            -127, 127
        ) # assume 8 bit quantization

        output = torch._ops.quantized.linear(
            quantized_x,
            self.weight_quantized,
            self.weight_scale,
            self.weight_zero_point,
            self.activation_scale,
            self.activation_zero_point
        )

        return output

In [4]:
class QuantizedModel(nn.Module):
    def __init__(self, in_features, hidden_features, out_features):
        super().__init__()

        self.fc1 = QuantizedLinearLayer(in_features, hidden_features)
        self.relu1 = nn.ReLU()
        self.fc2 = QuantizedLinearLayer(hidden_features, hidden_features)
        self.relu2 = nn.ReLU()
        self.fc3 = QuantizedLinearLayer(hidden_features, out_features)

    def forward(self, x):
        x = self.relu1(self.fc1(x))
        x = self.relu2(self.fc2(x))
        x = self.fc3(x)

        return x
    
    def quantize_weights(self):
        self.fc1.quantize_weights()
        self.fc2.quantize_weights()
        self.fc3.quantize_weights()

#### Min-Max Calibration

In [None]:
def calibrate_activations(model, layer_name, calibration_loader):
    activations = []

    def hook_fn(model, input, output):
        activations.append(output.detach().cpu())
        
    module = dict(model.named_modules())[layer_name]
    handle = module.register_forward_hook(hook_fn)
        
    model.eval()
    with torch.no_grad():
        for data in calibration_loader:
            _ = model(data)

    handle.remove()

    activations = torch.cat(activations, dim=0)
    max_val = activations.max()
    min_val = activations.min()

    scale = max(abs(min_val), abs(max_val)) / 255

    return scale

#### Percentile Calibration

In [5]:
def percentile_calibration(activations, percentile=99.99):
    acts = activations.flatten()

    alpha = (100 - percentile) / 2
    lower_quantile = torch.quantile(acts, alpha)
    upper_quantile = torch.quantile(acts, 1. - alpha)

    acts_clipped = torch.clamp(acts, lower_quantile, upper_quantile)

    scale = max(abs(lower_quantile), abs(upper_quantile)) / 127
    
    return scale, lower_quantile, upper_quantile

#### MSE Calibration

In [None]:
def mse_calibration(activations, num_candidates=100):
    max_val = activations.max()

    candidates = torch.linspace(max_val * 0.5, max_val, steps=num_candidates)
    best_candidate = None
    best_scale = None
    mse = float('inf')

    for candidate in candidates:
        scale = candidate / 127

        quantized = torch.clamp(
            activations / scale, 
            -127, 127
        ).to(torch.int8)

        dequantized = quantized.float() * scale

        loss = torch.square(activations - dequantized).sum().item()

        if loss < mse:
            mse = loss
            best_scale = scale
            best_candidate = candidate

    return best_scale, best_candidate

### Quantization Aware Training

In [6]:
from torch.autograd import Function

In [7]:
class FakeQuantize(Function):
    @staticmethod
    def forward(ctx, input, scale, zero_point, q_min, q_max):
        quantized = torch.clamp(
            input / scale + zero_point,
            q_min, q_max
        ).to(torch.int8)

        dequantized = (quantized - zero_point).float() * scale

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output, None, None, None, None
    
fakequantize = FakeQuantize.apply

In [8]:
class MinMaxObserver(nn.Module):
    def __init__(self, n_bits=8):
        super().__init__()

        self.n_bits = n_bits
        self.q_max = 2 ** (n_bits - 1) - 1
        self.q_min = - self.q_max

        self.register_buffer('min_val', torch.tensor(float('inf')))
        self.register_buffer('max_val', torch.tensor(float('-inf')))

        self.register_buffer('scale', torch.tensor(1.0))
        self.register_buffer('zero_poinr', torch.tensor(0.0))

    def forward(self, x):
        if self.training:
            current_min = x.min()
            current_max = x.max()

            self.min_val = min(self.min_val, current_min)
            self.max_val = max(self.max_val, current_max)

        return x
    
    def calculate_params(self):
        self.scale = (self.max_val - self.min_val) / (self.q_max - self.q_min)
        if self.scale == 0.0:
            self.scale = 1.0

        self.zero_point = self.q_min - self.min_val / self.scale

        return self.scale, self.zero_point

In [11]:
class MovingAverageMinMaxObserver(MinMaxObserver):
    def __init__(self, n_bits=8, averaging_constant=0.01):
        super().__init__(n_bits)

        self.averaging_constant = averaging_constant

    def forward(self, x):
        if self.training:
            current_min = x.min()
            current_max = x.max()

            if self.min_val == float('inf'):
                self.min_val = current_min
            else:
                self.min_val = self.averaging_constant * current_min + (1. - self.averaging_constant) * self.min_val
            
            if self.max_val == float('-inf'):
                self.max_val = current_max
            else:
                self.max_val = self.averaging_constant * current_max + (1. - self.averaging_constant) * self.max_val

        return x

In [12]:
class QATLinearLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()

        self.linear = nn.Linear(in_features, out_features, bias=True)

        self.activation_observer = MovingAverageMinMaxObserver()
        self.register_buffer("activation_scale", None)
        self.register_buffer("activation_zero_point", None)

        self.qat_mode = False

    def enable_observer(self):
        self.activation_observer.train()

    def disble_observer(self):
        self.activation_observer.eval()

    def calculate_params(self):
        scale, zero_point = self.activation_observer.calculate_params()
        self.activation_scale = scale
        self.activation_zero_point = zero_point

    def enable_qat(self):
        self.qat_mode = True
        self.disble_observer()
        self.calculate_params()

    def forward(self):
        x = self.activation_observer(x)

        if not self.qat_mode:
            return self.linear(x)
        
        x_fq = fakequantize(
            x, self.activation_scale, self.activation_zero_point, 
            self.activation_observer.min_val,
            self.activation_observer.max_val
        )

        quantized_weights = per_channel_quantization(self.linear.weight.data)

        return nn.functional.linear(x_fq, quantized_weights, self.linear.bias)

In [13]:
class QATModel(nn.Module):
    def __init__(self, in_features, hidden_features, out_features):
        super().__init__()

        self.fc1 = QATLinearLayer(in_features, hidden_features)
        self.relu1 = nn.ReLU()
        self.fc2 = QATLinearLayer(hidden_features, hidden_features)
        self.relu2 = nn.ReLU()
        self.fc3 = QATLinearLayer(hidden_features, out_features)

    def enable_observer(self):
        for module in self.modules():
            if isinstance(module, QATLinearLayer):
                module.enable_observer()

    def disable_observer(self):
        for module in self.modules():
            if isinstance(module, QATLinearLayer):
                module.disble_observer()

    def calculate_params(self):
        for module in self.modules():
            if isinstance(module, QATLinearLayer):
                module.calculate_params()

    def enable_qat(self):
        for module in self.modules():
            if isinstance(module, QATLinearLayer):
                module.enable_qat()