### Part 1: Core Quantization Functions

In [1]:
import torch
import torch.nn as nn
import numpy as np

In [2]:
def quantize_tensor_symmetric(tensor, num_bits=8):
    """
    Symmetric quantization: range centered at 0
    
    Args:
        tensor: FP32 tensor to quantize
        num_bits: bit width (8 for INT8)
    
    Returns:
        q_tensor: quantized tensor (INT)
        scale: quantization scale
    """

    q_max = 2 ** (num_bits - 1) - 1
    q_min = - q_max

    max_val = tensor.abs().max()
    scale = max_val / q_max

    if scale == 0.0:
        scale = 1.0

    q_tensor = torch.clamp(
        torch.round(tensor / scale), 
        q_min, q_max
    ).to(torch.int8)

    return q_tensor, scale

def dequantize_tensor(q_tensor, scale):
    """
    Dequantize: INT8 -> FP32
    
    Args:
        q_tensor: quantized tensor (INT8)
        scale: quantization scale
    
    Returns:
        tensor: dequantized FP32 tensor
    """
    return q_tensor.float() * scale


def quantize_tensor_asymmetric(tensor, num_bits=8):
    """
    Asymmetric quantization: uses full range
    
    Args:
        tensor: FP32 tensor to quantize
        num_bits: bit width
    
    Returns:
        q_tensor: quantized tensor (INT)
        scale: quantization scale
        zero_point: zero point
    """
    q_max = 2 ** (num_bits - 1) - 1
    q_min = - q_max

    max_val = tensor.abs().max()
    min_val = tensor.abs().min()

    scale = (max_val - min_val) / (q_max - q_min)

    if scale == 0.0:
        scale = 1.0

    zero_point = q_min - torch.round(min_val / scale)
    zero_point = torch.clamp(zero_point, q_min, q_max).to(torch.int8)

    q_tensor = torch.clamp(
        torch.round(tensor / scale) + zero_point,
        q_min, q_max 
    ).to(torch.int8)

    return q_tensor, scale, zero_point

def dequantize_tensor_asymmetric(q_tensor, scale, zero_point):
    """
    Dequantize asymmetric: UINT8 -> FP32
    """
    return (q_tensor.float() - zero_point) * scale

print("="*60)
print("TESTING BASIC QUANTIZATION")
print("="*60)

tensor_fp32 = torch.randn(4, 4) * 2.0
print(f"\nOriginal tensor:\n{tensor_fp32}")
print(f"Range: [{tensor_fp32.min():.3f}, {tensor_fp32.max():.3f}]")

q_tensor, scale = quantize_tensor_symmetric(tensor_fp32)
print(f"\nQuantized (INT8):\n{q_tensor}")
print(f"Scale: {scale:.6f}")

dq_tensor = dequantize_tensor(q_tensor, scale)
print(f"\nDequantized:\n{dq_tensor}")

error = (tensor_fp32 - dq_tensor).abs().mean()
print(f"\nMean absolute error: {error:.6f}")

tensor_relu = torch.relu(torch.randn(4, 4))
print(f"\n\nReLU activations (asymmetric):\n{tensor_relu}")
print(f"Range: [{tensor_relu.min():.3f}, {tensor_relu.max():.3f}]")

q_tensor_asym, scale_asym, zp = quantize_tensor_asymmetric(tensor_relu)
print(f"\nQuantized (UINT8):\n{q_tensor_asym}")
print(f"Scale: {scale_asym:.6f}, Zero point: {zp}")

dq_tensor_asym = dequantize_tensor_asymmetric(q_tensor_asym, scale_asym, zp)
print(f"\nDequantized:\n{dq_tensor_asym}")

error_asym = (tensor_relu - dq_tensor_asym).abs().mean()
print(f"\nMean absolute error: {error_asym:.6f}")

TESTING BASIC QUANTIZATION

Original tensor:
tensor([[ 0.1793, -3.9349, -1.2818,  0.9817],
        [-2.7567,  2.7398, -1.8403, -0.9391],
        [-3.3878, -0.0266, -1.0098,  2.2864],
        [ 0.0894, -0.6689, -1.5727,  3.1289]])
Range: [-3.935, 3.129]

Quantized (INT8):
tensor([[   6, -127,  -41,   32],
        [ -89,   88,  -59,  -30],
        [-109,   -1,  -33,   74],
        [   3,  -22,  -51,  101]], dtype=torch.int8)
Scale: 0.030984

Dequantized:
tensor([[ 0.1859, -3.9349, -1.2703,  0.9915],
        [-2.7576,  2.7266, -1.8280, -0.9295],
        [-3.3772, -0.0310, -1.0225,  2.2928],
        [ 0.0930, -0.6816, -1.5802,  3.1294]])

Mean absolute error: 0.007639


ReLU activations (asymmetric):
tensor([[0.0000, 0.0000, 1.1442, 0.9498],
        [2.4502, 0.0193, 0.0000, 0.0831],
        [0.8293, 0.1243, 0.0000, 0.0000],
        [2.5610, 0.0000, 1.0084, 0.0000]])
Range: [0.000, 2.561]

Quantized (UINT8):
tensor([[-127, -127,  -14,  -33],
        [ 116, -125, -127, -119],
        [ -45, 

### Part 2: Per-Channel Quantization

In [8]:
def quantize_per_channel_symmetric(tensor, channel_dim=0, num_bits=8):
    """
    Per-channel symmetric quantization for weights
    
    Args:
        tensor: Weight tensor (e.g., [out_features, in_features])
        channel_dim: Which dimension is the channel (0 for output channels)
        num_bits: bit width
    
    Returns:
        q_tensor: quantized tensor
        scales: per-channel scales
    """
    q_max = 2 ** (num_bits - 1) - 1
    q_min = - q_max

    tensor_transposed = tensor.transpose(0, channel_dim) if channel_dim != 0 else tensor
    num_channels = tensor_transposed.shape[0]

    scales = []
    q_tensor_list = []

    for i in range(num_channels):
        channel_data = tensor_transposed[i,:]
        max_val = channel_data.abs().max()

        scale = max_val / q_max
        if scale == 0.0:
            scale = 1.0
        scales.append(scale)

        q_channel = torch.clamp(
            torch.round(channel_data / scale),
            q_min, q_max
        ).to(torch.int8)

        q_tensor_list.append(q_channel)

    q_tensor = torch.stack(q_tensor_list, dim=0)
    scales = torch.tensor(scales)

    if channel_dim != 0:
        q_tensor = q_tensor.transpose(0, channel_dim)
    
    return q_tensor, scales

def dequantize_per_channel(q_tensor, scales, channel_dim=0):
    """
    Dequantize per-channel quantized tensor
    """
    q_tensor_transposed = q_tensor.transpose(0, channel_dim) if channel_dim != 0 else q_tensor
    num_channels = q_tensor_transposed.shape[0]
    
    dq_list = []
    for i in range(num_channels):
        dq_channel = q_tensor_transposed[i].float() * scales[i]
        dq_list.append(dq_channel)
    
    dq_tensor = torch.stack(dq_list)
    
    if channel_dim != 0:
        dq_tensor = dq_tensor.transpose(0, channel_dim)
    
    return dq_tensor

print("\n" + "="*60)
print("TESTING PER-CHANNEL QUANTIZATION")
print("="*60)

weight = torch.randn(4, 8)  
weight[0] *= 0.5  
weight[1] *= 2.0  
weight[2] *= 1.0  
weight[3] *= 0.3  

print(f"\nWeight matrix:\n{weight}")
print(f"\nPer-channel statistics:")
for i in range(4):
    print(f"  Channel {i}: range [{weight[i].min():.3f}, {weight[i].max():.3f}]")

q_tensor_pt, scale_pt = quantize_tensor_symmetric(weight)
dq_tensor_pt = dequantize_tensor(q_tensor_pt, scale_pt)
error_pt = (weight - dq_tensor_pt).abs().mean()
print(f"\nPer-tensor error: {error_pt:.6f}")

q_tensor_pc, scales_pc = quantize_per_channel_symmetric(weight, channel_dim=0)
dq_tensor_pc = dequantize_per_channel(q_tensor_pc, scales_pc, channel_dim=0)
error_pc = (weight - dq_tensor_pc).abs().mean()
print(f"Per-channel error: {error_pc:.6f}")
print(f"Improvement: {(error_pt - error_pc) / error_pt * 100:.1f}%")

print(f"\nPer-channel scales: {scales_pc}")


TESTING PER-CHANNEL QUANTIZATION

Weight matrix:
tensor([[ 0.1753,  0.9806,  0.3283,  0.7704, -0.5617,  0.6033,  0.1387,  0.2544],
        [-0.8346,  0.5745,  0.5997,  0.4270,  0.4686, -1.8178,  1.7313,  2.6731],
        [ 0.1446, -1.6448,  1.1356, -1.9062, -0.2997,  0.1748, -1.0819, -1.7786],
        [ 0.2215, -0.2306, -0.2458, -0.1055, -0.1148,  0.2646, -0.4096, -0.3463]])

Per-channel statistics:
  Channel 0: range [-0.562, 0.981]
  Channel 1: range [-1.818, 2.673]
  Channel 2: range [-1.906, 1.136]
  Channel 3: range [-0.410, 0.265]

Per-tensor error: 0.006477
Per-channel error: 0.003069
Improvement: 52.6%

Per-channel scales: tensor([0.0077, 0.0210, 0.0150, 0.0032])


### Part 3: Quantized Linear Layer

In [12]:
class QuantizedLinear(nn.Module):
    """
    Linear layer with INT8 weights and INT8 activations
    """
    def __init__(self, in_features, out_features, bias=True):
        super().__init__()
        
        # Store original FP32 weights for initialization
        self.linear_fp32 = nn.Linear(in_features, out_features, bias=bias)
        
        # Quantized weights (will be set by quantize method)
        self.register_buffer('weight_quantized', torch.zeros(out_features, in_features, dtype=torch.int8))
        self.register_buffer('weight_scales', torch.zeros(out_features))
        
        # Activation scale (will be set during calibration)
        self.register_buffer('activation_scale', torch.tensor(1.0))
        
        self.quantized = False
    
    def quantize_weights(self):
        """Quantize the weights (per-channel)"""
        weight = self.linear_fp32.weight.data
        q_weight, scales = quantize_per_channel_symmetric(weight, channel_dim=0)
        
        self.weight_quantized = q_weight
        self.weight_scales = scales
        self.quantized = True
    
    def set_activation_scale(self, scale):
        """Set activation quantization scale (from calibration)"""
        self.activation_scale = scale
    
    def forward(self, x):
        if not self.quantized:
            # FP32 mode
            return self.linear_fp32(x)
        
        # Quantize input activations
        x_q, x_scale = quantize_tensor_symmetric(x)
        
        # Dequantize weights for computation (in practice, use INT8 matmul)
        weight_dq = dequantize_per_channel(
            self.weight_quantized, 
            self.weight_scales, 
            channel_dim=0
        )
        
        # Dequantize activations
        x_dq = dequantize_tensor(x_q, x_scale)
        
        # Compute (in FP32 for simplicity - real impl would use INT8 ops)
        output = torch.nn.functional.linear(x_dq, weight_dq, self.linear_fp32.bias)
        
        return output


# ==================== TESTING QUANTIZED LINEAR ====================

print("\n" + "="*60)
print("TESTING QUANTIZED LINEAR LAYER")
print("="*60)

# Create layers
linear_fp32 = nn.Linear(128, 64)
linear_quant = QuantizedLinear(128, 64)

# Copy weights
linear_quant.linear_fp32.load_state_dict(linear_fp32.state_dict())

# Test input
x = torch.randn(8, 128)  # Batch of 8

# FP32 forward
output_fp32 = linear_fp32(x)
print(f"\nFP32 output shape: {output_fp32.shape}")
print(f"FP32 output range: [{output_fp32.min():.3f}, {output_fp32.max():.3f}]")

# Quantize and forward
linear_quant.quantize_weights()
output_quant = linear_quant(x)
print(f"\nQuantized output shape: {output_quant.shape}")
print(f"Quantized output range: [{output_quant.min():.3f}, {output_quant.max():.3f}]")

# Compare outputs
error = (output_fp32 - output_quant).abs().mean()
relative_error = error / output_fp32.abs().mean()
print(f"\nAbsolute error: {error:.6f}")
print(f"Relative error: {relative_error:.6f} ({relative_error*100:.2f}%)")

# Memory comparison
fp32_size = linear_fp32.weight.numel() * 4  # 4 bytes per float
int8_size = linear_quant.weight_quantized.numel() * 1  # 1 byte per int8
scales_size = linear_quant.weight_scales.numel() * 4  # scales are FP32

print(f"\nMemory usage:")
print(f"  FP32: {fp32_size} bytes")
print(f"  INT8: {int8_size + scales_size} bytes")
print(f"  Reduction: {fp32_size / (int8_size + scales_size):.2f}x")


TESTING QUANTIZED LINEAR LAYER

FP32 output shape: torch.Size([8, 64])
FP32 output range: [-1.851, 1.513]

Quantized output shape: torch.Size([8, 64])
Quantized output range: [-1.850, 1.510]

Absolute error: 0.004590
Relative error: 0.009914 (0.99%)

Memory usage:
  FP32: 32768 bytes
  INT8: 8448 bytes
  Reduction: 3.88x


In [13]:
class QuantizedModel(nn.Module):
    """Quantized version"""
    def __init__(self, input_size=784, hidden_size=256, num_classes=10):
        super().__init__()
        self.fc1 = QuantizedLinear(input_size, hidden_size)
        self.relu1 = nn.ReLU()
        self.fc2 = QuantizedLinear(hidden_size, hidden_size)
        self.relu2 = nn.ReLU()
        self.fc3 = QuantizedLinear(hidden_size, num_classes)
    
    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.relu1(self.fc1(x))
        x = self.relu2(self.fc2(x))
        x = self.fc3(x)
        return x
    
    def quantize_all_weights(self):
        """Quantize all linear layers"""
        self.fc1.quantize_weights()
        self.fc2.quantize_weights()
        self.fc3.quantize_weights()

## Step 7: Advanced PyTorch Implementation - Fake Quantization & QAT

### Part 1: Advanced Fake Quantization with Observers

In [34]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function

In [35]:
class FakeQuantize(Function):
    """
    Fake quantization with Straight-Through Estimator
    Supports both symmetric and asymmetric quantization
    """
    @staticmethod
    def forward(self, x, scale, zero_point, q_max, q_min):
        x_int = torch.clamp(
            torch.round(x / scale) + zero_point,
            q_min, q_max
        )
        x_deq = (x_int - zero_point) * scale
        return x_deq

    @staticmethod
    def backward(self, grad_output):
        return grad_output, None, None, None, None

fake_quantize = FakeQuantize.apply

In [36]:
class MinMaxObserver(nn.Module):
    """
    Observes min/max values during calibration to compute scales
    """
    def __init__(self, dtype=torch.qint8, qscheme='per_tensor_symmetric'):
        super().__init__()
        self.dtype = dtype
        self.qscheme = qscheme
        
        # Quantization range
        if dtype == torch.qint8:
            self.quant_min = -127
            self.quant_max = 127
        elif dtype == torch.quint8:
            self.quant_min = 0
            self.quant_max = 255
        
        # Observed statistics
        self.register_buffer('min_val', torch.tensor(float('inf')))
        self.register_buffer('max_val', torch.tensor(float('-inf')))
        
        # Computed scale and zero point
        self.register_buffer('scale', torch.tensor(1.0))
        self.register_buffer('zero_point', torch.tensor(0, dtype=torch.int32))
    
    def forward(self, x):
        """Update min/max statistics"""
        if self.training:
            # Update running min/max
            current_min = x.min()
            current_max = x.max()
            
            self.min_val = torch.min(self.min_val, current_min)
            self.max_val = torch.max(self.max_val, current_max)
        
        return x
    
    def calculate_qparams(self):
        """Calculate quantization parameters from observed min/max"""
        if self.qscheme == 'per_tensor_symmetric':
            max_val = torch.max(self.min_val.abs(), self.max_val.abs())
            self.scale = max_val / 127
            self.zero_point = torch.tensor(0, dtype=torch.int32)
        
        elif self.qscheme == 'per_tensor_asymmetric':
            self.scale = (self.max_val - self.min_val) / 255
            self.zero_point = torch.round(-self.min_val / self.scale).to(torch.int32)
            self.zero_point = torch.clamp(self.zero_point, 0, 255)
        
        # Handle zero scale
        if self.scale == 0:
            self.scale = torch.tensor(1.0)
        
        return self.scale, self.zero_point

In [37]:
class MovingAverageMinMaxObserver(MinMaxObserver):
    """
    Observer with exponential moving average (better for QAT)
    """
    def __init__(self, dtype=torch.qint8, qscheme='per_tensor_symmetric', 
                 averaging_constant=0.01):
        super().__init__(dtype, qscheme)
        self.averaging_constant = averaging_constant
    
    def forward(self, x):
        """Update with exponential moving average"""
        if self.training:
            current_min = x.detach().min()
            current_max = x.detach().max()
            
            if self.min_val == float('inf'):
                # First observation
                self.min_val = current_min
                self.max_val = current_max
            else:
                # EMA update
                self.min_val = (
                    self.min_val * (1 - self.averaging_constant) +
                    current_min * self.averaging_constant
                )
                self.max_val = (
                    self.max_val * (1 - self.averaging_constant) +
                    current_max * self.averaging_constant
                )
        
        return x

In [None]:
class QATLinear(nn.Module):
    """
    Linear layer with fake quantization for QAT
    Uses observers to track statistics and compute scales
    """
    def __init__(self, in_features, out_features, bias=True, 
                 weight_qscheme='per_channel_symmetric',
                 activation_qscheme='per_tensor_symmetric'):
        super().__init__()
        
        self.linear = nn.Linear(in_features, out_features, bias=bias)
        self.weight_qscheme = weight_qscheme
        self.activation_qscheme = activation_qscheme
        
        # Observers
        if weight_qscheme == 'per_channel_symmetric':
            # Per-channel observer (one per output channel)
            self.weight_observer = nn.ModuleList([
                MinMaxObserver(torch.qint8, 'per_tensor_symmetric')
                for _ in range(out_features)
            ])
        else:
            self.weight_observer = MinMaxObserver(torch.qint8, weight_qscheme)
        
        self.activation_observer = MovingAverageMinMaxObserver(
            torch.qint8, activation_qscheme
        )
        
        # Quantization parameters (set after calibration)
        self.register_buffer('weight_scale', None)
        self.register_buffer('weight_zero_point', None)
        self.register_buffer('activation_scale', torch.tensor(1.0))
        self.register_buffer('activation_zero_point', torch.tensor(0, dtype=torch.int32))
        
        # QAT mode flag
        self.qat_mode = False
    
    def enable_observer(self):
        """Enable observers for calibration"""
        for module in self.modules():
            if isinstance(module, (MinMaxObserver, MovingAverageMinMaxObserver)):
                module.train()
    
    def disable_observer(self):
        """Disable observers after calibration"""
        for module in self.modules():
            if isinstance(module, (MinMaxObserver, MovingAverageMinMaxObserver)):
                module.eval()
    
    def calculate_qparams(self):
        """Calculate quantization parameters from observers"""
        # Weight quantization parameters
        if self.weight_qscheme == 'per_channel_symmetric':
            scales = []
            for i, observer in enumerate(self.weight_observer):
                # Observe this channel
                with torch.no_grad():
                    observer(self.linear.weight[i])
                    scale, _ = observer.calculate_qparams()
                    scales.append(scale)
            self.weight_scale = torch.stack(scales)
            self.weight_zero_point = torch.zeros(len(scales), dtype=torch.int32)
        else:
            with torch.no_grad():
                self.weight_observer(self.linear.weight)
                self.weight_scale, self.weight_zero_point = self.weight_observer.calculate_qparams()
        
        # Activation quantization parameters
        self.activation_scale, self.activation_zero_point = \
            self.activation_observer.calculate_qparams()
    
    def enable_qat(self):
        """Enable QAT mode (fake quantization)"""
        self.qat_mode = True
        self.calculate_qparams()
    
    def forward(self, x):
        # Observe activations
        x = self.activation_observer(x)
        
        if not self.qat_mode:
            # Normal FP32 forward
            return self.linear(x)
        
        # Fake quantize activations
        x_fq = fake_quantize(
            x,
            self.activation_scale,
            self.activation_zero_point,
            self.activation_observer.quant_min,
            self.activation_observer.quant_max
        )
        
        # Fake quantize weights (per-channel)
        if self.weight_qscheme == 'per_channel_symmetric':
            weight_fq = []
            for i in range(self.linear.weight.shape[0]):
                w_channel = self.linear.weight[i]
                w_fq_channel = fake_quantize(
                    w_channel,
                    self.weight_scale[i],
                    self.weight_zero_point[i],
                    -127, 127
                )
                weight_fq.append(w_fq_channel)
            weight_fq = torch.stack(weight_fq)
        else:
            weight_fq = fake_quantize(
                self.linear.weight,
                self.weight_scale,
                self.weight_zero_point,
                -127, 127
            )
        
        # Linear with fake-quantized inputs
        return F.linear(x_fq, weight_fq, self.linear.bias)

class QATModel(nn.Module):
    """
    Complete model with QAT support
    """
    def __init__(self, input_size=784, hidden_size=256, num_classes=10):
        super().__init__()
        self.fc1 = QATLinear(input_size, hidden_size)
        self.relu1 = nn.ReLU()
        self.fc2 = QATLinear(hidden_size, hidden_size)
        self.relu2 = nn.ReLU()
        self.fc3 = QATLinear(hidden_size, num_classes)
    
    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.relu1(self.fc1(x))
        x = self.relu2(self.fc2(x))
        x = self.fc3(x)
        return x
    
    def enable_observer(self):
        """Enable all observers for calibration"""
        for module in self.modules():
            if isinstance(module, QATLinear):
                module.enable_observer()
    
    def disable_observer(self):
        """Disable all observers after calibration"""
        for module in self.modules():
            if isinstance(module, QATLinear):
                module.disable_observer()
    
    def calculate_qparams(self):
        """Calculate quantization parameters for all layers"""
        for module in self.modules():
            if isinstance(module, QATLinear):
                module.calculate_qparams()
    
    def enable_qat(self):
        """Enable QAT mode (fake quantization)"""
        for module in self.modules():
            if isinstance(module, QATLinear):
                module.enable_qat()

In [40]:
# ==================== TRAINING PIPELINE ====================

def calibrate_model(model, calibration_loader, device='cpu'):
    """
    Calibration phase: collect statistics
    
    Args:
        model: Model with observers
        calibration_loader: DataLoader with calibration data
        device: Device to run on
    """
    print("Starting calibration...")
    model.eval()
    model.enable_observer()
    
    with torch.no_grad():
        for batch_idx, (data, _) in enumerate(calibration_loader):
            data = data.to(device)
            model(data)
            
            if batch_idx % 10 == 0:
                print(f"  Calibration batch {batch_idx}/{len(calibration_loader)}")
    
    # Calculate quantization parameters
    model.calculate_qparams()
    model.disable_observer()
    
    print("Calibration complete!")
    
    # Print scales
    print("\nQuantization scales:")
    for name, module in model.named_modules():
        if isinstance(module, QATLinear):
            print(f"  {name}:")
            print(f"    Weight scale: {module.weight_scale.mean().item():.6f}")
            print(f"    Activation scale: {module.activation_scale.item():.6f}")


def train_qat(model, train_loader, optimizer, criterion, device='cpu'):
    """
    QAT training for one epoch
    """
    model.train()
    total_loss = 0
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        if batch_idx % 100 == 0:
            print(f"  Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}")
    
    return total_loss / len(train_loader)


def evaluate(model, test_loader, device='cpu'):
    """
    Evaluate model accuracy
    """
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1)
            correct += (pred == target).sum().item()
            total += target.size(0)
    
    accuracy = 100.0 * correct / total
    return accuracy


# ==================== COMPLETE QAT WORKFLOW ====================

def complete_qat_workflow():
    """
    Full QAT workflow: Pretrain → Calibrate → QAT → Evaluate
    """
    print("="*60)
    print("COMPLETE QAT WORKFLOW")
    print("="*60)
    
    # Hyperparameters
    batch_size = 64
    num_calibration_batches = 100
    num_qat_epochs = 3
    
    # Create dummy datasets (replace with real MNIST/CIFAR)
    def get_dummy_loader(num_samples, batch_size):
        data = torch.randn(num_samples, 784)
        targets = torch.randint(0, 10, (num_samples,))
        dataset = torch.utils.data.TensorDataset(data, targets)
        return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    train_loader = get_dummy_loader(1000, batch_size)
    calibration_loader = get_dummy_loader(500, batch_size)
    test_loader = get_dummy_loader(200, batch_size)
    
    # Step 1: Create and pretrain FP32 model
    print("\nStep 1: Pretraining FP32 model...")
    model = QATModel()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()
    
    # Pretrain for a few epochs
    for epoch in range(2):
        loss = train_qat(model, train_loader, optimizer, criterion)
        print(f"Pretrain Epoch {epoch+1}: Loss = {loss:.4f}")
    
    fp32_acc = evaluate(model, test_loader)
    print(f"\nFP32 Accuracy: {fp32_acc:.2f}%")
    
    # Step 2: Calibration
    print("\n" + "="*60)
    print("Step 2: Calibration")
    print("="*60)
    calibrate_model(model, calibration_loader)
    
    # Step 3: Enable QAT mode
    print("\n" + "="*60)
    print("Step 3: QAT Training")
    print("="*60)
    model.enable_qat()
    
    # Use smaller learning rate for QAT fine-tuning
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    
    for epoch in range(num_qat_epochs):
        print(f"\nQAT Epoch {epoch+1}/{num_qat_epochs}")
        loss = train_qat(model, train_loader, optimizer, criterion)
        qat_acc = evaluate(model, test_loader)
        print(f"QAT Epoch {epoch+1}: Loss = {loss:.4f}, Accuracy = {qat_acc:.2f}%")
    
    # Step 4: Final evaluation
    print("\n" + "="*60)
    print("Final Results")
    print("="*60)
    print(f"FP32 Accuracy: {fp32_acc:.2f}%")
    print(f"QAT Accuracy: {qat_acc:.2f}%")
    print(f"Accuracy drop: {fp32_acc - qat_acc:.2f}%")
    
    return model


# ==================== RUN ====================

if __name__ == "__main__":
    model = complete_qat_workflow()

COMPLETE QAT WORKFLOW

Step 1: Pretraining FP32 model...
  Batch 0/16, Loss: 2.3075
Pretrain Epoch 1: Loss = 2.3127
  Batch 0/16, Loss: 2.0951
Pretrain Epoch 2: Loss = 1.9762

FP32 Accuracy: 9.00%

Step 2: Calibration
Starting calibration...
  Calibration batch 0/8
Calibration complete!

Quantization scales:
  fc1:
    Weight scale: 0.000407
    Activation scale: 0.032719
  fc2:
    Weight scale: 0.000584
    Activation scale: 0.017808
  fc3:
    Weight scale: 0.000601
    Activation scale: 0.008153

Step 3: QAT Training

QAT Epoch 1/3
  Batch 0/16, Loss: 2.7003
QAT Epoch 1: Loss = 2.6010, Accuracy = 11.00%

QAT Epoch 2/3
  Batch 0/16, Loss: 2.6204
QAT Epoch 2: Loss = 2.5973, Accuracy = 11.00%

QAT Epoch 3/3
  Batch 0/16, Loss: 2.6751
QAT Epoch 3: Loss = 2.6024, Accuracy = 11.00%

Final Results
FP32 Accuracy: 9.00%
QAT Accuracy: 11.00%
Accuracy drop: -2.00%
