# Generate Data Splits for PlantVillage Dataset

This notebook converts JPG images from the PlantVillage dataset into TensorFlow tensors and splits them into train/test/validation sets.


In [56]:
import pandas as pd
import numpy as np
from pathlib import Path
from sklearn.model_selection import train_test_split
import shutil
from tqdm import tqdm
import os


## 1. Load Metadata and Configure Parameters


In [57]:
# Load metadata
metadata_path = Path("../data/plantvillage_images_metadata.parquet")
df = pd.read_parquet(metadata_path)

# Configuration
PLANT_TYPES = ["Strawberry"]
BASE_DATA_PATH = Path("..")
SPLITS_PATH = BASE_DATA_PATH / "data" / "splits"
IMAGE_SIZE = (224, 224)  # Standard size for many pre-trained models
TRAIN_SPLIT = 0.7
VAL_SPLIT = 0.15
TEST_SPLIT = 0.15
RANDOM_STATE = 42

# Use only color images for training (you can change this to 'grayscale' or 'segmented' if needed)
IMAGE_TYPE_TO_USE = 'color'

print(f"Plants to use: {PLANT_TYPES}")
print(f"Total images in metadata: {len(df):,}")
print(f"Image types available: {df['image_type'].unique()}")
print(f"\nUsing image type: {IMAGE_TYPE_TO_USE}")
print(f"\nSplit ratios:")
print(f"  Train: {TRAIN_SPLIT*100}%")
print(f"  Validation: {VAL_SPLIT*100}%")
print(f"  Test: {TEST_SPLIT*100}%")


Plants to use: ['Strawberry']
Total images in metadata: 108,610
Image types available: ['grayscale' 'color']

Using image type: color

Split ratios:
  Train: 70.0%
  Validation: 15.0%
  Test: 15.0%


In [58]:
BASE_DATA_PATH

PosixPath('..')

In [59]:
df = df[df['plant_type'].isin(PLANT_TYPES)]

In [60]:
len(df)
print(df.condition.unique())

['healthy' 'Leaf_scorch']


In [61]:
# Filter for the selected image type
df_filtered = df[df['image_type'] == IMAGE_TYPE_TO_USE].copy()

# Create a combined label from plant_type and condition
df_filtered.rename(columns={'condition': 'label'}, inplace=True)

In [62]:
df_filtered.head()
print(df_filtered.label.unique())

['healthy' 'Leaf_scorch']


In [63]:


# Make paths absolute
df_filtered['full_image_path'] = df_filtered['image_path'].apply(
    lambda x: BASE_DATA_PATH / x
)

print(f"Images after filtering for {IMAGE_TYPE_TO_USE}: {len(df_filtered):,}")
print(f"\nNumber of classes: {df_filtered['label'].nunique()}")
print(f"\nClass distribution:")
print(df_filtered['label'].value_counts())
print(df_filtered.head())


Images after filtering for color: 1,565

Number of classes: 2

Class distribution:
label
Leaf_scorch    1109
healthy         456
Name: count, dtype: int64
                                              image_path image_type  \
54305  data/plantvillage dataset/color/Strawberry___h...      color   
54306  data/plantvillage dataset/color/Strawberry___h...      color   
54307  data/plantvillage dataset/color/Strawberry___h...      color   
54308  data/plantvillage dataset/color/Strawberry___h...      color   
54309  data/plantvillage dataset/color/Strawberry___h...      color   

       plant_type    label  file_size_bytes  width  height  file_size_kb  \
54305  Strawberry  healthy            21078    256     256     20.583984   
54306  Strawberry  healthy            21575    256     256     21.069336   
54307  Strawberry  healthy            18671    256     256     18.233398   
54308  Strawberry  healthy            21061    256     256     20.567383   
54309  Strawberry  healthy            

## 2. Create Stratified Train/Validation/Test Splits


In [64]:
# Stratified split to maintain class distribution
# First split: separate train from (val + test)
train_df, temp_df = train_test_split(
    df_filtered,
    test_size=(VAL_SPLIT + TEST_SPLIT),
    random_state=RANDOM_STATE,
    stratify=df_filtered['label']
)

# Second split: separate val from test
val_df, test_df = train_test_split(
    temp_df,
    test_size=TEST_SPLIT / (VAL_SPLIT + TEST_SPLIT),
    random_state=RANDOM_STATE,
    stratify=temp_df['label']
)

print(f"Train set size: {len(train_df):,} ({len(train_df)/len(df_filtered)*100:.1f}%)")
print(f"Validation set size: {len(val_df):,} ({len(val_df)/len(df_filtered)*100:.1f}%)")
print(f"Test set size: {len(test_df):,} ({len(test_df)/len(df_filtered)*100:.1f}%)")

# Verify class distribution is maintained
print(f"\nTrain classes: {train_df['label'].nunique()}")
print(f"Val classes: {val_df['label'].nunique()}")
print(f"Test classes: {test_df['label'].nunique()}")


Train set size: 1,095 (70.0%)
Validation set size: 235 (15.0%)
Test set size: 235 (15.0%)

Train classes: 2
Val classes: 2
Test classes: 2


## 3. Create Directory Structure and Save Images


In [65]:
# Create base splits directory
SPLITS_PATH.mkdir(exist_ok=True, parents=True)

# Create subdirectories for each split
for split_name in ['train', 'validation', 'test']:
    split_path = SPLITS_PATH / split_name
    split_path.mkdir(exist_ok=True, parents=True)
    
    # Create class subdirectories
    for label in df_filtered['label'].unique():
        class_path = split_path / label
        class_path.mkdir(exist_ok=True, parents=True)

print("Directory structure created successfully!")
print(f"\nBase path: {SPLITS_PATH}")
print(f"Subdirectories: train, validation, test")
print(f"Classes per subdirectory: {df_filtered['label'].nunique()}")


Directory structure created successfully!

Base path: ../data/splits
Subdirectories: train, validation, test
Classes per subdirectory: 2


In [66]:
def copy_images_to_split(df, split_name):
    """Copy images from source to split directory"""
    print(f"\nCopying images to {split_name}...")
    
    for idx, row in tqdm(df.iterrows(), total=len(df)):
        src_path = row['full_image_path']
        label = row['label']
        
        # Create destination path
        dst_path = SPLITS_PATH / split_name / label / src_path.name
        
        # Copy file if source exists
        if src_path.exists():
            shutil.copy2(src_path, dst_path)
        else:
            print(f"Warning: Source file not found: {src_path}")
    
    print(f"Completed copying {len(df):,} images to {split_name}")

# Copy images to each split
copy_images_to_split(train_df, 'train')
copy_images_to_split(val_df, 'validation')
copy_images_to_split(test_df, 'test')

print("\n✓ All images copied successfully!")



Copying images to train...


100%|██████████| 1095/1095 [00:00<00:00, 2160.63it/s]


Completed copying 1,095 images to train

Copying images to validation...


100%|██████████| 235/235 [00:00<00:00, 2189.31it/s]


Completed copying 235 images to validation

Copying images to test...


100%|██████████| 235/235 [00:00<00:00, 784.56it/s]

Completed copying 235 images to test

✓ All images copied successfully!



