02 -- Train RiceLeafs Dataset

## Overview — What This Notebook Does

This notebook trains a **YOLOv8 object detection model** on the processed PlantDoc dataset using bounding box annotations.

### Step-by-step workflow

1. **Install and import YOLOv8 (Ultralytics)**  
   - Install the `ultralytics` package.  
   - Import `YOLO` to access the training, validation, and prediction APIs.

2. **Define the dataset configuration (`plantdoc.yaml`)**  
   - Provide YOLO with:
     - `path`: root of the processed PlantDoc dataset (e.g. `data/processed/PlantDoc`).  
     - `train`, `val`, `test`: relative paths to image folders.  
     - `names`: mapping from class indices (0, 1, 2, …) to disease class names.
   - YOLO uses this file to know where the images and labels are and what each class index means.

3. **Initialize a pretrained YOLOv8 model**  
   - Start from a small pretrained checkpoint (e.g. `yolov8n.pt`) trained on COCO.  
   - This gives good initial features instead of training from scratch.

4. **Train the detector**  
   - Call `model.train(...)` with:
     - `data`: path to `plantdoc.yaml`.  
     - `epochs`, `imgsz`, `batch`, and other hyperparameters.  
   - During training, YOLO repeatedly:
     1. Loads batches of images and their bounding boxes.  
     2. Predicts bounding boxes and class probabilities.  
     3. Computes loss on box regression, objectness, and class prediction.  
     4. Backpropagates and updates the model weights.
   - All training artifacts and logs are written to `runs/train/<experiment_name>`.

5. **Validate the trained model**  
   - Run `model.val()` to evaluate on the validation set.  
   - YOLO reports metrics such as:
     - mAP@0.5 and mAP@0.5:0.95  
     - Precision and recall per class.  
   - These metrics show how well the model localizes and classifies diseased regions.

6. **Export or save the final model**  
   - The best weights (based on validation performance) are automatically saved as `best.pt` in the run directory.  
   - Optionally export the model to other formats (e.g. ONNX) for deployment.  
   - These weights will later be used in the inference pipeline to detect diseased areas on new leaf images.


In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms, models, datasets
from torch.optim import Adam
from tqdm import tqdm
import matplotlib.pyplot as plt
import os


In [None]:
DATA_DIR = "processed/riceleaf"
BATCH_SIZE = 32
LR = 1e-4
EPOCHS = 15
IMG_SIZE = 224
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


In [None]:
train_tfms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
])

val_tfms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
])


train_ds = datasets.ImageFolder(os.path.join(DATA_DIR, "train"), transform=train_tfms)
val_ds   = datasets.ImageFolder(os.path.join(DATA_DIR, "val"),   transform=val_tfms)

train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_dl   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)

NUM_CLASSES = len(train_ds.classes)
train_ds.classes


In [None]:
model = models.vit_b_16(weights="IMAGENET1K_V1")
model.heads.head = nn.Linear(model.heads.head.in_features, NUM_CLASSES)
model = model.to(DEVICE)

criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=LR)


for epoch in range(EPOCHS):
    model.train()
    total_loss = 0

    for imgs, labels in tqdm(train_dl):
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)

        optimizer.zero_grad()
        preds = model(imgs)
        loss = criterion(preds, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {total_loss/len(train_dl):.4f}")


In [None]:
model.eval()
correct, total = 0, 0

with torch.no_grad():
    for imgs, labels in val_dl:
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
        preds = model(imgs).argmax(1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

accuracy = correct / total
accuracy


In [None]:
# Save the model
torch.save(model.state_dict(), "cassava_vit_best.pth")
