In [11]:
import pandas as pd
import os
import sys
import hashlib
from pathlib import Path
from tqdm import tqdm
import numpy as np
import albumentations as A
import random
from sklearn.model_selection import train_test_split
import shutil # <--- NEW: Required for moving files
from PIL import Image

# --- Configuration (Define these based on your setup) ---
ROOT = Path("/Users/sanadmadani/plant-disease-detection/plant-disease-detection")
DATA_DIR = ROOT / 'data_raw'         # Folder containing 'archive', 'Cauliflower', etc.
OUT_DIR = ROOT / 'jordan_dataset2'    # Output folder for metadata and unified 'images'
IMAGE_SIZE = 512
RANDOM_SEED = 42

# Ensure output paths are defined for the rest of the script
IMG_OUT = OUT_DIR / 'images'
METADATA_CSV = OUT_DIR / 'metadata_all.csv'
# --- End Configuration ---


def normalize_and_save(src_path, dest_path, size=IMAGE_SIZE):
    """Resize, convert to RGB, and save image."""
    dest_path.parent.mkdir(parents=True, exist_ok=True)
    try:
        img = Image.open(src_path).convert('RGB')
        img = img.resize((size, size), Image.Resampling.LANCZOS)
        img.save(dest_path, format='JPEG', quality=90)
    except Exception as e:
        print(f'Failed processing {src_path}: {e}')

def file_hash_name(path):
    """Generate a unique filename based on file content hash."""
    h = hashlib.sha1()
    try:
        with open(path, 'rb') as f:
            for chunk in iter(lambda: f.read(4096), b""):
                h.update(chunk)
        return h.hexdigest() + '.jpg'
    except (IOError, OSError) as e:
        print(f"Error hashing file {path}: {e}")
        return None

# --- Class Mapping (Copied from your input) ---
# ... (The CLASS_MAPPING dictionary remains here) ...
CLASS_MAPPING = {
    # --- From 'archive/train' (Wheat) ---
    'Healthy Wheat': ('Wheat', 'healthy'),
    'Wheat aphid': ('Wheat', 'Aphid'),
    'Wheat black rust': ('Wheat', 'Black_rust'),
    'Wheat Brown leaf Rust': ('Wheat', 'Brown_leaf_Rust'),
    'Wheat leaf blight': ('Wheat', 'Leaf_blight'),
    'Wheat mite': ('Wheat', 'Mite'),
    'Wheat powdery mildew': ('Wheat', 'Powdery_mildew'),
    'Wheat scab': ('Wheat', 'Scab'),
    'Wheat Stem fly': ('Wheat', 'Stem_fly'),
    'Wheat___Yellow_Rust': ('Wheat', 'Yellow_Rust'),

    # --- From 'Cauliflower/train' (Cauliflower & Eggplant) ---
    'Cauliflower_Bacterial_spot_rot': ('Cauliflower', 'Bacterial_spot_rot'),
    'Cauliflower_Black_Rot': ('Cauliflower', 'Black_Rot'),
    'Cauliflower_Downy_Mildew': ('Cauliflower', 'Downy_Mildew'),
    'Cauliflower_Healthy': ('Cauliflower', 'healthy'),
    'EggPlant_Healthy_Leaf': ('Eggplant', 'healthy'),
    'EggPlant_Insect_Pest_Disease': ('Eggplant', 'Insect_Pest_Disease'),
    'EggPlant_Leaf_Spot_Disease': ('Eggplant', 'Leaf_Spot_Disease'),
    'EggPlant_Mosaic_Virus_Disease': ('Eggplant', 'Mosaic_Virus_Disease'),
    'EggPlant_Small_Leaf_Disease': ('Eggplant', 'Small_Leaf_Disease'),
    'EggPlant_White_Mold_Disease': ('Eggplant', 'White_Mold_Disease'),
    'EggPlant_Wilt_Disease': ('Eggplant', 'Wilt_Disease'),

    # --- From 'mult_classes/train' (PlantVillage subset) ---
    'Apple___Apple_scab': ('Apple', 'Apple_scab'),
    'Apple___Black_rot': ('Apple', 'Black_rot'),
    'Apple___Cedar_apple_rust': ('Apple', 'Cedar_apple_rust'),
    'Apple___healthy': ('Apple', 'healthy'),
    'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot': ('Maize', 'Cercospora_leaf_spot_Gray_leaf_spot'),
    'Corn_(maize)___Common_rust_': ('Maize', 'Common_rust'),
    'Corn_(maize)___healthy': ('Maize', 'healthy'),
    'Corn_(maize)___Northern_Leaf_Blight': ('Maize', 'Northern_Leaf_Blight'),
    'Grape___Black_rot': ('Grape', 'Black_rot'),
    'Grape___Esca_(Black_Measles)': ('Grape', 'Esca_Black_Measles'),
    'Grape___healthy': ('Grape', 'healthy'),
    'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)': ('Grape', 'Leaf_blight_Isariopsis_Leaf_Spot'),
    'Orange___Haunglongbing_(Citrus_greening)': ('Orange', 'Citrus_greening'),
    'Peach___Bacterial_spot': ('Peach', 'Bacterial_spot'),
    'Peach___healthy': ('Peach', 'healthy'),
    'Potato___Early_blight': ('Potato', 'Early_blight'),
    'Potato___healthy': ('Potato', 'healthy'),
    'Potato___Late_blight': ('Potato', 'Late_blight'),
    'Tomato___Bacterial_spot': ('Tomato', 'Bacterial_spot'),
    'Tomato___Early_blight': ('Tomato', 'Early_blight'),
    'Tomato___healthy': ('Tomato', 'healthy'),
    'Tomato___Late_blight': ('Tomato', 'Late_blight'),
    'Tomato___Leaf_Mold': ('Tomato', 'Leaf_Mold'),
    'Tomato___Septoria_leaf_spot': ('Tomato', 'Septoria_leaf_spot'),
    'Tomato___Spider_mites Two-spotted_spider_mite': ('Tomato', 'Spider_mites'),
    'Tomato___Target_Spot': ('Tomato', 'Target_Spot'),
    'Tomato___Tomato_mosaic_virus': ('Tomato', 'Mosaic_virus'),
    'Tomato___Tomato_Yellow_Leaf_Curl_Virus': ('Tomato', 'Yellow_Leaf_Curl_Virus'),

    # --- From 'olive/train' (Olive) ---
    'aculus_olearius': ('Olive', 'Aculus_olearius_mite'),
    'Healthy': ('Olive', 'healthy'),
    'olive_peacock_spot': ('Olive', 'Peacock_spot'),
}

# --- Image Gathering and Renaming Logic (No major change here, assumes it runs once) ---
print('--- Standardizing, Renaming, and Gathering Images ---')
IMG_OUT.mkdir(parents=True, exist_ok=True)
print(f"Using Data Directory: {DATA_DIR}")

METADATA = []

COMMON_ROOTS = [
    DATA_DIR / 'archive' / 'train',
    DATA_DIR / 'Cauliflower' / 'train',
    DATA_DIR / 'mult_classes' / 'train',
    DATA_DIR / 'olive' / 'train',
    
    # Also check test/valid folders
    DATA_DIR / 'archive' / 'test',
    DATA_DIR / 'Cauliflower' / 'test',
    DATA_DIR / 'mult_classes' / 'test',
    DATA_DIR / 'olive' / 'test',
    
    DATA_DIR / 'archive' / 'valid',
    DATA_DIR / 'Cauliflower' / 'valid',
    DATA_DIR / 'mult_classes' / 'valid',
    DATA_DIR / 'olive' / 'valid',
]

# Check if the primary root exists, to be sure
if not (DATA_DIR / 'archive' / 'train').is_dir():
    print(f"FATAL ERROR: Could not find {DATA_DIR / 'archive' / 'train'}")
    print("Please verify the hardcoded 'ROOT' path in the script.")
    sys.exit(1)


for src_folder_name, (crop, disease) in tqdm(CLASS_MAPPING.items(), desc="Processing Classes"):
    src = None
    
    for root in COMMON_ROOTS:
        potential_src = root / src_folder_name
        if potential_src.is_dir():
            src = potential_src
            break 

    if src is None:
        continue 
    
    # Destination folder is still the unified location for all images initially:
    dst_folder = IMG_OUT / crop / disease 
    
    for root_dir, dirs, files in os.walk(src):
        for f in files:
            if f.lower().endswith(('.jpg','.jpeg','.png')):
                src_file = Path(root_dir) / f
                
                try:
                    new_name = file_hash_name(src_file)
                    if new_name is None:
                        continue 
                        
                    dst_file = dst_folder / new_name
                    
                    if dst_file.exists():
                        continue 
                        
                    normalize_and_save(src_file, dst_file)
                    
                    METADATA.append({
                        'image_path': str(dst_file.relative_to(OUT_DIR)),
                        'crop': crop,
                        'disease': disease,
                        'source_folder': src_folder_name,
                    })
                except Exception as e:
                    print(f"Error processing {src_file}: {e}")
                
# Save collected metadata
meta = pd.DataFrame(METADATA)
if len(meta) > 0:
    meta.to_csv(METADATA_CSV, index=False)
else:
    print("\n--- ERROR ---")
    print("Collected 0 images. This should not happen with the new mapping.")
    sys.exit(1)

print(f'\nCollected {len(meta)} images.')
print(f'The unified dataset images are located in: {IMG_OUT}')


# --- Train / Val / Test split (MODIFIED SECTION) ---
print('\n--- Splitting Data and Moving Files ---')
if len(meta) > 1 and len(meta['crop'].unique()) > 1:
    meta['label'] = meta['crop'] + '___' + meta['disease']
    
    # Filtering logic remains the same (important for stratified split)
    counts = meta['label'].value_counts()
    single_image_classes = counts[counts == 1].index
    
    if len(single_image_classes) > 0:
        meta = meta[~meta['label'].isin(single_image_classes)]

    if len(meta) > 1:
        counts = meta['label'].value_counts()
        valid_classes = counts[counts >= 2].index
        if len(valid_classes) < len(counts):
            meta = meta[meta['label'].isin(valid_classes)]

    if len(meta) > 1:
        # Perform stratified split
        train, temp = train_test_split(meta, stratify=meta['label'], test_size=0.3, random_state=RANDOM_SEED)
        val, test = train_test_split(temp, stratify=temp['label'], test_size=0.5, random_state=RANDOM_SEED)

        for df, name in [(train,'train'), (val,'val'), (test,'test')]:
            print(f'\nProcessing {name} split...')

            # 1. Save metadata CSV
            if 'label' in df.columns:
                 df_csv = df.drop(columns=['label']).copy()
            df_csv.to_csv(OUT_DIR / f'metadata_{name}.csv', index=False)
            print(f'{name}: {len(df)} images.')

            # 2. Physical File Movement
            for index, row in tqdm(df.iterrows(), total=len(df), desc=f'Moving {name} files'):
                
                # row['image_path'] is relative to OUT_DIR, e.g., 'images/Apple/healthy/hash.jpg'
                
                # The unified image location (Source)
                src_file = OUT_DIR / row['image_path']
                
                # Determine the new destination path structure: images/SPLIT_NAME/Crop/Disease/hash.jpg
                # e.g., images/train/Apple/healthy/hash.jpg
                
                # The path inside the 'images' folder (e.g., 'Apple/healthy/hash.jpg')
                relative_path_in_images = Path(row['image_path']).relative_to('images') 
                
                # Destination folder: IMG_OUT / split_name / relative_path_in_images.parent
                dst_folder = IMG_OUT / name / relative_path_in_images.parent 
                
                # Destination file path
                dst_file = dst_folder / relative_path_in_images.name
                
                # Create the destination folder (e.g., jordan_dataset/images/train/Apple/healthy)
                dst_folder.mkdir(parents=True, exist_ok=True)
                
                # Move the file from the unified pool into the split folder
                try:
                    # Only move if the source file exists and is not already in the final destination
                    if src_file.exists() and not dst_file.exists(): 
                        shutil.move(src_file, dst_file)
                except Exception as e:
                    print(f"\nWarning: Failed to move {src_file} to {dst_file}. Error: {e}")
    else:
        print("Warning: Not enough data left after filtering to perform stratified split.")
else:
    print("Warning: Not enough data or classes to perform stratified Train/Val/Test split.")

# --- Simple augmentation (Modified to operate on the new split folders) ---
print('\n--- Augmenting Small Classes ---')
MIN_SAMPLES = 500
augmenter = A.Compose([
    A.RandomRotate90(), 
    A.HorizontalFlip(),
    A.Transpose(),
    A.RandomBrightnessContrast(p=0.5),
    A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=15, p=0.7),
])

augmented_count = 0
# We now rely on the 'train' split count, as augmentation only happens on training data
if 'train' in locals() and 'label' in train.columns:
    counts = train['label'].value_counts()
    
    # Save the training set class counts
    print("Saving training set class counts...")
    counts.to_frame('n_images').to_csv(OUT_DIR / 'train_class_counts.csv')

    for label, n in counts.items():
        if n < MIN_SAMPLES:
            need = MIN_SAMPLES - n
            print(f'Augmenting {label}, need {need} copies.')
            crop, disease = label.split('___')
            
            # --- OPERATE ON THE 'train' SPLIT FOLDER ---
            folder = IMG_OUT / 'train' / crop / disease 
            
            images = [p for p in folder.glob('*.jpg') if '_aug_' not in p.name]
            if not images:
                print(f"Warning: No source images found to augment for {label}")
                continue
            
            i = 0
            while need > 0:
                src = random.choice(images)
                try:
                    img = Image.open(src).convert('RGB')
                    arr = np.array(img)
                    aug = augmenter(image=arr)['image']
                    
                    new_name = src.stem + f'_aug_{i}.jpg'
                    outp = folder / new_name
                    Image.fromarray(aug).save(outp, format='JPEG', quality=90)
                    
                    augmented_count += 1
                    need -= 1
                    i += 1
                except Exception as e:
                    print(f'Augment failed for {src}: {e}')

print('Augmented images created:', augmented_count)

print('\n✨ Dataset Consolidation and Splitting Complete! ✨')
print(f"The data is ready in the standard PyTorch format:")
print(f"{IMG_OUT}/train/[Crop]/[Disease]/*.jpg")
print(f"{IMG_OUT}/val/[Crop]/[Disease]/*.jpg")
print(f"{IMG_OUT}/test/[Crop]/[Disease]/*.jpg")

--- Standardizing, Renaming, and Gathering Images ---
Using Data Directory: /Users/sanadmadani/plant-disease-detection/plant-disease-detection/data_raw


Processing Classes: 100%|██████████| 52/52 [05:38<00:00,  6.51s/it]



Collected 57560 images.
The unified dataset images are located in: /Users/sanadmadani/plant-disease-detection/plant-disease-detection/jordan_dataset2/images

--- Splitting Data and Moving Files ---

Processing train split...
train: 40292 images.


Moving train files: 100%|██████████| 40292/40292 [00:14<00:00, 2800.76it/s]



Processing val split...
val: 8634 images.


Moving val files: 100%|██████████| 8634/8634 [00:02<00:00, 4206.02it/s]



Processing test split...
test: 8634 images.


Moving test files: 100%|██████████| 8634/8634 [00:02<00:00, 3023.44it/s]
  original_init(self, **validated_kwargs)



--- Augmenting Small Classes ---
Saving training set class counts...
Augmenting Olive___Aculus_olearius_mite, need 23 copies.
Augmenting Wheat___healthy, need 298 copies.
Augmenting Wheat___Stem_fly, need 380 copies.
Augmenting Wheat___Powdery_mildew, need 382 copies.
Augmenting Wheat___Black_rust, need 388 copies.
Augmenting Wheat___Aphid, need 391 copies.
Augmenting Eggplant___Leaf_Spot_Disease, need 392 copies.
Augmenting Wheat___Mite, need 392 copies.
Augmenting Eggplant___Insect_Pest_Disease, need 396 copies.
Augmenting Eggplant___healthy, need 396 copies.
Augmenting Eggplant___White_Mold_Disease, need 397 copies.
Augmenting Eggplant___Mosaic_Virus_Disease, need 398 copies.
Augmenting Eggplant___Small_Leaf_Disease, need 398 copies.
Augmenting Cauliflower___healthy, need 400 copies.
Augmenting Wheat___Leaf_blight, need 401 copies.
Augmenting Eggplant___Wilt_Disease, need 403 copies.
Augmenting Cauliflower___Bacterial_spot_rot, need 414 copies.
Augmenting Cauliflower___Downy_Mildew