In [1]:
import torch
import torch.nn as nn

In [6]:
def quantize_tensor_symmetric(x: torch.tensor, n_bits: int):
    q_max = 2 ** (n_bits - 1) - 1
    q_min = - q_max

    max_val = x.abs().max()

    scale = max_val / q_max
    if scale == 0.0:
        scale = 1.0

    quantized = torch.clamp(
        torch.round(x / scale), 
        q_min, q_max
    ).to(torch.int8)

    return quantized, scale

In [7]:
def dequantize_tensor_assymetric(quantized: torch.tensor, scale: float):
    return quantized.float() * scale

In [None]:
def quantize_tensor_asymmetric(x: torch.tensor, n_bits: int):
    q_max = 2 ** (n_bits - 1) - 1
    q_min = - q_max

    max_val = x.max()
    min_val = x.min()

    scale = (max_val - min_val) / (q_max - q_min)
    if scale == 0.0:
        scale = 1.0

    zero_point = q_min - torch.round(min_val / scale)
    zero_point = torch.clamp(zero_point, q_min, q_max).to(torch.int8)

    quantized = torch.clamp(
        torch.round(x / scale) + zero_point, 
        q_min, q_max
    ).to(torch.int8)

    return quantized, scale, zero_point

In [9]:
def dequantize_tensor_asymmetric(quantized: torch.tensor, scale: float, zero_point: int):
    return (quantized - zero_point).float() * scale 

In [10]:
def per_channel_quantization_symmetric(x: torch.tensor, channel_dim: int, n_bits: int):
    q_max = 2 ** (n_bits - 1) - 1
    q_min = - q_max

    transposed_x = x.transpose(0, channel_dim) if channel_dim != 0 else x
    quantized_channels = []
    scales = []

    for i in range(transposed_x.shape[0]):
        channel_data = transposed_x[i,:]
        max_val = channel_data.abs().max()

        scale = max_val / q_max
        if scale == 0.0:
            scale = 1.0

        quantized_channel = torch.clamp(
            torch.round(channel_data / scale), 
            q_min, q_max
        ).to(torch.int8)

        quantized_channels.append(quantized_channel)
        scales.append(scale)

    quantized = torch.stack(quantized_channels, dim=0)
    scales = torch.tensor(scales)

    quantized = quantized.transpose(0, channel_dim) if channel_dim != 0 else quantized

    return quantized, scales

def per_channel_dequantization_symmetric(quantized: torch.tensor, scales: torch.tensor, channel_dim: int):
    quantized = quantized.transpose(0, channel_dim) if channel_dim != 0 else quantized
    dequantized = []

    for i in range(quantized.shape[0]):
        channel_data = quantized[i]
        scale = scales[i]

        dq = channel_data.float() * scale
        dequantized.append(dq)

    dequantized = dequantized.transpose(0, channel_dim) if channel_dim != 0 else dequantized

    return torch.stack(dequantized)

In [13]:
class QuantizedLinear(nn.Module):
    def __init__(self, in_features, out_features, channel_dim, n_bits=8):
        super().__init__()

        self.fp32_weight = nn.Linear(in_features, out_features)
        self.n_bits = n_bits
        self.channel_dim = channel_dim

        self.register_buffer('weight_quantized', torch.zeros(out_features, in_features, dtype=torch.int8))
        self.register_buffer('weight_scale', torch.ones(out_features))

        self.register_buffer('activation_scale', torch.ones(out_features))

        self.quantized = False

    def quantize_weights(self):
        if self.quantized:
            return 

        weight = self.fp32_weight.weight.data
        quantized, scales = per_channel_quantization_symmetric(weight, self.channel_dim, self.n_bits)

        self.weight_quantized = quantized
        self.weight_scale = scales  

        self.quantized = True

    def set_activation_scale(self, scale):
        self.activation_scale = scale

    def forward(self, x):
        if not self.quantized:
            return nn.functional.linear(x, self.fp32_weight)
        
        x_q, x_scale = per_channel_quantization_symmetric(x, self.channel_dim, self.n_bits)

        w_dq = per_channel_dequantization_symmetric(self.weight_quantized, self.weight_scale, self.channel_dim)
        x_dq = per_channel_dequantization_symmetric(x_q, x_scale, self.channel_dim)

        return nn.functional.linear(x_dq, w_dq)


In [14]:
class QuantizedModel(nn.Module):
    """Quantized version"""
    def __init__(self, input_size=784, hidden_size=256, num_classes=10):
        super().__init__()
        self.fc1 = QuantizedLinear(input_size, hidden_size)
        self.relu1 = nn.ReLU()
        self.fc2 = QuantizedLinear(hidden_size, hidden_size)
        self.relu2 = nn.ReLU()
        self.fc3 = QuantizedLinear(hidden_size, num_classes)
    
    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.relu1(self.fc1(x))
        x = self.relu2(self.fc2(x))
        x = self.fc3(x)
        return x
    
    def quantize_all_weights(self):
        """Quantize all linear layers"""
        self.fc1.quantize_weights()
        self.fc2.quantize_weights()
        self.fc3.quantize_weights()

In [15]:
from torch.autograd import Function

In [16]:
class FakeQuantization(Function):
    @staticmethod
    def forward(self, x, scale, zero_point, q_max, q_min):
        quantized = torch.clamp(
            torch.round(x / scale) + zero_point, 
            q_min, q_max
        ).to(torch.int8)

        dequantizd = (quantized - zero_point).float() * scale
        return dequantizd
    
    @staticmethod
    def backward(self, grad_output):
        return grad_output, None, None, None, None

fake_quantization = FakeQuantization.apply

In [None]:
class MinMaxObserver(nn.Module):
    def __init__(self, n_bits=8):
        super().__init__()

        self.n_bits = n_bits

        self.quantized_max = 2 ** (n_bits - 1) - 1
        self.quantized_min = - self.quantized_max

        self.register_buffer('max_val', torch.tensor(float('-inf')))
        self.register_buffer('min_val', torch.tensor(float('inf')))

        self.register_buffer('scale', torch.tensor(1.0))
        self.register_buffer('zero_point', torch.tensor(0.0))

    def forward(self, x):
        if self.training:
            max_val = x.max()
            min_val = x.min()

            self.max_val = torch.max(max_val, self.max_val)
            self.min_val = torch.min(min_val, self.min_val)

        return x
    
    def calculate_params(self):
        self.scale = (self.max_val - self.min_val) / (self.quantized_max - self.quantized_min)
        if self.scale == 0.0:
            self.scale = 1.0
        
        zero_point = self.quantized_min - torch.round(self.min_val / self.scale)
        zero_point = torch.clamp(zero_point, self.quantized_min, self.quantized_max)

        self.zero_point = zero_point

        return self.scale, self.zero_point

In [None]:
class MovingAverageMinMaxObserver(MinMaxObserver):
    def __init__(self, averaging_constant, n_bits=8):
        super().__init__(n_bits)
        self.averaging_constant = averaging_constant

    def forward(self, x):
        if self.training:
            min_val = x.min_val()
            max_val = x.max_val()

            if self.min_val == float('inf'):
                self.min_val = min_val
                self.max_val = max_val
            else:
                self.min_val = self.averaging_constant * self.min_val + (1. - self.averaging_constant) * min_val
                self.max_val = self.averaging_constant * self.max_val + (1. - self.averaging_constant) * max_val

        return x

In [20]:
class QATLinear(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()

        self.in_features = in_features
        self.out_features = out_features

        self.linear = nn.Linear(in_features, out_features)

        self.weight_observer = MovingAverageMinMaxObserver(0.9)
        self.activation_observer = MovingAverageMinMaxObserver(0.9)

        self.register_buffer('weight_scale', torch.tensor(1.0))
        self.register_buffer('weight_zero_point', torch.tensor(0.0))
        self.register_buffer('activation_scale', torch.tensor(1.0))
        self.register_buffer('activation_zero_point', torch.tensor(0.0))

        self.qat_mode = False

    def enable_observers(self):
        for module in self.modules:
            if isinstance(module, (MinMaxObserver, MovingAverageMinMaxObserver)):
                module.train()

    def disable_observer(self):
        for module in self.modules:
            if isinstance(module, (MinMaxObserver, MovingAverageMinMaxObserver)):
                module.eval()

    def calculate_params(self):
        with torch.no_grad():
            self.weight_observer(self.linear)
            self.weight_scale, self.weight_zero_point = self.weight_observer.calculate_params()

            self.activation_scale, self.activation_zero_point = self.activation_observer.calculate_params()

    def enable_qat(self):
        self.qat_mode = True
        self.calculate_params()

    def forward(self, x):
        x = self.activation_observer(x)

        if not self.qat_mode:
            return nn.functional.linear(x, self.linear)
        
        x_fq = fake_quantization(
            x, 
            self.activation_scale, 
            self.activation_zero_point,
            self.activation_observer.quantized_max, 
            self.activation_observer.quantized_min
        )

        weight_fq = fake_quantization(
            self.linear, 
            self.weight_scale, 
            self.weight_zero_point,
            self.weight_observer.quantized_max,
            self.weight_observer.quantized_min
        )

        return nn.functional.linear(x_fq, weight_fq)

In [21]:
class QATModel(nn.Module):
    """
    Complete model with QAT support
    """
    def __init__(self, input_size=784, hidden_size=256, num_classes=10):
        super().__init__()
        self.fc1 = QATLinear(input_size, hidden_size)
        self.relu1 = nn.ReLU()
        self.fc2 = QATLinear(hidden_size, hidden_size)
        self.relu2 = nn.ReLU()
        self.fc3 = QATLinear(hidden_size, num_classes)
    
    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.relu1(self.fc1(x))
        x = self.relu2(self.fc2(x))
        x = self.fc3(x)
        return x
    
    def enable_observer(self):
        """Enable all observers for calibration"""
        for module in self.modules():
            if isinstance(module, QATLinear):
                module.enable_observer()
    
    def disable_observer(self):
        """Disable all observers after calibration"""
        for module in self.modules():
            if isinstance(module, QATLinear):
                module.disable_observer()
    
    def calculate_qparams(self):
        """Calculate quantization parameters for all layers"""
        for module in self.modules():
            if isinstance(module, QATLinear):
                module.calculate_qparams()
    
    def enable_qat(self):
        """Enable QAT mode (fake quantization)"""
        for module in self.modules():
            if isinstance(module, QATLinear):
                module.enable_qat()

## GPTQ Pending