# YOLOv11s Fine-Tuning for Chest X-Ray Abnormality Detection

**Purpose**: Fine-tune YOLOv11s model on VinBigData Chest X-ray dataset for detecting 14 disease classes

**Date**: 2025-11-08

**Dataset**: Roboflow Universe - VinBigData Chest X-ray Symptom Detection (version 3, YOLOv11 format)

**Output**: Trained model weights exported to `../backend/models/yolov11s_finetuned.pt`

## Workflow Overview

1. **Dataset Download**: Download preprocessed dataset from Roboflow
2. **Auto-Labeling**: Label images without annotations as "Normal" (B√¨nh th∆∞·ªùng)
3. **Class Mapping Verification**: Ensure English-Vietnamese class mapping is correct
4. **Preprocessing**: Apply filters for data augmentation
5. **Training**: Fine-tune YOLOv11s with tqdm progress tracking and WandB logging
6. **Validation**: Test model on validation set
7. **Export**: Save best weights for production use

## Requirements

- Python 3.12.3+
- ultralytics (YOLOv11)
- roboflow (dataset download)
- wandb (experiment tracking)
- tqdm (progress bars)
- numpy, pillow (image processing)

## Section 1: Environment Setup

Install required packages and import dependencies.

In [None]:
# # Install required packages
# !pip install -q roboflow ultralytics wandb tqdm pillow numpy

In [None]:
# Import dependencies
import os
import json
import shutil
from pathlib import Path
from tqdm import tqdm
import numpy as np
from PIL import Image

# Set working directory to repository root
os.chdir('..')
print(f"Current working directory: {os.getcwd()}")

## Section 2: Dataset Download from Roboflow

Download the VinBigData Chest X-ray Symptom Detection dataset (version 3) in YOLOv11 format.

**API Key**: wQ9S049DhK8xjIhNy6zv

**Project**: chest-xray-symptom-detection

**Workspace**: vinbigdataxrayproject

In [None]:
from roboflow import Roboflow

# Initialize Roboflow with API key
rf = Roboflow(api_key="wQ9S049DhK8xjIhNy6zv")

# Access the VinBigData Chest X-ray project
project = rf.workspace("vinbigdataxrayproject").project("chest-xray-symptom-detection")

# Download version 3 in YOLOv11 format
version = project.version(3)
dataset = version.download("yolov11")

print(f"\n‚úì Dataset downloaded to: {dataset.location}")
print(f"‚úì Dataset structure:")
print(f"  - Train: {dataset.location}/train/")
print(f"  - Validation: {dataset.location}/valid/")
print(f"  - Test: {dataset.location}/test/")
print(f"  - Config: {dataset.location}/data.yaml")

## Section 3: Auto-Labeling Normal Images

Images without bounding box annotations (empty or missing .txt files) should be labeled as "Normal" (B√¨nh th∆∞·ªùng).

This ensures the model can distinguish between healthy and abnormal X-rays.

In [None]:
# Auto-label images without annotations as "Normal"

def auto_label_normal_images(dataset_path):
    """
    Scan labels directory and assign "Normal" class to images with empty/missing .txt files.
    
    Args:
        dataset_path: Path to dataset root (contains train/, valid/, test/)
    """
    normal_count = 0
    
    for split in ['train', 'valid', 'test']:
        images_dir = Path(dataset_path) / split / 'images'
        labels_dir = Path(dataset_path) / split / 'labels'
        
        if not images_dir.exists():
            print(f"‚ö†Ô∏è Skipping {split}: images directory not found")
            continue
            
        # Create labels directory if it doesn't exist
        labels_dir.mkdir(parents=True, exist_ok=True)
        
        # Get all image files
        image_files = list(images_dir.glob('*.jpg')) + list(images_dir.glob('*.png'))
        
        print(f"\nüìÅ Processing {split} split: {len(image_files)} images")
        
        for img_path in tqdm(image_files, desc=f"Auto-labeling {split}"):
            # Get corresponding label file path
            label_path = labels_dir / f"{img_path.stem}.txt"
            
            # Check if label file exists and is non-empty
            if not label_path.exists() or label_path.stat().st_size == 0:
                # Create empty label file (indicates "Normal" - no bounding boxes)
                label_path.touch()
                normal_count += 1
        
        # Count normal vs abnormal images
        total_images = len(image_files)
        normal_images = sum(1 for f in labels_dir.glob('*.txt') if f.stat().st_size == 0)
        abnormal_images = total_images - normal_images
        
        print(f"  ‚úì {split}: {normal_images} normal, {abnormal_images} abnormal images")
    
    print(f"\n‚úì Auto-labeling complete: {normal_count} images labeled as 'Normal'")
    return normal_count

# Run auto-labeling on downloaded dataset
dataset_path = dataset.location
normal_count = auto_label_normal_images(dataset_path)

## Section 4: Class Mapping Verification

Verify that all 14 English disease classes map correctly to Vietnamese translations.

Update the dataset.yaml file with Vietnamese class names for consistency.

In [None]:
# Load class mapping from configs
with open('configs/class_mapping.json', 'r', encoding='utf-8') as f:
    class_mapping = json.load(f)

print("üìã English-Vietnamese Class Mapping:")
print("=" * 60)
for idx, (en_name, vi_name) in enumerate(class_mapping.items()):
    print(f"{idx:2d}. {en_name:25s} ‚Üí {vi_name}")
print("=" * 60)
print(f"‚úì Total classes: {len(class_mapping)}")

# Load dataset.yaml
yaml_path = Path(dataset.location) / 'data.yaml'
print(f"\nüìÑ Reading dataset config: {yaml_path}")

# Read and update data.yaml with Vietnamese names
with open(yaml_path, 'r') as f:
    yaml_content = f.read()

print("\nüîç Original data.yaml classes:")
print(yaml_content)

# Parse class names from data.yaml
import yaml
with open(yaml_path, 'r') as f:
    data_config = yaml.safe_load(f)

# Verify that all classes in dataset have Vietnamese mappings
dataset_classes = data_config.get('names', [])
print(f"\n‚úì Dataset has {len(dataset_classes)} classes")

# Check for missing mappings
missing_mappings = []
for class_name in dataset_classes:
    if class_name not in class_mapping and class_name != "Normal":
        missing_mappings.append(class_name)

if missing_mappings:
    print(f"‚ö†Ô∏è WARNING: {len(missing_mappings)} classes missing Vietnamese mapping:")
    for cls in missing_mappings:
        print(f"  - {cls}")
else:
    print("‚úì All classes have Vietnamese mappings!")

# Create Vietnamese version of class names
vietnamese_names = []
for class_name in dataset_classes:
    vi_name = class_mapping.get(class_name, class_name)
    vietnamese_names.append(vi_name)

# Update data.yaml with Vietnamese names
data_config['names_vi'] = vietnamese_names
data_config['names_en'] = dataset_classes

# Save updated data.yaml
yaml_output_path = Path(dataset.location) / 'data_vi.yaml'
with open(yaml_output_path, 'w', encoding='utf-8') as f:
    yaml.dump(data_config, f, allow_unicode=True, default_flow_style=False)

print(f"\n‚úì Updated config saved to: {yaml_output_path}")
print(f"‚úì Vietnamese names added: {len(vietnamese_names)} classes")

# Display mapping summary
print("\nüìä Class Mapping Summary:")
print("=" * 60)
for idx, (en, vi) in enumerate(zip(dataset_classes, vietnamese_names)):
    print(f"{idx:2d}. {en:25s} ‚Üí {vi}")
print("=" * 60)

## Section 5: Preprocessing and Data Augmentation

Apply custom filters from backend for preprocessing:
- Histogram Equalization (contrast enhancement)
- Gaussian Blur (noise reduction)

**Note**: This section will be implemented in tasks T062-T065

In [None]:
# Import custom filter implementations from backend
import sys
sys.path.append('../backend/src')

from filters.histogram import apply_histogram_equalization
from filters.gaussian import apply_gaussian_blur

print("‚úì Filter implementations imported successfully")
print("  - Histogram Equalization: apply_histogram_equalization()")
print("  - Gaussian Blur: apply_gaussian_blur()")

In [None]:
# Apply preprocessing filters to enhance training images (optional)

def preprocess_image(image_path, output_path):
    """
    Apply histogram equalization and Gaussian blur for contrast enhancement.
    
    Args:
        image_path: Path to input image
        output_path: Path to save preprocessed image
    """
    # Load image
    img = Image.open(image_path).convert('L')  # Convert to grayscale
    img_array = np.array(img)
    
    # Apply histogram equalization for contrast enhancement
    img_equalized = apply_histogram_equalization(img_array)
    
    # Apply Gaussian blur for noise reduction
    img_blurred = apply_gaussian_blur(img_equalized)
    
    # Save preprocessed image
    result_img = Image.fromarray(img_blurred.astype(np.uint8))
    result_img.save(output_path)
    
    return img_blurred

# Create preprocessed dataset directory
preprocessed_dir = Path(dataset.location) / 'preprocessed'
preprocessed_dir.mkdir(exist_ok=True)

print("üîÑ Preprocessing training images...")
print(f"üìÅ Output directory: {preprocessed_dir}")

# Preprocess training images (optional - can be used for augmentation)
train_images_dir = Path(dataset.location) / 'train' / 'images'
preprocessed_train_dir = preprocessed_dir / 'train' / 'images'
preprocessed_train_dir.mkdir(parents=True, exist_ok=True)

# Sample preprocessing (process first 100 images as demo)
sample_images = list(train_images_dir.glob('*.jpg'))[:100]

if sample_images:
    print(f"üìä Processing {len(sample_images)} sample images...")
    
    for img_path in tqdm(sample_images, desc="Preprocessing"):
        output_path = preprocessed_train_dir / img_path.name
        try:
            preprocess_image(img_path, output_path)
        except Exception as e:
            print(f"‚ö†Ô∏è Error processing {img_path.name}: {e}")
    
    print(f"‚úì Preprocessing complete: {len(list(preprocessed_train_dir.glob('*.jpg')))} images processed")
else:
    print("‚ö†Ô∏è No training images found to preprocess")

print("\nüí° Note: Preprocessing is optional. Original images can be used for training.")
print("   Preprocessed images are stored separately and can be used for comparison.")

In [None]:
# Data augmentation with filter-based and geometric transforms

import random

def augment_image(img_array):
    """
    Apply random augmentations to an image.
    
    Args:
        img_array: Input image as numpy array
        
    Returns:
        Augmented image as numpy array
    """
    # Random filter application (50% chance)
    if random.random() > 0.5:
        if random.random() > 0.5:
            img_array = apply_histogram_equalization(img_array)
        else:
            img_array = apply_gaussian_blur(img_array)
    
    # Convert to PIL for geometric transforms
    img = Image.fromarray(img_array.astype(np.uint8))
    
    # Random horizontal flip (50% chance)
    if random.random() > 0.5:
        img = img.transpose(Image.FLIP_LEFT_RIGHT)
    
    # Random rotation (-15 to 15 degrees)
    if random.random() > 0.5:
        angle = random.uniform(-15, 15)
        img = img.rotate(angle, fillcolor=0)
    
    # Random brightness adjustment (0.8 to 1.2)
    if random.random() > 0.5:
        from PIL import ImageEnhance
        enhancer = ImageEnhance.Brightness(img)
        factor = random.uniform(0.8, 1.2)
        img = enhancer.enhance(factor)
    
    # Random contrast adjustment (0.8 to 1.2)
    if random.random() > 0.5:
        from PIL import ImageEnhance
        enhancer = ImageEnhance.Contrast(img)
        factor = random.uniform(0.8, 1.2)
        img = enhancer.enhance(factor)
    
    return np.array(img)

# Create augmented dataset directory
augmented_dir = Path(dataset.location) / 'augmented'
augmented_train_dir = augmented_dir / 'train' / 'images'
augmented_labels_dir = augmented_dir / 'train' / 'labels'
augmented_train_dir.mkdir(parents=True, exist_ok=True)
augmented_labels_dir.mkdir(parents=True, exist_ok=True)

print("üé≤ Data Augmentation Configuration:")
print("=" * 60)
print("Augmentation techniques:")
print("  ‚úì Random histogram equalization (50% chance)")
print("  ‚úì Random Gaussian blur (50% chance)")
print("  ‚úì Random horizontal flip (50% chance)")
print("  ‚úì Random rotation ¬±15¬∞ (50% chance)")
print("  ‚úì Random brightness 0.8-1.2x (50% chance)")
print("  ‚úì Random contrast 0.8-1.2x (50% chance)")
print("=" * 60)

# Sample augmentation (augment first 50 images as demo)
train_images_dir = Path(dataset.location) / 'train' / 'images'
train_labels_dir = Path(dataset.location) / 'train' / 'labels'
sample_images = list(train_images_dir.glob('*.jpg'))[:50]

if sample_images:
    print(f"\nüìä Augmenting {len(sample_images)} sample images...")
    
    augmented_count = 0
    for img_path in tqdm(sample_images, desc="Augmenting"):
        try:
            # Load and augment image
            img = Image.open(img_path).convert('L')
            img_array = np.array(img)
            augmented_array = augment_image(img_array)
            
            # Save augmented image
            aug_img_path = augmented_train_dir / f"aug_{img_path.name}"
            Image.fromarray(augmented_array.astype(np.uint8)).save(aug_img_path)
            
            # Copy corresponding label file
            label_path = train_labels_dir / f"{img_path.stem}.txt"
            if label_path.exists():
                aug_label_path = augmented_labels_dir / f"aug_{img_path.stem}.txt"
                shutil.copy(label_path, aug_label_path)
            
            augmented_count += 1
        except Exception as e:
            print(f"‚ö†Ô∏è Error augmenting {img_path.name}: {e}")
    
    print(f"\n‚úì Augmentation complete: {augmented_count} images augmented")
    print(f"üìÅ Augmented images: {augmented_train_dir}")
    print(f"üìÅ Augmented labels: {augmented_labels_dir}")
else:
    print("‚ö†Ô∏è No training images found for augmentation")

print("\nüí° Note: Augmented images can be merged with original training set")
print("   to increase dataset diversity and improve model generalization.")

In [None]:
# Dataset statistics visualization

def analyze_dataset(dataset_path):
    """
    Analyze dataset and collect statistics.
    
    Args:
        dataset_path: Path to dataset root
        
    Returns:
        Dictionary with dataset statistics
    """
    stats = {
        'splits': {},
        'class_distribution': {},
        'image_dimensions': [],
        'normal_count': 0,
        'abnormal_count': 0
    }
    
    for split in ['train', 'valid', 'test']:
        images_dir = Path(dataset_path) / split / 'images'
        labels_dir = Path(dataset_path) / split / 'labels'
        
        if not images_dir.exists():
            continue
        
        image_files = list(images_dir.glob('*.jpg')) + list(images_dir.glob('*.png'))
        
        split_stats = {
            'total_images': len(image_files),
            'normal': 0,
            'abnormal': 0,
            'total_annotations': 0
        }
        
        # Analyze each image
        for img_path in image_files:
            # Get image dimensions
            try:
                img = Image.open(img_path)
                stats['image_dimensions'].append(img.size)
            except:
                pass
            
            # Check labels
            label_path = labels_dir / f"{img_path.stem}.txt"
            if label_path.exists() and label_path.stat().st_size > 0:
                split_stats['abnormal'] += 1
                stats['abnormal_count'] += 1
                # Count annotations
                with open(label_path, 'r') as f:
                    annotations = f.readlines()
                    split_stats['total_annotations'] += len(annotations)
            else:
                split_stats['normal'] += 1
                stats['normal_count'] += 1
        
        stats['splits'][split] = split_stats
    
    return stats

# Analyze original dataset
print("üìä Dataset Statistics Analysis")
print("=" * 80)

dataset_stats = analyze_dataset(dataset.location)

# Display split statistics
print("\nüóÇÔ∏è Dataset Splits:")
print("-" * 80)
for split, split_stats in dataset_stats['splits'].items():
    total = split_stats['total_images']
    normal = split_stats['normal']
    abnormal = split_stats['abnormal']
    normal_pct = (normal / total * 100) if total > 0 else 0
    abnormal_pct = (abnormal / total * 100) if total > 0 else 0
    
    print(f"\n{split.upper()}:")
    print(f"  Total images:      {total:5d}")
    print(f"  Normal (healthy):  {normal:5d} ({normal_pct:5.1f}%)")
    print(f"  Abnormal:          {abnormal:5d} ({abnormal_pct:5.1f}%)")
    print(f"  Total annotations: {split_stats['total_annotations']:5d}")

# Overall statistics
print("\n" + "=" * 80)
print("üìà Overall Dataset Statistics:")
print("-" * 80)
total_images = dataset_stats['normal_count'] + dataset_stats['abnormal_count']
normal_pct = (dataset_stats['normal_count'] / total_images * 100) if total_images > 0 else 0
abnormal_pct = (dataset_stats['abnormal_count'] / total_images * 100) if total_images > 0 else 0

print(f"Total images:      {total_images:5d}")
print(f"Normal (healthy):  {dataset_stats['normal_count']:5d} ({normal_pct:5.1f}%)")
print(f"Abnormal:          {dataset_stats['abnormal_count']:5d} ({abnormal_pct:5.1f}%)")

# Image dimensions statistics
if dataset_stats['image_dimensions']:
    widths = [dim[0] for dim in dataset_stats['image_dimensions']]
    heights = [dim[1] for dim in dataset_stats['image_dimensions']]
    
    print(f"\nüìê Image Dimensions:")
    print(f"  Width  - Min: {min(widths):4d}px, Max: {max(widths):4d}px, Avg: {sum(widths)//len(widths):4d}px")
    print(f"  Height - Min: {min(heights):4d}px, Max: {max(heights):4d}px, Avg: {sum(heights)//len(heights):4d}px")

print("=" * 80)

# Class balance analysis
print("\n‚öñÔ∏è Dataset Balance Analysis:")
if normal_pct > 80:
    print("  ‚ö†Ô∏è WARNING: Dataset is heavily imbalanced towards normal images")
    print("     Consider using weighted loss or oversampling abnormal cases")
elif abnormal_pct > 80:
    print("  ‚ö†Ô∏è WARNING: Dataset is heavily imbalanced towards abnormal images")
    print("     This is unusual for medical datasets")
elif 40 <= normal_pct <= 60:
    print("  ‚úì Dataset is well balanced between normal and abnormal cases")
else:
    print("  ‚ÑπÔ∏è Dataset has moderate class imbalance")
    print("     Consider monitoring per-class metrics during training")

print("\n‚úì Dataset analysis complete")

## Section 6: Model Training

Fine-tune YOLOv11s with:
- Base model: yolov11s.pt
- Epochs: 50
- Batch size: 16
- Image size: 640
- Early stopping patience: 10

**Note**: This section will be implemented in tasks T066-T070

In [None]:
# Initialize Weights & Biases for experiment tracking

import wandb

# Initialize WandB project
wandb.init(
    project="chest-xray-detection",
    name="yolov11s-finetune",
    config={
        "model": "YOLOv11s",
        "dataset": "VinBigData Chest X-ray v3",
        "epochs": 50,
        "batch_size": 16,
        "image_size": 640,
        "patience": 10,
        "optimizer": "AdamW",
        "learning_rate": 0.001,
        "augmentation": "enabled",
        "preprocessing": "histogram_eq + gaussian_blur"
    }
)

print("‚úì WandB initialized successfully")
print(f"  Project: chest-xray-detection")
print(f"  Run name: {wandb.run.name}")
print(f"  Run URL: {wandb.run.url}")
print("\nüìä Hyperparameters logged:")
for key, value in wandb.config.items():
    print(f"  - {key}: {value}")

In [None]:
# Train YOLOv11s model with integrated progress tracking and logging

from ultralytics import YOLO

# Training configuration
training_config = {
    'data': str(Path(dataset.location) / 'data.yaml'),
    'epochs': 50,
    'batch': 16,
    'imgsz': 640,
    'patience': 10,
    'save': True,
    'plots': True,
    'verbose': True,
    'device': 'cuda' if os.system('nvidia-smi > /dev/null 2>&1') == 0 else 'cpu'
}

print("üöÄ Starting YOLOv11s Training")
print("=" * 80)
print("\n‚öôÔ∏è Training Configuration:")
for key, value in training_config.items():
    print(f"  {key:15s}: {value}")
print("=" * 80)

# Check device
device_info = training_config['device']
if device_info == 'cuda':
    print("\nüéÆ GPU detected: Training will use CUDA acceleration")
else:
    print("\nüíª No GPU detected: Training will use CPU (slower)")

# Load base YOLOv11s model
print("\nüì¶ Loading base YOLOv11s model...")
model = YOLO('yolov11s.pt')

print("‚úì Model loaded successfully")
print(f"  Model architecture: YOLOv11s")
print(f"  Parameters: ~{sum(p.numel() for p in model.model.parameters()) / 1e6:.1f}M")

# Train the model with WandB integration
# Note: Ultralytics automatically integrates with WandB when wandb.init() has been called
print("\nüèãÔ∏è Starting training...")
print("üìä Progress will be tracked with tqdm and logged to WandB")
print("-" * 80)

try:
    results = model.train(
        **training_config,
        # WandB is automatically integrated by ultralytics
        project='chest-xray-detection',
        name='yolov11s-finetune'
    )
    
    print("\n" + "=" * 80)
    print("‚úì Training completed successfully!")
    print("=" * 80)
    
    # Display training results
    print("\nüìà Training Results:")
    print(f"  Best epoch: {results.best_epoch if hasattr(results, 'best_epoch') else 'N/A'}")
    print(f"  Best mAP50: {results.results_dict.get('metrics/mAP50(B)', 'N/A')}")
    print(f"  Best mAP50-95: {results.results_dict.get('metrics/mAP50-95(B)', 'N/A')}")
    print(f"  Final loss: {results.results_dict.get('train/box_loss', 'N/A')}")
    
    # Save best model path
    best_model_path = results.save_dir / 'weights' / 'best.pt'
    print(f"\nüíæ Best model saved to: {best_model_path}")
    
except Exception as e:
    print(f"\n‚ùå Training failed with error: {e}")
    print("Please check the error message above and try again")
    raise

print("\n‚úì Training phase complete")

## Section 7: Validation and Analysis

Test trained model on validation set:
- Calculate mAP (mean Average Precision)
- Generate confusion matrix
- Visualize predictions with Vietnamese labels

**Note**: This section will be implemented in tasks T071-T074

In [None]:
# Validate trained model on test set

print("üß™ Model Validation on Test Set")
print("=" * 80)

# Load the fine-tuned model
if best_model_path.exists():
    print(f"üì¶ Loading fine-tuned model: {best_model_path}")
    model = YOLO(str(best_model_path))
    print("‚úì Model loaded successfully")
else:
    print("‚ö†Ô∏è WARNING: Fine-tuned model not found, using last trained model")
    model = YOLO('runs/detect/train/weights/best.pt')

# Run validation on test set
test_data_yaml = Path(dataset.location) / 'data.yaml'

print(f"\nüîç Running validation on test set...")
print(f"üìÑ Data config: {test_data_yaml}")
print("-" * 80)

try:
    # Validate the model
    metrics = model.val(data=str(test_data_yaml), split='test')
    
    print("\n" + "=" * 80)
    print("üìä Validation Results:")
    print("=" * 80)
    
    # Display key metrics
    results_dict = metrics.results_dict
    
    print("\nüéØ Overall Metrics:")
    print(f"  mAP50:       {results_dict.get('metrics/mAP50(B)', 0):.4f}")
    print(f"  mAP50-95:    {results_dict.get('metrics/mAP50-95(B)', 0):.4f}")
    print(f"  Precision:   {results_dict.get('metrics/precision(B)', 0):.4f}")
    print(f"  Recall:      {results_dict.get('metrics/recall(B)', 0):.4f}")
    
    # Per-class metrics if available
    if hasattr(metrics, 'box'):
        print("\nüìã Per-Class Metrics:")
        print(f"  {'Class':<25s} {'Precision':>10s} {'Recall':>10s} {'mAP50':>10s}")
        print("  " + "-" * 60)
        
        # Get class names
        class_names = model.names
        
        # Display metrics for each class
        for class_id, class_name in class_names.items():
            if hasattr(metrics.box, 'class_result'):
                try:
                    class_metrics = metrics.box.class_result(class_id)
                    p = class_metrics[0] if len(class_metrics) > 0 else 0
                    r = class_metrics[1] if len(class_metrics) > 1 else 0
                    ap50 = class_metrics[2] if len(class_metrics) > 2 else 0
                    print(f"  {class_name:<25s} {p:>10.4f} {r:>10.4f} {ap50:>10.4f}")
                except:
                    pass
    
    # Confusion matrix location
    confusion_matrix_path = Path('runs/detect/val/confusion_matrix.png')
    if confusion_matrix_path.exists():
        print(f"\nüìà Confusion matrix saved to: {confusion_matrix_path}")
    
    print("\n‚úì Validation complete")
    
except Exception as e:
    print(f"\n‚ùå Validation failed: {e}")
    print("   Please ensure training was completed successfully")
    raise

print("=" * 80)

In [None]:
# Verify Vietnamese label mapping in predictions

print("üåê Vietnamese Label Verification")
print("=" * 80)

# Load class mapping
with open('configs/class_mapping.json', 'r', encoding='utf-8') as f:
    class_mapping = json.load(f)

print("\nüìã Verifying model predictions use correct Vietnamese labels...")

# Get model's class names
model_classes = model.names

print(f"\n‚úì Model has {len(model_classes)} classes")
print("\nüîç Class Mapping Verification:")
print("-" * 80)
print(f"{'Class ID':<10s} {'English Name':<30s} {'Vietnamese Name':<30s} {'Status':<10s}")
print("-" * 80)

# Verify each class
mapping_errors = []
for class_id, class_name_en in model_classes.items():
    # Check if English name has Vietnamese mapping
    if class_name_en in class_mapping:
        class_name_vi = class_mapping[class_name_en]
        status = "‚úì OK"
    else:
        class_name_vi = "MISSING"
        status = "‚úó ERROR"
        mapping_errors.append(class_name_en)
    
    print(f"{class_id:<10d} {class_name_en:<30s} {class_name_vi:<30s} {status:<10s}")

print("-" * 80)

if mapping_errors:
    print(f"\n‚ùå Found {len(mapping_errors)} classes without Vietnamese mapping:")
    for cls in mapping_errors:
        print(f"   - {cls}")
    print("\n‚ö†Ô∏è WARNING: These classes will display in English in the application")
    print("   Update configs/class_mapping.json to add missing translations")
else:
    print("\n‚úì All classes have Vietnamese translations!")
    print("  Predictions will display correctly in the web application")

# Create reverse mapping (Vietnamese to English) for validation
reverse_mapping = {vi: en for en, vi in class_mapping.items()}

print(f"\nüìä Mapping Statistics:")
print(f"  Total classes in model:  {len(model_classes)}")
print(f"  Classes with Vietnamese: {len(model_classes) - len(mapping_errors)}")
print(f"  Classes missing Vietnamese: {len(mapping_errors)}")
print(f"  Coverage: {(len(model_classes) - len(mapping_errors)) / len(model_classes) * 100:.1f}%")

# Test prediction with Vietnamese labels (if model is loaded)
print("\nüß™ Testing prediction with Vietnamese labels...")
test_images_dir = Path(dataset.location) / 'test' / 'images'
if test_images_dir.exists():
    test_images = list(test_images_dir.glob('*.jpg'))[:1]  # Test with first image
    
    if test_images:
        test_img = test_images[0]
        print(f"   Test image: {test_img.name}")
        
        # Run prediction
        results = model(test_img, verbose=False)
        
        if len(results) > 0 and len(results[0].boxes) > 0:
            print(f"   ‚úì Found {len(results[0].boxes)} detections")
            
            # Display detections with Vietnamese labels
            for box in results[0].boxes[:3]:  # Show first 3
                class_id = int(box.cls[0])
                confidence = float(box.conf[0])
                class_name_en = model_classes[class_id]
                class_name_vi = class_mapping.get(class_name_en, class_name_en)
                
                print(f"     - {class_name_vi} ({class_name_en}): {confidence:.2%}")
        else:
            print("   ‚úì No detections (image is normal)")
    else:
        print("   ‚ö†Ô∏è No test images found")
else:
    print("   ‚ö†Ô∏è Test images directory not found")

print("\n" + "=" * 80)
print("‚úì Vietnamese label verification complete")
print("=" * 80)

In [None]:
# Visualize sample predictions with Vietnamese labels

import matplotlib.pyplot as plt
import matplotlib.patches as patches

print("üñºÔ∏è Sample Prediction Visualization")
print("=" * 80)

# Get test images
test_images_dir = Path(dataset.location) / 'test' / 'images'

if not test_images_dir.exists():
    print("‚ö†Ô∏è Test images directory not found")
else:
    # Select sample images (mix of normal and abnormal if possible)
    test_images = list(test_images_dir.glob('*.jpg'))[:6]  # Show 6 samples
    
    if not test_images:
        print("‚ö†Ô∏è No test images found")
    else:
        print(f"üìä Visualizing predictions for {len(test_images)} sample images...")
        
        # Create figure with subplots
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        axes = axes.flatten()
        
        for idx, img_path in enumerate(test_images):
            if idx >= len(axes):
                break
            
            ax = axes[idx]
            
            # Load image
            img = Image.open(img_path)
            img_array = np.array(img)
            
            # Run prediction
            results = model(img_path, verbose=False)
            
            # Display image
            ax.imshow(img_array, cmap='gray')
            ax.set_title(f"{img_path.name}", fontsize=10, pad=10)
            ax.axis('off')
            
            # Draw bounding boxes with Vietnamese labels
            if len(results) > 0 and len(results[0].boxes) > 0:
                boxes = results[0].boxes
                
                for box in boxes:
                    # Get box coordinates (xyxy format)
                    x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
                    confidence = float(box.conf[0])
                    class_id = int(box.cls[0])
                    
                    # Get Vietnamese label
                    class_name_en = model_classes[class_id]
                    class_name_vi = class_mapping.get(class_name_en, class_name_en)
                    
                    # Determine box style based on confidence
                    if confidence > 0.7:
                        color = 'red'
                        linestyle = '-'  # Solid
                        label_prefix = 'üî¥'
                    elif confidence > 0.4:
                        color = 'orange'
                        linestyle = '--'  # Dashed
                        label_prefix = 'üü†'
                    else:
                        continue  # Skip low confidence
                    
                    # Draw bounding box
                    width = x2 - x1
                    height = y2 - y1
                    rect = patches.Rectangle(
                        (x1, y1), width, height,
                        linewidth=2,
                        edgecolor=color,
                        facecolor='none',
                        linestyle=linestyle
                    )
                    ax.add_patch(rect)
                    
                    # Add label with Vietnamese name and confidence
                    label_text = f"{label_prefix} {class_name_vi}\n{confidence:.1%}"
                    ax.text(
                        x1, y1 - 10,
                        label_text,
                        color='white',
                        fontsize=9,
                        bbox=dict(
                            boxstyle='round,pad=0.5',
                            facecolor=color,
                            alpha=0.8,
                            edgecolor='none'
                        ),
                        verticalalignment='bottom'
                    )
                
                # Add detection count
                detection_text = f"Ph√°t hi·ªán: {len(boxes)} b·∫•t th∆∞·ªùng"
                ax.text(
                    10, img_array.shape[0] - 10,
                    detection_text,
                    color='white',
                    fontsize=10,
                    bbox=dict(
                        boxstyle='round,pad=0.5',
                        facecolor='blue',
                        alpha=0.7,
                        edgecolor='none'
                    ),
                    verticalalignment='bottom'
                )
            else:
                # Normal image - no detections
                normal_text = "‚úì B√¨nh th∆∞·ªùng\n(Kh√¥ng ph√°t hi·ªán b·∫•t th∆∞·ªùng)"
                ax.text(
                    img_array.shape[1] // 2, img_array.shape[0] // 2,
                    normal_text,
                    color='white',
                    fontsize=12,
                    ha='center',
                    va='center',
                    bbox=dict(
                        boxstyle='round,pad=1',
                        facecolor='green',
                        alpha=0.8,
                        edgecolor='none'
                    )
                )
        
        plt.tight_layout()
        plt.savefig('sample_predictions_vi.png', dpi=150, bbox_inches='tight')
        print(f"\n‚úì Visualization saved to: sample_predictions_vi.png")
        plt.show()
        
        # Legend
        print("\nüìñ Legend:")
        print("  üî¥ Solid red box    = High confidence (>70%)")
        print("  üü† Dashed orange box = Medium confidence (40-70%)")
        print("  ‚úì Green overlay     = Normal (no abnormalities)")
        
        print("\nüí° All labels are displayed in Vietnamese as configured")
        print("   This matches the production behavior in the web application")

print("\n" + "=" * 80)
print("‚úì Sample prediction visualization complete")
print("=" * 80)

In [None]:
# Compare base YOLOv11s vs fine-tuned model performance

print("‚öñÔ∏è Model Comparison: Base vs Fine-Tuned")
print("=" * 80)

# Load both models
print("\nüì¶ Loading models for comparison...")

# Fine-tuned model (already loaded)
finetuned_model = model
print(f"‚úì Fine-tuned model: {target_model_path}")

# Base model
base_model = YOLO('yolov11s.pt')
print(f"‚úì Base model: yolov11s.pt (pretrained on COCO)")

# Prepare test data
test_data_yaml = Path(dataset.location) / 'data.yaml'

print("\nüî¨ Evaluating both models on test set...")
print("-" * 80)

# Evaluate fine-tuned model
print("\n1Ô∏è‚É£ Fine-tuned Model Evaluation:")
try:
    finetuned_metrics = finetuned_model.val(data=str(test_data_yaml), split='test', verbose=False)
    finetuned_results = finetuned_metrics.results_dict
    
    finetuned_map50 = finetuned_results.get('metrics/mAP50(B)', 0)
    finetuned_map50_95 = finetuned_results.get('metrics/mAP50-95(B)', 0)
    finetuned_precision = finetuned_results.get('metrics/precision(B)', 0)
    finetuned_recall = finetuned_results.get('metrics/recall(B)', 0)
    
    print(f"   mAP50:       {finetuned_map50:.4f}")
    print(f"   mAP50-95:    {finetuned_map50_95:.4f}")
    print(f"   Precision:   {finetuned_precision:.4f}")
    print(f"   Recall:      {finetuned_recall:.4f}")
except Exception as e:
    print(f"   ‚ùå Evaluation failed: {e}")
    finetuned_map50 = 0
    finetuned_map50_95 = 0
    finetuned_precision = 0
    finetuned_recall = 0

# Evaluate base model
print("\n2Ô∏è‚É£ Base Model Evaluation:")
try:
    base_metrics = base_model.val(data=str(test_data_yaml), split='test', verbose=False)
    base_results = base_metrics.results_dict
    
    base_map50 = base_results.get('metrics/mAP50(B)', 0)
    base_map50_95 = base_results.get('metrics/mAP50-95(B)', 0)
    base_precision = base_results.get('metrics/precision(B)', 0)
    base_recall = base_results.get('metrics/recall(B)', 0)
    
    print(f"   mAP50:       {base_map50:.4f}")
    print(f"   mAP50-95:    {base_map50_95:.4f}")
    print(f"   Precision:   {base_precision:.4f}")
    print(f"   Recall:      {base_recall:.4f}")
except Exception as e:
    print(f"   ‚ùå Evaluation failed: {e}")
    base_map50 = 0
    base_map50_95 = 0
    base_precision = 0
    base_recall = 0

# Calculate improvements
print("\n" + "=" * 80)
print("üìä Performance Comparison:")
print("=" * 80)

def calc_improvement(base, finetuned):
    """Calculate percentage improvement"""
    if base == 0:
        return 0
    return ((finetuned - base) / base) * 100

improvements = {
    'mAP50': calc_improvement(base_map50, finetuned_map50),
    'mAP50-95': calc_improvement(base_map50_95, finetuned_map50_95),
    'Precision': calc_improvement(base_precision, finetuned_precision),
    'Recall': calc_improvement(base_recall, finetuned_recall)
}

# Display comparison table
print(f"\n{'Metric':<15s} {'Base':>12s} {'Fine-tuned':>12s} {'Improvement':>15s}")
print("-" * 60)
print(f"{'mAP50':<15s} {base_map50:>12.4f} {finetuned_map50:>12.4f} {improvements['mAP50']:>14.1f}%")
print(f"{'mAP50-95':<15s} {base_map50_95:>12.4f} {finetuned_map50_95:>12.4f} {improvements['mAP50-95']:>14.1f}%")
print(f"{'Precision':<15s} {base_precision:>12.4f} {finetuned_precision:>12.4f} {improvements['Precision']:>14.1f}%")
print(f"{'Recall':<15s} {base_recall:>12.4f} {finetuned_recall:>12.4f} {improvements['Recall']:>14.1f}%")
print("-" * 60)

# Overall assessment
avg_improvement = sum(improvements.values()) / len(improvements)
print(f"\nüìà Average Improvement: {avg_improvement:+.1f}%")

if avg_improvement > 20:
    print("\nüéâ EXCELLENT: Fine-tuning significantly improved model performance!")
    print("   The model is well-suited for chest X-ray abnormality detection")
elif avg_improvement > 10:
    print("\n‚úì GOOD: Fine-tuning improved model performance")
    print("   The model shows better accuracy on the target dataset")
elif avg_improvement > 0:
    print("\n‚úì MODERATE: Fine-tuning provided some improvement")
    print("   Consider additional training epochs or data augmentation")
else:
    print("\n‚ö†Ô∏è WARNING: Fine-tuning did not improve performance")
    print("   Check training logs, hyperparameters, or dataset quality")

# Visualize comparison
fig, ax = plt.subplots(1, 1, figsize=(10, 6))

metrics_names = ['mAP50', 'mAP50-95', 'Precision', 'Recall']
base_values = [base_map50, base_map50_95, base_precision, base_recall]
finetuned_values = [finetuned_map50, finetuned_map50_95, finetuned_precision, finetuned_recall]

x = np.arange(len(metrics_names))
width = 0.35

bars1 = ax.bar(x - width/2, base_values, width, label='Base YOLOv11s', color='lightblue')
bars2 = ax.bar(x + width/2, finetuned_values, width, label='Fine-tuned', color='darkblue')

ax.set_xlabel('Metrics', fontsize=12)
ax.set_ylabel('Score', fontsize=12)
ax.set_title('Base vs Fine-tuned Model Performance', fontsize=14, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(metrics_names)
ax.legend()
ax.grid(axis='y', alpha=0.3)

# Add value labels on bars
for bars in [bars1, bars2]:
    for bar in bars:
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{height:.3f}',
                ha='center', va='bottom', fontsize=9)

plt.tight_layout()
plt.savefig('model_comparison.png', dpi=150, bbox_inches='tight')
print(f"\n‚úì Comparison chart saved to: model_comparison.png")
plt.show()

print("\n" + "=" * 80)
print("‚úì Model comparison complete")
print("=" * 80)

print("\nüéØ Recommendation:")
if avg_improvement > 10:
    print("   ‚úì Use the fine-tuned model in production")
    print(f"   ‚úì Model path: {target_model_path}")
else:
    print("   ‚ö†Ô∏è Review training configuration and dataset quality")
    print("   ‚ö†Ô∏è Consider longer training or different hyperparameters")

## Section 8: Model Export

Export best trained weights to backend models directory for production use.

**Output path**: `backend/models/yolov11s_finetuned.pt`

In [None]:
# Export best model weights to backend for production use

# Define target path for backend
backend_models_dir = Path('../backend/models')
backend_models_dir.mkdir(parents=True, exist_ok=True)

target_model_path = backend_models_dir / 'yolov11s_finetuned.pt'

print("üì¶ Exporting Model Weights")
print("=" * 80)

# Check if training was completed
if 'best_model_path' not in locals():
    print("‚ö†Ô∏è WARNING: Training results not found in this session")
    print("   Attempting to find best model from runs directory...")
    
    # Try to find the most recent training run
    runs_dir = Path('runs/detect')
    if runs_dir.exists():
        train_dirs = sorted([d for d in runs_dir.iterdir() if d.is_dir()], 
                          key=lambda x: x.stat().st_mtime, reverse=True)
        if train_dirs:
            best_model_path = train_dirs[0] / 'weights' / 'best.pt'
            print(f"   Found: {best_model_path}")
        else:
            print("‚ùå No training runs found. Please run training first.")
            best_model_path = None
    else:
        print("‚ùå No runs directory found. Please run training first.")
        best_model_path = None
else:
    print(f"‚úì Using best model from training: {best_model_path}")

# Copy model to backend
if best_model_path and Path(best_model_path).exists():
    print(f"\nüìÇ Source: {best_model_path}")
    print(f"üìÇ Target: {target_model_path}")
    
    # Copy the model file
    shutil.copy(best_model_path, target_model_path)
    
    # Verify the copy
    if target_model_path.exists():
        source_size = Path(best_model_path).stat().st_size
        target_size = target_model_path.stat().st_size
        
        print(f"\n‚úì Model exported successfully!")
        print(f"  File size: {target_size / (1024*1024):.2f} MB")
        
        if source_size == target_size:
            print(f"  ‚úì Checksum verified: File copied correctly")
        else:
            print(f"  ‚ö†Ô∏è WARNING: File sizes differ (source: {source_size}, target: {target_size})")
        
        print(f"\nüéØ Model is ready for production use!")
        print(f"   Backend can now load weights from: {target_model_path}")
        print(f"   Update backend.src.models.yolo_detector.py to use this path")
    else:
        print("‚ùå Error: Model export failed - file not found at target location")
else:
    print("\n‚ùå Error: Cannot export model - best weights file not found")
    print("   Please complete training first")

# Test model loading (optional)
try:
    print("\nüîç Testing model loading...")
    test_model = YOLO(str(target_model_path))
    print("‚úì Model loads successfully in Ultralytics YOLO")
    print(f"  Model type: {type(test_model.model).__name__}")
    print(f"  Number of classes: {len(test_model.names)}")
    print(f"  Class names: {list(test_model.names.values())[:5]}...")  # Show first 5 classes
except Exception as e:
    print(f"‚ö†Ô∏è Warning: Model loading test failed: {e}")

print("\n" + "=" * 80)
print("‚úì Model export complete")
print("=" * 80)

# Close WandB run
wandb.finish()
print("\n‚úì WandB run finished")