# Stain Detection Model Training

June 21, 2025

In [None]:
# Install dependencies

import sys
import subprocess

def install_packages():
    """Install required packages"""
    packages = [
        'ultralytics',
        'torch',
        'torchvision',
        'onnx',
        'onnxruntime-gpu',
        'opencv-python',
        'pillow',
        'matplotlib',
        'seaborn',
        'pandas',
        'ipynbname'
    ]
    
    subprocess.check_call([sys.executable, "-m", "pip", "install", *packages])
    
    print("All packages installed successfully!")

install_packages()

In [None]:
# GPU check 

import torch
import cv2
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import os
import yaml
from ultralytics import YOLO
from IPython.display import display, Image, clear_output
import pandas as pd

# Check GPU availability
print("=== GPU Check ===")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU count: {torch.cuda.device_count()}")
    print(f"Current GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")
else:
    print("WARNING: No GPU detected! Training will be very slow.")


In [None]:
# Split the dataset into training and validation sets

import os
import random
import shutil

import ipynbname

def split_dataset_for_yolo(source_dir, dest_dir, split_ratio=0.8):
    """
    Splits the dataset into training and validation sets.

    :param source_dir: Directory containing the original dataset.
    :param train_dir: Directory to save the training set.
    :param val_dir: Directory to save the validation set.
    :param split_ratio: Ratio of training data to total data (default is 0.8).
    """

    # Remove the output dirs recursively and unconditionally to ensure they do not exist.

   
    # Create the image and label dirs
    for d in ['train','test','val']:
        this_dir = os.path.join(dest_dir,d)
        if os.path.exists(this_dir):
            shutil.rmtree(this_dir)
        os.makedirs(this_dir)
        os.makedirs(os.path.join(this_dir, 'images'))
        os.makedirs(os.path.join(this_dir, 'labels'))

    # Get all files in the source directory
    all_files = [f for f in os.listdir(source_dir) if os.path.isfile(os.path.join(source_dir, f)) and f.endswith('.jpg')]
    
    # Shuffle the files
    random.shuffle(all_files)

    # Calculate the split index
    split_index = int(len(all_files) * split_ratio)

    # Split the files into training and validation sets
    train_files = all_files[:split_index]
    val_files = all_files[split_index:]

    # Select 10% random files for test set
    random.shuffle(all_files)
    split_index = int(len(all_files) * 0.1)
    test_files = all_files[:split_index]

    # Move files to their respective directories
    for file in train_files:
        shutil.copy(os.path.join(source_dir, file), os.path.join(dest_dir, "train", "images", file))
        label_file = os.path.splitext(file)[0] + '.txt'
        if os.path.exists(os.path.join(source_dir, label_file)):
            shutil.copy(os.path.join(source_dir, label_file), os.path.join(dest_dir, "train", "labels", label_file))

    for file in val_files:
        shutil.copy(os.path.join(source_dir, file), os.path.join(dest_dir, "val", "images", file))
        label_file = os.path.splitext(file)[0] + '.txt'
        if os.path.exists(os.path.join(source_dir, label_file)):
            shutil.copy(os.path.join(source_dir, label_file), os.path.join(dest_dir, "val", "labels", label_file))
    
    for file in test_files:
        shutil.copy(os.path.join(source_dir, file), os.path.join(dest_dir, "test", "images", file))
        label_file = os.path.splitext(file)[0] + '.txt'
        if os.path.exists(os.path.join(source_dir, label_file)):
            shutil.copy(os.path.join(source_dir, label_file), os.path.join(dest_dir, "test", "labels", label_file))
    
    print(f"Dataset split completed: {len(train_files)} training files, {len(test_files)} test files, {len(val_files)} validation files.")

try:
    script_dir = os.path.dirname(os.path.abspath(__file__))
except NameError:
    # try as Notebook
    notebook_path = ipynbname.path()
    script_dir = os.path.dirname(notebook_path)

print("Script directory:", script_dir)

ds_directory = os.path.join(script_dir, "dataset")
source_directory = os.path.join(ds_directory,"src")

split_dataset_for_yolo(source_directory, ds_directory)


In [None]:
# Data Augmentation with Contrast Enhancement

import cv2
import numpy as np
import random
import os
from pathlib import Path
import shutil

def apply_contrast_enhancement(image, method='clahe', **kwargs):
    """Apply various contrast enhancement techniques"""
    
    if len(image.shape) == 3:
        # Convert to grayscale for processing, then back to BGR
        gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    else:
        gray = image.copy()
    
    if method == 'clahe':
        clip_limit = kwargs.get('clip_limit', random.uniform(1.5, 3.0))
        tile_size = kwargs.get('tile_size', random.choice([8, 16]))
        clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=(tile_size, tile_size))
        enhanced = clahe.apply(gray)
        
    elif method == 'histogram_eq':
        enhanced = cv2.equalizeHist(gray)
        
    elif method == 'contrast_stretch':
        lower_pct = kwargs.get('lower_pct', random.uniform(1, 3))
        upper_pct = kwargs.get('upper_pct', random.uniform(97, 99))
        lower = np.percentile(gray, lower_pct)
        upper = np.percentile(gray, upper_pct)
        enhanced = np.clip((gray - lower) / (upper - lower) * 255, 0, 255).astype(np.uint8)
        
    elif method == 'adaptive_gamma':
        # Adaptive gamma correction based on image statistics
        mean_intensity = np.mean(gray)
        gamma = 2.2 if mean_intensity > 127 else 0.8
        gamma += random.uniform(-0.2, 0.2)  # Add randomness
        
        # Create lookup table
        inv_gamma = 1.0 / gamma
        table = np.array([((i / 255.0) ** inv_gamma) * 255 for i in np.arange(0, 256)]).astype("uint8")
        enhanced = cv2.LUT(gray, table)
    
    # Convert back to BGR if original was color
    if len(image.shape) == 3:
        enhanced = cv2.cvtColor(enhanced, cv2.COLOR_GRAY2BGR)
    
    return enhanced

def augment_dataset(source_dir, work_dir, augmentation_factor=3):
    """
    Create augmented dataset with contrast enhancement and other augmentations
    
    Args:
        source_dir: Original dataset directory (dataset/src)
        work_dir: Working directory for augmented data (dataset/work)
        augmentation_factor: How many augmented versions per original image
    """
    
    # Clear and create work directory structure
    work_path = Path(work_dir)
    if work_path.exists():
        shutil.rmtree(work_path)
    
    for split in ['train', 'val', 'test']:
        for subdir in ['images', 'labels']:
            (work_path / split / subdir).mkdir(parents=True, exist_ok=True)
    
    # Get all image files
    source_path = Path(source_dir)
    image_files = list(source_path.glob('*.jpg')) + list(source_path.glob('*.jpeg')) + list(source_path.glob('*.png'))
    
    print(f"Found {len(image_files)} images to augment")
    
    # Split files
    random.shuffle(image_files)
    
    # Use your existing split ratios
    train_split = int(len(image_files) * 0.8)
    test_split = int(len(image_files) * 0.1)
    
    splits = {
        'train': image_files[:train_split],
        'test': image_files[train_split:train_split + test_split],
        'val': image_files[train_split + test_split:]
    }
    
    total_processed = 0
    
    for split_name, files in splits.items():
        print(f"Processing {split_name} split: {len(files)} files")
        
        for img_file in files:
            # Copy original files first
            img_name = img_file.name
            label_name = img_file.stem + '.txt'
            label_file = source_path / label_name
            
            # Original image and label
            shutil.copy2(img_file, work_path / split_name / 'images' / img_name)
            if label_file.exists():
                shutil.copy2(label_file, work_path / split_name / 'labels' / label_name)
            
            # Create augmented versions
            if split_name == 'train':  # Only augment training data
                img = cv2.imread(str(img_file))
                
                for aug_idx in range(augmentation_factor):
                    # Random contrast enhancement method
                    method = random.choice(['clahe', 'contrast_stretch', 'adaptive_gamma'])
                    enhanced_img = apply_contrast_enhancement(img, method=method)
                    
                    # Additional random augmentations
                    augmented_img = enhanced_img.copy()
                    
                    # Random brightness adjustment
                    if random.random() < 0.5:
                        brightness = random.randint(-30, 30)
                        augmented_img = cv2.convertScaleAbs(augmented_img, alpha=1, beta=brightness)
                    
                    # Random noise (very subtle for fabric)
                    if random.random() < 0.3:
                        noise = np.random.normal(0, random.uniform(1, 3), augmented_img.shape).astype(np.int16)
                        augmented_img = np.clip(augmented_img.astype(np.int16) + noise, 0, 255).astype(np.uint8)
                    
                    # Random slight blur (to simulate camera focus variations)
                    if random.random() < 0.2:
                        kernel_size = random.choice([3, 5])
                        augmented_img = cv2.GaussianBlur(augmented_img, (kernel_size, kernel_size), 0)
                    
                    # Save augmented image
                    aug_img_name = f"{img_file.stem}_aug{aug_idx}{img_file.suffix}"
                    aug_label_name = f"{img_file.stem}_aug{aug_idx}.txt"
                    
                    cv2.imwrite(str(work_path / split_name / 'images' / aug_img_name), augmented_img)
                    
                    # Copy label file for augmented image
                    if label_file.exists():
                        shutil.copy2(label_file, work_path / split_name / 'labels' / aug_label_name)
            
            total_processed += 1
            if total_processed % 10 == 0:
                print(f"Processed {total_processed}/{len(image_files)} images")
    
    # Count final dataset
    for split in ['train', 'val', 'test']:
        img_count = len(list((work_path / split / 'images').glob('*')))
        print(f"{split}: {img_count} images")
    
    print(f"✅ Augmented dataset created in {work_dir}")
    return work_dir

# Create augmented dataset
try:
    script_dir = os.path.dirname(os.path.abspath(__file__))
except NameError:
    notebook_path = ipynbname.path()
    script_dir = os.path.dirname(notebook_path)

source_directory = os.path.join(script_dir, "dataset", "src")
work_directory = os.path.join(script_dir, "dataset", "work")

if os.path.exists(work_directory):
    shutil.rmtree(work_directory)
os.makedirs(work_directory, exist_ok=True)

print("Creating augmented dataset...")
augmented_dataset_dir = augment_dataset(source_directory, work_directory, augmentation_factor=3)

# Create data.yaml pointing to augmented dataset
data_yaml_content = f"""
# Stain Detection Dataset Configuration
path: {os.path.join(script_dir, "dataset", "work")}  # dataset root dir
train: train/images  # train images (relative to 'path')
val: val/images      # val images (relative to 'path')
test: test/images    # test images (relative to 'path')

# Classes
names:
  0: stain

# Number of classes
nc: 1
"""

# Write the data.yaml file
with open('data.yaml', 'w') as f:
    f.write(data_yaml_content.strip())

print("✅ data.yaml updated to use augmented dataset")
print(f"Dataset path: {os.path.join(script_dir, 'dataset', 'work')}")

In [None]:
# Dataset structure validation

def validate_dataset_structure():
    """Validate that the dataset structure is correct"""
    
    # Check for data.yaml
    if not os.path.exists('data.yaml'):
        print("❌ data.yaml not found!")
        return False
    
    # Load and display data.yaml
    with open('data.yaml', 'r') as f:
        data_config = yaml.safe_load(f)
    
    print("=== Dataset Configuration ===")
    for key, value in data_config.items():
        print(f"{key}: {value}")
    
    # Check directories
    required_dirs = ['train', 'val', 'test']  # Note: 'val' not 'valid' in some configs
    
    for split in required_dirs:
        img_dir = data_config.get(split, f'../{split}/images')
        # Handle relative paths
        if img_dir.startswith('../'):
            img_dir = img_dir[3:]  # Remove '../'

        img_dir = os.path.join(data_config['path'], img_dir)
        if os.path.exists(img_dir):
            img_count = len([f for f in os.listdir(img_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
            print(f"✅ {split}: {img_count} images found in {img_dir}")
        else:
            print(f"❌ {split} directory not found: {img_dir}")
    
    return True

# Validate dataset
dataset_valid = validate_dataset_structure()


In [None]:
# Test system

import gc
import torch
import psutil
import os

def check_memory_and_disk():
    """Check current memory and disk usage"""
    
    # Check RAM
    memory = psutil.virtual_memory()
    print(f"RAM: {memory.used/1024**3:.1f}GB used / {memory.total/1024**3:.1f}GB total ({memory.percent:.1f}%)")
    
    # Check disk space
    disk = psutil.disk_usage('/')
    print(f"Disk: {disk.used/1024**3:.1f}GB used / {disk.total/1024**3:.1f}GB total ({disk.used/disk.total*100:.1f}%)")
    
    # Check GPU memory if available
    if torch.cuda.is_available():
        torch.cuda.synchronize()
        gpu_memory = torch.cuda.memory_allocated() / 1024**3
        gpu_total = torch.cuda.get_device_properties(0).total_memory / 1024**3
        print(f"GPU: {gpu_memory:.1f}GB used / {gpu_total:.1f}GB total")
    
    # Check shared memory
    try:
        shm_stats = os.statvfs('/dev/shm')
        shm_total = shm_stats.f_frsize * shm_stats.f_blocks / 1024**3
        shm_free = shm_stats.f_frsize * shm_stats.f_bavail / 1024**3
        print(f"Shared Memory: {shm_total-shm_free:.1f}GB used / {shm_total:.1f}GB total")
    except:
        print("Shared memory stats not available")

def clean_memory():
    """Aggressive memory cleanup"""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
    print("Memory cleaned")

def clear_temp_files():
    """Clear PyTorch temporary files"""
    import tempfile
    import glob
    
    temp_dir = tempfile.gettempdir()
    torch_files = glob.glob(os.path.join(temp_dir, 'torch_*'))
    
    removed_count = 0
    for file_path in torch_files:
        try:
            if os.path.isfile(file_path):
                os.unlink(file_path)
                removed_count += 1
        except Exception as e:
            pass  # Ignore files in use
    
    print(f"Cleared {removed_count} temporary PyTorch files")

check_memory_and_disk()
clear_temp_files()
clean_memory()

In [None]:
# Visualize samples

import random

def visualize_sample_data(num_samples=6):
    """Visualize sample images with their labels"""
    
    # Find sample images
    with open('data.yaml', 'r') as f:
        data_config = yaml.safe_load(f)
        train_dir = os.path.join(data_config['path'], data_config['train'])
    if not os.path.exists(train_dir):
        train_dir = './train/images'
    
    if not os.path.exists(train_dir):
        print("❌ Train directory not found!")
        return
    
    image_files = [f for f in os.listdir(train_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
    
    if len(image_files) == 0:
        print("❌ No images found in train directory!")
        return
    
    # Plot samples
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    axes = axes.flatten()
    
    for i in range(min(num_samples, len(image_files))):
        img_path = os.path.join(train_dir, random.choice(image_files))
        
        # Load image
        img = cv2.imread(img_path)
        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        # Load corresponding label
        label_path = img_path.replace('/images/', '/labels/').replace('.jpg', '.txt')
        
        axes[i].imshow(img_rgb)
        axes[i].set_title(f'Sample {i+1}: {image_files[i]}')
        axes[i].axis('off')
        
        # Try to load and display bounding box if label exists
        if os.path.exists(label_path):
            try:
                with open(label_path, 'r') as f:
                    label_data = f.read().strip()
                    if label_data:
                        # Parse YOLO format: class x_center y_center width height (normalized)
                        parts = label_data.split()
                        if len(parts) >= 5:
                            _, x_center, y_center, width, height = map(float, parts[:5])
                            
                            # Convert to pixel coordinates
                            img_h, img_w = img_rgb.shape[:2]
                            x1 = int((x_center - width/2) * img_w)
                            y1 = int((y_center - height/2) * img_h)
                            x2 = int((x_center + width/2) * img_w)
                            y2 = int((y_center + height/2) * img_h)
                            
                            # Draw bounding box
                            from matplotlib.patches import Rectangle
                            rect = Rectangle((x1, y1), x2-x1, y2-y1, linewidth=2, 
                                           edgecolor='red', facecolor='none')
                            axes[i].add_patch(rect)
                            axes[i].set_title(f'Sample {i+1}: {image_files[i]} (with stain)')
            except Exception as e:
                print(f"Could not parse label for {image_files[i]}: {e}")
    
    plt.tight_layout()
    plt.show()

# Visualize samples
if dataset_valid:
    visualize_sample_data()


In [None]:
# Configure training
import torch
import torch.nn as nn
import torch.nn.functional as F

class FocalLoss(nn.Module):
    """
    Focal Loss implementation for object detection
    
    This addresses class imbalance by down-weighting easy examples
    and focusing training on hard examples (like subtle stains)
    """
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
    
    def forward(self, pred, target):
        # Convert predictions to probabilities
        p = torch.sigmoid(pred)
        
        # Calculate cross entropy
        ce_loss = F.binary_cross_entropy_with_logits(pred, target, reduction='none')
        
        # Calculate p_t
        p_t = p * target + (1 - p) * (1 - target)
        
        # Calculate alpha_t
        alpha_t = self.alpha * target + (1 - self.alpha) * (1 - target)
        
        # Calculate focal loss
        focal_loss = alpha_t * (1 - p_t) ** self.gamma * ce_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

class EnhancedTrainingConfig:
    def __init__(self):
        # Model settings
        self.model_size = 'yolo11n.pt'  # Start with nano
        
        # Training parameters - ENHANCED FOR SUBTLE DETECTION
        self.epochs = 150  # Increased for better convergence
        self.batch_size = 8
        self.imgsz = 640
        self.patience = 30  # Increased patience
        
        # Paths
        self.data_yaml = 'data.yaml'
        self.project_name = 'fabric_stain_enhanced'
        self.experiment_name = 'stain_detection_focal_v1'
        
        # Hardware
        self.device = 0 if torch.cuda.is_available() else 'cpu'
        self.workers = 0
        
        # Memory optimization
        self.cache = False
        self.amp = True
        self.save = True
        self.verbose = True
        
        # CRITICAL: Enhanced loss function parameters
        self.cls_loss_weight = 1.0  # Standard classification weight
        self.box_loss_weight = 7.5  # Increased box regression weight
        self.dfl_loss_weight = 1.5  # Distribution focal loss weight
        
        # Confidence and NMS thresholds - CRITICAL FOR FALSE POSITIVES
        self.conf_threshold = 0.25  # Lower initial confidence
        self.iou_threshold = 0.7    # Higher IoU for NMS
        
        # Optimizer settings - tuned for subtle features
        self.optimizer = 'AdamW'  # Better for fine-tuning
        self.lr0 = 0.001  # Lower learning rate for stability
        self.lrf = 0.001  # Lower final learning rate
        self.momentum = 0.937
        self.weight_decay = 0.0005
        self.single_cls = True
        
        # Enhanced augmentation for fabric textures
        self.hsv_h = 0.008    # Reduced hue variation (fabric color consistency)
        self.hsv_s = 0.3      # Reduced saturation (white fabric)
        self.hsv_v = 0.2      # Reduced brightness variation
        self.degrees = 5.0    # Reduced rotation (fabric typically flat)
        self.translate = 0.1  # Reduced translation
        self.scale = 0.2      # Reduced scaling
        self.fliplr = 0.5     # Keep horizontal flip
        self.flipud = 0.1     # Minimal vertical flip
        self.perspective = 0.0002  # Minimal perspective change
        self.mixup = 0.0      # Disable mixup (can blur stain boundaries)
        self.copy_paste = 0.0 # Disable copy-paste
        
        # Early stopping and validation settings
        self.val_period = 1   # Validate every epoch
        self.save_period = 10 # Save checkpoint every 10 epochs
        
        # Class balancing (for handling class imbalance)
        self.cls_pw = 1.0     # Positive weight for classification
        self.obj_pw = 1.0     # Positive weight for objectness
        
    def display(self):
        """Display configuration"""
        print("=== Enhanced Training Configuration ==")
        for attr, value in self.__dict__.items():
            print(f"{attr}: {value}")

# Create enhanced config
config = EnhancedTrainingConfig()
config.display()

In [None]:
# Train the model

def train_yolo_model(config):
    """Train the YOLO model with enhanced settings for subtle feature detection"""
    
    print("🚀 Starting YOLOv11 training for stain detection...")
    print(f"Using device: {config.device}")
    
    # Load model
    model = YOLO(config.model_size)
    
    # Start training with enhanced parameters
    results = model.train(
        data=config.data_yaml,
        epochs=config.epochs,
        imgsz=config.imgsz,
        batch=config.batch_size,
        patience=config.patience,
        save=config.save,
        cache=config.cache,
        device=config.device,
        workers=config.workers,
        project=config.project_name,
        name=config.experiment_name,
        exist_ok=True,
        pretrained=True,
        optimizer=config.optimizer,
        verbose=config.verbose,
        seed=42,
        deterministic=True,
        single_cls=config.single_cls,
        amp=config.amp,
        
        # Learning rate settings
        lr0=config.lr0,
        lrf=config.lrf,
        momentum=config.momentum,
        weight_decay=config.weight_decay,
        
        # Loss weights - CRITICAL FOR SUBTLE DETECTION
        cls=config.cls_loss_weight,
        box=config.box_loss_weight,
        dfl=config.dfl_loss_weight,
        
        # Confidence thresholds
        conf=config.conf_threshold,
        iou=config.iou_threshold,
        
        # Enhanced augmentation
        hsv_h=config.hsv_h,
        hsv_s=config.hsv_s,
        hsv_v=config.hsv_v,
        degrees=config.degrees,
        translate=config.translate,
        scale=config.scale,
        fliplr=config.fliplr,
        flipud=config.flipud,
        perspective=config.perspective,
        mixup=config.mixup,
        copy_paste=config.copy_paste,
        
        # Training schedule
        warmup_epochs=5.0,
        warmup_momentum=0.8,
        warmup_bias_lr=0.1,
        
        # Validation settings
        val=True,
        plots=True,
        save_json=True,
        
        # Additional settings for subtle detection
        mosaic=0.8,  # Reduced mosaic augmentation
        close_mosaic=15,  # Close mosaic earlier
    )
    
    print("✅ Enhanced training completed!")
    return results, model

# Start training (this will take a while!)
if dataset_valid:
    print("Starting training... This may take 30+ minutes depending on your GPU.")
    training_results, trained_model = train_yolo_model(config)
else:
    print("❌ Cannot start training due to dataset issues. Please fix the dataset structure first.")


In [None]:
# Training report

def analyze_training_results(project_name, experiment_name):
    """Analyze and visualize training results"""
    
    results_dir = Path(project_name) / experiment_name
    
    if not results_dir.exists():
        print(f"❌ Results directory not found: {results_dir}")
        return
    
    print(f"📊 Analyzing results from: {results_dir}")
    
    # Display training curves
    results_img = results_dir / 'results.png'
    if results_img.exists():
        print("=== Training Curves ===")
        display(Image(str(results_img)))
    
    # Display confusion matrix
    confusion_matrix_img = results_dir / 'confusion_matrix.png'
    if confusion_matrix_img.exists():
        print("=== Confusion Matrix ===")
        display(Image(str(confusion_matrix_img)))
    
    # Display validation predictions
    val_batch_imgs = list(results_dir.glob('val_batch*_pred.jpg'))
    if val_batch_imgs:
        print("=== Validation Predictions ===")
        for img_path in val_batch_imgs[:2]:  # Show first 2 batches
            print(f"Validation predictions: {img_path.name}")
            display(Image(str(img_path)))
    
    # Get best model path
    best_model_path = results_dir / 'weights' / 'best.pt'
    if best_model_path.exists():
        print(f"✅ Best model saved at: {best_model_path}")
        return best_model_path
    else:
        print("❌ Best model not found!")
        return None

# Analyze results
if 'training_results' in locals():
    best_model_path = analyze_training_results(config.project_name, config.experiment_name)

In [None]:
# Model Validation and Metrics

def validate_model(model_path):
    """Validate the trained model and show metrics"""
    
    if not os.path.exists(model_path):
        print(f"❌ Model not found: {model_path}")
        return None
    
    print("🔍 Validating model...")
    
    # Load model
    model = YOLO(model_path)
    
    # Run validation
    metrics = model.val()
    
    # Display metrics
    print("=== Validation Metrics ===")
    print(f"mAP50: {metrics.box.map50:.4f}")
    print(f"mAP50-95: {metrics.box.map:.4f}")
    print(f"Precision: {metrics.box.mp:.4f}")
    print(f"Recall: {metrics.box.mr:.4f}")
    
    return metrics

# Validate model
if 'best_model_path' in locals() and best_model_path:
    validation_metrics = validate_model(best_model_path)

In [None]:
# ONNX export

def export_to_onnx(model_path, output_name="stain_detection_model"):
    """Export the trained model to ONNX format"""
    
    if not os.path.exists(model_path):
        print(f"❌ Model not found: {model_path}")
        return None
    
    print("📦 Exporting model to ONNX format...")
    
    # Load model
    model = YOLO(model_path)
    
    # Export to ONNX
    onnx_path = model.export(
        format='onnx',
        imgsz=640,
        optimize=True,
        half=False,  # Use FP32 for better compatibility
        int8=False,
        dynamic=False,
        simplify=True,
        opset=11,  # ONNX opset 11 for wide compatibility
    )
    
    print(f"✅ ONNX model exported to: {onnx_path}")
    print(f"Model size: {os.path.getsize(onnx_path) / (1024*1024):.2f} MB")
    
    return onnx_path

# Export to ONNX
if 'best_model_path' in locals() and best_model_path:
    onnx_model_path = export_to_onnx(best_model_path)

In [None]:
# Test 

import random

def get_grid_position(bbox, img_width, img_height):
    """Convert bounding box center to 3x3 grid position"""
    # Get center of bounding box
    center_x = bbox[0] + bbox[2] / 2
    center_y = bbox[1] + bbox[3] / 2
    
    # Convert to grid coordinates (0-2 for both x and y)
    grid_x = int(center_x / img_width * 3)
    grid_y = int(center_y / img_height * 3)
    
    # Ensure within bounds
    grid_x = max(0, min(2, grid_x))
    grid_y = max(0, min(2, grid_y))
    
    # Convert to single grid position (0-8)
    grid_position = grid_y * 3 + grid_x
    
    return grid_position

def test_inference(model_path, test_image_path=None, confidence_threshold=0.5):
    """Test inference on a sample image"""
    
    if not os.path.exists(model_path):
        print(f"❌ Model not found: {model_path}")
        return
    
    # Find a test image if not provided
    if test_image_path is None:
        test_dirs = ['dataset/test/images', '../test/images', 'val/images', '../val/images', 'valid/images', '../valid/images']
        for test_dir in test_dirs:
            if os.path.exists(test_dir):
                test_images = [f for f in os.listdir(test_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
                if test_images:
                    test_image_path = os.path.join(test_dir, random.choice(test_images))
                    break
    
    if test_image_path is None or not os.path.exists(test_image_path):
        print("❌ No test image found!")
        return
    
    print(f"🧪 Testing inference on: {test_image_path}")
    
    # Load model
    model = YOLO(model_path)
    
    # Load and display original image
    image = cv2.imread(test_image_path)
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    img_height, img_width = image_rgb.shape[:2]
    
    # Run prediction
    results = model.predict(test_image_path, conf=confidence_threshold, verbose=False)
    
    # Process results
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # Original image
    ax1.imshow(image_rgb)
    ax1.set_title('Original Image')
    ax1.axis('off')
    
    # Prediction results
    ax2.imshow(image_rgb)
    ax2.set_title('Predictions with Grid')
    ax2.axis('off')
    
    # Draw grid lines
    for i in range(1, 3):
        ax2.axvline(x=img_width * i / 3, color='blue', linestyle='--', alpha=0.5)
        ax2.axhline(y=img_height * i / 3, color='blue', linestyle='--', alpha=0.5)
    
    stain_detected = False
    detections = []
    
    # Process detections
    for result in results:
        if result.boxes is not None and len(result.boxes) > 0:
            stain_detected = True
            
            for box in result.boxes:
                # Get bounding box coordinates
                bbox = box.xyxy[0].cpu().numpy()  # [x1, y1, x2, y2]
                confidence = box.conf[0].cpu().numpy()
                
                # Draw bounding box
                from matplotlib.patches import Rectangle
                rect = Rectangle((bbox[0], bbox[1]), bbox[2]-bbox[0], bbox[3]-bbox[1], 
                               linewidth=2, edgecolor='red', facecolor='none')
                ax2.add_patch(rect)
                
                # Get grid position
                grid_pos = get_grid_position(bbox, img_width, img_height)
                
                # Add text annotation
                ax2.text(bbox[0], bbox[1]-10, f'Stain: {confidence:.2f}\\nGrid: {grid_pos}', 
                        bbox=dict(boxstyle="round,pad=0.3", facecolor="yellow", alpha=0.7),
                        fontsize=10)
                
                detections.append({
                    'grid_position': grid_pos,
                    'confidence': float(confidence),
                    'bbox': bbox.tolist()
                })
    
    plt.tight_layout()
    plt.show()
    
    # Print results
    print(f"Stain detected: {stain_detected}")
    if stain_detected:
        grid_map = {
            0: "Top-Left", 1: "Top-Center", 2: "Top-Right",
            3: "Middle-Left", 4: "Middle-Center", 5: "Middle-Right",
            6: "Bottom-Left", 7: "Bottom-Center", 8: "Bottom-Right"
        }
        
        for i, detection in enumerate(detections):
            grid_pos = detection['grid_position']
            confidence = detection['confidence']
            print(f"Detection {i+1}: Grid position {grid_pos} ({grid_map[grid_pos]}), Confidence: {confidence:.3f}")
    
    return detections

# Test inference
if 'best_model_path' in locals() and best_model_path:
    test_results = test_inference(best_model_path)