# YOLOv8 Malaria Detection Training
## Clinical Parasite Detection in Microscope Images

This notebook trains a YOLOv8 model for malaria parasite detection using the Kaggle cell dataset.

## 1. Setup Environment

In [None]:
# Check GPU availability
!nvidia-smi
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"GPU count: {torch.cuda.device_count()}")
if torch.cuda.is_available():
    print(f"GPU name: {torch.cuda.get_device_name(0)}")

In [None]:
# Install dependencies
!pip install ultralytics kagglehub wandb -q
!pip install opencv-python matplotlib seaborn scikit-learn -q

In [None]:
# Import libraries
import os
import shutil
import random
import cv2
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from ultralytics import YOLO
import kagglehub
import yaml
from PIL import Image
import torch

# Set random seeds
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

## 2. Download and Prepare Dataset

In [None]:
# Download Kaggle dataset
print("Downloading Kaggle malaria dataset...")
kaggle_path = kagglehub.dataset_download("iarunava/cell-images-for-detecting-malaria")
print(f"Dataset downloaded to: {kaggle_path}")

In [None]:
def generate_cell_bbox(image_path, padding_ratio=0.1):
    """Generate bounding box for cell using contour detection."""
    image = cv2.imread(image_path)
    if image is None:
        return (0.5, 0.5, 0.8, 0.8)
    
    h, w = image.shape[:2]
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    blurred = cv2.GaussianBlur(gray, (5, 5), 0)
    _, thresh = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    
    contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    if contours:
        largest_contour = max(contours, key=cv2.contourArea)
        x, y, bbox_w, bbox_h = cv2.boundingRect(largest_contour)
        
        pad_x = int(bbox_w * padding_ratio)
        pad_y = int(bbox_h * padding_ratio)
        
        x = max(0, x - pad_x)
        y = max(0, y - pad_y)
        bbox_w = min(w - x, bbox_w + 2 * pad_x)
        bbox_h = min(h - y, bbox_h + 2 * pad_y)
        
        center_x = (x + bbox_w / 2) / w
        center_y = (y + bbox_h / 2) / h
        norm_width = bbox_w / w
        norm_height = bbox_h / h
        
        return (center_x, center_y, norm_width, norm_height)
    else:
        return (0.5, 0.5, 0.8, 0.8)

In [None]:
# Create YOLOv8 dataset structure
yolo_path = Path("yolo_malaria")
for split in ['train', 'val', 'test']:
    (yolo_path / split / "images").mkdir(parents=True, exist_ok=True)
    (yolo_path / split / "labels").mkdir(parents=True, exist_ok=True)

print(f"Created YOLOv8 structure at: {yolo_path}")

In [None]:
# Convert dataset to YOLO format
kaggle_path = Path(kaggle_path)
cell_images_path = None

for root, dirs, files in os.walk(kaggle_path):
    if 'Parasitized' in dirs and 'Uninfected' in dirs:
        cell_images_path = Path(root)
        break

print(f"Found cell images at: {cell_images_path}")

# Process images
all_files = []
for class_name in ['Parasitized', 'Uninfected']:
    class_path = cell_images_path / class_name
    class_files = list(class_path.glob('*.png'))
    print(f"Found {len(class_files)} images in {class_name}")
    
    class_id = 0 if class_name == 'Parasitized' else None
    for img_path in class_files:
        all_files.append((img_path, class_id))

# Split dataset
random.shuffle(all_files)
total = len(all_files)
train_end = int(total * 0.7)
val_end = int(total * 0.9)

splits = {
    'train': all_files[:train_end],
    'val': all_files[train_end:val_end],
    'test': all_files[val_end:]
}

print(f"Dataset split: Train={len(splits['train'])}, Val={len(splits['val'])}, Test={len(splits['test'])}")

In [None]:
# Convert and copy files
for split_name, files in splits.items():
    print(f"Processing {split_name} split...")
    
    images_dir = yolo_path / split_name / "images"
    labels_dir = yolo_path / split_name / "labels"
    
    for i, (img_path, class_id) in enumerate(files):
        if i % 1000 == 0:
            print(f"  Processed {i}/{len(files)} images")
            
        new_img_name = f"{split_name}_{i:06d}.png"
        new_img_path = images_dir / new_img_name
        shutil.copy2(img_path, new_img_path)
        
        label_path = labels_dir / f"{split_name}_{i:06d}.txt"
        
        if class_id is not None:  # Parasitized
            bbox = generate_cell_bbox(str(img_path))
            with open(label_path, 'w') as f:
                f.write(f"{class_id} {bbox[0]:.6f} {bbox[1]:.6f} {bbox[2]:.6f} {bbox[3]:.6f}\n")
        else:  # Uninfected
            label_path.touch()
    
    print(f"Completed {split_name}: {len(files)} images")

In [None]:
# Create data.yaml
yaml_content = f"""path: {yolo_path.absolute()}
train: train/images
val: val/images
test: test/images

nc: 1
names: ['malaria_parasite']
"""

yaml_path = yolo_path / "malaria_data.yaml"
with open(yaml_path, 'w') as f:
    f.write(yaml_content)

print(f"Created data.yaml at: {yaml_path}")

## 3. Visualize Dataset Samples

In [None]:
# Visualize samples
import matplotlib.patches as patches

train_images = list((yolo_path / "train" / "images").glob("*.png"))
samples = random.sample(train_images, 6)

fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

for i, img_path in enumerate(samples):
    image = Image.open(img_path)
    w, h = image.size
    
    label_path = yolo_path / "train" / "labels" / f"{img_path.stem}.txt"
    
    axes[i].imshow(image)
    axes[i].set_title(f"Sample {i+1}")
    axes[i].axis('off')
    
    if label_path.exists() and label_path.stat().st_size > 0:
        with open(label_path, 'r') as f:
            line = f.readline().strip()
            if line:
                parts = line.split()
                center_x, center_y, width, height = map(float, parts[1:5])
                
                x = (center_x - width/2) * w
                y = (center_y - height/2) * h
                box_w = width * w
                box_h = height * h
                
                rect = patches.Rectangle((x, y), box_w, box_h, 
                                       linewidth=2, edgecolor='red', facecolor='none')
                axes[i].add_patch(rect)
                axes[i].text(x, y-5, 'Parasite', color='red', fontsize=10, weight='bold')

plt.tight_layout()
plt.show()

## 4. Initialize YOLOv8 Model

In [None]:
# Initialize YOLOv8 model
model = YOLO('yolov8n.pt')  # Load pretrained YOLOv8 nano model
print("YOLOv8 model loaded successfully")
print(f"Model parameters: {sum(p.numel() for p in model.model.parameters()):,}")

## 5. Configure Training Parameters

In [None]:
# Training configuration
training_config = {
    'data': str(yaml_path),
    'epochs': 50,  # Reduced for Colab
    'batch': 16,   # Adjust based on GPU memory
    'imgsz': 640,
    'lr0': 0.01,
    'weight_decay': 0.0005,
    'patience': 15,
    'save_period': 10,
    'device': 0 if torch.cuda.is_available() else 'cpu',
    'workers': 2,  # Reduced for Colab
    'project': 'malaria_detection',
    'name': 'yolov8n_malaria',
    'exist_ok': True,
    'pretrained': True,
    'optimizer': 'AdamW',
    'verbose': True,
    'seed': 42,
    'deterministic': True,
    'single_cls': True,  # Single class detection
    'rect': False,  # Disable rectangular training for better augmentation
    'cos_lr': True,  # Cosine learning rate scheduler
    'close_mosaic': 10,  # Disable mosaic augmentation in last 10 epochs
    'resume': False,
    'amp': True,  # Automatic Mixed Precision
    'fraction': 1.0,  # Use full dataset
    'profile': False,
    'freeze': None,
    'multi_scale': False,
    'overlap_mask': True,
    'mask_ratio': 4,
    'dropout': 0.0,
    'val': True,
    'split': 'val',
    'save_json': True,
    'save_hybrid': False,
    'conf': None,
    'iou': 0.7,
    'max_det': 300,
    'half': False,
    'dnn': False,
    'plots': True,
    'source': None,
    'show': False,
    'save_txt': False,
    'save_conf': False,
    'save_crop': False,
    'show_labels': True,
    'show_conf': True,
    'vid_stride': 1,
    'stream_buffer': False,
    'line_width': None,
    'visualize': False,
    'augment': False,
    'agnostic_nms': False,
    'classes': None,
    'retina_masks': False,
    'boxes': True,
    'format': 'torchscript',
    'keras': False,
    'optimize': False,
    'int8': False,
    'dynamic': False,
    'simplify': False,
    'opset': None,
    'workspace': 4,
    'nms': False,
    'batch_size': 1,
}

print("Training configuration:")
for key, value in training_config.items():
    if key in ['data', 'epochs', 'batch', 'imgsz', 'lr0', 'device']:
        print(f"  {key}: {value}")

## 6. Start Training

In [None]:
# Start training
print("Starting YOLOv8 training...")
print(f"Training on device: {training_config['device']}")
print(f"Dataset: {len(splits['train'])} train, {len(splits['val'])} val images")

results = model.train(**training_config)

print("Training completed!")
print(f"Best model saved at: {model.trainer.best}")

## 7. Evaluate Model

In [None]:
# Load best model for evaluation
best_model = YOLO(model.trainer.best)

# Validate on test set
test_results = best_model.val(data=str(yaml_path), split='test')

print("Test Results:")
print(f"mAP50: {test_results.box.map50:.4f}")
print(f"mAP50-95: {test_results.box.map:.4f}")
print(f"Precision: {test_results.box.mp:.4f}")
print(f"Recall: {test_results.box.mr:.4f}")

## 8. Test Inference

In [None]:
# Test inference on sample images
test_images = list((yolo_path / "test" / "images").glob("*.png"))[:6]

fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

for i, img_path in enumerate(test_images):
    # Run inference
    results = best_model(str(img_path))
    
    # Load original image
    image = cv2.imread(str(img_path))
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    # Draw predictions
    if len(results[0].boxes) > 0:
        boxes = results[0].boxes.xyxy.cpu().numpy()
        confs = results[0].boxes.conf.cpu().numpy()
        
        for box, conf in zip(boxes, confs):
            x1, y1, x2, y2 = box.astype(int)
            cv2.rectangle(image, (x1, y1), (x2, y2), (255, 0, 0), 2)
            cv2.putText(image, f'Parasite {conf:.2f}', (x1, y1-10), 
                       cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
    
    axes[i].imshow(image)
    axes[i].set_title(f"Test {i+1} - {len(results[0].boxes)} detections")
    axes[i].axis('off')

plt.tight_layout()
plt.show()

## 9. Export Model

In [None]:
# Export model to different formats
print("Exporting model...")

# Export to ONNX (recommended for production)
onnx_path = best_model.export(format='onnx', optimize=True)
print(f"ONNX model exported to: {onnx_path}")

# Export to TorchScript
torchscript_path = best_model.export(format='torchscript')
print(f"TorchScript model exported to: {torchscript_path}")

print("Model export completed!")

## 10. Download Results

In [None]:
# Create download package
import zipfile

# Create zip file with results
with zipfile.ZipFile('malaria_detection_results.zip', 'w') as zipf:
    # Add best model
    zipf.write(model.trainer.best, 'best_model.pt')
    
    # Add exported models
    if os.path.exists(onnx_path):
        zipf.write(onnx_path, 'best_model.onnx')
    if os.path.exists(torchscript_path):
        zipf.write(torchscript_path, 'best_model.torchscript')
    
    # Add data config
    zipf.write(yaml_path, 'malaria_data.yaml')
    
    # Add training results
    results_dir = Path(model.trainer.save_dir)
    for file in results_dir.glob('*.png'):
        zipf.write(file, f'results/{file.name}')
    
    # Add metrics
    if (results_dir / 'results.csv').exists():
        zipf.write(results_dir / 'results.csv', 'results/training_metrics.csv')

print("Results packaged in: malaria_detection_results.zip")
print("Download this file to use the trained model locally.")

# Display final summary
print("\n=== Training Summary ===")
print(f"Model: YOLOv8n")
print(f"Dataset: {len(all_files)} total images")
print(f"Training time: ~{training_config['epochs']} epochs")
print(f"Best mAP50: {test_results.box.map50:.4f}")
print(f"Best model: {model.trainer.best}")
print(f"ONNX export: {onnx_path}")