In [1]:
# BasicRDOPTQ
import torch
import torch.nn as nn
import torch.optim as optim
from typing import List, Tuple, Optional, Dict
import copy
import numpy as np

class QuantizedConv2d(nn.Module):
    def __init__(self, conv_layer: nn.Conv2d, bit_width: int = 8):
        super().__init__()
        self.in_channels = conv_layer.in_channels
        self.out_channels = conv_layer.out_channels
        self.kernel_size = conv_layer.kernel_size
        self.stride = conv_layer.stride
        self.padding = conv_layer.padding
        self.bit_width = bit_width
        self.n_levels = 2 ** bit_width - 1


        self.register_buffer('weight_float', conv_layer.weight.data.clone())
        if conv_layer.bias is not None:
            self.register_buffer('bias_float', conv_layer.bias.data.clone())
        else:
            self.bias_float = None


        self.weight_scale = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1))
        self.weight_zero_point = nn.Parameter(torch.zeros(self.out_channels, 1, 1, 1))

        if self.bias_float is not None:
            self.bias_scale = nn.Parameter(torch.ones(self.out_channels))
            self.bias_zero_point = nn.Parameter(torch.zeros(self.out_channels))


        self._initialize_scales()

    def _initialize_scales(self):
        with torch.no_grad():
            for c in range(self.out_channels):
                w_channel = self.weight_float[c]
                w_min = w_channel.min()
                w_max = w_channel.max()
                scale = (w_max - w_min) / self.n_levels
                zero_point = -w_min / (scale + 1e-8)
                self.weight_scale.data[c] = scale
                self.weight_zero_point.data[c] = zero_point.clamp(0, self.n_levels)

            if self.bias_float is not None:
                for c in range(self.out_channels):
                    b_val = self.bias_float[c]
                    b_abs = abs(b_val) + 1e-8
                    self.bias_scale.data[c] = b_abs / (self.n_levels / 2)
                    self.bias_zero_point.data[c] = self.n_levels / 2

    def quantize_weight(self):
        w_normalized = self.weight_float / (self.weight_scale + 1e-8) + self.weight_zero_point
        w_quant = torch.clamp(torch.round(w_normalized), 0, self.n_levels)
        w_dequant = (w_quant - self.weight_zero_point) * self.weight_scale
        return w_dequant

    def quantize_bias(self):
        if self.bias_float is None:
            return None
        b_normalized = self.bias_float / (self.bias_scale + 1e-8) + self.bias_zero_point
        b_quant = torch.clamp(torch.round(b_normalized), 0, self.n_levels)
        b_dequant = (b_quant - self.bias_zero_point) * self.bias_scale
        return b_dequant

    def forward(self, x):
        w_quant = self.quantize_weight()
        b_quant = self.quantize_bias()

        return nn.functional.conv2d(
            x, w_quant, b_quant,
            stride=self.stride,
            padding=self.padding
        )

    def get_quantized_params(self):
        w_normalized = self.weight_float / (self.weight_scale + 1e-8) + self.weight_zero_point
        w_quant_int = torch.clamp(torch.round(w_normalized), 0, self.n_levels).to(torch.int8)

        if self.bias_float is not None:
            b_normalized = self.bias_float / (self.bias_scale + 1e-8) + self.bias_zero_point
            b_quant_int = torch.clamp(torch.round(b_normalized), 0, self.n_levels).to(torch.int8)
        else:
            b_quant_int = None

        return {
            'weight_int8': w_quant_int,
            'weight_scale': self.weight_scale,
            'weight_zero_point': self.weight_zero_point,
            'bias_int8': b_quant_int,
            'bias_scale': self.bias_scale if self.bias_float is not None else None,
            'bias_zero_point': self.bias_zero_point if self.bias_float is not None else None
        }


class RDO_PTQ:
    def __init__(self,
                 model: nn.Module,
                 bit_width: int = 8,
                 lambda_rd: float = 0.01,
                 device: str = 'cuda' if torch.cuda.is_available() else 'cpu'):
        self.model = model.to(device)
        self.bit_width = bit_width
        self.lambda_rd = lambda_rd
        self.device = device
        self.quantized_layers = {}

    def _get_quantizable_layers(self) -> List[Tuple[str, nn.Module]]:
        quantizable = []
        for name, module in self.model.named_modules():
            if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)):
                quantizable.append((name, module))
        return quantizable

    def _compute_rate_distortion_loss(self,
                                     original_output: torch.Tensor,
                                     quantized_output: torch.Tensor,
                                     quantized_layer: QuantizedConv2d) -> torch.Tensor:


        distortion = torch.mean((original_output - quantized_output) ** 2)


        w_quant = quantized_layer.quantize_weight()
        num_params = w_quant.numel()
        if quantized_layer.bias_float is not None:
            num_params += quantized_layer.bias_float.numel()


        rate_penalty = torch.mean(torch.abs(quantized_layer.weight_scale))


        rd_loss = distortion + self.lambda_rd * rate_penalty

        return rd_loss

    def _get_layer_activations(self, layer_name: str,
                               inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

        activations = {}

        def hook_fn(module, input, output):
            activations['input'] = input[0].detach().clone()
            activations['output'] = output.detach().clone()


        target_layer = dict(self.model.named_modules())[layer_name]
        handle = target_layer.register_forward_hook(hook_fn)


        self.model.eval()
        with torch.no_grad():
            _ = self.model(inputs)


        handle.remove()

        return activations['input'], activations['output']

    def quantize_layer(self,
                      layer_name: str,
                      layer: nn.Conv2d,
                      calibration_inputs: torch.Tensor,
                      original_outputs: torch.Tensor,
                      num_iterations: int = 200,
                      lr: float = 1e-4) -> QuantizedConv2d:

        print(f"Quantizing layer: {layer_name}")


        q_layer = QuantizedConv2d(layer, self.bit_width).to(self.device)


        params = [q_layer.weight_scale, q_layer.weight_zero_point]
        if q_layer.bias_float is not None:
            params.extend([q_layer.bias_scale, q_layer.bias_zero_point])

        optimizer = optim.Adam(params, lr=lr)
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, num_iterations)

        best_loss = float('inf')
        best_state = None


        q_layer.train()
        for iteration in range(num_iterations):
            optimizer.zero_grad()


            quantized_output = q_layer(calibration_inputs)


            loss = self._compute_rate_distortion_loss(
                original_outputs, quantized_output, q_layer
            )


            loss.backward()


            torch.nn.utils.clip_grad_norm_(params, max_norm=1.0)

            optimizer.step()
            scheduler.step()


            if loss.item() < best_loss:
                best_loss = loss.item()
                best_state = {
                    'weight_scale': q_layer.weight_scale.data.clone(),
                    'weight_zero_point': q_layer.weight_zero_point.data.clone(),
                }
                if q_layer.bias_float is not None:
                    best_state['bias_scale'] = q_layer.bias_scale.data.clone()
                    best_state['bias_zero_point'] = q_layer.bias_zero_point.data.clone()

            if iteration % 50 == 0 or iteration == num_iterations - 1:
                print(f"  Iter {iteration:3d}/{num_iterations}, Loss: {loss.item():.6f}, LR: {scheduler.get_last_lr()[0]:.2e}")


        if best_state is not None:
            q_layer.weight_scale.data = best_state['weight_scale']
            q_layer.weight_zero_point.data = best_state['weight_zero_point']
            if q_layer.bias_float is not None:
                q_layer.bias_scale.data = best_state['bias_scale']
                q_layer.bias_zero_point.data = best_state['bias_zero_point']

        q_layer.eval()
        return q_layer

    def quantize_model(self,
                      calibration_images: torch.Tensor,
                      num_iterations: int = 200,
                      lr: float = 1e-4) -> nn.Module:

        calibration_images = calibration_images.to(self.device)
        quantizable_layers = self._get_quantizable_layers()

        print(f"\nFound {len(quantizable_layers)} quantizable layers")
        print(f"Quantizing to {self.bit_width}-bit with lambda={self.lambda_rd}\n")


        quantized_model = copy.deepcopy(self.model)


        for idx, (layer_name, layer) in enumerate(quantizable_layers):
            print(f"[{idx+1}/{len(quantizable_layers)}] Processing: {layer_name}")


            layer_inputs, layer_outputs = self._get_layer_activations(
                layer_name, calibration_images
            )


            q_layer = self.quantize_layer(
                layer_name, layer, layer_inputs, layer_outputs,
                num_iterations, lr
            )


            self._replace_layer(quantized_model, layer_name, q_layer)
            self.quantized_layers[layer_name] = q_layer

            print()

        print("Quantization complete!\n")
        return quantized_model

    def _replace_layer(self, model: nn.Module, layer_name: str, new_layer: nn.Module):

        parts = layer_name.split('.')
        parent = model
        for part in parts[:-1]:
            parent = getattr(parent, part)
        setattr(parent, parts[-1], new_layer)

    def export_quantized_params(self, output_path: str = 'quantized_params.pth'):

        quantized_params = {}
        for name, q_layer in self.quantized_layers.items():
            quantized_params[name] = q_layer.get_quantized_params()

        torch.save(quantized_params, output_path)
        print(f"Exported quantized parameters to {output_path}")
        return quantized_params


def demo_rdo_ptq():

    print("="*70)
    print("RDO-PTQ Demo: Rate-Distortion Optimized Quantization")
    print("="*70)


    class SimpleImageEncoder(nn.Module):
        def __init__(self):
            super().__init__()
            self.encoder = nn.Sequential(
                nn.Conv2d(3, 64, 3, padding=1),
                nn.ReLU(),
                nn.Conv2d(64, 128, 3, stride=2, padding=1),
                nn.ReLU(),
                nn.Conv2d(128, 256, 3, stride=2, padding=1),
                nn.ReLU()
            )

        def forward(self, x):
            return self.encoder(x)


    print("\n[Step 1] Creating pre-trained model...")
    model = SimpleImageEncoder()
    for param in model.parameters():
        nn.init.normal_(param, mean=0, std=0.02)

    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total parameters: {total_params:,}")
    print(f"FP32 model size: {total_params * 4 / (1024**2):.2f} MB")


    print("\n[Step 2] Preparing calibration data...")
    num_calibration_images = 20
    calibration_data = torch.randn(num_calibration_images, 3, 128, 128)
    print(f"Calibration set: {num_calibration_images} images")


    print("\n[Step 3] Initializing RDO-PTQ quantizer...")
    bit_width = 8
    lambda_rd = 0.005

    rdo_ptq = RDO_PTQ(
        model=model,
        bit_width=bit_width,
        lambda_rd=lambda_rd,
        device='cpu'
    )
    print(f"Target: INT{bit_width}, Lambda: {lambda_rd}")


    print("\n[Step 4] Quantizing model...")
    print("-" * 70)
    quantized_model = rdo_ptq.quantize_model(
        calibration_images=calibration_data,
        num_iterations=150,
        lr=5e-4
    )
    print("-" * 70)


    print("[Step 5] Evaluating quantized model...")

    int8_size = total_params * (bit_width / 8) / (1024**2)
    compression_ratio = (total_params * 4) / (total_params * bit_width / 8)

    print(f"\n Compression Results:")
    print(f"  FP32 size: {total_params * 4 / (1024**2):.2f} MB")
    print(f"  INT8 size: {int8_size:.2f} MB")
    print(f"  Compression: {compression_ratio:.1f}x")


    print(f"\n Quality Metrics (on test images):")
    test_images = torch.randn(10, 3, 128, 128)

    model.eval()
    quantized_model.eval()

    with torch.no_grad():
        original_out = model(test_images)
        quantized_out = quantized_model(test_images)

        mse = torch.mean((original_out - quantized_out) ** 2).item()
        psnr = 10 * torch.log10(torch.tensor(original_out.max()**2 / (mse + 1e-10))).item()
        max_diff = torch.max(torch.abs(original_out - quantized_out)).item()
        relative_error = (torch.mean(torch.abs(original_out - quantized_out)) /
                         (torch.mean(torch.abs(original_out)) + 1e-10)).item() * 100

    print(f"  MSE: {mse:.6f}")
    print(f"  PSNR: {psnr:.2f} dB")
    print(f"  Max difference: {max_diff:.4f}")
    print(f"  Relative error: {relative_error:.2f}%")


    print(f"\n Exporting quantized parameters...")
    rdo_ptq.export_quantized_params('quantized_model.pth')

    print("\n" + "="*70)
    print("Demo complete!")
    print("="*70)

    return model, quantized_model


if __name__ == "__main__":
    original_model, quantized_model = demo_rdo_ptq()

RDO-PTQ Demo: Rate-Distortion Optimized Quantization

[Step 1] Creating pre-trained model...
Total parameters: 370,816
FP32 model size: 1.41 MB

[Step 2] Preparing calibration data...
Calibration set: 20 images

[Step 3] Initializing RDO-PTQ quantizer...
Target: INT8, Lambda: 0.005

[Step 4] Quantizing model...
----------------------------------------------------------------------

Found 3 quantizable layers
Quantizing to 8-bit with lambda=0.005

[1/3] Processing: encoder.0
Quantizing layer: encoder.0
  Iter   0/150, Loss: 0.000002, LR: 5.00e-04
  Iter  50/150, Loss: 0.000024, LR: 3.70e-04
  Iter 100/150, Loss: 0.000023, LR: 1.20e-04
  Iter 149/150, Loss: 0.000023, LR: 0.00e+00

[2/3] Processing: encoder.2
Quantizing layer: encoder.2
  Iter   0/150, Loss: 0.000002, LR: 5.00e-04
  Iter  50/150, Loss: 0.000013, LR: 3.70e-04
  Iter 100/150, Loss: 0.000012, LR: 1.20e-04
  Iter 149/150, Loss: 0.000012, LR: 0.00e+00

[3/3] Processing: encoder.4
Quantizing layer: encoder.4
  Iter   0/150, Los

  psnr = 10 * torch.log10(torch.tensor(original_out.max()**2 / (mse + 1e-10))).item()


In [2]:
#ProductionRDOPTQ
import torch
import torch.nn as nn
import torch.optim as optim
from typing import List, Tuple, Optional, Dict, Union
import copy
import numpy as np
from collections import defaultdict
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class QuantizedConv2d(nn.Module):

    def __init__(self, conv_layer: nn.Conv2d, bit_width: int = 8, quantize_activations: bool = True):
        super().__init__()
        self.in_channels = conv_layer.in_channels
        self.out_channels = conv_layer.out_channels
        self.kernel_size = conv_layer.kernel_size
        self.stride = conv_layer.stride
        self.padding = conv_layer.padding
        self.dilation = conv_layer.dilation
        self.groups = conv_layer.groups
        self.bit_width = bit_width
        self._quantize_activations = quantize_activations


        self.qmax = (2 ** (bit_width - 1)) - 1
        self.qmin = -(2 ** (bit_width - 1))


        self.register_buffer('weight_float', conv_layer.weight.data.clone())
        if conv_layer.bias is not None:
            self.register_buffer('bias_float', conv_layer.bias.data.clone())
        else:
            self.bias_float = None


        self.weight_scale = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1))

        if self.bias_float is not None:
            self.bias_scale = nn.Parameter(torch.ones(self.out_channels))


        if self._quantize_activations:
            self.act_scale = nn.Parameter(torch.ones(1))
            self.register_buffer('act_running_min', torch.zeros(1))
            self.register_buffer('act_running_max', torch.zeros(1))
            self.register_buffer('num_batches_tracked', torch.tensor(0))

        self._initialize_scales()

    def _initialize_scales(self):

        with torch.no_grad():

            for c in range(self.out_channels):
                w_channel = self.weight_float[c]
                w_max_abs = torch.max(torch.abs(w_channel))
                scale = w_max_abs / (self.qmax + 1e-8)
                self.weight_scale.data[c] = scale.clamp(min=1e-8)


            if self.bias_float is not None:
                for c in range(self.out_channels):
                    b_abs = torch.abs(self.bias_float[c])
                    self.bias_scale.data[c] = (b_abs / (self.qmax + 1e-8)).clamp(min=1e-8)

    def quantize_weights(self):

        w_div = self.weight_float / self.weight_scale.clamp(min=1e-8)
        w_quant = torch.clamp(torch.round(w_div), self.qmin, self.qmax)
        w_dequant = w_quant * self.weight_scale
        return w_dequant

    def quantize_bias(self):

        if self.bias_float is None:
            return None
        b_div = self.bias_float / self.bias_scale.clamp(min=1e-8)
        b_quant = torch.clamp(torch.round(b_div), self.qmin, self.qmax)
        b_dequant = b_quant * self.bias_scale
        return b_dequant

    def quantize_activations(self, x: torch.Tensor, training: bool = False):

        if not self._quantize_activations:
            return x

        if training:

            x_min = x.min()
            x_max = x.max()

            momentum = 0.1
            self.act_running_min = (1 - momentum) * self.act_running_min + momentum * x_min
            self.act_running_max = (1 - momentum) * self.act_running_max + momentum * x_max
            self.num_batches_tracked += 1


        x_max_abs = torch.max(torch.abs(self.act_running_min), torch.abs(self.act_running_max))
        scale = x_max_abs / (self.qmax + 1e-8)
        scale = scale.clamp(min=1e-8)


        x_div = x / scale
        x_quant = torch.clamp(torch.round(x_div), self.qmin, self.qmax)
        x_dequant = x_quant * scale

        return x_dequant

    def forward(self, x):

        if self._quantize_activations:
            x = self.quantize_activations(x, self.training)


        w_quant = self.quantize_weights()
        b_quant = self.quantize_bias()


        out = nn.functional.conv2d(
            x, w_quant, b_quant,
            stride=self.stride,
            padding=self.padding,
            dilation=self.dilation,
            groups=self.groups
        )

        return out

    def compute_compression_rate(self) -> float:

        with torch.no_grad():
            w_div = self.weight_float / self.weight_scale.clamp(min=1e-8)
            w_quant = torch.clamp(torch.round(w_div), self.qmin, self.qmax)


            total_entropy = 0
            for c in range(self.out_channels):
                unique_vals = torch.unique(w_quant[c])
                n_unique = len(unique_vals)
                effective_bits = np.log2(n_unique + 1)
                total_entropy += effective_bits * w_quant[c].numel()

            total_weights = w_quant.numel()
            avg_bits = total_entropy / total_weights

        return avg_bits

    def get_quantized_params(self) -> Dict[str, torch.Tensor]:

        with torch.no_grad():
            w_div = self.weight_float / self.weight_scale.clamp(min=1e-8)
            w_int8 = torch.clamp(torch.round(w_div), self.qmin, self.qmax).to(torch.int8)

            result = {
                'weight_int8': w_int8,
                'weight_scale': self.weight_scale.clone(),
            }

            if self.bias_float is not None:
                b_div = self.bias_float / self.bias_scale.clamp(min=1e-8)
                b_int8 = torch.clamp(torch.round(b_div), self.qmin, self.qmax).to(torch.int8)
                result['bias_int8'] = b_int8
                result['bias_scale'] = self.bias_scale.clone()

            if self._quantize_activations:
                result['act_scale'] = self.act_scale.clone()
                result['act_min'] = self.act_running_min.clone()
                result['act_max'] = self.act_running_max.clone()

        return result


class QuantizedLinear(nn.Module):

    def __init__(self, linear_layer: nn.Linear, bit_width: int = 8, quantize_activations: bool = True):
        super().__init__()
        self.in_features = linear_layer.in_features
        self.out_features = linear_layer.out_features
        self.bit_width = bit_width
        self._quantize_activations = quantize_activations

        self.qmax = (2 ** (bit_width - 1)) - 1
        self.qmin = -(2 ** (bit_width - 1))

        self.register_buffer('weight_float', linear_layer.weight.data.clone())
        if linear_layer.bias is not None:
            self.register_buffer('bias_float', linear_layer.bias.data.clone())
        else:
            self.bias_float = None

        self.weight_scale = nn.Parameter(torch.ones(self.out_features, 1))

        if self.bias_float is not None:
            self.bias_scale = nn.Parameter(torch.ones(self.out_features))

        if self._quantize_activations:
            self.act_scale = nn.Parameter(torch.ones(1))
            self.register_buffer('act_running_min', torch.zeros(1))
            self.register_buffer('act_running_max', torch.zeros(1))

        self._initialize_scales()

    def _initialize_scales(self):
        with torch.no_grad():
            for c in range(self.out_features):
                w_row = self.weight_float[c]
                w_max_abs = torch.max(torch.abs(w_row))
                self.weight_scale.data[c] = (w_max_abs / (self.qmax + 1e-8)).clamp(min=1e-8)

            if self.bias_float is not None:
                for c in range(self.out_features):
                    b_abs = torch.abs(self.bias_float[c])
                    self.bias_scale.data[c] = (b_abs / (self.qmax + 1e-8)).clamp(min=1e-8)

    def quantize_weights(self):
        w_div = self.weight_float / self.weight_scale.clamp(min=1e-8)
        w_quant = torch.clamp(torch.round(w_div), self.qmin, self.qmax)
        return w_quant * self.weight_scale

    def quantize_bias(self):
        if self.bias_float is None:
            return None
        b_div = self.bias_float / self.bias_scale.clamp(min=1e-8)
        b_quant = torch.clamp(torch.round(b_div), self.qmin, self.qmax)
        return b_quant * self.bias_scale

    def forward(self, x):
        w_quant = self.quantize_weights()
        b_quant = self.quantize_bias()
        return nn.functional.linear(x, w_quant, b_quant)

    def get_quantized_params(self):
        with torch.no_grad():
            w_div = self.weight_float / self.weight_scale.clamp(min=1e-8)
            w_int8 = torch.clamp(torch.round(w_div), self.qmin, self.qmax).to(torch.int8)

            result = {'weight_int8': w_int8, 'weight_scale': self.weight_scale.clone()}

            if self.bias_float is not None:
                b_div = self.bias_float / self.bias_scale.clamp(min=1e-8)
                b_int8 = torch.clamp(torch.round(b_div), self.qmin, self.qmax).to(torch.int8)
                result['bias_int8'] = b_int8
                result['bias_scale'] = self.bias_scale.clone()

        return result


class ProductionRDO_PTQ:

    def __init__(self,
                 model: nn.Module,
                 bit_width: int = 8,
                 lambda_rd: float = 0.01,
                 quantize_activations: bool = True,
                 use_blocks: bool = True,
                 mixed_precision: bool = False,
                 device: str = 'cuda' if torch.cuda.is_available() else 'cpu'):

        self.model = model.to(device)
        self.default_bit_width = bit_width
        self.lambda_rd = lambda_rd
        self.quantize_activations = quantize_activations
        self.use_blocks = use_blocks
        self.mixed_precision = mixed_precision
        self.device = device

        self.quantized_layers = {}
        self.layer_bit_widths = {}
        self.layer_sensitivities = {}

    def _get_quantizable_layers(self) -> List[Tuple[str, nn.Module]]:

        quantizable = []
        for name, module in self.model.named_modules():
            if isinstance(module, (nn.Conv2d, nn.Linear)):
                quantizable.append((name, module))
        return quantizable

    def _get_layer_blocks(self, layers: List[Tuple[str, nn.Module]], block_size: int = 3) -> List[List[Tuple[str, nn.Module]]]:

        if not self.use_blocks:
            return [[layer] for layer in layers]

        blocks = []
        for i in range(0, len(layers), block_size):
            blocks.append(layers[i:i + block_size])
        return blocks

    def analyze_layer_sensitivity(self,
                                  calibration_data: torch.Tensor,
                                  num_samples: int = 100) -> Dict[str, float]:

        logger.info("Analyzing layer sensitivities...")

        quantizable_layers = self._get_quantizable_layers()
        sensitivities = {}

        self.model.eval()
        with torch.no_grad():

            baseline_output = self.model(calibration_data[:num_samples])

        for layer_name, layer in quantizable_layers:

            if isinstance(layer, nn.Conv2d):
                temp_q_layer = QuantizedConv2d(layer, bit_width=4, quantize_activations=False)
            else:
                temp_q_layer = QuantizedLinear(layer, bit_width=4, quantize_activations=False)


            original_layer = self._get_layer_by_name(self.model, layer_name)
            self._replace_layer(self.model, layer_name, temp_q_layer)


            with torch.no_grad():
                quantized_output = self.model(calibration_data[:num_samples])
                sensitivity = torch.mean((baseline_output - quantized_output) ** 2).item()

            sensitivities[layer_name] = sensitivity


            self._replace_layer(self.model, layer_name, original_layer)

            logger.info(f"  {layer_name}: sensitivity = {sensitivity:.6f}")

        self.layer_sensitivities = sensitivities
        return sensitivities

    def assign_bit_widths(self, total_bit_budget: Optional[float] = None):

        if not self.mixed_precision or not self.layer_sensitivities:

            for name, _ in self._get_quantizable_layers():
                self.layer_bit_widths[name] = self.default_bit_width
            return

        logger.info("Assigning mixed precision bit-widths...")


        sensitivities = np.array(list(self.layer_sensitivities.values()))
        sens_min, sens_max = sensitivities.min(), sensitivities.max()
        sens_norm = (sensitivities - sens_min) / (sens_max - sens_min + 1e-8)


        bit_options = [4, 6, 8]
        layer_names = list(self.layer_sensitivities.keys())

        for i, (name, sens) in enumerate(zip(layer_names, sens_norm)):
            if sens > 0.7:
                bit_width = 8
            elif sens > 0.3:
                bit_width = 6
            else:
                bit_width = 4

            self.layer_bit_widths[name] = bit_width
            logger.info(f"  {name}: {bit_width}-bit (sensitivity={sens:.3f})")

    def _compute_hessian_trace(self,
                               layer_outputs: torch.Tensor,
                               target_outputs: torch.Tensor) -> torch.Tensor:

        loss = torch.mean((layer_outputs - target_outputs) ** 2)
        grads = torch.autograd.grad(loss, layer_outputs, create_graph=True)[0]
        hessian_trace = torch.mean(grads ** 2)
        return hessian_trace

    def _compute_rd_loss(self,
                        original_output: torch.Tensor,
                        quantized_output: torch.Tensor,
                        quantized_layer: Union[QuantizedConv2d, QuantizedLinear],
                        use_hessian: bool = False) -> Tuple[torch.Tensor, Dict[str, float]]:

        distortion = torch.mean((original_output - quantized_output) ** 2)


        rate = torch.mean(torch.log(quantized_layer.weight_scale.clamp(min=1e-8) + 1e-8))


        if use_hessian:
            hessian_trace = self._compute_hessian_trace(quantized_output, original_output)
            distortion = distortion * (1.0 + 0.1 * hessian_trace)


        total_loss = distortion + self.lambda_rd * rate

        metrics = {
            'distortion': distortion.item(),
            'rate': rate.item(),
            'total': total_loss.item()
        }

        return total_loss, metrics

    def _get_layer_activations(self,
                               layer_name: str,
                               inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

        activations = {}

        def hook_fn(module, inp, output):
            activations['input'] = inp[0].detach().clone()
            activations['output'] = output.detach().clone()

        target_layer = self._get_layer_by_name(self.model, layer_name)
        handle = target_layer.register_forward_hook(hook_fn)

        self.model.eval()
        with torch.no_grad():
            _ = self.model(inputs)

        handle.remove()
        return activations['input'], activations['output']

    def quantize_layer(self,
                      layer_name: str,
                      layer: Union[nn.Conv2d, nn.Linear],
                      calibration_inputs: torch.Tensor,
                      original_outputs: torch.Tensor,
                      bit_width: int,
                      num_iterations: int = 300,
                      lr: float = 1e-3) -> Union[QuantizedConv2d, QuantizedLinear]:

        logger.info(f"Quantizing {layer_name} to {bit_width}-bit...")


        if isinstance(layer, nn.Conv2d):
            q_layer = QuantizedConv2d(layer, bit_width, self.quantize_activations).to(self.device)
        else:
            q_layer = QuantizedLinear(layer, bit_width, self.quantize_activations).to(self.device)


        params = [q_layer.weight_scale]
        if q_layer.bias_float is not None:
            params.append(q_layer.bias_scale)
        if self.quantize_activations:
            if hasattr(q_layer, 'act_scale'):
                params.append(q_layer.act_scale)


        optimizer = optim.AdamW(params, lr=lr, weight_decay=1e-5)
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, num_iterations)

        best_loss = float('inf')
        best_state = {p: p.data.clone() for p in params}

        q_layer.train()
        for iteration in range(num_iterations):
            optimizer.zero_grad()

            quantized_output = q_layer(calibration_inputs)

            loss, metrics = self._compute_rd_loss(
                original_outputs, quantized_output, q_layer, use_hessian=(iteration % 10 == 0)
            )

            loss.backward()
            torch.nn.utils.clip_grad_norm_(params, max_norm=1.0)
            optimizer.step()
            scheduler.step()

            if loss.item() < best_loss:
                best_loss = loss.item()
                best_state = {p: p.data.clone() for p in params}

            if iteration % 50 == 0 or iteration == num_iterations - 1:
                logger.info(f"  Iter {iteration:3d}/{num_iterations} | "
                          f"Loss: {metrics['total']:.6f} | "
                          f"D: {metrics['distortion']:.6f} | "
                          f"R: {metrics['rate']:.4f}")


        for p in params:
            p.data = best_state[p]

        q_layer.eval()


        if hasattr(q_layer, 'compute_compression_rate'):
            comp_rate = q_layer.compute_compression_rate()
            logger.info(f"  Effective bits/weight: {comp_rate:.2f}")

        return q_layer

    def quantize_block(self,
                      block: List[Tuple[str, nn.Module]],
                      calibration_data: torch.Tensor,
                      num_iterations: int = 300) -> List[Union[QuantizedConv2d, QuantizedLinear]]:

        quantized_block = []

        for layer_name, layer in block:

            bit_width = self.layer_bit_widths.get(layer_name, self.default_bit_width)


            layer_inputs, layer_outputs = self._get_layer_activations(layer_name, calibration_data)


            q_layer = self.quantize_layer(
                layer_name, layer, layer_inputs, layer_outputs,
                bit_width, num_iterations
            )

            quantized_block.append(q_layer)
            self.quantized_layers[layer_name] = q_layer

        return quantized_block

    def quantize_model(self,
                      calibration_data: torch.Tensor,
                      num_iterations: int = 300,
                      analyze_sensitivity: bool = True) -> nn.Module:

        calibration_data = calibration_data.to(self.device)


        if self.mixed_precision and analyze_sensitivity:
            self.analyze_layer_sensitivity(calibration_data)
            self.assign_bit_widths()
        else:
            for name, _ in self._get_quantizable_layers():
                self.layer_bit_widths[name] = self.default_bit_width


        quantizable_layers = self._get_quantizable_layers()
        blocks = self._get_layer_blocks(quantizable_layers)

        logger.info(f"\nQuantizing {len(quantizable_layers)} layers in {len(blocks)} blocks")


        quantized_model = copy.deepcopy(self.model)


        for block_idx, block in enumerate(blocks):
            logger.info(f"\n[Block {block_idx + 1}/{len(blocks)}]")

            quantized_block = self.quantize_block(block, calibration_data, num_iterations)


            for (layer_name, _), q_layer in zip(block, quantized_block):
                self._replace_layer(quantized_model, layer_name, q_layer)

        logger.info("\n✓ Quantization complete!")
        return quantized_model

    def _get_layer_by_name(self, model: nn.Module, layer_name: str) -> nn.Module:

        parts = layer_name.split('.')
        layer = model
        for part in parts:
            layer = getattr(layer, part)
        return layer

    def _replace_layer(self, model: nn.Module, layer_name: str, new_layer: nn.Module):

        parts = layer_name.split('.')
        parent = model
        for part in parts[:-1]:
            parent = getattr(parent, part)
        setattr(parent, parts[-1], new_layer)

    def evaluate_model(self,
                      original_model: nn.Module,
                      quantized_model: nn.Module,
                      test_data: torch.Tensor) -> Dict[str, float]:

        original_model.eval()
        quantized_model.eval()

        with torch.no_grad():
            orig_out = original_model(test_data)
            quant_out = quantized_model(test_data)

            mse = torch.mean((orig_out - quant_out) ** 2).item()
            max_val = orig_out.abs().max().item()
            psnr = 10 * np.log10(max_val**2 / (mse + 1e-10)) if mse > 0 else float('inf')

            rel_error = (torch.mean(torch.abs(orig_out - quant_out)) /
                        (torch.mean(torch.abs(orig_out)) + 1e-10)).item() * 100

            cosine_sim = nn.functional.cosine_similarity(
                orig_out.flatten(), quant_out.flatten(), dim=0
            ).item()


        orig_params = sum(p.numel() for p in original_model.parameters())
        orig_size_mb = orig_params * 4 / (1024 ** 2)


        quant_size_bits = 0
        for name, q_layer in self.quantized_layers.items():
            bit_width = self.layer_bit_widths.get(name, self.default_bit_width)
            if isinstance(q_layer, (QuantizedConv2d, QuantizedLinear)):
                num_params = q_layer.weight_float.numel()
                if q_layer.bias_float is not None:
                    num_params += q_layer.bias_float.numel()
                quant_size_bits += num_params * bit_width

        quant_size_mb = quant_size_bits / (8 * 1024 ** 2)
        compression_ratio = orig_size_mb / quant_size_mb if quant_size_mb > 0 else 0

        metrics = {
            'mse': mse,
            'psnr_db': psnr,
            'relative_error_pct': rel_error,
            'cosine_similarity': cosine_sim,
            'original_size_mb': orig_size_mb,
            'quantized_size_mb': quant_size_mb,
            'compression_ratio': compression_ratio
        }

        return metrics

    def export_quantized_model(self,
                               quantized_model: nn.Module,
                               output_path: str = 'quantized_model.pth'):

        export_dict = {
            'model_state': quantized_model.state_dict(),
            'quantized_params': {},
            'bit_widths': self.layer_bit_widths,
            'config': {
                'default_bit_width': self.default_bit_width,
                'lambda_rd': self.lambda_rd,
                'quantize_activations': self.quantize_activations
            }
        }

        for name, q_layer in self.quantized_layers.items():
            export_dict['quantized_params'][name] = q_layer.get_quantized_params()

        torch.save(export_dict, output_path)
        logger.info(f"Exported quantized model to {output_path}")

        return export_dict


def demo_production_rdo_ptq():


    print("="*80)
    print("Production RDO-PTQ: Full-Featured Quantization Pipeline")
    print("="*80)


    class ResNetBlock(nn.Module):
        def __init__(self, in_channels, out_channels, stride=1):
            super().__init__()
            self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1)
            self.bn1 = nn.BatchNorm2d(out_channels)
            self.relu = nn.ReLU(inplace=True)
            self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
            self.bn2 = nn.BatchNorm2d(out_channels)

            self.downsample = None
            if stride != 1 or in_channels != out_channels:
                self.downsample = nn.Sequential(
                    nn.Conv2d(in_channels, out_channels, 1, stride=stride),
                    nn.BatchNorm2d(out_channels)
                )

        def forward(self, x):
            identity = x

            out = self.conv1(x)
            out = self.bn1(out)
            out = self.relu(out)

            out = self.conv2(out)
            out = self.bn2(out)

            if self.downsample is not None:
                identity = self.downsample(x)

            out += identity
            out = self.relu(out)
            return out

    class ImageClassifier(nn.Module):
        def __init__(self, num_classes=10):
            super().__init__()
            self.conv1 = nn.Conv2d(3, 64, 7, stride=2, padding=3)
            self.bn1 = nn.BatchNorm2d(64)
            self.relu = nn.ReLU(inplace=True)
            self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)

            self.layer1 = ResNetBlock(64, 64)
            self.layer2 = ResNetBlock(64, 128, stride=2)
            self.layer3 = ResNetBlock(128, 256, stride=2)

            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
            self.fc = nn.Linear(256, num_classes)

        def forward(self, x):
            x = self.conv1(x)
            x = self.bn1(x)
            x = self.relu(x)
            x = self.maxpool(x)

            x = self.layer1(x)
            x = self.layer2(x)
            x = self.layer3(x)

            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.fc(x)
            return x

    print("\n[Step 1] Creating pre-trained model...")
    model = ImageClassifier(num_classes=10)


    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, 0, 0.01)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    total_params = sum(p.numel() for p in model.parameters())
    fp32_size = total_params * 4 / (1024 ** 2)
    print(f"Total parameters: {total_params:,}")
    print(f"FP32 model size: {fp32_size:.2f} MB")


    print("\n[Step 2] Preparing calibration data...")
    num_calibration = 50
    calibration_data = torch.randn(num_calibration, 3, 224, 224)
    print(f"Calibration set: {num_calibration} images (224x224)")


    print("\n[Step 3] Initializing Production RDO-PTQ...")
    print("Features enabled:")
    print("  ✓ Activation quantization")
    print("  ✓ Block-wise reconstruction")
    print("  ✓ Mixed precision search")
    print("  ✓ Hessian-guided optimization")

    quantizer = ProductionRDO_PTQ(
        model=model,
        bit_width=8,
        lambda_rd=0.005,
        quantize_activations=True,
        use_blocks=True,
        mixed_precision=True,
        device='cpu'
    )


    print("\n[Step 4] Running quantization pipeline...")
    print("-" * 80)

    quantized_model = quantizer.quantize_model(
        calibration_data=calibration_data,
        num_iterations=150,
        analyze_sensitivity=True
    )

    print("-" * 80)

    print("\n[Step 5] Evaluating quantized model...")
    test_data = torch.randn(20, 3, 224, 224)

    metrics = quantizer.evaluate_model(model, quantized_model, test_data)

    print("\n" + "="*80)
    print("QUANTIZATION RESULTS")
    print("="*80)

    print("\n Compression Metrics:")
    print(f"  Original size:    {metrics['original_size_mb']:.2f} MB")
    print(f"  Quantized size:   {metrics['quantized_size_mb']:.2f} MB")
    print(f"  Compression:      {metrics['compression_ratio']:.2f}x")

    print("\n Accuracy Metrics:")
    print(f"  MSE:              {metrics['mse']:.8f}")
    print(f"  PSNR:             {metrics['psnr_db']:.2f} dB")
    print(f"  Relative Error:   {metrics['relative_error_pct']:.2f}%")
    print(f"  Cosine Sim:       {metrics['cosine_similarity']:.4f}")

    print("\n Per-Layer Bit-Widths:")
    for layer_name, bit_width in list(quantizer.layer_bit_widths.items())[:10]:
        print(f"  {layer_name:30s} → {bit_width}-bit")
    if len(quantizer.layer_bit_widths) > 10:
        print(f"  ... and {len(quantizer.layer_bit_widths) - 10} more layers")

    print("\n[Step 6] Exporting quantized model...")
    export_dict = quantizer.export_quantized_model(quantized_model, 'production_quantized.pth')

    print(f"\n Export complete:")
    print(f"  Model state:      {len(export_dict['model_state'])} tensors")
    print(f"  Quantized params: {len(export_dict['quantized_params'])} layers")
    print(f"  File saved:       production_quantized.pth")


    print("\n[Step 7] Deployment Info:")
    print("-" * 80)

    return model, quantized_model, quantizer


if __name__ == "__main__":

    original_model, quantized_model, quantizer = demo_production_rdo_ptq()



Production RDO-PTQ: Full-Featured Quantization Pipeline

[Step 1] Creating pre-trained model...
Total parameters: 1,236,618
FP32 model size: 4.72 MB

[Step 2] Preparing calibration data...
Calibration set: 50 images (224x224)

[Step 3] Initializing Production RDO-PTQ...
Features enabled:
  ✓ Activation quantization
  ✓ Block-wise reconstruction
  ✓ Mixed precision search
  ✓ Hessian-guided optimization

[Step 4] Running quantization pipeline...
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------

[Step 5] Evaluating quantized model...

QUANTIZATION RESULTS

 Compression Metrics:
  Original size:    4.72 MB
  Quantized size:   0.68 MB
  Compression:      6.96x

 Accuracy Metrics:
  MSE:              0.00197210
  PSNR:             9.82 dB
  Relative Error:   62.41%
  Cosine Sim:       0.7467

 Per-Layer Bit-Widths:
  conv1                          → 4-bit
  layer1.conv1          

In [3]:
#FixedRDOPTQ
import torch
import torch.nn as nn
import torch.optim as optim
from typing import List, Tuple, Dict
import copy
import numpy as np

print("Importing RDO-PTQ modules...")

class QuantizedConv2d(nn.Module):

    def __init__(self, conv_layer: nn.Conv2d, bit_width: int = 8):
        super().__init__()
        self.in_channels = conv_layer.in_channels
        self.out_channels = conv_layer.out_channels
        self.kernel_size = conv_layer.kernel_size
        self.stride = conv_layer.stride
        self.padding = conv_layer.padding
        self.bit_width = bit_width

        self.qmax = (2 ** (bit_width - 1)) - 1
        self.qmin = -(2 ** (bit_width - 1))

        self.register_buffer('weight_float', conv_layer.weight.data.clone())
        if conv_layer.bias is not None:
            self.register_buffer('bias_float', conv_layer.bias.data.clone())
        else:
            self.bias_float = None

        self.weight_scale = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1))

        if self.bias_float is not None:
            self.bias_scale = nn.Parameter(torch.ones(self.out_channels))

        self._initialize_scales()

    def _initialize_scales(self):
        with torch.no_grad():
            for c in range(self.out_channels):
                w_channel = self.weight_float[c]
                w_max_abs = torch.max(torch.abs(w_channel))
                self.weight_scale.data[c] = (w_max_abs / (self.qmax + 1e-8)).clamp(min=1e-8)

            if self.bias_float is not None:
                for c in range(self.out_channels):
                    b_abs = torch.abs(self.bias_float[c])
                    self.bias_scale.data[c] = (b_abs / (self.qmax + 1e-8)).clamp(min=1e-8)

    def quantize_weights(self):
        w_div = self.weight_float / self.weight_scale.clamp(min=1e-8)
        w_quant = torch.clamp(torch.round(w_div), self.qmin, self.qmax)
        return w_quant * self.weight_scale

    def quantize_bias(self):
        if self.bias_float is None:
            return None
        b_div = self.bias_float / self.bias_scale.clamp(min=1e-8)
        b_quant = torch.clamp(torch.round(b_div), self.qmin, self.qmax)
        return b_quant * self.bias_scale

    def forward(self, x):
        w_quant = self.quantize_weights()
        b_quant = self.quantize_bias()
        return nn.functional.conv2d(x, w_quant, b_quant,
                                   stride=self.stride, padding=self.padding)

    def get_quantized_params(self):
        with torch.no_grad():
            w_div = self.weight_float / self.weight_scale.clamp(min=1e-8)
            w_int8 = torch.clamp(torch.round(w_div), self.qmin, self.qmax).to(torch.int8)

            result = {'weight_int8': w_int8, 'weight_scale': self.weight_scale.clone()}

            if self.bias_float is not None:
                b_div = self.bias_float / self.bias_scale.clamp(min=1e-8)
                b_int8 = torch.clamp(torch.round(b_div), self.qmin, self.qmax).to(torch.int8)
                result['bias_int8'] = b_int8
                result['bias_scale'] = self.bias_scale.clone()

        return result


class FixedRDO_PTQ:


    def __init__(self,
                 model: nn.Module,
                 bit_width: int = 8,
                 lambda_rd: float = 0.01,
                 mixed_precision: bool = False,
                 device: str = 'cpu'):
        self.model = model.to(device)
        self.default_bit_width = bit_width
        self.lambda_rd = lambda_rd
        self.mixed_precision = mixed_precision
        self.device = device

        self.quantized_layers = {}
        self.layer_bit_widths = {}
        self.layer_sensitivities = {}

    def _get_quantizable_layers(self) -> List[Tuple[str, nn.Module]]:
        quantizable = []
        for name, module in self.model.named_modules():
            if isinstance(module, nn.Conv2d):
                quantizable.append((name, module))
        print(f"DEBUG: Found {len(quantizable)} Conv2d layers")
        return quantizable

    def analyze_sensitivity(self, calibration_data: torch.Tensor) -> Dict[str, float]:

        print("\n Analyzing layer sensitivities...")

        quantizable_layers = self._get_quantizable_layers()
        sensitivities = {}


        data_subset = calibration_data[:min(10, len(calibration_data))].to(self.device)

        self.model.eval()
        with torch.no_grad():
            baseline = self.model(data_subset)

        for idx, (layer_name, layer) in enumerate(quantizable_layers):
            print(f"  [{idx+1}/{len(quantizable_layers)}] Testing {layer_name}...", end=' ')


            temp_q = QuantizedConv2d(layer, bit_width=4).to(self.device)
            temp_q.eval()


            parts = layer_name.split('.')
            parent = self.model
            for part in parts[:-1]:
                parent = getattr(parent, part)
            original = getattr(parent, parts[-1])


            setattr(parent, parts[-1], temp_q)


            with torch.no_grad():
                quantized = self.model(data_subset)
                sens = torch.mean((baseline - quantized) ** 2).item()

            sensitivities[layer_name] = sens
            print(f"sensitivity={sens:.6f}")


            setattr(parent, parts[-1], original)

        self.layer_sensitivities = sensitivities
        return sensitivities

    def assign_bit_widths(self):

        if not self.layer_sensitivities:
            for name, _ in self._get_quantizable_layers():
                self.layer_bit_widths[name] = self.default_bit_width
            return

        print("\n Assigning bit-widths...")

        sens_list = np.array(list(self.layer_sensitivities.values()))
        if len(sens_list) == 0:
            return

        sens_min, sens_max = sens_list.min(), sens_list.max()

        for name, sens in self.layer_sensitivities.items():

            if sens_max > sens_min:
                norm_sens = (sens - sens_min) / (sens_max - sens_min)
            else:
                norm_sens = 0.5


            if norm_sens > 0.7:
                bits = 8
            elif norm_sens > 0.3:
                bits = 6
            else:
                bits = 4

            self.layer_bit_widths[name] = bits
            print(f"  {name:30s} → {bits}-bit (sens={norm_sens:.3f})")

    def _get_activations(self, layer_name: str, inputs: torch.Tensor):

        acts = {}

        def hook(module, inp, out):
            acts['input'] = inp[0].detach().clone()
            acts['output'] = out.detach().clone()


        parts = layer_name.split('.')
        layer = self.model
        for part in parts:
            layer = getattr(layer, part)

        handle = layer.register_forward_hook(hook)

        self.model.eval()
        with torch.no_grad():
            _ = self.model(inputs)

        handle.remove()
        return acts['input'], acts['output']

    def quantize_layer(self,
                      layer_name: str,
                      layer: nn.Conv2d,
                      layer_inputs: torch.Tensor,
                      layer_outputs: torch.Tensor,
                      bit_width: int,
                      num_iters: int = 100) -> QuantizedConv2d:


        print(f"  Quantizing {layer_name} → {bit_width}-bit")

        q_layer = QuantizedConv2d(layer, bit_width).to(self.device)


        params = [q_layer.weight_scale]
        if q_layer.bias_float is not None:
            params.append(q_layer.bias_scale)

        optimizer = optim.Adam(params, lr=1e-3)
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, num_iters)

        best_loss = float('inf')
        best_state = [p.data.clone() for p in params]

        q_layer.train()
        for it in range(num_iters):
            optimizer.zero_grad()

            q_out = q_layer(layer_inputs)


            distortion = torch.mean((layer_outputs - q_out) ** 2)
            rate = torch.mean(torch.log(q_layer.weight_scale.clamp(min=1e-8) + 1e-8))
            loss = distortion + self.lambda_rd * rate

            loss.backward()
            torch.nn.utils.clip_grad_norm_(params, 1.0)
            optimizer.step()
            scheduler.step()

            if loss.item() < best_loss:
                best_loss = loss.item()
                best_state = [p.data.clone() for p in params]

            if it % 25 == 0 or it == num_iters - 1:
                print(f"    Iter {it:3d}/{num_iters} | Loss: {loss.item():.6f} | "
                      f"D: {distortion.item():.6f} | R: {rate.item():.4f}")


        for p, best in zip(params, best_state):
            p.data = best

        q_layer.eval()
        return q_layer

    def quantize_model(self,
                      calibration_data: torch.Tensor,
                      num_iterations: int = 100) -> nn.Module:


        print("\n" + "="*80)
        print("STARTING QUANTIZATION")
        print("="*80)

        calibration_data = calibration_data.to(self.device)


        quantizable_layers = self._get_quantizable_layers()

        if len(quantizable_layers) == 0:
            print("ERROR: No quantizable layers found!")
            return self.model


        if self.mixed_precision:
            self.analyze_sensitivity(calibration_data)
            self.assign_bit_widths()
        else:
            print(f"\n Using uniform {self.default_bit_width}-bit quantization")
            for name, _ in quantizable_layers:
                self.layer_bit_widths[name] = self.default_bit_width


        print(f"\n Quantizing {len(quantizable_layers)} layers...")
        print("="*80)

        quantized_model = copy.deepcopy(self.model)

        for idx, (layer_name, layer) in enumerate(quantizable_layers):
            print(f"\n[Layer {idx+1}/{len(quantizable_layers)}] {layer_name}")


            layer_in, layer_out = self._get_activations(layer_name, calibration_data)


            bit_width = self.layer_bit_widths.get(layer_name, self.default_bit_width)


            q_layer = self.quantize_layer(
                layer_name, layer, layer_in, layer_out,
                bit_width, num_iterations
            )


            parts = layer_name.split('.')
            parent = quantized_model
            for part in parts[:-1]:
                parent = getattr(parent, part)
            setattr(parent, parts[-1], q_layer)

            self.quantized_layers[layer_name] = q_layer

        print("\n" + "="*80)
        print("✓ QUANTIZATION COMPLETE")
        print("="*80)

        return quantized_model

    def evaluate(self, orig_model, quant_model, test_data):

        orig_model.eval()
        quant_model.eval()

        test_data = test_data.to(self.device)

        with torch.no_grad():
            orig_out = orig_model(test_data)
            quant_out = quant_model(test_data)

            mse = torch.mean((orig_out - quant_out) ** 2).item()
            max_val = orig_out.abs().max().item()
            psnr = 10 * np.log10(max_val**2 / (mse + 1e-10)) if mse > 0 else 100.0

            rel_err = (torch.mean(torch.abs(orig_out - quant_out)) /
                      (torch.mean(torch.abs(orig_out)) + 1e-10)).item() * 100

            cos_sim = nn.functional.cosine_similarity(
                orig_out.flatten(), quant_out.flatten(), dim=0
            ).item()


        orig_params = sum(p.numel() for p in orig_model.parameters())
        orig_size_mb = orig_params * 4 / (1024 ** 2)

        quant_bits = 0
        for name, q_layer in self.quantized_layers.items():
            bits = self.layer_bit_widths.get(name, 8)
            quant_bits += q_layer.weight_float.numel() * bits
            if q_layer.bias_float is not None:
                quant_bits += q_layer.bias_float.numel() * bits

        quant_size_mb = quant_bits / (8 * 1024 ** 2)
        compression = orig_size_mb / quant_size_mb if quant_size_mb > 0 else 0

        return {
            'mse': mse,
            'psnr_db': psnr,
            'relative_error_pct': rel_err,
            'cosine_similarity': cos_sim,
            'original_size_mb': orig_size_mb,
            'quantized_size_mb': quant_size_mb,
            'compression_ratio': compression
        }


def demo():


    print("="*80)
    print("Fixed Production RDO-PTQ Demo")
    print("="*80)


    class SimpleModel(nn.Module):
        def __init__(self):
            super().__init__()
            self.features = nn.Sequential(
                nn.Conv2d(3, 32, 3, padding=1),
                nn.ReLU(),
                nn.Conv2d(32, 64, 3, stride=2, padding=1),
                nn.ReLU(),
                nn.Conv2d(64, 128, 3, stride=2, padding=1),
                nn.ReLU(),
            )
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
            self.fc = nn.Linear(128, 10)

        def forward(self, x):
            x = self.features(x)
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.fc(x)
            return x

    print("\n[1] Creating model...")
    model = SimpleModel()
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)

    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total parameters: {total_params:,}")
    print(f"FP32 size: {total_params * 4 / 1024**2:.2f} MB")

    print("\n[2] Preparing calibration data...")
    calib_data = torch.randn(30, 3, 64, 64)
    print(f"Calibration: {len(calib_data)} images")

    print("\n[3] Initializing quantizer...")
    quantizer = FixedRDO_PTQ(
        model=model,
        bit_width=8,
        lambda_rd=0.01,
        mixed_precision=True,
        device='cpu'
    )

    print("\n[4] Running quantization...")
    quant_model = quantizer.quantize_model(calib_data, num_iterations=50)

    print("\n[5] Evaluation...")
    test_data = torch.randn(20, 3, 64, 64)
    metrics = quantizer.evaluate(model, quant_model, test_data)

    print("\n" + "="*80)
    print("RESULTS")
    print("="*80)
    print(f"\n Compression:")
    print(f"  Original:    {metrics['original_size_mb']:.2f} MB")
    print(f"  Quantized:   {metrics['quantized_size_mb']:.2f} MB")
    print(f"  Ratio:       {metrics['compression_ratio']:.2f}x")

    print(f"\n Accuracy:")
    print(f"  PSNR:        {metrics['psnr_db']:.2f} dB")
    print(f"  Rel Error:   {metrics['relative_error_pct']:.2f}%")
    print(f"  Cosine Sim:  {metrics['cosine_similarity']:.4f}")

    print(f"\n Bit-widths:")
    for name, bits in list(quantizer.layer_bit_widths.items())[:5]:
        print(f"  {name:25s} → {bits}-bit")

    print("\n" + "="*80)
    print("✓ Demo complete!")
    print("="*80)

    return model, quant_model


if __name__ == "__main__":
    demo()

Importing RDO-PTQ modules...
Fixed Production RDO-PTQ Demo

[1] Creating model...
Total parameters: 94,538
FP32 size: 0.36 MB

[2] Preparing calibration data...
Calibration: 30 images

[3] Initializing quantizer...

[4] Running quantization...

STARTING QUANTIZATION
DEBUG: Found 3 Conv2d layers

 Analyzing layer sensitivities...
DEBUG: Found 3 Conv2d layers
  [1/3] Testing features.0... sensitivity=0.000031
  [2/3] Testing features.2... sensitivity=0.000643
  [3/3] Testing features.4... sensitivity=0.004673

 Assigning bit-widths...
  features.0                     → 4-bit (sens=0.000)
  features.2                     → 4-bit (sens=0.132)
  features.4                     → 8-bit (sens=1.000)

 Quantizing 3 layers...

[Layer 1/3] features.0
  Quantizing features.0 → 4-bit
    Iter   0/50 | Loss: -0.006547 | D: 0.017696 | R: -2.4244
    Iter  25/50 | Loss: -0.008407 | D: 0.016087 | R: -2.4494
    Iter  49/50 | Loss: -0.008488 | D: 0.016006 | R: -2.4494

[Layer 2/3] features.2
  Quantizin