# 05 - YOLO Dataset Preparation
## Creating Train/Val/Test Splits and Training YOLO

This notebook covers:
- Creating train/val/test splits
- Organizing YOLO directory structure
- Creating data.yaml configuration
- Training YOLOv8 model
- Evaluation and inference

In [None]:
# Setup
import sys
sys.path.append('..')

import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import shutil
from sklearn.model_selection import train_test_split
import yaml

print("✅ All imports successful!")

## Configuration

In [None]:
# Paths
SPECTROGRAM_DIR = Path('../data/spectrograms')
ANNOTATION_DIR = Path('../data/annotations')
DATASET_DIR = Path('../data/yolo_dataset')

# Split ratios
TRAIN_RATIO = 0.7
VAL_RATIO = 0.15
TEST_RATIO = 0.15

# Signal classes
SIGNAL_CLASSES = ['bluetooth', 'wifi', 'zigbee', 'drone']

print(f"Train: {TRAIN_RATIO*100}%, Val: {VAL_RATIO*100}%, Test: {TEST_RATIO*100}%")
print(f"Classes: {SIGNAL_CLASSES}")

## Create Dataset Split

In [None]:
def create_yolo_dataset_split(image_dir, annotation_dir, output_dir, 
                              train_ratio=0.7, val_ratio=0.15, test_ratio=0.15):
    """
    Organize dataset into YOLO directory structure
    """
    # Get all image files
    image_files = sorted(Path(image_dir).glob("*.png"))
    print(f"Found {len(image_files)} images")
    
    # Filter images that have annotations
    annotated_images = []
    for img_file in image_files:
        annot_file = Path(annotation_dir) / f"{img_file.stem}.txt"
        if annot_file.exists():
            annotated_images.append(img_file)
    
    print(f"Found {len(annotated_images)} annotated images")
    
    # Split dataset
    train_files, temp_files = train_test_split(annotated_images, train_size=train_ratio, random_state=42)
    val_files, test_files = train_test_split(temp_files, train_size=val_ratio/(val_ratio+test_ratio), random_state=42)
    
    splits = {
        'train': train_files,
        'val': val_files,
        'test': test_files
    }
    
    # Create directory structure and copy files
    for split_name, files in splits.items():
        split_dir = Path(output_dir) / split_name
        (split_dir / 'images').mkdir(parents=True, exist_ok=True)
        (split_dir / 'labels').mkdir(parents=True, exist_ok=True)
        
        for img_file in files:
            # Copy image
            shutil.copy(img_file, split_dir / 'images' / img_file.name)
            
            # Copy corresponding annotation
            annot_file = Path(annotation_dir) / f"{img_file.stem}.txt"
            if annot_file.exists():
                shutil.copy(annot_file, split_dir / 'labels' / annot_file.name)
    
    print(f"\n✅ Dataset split created in {output_dir}")
    print(f"Train: {len(train_files)}, Val: {len(val_files)}, Test: {len(test_files)}")
    
    return len(train_files), len(val_files), len(test_files)

# Execute split
train_count, val_count, test_count = create_yolo_dataset_split(
    SPECTROGRAM_DIR,
    ANNOTATION_DIR,
    DATASET_DIR,
    TRAIN_RATIO,
    VAL_RATIO,
    TEST_RATIO
)

## Create data.yaml

In [None]:
# Create YOLO configuration file
data_yaml = {
    'path': str(DATASET_DIR.absolute()),
    'train': 'train/images',
    'val': 'val/images',
    'test': 'test/images',
    'nc': len(SIGNAL_CLASSES),
    'names': SIGNAL_CLASSES
}

yaml_path = DATASET_DIR / 'data.yaml'
with open(yaml_path, 'w') as f:
    yaml.dump(data_yaml, f, default_flow_style=False)

print(f"✅ Created {yaml_path}")
print("\nContents:")
print(yaml.dump(data_yaml, default_flow_style=False))

## Train YOLOv8 Model

In [None]:
from ultralytics import YOLO

# Training parameters
MODEL_SIZE = 'yolov8n'  # Options: yolov8n, yolov8s, yolov8m, yolov8l, yolov8x
EPOCHS = 100
IMAGE_SIZE = 256
BATCH_SIZE = 16

print(f"Training {MODEL_SIZE} for {EPOCHS} epochs")
print(f"Image size: {IMAGE_SIZE}x{IMAGE_SIZE}")
print(f"Batch size: {BATCH_SIZE}")

In [None]:
# Initialize model
model = YOLO(f'{MODEL_SIZE}.pt')

# Train
results = model.train(
    data=str(yaml_path),
    epochs=EPOCHS,
    imgsz=IMAGE_SIZE,
    batch=BATCH_SIZE,
    device=0,  # GPU device (use 'cpu' for CPU training)
    workers=4,
    
    # Optimization
    optimizer='AdamW',
    lr0=0.001,
    lrf=0.01,
    momentum=0.937,
    weight_decay=0.0005,
    
    # Augmentation (safe for spectrograms)
    hsv_h=0.015,
    hsv_s=0.7,
    hsv_v=0.4,
    degrees=0.0,  # No rotation
    translate=0.1,
    scale=0.5,
    shear=0.0,
    perspective=0.0,
    flipud=0.0,  # No vertical flip
    fliplr=0.0,  # No horizontal flip
    mosaic=1.0,
    mixup=0.0,
    
    # Validation
    val=True,
    save_period=10,
    
    # Logging
    project='../models',
    name='rf_signal_detection',
    exist_ok=True
)

print("\n✅ Training complete!")

## Evaluate Model

In [None]:
# Load best model
best_model_path = Path('../models/rf_signal_detection/weights/best.pt')
model = YOLO(str(best_model_path))

# Run validation
metrics = model.val(data=str(yaml_path), split='test')

print(f"\nmAP@0.5: {metrics.box.map50:.4f}")
print(f"mAP@0.5:0.95: {metrics.box.map:.4f}")
print(f"\nPer-class metrics:")
for i, class_name in enumerate(SIGNAL_CLASSES):
    print(f"  {class_name}:")
    print(f"    Precision: {metrics.box.p[i]:.4f}")
    print(f"    Recall: {metrics.box.r[i]:.4f}")

## Run Inference on Test Images

In [None]:
import cv2

# Get test images
test_images = sorted((DATASET_DIR / 'test/images').glob('*.png'))

# Run inference on first few
results = model.predict(
    source=test_images[:6],
    conf=0.25,
    iou=0.45,
    save=False
)

# Display results
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

for idx, result in enumerate(results):
    img = result.plot()  # Draw predictions on image
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    axes[idx].imshow(img_rgb)
    axes[idx].set_title(f"Test Image {idx+1}")
    axes[idx].axis('off')

plt.tight_layout()
plt.show()

## Summary

Training complete! The model is ready for deployment.

Next steps:
- Fine-tune hyperparameters
- Collect more training data
- Implement real-time inference pipeline