In [1]:
import os
import shutil
import random
from pathlib import Path

from torchvision import transforms, datasets
from torch.utils.data import DataLoader

import torch

In [5]:
original_dataset_dir = Path("dataset/PlantVillage")
split_dataset_dir = Path("dataset_split")  # Output folder

if split_dataset_dir.exists():
    shutil.rmtree(split_dataset_dir)
split_dataset_dir.mkdir(parents=True)

# Set seed for reproducibility
random.seed(42)

def split_dataset(src_dir, dst_dir, train_ratio=0.7, val_ratio=0.2, test_ratio=0.1):
    assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-6


    for class_dir in src_dir.iterdir():
        if not class_dir.is_dir():
            continue

        images = list(class_dir.glob("*.*"))
        random.shuffle(images)

        n = len(images)
        n_train = int(n * train_ratio)
        n_val = int(n * val_ratio)

        splits = {
            "train": images[:n_train],
            "val": images[n_train:n_train+n_val],
            "test": images[n_train+n_val:]
        }

        for split, files in splits.items():
            split_class_dir = dst_dir / split / class_dir.name
            split_class_dir.mkdir(parents=True, exist_ok=True)
            for img_path in files:
                shutil.copy(img_path, split_class_dir / img_path.name)

    print("✅ Dataset split into train / val / test")

split_dataset(original_dataset_dir, split_dataset_dir)

✅ Dataset split into train / val / test
