In [7]:
#!/usr/bin/env python3
"""
Extract quantized weights from PyTorch checkpoint and generate C++ header file.
Fixed for checkpoint with fire.conv1/conv2/conv3 naming convention.
"""

import torch
import numpy as np
import sys
from pathlib import Path

def analyze_weight_distribution(state_dict):
    """Analyze weight ranges across all layers"""
    print("\n" + "="*70)
    print("WEIGHT DISTRIBUTION ANALYSIS")
    print("="*70)
    
    all_weights = []
    for key, val in state_dict.items():
        if 'weight' in key:
            weights = val.cpu().numpy()
            all_weights.append(weights.flatten())
            print(f"{key:40s} min={weights.min():7.4f} max={weights.max():7.4f}")
    
    all_weights = np.concatenate(all_weights)
    print("\n" + "="*70)
    print(f"OVERALL: min={all_weights.min():.4f}, max={all_weights.max():.4f}")
    print(f"         mean={all_weights.mean():.4f}, std={all_weights.std():.4f}")
    print("="*70 + "\n")
    
    return all_weights.min(), all_weights.max()

def quantize_to_fixed_point(weight, int_bits=4, frac_bits=4, verbose=False):
    """
    Convert float weights to fixed-point representation.
    For <8,4>: range is -8.0 to +7.9375
    
    If weights exceed this range, we'll scale them down.
    """
    # Check input range
    w_min, w_max = weight.min(), weight.max()
    
    if verbose:
        print(f"  Input range: [{w_min:.4f}, {w_max:.4f}]")
    
    # Target range for signed fixed-point
    max_val = (2 ** (int_bits - 1)) - 1  # 7 for 4 int bits
    min_val = -(2 ** (int_bits - 1))      # -8 for 4 int bits
    max_fp = max_val + (1 - 2**(-frac_bits))  # 7.9375
    min_fp = min_val  # -8.0
    
    # If weights exceed range, apply scaling
    scale_factor = 1.0
    if w_max > max_fp or w_min < min_fp:
        # Calculate required scale to fit within range
        scale_down = max(abs(w_max / max_fp), abs(w_min / min_fp))
        scale_factor = 1.0 / scale_down
        print(f"  WARNING: Weights [{w_min:.4f}, {w_max:.4f}] exceed [{min_fp}, {max_fp}]")
        print(f"  Scaling by {scale_factor:.6f} to fit range")
        weight = weight * scale_factor
    
    # Scale to fixed-point
    scale = 2 ** frac_bits  # 16 for 4 fractional bits
    
    # Quantize: multiply by scale and round
    quantized = np.round(weight * scale).astype(np.int32)
    
    # Saturate to valid range (should rarely trigger after scaling)
    max_int = (2 ** (int_bits + frac_bits - 1)) - 1  # 127 for 8-bit
    min_int = -(2 ** (int_bits + frac_bits - 1))     # -128 for 8-bit
    
    clipped = np.clip(quantized, min_int, max_int)
    if np.any(clipped != quantized):
        n_clipped = np.sum(clipped != quantized)
        print(f"  Clipped {n_clipped} values to [{min_int}, {max_int}]")
    
    if verbose:
        print(f"  Output int range: [{clipped.min()}, {clipped.max()}]")
        print(f"  Represents float: [{clipped.min()/scale:.4f}, {clipped.max()/scale:.4f}]")
    
    return clipped

def format_array_cpp(name, data, line_width=12):
    """Format numpy array as C++ array initialization"""
    flat = data.flatten()
    lines = []
    lines.append(f"const fixed_point_t {name}[{len(flat)}] = {{")
    
    # Format in rows
    for i in range(0, len(flat), line_width):
        row = flat[i:i+line_width]
        row_str = ", ".join(f"{int(x)}" for x in row)
        lines.append(f"    {row_str},")
    
    # Remove trailing comma from last line
    if lines[-1].endswith(","):
        lines[-1] = lines[-1][:-1]
    
    lines.append("};")
    return "\n".join(lines)

def extract_conv_weights(state_dict, key_name, out_name, verbose=False):
    """Extract convolutional layer weights"""
    if key_name not in state_dict:
        print(f"  ERROR: {key_name} not found!")
        return None
    
    weight = state_dict[key_name].cpu().numpy()
    print(f"  Found: {key_name}")
    print(f"  Shape: {weight.shape}")
    quantized = quantize_to_fixed_point(weight, verbose=verbose)
    return format_array_cpp(out_name, quantized)

def extract_fire_module(state_dict, fire_name, fire_num, verbose=False):
    """Extract all weights from a Fire module"""
    # Your checkpoint uses: fire2.conv1, fire2.conv2, fire2.conv3
    squeeze_key = f"{fire_name}.conv1.weight"  # squeeze = conv1
    expand1_key = f"{fire_name}.conv2.weight"  # expand1x1 = conv2
    expand3_key = f"{fire_name}.conv3.weight"  # expand3x3 = conv3
    
    # Check all keys exist
    if squeeze_key not in state_dict:
        print(f"  ERROR: {squeeze_key} not found!")
        return None
    if expand1_key not in state_dict:
        print(f"  ERROR: {expand1_key} not found!")
        return None
    if expand3_key not in state_dict:
        print(f"  ERROR: {expand3_key} not found!")
        return None
    
    print(f"  Squeeze (conv1): {state_dict[squeeze_key].shape}")
    print(f"  Expand1 (conv2): {state_dict[expand1_key].shape}")
    print(f"  Expand3 (conv3): {state_dict[expand3_key].shape}")
    
    # Extract and quantize
    squeeze = state_dict[squeeze_key].cpu().numpy()
    squeeze_quant = quantize_to_fixed_point(squeeze, verbose=verbose)
    
    expand1 = state_dict[expand1_key].cpu().numpy()
    expand1_quant = quantize_to_fixed_point(expand1, verbose=verbose)
    
    expand3 = state_dict[expand3_key].cpu().numpy()
    expand3_quant = quantize_to_fixed_point(expand3, verbose=verbose)
    
    results = []
    results.append(format_array_cpp(f"fire{fire_num}_squeeze_weights_flat", squeeze_quant))
    results.append(format_array_cpp(f"fire{fire_num}_expand1x1_weights_flat", expand1_quant))
    results.append(format_array_cpp(f"fire{fire_num}_expand3x3_weights_flat", expand3_quant))
    
    return results

def main():
    checkpoint_path = "squeezenet_qat.pth"
    
    if not Path(checkpoint_path).exists():
        print(f"Error: File {checkpoint_path} not found")
        sys.exit(1)
    
    print(f"Loading checkpoint from {checkpoint_path}...")
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    
    # Handle different checkpoint formats
    if 'state_dict' in checkpoint:
        state_dict = checkpoint['state_dict']
    elif 'model_state_dict' in checkpoint:
        state_dict = checkpoint['model_state_dict']
    elif 'model' in checkpoint:
        state_dict = checkpoint['model']
    else:
        state_dict = checkpoint
    
    print(f"Loaded {len(state_dict)} parameters")
    
    # Analyze weight distribution
    w_min, w_max = analyze_weight_distribution(state_dict)
    
    # Check if we need global scaling
    target_max = 7.9375
    if abs(w_max) > target_max or abs(w_min) > target_max:
        print(f"\n⚠️  WARNING: Weights range [{w_min:.4f}, {w_max:.4f}]")
        print(f"   exceeds target range [-8.0, 7.9375]")
        print(f"   Each layer will be scaled individually to fit.\n")
    
    print("\n" + "="*70)
    print("EXTRACTING WEIGHTS")
    print("="*70)
    
    output_lines = []
    output_lines.append("#ifndef WEIGHTS_COMBINED_H")
    output_lines.append("#define WEIGHTS_COMBINED_H")
    output_lines.append("")
    output_lines.append("#include \"config.h\"")
    output_lines.append("")
    output_lines.append("// Auto-generated weight file from squeezenet_qat.pth")
    output_lines.append("// Quantized to 8-bit fixed-point (4 int, 4 frac)")
    output_lines.append("// Range: -8.0 to +7.9375 (resolution: 0.0625)")
    output_lines.append("")
    
    # Conv1
    print("\n[1/11] Extracting Conv1...")
    conv1 = extract_conv_weights(state_dict, "conv1.weight", "conv1_weights_flat", verbose=True)
    if conv1:
        output_lines.append("// Conv1: 7x7x3x96 = 14,112 weights")
        output_lines.append("inline " + conv1)
        output_lines.append("")
    else:
        print("ERROR: Conv1 extraction failed!")
        return
    
    # Fire modules 2-9
    fire_modules = [
        (2, "fire2"),
        (3, "fire3"),
        (4, "fire4"),
        (5, "fire5"),
        (6, "fire6"),
        (7, "fire7"),
        (8, "fire8"),
        (9, "fire9")
    ]
    
    for idx, (fire_num, fire_name) in enumerate(fire_modules, start=2):
        print(f"\n[{idx}/11] Extracting Fire{fire_num}...")
        fire_weights = extract_fire_module(state_dict, fire_name, fire_num, verbose=False)
        if fire_weights:
            output_lines.append(f"// Fire{fire_num}")
            for weight_def in fire_weights:
                output_lines.append("inline " + weight_def)
            output_lines.append("")
        else:
            print(f"ERROR: Fire{fire_num} extraction failed!")
    
    # Conv10
    print("\n[11/11] Extracting Conv10...")
    conv10 = extract_conv_weights(state_dict, "conv10.weight", "conv10_weights_flat", verbose=True)
    if conv10:
        output_lines.append("// Conv10: 1x1x512x10 = 5,120 weights")
        output_lines.append("inline " + conv10)
        output_lines.append("")
    else:
        print("ERROR: Conv10 extraction failed!")
        return
    
    # Close header guard
    output_lines.append("#endif // WEIGHTS_COMBINED_H")
    
    # Write to file
    output_file = "weights_combined.h"
    print(f"\n{'='*70}")
    print(f"Writing to {output_file}...")
    with open(output_file, 'w') as f:
        f.write('\n'.join(output_lines))
    
    # Calculate total weights
    total_weights = (
        14112 +      # conv1
        1536 + 1024 + 9216 +   # fire2
        2048 + 1024 + 9216 +   # fire3
        4096 + 4096 + 36864 +  # fire4
        8192 + 4096 + 36864 +  # fire5
        12288 + 9216 + 82944 + # fire6
        18432 + 9216 + 82944 + # fire7
        24576 + 16384 + 147456 + # fire8
        32768 + 16384 + 147456 + # fire9
        5120         # conv10
    )
    
    file_size = Path(output_file).stat().st_size / (1024 * 1024)
    
    print(f"✓ Successfully extracted {total_weights:,} weights")
    print(f"✓ File size: {file_size:.2f} MB")
    print(f"{'='*70}\n")
    
    print("Next steps:")
    print("1. Compile: make clean && make weight_load")
    print("2. Test: ./tb_weight_loading")
    print("3. Verify weight ranges are within [-128, 127]")

if __name__ == "__main__":
    main()

Loading checkpoint from squeezenet_qat.pth...
Loaded 52 parameters

WEIGHT DISTRIBUTION ANALYSIS
conv1.weight                             min=-0.2870 max= 0.3948
fire2.conv1.weight                       min=-0.6087 max= 0.4776
fire2.conv2.weight                       min=-0.1245 max= 0.2498
fire2.conv3.weight                       min=-0.2449 max= 0.2877
fire3.conv1.weight                       min=-0.5484 max= 0.5835
fire3.conv2.weight                       min=-0.1264 max= 0.2192
fire3.conv3.weight                       min=-0.2986 max= 0.3525
fire4.conv1.weight                       min=-0.7804 max= 0.7987
fire4.conv2.weight                       min=-0.0852 max= 0.2099
fire4.conv3.weight                       min=-0.2670 max= 0.4170
fire5.conv1.weight                       min=-0.4005 max= 0.5334
fire5.conv2.weight                       min=-0.2113 max= 0.3256
fire5.conv3.weight                       min=-0.2719 max= 0.3134
fire6.conv1.weight                       min=-0.3783 max= 