# MNIST INT8 Classification on ESP32-P4 using P4-JIT

**Complete Quantization-Aware Training â†’ Native RISC-V Deployment Pipeline**

## Features:
- âœ… **Fake Quantization** with Straight-Through Estimator (STE)
- âœ… **Power-of-2 Scales** for efficient bit-shift operations
- âœ… **INT8 weights & activations** throughout
- âœ… **INT32 accumulators** for precision
- âœ… **Zero firmware changes** via P4-JIT
- âœ… **Native RISC-V execution** at 360 MHz
- âœ… **Comprehensive binary analysis**

---

## 1. Setup & Environment

In [1]:
import os
import sys
import struct
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm.auto import tqdm

# Setup directories
NOTEBOOK_DIR = Path.cwd()
SOURCE_DIR = NOTEBOOK_DIR / "source"
WEIGHTS_DIR = NOTEBOOK_DIR / "weights"
RESULTS_DIR = NOTEBOOK_DIR / "results"

for d in [SOURCE_DIR, WEIGHTS_DIR, RESULTS_DIR]:
    d.mkdir(parents=True, exist_ok=True)

# Add P4-JIT to path
PROJECT_ROOT = NOTEBOOK_DIR.parent.parent.parent
sys.path.append(str(PROJECT_ROOT / "host"))

from p4jit import P4JIT, MALLOC_CAP_SPIRAM, MALLOC_CAP_8BIT
import p4jit

# Configure
torch.manual_seed(42)
np.random.seed(42)

print("âœ“ Environment ready")
print(f"âœ“ PyTorch version: {torch.__version__}")
print(f"âœ“ Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}")

âœ“ Environment ready
âœ“ PyTorch version: 2.8.0+cu126
âœ“ Device: CUDA


## 2. Dataset Preparation

In [None]:
# MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

print(f"âœ“ Training samples: {len(train_dataset):,}")
print(f"âœ“ Test samples: {len(test_dataset):,}")

# Visualize samples
fig, axes = plt.subplots(2, 10, figsize=(15, 3))
for i in range(20):
    img, label = train_dataset[i]
    ax = axes[i // 10, i % 10]
    ax.imshow(img.squeeze(), cmap='gray')
    ax.set_title(f'{label}')
    ax.axis('off')
plt.suptitle('MNIST Dataset Samples', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(RESULTS_DIR / 'dataset_samples.png', dpi=150, bbox_inches='tight')
plt.show()

## 3. Quantization Module with Straight-Through Estimator

**Key Features:**
- Power-of-2 scales (2^n) for efficient bit-shift operations
- Straight-Through Estimator for gradient flow
- **CONSTANT scale exponents** (determined via calibration)
- Quantize both weights AND activations during training

In [4]:
class PowerOfTwoQuantize(torch.autograd.Function):
    """
    Straight-Through Estimator for INT8 quantization.
    Uses power-of-2 scale for efficient bit-shift operations.
    
    Forward: Quantize to INT8 using scale = 2^(-n)
    Backward: Pass gradients straight through (STE)
    """
    @staticmethod
    def forward(ctx, x, scale_exp):
        # scale = 2^(-scale_exp), so quantization is x * 2^scale_exp
        # Then clip to INT8 range [-128, 127]
        scale = 2.0 ** scale_exp
        x_scaled = x * scale
        x_quant = torch.clamp(torch.round(x_scaled), -128, 127)
        x_dequant = x_quant / scale
        return x_dequant
    
    @staticmethod
    def backward(ctx, grad_output):
        # Straight-through: gradient flows unchanged
        return grad_output, None


class FakeQuantizeINT8(nn.Module):
    """
    Fake quantization module with FIXED power-of-2 scale.
    Scale exponent is calibrated and then frozen.
    """
    def __init__(self):
        super().__init__()
        # NOT a Parameter - just a buffer (constant)
        self.register_buffer('scale_exp', torch.tensor(0))
        self.enabled = False  # Start disabled
        
    def set_scale_exp(self, exp):
        """Set the fixed scale exponent after calibration"""
        self.scale_exp = torch.tensor(exp)
        
    def enable(self):
        """Enable quantization (after calibration)"""
        self.enabled = True
        
    def forward(self, x):
        if not self.enabled:
            # Quantization disabled - pass through
            return x
        
        # Use fixed scale_exp
        return PowerOfTwoQuantize.apply(x, self.scale_exp)
    
    def get_scale_info(self):
        """Get quantization scale as power-of-2"""
        exp = int(self.scale_exp.item())
        scale = 2.0 ** (-exp)
        return exp, scale


class QuantizedConv2d(nn.Module):
    """Conv2d with weight quantization"""
    def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, **kwargs)
        self.weight_quant = FakeQuantizeINT8()
        
    def forward(self, x):
        # Quantize weights (if enabled)
        w_quant = self.weight_quant(self.conv.weight)
        # Use quantized weights for convolution
        return F.conv2d(x, w_quant, self.conv.bias, 
                       self.conv.stride, self.conv.padding, 
                       self.conv.dilation, self.conv.groups)


class QuantizedLinear(nn.Module):
    """Linear layer with weight quantization"""
    def __init__(self, in_features, out_features):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features)
        self.weight_quant = FakeQuantizeINT8()
        
    def forward(self, x):
        # Quantize weights (if enabled)
        w_quant = self.weight_quant(self.linear.weight)
        return F.linear(x, w_quant, self.linear.bias)


print("âœ“ Quantization modules defined")
print("âœ“ Using Straight-Through Estimator for gradient flow")
print("âœ“ Power-of-2 scales for efficient bit-shift operations")
print("âœ“ Scale exponents will be CALIBRATED and FROZEN")

âœ“ Quantization modules defined
âœ“ Using Straight-Through Estimator for gradient flow
âœ“ Power-of-2 scales for efficient bit-shift operations
âœ“ Scale exponents will be CALIBRATED and FROZEN


## 4. Quantized Neural Network Architecture

**Network Structure:**
```
Input (28Ã—28) 
  â†’ Conv2d(1â†’16, 3Ã—3) â†’ Quantize â†’ ReLU â†’ Quantize â†’ MaxPool(2Ã—2)
  â†’ Conv2d(16â†’32, 3Ã—3) â†’ Quantize â†’ ReLU â†’ Quantize â†’ MaxPool(2Ã—2)
  â†’ Flatten
  â†’ Linear(800â†’128) â†’ Quantize â†’ ReLU â†’ Quantize
  â†’ Linear(128â†’10)
  â†’ Output (10 classes)
```

**All weights and activations are fake-quantized to INT8!**

In [5]:
class QuantizedMNISTNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Layer 1: Conv + ReLU + MaxPool
        self.conv1 = QuantizedConv2d(1, 16, kernel_size=3, padding=1)
        self.act1_quant = FakeQuantizeINT8()
        
        # Layer 2: Conv + ReLU + MaxPool
        self.conv2 = QuantizedConv2d(16, 32, kernel_size=3, padding=1)
        self.act2_quant = FakeQuantizeINT8()
        
        # Layer 3: FC + ReLU
        self.fc1 = QuantizedLinear(32 * 7 * 7, 128)
        self.act3_quant = FakeQuantizeINT8()
        
        # Layer 4: FC (output)
        self.fc2 = QuantizedLinear(128, 10)
        
    def forward(self, x):
        # Conv1 block
        x = self.conv1(x)
        x = F.relu(x)
        x = self.act1_quant(x)  # Quantize activations
        x = F.max_pool2d(x, 2)
        
        # Conv2 block
        x = self.conv2(x)
        x = F.relu(x)
        x = self.act2_quant(x)  # Quantize activations
        x = F.max_pool2d(x, 2)
        
        # Flatten
        x = x.view(x.size(0), -1)
        
        # FC1 block
        x = self.fc1(x)
        x = F.relu(x)
        x = self.act3_quant(x)  # Quantize activations
        
        # FC2 (output logits - no quantization)
        x = self.fc2(x)
        
        return x
    
    def print_quantization_info(self):
        """Print quantization scale information for all layers"""
        print("\n" + "="*80)
        print("QUANTIZATION PARAMETERS (Power-of-2 Scales)")
        print("="*80)
        
        for name, module in self.named_modules():
            if isinstance(module, FakeQuantizeINT8):
                exp, scale = module.get_scale_info()
                print(f"{name:30s} | Scale: 2^({exp:+3d}) = {scale:.6f} | Shift: {exp} bits")
        
        print("="*80)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = QuantizedMNISTNet().to(device)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nâœ“ Model created")
print(f"âœ“ Total parameters: {total_params:,}")
print(f"âœ“ Trainable parameters: {trainable_params:,}")
print(f"\nArchitecture:")
print(model)


âœ“ Model created
âœ“ Total parameters: 206,922
âœ“ Trainable parameters: 206,922

Architecture:
QuantizedMNISTNet(
  (conv1): QuantizedConv2d(
    (conv): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (weight_quant): FakeQuantizeINT8()
  )
  (act1_quant): FakeQuantizeINT8()
  (conv2): QuantizedConv2d(
    (conv): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (weight_quant): FakeQuantizeINT8()
  )
  (act2_quant): FakeQuantizeINT8()
  (fc1): QuantizedLinear(
    (linear): Linear(in_features=1568, out_features=128, bias=True)
    (weight_quant): FakeQuantizeINT8()
  )
  (act3_quant): FakeQuantizeINT8()
  (fc2): QuantizedLinear(
    (linear): Linear(in_features=128, out_features=10, bias=True)
    (weight_quant): FakeQuantizeINT8()
  )
)


## 5. Training Pipeline: Warmup â†’ Calibration â†’ QAT

**3-Phase Training:**
1. **Warmup (3 epochs)**: Train without quantization to stabilize weights
2. **Calibration**: Determine optimal scale exponents using statistics
3. **QAT (7 epochs)**: Train with FIXED quantization scales

In [7]:
def train_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    correct = 0
    
    for data, target in tqdm(loader, desc='Training', leave=False):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        pred = output.argmax(dim=1)
        correct += pred.eq(target).sum().item()
    
    return total_loss / len(loader), 100. * correct / len(loader.dataset)


def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            total_loss += criterion(output, target).item()
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
    
    return total_loss / len(loader), 100. * correct / len(loader.dataset)


# Training configuration
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training phases
WARMUP_EPOCHS = 3
QAT_EPOCHS = 3
TOTAL_EPOCHS = WARMUP_EPOCHS + QAT_EPOCHS

history = {'train_loss': [], 'train_acc': [], 'test_loss': [], 'test_acc': []}

print("\n" + "="*80)
print("PHASE 1: WARMUP TRAINING (No Quantization)")
print("="*80)
print("Training without quantization to stabilize weights...\n")

for epoch in range(1, WARMUP_EPOCHS + 1):
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
    test_loss, test_acc = evaluate(model, test_loader, criterion, device)
    
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['test_loss'].append(test_loss)
    history['test_acc'].append(test_acc)
    
    print(f"Warmup Epoch {epoch}/{WARMUP_EPOCHS} | "
          f"Train Loss: {train_loss:.4f} Acc: {train_acc:.2f}% | "
          f"Test Loss: {test_loss:.4f} Acc: {test_acc:.2f}%")

print("\nâœ“ Warmup complete! Weights stabilized.")
print(f"âœ“ Warmup Test Accuracy: {history['test_acc'][-1]:.2f}%")


PHASE 1: WARMUP TRAINING (No Quantization)
Training without quantization to stabilize weights...



Training:   0%|          | 0/469 [00:00<?, ?it/s]

Warmup Epoch 1/3 | Train Loss: 0.0378 Acc: 98.87% | Test Loss: 0.0336 Acc: 98.94%


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Warmup Epoch 2/3 | Train Loss: 0.0258 Acc: 99.19% | Test Loss: 0.0414 Acc: 98.65%


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Warmup Epoch 3/3 | Train Loss: 0.0207 Acc: 99.34% | Test Loss: 0.0339 Acc: 98.83%

âœ“ Warmup complete! Weights stabilized.
âœ“ Warmup Test Accuracy: 98.83%


### Calibration: Determine Scale Exponents

In [8]:
def calculate_scale_exponent(tensor):
    """
    Calculate optimal power-of-2 scale exponent for INT8 quantization.
    
    For a tensor to fit in [-128, 127]:
    max(abs(tensor)) * 2^exp <= 127
    exp = floor(log2(127 / max(abs(tensor))))
    """
    max_val = tensor.abs().max().item()
    if max_val == 0:
        return 0
    
    # Calculate exponent: scale = 2^exp such that max_val * 2^exp <= 127
    import math
    exp = math.floor(math.log2(127.0 / max_val))
    return exp


def calibrate_activations(model, loader, device, num_batches=10):
    """
    Calibrate activation scales using statistics from multiple batches.
    Returns a dict mapping layer names to scale exponents.
    """
    model.eval()
    
    # Hooks to capture activations
    activation_stats = {}
    hooks = []
    
    def get_hook(name):
        def hook(module, input, output):
            if name not in activation_stats:
                activation_stats[name] = []
            activation_stats[name].append(output.abs().max().item())
        return hook
    
    # Register hooks for activation quantizers
    for name, module in model.named_modules():
        if isinstance(module, FakeQuantizeINT8) and 'act' in name:
            # Hook on the layer BEFORE the quantizer
            parent_name = '.'.join(name.split('.')[:-1])
            if parent_name:
                parent_module = dict(model.named_modules())[parent_name]
            else:
                parent_module = model
            
            hooks.append(parent_module.register_forward_hook(get_hook(name)))
    
    # Run inference on calibration batches
    with torch.no_grad():
        for i, (data, _) in enumerate(loader):
            if i >= num_batches:
                break
            data = data.to(device)
            model(data)
    
    # Remove hooks
    for hook in hooks:
        hook.remove()
    
    # Calculate exponents from statistics
    exponents = {}
    for name, values in activation_stats.items():
        max_activation = max(values)
        import math
        exp = math.floor(math.log2(127.0 / max_activation)) if max_activation > 0 else 0
        exponents[name] = exp
    
    return exponents


print("\n" + "="*80)
print("PHASE 2: CALIBRATION")
print("="*80)
print("Calculating optimal scale exponents...\n")

# Calibrate weights (simple - based on current weight values)
print("Weight Calibration:")
for name, module in model.named_modules():
    if isinstance(module, (QuantizedConv2d, QuantizedLinear)):
        if isinstance(module, QuantizedConv2d):
            weight = module.conv.weight
        else:
            weight = module.linear.weight
        
        exp = calculate_scale_exponent(weight.data)
        module.weight_quant.set_scale_exp(exp)
        
        scale = 2.0 ** (-exp)
        print(f"  {name:20s} | Exponent: {exp:+3d} | Scale: 2^({exp:+3d}) = {scale:.6f}")

# Calibrate activations (using statistics from large batch)
print("\nActivation Calibration (using 10 batches):")
act_exponents = calibrate_activations(model, train_loader, device, num_batches=10)

for name, module in model.named_modules():
    if isinstance(module, FakeQuantizeINT8) and name in act_exponents:
        exp = act_exponents[name]
        module.set_scale_exp(exp)
        scale = 2.0 ** (-exp)
        print(f"  {name:20s} | Exponent: {exp:+3d} | Scale: 2^({exp:+3d}) = {scale:.6f}")

print("\nâœ“ Calibration complete! Scale exponents are now FIXED.")
print("="*80)


PHASE 2: CALIBRATION
Calculating optimal scale exponents...

Weight Calibration:
  conv1                | Exponent:  +7 | Scale: 2^( +7) = 0.007812
  conv2                | Exponent:  +8 | Scale: 2^( +8) = 0.003906
  fc1                  | Exponent:  +8 | Scale: 2^( +8) = 0.003906
  fc2                  | Exponent:  +8 | Scale: 2^( +8) = 0.003906

Activation Calibration (using 10 batches):
  act1_quant           | Exponent:  +2 | Scale: 2^( +2) = 0.250000
  act2_quant           | Exponent:  +2 | Scale: 2^( +2) = 0.250000
  act3_quant           | Exponent:  +2 | Scale: 2^( +2) = 0.250000

âœ“ Calibration complete! Scale exponents are now FIXED.


### Enable Quantization and Continue Training

In [9]:
# Enable all quantizers
for module in model.modules():
    if isinstance(module, FakeQuantizeINT8):
        module.enable()

print("\n" + "="*80)
print("PHASE 3: QUANTIZATION-AWARE TRAINING (Fixed Scales)")
print("="*80)
print("Training with fake quantization enabled...\n")

# Reset optimizer for QAT phase
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.5)

for epoch in range(1, QAT_EPOCHS + 1):
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
    test_loss, test_acc = evaluate(model, test_loader, criterion, device)
    scheduler.step()
    
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['test_loss'].append(test_loss)
    history['test_acc'].append(test_acc)
    
    total_epoch = WARMUP_EPOCHS + epoch
    print(f"QAT Epoch {epoch}/{QAT_EPOCHS} (Total: {total_epoch}/{TOTAL_EPOCHS}) | "
          f"Train Loss: {train_loss:.4f} Acc: {train_acc:.2f}% | "
          f"Test Loss: {test_loss:.4f} Acc: {test_acc:.2f}%")

print("\n" + "="*80)
print("TRAINING COMPLETE")
print("="*80)
print(f"âœ“ Final Test Accuracy: {history['test_acc'][-1]:.2f}%")
print(f"âœ“ Accuracy after warmup: {history['test_acc'][WARMUP_EPOCHS-1]:.2f}%")
print(f"âœ“ Accuracy after QAT: {history['test_acc'][-1]:.2f}%")
print("="*80)

# Show final quantization parameters
model.print_quantization_info()


PHASE 3: QUANTIZATION-AWARE TRAINING (Fixed Scales)
Training with fake quantization enabled...



Training:   0%|          | 0/469 [00:00<?, ?it/s]

QAT Epoch 1/3 (Total: 4/6) | Train Loss: 0.0175 Acc: 99.42% | Test Loss: 0.0330 Acc: 99.00%


Training:   0%|          | 0/469 [00:00<?, ?it/s]

QAT Epoch 2/3 (Total: 5/6) | Train Loss: 0.0121 Acc: 99.61% | Test Loss: 0.0296 Acc: 99.07%


Training:   0%|          | 0/469 [00:00<?, ?it/s]

QAT Epoch 3/3 (Total: 6/6) | Train Loss: 0.0096 Acc: 99.72% | Test Loss: 0.0420 Acc: 98.80%

TRAINING COMPLETE
âœ“ Final Test Accuracy: 98.80%
âœ“ Accuracy after warmup: 98.83%
âœ“ Accuracy after QAT: 98.80%

QUANTIZATION PARAMETERS (Power-of-2 Scales)
conv1.weight_quant             | Scale: 2^( +7) = 0.007812 | Shift: 7 bits
act1_quant                     | Scale: 2^( +2) = 0.250000 | Shift: 2 bits
conv2.weight_quant             | Scale: 2^( +8) = 0.003906 | Shift: 8 bits
act2_quant                     | Scale: 2^( +2) = 0.250000 | Shift: 2 bits
fc1.weight_quant               | Scale: 2^( +8) = 0.003906 | Shift: 8 bits
act3_quant                     | Scale: 2^( +2) = 0.250000 | Shift: 2 bits
fc2.weight_quant               | Scale: 2^( +8) = 0.003906 | Shift: 8 bits


In [None]:
# Plot training history with phase separation
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

epochs = list(range(1, len(history['train_loss']) + 1))

# Loss plot
ax1.plot(epochs, history['train_loss'], label='Train Loss', marker='o', linewidth=2)
ax1.plot(epochs, history['test_loss'], label='Test Loss', marker='s', linewidth=2)
ax1.axvline(x=WARMUP_EPOCHS, color='red', linestyle='--', linewidth=2, 
            label='Calibration Point')
ax1.fill_between([0, WARMUP_EPOCHS], 0, ax1.get_ylim()[1], alpha=0.2, color='orange', 
                  label='Warmup Phase')
ax1.fill_between([WARMUP_EPOCHS, len(epochs)], 0, ax1.get_ylim()[1], alpha=0.2, color='blue',
                  label='QAT Phase')
ax1.set_xlabel('Epoch', fontsize=12, fontweight='bold')
ax1.set_ylabel('Loss', fontsize=12, fontweight='bold')
ax1.set_title('Training Loss', fontsize=14, fontweight='bold')
ax1.legend(fontsize=10, loc='upper right')
ax1.grid(True, alpha=0.3)

# Accuracy plot
ax2.plot(epochs, history['train_acc'], label='Train Accuracy', marker='o', linewidth=2)
ax2.plot(epochs, history['test_acc'], label='Test Accuracy', marker='s', linewidth=2)
ax2.axvline(x=WARMUP_EPOCHS, color='red', linestyle='--', linewidth=2,
            label='Calibration Point')
ax2.fill_between([0, WARMUP_EPOCHS], 0, 100, alpha=0.2, color='orange',
                  label='Warmup Phase')
ax2.fill_between([WARMUP_EPOCHS, len(epochs)], 0, 100, alpha=0.2, color='blue',
                  label='QAT Phase')
ax2.set_xlabel('Epoch', fontsize=12, fontweight='bold')
ax2.set_ylabel('Accuracy (%)', fontsize=12, fontweight='bold')
ax2.set_title('Training Accuracy', fontsize=14, fontweight='bold')
ax2.legend(fontsize=10, loc='lower right')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(RESULTS_DIR / 'training_history.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\nðŸ“Š Training Summary:")
print(f"  Warmup Phase: {WARMUP_EPOCHS} epochs (no quantization)")
print(f"  Calibration: Scale exponents determined and frozen")
print(f"  QAT Phase: {QAT_EPOCHS} epochs (with fixed quantization)")
print(f"  Final Accuracy: {history['test_acc'][-1]:.2f}%")

## 6. Extract Quantized Weights (INT8 + Power-of-2 Scales)

In [11]:
def quantize_to_int8(tensor, scale_exp):
    """
    Quantize tensor to INT8 using power-of-2 scale.
    scale_exp: scale = 2^(-scale_exp)
    """
    scale = 2.0 ** scale_exp
    quantized = torch.clamp(torch.round(tensor * scale), -128, 127).to(torch.int8)
    return quantized.cpu().numpy(), scale_exp


model.eval()

quantized_weights = {}

print("\n" + "="*80)
print("WEIGHT EXTRACTION & QUANTIZATION")
print("="*80)

# Conv1
exp, _ = model.conv1.weight_quant.get_scale_info()
w, scale_exp = quantize_to_int8(model.conv1.conv.weight.data, exp)
b, bias_exp = quantize_to_int8(model.conv1.conv.bias.data, exp)

quantized_weights['conv1'] = {
    'weight': w, 'weight_shape': w.shape, 'weight_scale_exp': scale_exp,
    'bias': b, 'bias_shape': b.shape, 'bias_scale_exp': bias_exp
}
print(f"Conv1: W{w.shape} [{w.min():+4d}, {w.max():+4d}] | Scale: 2^{scale_exp:+d}")

# Conv2
exp, _ = model.conv2.weight_quant.get_scale_info()
w, scale_exp = quantize_to_int8(model.conv2.conv.weight.data, exp)
b, bias_exp = quantize_to_int8(model.conv2.conv.bias.data, exp)

quantized_weights['conv2'] = {
    'weight': w, 'weight_shape': w.shape, 'weight_scale_exp': scale_exp,
    'bias': b, 'bias_shape': b.shape, 'bias_scale_exp': bias_exp
}
print(f"Conv2: W{w.shape} [{w.min():+4d}, {w.max():+4d}] | Scale: 2^{scale_exp:+d}")

# FC1
exp, _ = model.fc1.weight_quant.get_scale_info()
w, scale_exp = quantize_to_int8(model.fc1.linear.weight.data, exp)
b, bias_exp = quantize_to_int8(model.fc1.linear.bias.data, exp)

quantized_weights['fc1'] = {
    'weight': w, 'weight_shape': w.shape, 'weight_scale_exp': scale_exp,
    'bias': b, 'bias_shape': b.shape, 'bias_scale_exp': bias_exp
}
print(f"FC1:   W{w.shape} [{w.min():+4d}, {w.max():+4d}] | Scale: 2^{scale_exp:+d}")

# FC2
exp, _ = model.fc2.weight_quant.get_scale_info()
w, scale_exp = quantize_to_int8(model.fc2.linear.weight.data, exp)
b, bias_exp = quantize_to_int8(model.fc2.linear.bias.data, exp)

quantized_weights['fc2'] = {
    'weight': w, 'weight_shape': w.shape, 'weight_scale_exp': scale_exp,
    'bias': b, 'bias_shape': b.shape, 'bias_scale_exp': bias_exp
}
print(f"FC2:   W{w.shape} [{w.min():+4d}, {w.max():+4d}] | Scale: 2^{scale_exp:+d}")

# Calculate total memory
total_int8_params = sum(w['weight'].size + w['bias'].size for w in quantized_weights.values())
total_fp32_params = total_int8_params * 4

print(f"\nâœ“ Total INT8 parameters: {total_int8_params:,} bytes ({total_int8_params/1024:.2f} KB)")
print(f"âœ“ Equivalent FP32 size: {total_fp32_params/1024:.2f} KB")
print(f"âœ“ Compression ratio: {total_fp32_params/total_int8_params:.1f}Ã—")
print("="*80)


WEIGHT EXTRACTION & QUANTIZATION
Conv1: W(16, 1, 3, 3) [-102,  +70] | Scale: 2^+7
Conv2: W(32, 16, 3, 3) [-111,  +91] | Scale: 2^+8
FC1:   W(128, 1568) [-109,  +83] | Scale: 2^+8
FC2:   W(10, 128) [ -88,  +59] | Scale: 2^+8

âœ“ Total INT8 parameters: 206,922 bytes (202.07 KB)
âœ“ Equivalent FP32 size: 808.29 KB
âœ“ Compression ratio: 4.0Ã—


## 7. Create Optimized C Implementation

**Pure INT8 Operations:**
- All weights: INT8
- All activations: INT8
- Accumulators: INT32 (for precision)
- Scaling: Bit-shift operations (power-of-2)

In [15]:
c_inference_code = """#include <stdint.h>
#include <stdio.h>

// Network dimensions
#define INPUT_H 28
#define INPUT_W 28
#define CONV1_OUT_C 16
#define CONV2_OUT_C 32
#define FC1_OUT 128
#define OUTPUT_SIZE 10

// INT8 operations
static inline int8_t relu_int8(int8_t x) {
    return (x > 0) ? x : 0;
}

static inline int8_t clip_int8(int32_t x) {
    if (x > 127) return 127;
    if (x < -128) return -128;
    return (int8_t)x;
}

// 2D Convolution: INT8 weights, INT8 input, INT32 accumulator, INT8 output
// Uses bit-shift for scaling (power-of-2)
void conv2d_int8(
    const int8_t* input, int in_h, int in_w, int in_c,
    const int8_t* weight, const int8_t* bias,
    int8_t* output, int out_c,
    int scale_shift  // Power-of-2: divide by 2^scale_shift
) {
    const int kernel_size = 3;
    const int padding = 1;
    const int stride = 1;
    
    int out_h = (in_h + 2*padding - kernel_size) / stride + 1;
    int out_w = (in_w + 2*padding - kernel_size) / stride + 1;
    
    for (int oc = 0; oc < out_c; oc++) {
        for (int oh = 0; oh < out_h; oh++) {
            for (int ow = 0; ow < out_w; ow++) {
                int32_t acc = 0;  // INT32 accumulator for precision
                
                // Convolution (INT8 Ã— INT8 â†’ INT32)
                for (int ic = 0; ic < in_c; ic++) {
                    for (int kh = 0; kh < kernel_size; kh++) {
                        for (int kw = 0; kw < kernel_size; kw++) {
                            int ih = oh * stride - padding + kh;
                            int iw = ow * stride - padding + kw;
                            
                            if (ih >= 0 && ih < in_h && iw >= 0 && iw < in_w) {
                                int in_idx = (ic * in_h + ih) * in_w + iw;
                                int w_idx = ((oc * in_c + ic) * kernel_size + kh) * kernel_size + kw;
                                acc += (int32_t)input[in_idx] * (int32_t)weight[w_idx];
                            }
                        }
                    }
                }
                
                // Add bias (INT32)
                acc += (int32_t)bias[oc] << scale_shift;
                
                // Scale down using bit-shift (power-of-2)
                acc = acc >> scale_shift;
                
                // Clip to INT8 range
                int out_idx = (oc * out_h + oh) * out_w + ow;
                output[out_idx] = clip_int8(acc);
            }
        }
    }
}

// MaxPool 2x2
void maxpool2d_int8(const int8_t* input, int8_t* output, int h, int w, int c) {
    int out_h = h / 2;
    int out_w = w / 2;
    
    for (int ch = 0; ch < c; ch++) {
        for (int oh = 0; oh < out_h; oh++) {
            for (int ow = 0; ow < out_w; ow++) {
                int ih = oh * 2;
                int iw = ow * 2;
                
                int8_t max_val = -128;
                for (int kh = 0; kh < 2; kh++) {
                    for (int kw = 0; kw < 2; kw++) {
                        int in_idx = (ch * h + ih + kh) * w + iw + kw;
                        if (input[in_idx] > max_val) {
                            max_val = input[in_idx];
                        }
                    }
                }
                
                int out_idx = (ch * out_h + oh) * out_w + ow;
                output[out_idx] = max_val;
            }
        }
    }
}

// Fully Connected: INT8 weights, INT8 input, INT32 accumulator, INT8 output
void fc_int8(
    const int8_t* input, int in_size,
    const int8_t* weight, const int8_t* bias,
    int8_t* output, int out_size,
    int scale_shift
) {
    for (int i = 0; i < out_size; i++) {
        int32_t acc = 0;  // INT32 accumulator
        
        // Multiply-accumulate (INT8 Ã— INT8 â†’ INT32)
        for (int j = 0; j < in_size; j++) {
            acc += (int32_t)input[j] * (int32_t)weight[i * in_size + j];
        }
        
        // Add bias
        acc += (int32_t)bias[i] << scale_shift;
        
        // Scale down
        acc = acc >> scale_shift;
        
        // Clip and store
        output[i] = clip_int8(acc);
    }
}

// Main inference function
// Arguments:
// - input: 784 bytes (28Ã—28)
// - weights_conv1: Conv1 weights
// - bias_conv1: Conv1 biases
// - weights_conv2: Conv2 weights
// - bias_conv2: Conv2 biases
// - weights_fc1: FC1 weights
// - bias_fc1: FC1 biases
// - weights_fc2: FC2 weights
// - bias_fc2: FC2 biases
// - scratch: Scratch buffer for intermediate results
int32_t mnist_inference(
    int8_t* input,
    int8_t* weights_conv1,
    int8_t* bias_conv1,
    int8_t* weights_conv2,
    int8_t* bias_conv2,
    int8_t* weights_fc1,
    int8_t* bias_fc1,
    int8_t* weights_fc2,
    int8_t* bias_fc2,
    int8_t* scratch
) {
    printf("[JIT] Starting MNIST inference...\\n");
    
    // Allocate intermediate buffers from scratch
    int8_t* conv1_out = scratch;
    int8_t* pool1_out = conv1_out + (16 * 28 * 28);
    int8_t* conv2_out = pool1_out + (16 * 14 * 14);
    int8_t* pool2_out = conv2_out + (32 * 14 * 14);
    int8_t* fc1_out = pool2_out + (32 * 7 * 7);
    int8_t* fc2_out = fc1_out + 128;
    
    // Layer 1: Conv2d(1â†’16) + ReLU + MaxPool
    printf("[JIT] Conv1...\\n");
    conv2d_int8(input, 28, 28, 1, weights_conv1, bias_conv1, conv1_out, 16, 0);
    
    // ReLU in-place
    for (int i = 0; i < 16 * 28 * 28; i++) {
        conv1_out[i] = relu_int8(conv1_out[i]);
    }
    
    maxpool2d_int8(conv1_out, pool1_out, 28, 28, 16);
    
    // Layer 2: Conv2d(16â†’32) + ReLU + MaxPool
    printf("[JIT] Conv2...\\n");
    conv2d_int8(pool1_out, 14, 14, 16, weights_conv2, bias_conv2, conv2_out, 32, 0);
    
    // ReLU in-place
    for (int i = 0; i < 32 * 14 * 14; i++) {
        conv2_out[i] = relu_int8(conv2_out[i]);
    }
    
    maxpool2d_int8(conv2_out, pool2_out, 14, 14, 32);
    
    // Layer 3: FC(800â†’128) + ReLU
    printf("[JIT] FC1...\\n");
    fc_int8(pool2_out, 800, weights_fc1, bias_fc1, fc1_out, 128, 0);
    
    // ReLU in-place
    for (int i = 0; i < 128; i++) {
        fc1_out[i] = relu_int8(fc1_out[i]);
    }
    
    // Layer 4: FC(128â†’10) - Output logits
    printf("[JIT] FC2...\\n");
    fc_int8(fc1_out, 128, weights_fc2, bias_fc2, fc2_out, 10, 0);
    
    // Find argmax
    int8_t max_val = fc2_out[0];
    int32_t max_idx = 0;
    
    for (int i = 1; i < 10; i++) {
        if (fc2_out[i] > max_val) {
            max_val = fc2_out[i];
            max_idx = i;
        }
    }
    
    printf("[JIT] Predicted class: %d (logit: %d)\\n", max_idx, max_val);
    
    return max_idx;
}
"""

# Save C code
c_path = SOURCE_DIR / "mnist_inference.c"
with open(c_path, 'w') as f:
    f.write(c_inference_code)

print(f"âœ“ C inference code saved: {c_path}")
print(f"âœ“ Code size: {len(c_inference_code)} bytes")
print(f"âœ“ All operations use INT8 with INT32 accumulators")
print(f"âœ“ Scaling via bit-shift (power-of-2 scales)")

âœ“ C inference code saved: c:\Users\orani\bilel\git_projects\robert_manzke\project1\trys\costume_p4code_binary\P4-JIT\notebooks\tutorials\t02_mnist_classification\source\mnist_inference.c
âœ“ Code size: 6524 bytes
âœ“ All operations use INT8 with INT32 accumulators
âœ“ Scaling via bit-shift (power-of-2 scales)


## 8. Prepare Test Images (One Per Class)

In [None]:
# Select one test image per class
test_images_torch = {}
test_images_int8 = {}

for digit in range(10):
    for idx, (img, label) in enumerate(test_dataset):
        if label == digit:
            test_images_torch[digit] = img
            
            # Convert to INT8 for device
            img_np = img.squeeze().numpy()
            # Normalize to [0, 1] then scale to INT8 range
            img_norm = (img_np - img_np.min()) / (img_np.max() - img_np.min())
            img_int8 = ((img_norm * 255) - 128).astype(np.int8)
            test_images_int8[digit] = img_int8.flatten()
            break

print("âœ“ Test images selected (one per class)")
print(f"  INT8 shape: {test_images_int8[0].shape}")
print(f"  INT8 range: [{test_images_int8[0].min()}, {test_images_int8[0].max()}]")

# Visualize
fig, axes = plt.subplots(2, 5, figsize=(12, 5))
for digit in range(10):
    ax = axes[digit // 5, digit % 5]
    img = test_images_torch[digit].squeeze()
    ax.imshow(img, cmap='gray')
    ax.set_title(f'Digit: {digit}', fontweight='bold')
    ax.axis('off')

plt.suptitle('Test Images (INT8 Quantized)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(RESULTS_DIR / 'test_images.png', dpi=150, bbox_inches='tight')
plt.show()

## 9. Deploy to ESP32-P4 via P4-JIT

**ðŸš€ Dynamic Code Loading Without Firmware Changes!**

In [18]:
print("\n" + "="*80)
print("ESP32-P4 DEPLOYMENT")
print("="*80)

# Initialize P4-JIT
jit = P4JIT()

# Device memory before
print("\nDevice Memory (Initial):")
stats_initial = jit.get_heap_stats(print_s=True)

# Compile and load inference kernel
print("\n" + "="*80)
print("COMPILING C KERNEL")
print("="*80)

p4jit.set_log_level('INFO_VERBOSE')

func = jit.load(
    source=str(c_path),
    function_name='mnist_inference',
    optimization='O3',
    use_firmware_elf=True,
    smart_args=False
)

print(f"\nâœ“ Kernel loaded at: 0x{func.code_addr:08X}")
print(f"âœ“ Binary size: {func.stats['code_size']} bytes ({func.stats['code_size']/1024:.2f} KB)")
print(f"âœ“ Args buffer: 0x{func.args_addr:08X}")


ESP32-P4 DEPLOYMENT
08:55:19 [p4jit.p4jit] [94mINFO[0m: Initializing P4JIT System...
08:55:19 [p4jit.runtime.jit_session] [94mINFO[0m: Auto-detecting JIT device...
08:55:19 [p4jit.runtime.device_manager] [94mINFO[0m: Connecting to COM3 at 115200 baud...
08:55:19 [p4jit.runtime.device_manager] [94mINFO[0m: Connecting to COM6 at 115200 baud...
08:55:19 [p4jit.runtime.device_manager] [94mINFO[0m: Connected.
08:55:19 [p4jit.runtime.jit_session] [94mINFO[0m: Found JIT Device at COM6
08:55:19 [p4jit.p4jit] [94mINFO[0m: P4JIT Initialized.

Device Memory (Initial):
08:55:19 [p4jit.p4jit] [94mINFO[0m: [Heap Params]
08:55:19 [p4jit.p4jit] [94mINFO[0m:   free_spiram    :   31388992 bytes (30653.31 KB)
08:55:19 [p4jit.p4jit] [94mINFO[0m:   total_spiram   :   33554432 bytes (32768.00 KB)
08:55:19 [p4jit.p4jit] [94mINFO[0m:   free_internal  :     384063 bytes (375.06 KB)
08:55:19 [p4jit.p4jit] [94mINFO[0m:   total_internal :     464119 bytes (453.24 KB)

COMPILING C KERNEL
0

### Detailed Binary Analysis

In [19]:
print("\n" + "="*80)
print("BINARY ANALYSIS")
print("="*80)

# Sections
print("\nELF Sections:")
func.binary.print_sections()

# Symbol table
print("\nSymbol Table (Functions):")
func.binary.print_symbols()

# Memory layout
print("\nMemory Layout:")
func.binary.print_memory_map()

# Disassembly
disasm_path = RESULTS_DIR / 'inference_disasm.txt'
func.binary.disassemble(output=str(disasm_path), source_intermix=False)
print(f"\nâœ“ Full disassembly saved to: {disasm_path}")

# Show first 50 lines of disassembly
print("\nDisassembly Preview (first 50 lines):")
print("-" * 80)
with open(disasm_path, 'r') as f:
    lines = f.readlines()[:50]
    print(''.join(lines))
print("-" * 80)


BINARY ANALYSIS

ELF Sections:
08:56:00 [p4jit.toolchain.binary_object] [94mINFO[0m: Sections:
08:56:00 [p4jit.toolchain.binary_object] [94mINFO[0m:   .text                0x48210ae0    4202 bytes
08:56:00 [p4jit.toolchain.binary_object] [94mINFO[0m:   .rodata              0x48211b4c     139 bytes

Symbol Table (Functions):
08:56:00 [p4jit.toolchain.binary_object] [94mINFO[0m: Functions:
08:56:00 [p4jit.toolchain.binary_object] [94mINFO[0m:   call_remote                    0x48210ae0  4202 bytes

Memory Layout:
08:56:00 [p4jit.toolchain.binary_object] [94mINFO[0m: Memory Map (Base: 0x48210ae0):
08:56:00 [p4jit.toolchain.binary_object] [94mINFO[0m:   â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
08:56:00 [p4jit.toolchain.binary_object] [94mINFO[0m:        0  â”‚ .text          4202 bytes
08:56:00 [p4jit.toolchain.binary_object] [94mINFO

### Upload Weights to Device

In [20]:
device = jit.session.device

print("\n" + "="*80)
print("WEIGHT UPLOAD")
print("="*80)

# Allocate and upload each layer's weights
weight_addrs = {}

for layer_name in ['conv1', 'conv2', 'fc1', 'fc2']:
    layer = quantized_weights[layer_name]
    
    # Weights
    w_bytes = layer['weight'].tobytes()
    w_addr = device.allocate(len(w_bytes), MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT, 16)
    device.write_memory(w_addr, w_bytes)
    
    # Biases
    b_bytes = layer['bias'].tobytes()
    b_addr = device.allocate(len(b_bytes), MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT, 16)
    device.write_memory(b_addr, b_bytes)
    
    weight_addrs[f'{layer_name}_w'] = w_addr
    weight_addrs[f'{layer_name}_b'] = b_addr
    
    print(f"{layer_name.upper():6s} | W: 0x{w_addr:08X} ({len(w_bytes):6d} bytes) | "
          f"B: 0x{b_addr:08X} ({len(b_bytes):4d} bytes)")

# Allocate scratch buffer
scratch_size = 64 * 1024  # 64 KB
scratch_addr = device.allocate(scratch_size, MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT, 16)
print(f"\nScratch buffer: 0x{scratch_addr:08X} ({scratch_size} bytes)")

# Memory after upload
print("\nDevice Memory (After Upload):")
stats_uploaded = jit.get_heap_stats(print_s=True)

memory_used = (stats_initial['free_spiram'] - stats_uploaded['free_spiram']) / 1024
print(f"\nâœ“ Total memory used: {memory_used:.2f} KB")
print("="*80)


WEIGHT UPLOAD
CONV1  | W: 0x48211CD0 (   144 bytes) | B: 0x48211D80 (  16 bytes)
CONV2  | W: 0x48211DB0 (  4608 bytes) | B: 0x48212FD0 (  32 bytes)
FC1    | W: 0x48213010 (200704 bytes) | B: 0x48244030 ( 128 bytes)
FC2    | W: 0x482440D0 (  1280 bytes) | B: 0x482445F0 (  10 bytes)

Scratch buffer: 0x48244600 (65536 bytes)

Device Memory (After Upload):
08:56:43 [p4jit.p4jit] [94mINFO[0m: [Heap Params]
08:56:43 [p4jit.p4jit] [94mINFO[0m:   free_spiram    :   31111952 bytes (30382.77 KB)
08:56:43 [p4jit.p4jit] [94mINFO[0m:   total_spiram   :   33554432 bytes (32768.00 KB)
08:56:43 [p4jit.p4jit] [94mINFO[0m:   free_internal  :     384063 bytes (375.06 KB)
08:56:43 [p4jit.p4jit] [94mINFO[0m:   total_internal :     464119 bytes (453.24 KB)

âœ“ Total memory used: 270.55 KB


## 10. Run Inference on Hardware

In [21]:
print("\n" + "="*80)
print("INFERENCE ON ESP32-P4 @ 360 MHz")
print("="*80)

import time
results = {}

for digit in range(10):
    print(f"\nTesting digit {digit}...")
    
    # Upload input image
    img_bytes = test_images_int8[digit].tobytes()
    input_addr = device.allocate(len(img_bytes), MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT, 16)
    device.write_memory(input_addr, img_bytes)
    
    # Prepare function arguments
    # mnist_inference(input, w1, b1, w2, b2, w_fc1, b_fc1, w_fc2, b_fc2, scratch)
    args_blob = struct.pack('<IIIIIIIIII',
        input_addr,
        weight_addrs['conv1_w'], weight_addrs['conv1_b'],
        weight_addrs['conv2_w'], weight_addrs['conv2_b'],
        weight_addrs['fc1_w'], weight_addrs['fc1_b'],
        weight_addrs['fc2_w'], weight_addrs['fc2_b'],
        scratch_addr
    )
    
    # Execute
    start = time.time()
    func(args_blob)
    duration_ms = (time.time() - start) * 1000
    
    # Read result (slot 31)
    result_addr = func.args_addr + 124
    result_bytes = device.read_memory(result_addr, 4)
    predicted = struct.unpack('<i', result_bytes)[0]
    
    # Cleanup input
    device.free(input_addr)
    
    # Store results
    results[digit] = {
        'true': digit,
        'predicted': predicted,
        'correct': (predicted == digit),
        'time_ms': duration_ms
    }
    
    status = "âœ“" if predicted == digit else "âœ—"
    print(f"  {status} True: {digit} | Predicted: {predicted} | Time: {duration_ms:.2f} ms")

# Summary
correct = sum(1 for r in results.values() if r['correct'])
accuracy = 100.0 * correct / len(results)
avg_time = np.mean([r['time_ms'] for r in results.values()])
min_time = np.min([r['time_ms'] for r in results.values()])
max_time = np.max([r['time_ms'] for r in results.values()])

print("\n" + "="*80)
print("RESULTS SUMMARY")
print("="*80)
print(f"âœ“ Accuracy: {correct}/{len(results)} = {accuracy:.1f}%")
print(f"âœ“ Avg inference time: {avg_time:.2f} ms")
print(f"âœ“ Min/Max time: {min_time:.2f} / {max_time:.2f} ms")
print(f"âœ“ Throughput: {1000/avg_time:.1f} inferences/second")
print("="*80)


INFERENCE ON ESP32-P4 @ 360 MHz

Testing digit 0...
  âœ— True: 0 | Predicted: 9 | Time: 55.93 ms

Testing digit 1...
  âœ— True: 1 | Predicted: 0 | Time: 55.41 ms

Testing digit 2...
  âœ— True: 2 | Predicted: 0 | Time: 56.70 ms

Testing digit 3...
  âœ— True: 3 | Predicted: 9 | Time: 56.22 ms

Testing digit 4...
  âœ— True: 4 | Predicted: 9 | Time: 56.47 ms

Testing digit 5...
  âœ— True: 5 | Predicted: 9 | Time: 56.52 ms

Testing digit 6...
  âœ— True: 6 | Predicted: 0 | Time: 55.60 ms

Testing digit 7...
  âœ— True: 7 | Predicted: 0 | Time: 55.34 ms

Testing digit 8...
  âœ“ True: 8 | Predicted: 8 | Time: 55.40 ms

Testing digit 9...
  âœ— True: 9 | Predicted: 8 | Time: 56.44 ms

RESULTS SUMMARY
âœ“ Accuracy: 1/10 = 10.0%
âœ“ Avg inference time: 56.00 ms
âœ“ Min/Max time: 55.34 / 56.70 ms
âœ“ Throughput: 17.9 inferences/second


## 11. Performance Visualization & Analysis

In [None]:
# Confusion matrix
confusion = np.zeros((10, 10), dtype=int)
for r in results.values():
    confusion[r['true'], r['predicted']] += 1

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

# Confusion matrix
im = ax1.imshow(confusion, cmap='Blues')
ax1.set_xticks(np.arange(10))
ax1.set_yticks(np.arange(10))
ax1.set_xlabel('Predicted', fontsize=12, fontweight='bold')
ax1.set_ylabel('True Label', fontsize=12, fontweight='bold')
ax1.set_title(f'Confusion Matrix\nAccuracy: {accuracy:.1f}%', 
              fontsize=14, fontweight='bold')

for i in range(10):
    for j in range(10):
        text = ax1.text(j, i, confusion[i, j],
                       ha="center", va="center",
                       color="white" if confusion[i, j] > 0.5 else "black",
                       fontsize=11, fontweight='bold')

plt.colorbar(im, ax=ax1)

# Inference time
times = [results[d]['time_ms'] for d in range(10)]
colors = ['green' if results[d]['correct'] else 'red' for d in range(10)]

bars = ax2.bar(range(10), times, color=colors, edgecolor='black', linewidth=1.5)
ax2.axhline(avg_time, color='blue', linestyle='--', linewidth=2, 
            label=f'Average: {avg_time:.2f} ms')
ax2.set_xlabel('Digit Class', fontsize=12, fontweight='bold')
ax2.set_ylabel('Inference Time (ms)', fontsize=12, fontweight='bold')
ax2.set_title('Inference Time per Class', fontsize=14, fontweight='bold')
ax2.set_xticks(range(10))
ax2.legend(fontsize=11)
ax2.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig(RESULTS_DIR / 'inference_results.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Memory breakdown
memory_breakdown = {
    'Code': func.stats['code_size'] / 1024,
    'Conv1 Weights': (quantized_weights['conv1']['weight'].size + 
                      quantized_weights['conv1']['bias'].size) / 1024,
    'Conv2 Weights': (quantized_weights['conv2']['weight'].size + 
                      quantized_weights['conv2']['bias'].size) / 1024,
    'FC1 Weights': (quantized_weights['fc1']['weight'].size + 
                    quantized_weights['fc1']['bias'].size) / 1024,
    'FC2 Weights': (quantized_weights['fc2']['weight'].size + 
                    quantized_weights['fc2']['bias'].size) / 1024,
    'Scratch Buffer': scratch_size / 1024
}

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

# Pie chart
colors_pie = plt.cm.Set3(np.linspace(0, 1, len(memory_breakdown)))
wedges, texts, autotexts = ax1.pie(memory_breakdown.values(), 
                                     labels=memory_breakdown.keys(),
                                     autopct='%1.1f%%',
                                     colors=colors_pie,
                                     startangle=90)

for autotext in autotexts:
    autotext.set_color('black')
    autotext.set_fontweight('bold')
    autotext.set_fontsize(10)

ax1.set_title(f'Memory Usage Breakdown\nTotal: {memory_used:.2f} KB', 
              fontsize=14, fontweight='bold')

# Bar chart
bars = ax2.barh(list(memory_breakdown.keys()), list(memory_breakdown.values()),
                color=colors_pie, edgecolor='black', linewidth=1.5)
ax2.set_xlabel('Size (KB)', fontsize=12, fontweight='bold')
ax2.set_title('Memory Components', fontsize=14, fontweight='bold')
ax2.grid(True, alpha=0.3, axis='x')

# Add values on bars
for i, bar in enumerate(bars):
    width = bar.get_width()
    ax2.text(width, bar.get_y() + bar.get_height()/2,
            f'{width:.2f} KB',
            ha='left', va='center', fontsize=10, fontweight='bold')

plt.tight_layout()
plt.savefig(RESULTS_DIR / 'memory_analysis.png', dpi=150, bbox_inches='tight')
plt.show()

## 12. Final Report & Specifications

In [None]:
print("\n" + "="*80)
print("FINAL SYSTEM SPECIFICATIONS")
print("="*80)

print("\nðŸ“Š Model Architecture:")
print(f"  â€¢ Input: 28Ã—28 grayscale (784 pixels)")
print(f"  â€¢ Conv1: 1â†’16 channels, 3Ã—3 kernel, ReLU, MaxPool")
print(f"  â€¢ Conv2: 16â†’32 channels, 3Ã—3 kernel, ReLU, MaxPool")
print(f"  â€¢ FC1: 800â†’128, ReLU")
print(f"  â€¢ FC2: 128â†’10 (output logits)")
print(f"  â€¢ Total parameters: {total_int8_params:,} (INT8)")

print("\nâš¡ Hardware Performance:")
print(f"  â€¢ Platform: ESP32-P4 @ 360 MHz")
print(f"  â€¢ Architecture: RISC-V RV32IMAFC")
print(f"  â€¢ Memory: {stats_initial['total_spiram']//1024//1024} MB SPIRAM")
print(f"  â€¢ Code size: {func.stats['code_size']/1024:.2f} KB")
print(f"  â€¢ Avg inference: {avg_time:.2f} ms")
print(f"  â€¢ Throughput: {1000/avg_time:.1f} fps")
print(f"  â€¢ Accuracy: {accuracy:.1f}%")

print("\nðŸ’¾ Memory Footprint:")
print(f"  â€¢ Weights (INT8): {total_int8_params/1024:.2f} KB")
print(f"  â€¢ Code: {func.stats['code_size']/1024:.2f} KB")
print(f"  â€¢ Scratch buffer: {scratch_size/1024:.2f} KB")
print(f"  â€¢ Total SPIRAM: {memory_used:.2f} KB")

print("\nðŸŽ¯ Quantization:")
print(f"  â€¢ Weight precision: INT8 (8-bit)")
print(f"  â€¢ Activation precision: INT8 (8-bit)")
print(f"  â€¢ Accumulator: INT32 (32-bit)")
print(f"  â€¢ Scaling: Power-of-2 (bit-shift)")
print(f"  â€¢ Compression: 4Ã— vs FP32")
print(f"  â€¢ Training method: QAT with STE")

print("\nðŸ“ˆ Comparison:")
print(f"  â€¢ FP32 model size: {total_fp32_params/1024:.2f} KB")
print(f"  â€¢ INT8 model size: {total_int8_params/1024:.2f} KB")
print(f"  â€¢ Size reduction: {(1-total_int8_params/total_fp32_params)*100:.1f}%")
print(f"  â€¢ Training accuracy: {history['test_acc'][-1]:.2f}%")
print(f"  â€¢ On-device accuracy: {accuracy:.1f}%")
print(f"  â€¢ Accuracy retention: {accuracy/history['test_acc'][-1]*100:.1f}%")

print("\nðŸš€ P4-JIT Advantages:")
print("  1. Deploy time: 2-3 seconds (vs 30-60s firmware rebuild)")
print("  2. No firmware changes required")
print("  3. Native RISC-V execution (zero interpreter overhead)")
print("  4. Dynamic code loading via USB")
print("  5. Seamless Python â†’ C â†’ Hardware workflow")
print("  6. Real-time performance monitoring")
print("  7. Comprehensive binary introspection")

print("\n" + "="*80)
print("âœ¨ DEMONSTRATION COMPLETE âœ¨")
print("="*80)

## 13. Cleanup

In [None]:
print("\n" + "="*80)
print("CLEANUP")
print("="*80)

# Free weights
for addr in weight_addrs.values():
    device.free(addr)

# Free scratch
device.free(scratch_addr)

# Free function
func.free()

print("âœ“ All device memory freed")

# Final stats
stats_final = jit.get_heap_stats(print_s=False)
reclaimed = (stats_final['free_spiram'] - stats_uploaded['free_spiram']) / 1024

print(f"âœ“ Memory reclaimed: {reclaimed:.2f} KB")
print(f"âœ“ Final free SPIRAM: {stats_final['free_spiram']//1024//1024} MB")

# Disconnect
jit.session.device.disconnect()
print("âœ“ Device disconnected")

print("\n" + "="*80)
print(f"All results saved to: {RESULTS_DIR}")
print("="*80)