# Palm Oil Disease Classification - YOLOv8 Training

This notebook trains a **YOLOv8 Classification** model (`-cls`) for palm oil leaf disease detection.

**Task**: Image Classification (Categorizing the whole image)
**Dataset Structure**: `datasets/palm-tree-leaves-diseases/{split}/{class_name}`

## 1. Setup and Installation

In [1]:
# Install required packages if not already present
!pip install ultralytics roboflow opencv-python pillow matplotlib



In [2]:
import os
import shutil
from pathlib import Path
import torch
from ultralytics import YOLO
import matplotlib.pyplot as plt
from PIL import Image
import yaml

# Check Device
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

PyTorch version: 2.8.0+cu126
CUDA available: True
GPU: Quadro RTX 5000


## 2. Define Dataset Path
Since we are doing classification with an existing folder structure, we point directly to the dataset root.

In [3]:
# Define path relative to this notebook (notebooks/ -> datasets/)
dataset_path = Path("../datasets/palm-tree-leaves-diseases").resolve()

train_dir = dataset_path / "train"
val_dir = dataset_path / "valid"
test_dir = dataset_path / "test"

print(f"Dataset Root: {dataset_path}")
print(f"Train Directory Exists: {train_dir.exists()}")

Dataset Root: /home/datasets/palm-tree-leaves-diseases
Train Directory Exists: False


## 3. Explore Dataset (Classification Mode)
We infer class names from the folder names.

In [None]:
# Get class names from folder names in 'train'
if train_dir.exists():
    classes = sorted([d.name for d in train_dir.iterdir() if d.is_dir()])
    print(f"Detected Classes ({len(classes)}): {classes}")
    
    # Count images
    train_count = len(list(train_dir.rglob('*.jpg')))
    val_count = len(list(val_dir.rglob('*.jpg')))
    print(f"Training images: {train_count}")
    print(f"Validation images: {val_count}")
else:
    print("‚ùå Error: Train directory not found at path!")

In [None]:
# Visualize sample images from different classes
def visualize_samples(root_dir, class_list, num_samples=4):
    fig, axes = plt.subplots(1, num_samples, figsize=(16, 4))
    
    # Select first few classes to display
    display_classes = class_list[:num_samples]
    
    for idx, class_name in enumerate(display_classes):
        class_path = root_dir / class_name
        images = list(class_path.glob('*.jpg'))
        
        if images:
            img = Image.open(images[0])
            axes[idx].imshow(img)
            axes[idx].axis('off')
            axes[idx].set_title(class_name)
            
    plt.tight_layout()
    plt.show()

if train_dir.exists():
    print("Sample images per class:")
    visualize_samples(train_dir, classes)

## 4. Initialize YOLO Classification Model
We use **`yolov8m-cls.pt`**. The `-cls` suffix is crucial.

In [None]:
MODEL_DIR = "../models"
MODEL_NAME = "yolov8m-cls.pt"
MODEL_PATH = f"{MODEL_DIR}/{MODEL_NAME}"
print(MODEL_PATH)

In [None]:
model = YOLO(MODEL_PATH)

print(f"Model loaded from: {MODEL_PATH}")
print(f"Task: {model.task}")

## 5. Train Model

In [None]:
# Project settings
PROJECT_NAME = "palm_disease_classification"
EXPERIMENT_NAME = "yolov8m_cls_run"
EPOCHS = 50  # Adjust as needed
BATCH_SIZE = 16
IMAGE_SIZE = 224  # Standard for classification (can use 640 if high res needed)

# Start Training
results = model.train(
    data=str(dataset_path),  # Point to root folder containing train/val
    epochs=EPOCHS,
    imgsz=IMAGE_SIZE,
    batch=BATCH_SIZE,
    project=PROJECT_NAME,
    name=EXPERIMENT_NAME,
    patience=10,
    save=True,
    
    # Augmentations (tuned for classification)
    degrees=10.0,
    translate=0.1,
    scale=0.5,
    fliplr=0.5,
    mosaic=0.0,  # Mosaic is typically for detection, often disabled for cls
    
    # Optimizer
    optimizer='AdamW',
    lr0=0.001,
)

print("Training completed!")

## 6. Evaluate Model

In [None]:
# Validate
metrics = model.val()

print("\nValidation Metrics:")
print(f"Top-1 Accuracy: {metrics.top1:.4f}")
print(f"Top-5 Accuracy: {metrics.top5:.4f}")

In [None]:
# Show Training Curves & Confusion Matrix
from IPython.display import Image as IPImage, display

results_dir = Path(PROJECT_NAME) / EXPERIMENT_NAME

print("Training Results:")
if (results_dir / "results.png").exists():
    display(IPImage(filename=str(results_dir / "results.png")))

print("\nConfusion Matrix:")
if (results_dir / "confusion_matrix.png").exists():
    display(IPImage(filename=str(results_dir / "confusion_matrix.png")))

## 7. Test Predictions

In [None]:
# Load the best model weights
best_model_path = results_dir / "weights" / "best.pt"
best_model = YOLO(best_model_path)

# Pick random validation images
test_images = list(val_dir.rglob('*.jpg'))[:4]

for img_path in test_images:
    # Predict
    results = best_model.predict(source=str(img_path), verbose=False)
    result = results[0]
    
    # Extract top 1 class
    top1_idx = result.probs.top1
    top1_conf = result.probs.top1conf.item()
    pred_class = result.names[top1_idx]
    
    print(f"\nImage: {img_path.name}")
    print(f"  True Class: {img_path.parent.name}")
    print(f"  Predicted:  {pred_class} ({top1_conf:.2%})")

## 8. Export Model & Metadata

In [None]:
# Create production models folder
models_output_dir = Path("../models")
models_output_dir.mkdir(exist_ok=True)

# 1. Copy Best Weights
prod_model_path = models_output_dir / "best_cls.pt"
shutil.copy(best_model_path, prod_model_path)
print(f"Model saved to: {prod_model_path}")

# 2. Save Metadata (Important for API)
import json
from datetime import datetime

metadata = {
    "model_type": "YOLOv8-Classification",
    "model_name": EXPERIMENT_NAME,
    "image_size": IMAGE_SIZE,
    "num_classes": len(classes),
    "class_names": {i: name for i, name in enumerate(classes)}, # Map index to name
    "training_date": datetime.now().isoformat(),
    "metrics": {
        "top1": float(metrics.top1),
        "top5": float(metrics.top5)
    }
}

metadata_path = models_output_dir / "model_metadata.json"
with open(metadata_path, 'w') as f:
    json.dump(metadata, f, indent=2)

print(f"Metadata saved to: {metadata_path}")