# Test: Training on Common Classes Across All Three Datasets

This notebook implements:
1. **Test 1**: Training both EfficientNet and ViT on a collective dataset containing only classes that exist in ALL THREE datasets (Main, Plant_doc, and FieldPlant)
2. **Validation Block**: Performance statistics on validation images from all three datasets separately

## Common Classes Identified:
- Corn___Cercospora_leaf_spot Gray_leaf_spot
- Corn___Common_rust
- Corn___Northern_Leaf_Blight
- Tomato___healthy
- Tomato___Tomato_mosaic_virus
- Tomato___Tomato_Yellow_Leaf_Curl_Virus

**Total: 6 common classes**


In [1]:
# Imports and Setup
import os
import time
import copy
from pathlib import Path
from collections import Counter, defaultdict
import json

import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

from PIL import Image
import torchvision.transforms as T
import numpy as np
import pandas as pd

import timm
from sklearn.metrics import f1_score, classification_report, confusion_matrix

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)

# Paths
METADATA_DIR = Path("./metadata")
LABEL_MAPPING_PATH = METADATA_DIR / "label_mapping.json"
DATASET_INDEX_PATH = METADATA_DIR / "dataset_index.json"
DATA_DIR = Path("./data")
PLANT_DOC_TRAIN = DATA_DIR / "Plant_doc" / "train"
PLANT_DOC_TEST = DATA_DIR / "Plant_doc" / "test"
FIELDPLANT_DIR = DATA_DIR / "FieldPlant_reformatted"

MODELS_DIR = Path("./models")
MODELS_DIR.mkdir(parents=True, exist_ok=True)


Using device: cpu


## Step 1: Identify Common Classes Across All Three Datasets


In [2]:
# Load main dataset metadata
with open(LABEL_MAPPING_PATH, "r") as f:
    label_mapping = json.load(f)

with open(DATASET_INDEX_PATH, "r") as f:
    dataset_index = json.load(f)

# Create mapping from main dataset
id_to_label = {c["id"]: c["canonical_label"] for c in label_mapping["classes"]}
label_to_id = {v: k for k, v in id_to_label.items()}
folder_to_label = {}
for c in label_mapping["classes"]:
    for folder in c.get("pv_folders", []):
        folder_to_label[folder] = c["canonical_label"]

print(f"Main dataset: {len(label_mapping['classes'])} classes")

# Define common classes (classes that exist in ALL THREE datasets)
# Based on dataset_classes_diagram.txt analysis
COMMON_CLASSES_MAIN = [
    "Corn___Cercospora_leaf_spot Gray_leaf_spot",  # ID 8
    "Corn___Common_rust",                          # ID 9
    "Corn___Northern_Leaf_Blight",                 # ID 11
    "Tomato___healthy",                            # ID 31
    "Tomato___Tomato_mosaic_virus",               # ID 37
    "Tomato___Tomato_Yellow_Leaf_Curl_Virus",     # ID 38
]

# Map to canonical labels
common_canonical_labels = []
common_class_ids_original = []
for folder_name in COMMON_CLASSES_MAIN:
    canonical = folder_to_label.get(folder_name)
    if canonical:
        common_canonical_labels.append(canonical)
        common_class_ids_original.append(label_to_id[canonical])

# Create new label mapping for common classes (0-5)
common_id_to_original_id = {new_id: orig_id for new_id, orig_id in enumerate(common_class_ids_original)}
original_id_to_common_id = {orig_id: new_id for new_id, orig_id in enumerate(common_class_ids_original)}

print(f"\nCommon classes found: {len(common_canonical_labels)}")
print("Common classes:")
for i, (orig_id, canonical) in enumerate(zip(common_class_ids_original, common_canonical_labels)):
    print(f"  {i}: {canonical} (original ID: {orig_id})")


Main dataset: 39 classes

Common classes found: 6
Common classes:
  0: corn_cercospora_leaf_spot_gray_leaf_spot (original ID: 8)
  1: corn_common_rust (original ID: 9)
  2: corn_northern_leaf_blight (original ID: 11)
  3: tomato_healthy (original ID: 31)
  4: tomato_tomato_mosaic_virus (original ID: 37)
  5: tomato_tomato_yellow_leaf_curl_virus (original ID: 38)


## Step 2: Create Mapping for Plant_doc and FieldPlant Datasets


In [3]:
# Mapping from Plant_doc folder names to canonical labels
plant_doc_to_canonical = {
    "Corn_Gray_leaf_spot": "corn_cercospora_leaf_spot_gray_leaf_spot",
    "Corn_rust_leaf": "corn_common_rust",
    "Corn_leaf_blight": "corn_northern_leaf_blight",
    "Tomato_leaf": "tomato_healthy",
    "Tomato_leaf_mosaic_virus": "tomato_tomato_mosaic_virus",
    "Tomato_leaf_yellow_virus": "tomato_tomato_yellow_leaf_curl_virus",
}

# Mapping from FieldPlant folder names to canonical labels
fieldplant_to_canonical = {
    "Corn___Gray_leaf_spot": "corn_cercospora_leaf_spot_gray_leaf_spot",
    "Corn___rust_leaf": "corn_common_rust",
    "Corn___leaf_blight": "corn_northern_leaf_blight",
    "Corn___healthy": "corn_healthy",  # Note: this is NOT in common classes
    "Tomato___healthy": "tomato_healthy",
    "Tomato___leaf_mosaic_virus": "tomato_tomato_mosaic_virus",
    "Tomato___leaf_yellow_virus": "tomato_tomato_yellow_leaf_curl_virus",
}

print("Plant_doc mapping:")
for folder, canonical in plant_doc_to_canonical.items():
    if canonical in common_canonical_labels:
        print(f"  {folder} -> {canonical}")

print("\nFieldPlant mapping:")
for folder, canonical in fieldplant_to_canonical.items():
    if canonical in common_canonical_labels:
        print(f"  {folder} -> {canonical}")


Plant_doc mapping:
  Corn_Gray_leaf_spot -> corn_cercospora_leaf_spot_gray_leaf_spot
  Corn_rust_leaf -> corn_common_rust
  Corn_leaf_blight -> corn_northern_leaf_blight
  Tomato_leaf -> tomato_healthy
  Tomato_leaf_mosaic_virus -> tomato_tomato_mosaic_virus
  Tomato_leaf_yellow_virus -> tomato_tomato_yellow_leaf_curl_virus

FieldPlant mapping:
  Corn___Gray_leaf_spot -> corn_cercospora_leaf_spot_gray_leaf_spot
  Corn___rust_leaf -> corn_common_rust
  Corn___leaf_blight -> corn_northern_leaf_blight
  Tomato___healthy -> tomato_healthy
  Tomato___leaf_mosaic_virus -> tomato_tomato_mosaic_virus
  Tomato___leaf_yellow_virus -> tomato_tomato_yellow_leaf_curl_virus


## Step 3: Create Filtered Dataset with Only Common Classes


In [4]:
# Filter main dataset to only common classes
def filter_main_dataset(dataset_index, common_class_ids):
    """Filter main dataset entries to only include common classes."""
    filtered = []
    for entry in dataset_index:
        if entry["class_id"] in common_class_ids:
            # Create new entry with remapped class_id
            new_entry = entry.copy()
            new_entry["class_id"] = original_id_to_common_id[entry["class_id"]]
            new_entry["dataset"] = "main"
            filtered.append(new_entry)
    return filtered

# Filter training data
train_entries_main = [e for e in dataset_index if e["split"] == "train"]
train_entries_filtered = filter_main_dataset(train_entries_main, common_class_ids_original)

print(f"Original training samples: {len(train_entries_main)}")
print(f"Filtered training samples (common classes only): {len(train_entries_filtered)}")
print(f"\nClass distribution in filtered training set:")
class_counts = Counter(e["class_id"] for e in train_entries_filtered)
for class_id, count in sorted(class_counts.items()):
    canonical = common_canonical_labels[class_id]
    print(f"  Class {class_id} ({canonical}): {count} samples")


Original training samples: 49179
Filtered training samples (common classes only): 8910

Class distribution in filtered training set:
  Class 0 (corn_cercospora_leaf_spot_gray_leaf_spot): 800 samples
  Class 1 (corn_common_rust): 953 samples
  Class 2 (corn_northern_leaf_blight): 800 samples
  Class 3 (tomato_healthy): 1272 samples
  Class 4 (tomato_tomato_mosaic_virus): 800 samples
  Class 5 (tomato_tomato_yellow_leaf_curl_virus): 4285 samples


## Step 4: Add Plant_doc and FieldPlant Training Data


In [None]:
# Add Plant_doc training data
def load_plant_doc_data(data_dir, split="train"):
    """Load Plant_doc dataset entries."""
    entries = []
    split_dir = data_dir / split
    
    if not split_dir.exists():
        print(f"Warning: {split_dir} does not exist!")
        return entries
    
    print(f"Loading Plant_doc data from {split_dir}")
    print(f"Available folders: {[f.name for f in split_dir.iterdir() if f.is_dir()]}")
    print(f"Mapping keys: {list(plant_doc_to_canonical.keys())}")
    
    for folder in split_dir.iterdir():
        if not folder.is_dir():
            continue
        
        canonical = plant_doc_to_canonical.get(folder.name)
        if canonical is None:
            print(f"  Warning: No mapping found for folder '{folder.name}'")
            continue
            
        if canonical not in common_canonical_labels:
            print(f"  Warning: Canonical label '{canonical}' not in common classes")
            continue
            
        # Get common class ID
        common_class_id = common_canonical_labels.index(canonical)
        
        # Get all images in folder
        image_files = list(folder.glob("*.jpg")) + list(folder.glob("*.JPG"))
        print(f"  Found {len(image_files)} images in {folder.name} -> {canonical} (class_id={common_class_id})")
        for img_path in image_files:
            entries.append({
                "path": str(img_path),
                "class_id": common_class_id,
                "dataset": "plant_doc",
                "split": split
            })
    
    return entries

# Add FieldPlant training dataInitial attempt at domain adaptation through fine-tuning.
def load_fieldplant_data(data_dir):
    """Load FieldPlant dataset entries."""
    entries = []
    
    if not data_dir.exists():
        return entries
    
    for folder in data_dir.iterdir():
        if not folder.is_dir():
            continue
        
        canonical = fieldplant_to_canonical.get(folder.name)
        if canonical and canonical in common_canonical_labels:
            # Get common class ID
            common_class_id = common_canonical_labels.index(canonical)
            
            # Get all images in folder
            image_files = list(folder.glob("*.jpg")) + list(folder.glob("*.JPG"))
            for img_path in image_files:
                entries.append({
                    "path": str(img_path),
                    "class_id": common_class_id,
                    "dataset": "fieldplant",
                    "split": "train"  # FieldPlant doesn't have split, use all as train
                })
    
    return entries

# Load additional datasets
train_entries_plant_doc = load_plant_doc_data(PLANT_DOC_TRAIN, split="train")
train_entries_fieldplant = load_fieldplant_data(FIELDPLANT_DIR)

print(f"Plant_doc training samples: {len(train_entries_plant_doc)}")
print(f"FieldPlant training samples: {len(train_entries_fieldplant)}")

# Combine all training data
all_train_entries = train_entries_filtered + train_entries_plant_doc + train_entries_fieldplant
print(f"\nTotal combined training samples: {len(all_train_entries)}")

# Show distribution by dataset
dataset_counts = Counter(e["dataset"] for e in all_train_entries)
print("\nTraining samples by dataset:")
for dataset, count in sorted(dataset_counts.items()):
    print(f"  {dataset}: {count}")

# Show distribution by class
print("\nFinal class distribution:")
final_class_counts = Counter(e["class_id"] for e in all_train_entries)
for class_id, count in sorted(final_class_counts.items()):
    canonical = common_canonical_labels[class_id]
    print(f"  Class {class_id} ({canonical}): {count} samples")


Plant_doc training samples: 0
FieldPlant training samples: 4336

Total combined training samples: 13246

Training samples by dataset:
  fieldplant: 4336
  main: 8910

Final class distribution:
  Class 0 (corn_cercospora_leaf_spot_gray_leaf_spot): 908 samples
  Class 1 (corn_common_rust): 1051 samples
  Class 2 (corn_northern_leaf_blight): 4104 samples
  Class 3 (tomato_healthy): 1830 samples
  Class 4 (tomato_tomato_mosaic_virus): 838 samples
  Class 5 (tomato_tomato_yellow_leaf_curl_virus): 4515 samples


## Step 5: Create Validation Sets from All Three Datasets


In [6]:
# Create validation sets from each dataset
val_entries_main = [e for e in dataset_index if e["split"] == "val"]
val_entries_main_filtered = filter_main_dataset(val_entries_main, common_class_ids_original)

val_entries_plant_doc = load_plant_doc_data(PLANT_DOC_TEST, split="test")  # Use test as validation
val_entries_fieldplant = load_fieldplant_data(FIELDPLANT_DIR)  # Use all as validation

# For FieldPlant, we'll use a subset for validation (e.g., 20% of each class)
# For now, use all as validation
np.random.seed(42)
if len(val_entries_fieldplant) > 0:
    indices = np.arange(len(val_entries_fieldplant))
    np.random.shuffle(indices)
    val_size = int(len(val_entries_fieldplant) * 0.2)
    val_entries_fieldplant = [val_entries_fieldplant[i] for i in indices[:val_size]]

print(f"Main dataset validation samples: {len(val_entries_main_filtered)}")
print(f"Plant_doc validation samples: {len(val_entries_plant_doc)}")
print(f"FieldPlant validation samples: {len(val_entries_fieldplant)}")

# Store separately for per-dataset evaluation
validation_sets = {
    "main": val_entries_main_filtered,
    "plant_doc": val_entries_plant_doc,
    "fieldplant": val_entries_fieldplant
}


Main dataset validation samples: 1114
Plant_doc validation samples: 0
FieldPlant validation samples: 867


## Step 6: Define Dataset Class and Transforms


In [7]:
# Transforms
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
IMG_SIZE = 224

transform_train = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE)),
    T.RandomHorizontalFlip(p=0.5),
    T.RandomRotation(degrees=20),
    T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.3, hue=0.1),
    T.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.8, 1.2)),
    T.ToTensor(),
    T.RandomErasing(p=0.3, scale=(0.02, 0.15), ratio=(0.3, 3.3)),
    T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

transform_eval = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE)),
    T.ToTensor(),
    T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

# Dataset class
class CommonClassDataset(Dataset):
    def __init__(self, entries, transform_train=True):
        self.entries = entries
        self.transform_train = transform_train

    def __len__(self):
        return len(self.entries)

    def __getitem__(self, idx):
        item = self.entries[idx]
        img_path = item["path"]
        class_id = item["class_id"]

        try:
            img = Image.open(img_path).convert("RGB")
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            # Return a black image as fallback
            img = Image.new("RGB", (224, 224), (0, 0, 0))
            class_id = 0

        if self.transform_train:
            img = transform_train(img)
        else:
            img = transform_eval(img)

        return img, class_id

# Create datasets
train_dataset = CommonClassDataset(all_train_entries, transform_train=True)
val_dataset_main = CommonClassDataset(val_entries_main_filtered, transform_train=False)
val_dataset_plant_doc = CommonClassDataset(val_entries_plant_doc, transform_train=False)
val_dataset_fieldplant = CommonClassDataset(val_entries_fieldplant, transform_train=False)

print("Datasets created successfully!")


Datasets created successfully!


## Step 7: Create DataLoaders


In [8]:
# Create weighted sampler for training
train_class_counts = Counter(e["class_id"] for e in all_train_entries)
max_count = max(train_class_counts.values())
class_weights = {cid: max_count / cnt for cid, cnt in train_class_counts.items()}
sample_weights = [class_weights[e["class_id"]] for e in all_train_entries]

sampler = WeightedRandomSampler(
    weights=torch.DoubleTensor(sample_weights),
    num_samples=len(sample_weights),
    replacement=True
)

BATCH_SIZE = 32

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    sampler=sampler,
    num_workers=0,
    pin_memory=True if torch.cuda.is_available() else False
)

val_loader_main = DataLoader(
    val_dataset_main,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    pin_memory=True if torch.cuda.is_available() else False
)

val_loader_plant_doc = DataLoader(
    val_dataset_plant_doc,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    pin_memory=True if torch.cuda.is_available() else False
)

val_loader_fieldplant = DataLoader(
    val_dataset_fieldplant,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    pin_memory=True if torch.cuda.is_available() else False
)

print(f"Train batches: {len(train_loader)}")
print(f"Main validation batches: {len(val_loader_main)}")
print(f"Plant_doc validation batches: {len(val_loader_plant_doc)}")
print(f"FieldPlant validation batches: {len(val_loader_fieldplant)}")


Train batches: 414
Main validation batches: 35
Plant_doc validation batches: 0
FieldPlant validation batches: 28


## Step 8: Training Functions


In [None]:
def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    all_targets = []
    all_preds = []

    for images, targets in loader:
        images = images.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        preds = outputs.argmax(dim=1)
        all_targets.append(targets.detach().cpu())
        all_preds.append(preds.detach().cpu())

    # Handle empty loader case
    if len(all_targets) == 0:
        raise ValueError("Training loader is empty! Cannot train on empty dataset.")
    
    all_targets = torch.cat(all_targets).numpy()
    all_preds = torch.cat(all_preds).numpy()

    epoch_loss = running_loss / len(loader.dataset)
    epoch_acc = (all_targets == all_preds).mean()
    epoch_f1 = f1_score(all_targets, all_preds, average="macro")

    return epoch_loss, epoch_acc, epoch_f1

@torch.no_grad()
def evaluate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    all_targets = []
    all_preds = []

    for images, targets in loader:
        images = images.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        outputs = model(images)
        loss = criterion(outputs, targets)

        running_loss += loss.item() * images.size(0)
        preds = outputs.argmax(dim=1)
        all_targets.append(targets.detach().cpu())
        all_preds.append(preds.detach().cpu())

    # Handle empty loader case
    if len(all_targets) == 0:
        # Return default values for empty validation set
        return 0.0, 0.0, 0.0, np.array([]), np.array([])
    
    all_targets = torch.cat(all_targets).numpy()
    all_preds = torch.cat(all_preds).numpy()

    epoch_loss = running_loss / len(loader.dataset) if len(loader.dataset) > 0 else 0.0
    epoch_acc = (all_targets == all_preds).mean() if len(all_targets) > 0 else 0.0
    epoch_f1 = f1_score(all_targets, all_preds, average="macro") if len(all_targets) > 0 else 0.0

    return epoch_loss, epoch_acc, epoch_f1, all_targets, all_preds

def create_model_and_optim(model_name, num_classes, lr=3e-4, weight_decay=1e-4, device=DEVICE):
    model = timm.create_model(
        model_name,
        pretrained=True,
        num_classes=num_classes
    )
    model.to(device)
    optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = CosineAnnealingLR(optimizer, T_max=20)
    criterion = nn.CrossEntropyLoss()
    return model, criterion, optimizer, scheduler


## Step 9: Training Loop for Common Classes


In [12]:
def train_model_common_classes(
    model_name,
    num_classes,
    train_loader,
    val_loaders_dict,  # Dictionary of validation loaders
    max_epochs=20,
    lr=3e-4,
    weight_decay=1e-4,
    device=DEVICE,
    early_stopping_patience=5
):
    print(f"Starting training for model: {model_name}")
    print(f"  -> Training on {num_classes} common classes")
    print(f"  -> Validation sets: {list(val_loaders_dict.keys())}")
    
    model, criterion, optimizer, scheduler = create_model_and_optim(
        model_name=model_name,
        num_classes=num_classes,
        lr=lr,
        weight_decay=weight_decay,
        device=device
    )

    best_val_f1 = -1.0
    best_state = None
    history = {
        "train_loss": [],
        "train_acc": [],
        "train_f1": [],
    }
    # Add validation history for each dataset
    for dataset_name in val_loaders_dict.keys():
        history[f"val_{dataset_name}_loss"] = []
        history[f"val_{dataset_name}_acc"] = []
        history[f"val_{dataset_name}_f1"] = []

    epochs_without_improvement = 0

    for epoch in range(1, max_epochs + 1):
        start_time = time.time()

        train_loss, train_acc, train_f1 = train_one_epoch(
            model, train_loader, criterion, optimizer, device
        )
        
        # Evaluate on all validation sets (skip empty ones)
        val_results = {}
        for dataset_name, val_loader in val_loaders_dict.items():
            # Skip empty validation loaders
            if len(val_loader) == 0:
                print(f"  Warning: {dataset_name} validation set is empty, skipping...")
                val_results[dataset_name] = {
                    "loss": 0.0,
                    "acc": 0.0,
                    "f1": 0.0
                }
            else:
                val_loss, val_acc, val_f1, _, _ = evaluate(
                    model, val_loader, criterion, device
                )
                val_results[dataset_name] = {
                    "loss": val_loss,
                    "acc": val_acc,
                    "f1": val_f1
                }

        scheduler.step()

        # Update history
        history["train_loss"].append(train_loss)
        history["train_acc"].append(train_acc)
        history["train_f1"].append(train_f1)
        
        for dataset_name, results in val_results.items():
            history[f"val_{dataset_name}_loss"].append(results["loss"])
            history[f"val_{dataset_name}_acc"].append(results["acc"])
            history[f"val_{dataset_name}_f1"].append(results["f1"])

        # Use average F1 across all non-empty validation sets for best model selection
        non_empty_f1s = [r["f1"] for r in val_results.values() if r["f1"] > 0.0]
        if len(non_empty_f1s) > 0:
            avg_val_f1 = np.mean(non_empty_f1s)
        else:
            avg_val_f1 = 0.0
            print("  Warning: All validation sets are empty!")

        elapsed = time.time() - start_time

        print(f"Epoch {epoch}/{max_epochs} ({elapsed:.1f}s)")
        print(f"  Train - Loss: {train_loss:.4f}, Acc: {train_acc:.4f}, F1: {train_f1:.4f}")
        for dataset_name, results in val_results.items():
            print(f"  Val {dataset_name} - Loss: {results['loss']:.4f}, Acc: {results['acc']:.4f}, F1: {results['f1']:.4f}")
        print(f"  Avg Val F1: {avg_val_f1:.4f}")

        if avg_val_f1 > best_val_f1:
            best_val_f1 = avg_val_f1
            best_state = copy.deepcopy(model.state_dict())
            epochs_without_improvement = 0
            print(f"  -> New best! Avg Val F1: {best_val_f1:.4f}")
        else:
            epochs_without_improvement += 1
            if epochs_without_improvement >= early_stopping_patience:
                print(f"  -> Early stopping after {epoch} epochs")
                break

    model.load_state_dict(best_state)
    return model, history, best_val_f1

# Prepare validation loaders dictionary
val_loaders = {
    "main": val_loader_main,
    "plant_doc": val_loader_plant_doc,
    "fieldplant": val_loader_fieldplant
}

NUM_COMMON_CLASSES = len(common_canonical_labels)
print(f"\nReady to train on {NUM_COMMON_CLASSES} common classes")



Ready to train on 6 common classes


## Step 10: Train EfficientNet-B0 on Common Classes


In [13]:
# Train EfficientNet-B0
print("="*70)
print("TRAINING EFFICIENTNET-B0 ON COMMON CLASSES")
print("="*70)

model_eff, history_eff, best_val_f1_eff = train_model_common_classes(
    model_name="efficientnet_b0",
    num_classes=NUM_COMMON_CLASSES,
    train_loader=train_loader,
    val_loaders_dict=val_loaders,
    max_epochs=20,
    lr=3e-4,
    weight_decay=1e-4,
    device=DEVICE,
    early_stopping_patience=5
)

# Save model
checkpoint_eff = {
    "model_name": "efficientnet_b0",
    "num_classes": NUM_COMMON_CLASSES,
    "state_dict": model_eff.state_dict(),
    "best_val_f1": best_val_f1_eff,
    "history": history_eff,
    "common_classes": common_canonical_labels,
    "common_class_ids_original": common_class_ids_original,
}

torch.save(checkpoint_eff, MODELS_DIR / "efficientnet_b0_common_classes.pt")
print(f"\nModel saved to: {MODELS_DIR / 'efficientnet_b0_common_classes.pt'}")


TRAINING EFFICIENTNET-B0 ON COMMON CLASSES
Starting training for model: efficientnet_b0
  -> Training on 6 common classes
  -> Validation sets: ['main', 'plant_doc', 'fieldplant']
Epoch 1/20 (512.0s)
  Train - Loss: 1.8157, Acc: 0.3748, F1: 0.3694
  Val main - Loss: 0.7480, Acc: 0.6652, F1: 0.5172
  Val plant_doc - Loss: 0.0000, Acc: 0.0000, F1: 0.0000
  Val fieldplant - Loss: 1.1761, Acc: 0.5363, F1: 0.1589
  Avg Val F1: 0.3380
  -> New best! Avg Val F1: 0.3380
Epoch 2/20 (543.7s)
  Train - Loss: 1.0689, Acc: 0.6299, F1: 0.6276
  Val main - Loss: 0.3151, Acc: 0.8869, F1: 0.7983
  Val plant_doc - Loss: 0.0000, Acc: 0.0000, F1: 0.0000
  Val fieldplant - Loss: 1.2054, Acc: 0.5294, F1: 0.2941
  Avg Val F1: 0.5462
  -> New best! Avg Val F1: 0.5462
Epoch 3/20 (1518.3s)
  Train - Loss: 0.7964, Acc: 0.7175, F1: 0.7177
  Val main - Loss: 0.2193, Acc: 0.9084, F1: 0.8347
  Val plant_doc - Loss: 0.0000, Acc: 0.0000, F1: 0.0000
  Val fieldplant - Loss: 0.6840, Acc: 0.7555, F1: 0.4062
  Avg Val F1:

## Step 11: Train ViT-Base on Common Classes


In [14]:
# Train ViT-Base
print("="*70)
print("TRAINING VIT-BASE ON COMMON CLASSES")
print("="*70)

model_vit, history_vit, best_val_f1_vit = train_model_common_classes(
    model_name="vit_base_patch16_224",
    num_classes=NUM_COMMON_CLASSES,
    train_loader=train_loader,
    val_loaders_dict=val_loaders,
    max_epochs=20,
    lr=3e-4,
    weight_decay=1e-4,
    device=DEVICE,
    early_stopping_patience=5
)

# Save model
checkpoint_vit = {
    "model_name": "vit_base_patch16_224",
    "num_classes": NUM_COMMON_CLASSES,
    "state_dict": model_vit.state_dict(),
    "best_val_f1": best_val_f1_vit,
    "history": history_vit,
    "common_classes": common_canonical_labels,
    "common_class_ids_original": common_class_ids_original,
}

torch.save(checkpoint_vit, MODELS_DIR / "vit_base_patch16_224_common_classes.pt")
print(f"\nModel saved to: {MODELS_DIR / 'vit_base_patch16_224_common_classes.pt'}")


TRAINING VIT-BASE ON COMMON CLASSES
Starting training for model: vit_base_patch16_224
  -> Training on 6 common classes
  -> Validation sets: ['main', 'plant_doc', 'fieldplant']
Epoch 1/20 (1963.2s)
  Train - Loss: 1.3613, Acc: 0.4681, F1: 0.4649
  Val main - Loss: 0.6732, Acc: 0.7666, F1: 0.6592
  Val plant_doc - Loss: 0.0000, Acc: 0.0000, F1: 0.0000
  Val fieldplant - Loss: 0.9760, Acc: 0.6401, F1: 0.2155
  Avg Val F1: 0.4373
  -> New best! Avg Val F1: 0.4373
Epoch 2/20 (2080.3s)
  Train - Loss: 0.9481, Acc: 0.6560, F1: 0.6500
  Val main - Loss: 0.5450, Acc: 0.7513, F1: 0.6021
  Val plant_doc - Loss: 0.0000, Acc: 0.0000, F1: 0.0000
  Val fieldplant - Loss: 0.8492, Acc: 0.6909, F1: 0.2305
  Avg Val F1: 0.4163
Epoch 3/20 (2200.0s)
  Train - Loss: 0.8474, Acc: 0.6924, F1: 0.6921
  Val main - Loss: 0.3658, Acc: 0.8402, F1: 0.7101
  Val plant_doc - Loss: 0.0000, Acc: 0.0000, F1: 0.0000
  Val fieldplant - Loss: 0.8200, Acc: 0.7370, F1: 0.2568
  Avg Val F1: 0.4835
  -> New best! Avg Val F1:

KeyboardInterrupt: 

## Step 12: Validation Block - Performance Statistics on All Three Datasets


In [None]:
def evaluate_on_all_datasets(model, val_loaders_dict, criterion, device, model_name):
    """Evaluate model on all validation datasets and return detailed statistics."""
    results = {}
    
    print(f"\n{'='*70}")
    print(f"VALIDATION RESULTS FOR {model_name.upper()}")
    print(f"{'='*70}\n")
    
    for dataset_name, val_loader in val_loaders_dict.items():
        print(f"Evaluating on {dataset_name.upper()} dataset...")
        val_loss, val_acc, val_f1, all_targets, all_preds = evaluate(
            model, val_loader, criterion, device
        )
        
        # Per-class metrics
        class_report = classification_report(
            all_targets, all_preds,
            target_names=[common_canonical_labels[i] for i in range(NUM_COMMON_CLASSES)],
            output_dict=True,
            zero_division=0
        )
        
        # Confusion matrix
        cm = confusion_matrix(all_targets, all_preds)
        
        results[dataset_name] = {
            "loss": val_loss,
            "accuracy": val_acc,
            "f1_macro": val_f1,
            "all_targets": all_targets,
            "all_preds": all_preds,
            "classification_report": class_report,
            "confusion_matrix": cm
        }
        
        print(f"  Loss: {val_loss:.4f}")
        print(f"  Accuracy: {val_acc:.4f} ({val_acc*100:.2f}%)")
        print(f"  Macro F1: {val_f1:.4f} ({val_f1*100:.2f}%)")
        print(f"  Total samples: {len(all_targets)}")
        print()
        
        # Per-class accuracy
        print(f"  Per-class accuracy on {dataset_name}:")
        for class_id in range(NUM_COMMON_CLASSES):
            class_mask = all_targets == class_id
            if np.sum(class_mask) > 0:
                class_acc = (all_preds[class_mask] == class_id).mean()
                class_name = common_canonical_labels[class_id]
                print(f"    Class {class_id} ({class_name}): {class_acc:.4f} ({np.sum(class_mask)} samples)")
        print()
    
    return results

# Evaluate EfficientNet
criterion = nn.CrossEntropyLoss()
print("\n" + "="*70)
print("EFFICIENTNET-B0 VALIDATION RESULTS")
print("="*70)
results_eff = evaluate_on_all_datasets(model_eff, val_loaders, criterion, DEVICE, "EfficientNet-B0")


In [None]:
# Evaluate ViT
print("\n" + "="*70)
print("VIT-BASE VALIDATION RESULTS")
print("="*70)
results_vit = evaluate_on_all_datasets(model_vit, val_loaders, criterion, DEVICE, "ViT-Base")


## Step 13: Summary Statistics and Comparison


In [None]:
# Create summary comparison table
print("\n" + "="*70)
print("SUMMARY COMPARISON TABLE")
print("="*70)
print("\nPerformance across all validation datasets:\n")

# Create comparison DataFrame
comparison_data = []
for dataset_name in ["main", "plant_doc", "fieldplant"]:
    comparison_data.append({
        "Dataset": dataset_name.upper(),
        "Model": "EfficientNet-B0",
        "Accuracy": results_eff[dataset_name]["accuracy"],
        "F1 Macro": results_eff[dataset_name]["f1_macro"],
        "Loss": results_eff[dataset_name]["loss"],
        "Samples": len(results_eff[dataset_name]["all_targets"])
    })
    comparison_data.append({
        "Dataset": dataset_name.upper(),
        "Model": "ViT-Base",
        "Accuracy": results_vit[dataset_name]["accuracy"],
        "F1 Macro": results_vit[dataset_name]["f1_macro"],
        "Loss": results_vit[dataset_name]["loss"],
        "Samples": len(results_vit[dataset_name]["all_targets"])
    })

df_comparison = pd.DataFrame(comparison_data)
print(df_comparison.to_string(index=False))

# Overall statistics
print("\n" + "="*70)
print("OVERALL STATISTICS")
print("="*70)

print("\nEfficientNet-B0:")
print(f"  Average Accuracy across datasets: {df_comparison[df_comparison['Model']=='EfficientNet-B0']['Accuracy'].mean():.4f}")
print(f"  Average F1 Macro across datasets: {df_comparison[df_comparison['Model']=='EfficientNet-B0']['F1 Macro'].mean():.4f}")

print("\nViT-Base:")
print(f"  Average Accuracy across datasets: {df_comparison[df_comparison['Model']=='ViT-Base']['Accuracy'].mean():.4f}")
print(f"  Average F1 Macro across datasets: {df_comparison[df_comparison['Model']=='ViT-Base']['F1 Macro'].mean():.4f}")

# Best model per dataset
print("\n" + "="*70)
print("BEST MODEL PER DATASET")
print("="*70)
for dataset_name in ["main", "plant_doc", "fieldplant"]:
    eff_f1 = results_eff[dataset_name]["f1_macro"]
    vit_f1 = results_vit[dataset_name]["f1_macro"]
    best_model = "EfficientNet-B0" if eff_f1 > vit_f1 else "ViT-Base"
    print(f"{dataset_name.upper()}: {best_model} (F1: {max(eff_f1, vit_f1):.4f})")


## Step 14: Detailed Per-Class Analysis


In [None]:
# Detailed per-class analysis
print("\n" + "="*70)
print("DETAILED PER-CLASS ANALYSIS")
print("="*70)

for dataset_name in ["main", "plant_doc", "fieldplant"]:
    print(f"\n{dataset_name.upper()} Dataset:")
    print("-" * 70)
    
    for model_name, results in [("EfficientNet-B0", results_eff), ("ViT-Base", results_vit)]:
        print(f"\n{model_name}:")
        all_targets = results[dataset_name]["all_targets"]
        all_preds = results[dataset_name]["all_preds"]
        
        for class_id in range(NUM_COMMON_CLASSES):
            class_name = common_canonical_labels[class_id]
            class_mask = all_targets == class_id
            if np.sum(class_mask) > 0:
                class_acc = (all_preds[class_mask] == class_id).mean()
                class_samples = np.sum(class_mask)
                class_correct = np.sum((all_preds[class_mask] == class_id))
                
                # Precision, Recall, F1 from classification report
                if class_name in results[dataset_name]["classification_report"]:
                    metrics = results[dataset_name]["classification_report"][class_name]
                    precision = metrics.get("precision", 0)
                    recall = metrics.get("recall", 0)
                    f1 = metrics.get("f1-score", 0)
                else:
                    precision = recall = f1 = 0
                
                print(f"  {class_name}:")
                print(f"    Accuracy: {class_acc:.4f} ({class_correct}/{class_samples})")
                print(f"    Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")
