In [4]:
#!/usr/bin/env python3
"""
Extract quantized weights from PyTorch checkpoint and generate C++ header file.
Usage: python extract_weights.py squeeze_qat.pth
"""

import torch
import numpy as np
import sys
from pathlib import Path

def explore_state_dict(state_dict):
    """Print all keys in the state dict to understand structure"""
    print("\n" + "="*70)
    print("STATE DICT STRUCTURE")
    print("="*70)
    
    keys = sorted(state_dict.keys())
    for key in keys:
        shape = tuple(state_dict[key].shape)
        print(f"{key:50s} {str(shape):20s}")
    
    print("="*70 + "\n")
    return keys

def quantize_to_fixed_point(weight, int_bits=4, frac_bits=4):
    """
    Convert float weights to fixed-point representation.
    Total bits = int_bits + frac_bits
    """
    total_bits = int_bits + frac_bits
    scale = 2 ** frac_bits
    
    # Quantize
    quantized = np.round(weight * scale).astype(np.int32)
    
    # Saturate to valid range
    max_val = (2 ** (total_bits - 1)) - 1
    min_val = -(2 ** (total_bits - 1))
    quantized = np.clip(quantized, min_val, max_val)
    
    return quantized

def find_layer_key(state_dict, patterns):
    """Find the first matching key for given patterns"""
    keys = state_dict.keys()
    for pattern in patterns:
        for key in keys:
            if pattern in key and key.endswith('.weight'):
                return key
    return None

def format_array_cpp(name, data, line_width=12):
    """Format numpy array as C++ array initialization"""
    flat = data.flatten()
    lines = []
    lines.append(f"const fixed_point_t {name}[{len(flat)}] = {{")
    
    # Format in rows
    for i in range(0, len(flat), line_width):
        row = flat[i:i+line_width]
        row_str = ", ".join(f"{int(x)}" for x in row)
        lines.append(f"    {row_str},")
    
    # Remove trailing comma from last line
    if lines[-1].endswith(","):
        lines[-1] = lines[-1][:-1]
    
    lines.append("};")
    return "\n".join(lines)

def extract_conv_weights(state_dict, layer_patterns, out_name):
    """Extract convolutional layer weights with flexible key matching"""
    # Try multiple possible key patterns
    if isinstance(layer_patterns, str):
        layer_patterns = [layer_patterns]
    
    key = find_layer_key(state_dict, [p + '.weight' for p in layer_patterns])
    
    if key is None:
        print(f"Warning: Could not find weight key matching {layer_patterns}")
        return None
    
    print(f"  Found: {key}")
    weight = state_dict[key].cpu().numpy()
    print(f"  Shape: {weight.shape}")
    quantized = quantize_to_fixed_point(weight)
    return format_array_cpp(out_name, quantized)

def extract_fire_weights(state_dict, fire_patterns):
    """Extract all weights from a Fire module with flexible pattern matching"""
    
    # Try to find squeeze weights
    squeeze_key = find_layer_key(state_dict, [p + '.squeeze' for p in fire_patterns])
    if not squeeze_key:
        print(f"Warning: Could not find squeeze weights for patterns {fire_patterns}")
        return None
    
    # Try to find expand1x1 weights  
    expand1_key = find_layer_key(state_dict, [p + '.expand1x1' for p in fire_patterns])
    if not expand1_key:
        print(f"Warning: Could not find expand1x1 weights for patterns {fire_patterns}")
        return None
    
    # Try to find expand3x3 weights
    expand3_key = find_layer_key(state_dict, [p + '.expand3x3' for p in fire_patterns])
    if not expand3_key:
        print(f"Warning: Could not find expand3x3 weights for patterns {fire_patterns}")
        return None
    
    print(f"  Squeeze: {squeeze_key} {state_dict[squeeze_key].shape}")
    print(f"  Expand1: {expand1_key} {state_dict[expand1_key].shape}")
    print(f"  Expand3: {expand3_key} {state_dict[expand3_key].shape}")
    
    squeeze = state_dict[squeeze_key].cpu().numpy()
    squeeze_quant = quantize_to_fixed_point(squeeze)
    
    expand1 = state_dict[expand1_key].cpu().numpy()
    expand1_quant = quantize_to_fixed_point(expand1)
    
    expand3 = state_dict[expand3_key].cpu().numpy()
    expand3_quant = quantize_to_fixed_point(expand3)
    
    # Extract fire number from the pattern
    fire_num = None
    for pattern in fire_patterns:
        if 'fire' in pattern.lower():
            # Extract number from pattern like "fire2" or "features.fire2"
            import re
            match = re.search(r'fire(\d+)', pattern.lower())
            if match:
                fire_num = match.group(1)
                break
    
    if fire_num is None:
        fire_num = "unknown"
    
    fire_name = f"fire{fire_num}"
    
    results = []
    results.append(format_array_cpp(f"{fire_name}_squeeze_weights_flat", squeeze_quant))
    results.append(format_array_cpp(f"{fire_name}_expand1x1_weights_flat", expand1_quant))
    results.append(format_array_cpp(f"{fire_name}_expand3x3_weights_flat", expand3_quant))
    
    return results

def main():
    checkpoint_path = "squeezenet_qat.pth"
    
    if not Path(checkpoint_path).exists():
        print(f"Error: File {checkpoint_path} not found")
        sys.exit(1)
    
    print(f"Loading checkpoint from {checkpoint_path}...")
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    
    # Handle different checkpoint formats
    if 'state_dict' in checkpoint:
        state_dict = checkpoint['state_dict']
    elif 'model_state_dict' in checkpoint:
        state_dict = checkpoint['model_state_dict']
    elif 'model' in checkpoint:
        state_dict = checkpoint['model']
    else:
        state_dict = checkpoint
    
    # First, explore the structure
    all_keys = explore_state_dict(state_dict)
    
    print("Extracting weights...")
    
    output_lines = []
    output_lines.append("#ifndef WEIGHTS_COMBINED_H")
    output_lines.append("#define WEIGHTS_COMBINED_H")
    output_lines.append("")
    output_lines.append("#include \"config.h\"")
    output_lines.append("")
    output_lines.append("// Auto-generated weight file from squeeze_qat.pth")
    output_lines.append("// Quantized to 8-bit fixed-point (4 int, 4 frac)")
    output_lines.append("")
    
    # Try to identify architecture from keys
    # Common patterns: features.X, conv1, classifier, etc.
    
    # Conv1 - try multiple possible patterns
    print("\nExtracting Conv1 weights...")
    conv1_patterns = ['features.0', 'conv1', 'features.conv1', 'backbone.0']
    conv1 = extract_conv_weights(state_dict, conv1_patterns, "conv1_weights_flat")
    if conv1:
        output_lines.append("// Conv1: 7x7x3x96")
        output_lines.append("inline " + conv1)
        output_lines.append("")
    else:
        print("ERROR: Could not find Conv1 weights!")
        print("Please manually check the state_dict structure above.")
    
    # Fire modules - try to auto-detect
    print("\nExtracting Fire module weights...")
    
    # Try to find fire module keys
    fire_keys = [k for k in all_keys if 'fire' in k.lower() and 'squeeze' in k.lower()]
    
    if fire_keys:
        print(f"Found {len(fire_keys)} fire module layers")
        # Extract fire numbers from keys
        fire_nums = set()
        import re
        for key in fire_keys:
            match = re.search(r'fire(\d+)', key.lower())
            if match:
                fire_nums.add(int(match.group(1)))
        
        fire_nums = sorted(fire_nums)
        print(f"Fire module numbers detected: {fire_nums}")
        
        for fire_num in fire_nums:
            print(f"\nExtracting Fire{fire_num} weights...")
            # Try multiple naming patterns
            fire_patterns = [
                f'features.fire{fire_num}',
                f'fire{fire_num}',
                f'features.{fire_num}',
                f'backbone.fire{fire_num}'
            ]
            fire_weights = extract_fire_weights(state_dict, fire_patterns)
            if fire_weights:
                output_lines.append(f"// Fire{fire_num}")
                for weight_def in fire_weights:
                    output_lines.append("inline " + weight_def)
                output_lines.append("")
    else:
        print("No fire modules found with pattern 'fire' in keys")
        print("Attempting fallback: checking features.X for Fire-like structure...")
        
        # Fallback: manually specify typical SqueezeNet structure
        # features.3, 4, 5, 7, 8, 9, 10, 12 are usually Fire modules
        fire_mapping = {
            2: ['features.3', 'features.fire2'],
            3: ['features.4', 'features.fire3'],
            4: ['features.5', 'features.fire4'],
            5: ['features.7', 'features.fire5'],
            6: ['features.8', 'features.fire6'],
            7: ['features.9', 'features.fire7'],
            8: ['features.10', 'features.fire8'],
            9: ['features.12', 'features.fire9']
        }
        
        for fire_num, patterns in fire_mapping.items():
            print(f"\nExtracting Fire{fire_num} weights...")
            fire_weights = extract_fire_weights(state_dict, patterns)
            if fire_weights:
                output_lines.append(f"// Fire{fire_num}")
                for weight_def in fire_weights:
                    output_lines.append("inline " + weight_def)
                output_lines.append("")
    
    # Conv10 (final classifier conv)
    print("\nExtracting Conv10 weights...")
    conv10_patterns = ['classifier.1', 'classifier.conv', 'conv10', 'features.13', 'final_conv']
    conv10 = extract_conv_weights(state_dict, conv10_patterns, "conv10_weights_flat")
    if conv10:
        output_lines.append("// Conv10: 1x1x512x10")
        output_lines.append("inline " + conv10)
        output_lines.append("")
    else:
        print("WARNING: Could not find Conv10 weights!")
    
    # Close the header guard
    output_lines.append("#endif // WEIGHTS_COMBINED_H")
    
    # Write to file
    output_file = "weights.h"
    print(f"\nWriting to {output_file}...")
    with open(output_file, 'w') as f:
        f.write('\n'.join(output_lines))
    
    print(f"✓ Successfully extracted weights to {output_file}")
    print("\nNext steps:")
    print("1. Compile the project: make weight_load")
    print("2. Run the test: ./tb_weight_loading")
    print("3. Verify weight values match your expected quantization")

if __name__ == "__main__":
    main()

Loading checkpoint from squeezenet_qat.pth...

STATE DICT STRUCTURE
conv1.bias                                         (96,)               
conv1.weight                                       (96, 3, 7, 7)       
conv10.bias                                        (10,)               
conv10.weight                                      (10, 512, 1, 1)     
fire2.conv1.bias                                   (16,)               
fire2.conv1.weight                                 (16, 96, 1, 1)      
fire2.conv2.bias                                   (64,)               
fire2.conv2.weight                                 (64, 16, 1, 1)      
fire2.conv3.bias                                   (64,)               
fire2.conv3.weight                                 (64, 16, 3, 3)      
fire3.conv1.bias                                   (16,)               
fire3.conv1.weight                                 (16, 128, 1, 1)     
fire3.conv2.bias                                   (64,)            