# YOLOv26 Quantization Aware Training (QAT)

## Overview
This notebook implements a complete Quantization Aware Training (QAT) pipeline for **YOLOv26**, tailored for deployment on **ESP32-P4** (via ESP-DL and ESP-PPQ).

### Key Features of YOLOv26
YOLOv26 represents an evolution in the YOLO architecture focusing on efficiency and end-to-end deployment.
- **NMS-Free Prediction**: Utilizes a dual-head architecture (One-to-Many for training signal, One-to-One for inference), eliminating the need for Non-Maximum Suppression (NMS) during inference.
- **RegMax=1**: Unlike YOLOv8 (RegMax=16) which uses Distribution Focal Loss (DFL), YOLOv26 uses direct regression (`RegMax=1`), reducing output channel complexity and post-processing overhead.
- **Efficient Attention**: Incorporates attention mechanisms (e.g., C2PSA) optimized for low-latency edge devices.

### Optimizer Choice: SGD vs MuoSGD
For this QAT pipeline, we have chosen **SGD (Stochastic Gradient Descent)** over **MuoSGD**.
- **Reason**: We involve only a small fine-tuning step on a model that is already pretrained and calibrated.
- **Simplicity**: Since the model starts in a good state, we do not need the additional complexity of MuoSGD (which is designed for more aggressive topology changes or from-scratch quantization training). Standard SGD is sufficient and effective for this task.

### QAT Workflow Specifics
This pipeline addresses specific challenges in quantizing YOLOv26 for embedded targets:
1.  **Custom Export Patches**: 
    - Patches `Attention` modules to use static reshaping (`view(-1, ...)`), ensuring compatibility with ESP-DL's static graph compiler.
    - Preserves all 6 output heads (3x Aux, 3x Main) during QAT to calculate accurate loss, while enabling `dynamic=False` export for final deployment.
2.  **Sensitive Layer Analysis**: Automatically identifies and disables quantization for the Auxiliary Branch to stabilize training.
3.  **PPQ Integration**: Uses the internal quantization pipeline (Simplify -> Fusion -> Parameter Quantize -> Calibration -> Finetuning) before starting the QAT loop.
4.  **Custom Validator**: A dedicated validator that mimics the Quantized Graph execution to report real on-target mAP metrics.


In [14]:
# Install Dependencies
%pip install -r requirements.txt

Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 25.0.1 -> 26.0
[notice] To update, run: python.exe -m pip install --upgrade pip


In [15]:
# System Imports
import os
import sys
# Add pipeline source to path for local imports
sys.path.append('scripts')

import types
import torch
import esp_ppq.lib as PFL
from esp_ppq.executor import TorchExecutor
from esp_ppq.core import QuantizationVisibility, TargetPlatform
from esp_ppq.api import get_target_platform
from esp_ppq.api.interface import load_onnx_graph
from esp_ppq.quantization.optim import (
    QuantizeSimplifyPass, QuantizeFusionPass, ParameterQuantizePass,
    RuntimeCalibrationPass, PassiveParameterQuantizePass, QuantAlignmentPass
)
from ultralytics.data.utils import check_det_dataset

In [16]:
# --- Important Configuration ---
IMG_SZ_I = 512   # any size
PLATFORM = "s3"  # p4 or s3
DATA_YAML_FILE_I = "coco.yaml" # Ultralytics default YAML file ,you can use your own

In [None]:
# --- Configuration ---
class QATConfig:
    # Training Parameters
    EPOCHS = 10           # 640 img_sz need about 10 epochs
    BATCH_SIZE = 14       # on 8GB VRAM use 14 for 640 img_sz and 20 for 512 img_sz
    IMG_SZ = IMG_SZ_I
    DATA_FRACTION = 1.0   # Use 0.005 (0.5%) of dataset for fast debugging
    SEED = 1234
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Optimizer
    OPTIMIZER_LR = 1e-6
    OPTIMIZER_MOMENTUM = 0.937
    OPTIMIZER_WEIGHT_DECAY = 5e-4
    
    # Data Settings
    DATA_YAML_FILE = DATA_YAML_FILE_I
    DATA_FALLBACK_PATH = "coco2017/images/train2017"
    CALIB_MAX_IMAGES = 8192
    CALIB_VALID_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.bmp', '.webp')
    
    # Quantization Settings
    TARGET_PLATFORM = get_target_platform("esp32"+PLATFORM, 8)
    CALIB_STEPS = 32
    QUANT_CALIB_METHOD = "kl"
    QUANT_ALIGNMENT = "Align to Output"
    EXPORT_OPSET = 13
    EXPORT_DYNAMIC = False 
    
    # Loss Defaults (Standard YOLOv8/v10/v26 defaults)
    LOSS_DEFAULTS = {
        'box': 7.5, 'cls': 0.5, 'dfl': 1.5, 'pose': 12.0, 'kobj': 1.0,
        'label_smoothing': 0.0, 'nbs': 64, 'hsv_h': 0.015, 'hsv_s': 0.7,
        'hsv_v': 0.4, 'degrees': 0.0, 'translate': 0.1, 'scale': 0.5,
        'shear': 0.0, 'perspective': 0.0, 'flipud': 0.0, 'fliplr': 0.5,
        'mosaic': 1.0, 'mixup': 0.0, 'copy_paste': 0.0,
    }

    # Model Paths
    # Assuming current directory
    BASE_DIR = os.getcwd()
    MODEL_NAME = "yolo26n"
    PT_FILE = f"{MODEL_NAME}.pt"
    ONNX_FILE = f"{MODEL_NAME}_train.onnx"
    
    # Derived Paths - Output Structure updated for GitHub workflow
    ESPDL_OUTPUT_DIR = os.path.join(BASE_DIR, "output", f"{DATA_YAML_FILE_I[:-5]}_{IMG_SZ}_s8_{PLATFORM}")
    ONNX_PATH = os.path.join(ESPDL_OUTPUT_DIR, ONNX_FILE)
    
    # Plotting
    VAL_PLOT_MAX_BATCHES = 3

    # Validation Batch Size
    VAL_BATCH_SIZE = 24

In [18]:
# --- Virtual Module Injection ---
# This tricks Python into thinking 'config.py' exists and contains our QATConfig class.
# This allows 'trainer.py' and 'utils.py' to do 'from config import QATConfig'

if 'config' in sys.modules:
    # If it exists, update it
    sys.modules['config'].QATConfig = QATConfig
else:
    # Create a dummy module
    config_module = types.ModuleType('config')
    config_module.QATConfig = QATConfig
    sys.modules['config'] = config_module

# Must ensure output directory exists for config files early usage if any
if not os.path.exists(QATConfig.ESPDL_OUTPUT_DIR):
    os.makedirs(QATConfig.ESPDL_OUTPUT_DIR)

# Now we can safely import our local modules that depend on 'config'
from utils import seed_everything, register_mod_op, patch_v8_detection_loss, get_exclusive_ancestors
from dataset import get_calibration_loader, get_train_loader
from trainer import QATTrainer
from export import apply_export_patches, ESP_YOLO

# Set seeds for reproducibility
seed_everything(QATConfig.SEED)
register_mod_op()
patch_v8_detection_loss()
print("Virtual Config injected and Modules initialized.")

Registered 'Mod' handler for PPQ.
Virtual Config injected and Modules initialized.


### ESP-PPQ ONNX Compatibility Patch

In [19]:
from esp_ppq_patch import apply_esp_ppq_patches
apply_esp_ppq_patches()

Applying ESP-PPQ Runtime Patches...
  [x] Patched OnnxParser.refine_graph
  [x] Patched Backend: Slice
  [x] Patched Backend: Gather
ESP-PPQ Runtime Patches Applied Successfully.


## 1. Model Preparation & Export
We load the PyTorch checkpoint (`.pt`) and export it to ONNX. 
Critically, we apply `ESP_Attention` patches here to ensure the exported ONNX graph uses static reshaping, preventing runtime errors on the target device.

In [20]:
def extract_model_meta():
    """Dynamically extracts model metadata (NC, RegMax, etc.) from the PT checkpoint."""
    print(f"Loading {QATConfig.PT_FILE} to extract metadata...")
    # We use ESP_YOLO here just to inspect, but standard YOLO would work too for this part
    tmp_model = ESP_YOLO(QATConfig.PT_FILE)
    
    # Access the Detect Head (last layer)
    detect_head = tmp_model.model.model[-1]
    
    # Derive input channels (ch) from the first layer of cv2
    ch = [m[0].conv.in_channels for m in detect_head.cv2]
    
    meta = {
        'nc': detect_head.nc,
        'reg_max': detect_head.reg_max,
        'stride': detect_head.stride,
        'ch': ch
    }
    
    if isinstance(meta['stride'], torch.Tensor):
        meta['stride'] = meta['stride'].tolist()
        
    print(f"Extracted Metadata: NC={meta['nc']}, RegMax={meta['reg_max']}, Stride={meta['stride']}")
    return meta

def prepare_onnx():
    """Ensures the correct ONNX model exists with all patches applied."""
    try:
        # Load using ESP_YOLO to enforce custom export logic
        model = ESP_YOLO(QATConfig.PT_FILE)
        
        # Apply patches (Attention & Detect)
        apply_export_patches(model)
        
        # Export
        model.export(format="onnx", opset=QATConfig.EXPORT_OPSET, simplify=True, 
                        imgsz=QATConfig.IMG_SZ, dynamic=QATConfig.EXPORT_DYNAMIC)
        print("Export complete.")
            
    except Exception as e:
        print(f"Error exporting model: {e}")
        raise e

# Run Preparation
prepare_onnx()
model_meta = extract_model_meta()

Applying ESP-DL patches for export...
Patched 2 Attention modules.
Patched Detect module: <class 'ultralytics.nn.modules.head.Detect'>
Ultralytics 8.4.7  Python-3.9.21 torch-2.8.0+cu126 CPU (13th Gen Intel Core(TM) i7-13650HX)
>> Fuse method blocked! Keeping all heads.
YOLO26n summary (fused): 146 layers, 2,562,496 parameters, 0 gradients, 6.0 GFLOPs

[34m[1mPyTorch:[0m starting from 'yolo26n.pt' with input shape (16, 3, 512, 512) BCHW and output shape(s) ((16, 84, 64, 64), (16, 84, 32, 32), (16, 84, 16, 16), (16, 84, 64, 64), (16, 84, 32, 32), (16, 84, 16, 16)) (5.3 MB)

[34m[1mONNX:[0m starting export with onnx 1.17.0 opset 13...
[34m[1mONNX:[0m simplifying with onnxsim 0.4.36...
[34m[1mONNX:[0m export success  1.3s, saved as 'c:\Users\orani\bilel\git_projects\yolo26\yolo26n_esp32_repo\esp-dl\examples\tutorial\how_to_quantize_model\quantize_yolo26\output\coco_640_s8_s3\yolo26n_train.onnx' (9.9 MB)

Export complete (2.4s)
Results saved to [1mC:\Users\orani\bilel\git_proje

## 2. Quantizer Initialization
We load the ONNX graph and initialize the PPQ Quantizer.
We also perform a graph analysis to separate the **Auxiliary Branch** (used for training guidance) from the **Main Branch** (used for inference). We disable quantization on the Aux branch to ensure gradients flow correctly without noise.

In [21]:
# Load Graph
graph = load_onnx_graph(onnx_import_file=QATConfig.ONNX_PATH)

# Identify and Disable Aux Layer Quantization
output_names = list(graph.outputs.keys())
aux_ops = set()
if len(output_names) >= 6:
    # Assuming order: [one2many_p3, ... one2one_p3, ...]
    # Note: Custom exporter uses names one2many_p3 etc. which is explicitly handled here
    aux_outputs = output_names[0:3]
    main_outputs = output_names[3:6]
    print("Identifying auxiliary branch operators to disable quantization...")
    aux_ops = get_exclusive_ancestors(graph, aux_outputs, main_outputs)
    print(f"Found {len(aux_ops)} operators exclusive to auxiliary branch.")
else:
    print("WARNING: Graph output count < 6. Cannot separate aux/main branches.")

# Initialize Quantizer
quantizer = PFL.Quantizer(platform=QATConfig.TARGET_PLATFORM, graph=graph)
dispatching_table = PFL.Dispatcher(graph=graph, method="conservative").dispatch(
    quantizer.quant_operation_types
)

# Enforce FP32 for Aux ops and defaults
for opname, platform in dispatching_table.items():
    if platform == TargetPlatform.UNSPECIFIED:
        dispatching_table[opname] = TargetPlatform(quantizer.target_platform)
        
for op in aux_ops:
    if op.name in dispatching_table:
        dispatching_table[op.name] = TargetPlatform.FP32

# Apply Quantization
for op in graph.operations.values():
    quantizer.quantize_operation(op_name=op.name, platform=dispatching_table[op.name])

Identifying auxiliary branch operators to disable quantization...
Found 45 operators exclusive to auxiliary branch.


## 3. Calibration
Before training, we must calibrate the quantized parameters (scale/offset) using a representative dataset. This initializes the network in a good state.

In [22]:
# Data Loaders
data_cfg = check_det_dataset(QATConfig.DATA_YAML_FILE)
cali_loader = get_calibration_loader(data_cfg)
train_loader = get_train_loader(data_cfg)

# Tracing
executor = TorchExecutor(graph=graph)
dummy_input = torch.zeros([1, 3, QATConfig.IMG_SZ, QATConfig.IMG_SZ]).to(QATConfig.DEVICE)
executor.tracing_operation_meta(inputs=dummy_input)

# Calibration Pipeline
print("Running Calibration Pipeline...")
pipeline = PFL.Pipeline([
    QuantizeSimplifyPass(),
    QuantizeFusionPass(activation_type=quantizer.activation_fusion_types),
    ParameterQuantizePass(),
    RuntimeCalibrationPass(method=QATConfig.QUANT_CALIB_METHOD),
    PassiveParameterQuantizePass(clip_visiblity=QuantizationVisibility.EXPORT_WHEN_ACTIVE),
    QuantAlignmentPass(elementwise_alignment=QATConfig.QUANT_ALIGNMENT),
])

pipeline.optimize(
    calib_steps=QATConfig.CALIB_STEPS,
    collate_fn=(lambda x: x.type(torch.float).to(QATConfig.DEVICE)),
    graph=graph,
    dataloader=cali_loader,
    executor=executor,
)

Using dataset at: C:\Users\orani\OneDrive\Desktop\bilel\Projects\p_2025\esp32p4dl_pip\pytorch_lab\version1\yolov5\datasets\coco\images\train2017
Fast image access  (ping: 0.10.1 ms, read: 1038.5659.0 MB/s, size: 155.2 KB)
[KScanning C:\Users\orani\OneDrive\Desktop\bilel\Projects\p_2025\esp32p4dl_pip\pytorch_lab\version1\yolov5\datasets\coco\labels\train2017.cache... 117266 images, 1021 backgrounds, 0 corrupt: 100% ━━━━━━━━━━━━ 118287/118287  0.0s
Subsampling dataset to 0.5%: 591 samples
Running Calibration Pipeline...
[12:58:57] PPQ Quantize Simplify Pass Running ...         Finished.
[12:58:57] PPQ Quantization Fusion Pass Running ...       Finished.
[12:58:58] PPQ Parameter Quantization Pass Running ...    Finished.
[12:58:58] PPQ Runtime Calibration Pass Running ...       

Calibration Progress(Phase 1):   0%|          | 0/32 [00:00<?, ?it/s]


RuntimeError: Op Execution Error: /model.10/m/m.0/attn/Reshape(Type: Reshape, Num of Input: 2, Num of Output: 1)

### Baseline Validation (Check PTQ Accuracy)

In [None]:
# Initialize the trainer just for evaluation
print("Initializing Trainer for Baseline Check...")
trainer = QATTrainer(graph=graph, model_meta=model_meta, device=QATConfig.DEVICE)

# Run validation on the graph in its current state (after Calibration/PTQ, before Training)
print("Running Baseline Validation on Quantized Graph...")
ptq_mAP = trainer.eval()

print(f"\n--- Baseline Results ---")
print(f"PTQ mAP50-95: {ptq_mAP:.3f}")
print(f"Target: This serves as the baseline. QAT will now attempt to improve this score.")

Initializing Trainer for Baseline Check...
Loading YOLOv26n model in Trainer to access correct Loss function...
Initializing Persistent Validator (reusing dataloader)...
Running Baseline Validation on Quantized Graph...
Ultralytics 8.4.7  Python-3.9.21 torch-2.8.0+cu126 CUDA:0 (NVIDIA GeForce RTX 4060 Laptop GPU, 8188MiB)
[34m[1mval: [0mFast image access  (ping: 0.30.1 ms, read: 176.5117.9 MB/s, size: 104.6 KB)
[K[34m[1mval: [0mScanning C:\Users\orani\OneDrive\Desktop\bilel\Projects\p_2025\esp32p4dl_pip\pytorch_lab\version1\yolov5\datasets\coco\labels\val2017.cache... 4952 images, 48 backgrounds, 0 corrupt: 100% ━━━━━━━━━━━━ 5000/5000  0.0s
[K                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100% ━━━━━━━━━━━━ 209/209 1.3it/s 2:37<0.6s
                   all       5000      36335      0.604      0.444      0.492      0.342

--- Baseline Results ---
PTQ mAP50-95: 0.342
Target: This serves as the baseline. QAT will now attempt to improve thi

## 4. QAT Training Loop
We now fine-tune the quantized model. The `QATTrainer` handles the forward pass through the quantized graph (simulated by PPQ) and the backward pass to update weights.

Note: The validation step uses our `QuantizedModelValidator` which manually decodes the raw graph outputs using the metadata extracted earlier (`NC`, `RegMax`, `Estride`).

In [None]:
print("Starting QAT Training...")

if not os.path.exists(QATConfig.ESPDL_OUTPUT_DIR):
    os.makedirs(QATConfig.ESPDL_OUTPUT_DIR)

best_mAP = 0
for epoch in range(QATConfig.EPOCHS):
    print(f"\n--- Epoch {epoch+1}/{QATConfig.EPOCHS} ---")
    
    # Train Epoch
    trainer.epoch(train_loader)
    
    # Validate
    current_mAP = trainer.eval()
    print(f"Epoch: {epoch+1}, mAP50-95: {current_mAP:.3f}")
    
    if current_mAP > best_mAP:
        best_mAP = current_mAP
        print(f"New best mAP! Saving to {QATConfig.ESPDL_OUTPUT_DIR}...")
        
        # Save Native Graph using the new helper method
        trainer.save_graph(os.path.join(QATConfig.ESPDL_OUTPUT_DIR, "Best_yolo26n.native"))

Starting QAT Training...

--- Epoch 1/2 ---


Epoch 0: 100%|██████████| 43/43 [00:38<00:00,  1.12it/s, loss=6.7541] 

Epoch Loss: 36.5569
Ultralytics 8.4.7  Python-3.9.21 torch-2.8.0+cu126 CUDA:0 (NVIDIA GeForce RTX 4060 Laptop GPU, 8188MiB)





[K                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100% ━━━━━━━━━━━━ 209/209 2.1it/s 1:40<0.5ss
                   all       5000      36335      0.609      0.463      0.506      0.357
Epoch: 1, mAP50-95: 0.357
New best mAP! Saving to c:\Users\orani\bilel\git_projects\yolo26\yolo26n_esp32_repo\esp-dl\examples\tutorial\how_to_quantize_model\quantize_yolo26\output\coco_640_s8_s3...

--- Epoch 2/2 ---


Epoch 1: 100%|██████████| 43/43 [00:37<00:00,  1.13it/s, loss=6.5141] 

Epoch Loss: 35.3756
Ultralytics 8.4.7  Python-3.9.21 torch-2.8.0+cu126 CUDA:0 (NVIDIA GeForce RTX 4060 Laptop GPU, 8188MiB)





[K                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100% ━━━━━━━━━━━━ 209/209 2.0it/s 1:43<0.6ss
                   all       5000      36335      0.608      0.472      0.515      0.364
Epoch: 2, mAP50-95: 0.364
New best mAP! Saving to c:\Users\orani\bilel\git_projects\yolo26\yolo26n_esp32_repo\esp-dl\examples\tutorial\how_to_quantize_model\quantize_yolo26\output\coco_640_s8_s3...


In [None]:
# Load Best Graph back into memory at the end
print("Reloading Best Graph...")
trainer.load_graph(os.path.join(QATConfig.ESPDL_OUTPUT_DIR, "Best_yolo26n.native"))
# Update global graph 
graph = trainer.graph 
print("Best Graph Reloaded.")

Reloading Best Graph...
Best Graph Reloaded.


In [None]:
import esp_ppq.lib as PFL
from config import QATConfig
import os
from esp_ppq.api import load_native_graph
from esp_ppq.IR import BaseGraph  

# --- Helper Function: Robust Graph Pruning ---
def prune_graph_safely(graph: BaseGraph) -> BaseGraph:
    """
    Robust pruning function for ESP-PPQ graphs.
    Safely removes disconnected operations and unused variables.
    """
    print("Starting Safe Pruning Procedure...")
    round_count = 0
    while True:
        this_round_op_removed = 0
        this_round_var_removed = 0
        
        # A. Find Dead Ops
        dead_ops = []
        for op in list(graph.operations.values()):
            is_output = any(var.name in graph.outputs for var in op.outputs)
            has_consumers = any(len(var.dest_ops) > 0 for var in op.outputs)
            
            if not is_output and not has_consumers:
                dead_ops.append(op)
        
        # Remove Dead Ops Safely
        for op in dead_ops:
            for var in list(op.inputs):
                 op.inputs.remove(var)
                 if op in var.dest_ops:
                     var.dest_ops.remove(op)
            graph.remove_operation(op, keep_coherence=False)
            this_round_op_removed += 1
            
        # B. Find Dead Variables
        dead_vars = []
        for var in list(graph.variables.values()):
            is_input = var.name in graph.inputs
            is_output = var.name in graph.outputs
            if is_input or is_output: continue
            
            if len(var.dest_ops) == 0:
                dead_vars.append(var)
                 
        # Remove Dead Variables
        for var in dead_vars:
            if var.name in graph.variables:
                graph.variables.pop(var.name)
                this_round_var_removed += 1
        
        round_count += 1
        if this_round_op_removed == 0 and this_round_var_removed == 0:
            break
            
    print(f"Pruning Finished. Total Rounds: {round_count}")
    return graph

# --- MAIN LOGIC: Split Outputs at Concat Source ---

print("Reloading Best Graph...")
native_graph_path = os.path.join(QATConfig.ESPDL_OUTPUT_DIR, "Best_yolo26n.native")
graph = load_native_graph(import_file=native_graph_path)

print("Preparing Graph for Inference...")

# 1. Remove Aux Heads First (Cleanup)
output_names = list(graph.outputs.keys())
if len(output_names) >= 6:
    aux_heads = output_names[0:3] 
    print(f"Removing Aux Heads: {aux_heads}")
    for name in aux_heads:
        if name in graph.outputs:
            graph.outputs.pop(name)
    prune_graph_safely(graph)

# 2. Apply Splitting Strategy
targets = ["one2one_p3", "one2one_p4", "one2one_p5"]
collected_outputs = {}

for target_name in targets:
    if target_name in graph.outputs:
        original_output_var = graph.variables[target_name]
        producer = original_output_var.source_op 
        
        if producer and producer.type == "Concat":
            print(f"Splitting {target_name} at Source Concat ({producer.name})...")
            
            box_var = None
            cls_var = None
            
            # Inspect Concat inputs to find Box(4) and Cls(80)
            for input_var in producer.inputs:
                dims = input_var.shape
                if dims is not None:
                    if 4 in dims: box_var = input_var
                    elif 80 in dims: cls_var = input_var
            
            if box_var and cls_var:
                # Rename Scheme: Suffix per Request
                pair_config = [
                    (box_var, f"{target_name}_box"),  # e.g. one2one_p3_box
                    (cls_var, f"{target_name}_cls")   # e.g. one2one_p3_cls
                ]
                
                for var, new_name in pair_config:
                    old_name = var.name
                    
                    # Robust Renaming (Modify Private + Registry)
                    if old_name in graph.variables:
                        graph.variables.pop(old_name)
                    
                    var._name = new_name
                    graph.variables[new_name] = var
                    
                    collected_outputs[new_name] = var
                
                # Update Graph Outputs: Remove old name
                graph.outputs.pop(target_name)
                
                # Properly Remove Concat Op
                graph.remove_operation(producer, keep_coherence=False)
                # Unlink inputs
                for var in producer.inputs:
                    if producer in var.dest_ops:
                        var.dest_ops.remove(producer)

                print(f"  -> Created {pair_config[0][1]} and {pair_config[1][1]}")
            else:
                print(f"ERROR: Shape mismatch for {target_name}")
        else:
             print(f"WARNING: Source for {target_name} is not Concat.")

# 3. Register New Outputs in Strict Order (Box Group then Cls Group)
final_output_list = [
    "one2one_p3_box", "one2one_p4_box", "one2one_p5_box",
    "one2one_p3_cls", "one2one_p4_cls", "one2one_p5_cls"
]

print("Updating Graph Output Order...")
graph.outputs.clear() # Enforce strict order
count_added = 0
for name in final_output_list:
    if name in collected_outputs:
        graph.outputs[name] = collected_outputs[name]
        count_added += 1

print(f"Registered {count_added} split outputs.")

# 4. Final Prune
prune_graph_safely(graph)

print(f"Final Graph Outputs: {list(graph.outputs.keys())}") 

# 5. Export
inference_export_path = os.path.join(QATConfig.ESPDL_OUTPUT_DIR, f"yolo26n_{QATConfig.IMG_SZ}_s8_{PLATFORM}.espdl")
print(f"Exporting Split Inference Model to {inference_export_path}...")

exporter = PFL.Exporter(platform=QATConfig.TARGET_PLATFORM)
exporter.export(inference_export_path, graph=graph)
print("Export Done!")

Reloading Best Graph...
Preparing Graph for Inference...
Removing Aux Heads: ['one2many_p3', 'one2many_p4', 'one2many_p5']
Starting Safe Pruning Procedure...
Pruning Finished. Total Rounds: 11
Splitting one2one_p3 at Source Concat (/model.23/Concat_3)...
  -> Created one2one_p3_box and one2one_p3_cls
Splitting one2one_p4 at Source Concat (/model.23/Concat_4)...
  -> Created one2one_p4_box and one2one_p4_cls
Splitting one2one_p5 at Source Concat (/model.23/Concat_5)...
  -> Created one2one_p5_box and one2one_p5_cls
Updating Graph Output Order...
Registered 6 split outputs.
Starting Safe Pruning Procedure...
Pruning Finished. Total Rounds: 2
Final Graph Outputs: ['one2one_p3_box', 'one2one_p4_box', 'one2one_p5_box', 'one2one_p3_cls', 'one2one_p4_cls', 'one2one_p5_cls']
Exporting Split Inference Model to c:\Users\orani\bilel\git_projects\yolo26\yolo26n_esp32_repo\esp-dl\examples\tutorial\how_to_quantize_model\quantize_yolo26\output\coco_640_s8_s3\yolo26n_640_s8_s3.espdl...
[38;5;2m[INFO]