# Training a Custom Cabruca Segmentation Model

This notebook demonstrates how to train a custom segmentation model for your specific cabruca plantation data.

## Contents
1. Data Preparation
2. Dataset Creation and Augmentation
3. Model Configuration
4. Training Loop with Monitoring
5. Evaluation and Fine-tuning
6. Export and Deployment

## 1. Setup and Imports

In [None]:
import sys
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from pathlib import Path
import yaml
import pandas as pd
from tqdm.notebook import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Add project source
sys.path.append('../src')

# Import project modules
from models.cabruca_segmentation_model import CabrucaSegmentationModel
from data.dataset import CabrucaDataset
from training.advanced_trainer import AdvancedTrainer
from evaluation.agroforestry_metrics import AgroforestryMetrics

# Set device
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    device = torch.device('mps')
    print("Using Apple Silicon GPU (MPS)")
elif torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"Using CUDA GPU: {torch.cuda.get_device_name(0)}")
else:
    device = torch.device('cpu')
    print("Using CPU")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## 2. Data Preparation

In [None]:
# Data configuration
DATA_DIR = Path("../data")
TRAIN_DIR = DATA_DIR / "train"
VAL_DIR = DATA_DIR / "val"
TEST_DIR = DATA_DIR / "test"

# Create directories if they don't exist
for dir_path in [TRAIN_DIR, VAL_DIR, TEST_DIR]:
    dir_path.mkdir(parents=True, exist_ok=True)
    (dir_path / "images").mkdir(exist_ok=True)
    (dir_path / "masks").mkdir(exist_ok=True)
    (dir_path / "annotations").mkdir(exist_ok=True)

print("Data directories:")
print(f"  Train: {TRAIN_DIR}")
print(f"  Val: {VAL_DIR}")
print(f"  Test: {TEST_DIR}")

# Check for existing data
train_images = list((TRAIN_DIR / "images").glob("*"))
print(f"\nFound {len(train_images)} training images")

if len(train_images) == 0:
    print("\n⚠️ No training data found!")
    print("Please add your images and annotations to the data directories")
    print("\nExpected structure:")
    print("  data/train/images/ - Training images")
    print("  data/train/masks/ - Segmentation masks")
    print("  data/train/annotations/ - COCO format annotations")

## 3. Create Synthetic Dataset (Optional)

If you don't have real data yet, we'll create a synthetic dataset for demonstration.

In [None]:
def create_synthetic_dataset(num_samples=10, img_size=512):
    """Create synthetic plantation images for demonstration."""
    import cv2
    import json
    
    for split, num in [("train", num_samples), ("val", num_samples//4), ("test", num_samples//4)]:
        split_dir = DATA_DIR / split
        
        annotations = {
            "images": [],
            "annotations": [],
            "categories": [
                {"id": 1, "name": "cacao"},
                {"id": 2, "name": "shade"}
            ]
        }
        
        for i in range(num):
            # Create synthetic image
            img = np.zeros((img_size, img_size, 3), dtype=np.uint8)
            img[:, :] = [34, 139, 34]  # Green background
            
            # Create semantic mask
            mask = np.zeros((img_size, img_size), dtype=np.uint8)
            mask[:, :] = 2  # Understory
            
            # Add random trees
            np.random.seed(i)
            num_trees = np.random.randint(10, 30)
            
            img_annotations = []
            
            for j in range(num_trees):
                x = np.random.randint(50, img_size-50)
                y = np.random.randint(50, img_size-50)
                radius = np.random.randint(15, 40)
                
                # Determine species
                is_cacao = np.random.random() > 0.3
                
                if is_cacao:
                    color = [0, 100, 0]  # Dark green for cacao
                    mask_value = 0  # Cacao in mask
                    category_id = 1
                else:
                    color = [139, 69, 19]  # Brown for shade
                    mask_value = 1  # Shade in mask
                    category_id = 2
                
                # Draw tree in image
                cv2.circle(img, (x, y), radius, color, -1)
                cv2.circle(mask, (x, y), radius, mask_value, -1)
                
                # Add annotation
                img_annotations.append({
                    "id": len(annotations["annotations"]),
                    "image_id": i,
                    "category_id": category_id,
                    "bbox": [x-radius, y-radius, 2*radius, 2*radius],
                    "area": np.pi * radius * radius,
                    "iscrowd": 0
                })
            
            # Save image and mask
            img_path = split_dir / "images" / f"img_{i:04d}.jpg"
            mask_path = split_dir / "masks" / f"mask_{i:04d}.png"
            
            cv2.imwrite(str(img_path), img)
            cv2.imwrite(str(mask_path), mask)
            
            # Add to annotations
            annotations["images"].append({
                "id": i,
                "file_name": img_path.name,
                "width": img_size,
                "height": img_size
            })
            annotations["annotations"].extend(img_annotations)
        
        # Save annotations
        ann_path = split_dir / "annotations" / "annotations.json"
        with open(ann_path, 'w') as f:
            json.dump(annotations, f, indent=2)
        
        print(f"Created {num} synthetic {split} samples")

# Create synthetic dataset if no real data exists
if len(train_images) == 0:
    print("Creating synthetic dataset for demonstration...")
    create_synthetic_dataset(num_samples=20)
    print("✅ Synthetic dataset created")

## 4. Dataset and Data Loaders

In [None]:
# Define augmentation transforms
train_transform = A.Compose([
    A.RandomResizedCrop(512, 512, scale=(0.8, 1.2)),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.5),
    A.GaussNoise(var_limit=(10.0, 50.0), p=0.3),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

val_transform = A.Compose([
    A.Resize(512, 512),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

# Create datasets
train_dataset = CabrucaDataset(
    image_dir=str(TRAIN_DIR / "images"),
    mask_dir=str(TRAIN_DIR / "masks"),
    annotation_file=str(TRAIN_DIR / "annotations" / "annotations.json"),
    transform=train_transform
)

val_dataset = CabrucaDataset(
    image_dir=str(VAL_DIR / "images"),
    mask_dir=str(VAL_DIR / "masks"),
    annotation_file=str(VAL_DIR / "annotations" / "annotations.json"),
    transform=val_transform
)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

# Create data loaders
batch_size = 4 if device.type == 'mps' else 8

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0,  # Set to 0 for macOS
    pin_memory=False
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=0,
    pin_memory=False
)

# Visualize a batch
batch = next(iter(train_loader))
images, targets = batch

print(f"\nBatch shape: {images.shape}")
print(f"Number of targets: {len(targets)}")

## 5. Model Configuration

In [None]:
# Model configuration
model_config = {
    'num_instance_classes': 3,  # background, cacao, shade
    'num_semantic_classes': 5,  # cacao, shade, understory, bare_soil, shadow
    'backbone': 'resnet50',
    'pretrained': True,
    'freeze_backbone': False
}

# Training configuration
training_config = {
    'epochs': 10,  # Reduced for demo
    'learning_rate': 0.001,
    'weight_decay': 0.0001,
    'optimizer': 'AdamW',
    'scheduler': 'CosineAnnealingLR',
    'gradient_clip': 1.0,
    'mixed_precision': False,  # Set to False for MPS
    'gradient_accumulation_steps': 2,
    'early_stopping_patience': 5,
    'save_best_only': True
}

# Save configuration
config = {
    'model': model_config,
    'training': training_config,
    'data': {
        'batch_size': batch_size,
        'num_workers': 0,
        'train_samples': len(train_dataset),
        'val_samples': len(val_dataset)
    }
}

config_path = Path("../configs/notebook_config.yaml")
config_path.parent.mkdir(exist_ok=True)
with open(config_path, 'w') as f:
    yaml.dump(config, f, default_flow_style=False)

print("Configuration saved to:", config_path)
print("\nModel Configuration:")
for key, value in model_config.items():
    print(f"  {key}: {value}")

## 6. Initialize Model and Trainer

In [None]:
# Initialize model
print("Initializing model...")
model = CabrucaSegmentationModel(
    num_instance_classes=model_config['num_instance_classes'],
    num_semantic_classes=model_config['num_semantic_classes'],
    backbone=model_config['backbone'],
    pretrained=model_config['pretrained']
)

# Move model to device
model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Model size: {total_params * 4 / 1024**2:.1f} MB")

# Initialize trainer
trainer = AdvancedTrainer(
    model=model,
    config=training_config,
    device=str(device),
    output_dir="../outputs/notebook_training"
)

print("\n✅ Model and trainer initialized")

## 7. Training Loop

In [None]:
# Training loop with live plotting
from IPython.display import clear_output
import time

# Initialize metrics tracking
train_losses = []
val_losses = []
learning_rates = []

# Training parameters
num_epochs = 5  # Reduced for demo
best_val_loss = float('inf')

print("Starting training...")
print("=" * 50)

for epoch in range(num_epochs):
    # Training phase
    model.train()
    train_loss = 0
    train_batches = 0
    
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]")
    
    for batch_idx, (images, targets) in enumerate(progress_bar):
        # Move to device
        images = images.to(device)
        targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v 
                   for k, v in t.items()} for t in targets]
        
        # Forward pass
        loss = trainer.train_step(images, targets)
        
        train_loss += loss
        train_batches += 1
        
        # Update progress bar
        progress_bar.set_postfix({'loss': f'{loss:.4f}'})
    
    avg_train_loss = train_loss / train_batches
    train_losses.append(avg_train_loss)
    
    # Validation phase
    model.eval()
    val_loss = 0
    val_batches = 0
    
    with torch.no_grad():
        for images, targets in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]"): 
            images = images.to(device)
            targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v 
                       for k, v in t.items()} for t in targets]
            
            loss = trainer.validate_step(images, targets)
            val_loss += loss
            val_batches += 1
    
    avg_val_loss = val_loss / val_batches
    val_losses.append(avg_val_loss)
    
    # Learning rate
    current_lr = trainer.optimizer.param_groups[0]['lr']
    learning_rates.append(current_lr)
    
    # Update learning rate
    if trainer.scheduler:
        trainer.scheduler.step()
    
    # Save best model
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(model.state_dict(), "../outputs/best_model.pth")
        print(f"  💾 Saved best model (val_loss: {avg_val_loss:.4f})")
    
    # Print epoch summary
    print(f"\nEpoch {epoch+1}/{num_epochs}:")
    print(f"  Train Loss: {avg_train_loss:.4f}")
    print(f"  Val Loss: {avg_val_loss:.4f}")
    print(f"  Learning Rate: {current_lr:.6f}")
    print("-" * 50)

print("\n✅ Training completed!")
print(f"Best validation loss: {best_val_loss:.4f}")

## 8. Visualize Training Progress

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Loss curves
axes[0].plot(train_losses, label='Train Loss', marker='o')
axes[0].plot(val_losses, label='Val Loss', marker='s')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training and Validation Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Learning rate schedule
axes[1].plot(learning_rates, marker='o', color='orange')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Learning Rate')
axes[1].set_title('Learning Rate Schedule')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Training summary
print("Training Summary:")
print(f"  Final train loss: {train_losses[-1]:.4f}")
print(f"  Final val loss: {val_losses[-1]:.4f}")
print(f"  Best val loss: {min(val_losses):.4f} (Epoch {val_losses.index(min(val_losses))+1})")
print(f"  Improvement: {(train_losses[0] - train_losses[-1])/train_losses[0]*100:.1f}%")

## 9. Model Evaluation

In [None]:
# Load best model
model.load_state_dict(torch.load("../outputs/best_model.pth", map_location=device))
model.eval()

# Evaluate on validation set
from evaluation.agroforestry_metrics import AgroforestryMetrics

evaluator = AgroforestryMetrics()
all_metrics = []

print("Evaluating model on validation set...")

with torch.no_grad():
    for i, (images, targets) in enumerate(tqdm(val_loader)):
        if i >= 5:  # Evaluate first 5 batches for demo
            break
            
        images = images.to(device)
        
        # Get predictions
        outputs = model(images)
        
        # Calculate metrics for each image
        for j in range(len(images)):
            # Convert outputs to proper format
            if isinstance(outputs, tuple):
                instance_out, semantic_out = outputs
            else:
                instance_out = outputs
                semantic_out = None
            
            # Simple metric calculation (placeholder)
            # In real scenario, you would compute mAP, IoU, etc.
            metrics = {
                'num_detections': len(instance_out[j]['boxes']) if 'boxes' in instance_out[j] else 0,
                'avg_confidence': np.mean(instance_out[j]['scores'].cpu().numpy()) if 'scores' in instance_out[j] and len(instance_out[j]['scores']) > 0 else 0
            }
            all_metrics.append(metrics)

# Aggregate metrics
if all_metrics:
    avg_detections = np.mean([m['num_detections'] for m in all_metrics])
    avg_confidence = np.mean([m['avg_confidence'] for m in all_metrics if m['avg_confidence'] > 0])
    
    print(f"\nValidation Metrics:")
    print(f"  Average detections per image: {avg_detections:.1f}")
    print(f"  Average confidence: {avg_confidence:.3f}")
else:
    print("No metrics calculated")

## 10. Visualize Predictions

In [None]:
# Test on a single image
import cv2
from inference.batch_inference import BatchInferenceEngine, VisualizationTools

# Get a test image
test_images = list((VAL_DIR / "images").glob("*"))[:3]

if test_images:
    fig, axes = plt.subplots(len(test_images), 3, figsize=(12, 4*len(test_images)))
    if len(test_images) == 1:
        axes = axes.reshape(1, -1)
    
    # Initialize inference engine with trained model
    engine = BatchInferenceEngine(
        model_path="../outputs/best_model.pth",
        device=str(device)
    )
    
    for idx, img_path in enumerate(test_images):
        # Load image
        img = cv2.imread(str(img_path))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        # Run inference
        result = engine.process_single(str(img_path))
        
        # Original image
        axes[idx, 0].imshow(img)
        axes[idx, 0].set_title('Original')
        axes[idx, 0].axis('off')
        
        # Detection overlay
        overlay = VisualizationTools.create_overlay(img, result, alpha=0.4)
        axes[idx, 1].imshow(overlay)
        axes[idx, 1].set_title(f'Detections ({len(result.trees)} trees)')
        axes[idx, 1].axis('off')
        
        # Semantic segmentation
        semantic_colored = BatchInferenceEngine.SEMANTIC_COLORS[result.semantic_map]
        axes[idx, 2].imshow(semantic_colored.astype(np.uint8))
        axes[idx, 2].set_title('Segmentation')
        axes[idx, 2].axis('off')
    
    plt.suptitle('Model Predictions on Validation Set', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()
else:
    print("No test images found")

## 11. Export Model for Deployment

In [None]:
# Export model checkpoint
OUTPUT_DIR = Path("../outputs")
OUTPUT_DIR.mkdir(exist_ok=True)

# Save full checkpoint with metadata
checkpoint = {
    'model_state_dict': model.state_dict(),
    'model_config': model_config,
    'training_config': training_config,
    'epoch': len(train_losses),
    'best_val_loss': best_val_loss,
    'train_losses': train_losses,
    'val_losses': val_losses
}

checkpoint_path = OUTPUT_DIR / "checkpoint_final.pth"
torch.save(checkpoint, checkpoint_path)
print(f"✅ Model checkpoint saved to {checkpoint_path}")

# Export to ONNX (optional)
try:
    dummy_input = torch.randn(1, 3, 512, 512).to(device)
    onnx_path = OUTPUT_DIR / "model.onnx"
    
    torch.onnx.export(
        model.instance_head,  # Export instance head only
        dummy_input,
        onnx_path,
        export_params=True,
        opset_version=11,
        do_constant_folding=True,
        input_names=['input'],
        output_names=['output'],
        dynamic_axes={'input': {0: 'batch_size'},
                     'output': {0: 'batch_size'}}
    )
    print(f"✅ ONNX model exported to {onnx_path}")
except Exception as e:
    print(f"⚠️ ONNX export failed: {e}")

# Create deployment package
deployment_info = {
    'model_path': str(checkpoint_path),
    'model_config': model_config,
    'input_size': [512, 512],
    'normalization': {
        'mean': [0.485, 0.456, 0.406],
        'std': [0.229, 0.224, 0.225]
    },
    'classes': ['background', 'cacao', 'shade'],
    'device_tested': str(device),
    'performance': {
        'val_loss': float(best_val_loss),
        'training_time': 'see_logs',
        'inference_time': 'measure_on_deployment'
    }
}

import json
deployment_path = OUTPUT_DIR / "deployment_info.json"
with open(deployment_path, 'w') as f:
    json.dump(deployment_info, f, indent=2)

print(f"✅ Deployment info saved to {deployment_path}")
print("\n📦 Model ready for deployment!")

## 12. Next Steps

### Model Improvements
1. **Data Augmentation**: Add more aggressive augmentations for better generalization
2. **Hyperparameter Tuning**: Use Optuna or Ray Tune for systematic optimization
3. **Architecture Changes**: Try different backbones (ResNet101, EfficientNet)
4. **Loss Functions**: Experiment with focal loss, dice loss combinations

### Deployment Options
1. **API Server**: Deploy using FastAPI (see `api_server.py`)
2. **Streamlit App**: Interactive web interface (see `viewer.py`)
3. **Docker**: Containerize for cloud deployment
4. **Edge Deployment**: Convert to TensorFlow Lite or Core ML

### Performance Optimization
1. **Model Pruning**: Reduce model size while maintaining accuracy
2. **Quantization**: INT8 quantization for faster inference
3. **Knowledge Distillation**: Train smaller student models
4. **Batch Processing**: Optimize for throughput

### Integration
1. **QGIS Plugin**: Use the model in GIS workflows
2. **Drone Integration**: Real-time processing of drone footage
3. **Mobile App**: Deploy to mobile devices for field use

### Resources
- [Training Best Practices](../docs/training_guide.md)
- [API Documentation](../API_DOCUMENTATION.md)
- [Deployment Guide](../docs/deployment.md)