# Nano-U Training Workflow Example

This notebook demonstrates the complete training pipeline for the Nano-U segmentation model, including:
1. Dataset preparation
2. Teacher model training (BU_Net)
3. Student model training with knowledge distillation
4. NAS monitoring for architecture analysis
5. Quantization for edge deployment
6. Evaluation and visualization

**Runtime**: ~30-60 minutes on GPU (NVIDIA RTX 5060)

## Setup and Imports

In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# Add project to path
PROJECT_ROOT = Path.cwd().parent if 'notebooks' in str(Path.cwd()) else Path.cwd()
sys.path.insert(0, str(PROJECT_ROOT))

# TensorFlow setup
import tensorflow as tf
print(f"TensorFlow version: {tf.__version__}")
print(f"GPU available: {len(tf.config.list_physical_devices('GPU')) > 0}")

# Custom imports
from src.train import train
from src.utils.config import load_config
from src.utils import get_project_root

print("\nâœ“ Imports successful")

## Step 1: Load Configuration

In [None]:
# Load project configuration
config = load_config()

# Display key settings
print("ðŸ“‹ Configuration Loaded")
print(f"\nData Settings:")
print(f"  Input shape: {config['data']['input_shape']}")
print(f"  Normalization mean: {config['data']['normalization']['mean']}")
print(f"  Normalization std: {config['data']['normalization']['std']}")
print(f"  Train/Val/Test split: {config['data']['split']}")

print(f"\nModel Architectures:")
print(f"  Nano_U filters: {config['models']['nano_u']['filters']}")
print(f"  Nano_U bottleneck: {config['models']['nano_u']['bottleneck']}")
print(f"  BU_Net filters: {config['models']['bu_net']['filters'][:3]}... (6 total)")

print(f"\nTraining Settings:")
print(f"  Nano_U epochs: {config['training']['nano_u']['epochs']}")
print(f"  Nano_U batch size: {config['training']['nano_u']['batch_size']}")
print(f"  Distillation alpha: {config['training']['nano_u']['distillation']['alpha']}")
print(f"  Distillation temperature: {config['training']['nano_u']['distillation']['temperature']}")

## Step 2: Prepare Dataset

Run data preparation if not already done:

In [None]:
import subprocess
from pathlib import Path

# Check if data is already prepared
processed_dir = Path(get_project_root()) / 'data' / 'processed_data'

if (processed_dir / 'train' / 'img').exists():
    train_count = len(list((processed_dir / 'train' / 'img').glob('*.png')))
    print(f"âœ“ Data already prepared ({train_count} training samples)")
else:
    print("ðŸ”„ Running data preparation...")
    result = subprocess.run(
        [sys.executable, str(Path(get_project_root()) / 'src' / 'prepare_data.py')],
        cwd=str(get_project_root())
    )
    if result.returncode == 0:
        print("âœ“ Data preparation complete")
    else:
        print("âœ— Data preparation failed")

## Step 3: Train Teacher Model (BU_Net)

The teacher is a larger U-Net architecture that will guide the student model.

In [None]:
print("ðŸŽ“ Training Teacher Model (BU_Net)")
print("="*60)

# Train teacher
teacher_model, teacher_history = train(
    model_name="bu_net",
    epochs=50,  # Reduced for demo (use 100 in production)
    batch_size=16,
    lr=1e-4,
    enable_nas_monitoring=False,  # Skip NAS for teacher
    augment=True
)

print("\nâœ“ Teacher training complete")
print(f"  Final validation IoU: {teacher_history.history['val_binary_iou'][-1]:.4f}")

## Step 4: Train Student Model with Knowledge Distillation

Now train the lightweight student model using the teacher as a guide.

In [None]:
print("ðŸ‘¶ Training Student Model (Nano_U) with Distillation")
print("="*60)

# Get teacher model path
models_dir = Path(get_project_root()) / 'models'
teacher_path = models_dir / 'bu_net.keras'

# Train student with distillation
student_model, student_history = train(
    model_name="nano_u",
    epochs=50,  # Reduced for demo
    batch_size=8,
    lr=1e-4,
    distill=True,
    teacher_weights=str(teacher_path),
    alpha=0.3,  # Favor distillation loss
    temperature=4.0,  # Softer targets
    enable_nas_monitoring=True,  # Enable NAS for analysis
    nas_layers=['encoder_conv_0', 'encoder_conv_1', 'bottleneck'],
    nas_log_dir='logs/nas_demo',
    nas_csv_path='logs/nas_demo/nano_u_metrics.csv',
    augment=True
)

print("\nâœ“ Student training complete")
print(f"  Final validation IoU: {student_history.history['val_binary_iou'][-1]:.4f}")
print(f"  Model saved to: {models_dir / 'nano_u.keras'}")

## Step 5: Compare Teacher vs Student

In [None]:
# Model size comparison
teacher_params = teacher_model.count_params()
student_params = student_model.count_params()
compression = (1 - student_params / teacher_params) * 100

print("ðŸ“Š Model Comparison")
print("="*60)
print(f"BU_Net (Teacher):")
print(f"  Parameters: {teacher_params:,} ({teacher_params/1e3:.1f}K)")
print(f"\nNano_U (Student):")
print(f"  Parameters: {student_params:,} ({student_params/1e3:.1f}K)")
print(f"\nCompression: {compression:.1f}% parameter reduction")
print(f"Speedup estimate: ~{teacher_params/student_params:.1f}x faster inference")

# Performance comparison
print(f"\nPerformance Comparison:")
teacher_iou = teacher_history.history['val_binary_iou'][-1]
student_iou = student_history.history['val_binary_iou'][-1]
iou_gap = (teacher_iou - student_iou) / teacher_iou * 100

print(f"  BU_Net validation IoU: {teacher_iou:.4f}")
print(f"  Nano_U validation IoU: {student_iou:.4f}")
print(f"  Performance gap: {iou_gap:.1f}%")

## Step 6: Visualize Training Curves

In [None]:
# Plot training history
fig, axes = plt.subplots(1, 2, figsize=(14, 4))

# Loss curves
axes[0].plot(teacher_history.history['loss'], label='Teacher train', linewidth=2)
axes[0].plot(teacher_history.history['val_loss'], label='Teacher val', linewidth=2, linestyle='--')
axes[0].plot(student_history.history['loss'], label='Student train', linewidth=2)
axes[0].plot(student_history.history['val_loss'], label='Student val', linewidth=2, linestyle='--')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training Loss Comparison')
axes[0].legend()
axes[0].grid(alpha=0.3)

# IoU curves
axes[1].plot(teacher_history.history['binary_iou'], label='Teacher train', linewidth=2)
axes[1].plot(teacher_history.history['val_binary_iou'], label='Teacher val', linewidth=2, linestyle='--')
axes[1].plot(student_history.history['binary_iou'], label='Student train', linewidth=2)
axes[1].plot(student_history.history['val_binary_iou'], label='Student val', linewidth=2, linestyle='--')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Binary IoU')
axes[1].set_title('Segmentation Performance (IoU) Comparison')
axes[1].legend()
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.savefig('training_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

print("âœ“ Training curves saved to training_comparison.png")

## Step 7: NAS Analysis

Visualize redundancy metrics from the NAS monitoring callback.

In [None]:
import pandas as pd

# Load NAS metrics
nas_csv = Path(get_project_root()) / 'logs' / 'nas_demo' / 'nano_u_metrics.csv'

if nas_csv.exists():
    nas_data = pd.read_csv(nas_csv)
    
    print("ðŸ“Š NAS Monitoring Results")
    print("="*60)
    print(f"\nLast epoch metrics:")
    last_row = nas_data.iloc[-1]
    print(f"  Redundancy score: {last_row['redundancy_score']:.4f}")
    print(f"  Mean correlation: {last_row.get('mean_correlation', 'N/A')}")
    print(f"  Condition number: {last_row.get('condition_number', 'N/A')}")
    print(f"  Trace: {last_row.get('trace', 'N/A')}")
    
    # Plot redundancy over time
    fig, ax = plt.subplots(figsize=(10, 5))
    if 'redundancy_score' in nas_data.columns:
        ax.plot(nas_data.index, nas_data['redundancy_score'], linewidth=2, marker='o')
        ax.axhline(y=0.7, color='r', linestyle='--', label='High redundancy threshold')
        ax.axhline(y=0.3, color='g', linestyle='--', label='Low redundancy threshold')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Redundancy Score')
    ax.set_title('Feature Redundancy Analysis During Training')
    ax.legend()
    ax.grid(alpha=0.3)
    plt.tight_layout()
    plt.savefig('nas_redundancy.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print("\nâœ“ NAS analysis plots saved")
else:
    print("âš  NAS CSV file not found. Skipping NAS analysis.")

## Step 8: Quantize Model for Edge Deployment

In [None]:
print("ðŸ”§ Quantizing Student Model to INT8")
print("="*60)

# Run quantization script
result = subprocess.run(
    [sys.executable, str(Path(get_project_root()) / 'src' / 'quantize.py'),
     '--model-name', 'nano_u',
     '--output', str(models_dir / 'nano_u_int8.tflite')],
    cwd=str(get_project_root())
)

if result.returncode == 0:
    # Get file sizes
    keras_size = (models_dir / 'nano_u.keras').stat().st_size / 1024  # KB
    tflite_size = (models_dir / 'nano_u_int8.tflite').stat().st_size / 1024  # KB
    
    print(f"\nâœ“ Quantization complete")
    print(f"  Keras model: {keras_size:.1f} KB")
    print(f"  TFLite INT8: {tflite_size:.1f} KB")
    print(f"  Size reduction: {(1 - tflite_size/keras_size)*100:.1f}%")
else:
    print("âœ— Quantization failed")

## Step 9: Run Inference on Test Set

In [None]:
import cv2
import glob

# Load test images
test_dir = Path(get_project_root()) / 'data' / 'processed_data' / 'test'
test_images = sorted(glob.glob(str(test_dir / 'img' / '*.png')))[:5]  # Load 5 test samples

print(f"ðŸ”® Running Inference on {len(test_images)} Test Images")
print("="*60)

# Prepare model for inference
student_model = tf.keras.models.load_model(models_dir / 'nano_u.keras', compile=False)

fig, axes = plt.subplots(len(test_images), 3, figsize=(12, 4*len(test_images)))
if len(test_images) == 1:
    axes = axes.reshape(1, -1)

for idx, img_path in enumerate(test_images):
    # Load image and mask
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    mask_path = str(img_path).replace('/img/', '/mask/')
    mask_gt = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
    
    # Normalize and infer
    img_norm = (img.astype(np.float32) / 255.0 - 0.5) / 0.5
    logits = student_model(np.expand_dims(img_norm, 0), training=False)[0]
    mask_pred = (1 / (1 + np.exp(-logits)) * 255).astype(np.uint8)  # Sigmoid + scale
    
    # Visualize
    axes[idx, 0].imshow(img)
    axes[idx, 0].set_title(f'Input Image {idx+1}')
    axes[idx, 0].axis('off')
    
    axes[idx, 1].imshow(mask_gt, cmap='gray')
    axes[idx, 1].set_title('Ground Truth')
    axes[idx, 1].axis('off')
    
    axes[idx, 2].imshow(mask_pred.squeeze(), cmap='gray')
    axes[idx, 2].set_title('Prediction')
    axes[idx, 2].axis('off')

plt.tight_layout()
plt.savefig('inference_samples.png', dpi=150, bbox_inches='tight')
plt.show()

print("âœ“ Inference visualizations saved to inference_samples.png")

## Step 10: Evaluate Models

In [None]:
print("ðŸ“ˆ Model Evaluation")
print("="*60)

# Run evaluation
result = subprocess.run(
    [sys.executable, str(Path(get_project_root()) / 'src' / 'evaluate.py'),
     '--model-name', 'nano_u',
     '--out', str(Path(get_project_root()) / 'eval_results.json')],
    cwd=str(get_project_root())
)

if result.returncode == 0:
    import json
    with open(Path(get_project_root()) / 'eval_results.json') as f:
        metrics = json.load(f)
    print(f"\nâœ“ Evaluation complete")
    for key, value in metrics.items():
        if isinstance(value, float):
            print(f"  {key}: {value:.4f}")
        else:
            print(f"  {key}: {value}")
else:
    print("âš  Evaluation script not available")

## Summary

Congratulations! You've successfully completed the Nano-U training pipeline:

âœ… **Teacher Model (BU_Net)**: Trained with ~180K parameters  
âœ… **Student Model (Nano_U)**: Trained with knowledge distillation, ~41K parameters  
âœ… **NAS Monitoring**: Analyzed feature redundancy during training  
âœ… **Quantization**: Converted to INT8 TFLite for edge deployment  
âœ… **Evaluation**: Benchmarked performance on test set  

### Next Steps

1. **Deploy to ESP32-S3**:
   ```bash
   cd esp_flash
   cargo build --release
   espflash flash --monitor models/nano_u_int8.tflite
   ```

2. **Fine-tune Hyperparameters**:
   - Adjust learning rate, batch size, distillation temperature
   - Run NAS monitoring to detect overparameterization
   - Implement architecture changes based on NAS recommendations

3. **Integrate Custom Dataset**:
   - See [`docs/CUSTOM_DATASET_INTEGRATION.md`](../docs/CUSTOM_DATASET_INTEGRATION.md)
   - Update `config/config.yaml` with new dataset paths
   - Re-run `src/prepare_data.py`

4. **Performance Optimization**:
   - Use NAS analysis to identify redundant layers
   - Reduce model size for better edge deployment
   - Measure inference latency on target hardware

### References

- [Main README](../README.md)
- [NAS Technical Reference](../docs/NAS_README.md)
- [Usage Examples](../docs/USAGE_EXAMPLES.md)
- [Custom Dataset Integration](../docs/CUSTOM_DATASET_INTEGRATION.md)