# YOLOv8 â€” Crop-wise Model Training Notebook
This notebook contains EDA, dataset preparation, training, validation, inference, and export cells for training crop-wise YOLO models on the CommonCropDiseases dataset.

In [None]:
# Cell 2 â€” Import Libraries
import os
import cv2
import glob
import shutil
import random
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from ultralytics import YOLO
from tqdm import tqdm

In [None]:
# Cell 3 â€” Mount Drive (optional, Colab only)
try:
    from google.colab import drive
    drive.mount('/content/drive')
except Exception:
    print("Not running in Colab or google.colab not available. Skipping drive mount.")

-----------------------------------------------------
ðŸ§ª EDA BLOCK â€” VERY IMPORTANT
-----------------------------------------------------

In [None]:
# Cell 4 â€” Explore All Class Folders
DATASET_PATH = "/content/CommonCropDiseases"  # adjust path as needed

if os.path.exists(DATASET_PATH):
    classes = sorted([d for d in os.listdir(DATASET_PATH) if os.path.isdir(os.path.join(DATASET_PATH,d))])
    print("Total classes:", len(classes))
    classes
else:
    print("DATASET_PATH does not exist. Update DATASET_PATH before running EDA.")

In [None]:
# Cell 5 â€” Show Random Samples from Each Class
def show_random_samples(class_name, n=4):
    folder = os.path.join(DATASET_PATH, class_name)
    images = glob.glob(folder + "/*.jpg")[:50]

    plt.figure(figsize=(12,6))
    for i in range(n):
        img = cv2.imread(random.choice(images))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        plt.subplot(1,n,i+1)
        plt.imshow(img)
        plt.title(class_name)
        plt.axis("off")
    plt.show()

if 'classes' in globals():
    for c in classes[:5]:  # show any 5 classes
        show_random_samples(c)
else:
    print("Run the previous cell to set DATASET_PATH and classes.")

In [None]:
# Cell 6 â€” Class Distribution Plot
if 'classes' in globals():
    class_counts = {c: len([f for f in os.listdir(os.path.join(DATASET_PATH, c)) if f.lower().endswith('.jpg')]) for c in classes}

    plt.figure(figsize=(14,8))
    sns.barplot(x=list(class_counts.keys()), y=list(class_counts.values()))
    plt.xticks(rotation=90)
    plt.title("Class Distribution (48 classes)")
    plt.show()
else:
    print("Run the dataset exploration cell first.")

In [None]:
# Cell 7 â€” Image Quality Checker (Find Corrupt Files)
bad_images = []

if 'classes' in globals():
    for cls in classes:
        folder = os.path.join(DATASET_PATH, cls)
        for img_path in glob.glob(folder + "/*.jpg"):
            try:
                img = cv2.imread(img_path)
                if img is None or img.size == 0:
                    bad_images.append(img_path)
            except Exception:
                bad_images.append(img_path)

    print("Corrupt images:", len(bad_images))
    bad_images[:10]
else:
    print("Run dataset exploration first.")

In [None]:
# Cell 8 â€” Remove Corrupt Images
for b in bad_images:
    try:
        os.remove(b)
    except Exception as e:
        print("Failed to remove:", b, e)
print("Cleaned corrupt images!")

In [None]:
# Cell 9 â€” Visualize Image Sizes
sizes = []

if 'classes' in globals():
    for cls in tqdm(classes):
        for img_path in glob.glob(f"{DATASET_PATH}/{cls}/*.jpg"):
            img = cv2.imread(img_path)
            if img is not None:
                sizes.append(img.shape[:2])  # H, W

    heights = [h for h,w in sizes]
    widths = [w for h,w in sizes]

    plt.figure(figsize=(10,5))
    sns.kdeplot(heights, label='Height')
    sns.kdeplot(widths, label='Width')
    plt.legend()
    plt.title("Image Dimension Distribution")
    plt.show()
else:
    print("Run dataset exploration first.")

-----------------------------------------------------
ðŸŒ± Crop-Wise Dataset Generator
-----------------------------------------------------

In [None]:
# Cell 10 â€” Define Crops â†’ Disease Mapping
CROPS = {
    "Apple": ["Apple___Black_rot","Apple___Cedar_apple_rust","Apple___Apple_scab","Apple___healthy"],
    "Corn": ["Corn___Cercospora_leaf_spot","Corn___Northern_Leaf_Blight","Corn___Common_rust","Corn___healthy"],
    "Potato": ["Potato___Early_blight","Potato___Late_blight","Potato___healthy"],
    "Tomato": ["Tomato___Bacterial_spot","Tomato___Late_blight","Tomato___Leaf_Mold",
               "Tomato___Septoria_leaf_spot","Tomato___Spider_mites","Tomato___Yellow_Leaf_Curl_Virus",
               "Tomato___healthy"],
    "Sugarcane": ["Sugarcane___RedRot","Sugarcane___Mosaic","Sugarcane___healthy","Sugarcane___Rust","Sugarcane___Yellow"],
    "Rice": ["Rice__Tungro","Rice__Brownspot","Rice__Blast","Rice__Healthy","Rice__Bacterialblight"],
    "Grape": ["Grape___healthy","Grape___Leaf_blight_(Isariopsis_Leaf_Spot)","Grape___Black_rot","Grape___Esca_(Black_Measles)"],
    "Cherry": ["Cherry_(including_sour)___healthy","Cherry_(including_sour)___Powdery_mildew"]
}

In [None]:
# Cell 11 â€” Augmentation Preview (Important)
def preview_augmentation(image_path):
    img = cv2.imread(image_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    # Flip
    flip = cv2.flip(img, 1)

    # Brightness
    bright = cv2.convertScaleAbs(img, alpha=1.2, beta=30)

    # Blur
    blur = cv2.GaussianBlur(img, (11,11), 0)

    plt.figure(figsize=(12,3))
    for i, im in enumerate([img, flip, bright, blur]):
        plt.subplot(1,4,i+1)
        plt.imshow(im)
        plt.axis("off")
    plt.show()

# Preview using first available image if dataset exists
try:
    preview_augmentation(glob.glob(f"{DATASET_PATH}/{classes[0]}/*.jpg")[0])
except Exception as e:
    print("Could not preview augmentation:", e)

-----------------------------------------------------
ðŸ“¦ Create YOLO Dataset (Auto)
-----------------------------------------------------

In [None]:
# Cell 12 â€” Create Dataset Builder
def create_crop_dataset(crop, disease_folders):
    base = f"/content/{crop}"
    os.makedirs(f"{base}/images/train", exist_ok=True)
    os.makedirs(f"{base}/images/val", exist_ok=True)
    os.makedirs(f"{base}/labels/train", exist_ok=True)
    os.makedirs(f"{base}/labels/val", exist_ok=True)

    for cls in disease_folders:
        folder = f"{DATASET_PATH}/{cls}"
        images = glob.glob(folder + "/*.jpg")

        random.shuffle(images)
        split = int(0.8 * len(images))
        train = images[:split]
        val = images[split:]

        for img in train:
            shutil.copy(img, f"{base}/images/train/{os.path.basename(img)}")

        for img in val:
            shutil.copy(img, f"{base}/images/val/{os.path.basename(img)}")

    # YOLO label = full image bounding box
    for folder in ["train", "val"]:
        for img in os.listdir(f"{base}/images/{folder}"):
            with open(f"{base}/labels/{folder}/{img.replace('.jpg','.txt')}", "w") as f:
                f.write("0 0.5 0.5 1 1")

In [None]:
# Cell 13 â€” Auto-Generate YAML File
def create_yaml(crop):
    content = f"""
path: /content/{crop}
train: images/train
val: images/val

nc: 1
names: ["disease"]
"""
    with open(f"/content/{crop}/data.yaml","w") as f:
        f.write(content)

-----------------------------------------------------
ðŸš€ TRAINING MODULE
-----------------------------------------------------

In [None]:
# Cell 14 â€” Train YOLO for One Crop
def train_crop(crop):
    print(f"ðŸ”¥ Training Model for {crop}")

    model = YOLO("yolov8m.pt")

    results = model.train(
        data=f"/content/{crop}/data.yaml",
        epochs=30,
        imgsz=640,
        batch=16,
        device=0,
        lr0=0.001,
        optimizer="Adam",
        patience=10
    )

    return model

In [None]:
# Cell 15 â€” Train All Crop Models
for crop, disease_classes in CROPS.items():
    create_crop_dataset(crop, disease_classes)
    create_yaml(crop)
    model = train_crop(crop)

-----------------------------------------------------
ðŸ“Š VALIDATION + METRICS
-----------------------------------------------------

In [None]:
# Cell 16 â€” Validate Model
model = YOLO("runs/detect/train/weights/best.pt")
metrics = model.val()
metrics

In [None]:
# Cell 17 â€” Plot Training Curves
from IPython.display import Image
Image(filename="runs/detect/train/results.png")

-----------------------------------------------------
ðŸŽ¯ INFERENCE + TESTING
-----------------------------------------------------

In [None]:
# Cell 18 â€” Test on a Random Image
try:
    test_img = glob.glob(f"/content/{list(CROPS.keys())[0]}/images/val/*.jpg")[0]
    model(test_img, show=True)
except Exception as e:
    print("Test image inference failed:", e)

In [None]:
# Cell 19 â€” Video Inference
# Ensure test_video.mp4 exists in the working directory or provide full path
# model.predict(source="test_video.mp4", show=True, save=True, conf=0.5)
print("Uncomment and run model.predict(...) after adding a video file.")

-----------------------------------------------------
ðŸ’¾ EXPORT MODEL
-----------------------------------------------------

In [None]:
# Cell 20 â€” Export ONNX / TensorRT
try:
    model.export(format="onnx")
except Exception as e:
    print("ONNX export failed (ensure dependencies installed):", e)
# For TensorRT (requires TensorRT + proper environment)
# model.export(format="engine")  # TensorRT