# Fine-Tuning Pre-trained Models on Combined Datasets

This notebook implements fine-tuning of pre-trained EfficientNet-B0 and ViT-Base models on a combined dataset that includes:
1. **Main Dataset**: Plant_leaf_diseases_dataset_with_augmentation (all 39 classes)
2. **Plant_doc Dataset**: Additional training data for overlapping classes
3. **FieldPlant Dataset**: Real-world field images for domain adaptation

## Fine-Tuning Strategy:
- **Lower Learning Rate**: 1e-4 (10x lower than original 3e-4)
- **Fewer Epochs**: 5 epochs (vs 20 in original training)
- **Layer-wise Learning Rates**: Different rates for early vs later layers
- **Evaluation**: Separate validation on all three datasets


## Performance Optimizations for RTX 3070 Super

**Optimizations Applied:**
1. **Mixed Precision Training (FP16)** - ~1.5-2x speedup, uses less VRAM
2. **Increased Batch Size** - 32 → 64 (RTX 3070 Super can handle it with FP16)
3. **Path Resolution** - Fixed for experiment_1/ subdirectory
4. **Pretrained Models** - Ensures ImageNet pretrained weights are used
5. **num_workers=0** - Avoids multiprocessing issues on Windows

**Expected Speedup:** 2-3x faster training per epoch


In [12]:
# Imports and Setup
import os
import time
import copy
from pathlib import Path
from collections import Counter, defaultdict
import json

import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torch.cuda.amp import autocast
from torch.amp import GradScaler

from PIL import Image
import torchvision.transforms as T
import numpy as np
import pandas as pd

import timm
from sklearn.metrics import f1_score, classification_report, confusion_matrix

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)

# Check for mixed precision support
USE_AMP = torch.cuda.is_available() and hasattr(torch.cuda, 'amp')
print(f"Mixed Precision (AMP): {'Enabled' if USE_AMP else 'Disabled'}")

# Check PyTorch version for torch.compile (disabled for stability)
USE_COMPILE = False  # hasattr(torch, 'compile') and torch.__version__ >= "2.0.0"
print(f"torch.compile available (but disabled): {hasattr(torch, 'compile')}")

# Paths - adjusted for notebook location in experiment_1/ subdirectory
import os
current_dir = Path(os.getcwd())

# Check if we're in experiment_1 subdirectory
if current_dir.name == "experiment_1":
    BASE_DIR = current_dir.parent
else:
    # If running from root PLANT_LEAF_DISEASE_DETECTION, use current directory
    BASE_DIR = current_dir

METADATA_DIR = BASE_DIR / "metadata"
LABEL_MAPPING_PATH = METADATA_DIR / "label_mapping.json"
DATASET_INDEX_PATH = METADATA_DIR / "dataset_index.json"
DATA_DIR = BASE_DIR / "data"
PLANT_DOC_DIR = DATA_DIR / "Plant_doc"  # Base directory - function will append train/test
FIELDPLANT_DIR = DATA_DIR / "FieldPlant_reformatted"

MODELS_DIR = BASE_DIR / "models"
MODELS_DIR.mkdir(parents=True, exist_ok=True)

print(f"Current working directory: {current_dir}")
print(f"Base directory: {BASE_DIR}")
print(f"Metadata directory: {METADATA_DIR}")
print(f"Models directory: {MODELS_DIR}")


Using device: cuda
Mixed Precision (AMP): Enabled
torch.compile available (but disabled): True
Current working directory: d:\Programming\Seminar_Project\PLANT_LEAF_DISEASE_DETECTION\experiment_1
Base directory: d:\Programming\Seminar_Project\PLANT_LEAF_DISEASE_DETECTION
Metadata directory: d:\Programming\Seminar_Project\PLANT_LEAF_DISEASE_DETECTION\metadata
Models directory: d:\Programming\Seminar_Project\PLANT_LEAF_DISEASE_DETECTION\models


## Step 1: Load Metadata and Create Class Mappings


In [13]:
# Load main dataset metadata
with open(LABEL_MAPPING_PATH, "r") as f:
    label_mapping = json.load(f)

with open(DATASET_INDEX_PATH, "r") as f:
    dataset_index = json.load(f)

# Create mappings
id_to_label = {c["id"]: c["canonical_label"] for c in label_mapping["classes"]}
label_to_id = {v: k for k, v in id_to_label.items()}
folder_to_label = {}
for c in label_mapping["classes"]:
    for folder in c.get("pv_folders", []):
        folder_to_label[folder] = c["canonical_label"]

num_classes = len(label_mapping["classes"])
print(f"Main dataset: {num_classes} classes")
print(f"Total samples in dataset_index: {len(dataset_index)}")

# Field-poor classes (for augmentations)
FIELD_POOR_THRESHOLD = 5
field_count_by_class = {
    c["id"]: c.get("field_count", 0)
    for c in label_mapping["classes"]
}
field_poor_classes = {
    cid for cid, cnt in field_count_by_class.items()
    if cnt <= FIELD_POOR_THRESHOLD
}
print(f"Field-poor classes: {len(field_poor_classes)}")


Main dataset: 39 classes
Total samples in dataset_index: 61486
Field-poor classes: 39


## Step 2: Create Mappings for Plant_doc and FieldPlant to Main Dataset Classes


In [14]:
# Mapping from Plant_doc folder names to canonical labels
# Based on dataset_classes_diagram.txt analysis
plant_doc_to_canonical = {
    "Apple_leaf": "apple_healthy",
    "Apple_rust_leaf": "apple_cedar_apple_rust",
    "Apple_Scab_Leaf": "apple_apple_scab",
    "Bell_pepper_leaf": "pepper,_bell_healthy",
    "Bell_pepper_leaf_spot": "pepper,_bell_bacterial_spot",
    "Blueberry_leaf": "blueberry_healthy",
    "Cherry_leaf": "cherry_healthy",
    "Corn_Gray_leaf_spot": "corn_cercospora_leaf_spot_gray_leaf_spot",
    "Corn_leaf_blight": "corn_northern_leaf_blight",
    "Corn_rust_leaf": "corn_common_rust",
    "grape_leaf": "grape_healthy",
    "grape_leaf_black_rot": "grape_black_rot",
    "Peach_leaf": "peach_healthy",
    "Potato_leaf_early_blight": "potato_early_blight",
    "Potato_leaf_late_blight": "potato_late_blight",
    "Raspberry_leaf": "raspberry_healthy",
    "Soyabean_leaf": "soybean_healthy",
    "Squash_Powdery_mildew_leaf": "squash_powdery_mildew",
    "Strawberry_leaf": "strawberry_healthy",
    "Tomato_Early_blight_leaf": "tomato_early_blight",
    "Tomato_leaf": "tomato_healthy",
    "Tomato_leaf_bacterial_spot": "tomato_bacterial_spot",
    "Tomato_leaf_late_blight": "tomato_late_blight",
    "Tomato_leaf_mosaic_virus": "tomato_tomato_mosaic_virus",
    "Tomato_leaf_yellow_virus": "tomato_tomato_yellow_leaf_curl_virus",
    "Tomato_mold_leaf": "tomato_leaf_mold",
    "Tomato_Septoria_leaf_spot": "tomato_septoria_leaf_spot",
}

# Mapping from FieldPlant folder names to canonical labels
fieldplant_to_canonical = {
    "Corn___Gray_leaf_spot": "corn_cercospora_leaf_spot_gray_leaf_spot",
    "Corn___rust_leaf": "corn_common_rust",
    "Corn___leaf_blight": "corn_northern_leaf_blight",
    "Corn___healthy": "corn_healthy",
    "Tomato___healthy": "tomato_healthy",
    "Tomato___leaf_mosaic_virus": "tomato_tomato_mosaic_virus",
    "Tomato___leaf_yellow_virus": "tomato_tomato_yellow_leaf_curl_virus",
    # Add more mappings as needed based on FieldPlant structure
}

print(f"Plant_doc mappings: {len(plant_doc_to_canonical)} classes")
print(f"FieldPlant mappings: {len(fieldplant_to_canonical)} classes")


Plant_doc mappings: 27 classes
FieldPlant mappings: 7 classes


## Step 3: Define Data Loading Functions


In [15]:
def load_plant_doc_data(data_dir, split="train"):
    """Load Plant_doc dataset entries and map to main dataset class IDs."""
    entries = []
    split_dir = data_dir / split
    
    if not split_dir.exists():
        print(f"Warning: {split_dir} does not exist!")
        return entries
    
    for folder in split_dir.iterdir():
        if not folder.is_dir():
            continue
        
        canonical = plant_doc_to_canonical.get(folder.name)
        if canonical is None:
            continue
        
        if canonical not in label_to_id:
            continue
        
        class_id = label_to_id[canonical]
        
        # Get all images in folder
        image_files = list(folder.glob("*.jpg")) + list(folder.glob("*.JPG"))
        for img_path in image_files:
            entries.append({
                "path": str(img_path),
                "class_id": class_id,
                "dataset": "plant_doc",
                "domain": "pv",
                "split": split
            })
    
    return entries

def load_fieldplant_data(data_dir):
    """Load FieldPlant dataset entries and map to main dataset class IDs."""
    entries = []
    
    if not data_dir.exists():
        return entries
    
    for folder in data_dir.iterdir():
        if not folder.is_dir():
            continue
        
        canonical = fieldplant_to_canonical.get(folder.name)
        if canonical is None:
            continue
        
        if canonical not in label_to_id:
            continue
        
        class_id = label_to_id[canonical]
        
        # Get all images in folder
        image_files = list(folder.glob("*.jpg")) + list(folder.glob("*.JPG"))
        for img_path in image_files:
            entries.append({
                "path": str(img_path),
                "class_id": class_id,
                "dataset": "fieldplant",
                "domain": "field",
                "split": "train"  # FieldPlant doesn't have split, use all as train
            })
    
    return entries

print("Data loading functions defined.")


Data loading functions defined.


## Step 4: Load and Combine All Training Data


In [16]:
# Load main dataset training data
train_entries_main = [e for e in dataset_index if e["split"] == "train"]
print(f"Main dataset training samples: {len(train_entries_main)}")

# Load Plant_doc training data
train_entries_plant_doc = load_plant_doc_data(PLANT_DOC_DIR, split="train")
print(f"Plant_doc training samples: {len(train_entries_plant_doc)}")

# Load FieldPlant training data
train_entries_fieldplant = load_fieldplant_data(FIELDPLANT_DIR)
print(f"FieldPlant training samples: {len(train_entries_fieldplant)}")

# Combine all training data
all_train_entries = train_entries_main + train_entries_plant_doc + train_entries_fieldplant
print(f"\nTotal combined training samples: {len(all_train_entries)}")

# Show distribution by dataset
dataset_counts = Counter(e.get("dataset", "main") for e in all_train_entries)
print("\nTraining samples by dataset:")
for dataset, count in sorted(dataset_counts.items()):
    print(f"  {dataset}: {count}")

# Show class distribution
print("\nClass distribution (top 10):")
class_counts = Counter(e["class_id"] for e in all_train_entries)
for class_id, count in sorted(class_counts.items(), key=lambda x: x[1], reverse=True)[:10]:
    class_name = id_to_label[class_id]
    print(f"  Class {class_id} ({class_name}): {count} samples")


Main dataset training samples: 49179
Plant_doc training samples: 5336
FieldPlant training samples: 4640

Total combined training samples: 59155

Training samples by dataset:
  fieldplant: 4640
  main: 49179
  plant_doc: 5336

Class distribution (top 10):
  Class 38 (tomato_tomato_yellow_leaf_curl_virus): 4961 samples
  Class 11 (corn_northern_leaf_blight): 4468 samples
  Class 16 (orange_haunglongbing_citrus_greening): 4405 samples
  Class 25 (soybean_healthy): 4186 samples
  Class 31 (tomato_healthy): 1918 samples
  Class 29 (tomato_bacterial_spot): 1903 samples
  Class 17 (peach_bacterial_spot): 1837 samples
  Class 32 (tomato_late_blight): 1729 samples
  Class 26 (squash_powdery_mildew): 1716 samples
  Class 34 (tomato_septoria_leaf_spot): 1706 samples


## Step 5: Create Validation Sets from All Three Datasets


In [17]:
# Main dataset validation
val_entries_main = [e for e in dataset_index if e["split"] == "val"]
print(f"Main dataset validation samples: {len(val_entries_main)}")

# Plant_doc validation (use test set)
val_entries_plant_doc = load_plant_doc_data(PLANT_DOC_DIR, split="test")
print(f"Plant_doc validation samples: {len(val_entries_plant_doc)}")

# FieldPlant validation (use 20% of data)
val_entries_fieldplant_all = load_fieldplant_data(FIELDPLANT_DIR)
np.random.seed(42)
if len(val_entries_fieldplant_all) > 0:
    indices = np.arange(len(val_entries_fieldplant_all))
    np.random.shuffle(indices)
    val_size = int(len(val_entries_fieldplant_all) * 0.2)
    val_entries_fieldplant = [val_entries_fieldplant_all[i] for i in indices[:val_size]]
    # Remove validation samples from training
    train_entries_fieldplant = [val_entries_fieldplant_all[i] for i in indices[val_size:]]
    # Update training entries
    all_train_entries = train_entries_main + train_entries_plant_doc + train_entries_fieldplant
    print(f"FieldPlant validation samples: {len(val_entries_fieldplant)}")
    print(f"FieldPlant training samples (after split): {len(train_entries_fieldplant)}")
else:
    val_entries_fieldplant = []
    print("FieldPlant validation samples: 0")

# Store separately for per-dataset evaluation
validation_sets = {
    "main": val_entries_main,
    "plant_doc": val_entries_plant_doc,
    "fieldplant": val_entries_fieldplant
}


Main dataset validation samples: 6148
Plant_doc validation samples: 504
FieldPlant validation samples: 928
FieldPlant training samples (after split): 3712


## Step 6: Define Dataset Class and Transforms


In [18]:
# Transforms (same as 3_train_model.ipynb)
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
IMG_SIZE = 224

transform_pv_basic = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE)),
    T.RandomHorizontalFlip(p=0.5),
    T.ToTensor(),
    T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

transform_pv_field_style = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE)),
    T.RandomHorizontalFlip(p=0.5),
    T.RandomRotation(degrees=20),
    T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.3, hue=0.1),
    T.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.8, 1.2)),
    T.ToTensor(),
    T.RandomErasing(p=0.3, scale=(0.02, 0.15), ratio=(0.3, 3.3)),
    T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

transform_field = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE)),
    T.RandomHorizontalFlip(p=0.5),
    T.ColorJitter(brightness=0.3, contrast=0.3),
    T.ToTensor(),
    T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

transform_eval = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE)),
    T.ToTensor(),
    T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

# Dataset class (same as 3_train_model.ipynb)
class PlantDataset(Dataset):
    def __init__(self, entries, transform_train=True, base_dir=None):
        self.entries = entries
        self.transform_train = transform_train
        self.base_dir = base_dir or BASE_DIR  # Use BASE_DIR from cell 1

    def __len__(self):
        return len(self.entries)

    def __getitem__(self, idx):
        item = self.entries[idx]
        img_path = item["path"]
        class_id = item["class_id"]
        domain = item.get("domain", "pv")

        # Resolve path relative to BASE_DIR if it's a relative path
        if not Path(img_path).is_absolute():
            img_path = self.base_dir / img_path
        else:
            img_path = Path(img_path)

        try:
            img = Image.open(img_path).convert("RGB")
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            img = Image.new("RGB", (224, 224), (0, 0, 0))
            class_id = 0

        if self.transform_train:
            if domain == "field":
                img = transform_field(img)
            elif domain == "pv":
                if class_id in field_poor_classes and torch.rand(1).item() < 0.5:
                    img = transform_pv_field_style(img)
                else:
                    img = transform_pv_basic(img)
            else:
                img = transform_pv_field_style(img)
        else:
            img = transform_eval(img)

        return img, class_id

# Create datasets
train_dataset = PlantDataset(all_train_entries, transform_train=True)
val_dataset_main = PlantDataset(val_entries_main, transform_train=False)
val_dataset_plant_doc = PlantDataset(val_entries_plant_doc, transform_train=False)
val_dataset_fieldplant = PlantDataset(val_entries_fieldplant, transform_train=False)

print("Datasets created successfully!")


Datasets created successfully!


## Step 7: Create DataLoaders


In [19]:
# Create weighted sampler for training
train_class_counts = Counter(e["class_id"] for e in all_train_entries)
max_count = max(train_class_counts.values())
class_weights = {cid: max_count / cnt for cid, cnt in train_class_counts.items()}
sample_weights = [class_weights[e["class_id"]] for e in all_train_entries]

sampler = WeightedRandomSampler(
    weights=torch.DoubleTensor(sample_weights),
    num_samples=len(sample_weights),
    replacement=True
)

# Optimized batch size for RTX 3070 Super (8GB VRAM)
# Can handle larger batches with mixed precision
BATCH_SIZE = 64  # Increased from 32

# Combined training loader
combined_train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    sampler=sampler,
    num_workers=0,
    pin_memory=True if torch.cuda.is_available() else False
)

# Combined validation loader (for fine-tuning monitoring)
combined_val_loader = DataLoader(
    val_dataset_main,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    pin_memory=True if torch.cuda.is_available() else False
)

# Separate validation loaders for evaluation
val_loader_main = DataLoader(
    val_dataset_main,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    pin_memory=True if torch.cuda.is_available() else False
)

val_loader_plant_doc = DataLoader(
    val_dataset_plant_doc,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    pin_memory=True if torch.cuda.is_available() else False
)

val_loader_fieldplant = DataLoader(
    val_dataset_fieldplant,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    pin_memory=True if torch.cuda.is_available() else False
)

print(f"Train batches: {len(combined_train_loader)}")
print(f"Combined validation batches: {len(combined_val_loader)}")
print(f"Main validation batches: {len(val_loader_main)}")
print(f"Plant_doc validation batches: {len(val_loader_plant_doc)}")
print(f"FieldPlant validation batches: {len(val_loader_fieldplant)}")


Train batches: 910
Combined validation batches: 97
Main validation batches: 97
Plant_doc validation batches: 8
FieldPlant validation batches: 15


## Step 8: Define Training and Evaluation Functions


In [20]:
def train_one_epoch(model, loader, criterion, optimizer, device, scaler=None, use_amp=False):
    model.train()
    running_loss = 0.0
    all_targets = []
    all_preds = []

    for images, targets in loader:
        images = images.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        optimizer.zero_grad()
        
        if use_amp and scaler is not None:
            with autocast():
                outputs = model(images)
                loss = criterion(outputs, targets)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            outputs = model(images)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

        running_loss += loss.item() * images.size(0)
        preds = outputs.argmax(dim=1)
        all_targets.append(targets.detach().cpu())
        all_preds.append(preds.detach().cpu())

    if len(all_targets) == 0:
        raise ValueError("Training loader is empty!")
    
    all_targets = torch.cat(all_targets).numpy()
    all_preds = torch.cat(all_preds).numpy()

    epoch_loss = running_loss / len(loader.dataset)
    epoch_acc = (all_targets == all_preds).mean()
    epoch_f1 = f1_score(all_targets, all_preds, average="macro")

    return epoch_loss, epoch_acc, epoch_f1

@torch.no_grad()
def evaluate(model, loader, criterion, device, use_amp=False):
    model.eval()
    running_loss = 0.0
    all_targets = []
    all_preds = []

    for images, targets in loader:
        images = images.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        if use_amp:
            with autocast():
                outputs = model(images)
                loss = criterion(outputs, targets)
        else:
            outputs = model(images)
            loss = criterion(outputs, targets)

        running_loss += loss.item() * images.size(0)
        preds = outputs.argmax(dim=1)
        all_targets.append(targets.detach().cpu())
        all_preds.append(preds.detach().cpu())

    if len(all_targets) == 0:
        return 0.0, 0.0, 0.0, np.array([]), np.array([])
    
    all_targets = torch.cat(all_targets).numpy()
    all_preds = torch.cat(all_preds).numpy()

    epoch_loss = running_loss / len(loader.dataset) if len(loader.dataset) > 0 else 0.0
    epoch_acc = (all_targets == all_preds).mean() if len(all_targets) > 0 else 0.0
    epoch_f1 = f1_score(all_targets, all_preds, average="macro") if len(all_targets) > 0 else 0.0

    return epoch_loss, epoch_acc, epoch_f1, all_targets, all_preds

print("Training and evaluation functions defined.")


Training and evaluation functions defined.


## Step 9: Fine-Tuning Function with Lower Learning Rate


In [21]:
def fine_tune_model(
    checkpoint_path,
    train_loader_combined,
    val_loader_combined,
    max_epochs=5,
    lr=1e-4,  # Lower learning rate for fine-tuning
    weight_decay=1e-4,
    device=DEVICE,
    early_stopping_patience=3,
    use_layerwise_lr=True  # Use different LRs for different layers
):
    """
    Fine-tune a pre-trained model on additional datasets.
    
    Args:
        checkpoint_path: Path to pre-trained model checkpoint
        train_loader_combined: Combined training DataLoader
        val_loader_combined: Combined validation DataLoader (for monitoring)
        max_epochs: Maximum number of fine-tuning epochs (default: 5)
        lr: Learning rate (default: 1e-4, 10x lower than original)
        use_layerwise_lr: If True, use different LRs for different layers
    """
    print(f"Loading pre-trained model from: {checkpoint_path}")
    
    # Load the pre-trained model
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    model_name = checkpoint["model_name"]
    num_classes = checkpoint["num_classes"]
    
    # Recreate model (use pretrained=True to ensure ImageNet weights if checkpoint fails)
    model = timm.create_model(
        model_name,
        pretrained=True,  # Always use pretrained ImageNet weights
        num_classes=num_classes
    )
    model.load_state_dict(checkpoint["state_dict"])
    model.to(device)
    
    # Use torch.compile for faster execution (if enabled)
    if USE_COMPILE:
        print(f"  -> Compiling model with torch.compile()...")
        model = torch.compile(model, mode='reduce-overhead')
        print(f"  -> Model compiled!")
    
    print(f"Model loaded: {model_name}, num_classes: {num_classes}")
    print(f"Original best validation F1: {checkpoint.get('best_val_f1', 'N/A'):.4f}")
    
    # Create optimizer with layer-wise learning rates
    if use_layerwise_lr and hasattr(model, 'blocks'):  # For EfficientNet/ViT
        # Different learning rates for different layers
        # Early layers: very low LR (frozen-like)
        # Middle layers: low LR
        # Later layers: normal LR
        # Classifier: higher LR
        
        # Get layer groups (this is model-specific)
        if 'efficientnet' in model_name.lower():
            # EfficientNet structure
            early_params = []
            middle_params = []
            late_params = []
            classifier_params = []
            
            for name, param in model.named_parameters():
                if 'classifier' in name:
                    classifier_params.append(param)
                elif 'blocks.0' in name or 'blocks.1' in name or 'blocks.2' in name:
                    early_params.append(param)
                elif 'blocks.3' in name or 'blocks.4' in name or 'blocks.5' in name:
                    middle_params.append(param)
                else:
                    late_params.append(param)
            
            optimizer = AdamW([
                {'params': early_params, 'lr': lr * 0.1},      # 10% of base LR
                {'params': middle_params, 'lr': lr * 0.5},    # 50% of base LR
                {'params': late_params, 'lr': lr},             # 100% of base LR
                {'params': classifier_params, 'lr': lr * 2}    # 200% of base LR
            ], weight_decay=weight_decay)
            
            print("Using layer-wise learning rates:")
            print(f"  Early layers: {lr * 0.1:.6f}")
            print(f"  Middle layers: {lr * 0.5:.6f}")
            print(f"  Late layers: {lr:.6f}")
            print(f"  Classifier: {lr * 2:.6f}")
        else:
            # For ViT or other models, use simpler grouping
            optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
            print(f"Using uniform learning rate: {lr:.6f}")
    else:
        # Uniform learning rate
        optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
        print(f"Using uniform learning rate: {lr:.6f}")
    
    scheduler = CosineAnnealingLR(optimizer, T_max=max_epochs)
    criterion = nn.CrossEntropyLoss()
    
    # Initialize gradient scaler for mixed precision
    scaler = GradScaler('cuda') if USE_AMP else None
    if USE_AMP:
        print("  -> Mixed precision training enabled (FP16)")
    
    # Track best model
    best_val_f1 = checkpoint.get("best_val_f1", -1.0)
    best_state = copy.deepcopy(model.state_dict())
    history = {
        "train_loss": [],
        "train_acc": [],
        "train_f1": [],
        "val_loss": [],
        "val_acc": [],
        "val_f1": []
    }
    
    epochs_without_improvement = 0
    
    print(f"\nStarting fine-tuning:")
    print(f"  -> Training samples: {len(train_loader_combined.dataset)}")
    print(f"  -> Validation samples: {len(val_loader_combined.dataset)}")
    print(f"  -> Max epochs: {max_epochs}")
    print(f"  -> Base learning rate: {lr}")
    print(f"  -> Early stopping patience: {early_stopping_patience}")
    
    for epoch in range(1, max_epochs + 1):
        start_time = time.time()
        
        train_loss, train_acc, train_f1 = train_one_epoch(
            model, train_loader_combined, criterion, optimizer, device, scaler, USE_AMP
        )
        val_loss, val_acc, val_f1, _, _ = evaluate(
            model, val_loader_combined, criterion, device, USE_AMP
        )
        
        scheduler.step()
        
        history["train_loss"].append(train_loss)
        history["train_acc"].append(train_acc)
        history["train_f1"].append(train_f1)
        history["val_loss"].append(val_loss)
        history["val_acc"].append(val_acc)
        history["val_f1"].append(val_f1)
        
        elapsed = time.time() - start_time
        
        print(f"Fine-tune Epoch {epoch:02d}/{max_epochs} ({elapsed:.1f}s)")
        print(f"  Train - Loss: {train_loss:.4f}, Acc: {train_acc:.4f}, F1: {train_f1:.4f}")
        print(f"  Val   - Loss: {val_loss:.4f}, Acc: {val_acc:.4f}, F1: {val_f1:.4f}")
        
        # Track best by validation F1
        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            best_state = copy.deepcopy(model.state_dict())
            epochs_without_improvement = 0
            print(f"  -> New best! Val F1: {best_val_f1:.4f}")
        else:
            epochs_without_improvement += 1
            if epochs_without_improvement >= early_stopping_patience:
                print(f"  -> Early stopping after {epoch} epochs")
                break
    
    # Load best weights
    model.load_state_dict(best_state)
    
    # Save fine-tuned model
    ckpt_path = MODELS_DIR / f"{model_name}_fine_tuned.pt"
    torch.save({
        "model_name": model_name,
        "state_dict": model.state_dict(),
        "num_classes": num_classes,
        "best_val_f1": best_val_f1,
        "history": history,
        "base_checkpoint": str(checkpoint_path),
        "fine_tuned": True,
        "fine_tune_epochs": epoch,
        "fine_tune_lr": lr
    }, ckpt_path)
    print(f"\nSaved fine-tuned checkpoint to: {ckpt_path}")
    
    return model, history, best_val_f1


## Step 10: Fine-Tune EfficientNet-B0


In [22]:
# Fine-tune EfficientNet-B0
print("="*70)
print("FINE-TUNING EFFICIENTNET-B0")
print("="*70)

checkpoint_path_eff = MODELS_DIR / "efficientnet_b0_best.pt"

model_eff_finetuned, history_eff, best_val_f1_eff = fine_tune_model(
    checkpoint_path=checkpoint_path_eff,
    train_loader_combined=combined_train_loader,
    val_loader_combined=combined_val_loader,
    max_epochs=5,
    lr=1e-4,  # 10x lower than original 3e-4
    weight_decay=1e-4,
    device=DEVICE,
    early_stopping_patience=3,
    use_layerwise_lr=True
)

print(f"\nFine-tuning complete!")
print(f"Best validation F1: {best_val_f1_eff:.4f}")


FINE-TUNING EFFICIENTNET-B0
Loading pre-trained model from: d:\Programming\Seminar_Project\PLANT_LEAF_DISEASE_DETECTION\models\efficientnet_b0_best.pt
Model loaded: efficientnet_b0, num_classes: 39
Original best validation F1: 0.9988
Using layer-wise learning rates:
  Early layers: 0.000010
  Middle layers: 0.000050
  Late layers: 0.000100
  Classifier: 0.000200
  -> Mixed precision training enabled (FP16)

Starting fine-tuning:
  -> Training samples: 58227
  -> Validation samples: 6148
  -> Max epochs: 5
  -> Base learning rate: 0.0001
  -> Early stopping patience: 3


  with autocast():
  with autocast():


Fine-tune Epoch 01/5 (557.1s)
  Train - Loss: 0.1966, Acc: 0.9455, F1: 0.9456
  Val   - Loss: 0.0070, Acc: 0.9982, F1: 0.9976


  with autocast():
  with autocast():


Fine-tune Epoch 02/5 (582.1s)
  Train - Loss: 0.0817, Acc: 0.9743, F1: 0.9743
  Val   - Loss: 0.0107, Acc: 0.9974, F1: 0.9966


  with autocast():
  with autocast():


Fine-tune Epoch 03/5 (628.1s)
  Train - Loss: 0.0491, Acc: 0.9847, F1: 0.9847
  Val   - Loss: 0.0089, Acc: 0.9979, F1: 0.9972
  -> Early stopping after 3 epochs

Saved fine-tuned checkpoint to: d:\Programming\Seminar_Project\PLANT_LEAF_DISEASE_DETECTION\models\efficientnet_b0_fine_tuned.pt

Fine-tuning complete!
Best validation F1: 0.9988


## Step 11: Fine-Tune ViT-Base


In [23]:
# Fine-tune ViT-Base
print("="*70)
print("FINE-TUNING VIT-BASE")
print("="*70)

checkpoint_path_vit = MODELS_DIR / "vit_base_patch16_224_best.pt"

model_vit_finetuned, history_vit, best_val_f1_vit = fine_tune_model(
    checkpoint_path=checkpoint_path_vit,
    train_loader_combined=combined_train_loader,
    val_loader_combined=combined_val_loader,
    max_epochs=5,
    lr=1e-4,  # 10x lower than original 3e-4
    weight_decay=1e-4,
    device=DEVICE,
    early_stopping_patience=3,
    use_layerwise_lr=False  # ViT structure is different, use uniform LR
)

print(f"\nFine-tuning complete!")
print(f"Best validation F1: {best_val_f1_vit:.4f}")


FINE-TUNING VIT-BASE
Loading pre-trained model from: d:\Programming\Seminar_Project\PLANT_LEAF_DISEASE_DETECTION\models\vit_base_patch16_224_best.pt
Model loaded: vit_base_patch16_224, num_classes: 39
Original best validation F1: 0.9711
Using uniform learning rate: 0.000100
  -> Mixed precision training enabled (FP16)

Starting fine-tuning:
  -> Training samples: 58227
  -> Validation samples: 6148
  -> Max epochs: 5
  -> Base learning rate: 0.0001
  -> Early stopping patience: 3


  with autocast():
  with autocast():


Fine-tune Epoch 01/5 (572.2s)
  Train - Loss: 0.6271, Acc: 0.8455, F1: 0.8464
  Val   - Loss: 0.1428, Acc: 0.9546, F1: 0.9499


  with autocast():
  with autocast():


Fine-tune Epoch 02/5 (454.9s)
  Train - Loss: 0.4894, Acc: 0.8591, F1: 0.8598
  Val   - Loss: 0.1071, Acc: 0.9636, F1: 0.9606


  with autocast():
  with autocast():


Fine-tune Epoch 03/5 (445.2s)
  Train - Loss: 0.4239, Acc: 0.8769, F1: 0.8778
  Val   - Loss: 0.1056, Acc: 0.9660, F1: 0.9613
  -> Early stopping after 3 epochs

Saved fine-tuned checkpoint to: d:\Programming\Seminar_Project\PLANT_LEAF_DISEASE_DETECTION\models\vit_base_patch16_224_fine_tuned.pt

Fine-tuning complete!
Best validation F1: 0.9711


## Step 12: Evaluate on All Three Validation Sets Separately


In [24]:
def evaluate_on_all_datasets(model, val_loaders_dict, criterion, device, model_name):
    """Evaluate model on all validation datasets and return detailed statistics."""
    results = {}
    
    print(f"\n{'='*70}")
    print(f"VALIDATION RESULTS FOR {model_name.upper()}")
    print(f"{'='*70}\n")

    # Specify all class labels explicitly to include all 39 classes
    all_class_labels = list(range(num_classes))
    
    for dataset_name, val_loader in val_loaders_dict.items():
        if len(val_loader) == 0:
            print(f"Skipping {dataset_name.upper()} (empty dataset)")
            continue
            
        print(f"Evaluating on {dataset_name.upper()} dataset...")
        val_loss, val_acc, val_f1, all_targets, all_preds = evaluate(
            model, val_loader, criterion, device, USE_AMP
        )
        
        # Per-class metrics
        if len(all_targets) > 0:
            class_report = classification_report(
                all_targets, all_preds,
                labels=all_class_labels, 
                target_names=[id_to_label[i] for i in range(num_classes)],
                output_dict=True,
                zero_division=0
            )
            
            # Confusion matrix
            cm = confusion_matrix(all_targets, all_preds)
            
            results[dataset_name] = {
                "loss": val_loss,
                "accuracy": val_acc,
                "f1_macro": val_f1,
                "all_targets": all_targets,
                "all_preds": all_preds,
                "classification_report": class_report,
                "confusion_matrix": cm
            }
            
            print(f"  Loss: {val_loss:.4f}")
            print(f"  Accuracy: {val_acc:.4f} ({val_acc*100:.2f}%)")
            print(f"  Macro F1: {val_f1:.4f} ({val_f1*100:.2f}%)")
            print(f"  Total samples: {len(all_targets)}")
            print()
    
    return results

# Prepare validation loaders dictionary
val_loaders_dict = {
    "main": val_loader_main,
    "plant_doc": val_loader_plant_doc,
    "fieldplant": val_loader_fieldplant
}

criterion = nn.CrossEntropyLoss()

# Evaluate EfficientNet
print("\n" + "="*70)
print("EFFICIENTNET-B0 FINE-TUNED VALIDATION RESULTS")
print("="*70)
results_eff = evaluate_on_all_datasets(model_eff_finetuned, val_loaders_dict, criterion, DEVICE, "EfficientNet-B0 Fine-Tuned")



EFFICIENTNET-B0 FINE-TUNED VALIDATION RESULTS

VALIDATION RESULTS FOR EFFICIENTNET-B0 FINE-TUNED

Evaluating on MAIN dataset...


  with autocast():


  Loss: 0.0068
  Accuracy: 0.9990 (99.90%)
  Macro F1: 0.9988 (99.88%)
  Total samples: 6148

Evaluating on PLANT_DOC dataset...


  with autocast():


  Loss: 6.4742
  Accuracy: 0.2421 (24.21%)
  Macro F1: 0.2061 (20.61%)
  Total samples: 504

Evaluating on FIELDPLANT dataset...


  with autocast():


  Loss: 7.9304
  Accuracy: 0.1573 (15.73%)
  Macro F1: 0.0395 (3.95%)
  Total samples: 928



In [25]:
# Evaluate ViT
print("\n" + "="*70)
print("VIT-BASE FINE-TUNED VALIDATION RESULTS")
print("="*70)
results_vit = evaluate_on_all_datasets(model_vit_finetuned, val_loaders_dict, criterion, DEVICE, "ViT-Base Fine-Tuned")



VIT-BASE FINE-TUNED VALIDATION RESULTS

VALIDATION RESULTS FOR VIT-BASE FINE-TUNED

Evaluating on MAIN dataset...


  with autocast():


  Loss: 0.0848
  Accuracy: 0.9741 (97.41%)
  Macro F1: 0.9710 (97.10%)
  Total samples: 6148

Evaluating on PLANT_DOC dataset...


  with autocast():


  Loss: 9.4252
  Accuracy: 0.0516 (5.16%)
  Macro F1: 0.0381 (3.81%)
  Total samples: 504

Evaluating on FIELDPLANT dataset...


  with autocast():


  Loss: 6.2760
  Accuracy: 0.1746 (17.46%)
  Macro F1: 0.0268 (2.68%)
  Total samples: 928



## Step 13: Summary Comparison


In [26]:
# Create summary comparison table
print("\n" + "="*70)
print("SUMMARY COMPARISON TABLE")
print("="*70)
print("\nPerformance across all validation datasets:\n")

# Create comparison DataFrame
comparison_data = []
for dataset_name in ["main", "plant_doc", "fieldplant"]:
    if dataset_name in results_eff and dataset_name in results_vit:
        comparison_data.append({
            "Dataset": dataset_name.upper(),
            "Model": "EfficientNet-B0 (Fine-Tuned)",
            "Accuracy": results_eff[dataset_name]["accuracy"],
            "F1 Macro": results_eff[dataset_name]["f1_macro"],
            "Loss": results_eff[dataset_name]["loss"],
            "Samples": len(results_eff[dataset_name]["all_targets"])
        })
        comparison_data.append({
            "Dataset": dataset_name.upper(),
            "Model": "ViT-Base (Fine-Tuned)",
            "Accuracy": results_vit[dataset_name]["accuracy"],
            "F1 Macro": results_vit[dataset_name]["f1_macro"],
            "Loss": results_vit[dataset_name]["loss"],
            "Samples": len(results_vit[dataset_name]["all_targets"])
        })

df_comparison = pd.DataFrame(comparison_data)
print(df_comparison.to_string(index=False))

# Overall statistics
print("\n" + "="*70)
print("OVERALL STATISTICS")
print("="*70)

print("\nEfficientNet-B0 (Fine-Tuned):")
eff_df = df_comparison[df_comparison['Model']=='EfficientNet-B0 (Fine-Tuned)']
if len(eff_df) > 0:
    print(f"  Average Accuracy across datasets: {eff_df['Accuracy'].mean():.4f}")
    print(f"  Average F1 Macro across datasets: {eff_df['F1 Macro'].mean():.4f}")

print("\nViT-Base (Fine-Tuned):")
vit_df = df_comparison[df_comparison['Model']=='ViT-Base (Fine-Tuned)']
if len(vit_df) > 0:
    print(f"  Average Accuracy across datasets: {vit_df['Accuracy'].mean():.4f}")
    print(f"  Average F1 Macro across datasets: {vit_df['F1 Macro'].mean():.4f}")

# Best model per dataset
print("\n" + "="*70)
print("BEST MODEL PER DATASET")
print("="*70)
for dataset_name in ["main", "plant_doc", "fieldplant"]:
    if dataset_name in results_eff and dataset_name in results_vit:
        eff_f1 = results_eff[dataset_name]["f1_macro"]
        vit_f1 = results_vit[dataset_name]["f1_macro"]
        best_model = "EfficientNet-B0" if eff_f1 > vit_f1 else "ViT-Base"
        print(f"{dataset_name.upper()}: {best_model} (F1: {max(eff_f1, vit_f1):.4f})")



SUMMARY COMPARISON TABLE

Performance across all validation datasets:

   Dataset                        Model  Accuracy  F1 Macro     Loss  Samples
      MAIN EfficientNet-B0 (Fine-Tuned)  0.999024  0.998801 0.006806     6148
      MAIN        ViT-Base (Fine-Tuned)  0.974138  0.971001 0.084781     6148
 PLANT_DOC EfficientNet-B0 (Fine-Tuned)  0.242063  0.206061 6.474166      504
 PLANT_DOC        ViT-Base (Fine-Tuned)  0.051587  0.038147 9.425156      504
FIELDPLANT EfficientNet-B0 (Fine-Tuned)  0.157328  0.039494 7.930360      928
FIELDPLANT        ViT-Base (Fine-Tuned)  0.174569  0.026753 6.276009      928

OVERALL STATISTICS

EfficientNet-B0 (Fine-Tuned):
  Average Accuracy across datasets: 0.4661
  Average F1 Macro across datasets: 0.4148

ViT-Base (Fine-Tuned):
  Average Accuracy across datasets: 0.4001
  Average F1 Macro across datasets: 0.3453

BEST MODEL PER DATASET
MAIN: EfficientNet-B0 (F1: 0.9988)
PLANT_DOC: EfficientNet-B0 (F1: 0.2061)
FIELDPLANT: EfficientNet-B0 (F1: 0.0

## Step 14: Visual Inspection of Misclassified Images


In [27]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from pathlib import Path

def visualize_misclassifications(
    model, 
    val_loader, 
    results_dict, 
    dataset_name, 
    id_to_label,
    num_samples=30,
    save_dir=None
):
    """
    Visualize misclassified images to identify failure patterns.
    
    Args:
        model: Trained model
        val_loader: Validation DataLoader
        results_dict: Dictionary containing evaluation results
        dataset_name: Name of the dataset ('plant_doc' or 'fieldplant')
        id_to_label: Mapping from class ID to label name
        num_samples: Number of misclassified samples to visualize
        save_dir: Directory to save images (if None, creates 'misclassification_analysis')
    """
    if dataset_name not in results_dict:
        print(f"No results found for {dataset_name}")
        return
    
    results = results_dict[dataset_name]
    targets = results['all_targets']
    preds = results['all_preds']
    
    # Find misclassified samples
    misclassified_indices = []
    misclassified_info = []
    
    # Get all entries from the dataset
    dataset_entries = []
    if dataset_name == "plant_doc":
        dataset_entries = val_entries_plant_doc
    elif dataset_name == "fieldplant":
        dataset_entries = val_entries_fieldplant
    else:
        print(f"Unknown dataset: {dataset_name}")
        return
    
    for idx, (true_label, pred_label) in enumerate(zip(targets, preds)):
        if true_label != pred_label:
            misclassified_indices.append(idx)
            misclassified_info.append({
                'index': idx,
                'true_label': true_label,
                'pred_label': pred_label,
                'true_name': id_to_label[true_label],
                'pred_name': id_to_label[pred_label],
                'entry': dataset_entries[idx] if idx < len(dataset_entries) else None
            })
    
    if len(misclassified_indices) == 0:
        print(f"No misclassifications found for {dataset_name}")
        return
    
    print(f"\n{'='*70}")
    print(f"VISUAL INSPECTION: {dataset_name.upper()}")
    print(f"{'='*70}")
    print(f"Total misclassifications: {len(misclassified_indices)}")
    print(f"Sampling {min(num_samples, len(misclassified_indices))} examples\n")
    
    # Sample misclassified images
    import random
    random.seed(42)
    sampled_indices = random.sample(
        misclassified_indices, 
        min(num_samples, len(misclassified_indices))
    )
    
    # Create save directory
    if save_dir is None:
        save_dir = Path("./misclassification_analysis") / dataset_name
    else:
        save_dir = Path(save_dir) / dataset_name
    save_dir.mkdir(parents=True, exist_ok=True)
    
    # Load and visualize images
    model.eval()
    device = next(model.parameters()).device
    
    # Create figure for grid display
    num_cols = 5
    num_rows = (len(sampled_indices) + num_cols - 1) // num_cols
    fig = plt.figure(figsize=(20, 4 * num_rows))
    gs = gridspec.GridSpec(num_rows, num_cols, figure=fig, hspace=0.4, wspace=0.3)
    
    for plot_idx, mis_idx in enumerate(sampled_indices):
        row = plot_idx // num_cols
        col = plot_idx % num_cols
        ax = fig.add_subplot(gs[row, col])
        
        # Get the entry
        if mis_idx < len(dataset_entries):
            entry = dataset_entries[mis_idx]
            img_path = Path(entry['path'])
            
            try:
                # Load and display image
                img = Image.open(img_path).convert('RGB')
                ax.imshow(img)
                
                # Get true and predicted labels
                true_label = targets[mis_idx]
                pred_label = preds[mis_idx]
                true_name = id_to_label[true_label]
                pred_name = id_to_label[pred_label]
                
                # Format labels (truncate if too long)
                true_display = true_name[:25] + "..." if len(true_name) > 25 else true_name
                pred_display = pred_name[:25] + "..." if len(pred_name) > 25 else pred_name
                
                # Set title with color coding
                ax.set_title(
                    f"True: {true_display}\nPred: {pred_display}",
                    fontsize=9,
                    color='red' if true_label != pred_label else 'green',
                    fontweight='bold'
                )
                ax.axis('off')
                
                # Save individual image
                save_path = save_dir / f"misclass_{mis_idx:04d}_true_{true_label}_pred_{pred_label}.jpg"
                img.save(save_path, quality=95)
                
            except Exception as e:
                ax.text(0.5, 0.5, f"Error loading\nimage {mis_idx}:\n{str(e)}", 
                       ha='center', va='center', fontsize=8, color='red')
                ax.axis('off')
                print(f"Warning: Could not load image {mis_idx}: {e}")
    
    # Save grid figure
    plt.suptitle(
        f"Misclassified Images: {dataset_name.upper()} Dataset\n"
        f"Showing {len(sampled_indices)} of {len(misclassified_indices)} misclassifications",
        fontsize=14, fontweight='bold', y=0.995
    )
    grid_save_path = save_dir / f"misclassification_grid_{dataset_name}.png"
    plt.savefig(grid_save_path, dpi=150, bbox_inches='tight')
    print(f"✓ Saved grid visualization: {grid_save_path}")
    plt.close()
    
    # Create summary statistics
    print(f"\n{'='*70}")
    print(f"MISCLASSIFICATION SUMMARY: {dataset_name.upper()}")
    print(f"{'='*70}\n")
    
    # Count misclassifications by true class
    misclass_by_true = {}
    for mis_idx in misclassified_indices:
        true_label = targets[mis_idx]
        true_name = id_to_label[true_label]
        if true_name not in misclass_by_true:
            misclass_by_true[true_name] = 0
        misclass_by_true[true_name] += 1
    
    print("Misclassifications by True Class (Top 10):")
    sorted_misclass = sorted(misclass_by_true.items(), key=lambda x: x[1], reverse=True)
    for class_name, count in sorted_misclass[:10]:
        print(f"  {class_name}: {count} misclassifications")
    
    # Count misclassifications by predicted class
    misclass_by_pred = {}
    for mis_idx in misclassified_indices:
        pred_label = preds[mis_idx]
        pred_name = id_to_label[pred_label]
        if pred_name not in misclass_by_pred:
            misclass_by_pred[pred_name] = 0
        misclass_by_pred[pred_name] += 1
    
    print("\nMost Common Incorrect Predictions (Top 10):")
    sorted_pred = sorted(misclass_by_pred.items(), key=lambda x: x[1], reverse=True)
    for class_name, count in sorted_pred[:10]:
        print(f"  {class_name}: predicted {count} times (incorrectly)")
    
    print(f"\n✓ Individual images saved to: {save_dir}")
    print(f"✓ Total images saved: {len(sampled_indices)}")
    
    return save_dir

# Visualize misclassifications for both datasets
print("\n" + "="*70)
print("VISUAL INSPECTION OF MISCLASSIFIED IMAGES")
print("="*70)

# Visualize Plant_doc misclassifications
if "plant_doc" in results_eff:
    print("\nAnalyzing Plant_doc misclassifications...")
    plant_doc_save_dir = visualize_misclassifications(
        model=model_eff_finetuned,
        val_loader=val_loader_plant_doc,
        results_dict=results_eff,
        dataset_name="plant_doc",
        id_to_label=id_to_label,
        num_samples=30,
        save_dir="./misclassification_analysis"
    )

# Visualize FieldPlant misclassifications
if "fieldplant" in results_eff:
    print("\nAnalyzing FieldPlant misclassifications...")
    fieldplant_save_dir = visualize_misclassifications(
        model=model_eff_finetuned,
        val_loader=val_loader_fieldplant,
        results_dict=results_eff,
        dataset_name="fieldplant",
        id_to_label=id_to_label,
        num_samples=30,
        save_dir="./misclassification_analysis"
    )

print("\n" + "="*70)
print("VISUAL INSPECTION COMPLETE")
print("="*70)
print("\nNext steps:")
print("1. Review the saved images in ./misclassification_analysis/")
print("2. Look for patterns in misclassifications:")
print("   - Lighting conditions")
print("   - Background differences")
print("   - Image quality/artifacts")
print("   - Similar-looking classes")
print("3. Use insights to improve data augmentation and training strategy")



VISUAL INSPECTION OF MISCLASSIFIED IMAGES

Analyzing Plant_doc misclassifications...

VISUAL INSPECTION: PLANT_DOC
Total misclassifications: 382
Sampling 30 examples

✓ Saved grid visualization: misclassification_analysis\plant_doc\misclassification_grid_plant_doc.png

MISCLASSIFICATION SUMMARY: PLANT_DOC

Misclassifications by True Class (Top 10):
  potato_early_blight: 24 misclassifications
  tomato_septoria_leaf_spot: 24 misclassifications
  apple_cedar_apple_rust: 20 misclassifications
  tomato_tomato_mosaic_virus: 20 misclassifications
  tomato_bacterial_spot: 18 misclassifications
  apple_apple_scab: 16 misclassifications
  corn_northern_leaf_blight: 16 misclassifications
  soybean_healthy: 16 misclassifications
  strawberry_healthy: 16 misclassifications
  tomato_healthy: 16 misclassifications

Most Common Incorrect Predictions (Top 10):
  background_without_leaves: predicted 158 times (incorrectly)
  tomato_late_blight: predicted 30 times (incorrectly)
  tomato_tomato_yellow_l

In [28]:
def analyze_failures(results_dict, dataset_name, id_to_label):
    """Analyze failure patterns."""
    if dataset_name not in results_dict:
        return
    
    results = results_dict[dataset_name]
    cm = results['confusion_matrix']
    targets = results['all_targets']
    preds = results['all_preds']
    
    # Find most confused classes
    print(f"\n{'='*70}")
    print(f"FAILURE ANALYSIS: {dataset_name.upper()}")
    print(f"{'='*70}\n")
    
    # Per-class accuracy
    print("Per-Class Accuracy:")
    for class_id in range(len(cm)):
        if cm[class_id].sum() > 0:
            acc = cm[class_id, class_id] / cm[class_id].sum()
            print(f"  Class {class_id} ({id_to_label[class_id]}): {acc:.3f}")
    
    # Most common misclassifications
    print("\nTop 10 Misclassifications:")
    misclass_pairs = []
    for i in range(len(cm)):
        for j in range(len(cm)):
            if i != j and cm[i, j] > 0:
                misclass_pairs.append((i, j, cm[i, j]))
    misclass_pairs.sort(key=lambda x: x[2], reverse=True)
    for i, j, count in misclass_pairs[:10]:
        print(f"  {id_to_label[i]} → {id_to_label[j]}: {count} times")

# Run analysis
analyze_failures(results_eff, "plant_doc", id_to_label)
analyze_failures(results_eff, "fieldplant", id_to_label)


FAILURE ANALYSIS: PLANT_DOC

Per-Class Accuracy:
  Class 0 (apple_apple_scab): 0.200
  Class 2 (apple_cedar_apple_rust): 0.000
  Class 3 (apple_healthy): 0.222
  Class 5 (blueberry_healthy): 0.364
  Class 6 (cherry_healthy): 0.300
  Class 7 (cherry_powdery_mildew): 0.500
  Class 8 (corn_cercospora_leaf_spot_gray_leaf_spot): 0.300
  Class 10 (corn_healthy): 0.333
  Class 11 (corn_northern_leaf_blight): 0.250
  Class 13 (grape_esca_black_measles): 0.667
  Class 16 (orange_haunglongbing_citrus_greening): 0.222
  Class 17 (peach_bacterial_spot): 0.333
  Class 18 (peach_healthy): 0.625
  Class 19 (pepper,_bell_bacterial_spot): 0.143
  Class 20 (pepper,_bell_healthy): 0.125
  Class 21 (potato_early_blight): 0.429
  Class 22 (potato_healthy): 0.000
  Class 23 (potato_late_blight): 0.333
  Class 24 (raspberry_healthy): 0.000
  Class 26 (squash_powdery_mildew): 0.000
  Class 27 (strawberry_healthy): 0.222
  Class 28 (strawberry_leaf_scorch): 0.000
  Class 29 (tomato_bacterial_spot): 0.400
  Cl