# Disease Normalisation Strategy

This is a crucial data engineering step. If we just concatenate the datasets, the model might treat `PlantVillage/Tomato_Blight` and `NewPlantDiseases/Tomato_Blight` as two different classes.

By **merging them into a Unified Label Space**, you force the model to learn that "Blight is Blight" regardless of whether it came from a Lab (PlantVillage) or a Kaggle competition (NewPlant). This significantly boosts the "Domain Invariance" of your pre-trained weights.

### **The Strategy: "Fuzzy Alignment"**

We need to scan all ~195k images, normalize their folder names (remove `_`, `leaf`, `dataset_source`), and group them by **Crop** and **Pathology**.

This notebook implements the alignment logic.

### **Step 1: The Alignment Script**

This script scans the processed dataset directories, normalizes the class names, and generates a unified mapping.

In [12]:
import os
import pandas as pd
import json
import re
from collections import defaultdict
from pathlib import Path

# CONFIG: The 8 datasets with CORRECT PROCESSED PATHS
DATASET_ROOTS = {
    'new_plant_diseases': 'data/processed/dataset/NewPlantDiseases_processed',
    'plantvillage': 'data/processed/dataset/PlantVillage_processed',
    'cassava': 'data/processed/dataset/Cassava_processed',
    'plantwild': 'data/processed/dataset/PlantWild_processed',
    'wheat': 'data/processed/dataset/Wheat_processed',
    'plantseg': 'data/processed/dataset/PlantSeg_processed',
    'plantdoc': 'data/processed/dataset/PlantDoc_processed',
    'tomatoleaf': 'data/processed/dataset/TomatoLeaf_processed'
}

# Fix relative paths if running from notebooks directory
if os.path.basename(os.getcwd()) == 'notebooks':
    # Move up one level for paths to work
    os.chdir('..')
    print(f"Changed working directory to: {os.getcwd()}")

# STOP WORDS to remove for cleaner matching
STOP_WORDS = ['leaf', 'leaves', 'plant', 'dataset', 'processed', 'images', 'train', 'val', 'test']

def normalize_name(class_name):
    """
    Converts 'Tomato_Early_blight' -> 'tomato early blight'
    Converts 'Apple___Black_rot' -> 'apple black rot'
    """
    # Lowercase
    name = class_name.lower()
    # Replace separators
    name = re.sub(r'[ _\-\.]', ' ', name)
    # Remove triple underscores common in some datasets
    name = name.replace('   ', ' ').replace('  ', ' ')
    
    # Remove stop words
    tokens = name.split()
    tokens = [t for t in tokens if t not in STOP_WORDS]
    
    # Reassemble
    return " ".join(tokens).strip()

def get_crop_name(normalized_name):
    """
    Heuristic: First word is usually the crop (e.g., 'tomato', 'apple').
    """
    if not normalized_name: return "unknown"
    return normalized_name.split()[0]

def analyze_datasets():
    all_classes = []
    
    print("Scanning Datasets...")
    for ds_name, path in DATASET_ROOTS.items():
        if not os.path.exists(path):
            # Try with ../ prefix just in case
            if os.path.exists(f"../{path}"):
                path = f"../{path}"
            else:
                print(f"!! Warning: Path not found for {ds_name}: {path}")
                continue
            
        # List subdirectories (classes)
        try:
            # Some datasets have 'images' subdir, some have classes directly
            # Based on previous processing, processed datasets usually have classes directly under the root
            # OR under an 'images' subfolder if they were complex. 
            # Let's check if 'images' exists.
            search_path = path
            if os.path.exists(os.path.join(path, 'images')) and os.path.isdir(os.path.join(path, 'images')):
                 # If 'images' exists, check if it contains class folders or just images
                 # TomatoLeaf has 'images' (flat) and 'annotated'...
                 # PlantVillage has class folders directly.
                 # Let's check typical class folder structure.
                 pass
            
            # Scan for directories
            potential_classes = [d for d in os.listdir(search_path) if os.path.isdir(os.path.join(search_path, d))]
            
            # Filter out non-class folders like 'images', 'labels', 'annotated' if they aren't classes
            # But wait, TomatoLeaf has 'images' containing raw images. 
            # We generally want the *labeled* classes.
            
            for c in potential_classes:
                if c in ['images', 'labels', 'annotated']:
                    continue # Skip structural folders
                    
                norm = normalize_name(c)
                crop = get_crop_name(norm)
                all_classes.append({
                    'dataset': ds_name,
                    'original_class': c,
                    'normalized_class': norm,
                    'crop': crop,
                    'path': os.path.join(path, c)
                })
        except Exception as e:
            print(f"Error scanning {ds_name}: {e}")

    df = pd.DataFrame(all_classes)
    return df

def generate_unified_mapping(df):
    """
    Groups classes that have identical 'normalized_class'.
    Assigns a unique 'foundation_class_id' to each group.
    """
    # Group by Normalized Name
    unique_classes = df['normalized_class'].unique()
    unique_classes.sort()
    
    # Create ID Map
    class_to_id = {name: i for i, name in enumerate(unique_classes)}
    
    # Map back to DataFrame
    df['foundation_label_id'] = df['normalized_class'].map(class_to_id)
    
    return df, class_to_id

def main():
    df = analyze_datasets()
    
    if df.empty:
        print("No classes found. Check paths.")
        return

    # Generate Unified Map
    df, label_map = generate_unified_mapping(df)
    
    # Save to CSV for inspection
    df.to_csv('data/processed/dataset/foundation_class_alignment.csv', index=False)
    print("Saved alignment CSV to data/processed/dataset/foundation_class_alignment.csv")
    
    # Save Mapping JSON (for the dataloader)
    with open('data/processed/dataset/foundation_label_map.json', 'w') as f:
        json.dump(label_map, f, indent=4)
    print("Saved label map to data/processed/dataset/foundation_label_map.json")

    # PRINT SUMMARY
    print(f"\n--- Alignment Report ---")
    print(f"Total Source Classes Scanned: {len(df)}")
    print(f"Total Unified Classes: {len(label_map)}")
    print(f"Reduction: {len(df) - len(label_map)} duplicate classes merged.")
    
    print("\n--- Top Crops Found ---")
    print(df['crop'].value_counts().head(10))
    
    print("\n--- Example Merges ---")
    # Show cases where multiple datasets map to one ID
    merge_counts = df['foundation_label_id'].value_counts()
    merges = merge_counts[merge_counts > 1].head(5).index
    for idx in merges:
        label = list(label_map.keys())[list(label_map.values()).index(idx)]
        print(f"\nUnified Label: '{label}' (ID {idx})")
        subset = df[df['foundation_label_id'] == idx]
        for _, row in subset.iterrows():
            print(f"  - {row['dataset']}: {row['original_class']}")

if __name__ == '__main__':
    main()

Scanning Datasets...
Saved alignment CSV to data/processed/dataset/foundation_class_alignment.csv
Saved label map to data/processed/dataset/foundation_label_map.json

--- Alignment Report ---
Total Source Classes Scanned: 374
Total Unified Classes: 286
Reduction: 88 duplicate classes merged.

--- Top Crops Found ---
crop
tomato       44
wheat        33
soybean      21
banana       18
pepper       16
apple        16
corn         16
grape        15
potato       13
blueberry    13
Name: count, dtype: int64

--- Example Merges ---

Unified Label: 'potato early blight' (ID 183)
  - new_plant_diseases: potato_early_blight
  - plantvillage: potato_early_blight
  - plantwild: potato_early_blight
  - plantseg: potato_early_blight
  - plantdoc: potato_early_blight

Unified Label: 'potato late blight' (ID 185)
  - new_plant_diseases: potato_late_blight
  - plantvillage: potato_late_blight
  - plantwild: potato_late_blight
  - plantseg: potato_late_blight
  - plantdoc: potato_late_blight

Unified 

### **Step 2: Update the Pre-training Loader**

Now we modify `src/pretrain_foundation.py` to use this mapping. Instead of letting `ImageFolder` assign random IDs (0, 1, 2...), we force it to use our **Unified IDs**.

**Update snippet for `src/pretrain_foundation.py`:**

In [13]:
# Add this class to your pre-training script
from torchvision import datasets

class UnifiedImageFolder(datasets.ImageFolder):
    def __init__(self, root, mapping_df, dataset_name, transform=None):
        super().__init__(root, transform=transform)
        
        # Create a lookup: original_class_index -> unified_label_id
        # ImageFolder sorts classes alphabetically to assign indices
        self.original_classes = self.classes # List of folder names
        self.idx_to_unified = {}
        
        # Filter DF for this dataset
        ds_map = mapping_df[mapping_df['dataset'] == dataset_name]
        
        for idx, folder_name in enumerate(self.original_classes):
            # Find the row corresponding to this folder
            row = ds_map[ds_map['original_class'] == folder_name]
            if not row.empty:
                self.idx_to_unified[idx] = int(row.iloc[0]['foundation_label_id'])
            else:
                # Fallback (shouldn't happen if alignment script ran well)
                print(f"Warning: No mapping for {dataset_name}/{folder_name}")
                self.idx_to_unified[idx] = -1 # Ignore

    def __getitem__(self, index):
        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
            
        # SWAP TARGET: Original Local ID -> Unified Global ID
        unified_target = self.idx_to_unified[target]
        
        return sample, unified_target

**Why this is better:**
When the model sees a "Tomato Early Blight" from *PlantDoc* (noisy field) and *PlantVillage* (clean lab) and they both have `Label ID = 42`, the gradients force the model to **ignore the background noise** and focus on the **shared lesion features**.

This is the key to strong Domain Generalization. Run the alignment script and check the CSV!

In [15]:
import pandas as pd

pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
pd.set_option('display.max_colwidth', None)

try:
    df = pd.read_csv('data/processed/dataset/foundation_class_alignment.csv')
    
    # 1. Total unique unified classes
    num_unique_classes = df['foundation_label_id'].nunique()
    print(f"Total Unique Unified Classes: {num_unique_classes}")
    
    # 2. Check for merges (ids with > 1 row)
    # We want to see IDs that have contributions from multiple *different* datasets
    # Group by label_id, then count unique datasets
    merges = df.groupby('foundation_label_id')['dataset'].nunique()
    multi_source_classes = merges[merges > 1]
    print(f"\nNumber of classes with contributions from multiple datasets: {len(multi_source_classes)}")
    
    # 3. Show some example merges (Top 10 by number of contributing datasets)
    print("\nTop 50 Merged Classes (by source diversity):")
    top_merged_ids = multi_source_classes.sort_values(ascending=False).head(50).index
    
    for idx in top_merged_ids:
        # Get the normalized name (just take the first one)
        subset = df[df['foundation_label_id'] == idx]
        name = subset['normalized_class'].iloc[0]
        print(f"\nID {idx}: {name}")
        # Show breakdown
        print(subset[['dataset', 'original_class']].to_string(index=False))

    # 4. Specific check for Tomato classes
    print("\n--- Tomato Class Alignment Check ---")
    tomato_subset = df[df['crop'].str.lower() == 'tomato']
    # Group by foundation_label_id and show original classes
    tomato_groups = tomato_subset.groupby('foundation_label_id')
    for idx, group in tomato_groups:
        if group['dataset'].nunique() > 1: # Only show interesting merges
            name = group['normalized_class'].iloc[0]
            print(f"\nTomato ID {idx}: {name}")
            print(group[['dataset', 'original_class']].to_string(index=False))

except Exception as e:
    print(f"Error analyzing CSV: {e}")

Total Unique Unified Classes: 286

Number of classes with contributions from multiple datasets: 41

Top 50 Merged Classes (by source diversity):

ID 183: potato early blight
           dataset      original_class
new_plant_diseases potato_early_blight
      plantvillage potato_early_blight
         plantwild potato_early_blight
          plantseg potato_early_blight
          plantdoc potato_early_blight

ID 233: tomato bacterial spot
           dataset        original_class
new_plant_diseases tomato_bacterial_spot
      plantvillage tomato_bacterial_spot
         plantwild tomato_bacterial_spot
          plantseg tomato_bacterial_spot
          plantdoc tomato_bacterial_spot

ID 234: tomato early blight
           dataset      original_class
new_plant_diseases tomato_early_blight
      plantvillage tomato_early_blight
         plantwild tomato_early_blight
          plantseg tomato_early_blight
          plantdoc tomato_early_blight

ID 237: tomato mold
           dataset original_cla