# Worm Detection - YOLOv8 Training

**Dataset**: 2,891 patches (2,457 train / 434 val) from 16 frames  
**Model**: YOLOv8m (medium) - best accuracy/speed balance  
**Hardware**: Kaggle 2x T4 GPU  
**Estimated Time**: ~15-25 min

## Step 1: Install Ultralytics

In [None]:
!pip install -q ultralytics

## Step 2: Upload & Extract Dataset

Upload `worm_dataset_yolov8.zip` as a Kaggle Dataset, then reference it below.

In [None]:
import os
import shutil

# Option A: If uploaded as a Kaggle Dataset
# Find the dataset path (adjust the name to match your upload)
kaggle_input = '/kaggle/input'
dataset_dirs = os.listdir(kaggle_input)
print('Available datasets:', dataset_dirs)

# Copy dataset to working directory
for d in dataset_dirs:
    zip_files = [f for f in os.listdir(os.path.join(kaggle_input, d)) if f.endswith('.zip')]
    for zf in zip_files:
        print(f'Extracting {zf}...')
        shutil.unpack_archive(os.path.join(kaggle_input, d, zf), '/kaggle/working/')
    
    # Also check if dataset was uploaded as a directory (auto-extracted)
    if os.path.exists(os.path.join(kaggle_input, d, 'data.yaml')):
        print(f'Found dataset at {os.path.join(kaggle_input, d)}')
        shutil.copytree(os.path.join(kaggle_input, d), '/kaggle/working/yolov8_dataset', dirs_exist_ok=True)
    elif os.path.exists(os.path.join(kaggle_input, d, 'yolov8_dataset', 'data.yaml')):
        print(f'Found dataset at {os.path.join(kaggle_input, d, "yolov8_dataset")}')
        shutil.copytree(os.path.join(kaggle_input, d, 'yolov8_dataset'), '/kaggle/working/yolov8_dataset', dirs_exist_ok=True)

# Verify
dataset_path = '/kaggle/working/yolov8_dataset'
print(f'\nDataset path: {dataset_path}')
print(f'Contents: {os.listdir(dataset_path)}')
print(f'Train images: {len(os.listdir(os.path.join(dataset_path, "train/images")))}')
print(f'Val images: {len(os.listdir(os.path.join(dataset_path, "val/images")))}')

## Step 3: Verify data.yaml

In [None]:
# Read and display data.yaml
yaml_path = os.path.join(dataset_path, 'data.yaml')
with open(yaml_path, 'r') as f:
    content = f.read()
print(content)

# Ensure the path is correct
with open(yaml_path, 'w') as f:
    f.write(f"""# Worm Detection Dataset
path: {dataset_path}
train: train/images
val: val/images

nc: 1
names: ['worm']
""")
print('data.yaml updated with correct paths.')

## Step 4: Visualize Sample Data

In [None]:
import cv2
import matplotlib.pyplot as plt
import numpy as np

train_imgs = sorted(os.listdir(os.path.join(dataset_path, 'train/images')))[:6]

fig, axes = plt.subplots(2, 3, figsize=(15, 10))
for idx, img_name in enumerate(train_imgs):
    ax = axes[idx // 3][idx % 3]
    img = cv2.imread(os.path.join(dataset_path, 'train/images', img_name))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    h, w = img.shape[:2]
    
    # Draw labels
    lbl_name = img_name.replace('.jpg', '.txt')
    lbl_path = os.path.join(dataset_path, 'train/labels', lbl_name)
    if os.path.exists(lbl_path):
        with open(lbl_path) as f:
            for line in f:
                parts = line.strip().split()
                cx, cy, bw, bh = float(parts[1]), float(parts[2]), float(parts[3]), float(parts[4])
                x1 = int((cx - bw/2) * w)
                y1 = int((cy - bh/2) * h)
                x2 = int((cx + bw/2) * w)
                y2 = int((cy + bh/2) * h)
                cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
    
    ax.imshow(img)
    ax.set_title(img_name[:20], fontsize=10)
    ax.axis('off')

plt.suptitle('Sample Training Patches with Labels', fontsize=14)
plt.tight_layout()
plt.savefig('sample_patches.png', dpi=100)
plt.show()

## Step 5: Train YOLOv8

In [None]:
from ultralytics import YOLO

# Use YOLOv8m (medium) for best accuracy
# Other options: yolov8n (nano/fast), yolov8s (small), yolov8l (large), yolov8x (extra-large)
model = YOLO('yolov8m.pt')

# Train
results = model.train(
    data=yaml_path,
    epochs=100,
    imgsz=416,
    batch=32,           # T4 can handle batch=32 at 416px
    patience=20,        # Early stopping if no improvement for 20 epochs
    save=True,
    save_period=10,     # Save checkpoint every 10 epochs
    device=0,           # Use first GPU
    workers=4,
    project='worm_detection',
    name='yolov8m_worms',
    
    # Augmentation (YOLOv8 built-in)
    augment=True,
    hsv_h=0.015,        # Hue augmentation
    hsv_s=0.4,          # Saturation augmentation  
    hsv_v=0.3,          # Value augmentation
    degrees=15,         # Rotation
    translate=0.1,      # Translation
    scale=0.3,          # Scale
    fliplr=0.5,         # Horizontal flip
    flipud=0.3,         # Vertical flip
    mosaic=1.0,         # Mosaic augmentation
    mixup=0.1,          # Mixup augmentation
    
    # Optimizer
    optimizer='AdamW',
    lr0=0.001,
    lrf=0.01,           # Final LR = lr0 * lrf
    warmup_epochs=5,
    weight_decay=0.0005,
    
    # Advanced
    cos_lr=True,        # Cosine LR scheduler
    close_mosaic=10,    # Disable mosaic for last 10 epochs
    amp=True,           # Mixed precision training
)

print('Training complete!')

## Step 6: Evaluate Performance

In [None]:
# Validate on val set
best_model_path = 'worm_detection/yolov8m_worms/weights/best.pt'
model = YOLO(best_model_path)

metrics = model.val(data=yaml_path)

print(f"\n{'='*50}")
print(f"RESULTS")
print(f"{'='*50}")
print(f"mAP50:      {metrics.box.map50:.4f}")
print(f"mAP50-95:   {metrics.box.map:.4f}")
print(f"Precision:  {metrics.box.mp:.4f}")
print(f"Recall:     {metrics.box.mr:.4f}")
print(f"{'='*50}")

## Step 7: Visualize Predictions

In [None]:
# Run inference on validation images
val_images = sorted(os.listdir(os.path.join(dataset_path, 'val/images')))[:8]

fig, axes = plt.subplots(2, 4, figsize=(20, 10))
for idx, img_name in enumerate(val_images):
    ax = axes[idx // 4][idx % 4]
    img_path = os.path.join(dataset_path, 'val/images', img_name)
    
    results = model(img_path, conf=0.5)
    annotated = results[0].plot()
    annotated = cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB)
    
    ax.imshow(annotated)
    n_det = len(results[0].boxes)
    ax.set_title(f'{n_det} worms', fontsize=10)
    ax.axis('off')

plt.suptitle('Validation Predictions (YOLOv8m)', fontsize=14)
plt.tight_layout()
plt.savefig('val_predictions.png', dpi=100)
plt.show()

## Step 8: Export Model

Export the best model in multiple formats for deployment.

In [None]:
# Export to ONNX (for cross-platform deployment)
model = YOLO(best_model_path)
model.export(format='onnx', imgsz=416, simplify=True)

# Export to TorchScript
model.export(format='torchscript', imgsz=416)

print('\nExported models:')
export_dir = 'worm_detection/yolov8m_worms/weights'
for f in os.listdir(export_dir):
    size = os.path.getsize(os.path.join(export_dir, f)) / (1024*1024)
    print(f'  {f}: {size:.1f} MB')

## Step 9: Download Model

Download `best.pt` from `worm_detection/yolov8m_worms/weights/best.pt`  
This is the model you'll use for local tracking.

In [None]:
# Copy model to easy-to-find location
shutil.copy(best_model_path, '/kaggle/working/best_worm_yolov8m.pt')

# Also copy the training plots
results_dir = 'worm_detection/yolov8m_worms'
for plot_file in ['results.png', 'confusion_matrix.png', 'P_curve.png', 'R_curve.png', 'F1_curve.png']:
    src = os.path.join(results_dir, plot_file)
    if os.path.exists(src):
        shutil.copy(src, f'/kaggle/working/{plot_file}')

print('Files ready for download in /kaggle/working/')
for f in os.listdir('/kaggle/working/'):
    if f.endswith(('.pt', '.png', '.onnx')):
        size = os.path.getsize(f'/kaggle/working/{f}') / (1024*1024)
        print(f'  {f}: {size:.1f} MB')

## Training Summary

### Training Curves

In [None]:
from IPython.display import Image, display

# Show training results
results_img = os.path.join(results_dir, 'results.png')
if os.path.exists(results_img):
    display(Image(filename=results_img, width=900))

# Confusion matrix
cm_img = os.path.join(results_dir, 'confusion_matrix.png')
if os.path.exists(cm_img):
    display(Image(filename=cm_img, width=600))