In [1]:
import torch
import os
import shutil
import random
from pathlib import Path
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"✅ Using: {device}, {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")

data_raw_dir = Path("../data/raw/PlantVillage")   
data_split_dir = Path("../data/splits")       
data_split_dir.mkdir(parents=True, exist_ok=True)


✅ Using: cuda, NVIDIA GeForce RTX 4050 Laptop GPU


In [2]:
def split_dataset(raw_dir, split_dir, train_ratio=0.7, val_ratio=0.2, test_ratio=0.1):
    assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-6, "Ratios must sum to 1"

    classes = [d.name for d in raw_dir.iterdir() if d.is_dir()]
    print(f"Found {len(classes)} classes.")

    for cls in classes:
        cls_dir = raw_dir / cls
        images = list(cls_dir.glob("*.jpg")) + list(cls_dir.glob("*.png"))  
        random.shuffle(images)

        n_total = len(images)
        n_train = int(train_ratio * n_total)
        n_val = int(val_ratio * n_total)
        n_test = n_total - n_train - n_val

        splits = {
            "train": images[:n_train],
            "val": images[n_train:n_train+n_val],
            "test": images[n_train+n_val:]
        }

        for split_name, split_images in splits.items():
            split_cls_dir = split_dir / split_name / cls
            split_cls_dir.mkdir(parents=True, exist_ok=True)

            for img_path in split_images:
                target_path = split_cls_dir / img_path.name
                if not target_path.exists():
                    shutil.copy(img_path, target_path)

        print(f"[{cls}] → Train: {n_train}, Val: {n_val}, Test: {n_test}")

#split_dataset(data_raw_dir, data_split_dir) # Run this line only once, then comment it


Found 38 classes.
[Apple___Apple_scab] → Train: 441, Val: 126, Test: 63
[Apple___Black_rot] → Train: 434, Val: 124, Test: 63
[Apple___Cedar_apple_rust] → Train: 192, Val: 55, Test: 28
[Apple___healthy] → Train: 1151, Val: 329, Test: 165
[Blueberry___healthy] → Train: 1051, Val: 300, Test: 151
[Cherry_(including_sour)___healthy] → Train: 597, Val: 170, Test: 87
[Cherry_(including_sour)___Powdery_mildew] → Train: 736, Val: 210, Test: 106
[Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot] → Train: 359, Val: 102, Test: 52
[Corn_(maize)___Common_rust_] → Train: 834, Val: 238, Test: 120
[Corn_(maize)___healthy] → Train: 813, Val: 232, Test: 117
[Corn_(maize)___Northern_Leaf_Blight] → Train: 689, Val: 197, Test: 99
[Grape___Black_rot] → Train: 826, Val: 236, Test: 118
[Grape___Esca_(Black_Measles)] → Train: 968, Val: 276, Test: 139
[Grape___healthy] → Train: 296, Val: 84, Test: 43
[Grape___Leaf_blight_(Isariopsis_Leaf_Spot)] → Train: 753, Val: 215, Test: 108
[Orange___Haunglongbing_(Citrus_