**Uncertainty-Aware Domain Adaptation for Plant Disease Detection**


This notebook implements the complete pipeline from data preparation through training to evaluation.

In [1]:
# ============================================================================
# CELL 1: SETUP AND INSTALLATIONS
# ============================================================================

import torch
import os
# Install required packages and check GPU availability
# Check GPU
print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU Device:", torch.cuda.get_device_name(0))
    print("GPU Memory:", torch.cuda.get_device_properties(0).total_memory / 1e9, "GB")

print("\nSetup completed successfully!")

os.makedirs('./checkpoints', exist_ok=True)
os.makedirs('./results', exist_ok=True)
print("\nDirectories created successfully")

PyTorch version: 2.8.0+cu126
CUDA available: True
GPU Device: Tesla T4
GPU Memory: 15.828320256 GB

Setup completed successfully!

Directories created successfully


# Downloading Dataset from Kaggle

In [3]:
# Usually pre-installed in Colab, but just in case
!pip install -q kaggle

from google.colab import files
import os

print("📁 Please upload your kaggle.json file")
print("(You can download it from Kaggle.com → Settings → API → Create New Token)")
print()

uploaded = files.upload()

# Setup Kaggle
os.makedirs(os.path.expanduser("~/.kaggle"), exist_ok=True)
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

print("\n Kaggle credentials successfully configured!")

📁 Please upload your kaggle.json file
(You can download it from Kaggle.com → Settings → API → Create New Token)



Saving kaggle.json to kaggle.json

 Kaggle credentials successfully configured!


In [None]:
# Test if Kaggle API is working
!kaggle datasets list | head -5

print("\n✓ Kaggle API is working!")

ref                                                           title                                                    size  lastUpdated                 downloadCount  voteCount  usabilityRating  
------------------------------------------------------------  -------------------------------------------------  ----------  --------------------------  -------------  ---------  ---------------  
ahmadrazakashif/bmw-worldwide-sales-records-20102024          BMW Worldwide Sales Records (2010–2024)                853348  2025-09-20 14:39:45.280000           7762        168  1.0              
jockeroika/life-style-data                                    Life Style Data                                       6289184  2025-10-13 02:11:38.793000           2239         62  0.9411765        
grandmaster07/student-exam-score-dataset-analysis             Student exam score dataset analysis                      2430  2025-09-26 07:44:12.677000           3215         72  1.0              

✓ Kaggle API i

# **PlantDoc Dataset**
available at [PlantDoc Classification dataset](https://www.kaggle.com/datasets/nirmalsankalana/plantdoc-dataset)

In [4]:
import os

print(" Downloading PlantDoc dataset...")
print("Dataset: nirmalsankalana/plantdoc-dataset")
print()

# Download
!kaggle datasets download -d nirmalsankalana/plantdoc-dataset

print("\n✓ Download complete!")

# Check file size
!ls -lh plantdoc-dataset.zip

 Downloading PlantDoc dataset...
Dataset: nirmalsankalana/plantdoc-dataset

Dataset URL: https://www.kaggle.com/datasets/nirmalsankalana/plantdoc-dataset
License(s): CC0-1.0
Downloading plantdoc-dataset.zip to /content
100% 894M/896M [00:07<00:00, 116MB/s] 
100% 896M/896M [00:07<00:00, 119MB/s]

✓ Download complete!
-rw-r--r-- 1 root root 896M Sep 16  2024 plantdoc-dataset.zip


In [5]:
print("📦 Extracting dataset...")

# Create directory and extract
!mkdir -p /content/plantdoc
!unzip -q plantdoc-dataset.zip -d /content/plantdoc

print("✓ Extraction complete!")

# Verify structure
print("\n📂 Dataset structure:")
!ls -lh /content/plantdoc

print("\n📂 Train folder:")
!ls /content/plantdoc/train | head -10

print("\n📂 Test folder:")
!ls /content/plantdoc/test | head -10

📦 Extracting dataset...
✓ Extraction complete!

📂 Dataset structure:
total 16K
-rw-r--r--  1 root root 1.2K Sep 16  2024 file_renamer.py
-rw-r--r--  1 root root  595 Sep 16  2024 folder_renamer.py
drwxr-xr-x 29 root root 4.0K Oct 27 08:47 test
drwxr-xr-x 30 root root 4.0K Oct 27 08:47 train

📂 Train folder:
Apple_leaf
Apple_rust_leaf
Apple_Scab_Leaf
Bell_pepper_leaf
Bell_pepper_leaf_spot
Blueberry_leaf
Cherry_leaf
Corn_Gray_leaf_spot
Corn_leaf_blight
Corn_rust_leaf

📂 Test folder:
Apple_leaf
Apple_rust_leaf
Apple_Scab_Leaf
Bell_pepper_leaf
Bell_pepper_leaf_spot
Blueberry_leaf
Cherry_leaf
Corn_Gray_leaf_spot
Corn_leaf_blight
Corn_rust_leaf


# **PlantVillage Dataset**
available at [PlantVillage](https://www.kaggle.com/datasets/mohitsingh1804/plantvillage)

In [6]:
import os

print(" Downloading PlantVillage dataset...")
print("Dataset: mohitsingh1804/plantvillage")
print()

# Download
!kaggle datasets download -d mohitsingh1804/plantvillage

print("\n✓ Download complete!")

# Check file size
!ls -lh plantvillage.zip

 Downloading PlantVillage dataset...
Dataset: mohitsingh1804/plantvillage

Dataset URL: https://www.kaggle.com/datasets/mohitsingh1804/plantvillage
License(s): GPL-2.0
Downloading plantvillage.zip to /content
 96% 785M/818M [00:04<00:00, 41.0MB/s]
100% 818M/818M [00:04<00:00, 197MB/s] 

✓ Download complete!
-rw-r--r-- 1 root root 818M Aug 20  2021 plantvillage.zip


In [7]:
print("📦 Extracting dataset...")

# Create directory and extract
!mkdir -p /content/plantvillage
!unzip -q plantvillage.zip -d /content/plantvillage

print("✓ Extraction complete!")

📦 Extracting dataset...
✓ Extraction complete!


In [8]:
print("\n📂 Dataset structure:")
!ls -lh /content/plantvillage/PlantVillage

print("\n📂 Train folder:")
!ls /content/plantvillage/PlantVillage/train | head -10

print("\n📂 Test folder:")
!ls /content/plantvillage/PlantVillage/val | head -10


📂 Dataset structure:
total 8.0K
drwxr-xr-x 40 root root 4.0K Oct 27 08:48 train
drwxr-xr-x 40 root root 4.0K Oct 27 08:48 val

📂 Train folder:
Apple___Apple_scab
Apple___Black_rot
Apple___Cedar_apple_rust
Apple___healthy
Blueberry___healthy
Cherry_(including_sour)___healthy
Cherry_(including_sour)___Powdery_mildew
Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot
Corn_(maize)___Common_rust_
Corn_(maize)___healthy

📂 Test folder:
Apple___Apple_scab
Apple___Black_rot
Apple___Cedar_apple_rust
Apple___healthy
Blueberry___healthy
Cherry_(including_sour)___healthy
Cherry_(including_sour)___Powdery_mildew
Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot
Corn_(maize)___Common_rust_
Corn_(maize)___healthy


# Data Exploration and Visualization Utilities


This module provides functions for exploring and visualizing the datasets
before training, useful for understanding data characteristics.

# DATASET VERIFICATION SCRIPT

To check if your datasets are properly structured and accessible

In [9]:
import os
from PIL import Image


def verify_dataset_structure(dataset_path, dataset_name):
    """Verify dataset structure and count images"""

    print(f"\n{'='*70}")
    print(f"VERIFYING {dataset_name.upper()} DATASET")
    print(f"{'='*70}\n")

    print(f"Path: {dataset_path}")
    print(f"Exists: {os.path.exists(dataset_path)}\n")

    if not os.path.exists(dataset_path):
        print("❌ Dataset path does not exist!")
        return False

    # Check top-level structure
    print("📂 Top-level contents:")
    top_level = os.listdir(dataset_path)
    for item in top_level:
        item_path = os.path.join(dataset_path, item)
        if os.path.isdir(item_path):
            print(f"  📁 {item}/")
        else:
            print(f"  📄 {item}")

    # Check for train/test folders
    has_train = os.path.exists(os.path.join(dataset_path, 'train'))
    has_test = os.path.exists(os.path.join(dataset_path, 'test'))
    has_val = os.path.exists(os.path.join(dataset_path, 'val'))

    print(f"\n📊 Structure check:")
    print(f"  train/ folder: {'✓' if has_train else '✗'}")
    print(f"  test/ folder: {'✓' if has_test else '✗'}")
    print(f"  val/ folder: {'✓' if has_val else '✗'}")

    # Count images in each split
    total_images = 0
    splits_info = {}

    for split_name in ['train', 'test', 'val']:
        split_path = os.path.join(dataset_path, split_name)
        if os.path.exists(split_path):
            split_images = 0
            classes = []

            # Iterate through class folders
            for class_name in os.listdir(split_path):
                class_path = os.path.join(split_path, class_name)
                if os.path.isdir(class_path):
                    classes.append(class_name)

                    # Count images in this class
                    images = [f for f in os.listdir(class_path)
                             if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
                    split_images += len(images)

            splits_info[split_name] = {
                'images': split_images,
                'classes': len(classes),
                'class_names': sorted(classes)
            }
            total_images += split_images

    # Print summary
    print(f"\n📈 Dataset summary:")
    for split_name, info in splits_info.items():
        print(f"\n  {split_name.upper()}:")
        print(f"    Images: {info['images']:,}")
        print(f"    Classes: {info['classes']}")
        print(f"    Sample classes: {', '.join(info['class_names'][:5])}")
        if len(info['class_names']) > 5:
            print(f"                    (and {len(info['class_names']) - 5} more...)")

    print(f"\n  TOTAL IMAGES: {total_images:,}")

    # Test loading a sample image
    print(f"\n🖼️  Testing image loading...")
    for split_name, info in splits_info.items():
        if info['images'] > 0:
            # Find first image
            split_path = os.path.join(dataset_path, split_name)
            for class_name in info['class_names']:
                class_path = os.path.join(split_path, class_name)
                images = [f for f in os.listdir(class_path)
                         if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
                if images:
                    sample_image = os.path.join(class_path, images[0])
                    try:
                        img = Image.open(sample_image)
                        print(f"  ✓ Successfully loaded sample from {split_name}/{class_name}")
                        print(f"    Size: {img.size}, Mode: {img.mode}")
                        break
                    except Exception as e:
                        print(f"  ✗ Error loading {sample_image}: {e}")
            break

    if total_images == 0:
        print("\n❌ No images found! Check your dataset structure.")
        return False
    else:
        print(f"\n✓ Dataset verification successful!")
        return True


def verify_both_datasets(plantvillage_path, plantdoc_path):
    """Verify both datasets"""

    print("\n" + "="*70)
    print("DATASET VERIFICATION")
    print("="*70)

    # Verify PlantVillage
    pv_ok = verify_dataset_structure(plantvillage_path, "PlantVillage")

    # Verify PlantDoc
    pd_ok = verify_dataset_structure(plantdoc_path, "PlantDoc")

    # Final summary
    print("\n" + "="*70)
    print("VERIFICATION SUMMARY")
    print("="*70)
    print(f"PlantVillage: {'✓ PASS' if pv_ok else '✗ FAIL'}")
    print(f"PlantDoc: {'✓ PASS' if pd_ok else '✗ FAIL'}")

    if pv_ok and pd_ok:
        print("\n🎉 Both datasets verified successfully!")
        print("You can now proceed with training.")
    else:
        print("\n⚠️  Dataset issues detected. Please fix before training.")

    print("="*70 + "\n")

    return pv_ok and pd_ok


# Run verification
if __name__ == "__main__":
    # Set your paths here
    PLANTVILLAGE_PATH = '/content/plantvillage/PlantVillage'
    PLANTDOC_PATH = '/content/plantdoc'

    # Run verification
    verify_both_datasets(PLANTVILLAGE_PATH, PLANTDOC_PATH)


DATASET VERIFICATION

VERIFYING PLANTVILLAGE DATASET

Path: /content/plantvillage/PlantVillage
Exists: True

📂 Top-level contents:
  📁 val/
  📁 train/

📊 Structure check:
  train/ folder: ✓
  test/ folder: ✗
  val/ folder: ✓

📈 Dataset summary:

  TRAIN:
    Images: 43,444
    Classes: 38
    Sample classes: Apple___Apple_scab, Apple___Black_rot, Apple___Cedar_apple_rust, Apple___healthy, Blueberry___healthy
                    (and 33 more...)

  VAL:
    Images: 10,861
    Classes: 38
    Sample classes: Apple___Apple_scab, Apple___Black_rot, Apple___Cedar_apple_rust, Apple___healthy, Blueberry___healthy
                    (and 33 more...)

  TOTAL IMAGES: 54,305

🖼️  Testing image loading...
  ✓ Successfully loaded sample from train/Apple___Apple_scab
    Size: (256, 256), Mode: RGB

✓ Dataset verification successful!

VERIFYING PLANTDOC DATASET

Path: /content/plantdoc
Exists: True

📂 Top-level contents:
  📄 folder_renamer.py
  📁 test/
  📄 file_renamer.py
  📁 train/

📊 Structure 

# COMPLETE DATA EXPLORATION

In [10]:
# ===================================================================
# COMPLETE DATA EXPLORATION - COPY THIS ENTIRE CELL
# ===================================================================

import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

def analyze_and_visualize(plantvillage_path, plantdoc_path, save_dir='./exploration_results'):
    """Complete dataset analysis - works with your exact folder structure"""
    os.makedirs(save_dir, exist_ok=True)

    print("\n" + "="*70)
    print("DATASET EXPLORATION")
    print("="*70 + "\n")

    # ========================
    # ANALYZE PLANTVILLAGE
    # ========================
    print("Analyzing PlantVillage...")
    pv_stats = {'total_images': 0, 'splits': {}}

    for split in ['train', 'val']:
        split_path = os.path.join(plantvillage_path, split)
        if not os.path.exists(split_path):
            continue

        split_images = 0
        split_classes = []
        class_counts = {}

        for class_name in os.listdir(split_path):
            class_path = os.path.join(split_path, class_name)
            if not os.path.isdir(class_path):
                continue

            split_classes.append(class_name)
            images = [f for f in os.listdir(class_path)
                     if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
            count = len(images)
            split_images += count
            class_counts[class_name] = count

        pv_stats['splits'][split] = {
            'images': split_images,
            'classes': len(split_classes),
            'class_names': sorted(split_classes),
            'class_counts': class_counts
        }
        pv_stats['total_images'] += split_images
        print(f"  {split}: {split_images:,} images, {len(split_classes)} classes")

    # ========================
    # ANALYZE PLANTDOC
    # ========================
    print("\nAnalyzing PlantDoc...")
    pd_stats = {'total_images': 0, 'splits': {}}

    for split in ['train', 'test']:
        split_path = os.path.join(plantdoc_path, split)
        if not os.path.exists(split_path):
            continue

        split_images = 0
        split_classes = []
        class_counts = {}

        for class_name in os.listdir(split_path):
            class_path = os.path.join(split_path, class_name)
            if not os.path.isdir(class_path):
                continue

            split_classes.append(class_name)
            images = [f for f in os.listdir(class_path)
                     if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
            count = len(images)
            split_images += count
            class_counts[class_name] = count

        pd_stats['splits'][split] = {
            'images': split_images,
            'classes': len(split_classes),
            'class_names': sorted(split_classes),
            'class_counts': class_counts
        }
        pd_stats['total_images'] += split_images
        print(f"  {split}: {split_images:,} images, {len(split_classes)} classes")

    # ========================
    # SUMMARY
    # ========================
    print(f"\n{'='*70}")
    print("SUMMARY")
    print(f"{'='*70}")
    print(f"PlantVillage total: {pv_stats['total_images']:,} images")
    print(f"PlantDoc total: {pd_stats['total_images']:,} images")

    # ========================
    # VISUALIZATIONS
    # ========================
    print(f"\n{'='*70}")
    print("GENERATING VISUALIZATIONS")
    print(f"{'='*70}\n")

    # 1. Dataset comparison
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    datasets = ['PlantVillage', 'PlantDoc']
    totals = [pv_stats['total_images'], pd_stats['total_images']]

    bars = axes[0].bar(datasets, totals, color=['skyblue', 'lightcoral'],
                       edgecolor='black', alpha=0.8, width=0.6)
    axes[0].set_ylabel('Number of Images', fontsize=12, fontweight='bold')
    axes[0].set_title('Total Images Comparison', fontsize=14, fontweight='bold')
    axes[0].grid(axis='y', alpha=0.3)
    for bar, val in zip(bars, totals):
        axes[0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(totals)*0.02,
                    f'{val:,}', ha='center', va='bottom', fontsize=12, fontweight='bold')

    pv_train_classes = pv_stats['splits'].get('train', {}).get('classes', 0)
    pd_train_classes = pd_stats['splits'].get('train', {}).get('classes', 0)
    classes = [pv_train_classes, pd_train_classes]

    bars = axes[1].bar(datasets, classes, color=['skyblue', 'lightcoral'],
                       edgecolor='black', alpha=0.8, width=0.6)
    axes[1].set_ylabel('Number of Classes', fontsize=12, fontweight='bold')
    axes[1].set_title('Number of Classes', fontsize=14, fontweight='bold')
    axes[1].grid(axis='y', alpha=0.3)
    for bar, val in zip(bars, classes):
        axes[1].text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(classes)*0.02,
                    f'{val}', ha='center', va='bottom', fontsize=12, fontweight='bold')

    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, 'dataset_comparison.png'), dpi=300, bbox_inches='tight')
    print("✓ Dataset comparison saved")
    plt.close()

    # 2. PlantVillage class distribution
    if 'train' in pv_stats['splits']:
        class_counts = pv_stats['splits']['train']['class_counts']
        sorted_classes = sorted(class_counts.items(), key=lambda x: x[1], reverse=True)[:20]

        plt.figure(figsize=(16, 6))
        classes = [c[0] for c in sorted_classes]
        counts = [c[1] for c in sorted_classes]

        bars = plt.bar(range(len(classes)), counts, color='skyblue',
                      edgecolor='black', alpha=0.8)
        plt.xlabel('Class', fontsize=12, fontweight='bold')
        plt.ylabel('Number of Images', fontsize=12, fontweight='bold')
        plt.title('PlantVillage - Top 20 Classes', fontsize=14, fontweight='bold')
        plt.xticks(range(len(classes)), classes, rotation=45, ha='right', fontsize=9)
        plt.grid(axis='y', alpha=0.3)

        for bar, count in zip(bars, counts):
            plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(counts)*0.01,
                    f'{count}', ha='center', va='bottom', fontsize=8)

        plt.tight_layout()
        plt.savefig(os.path.join(save_dir, 'plantvillage_classes.png'), dpi=300, bbox_inches='tight')
        print("✓ PlantVillage class distribution saved")
        plt.close()

    # 3. PlantDoc class distribution
    if 'train' in pd_stats['splits']:
        class_counts = pd_stats['splits']['train']['class_counts']
        sorted_classes = sorted(class_counts.items(), key=lambda x: x[1], reverse=True)

        plt.figure(figsize=(16, 6))
        classes = [c[0] for c in sorted_classes]
        counts = [c[1] for c in sorted_classes]

        bars = plt.bar(range(len(classes)), counts, color='lightcoral',
                      edgecolor='black', alpha=0.8)
        plt.xlabel('Class', fontsize=12, fontweight='bold')
        plt.ylabel('Number of Images', fontsize=12, fontweight='bold')
        plt.title('PlantDoc - Class Distribution', fontsize=14, fontweight='bold')
        plt.xticks(range(len(classes)), classes, rotation=45, ha='right', fontsize=9)
        plt.grid(axis='y', alpha=0.3)

        for bar, count in zip(bars, counts):
            plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(counts)*0.01,
                    f'{count}', ha='center', va='bottom', fontsize=8)

        plt.tight_layout()
        plt.savefig(os.path.join(save_dir, 'plantdoc_classes.png'), dpi=300, bbox_inches='tight')
        print("✓ PlantDoc class distribution saved")
        plt.close()

    # 4. Sample images from PlantVillage
    pv_train_path = os.path.join(plantvillage_path, 'train')
    if os.path.exists(pv_train_path):
        classes = [d for d in os.listdir(pv_train_path)
                  if os.path.isdir(os.path.join(pv_train_path, d))][:16]

        fig, axes = plt.subplots(4, 4, figsize=(16, 16))
        axes = axes.flatten()

        for idx, class_name in enumerate(classes):
            class_path = os.path.join(pv_train_path, class_name)
            images = [f for f in os.listdir(class_path)
                     if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
            if images:
                img_path = os.path.join(class_path, images[0])
                try:
                    img = Image.open(img_path)
                    axes[idx].imshow(img)
                    axes[idx].set_title(class_name, fontsize=8, fontweight='bold')
                    axes[idx].axis('off')
                except:
                    axes[idx].axis('off')

        for idx in range(len(classes), 16):
            axes[idx].axis('off')

        plt.suptitle('PlantVillage - Sample Images', fontsize=16, fontweight='bold')
        plt.tight_layout()
        plt.savefig(os.path.join(save_dir, 'plantvillage_samples.png'), dpi=300, bbox_inches='tight')
        print("✓ PlantVillage samples saved")
        plt.close()

    # 5. Sample images from PlantDoc
    pd_train_path = os.path.join(plantdoc_path, 'train')
    if os.path.exists(pd_train_path):
        classes = [d for d in os.listdir(pd_train_path)
                  if os.path.isdir(os.path.join(pd_train_path, d))][:16]

        fig, axes = plt.subplots(4, 4, figsize=(16, 16))
        axes = axes.flatten()

        for idx, class_name in enumerate(classes):
            class_path = os.path.join(pd_train_path, class_name)
            images = [f for f in os.listdir(class_path)
                     if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
            if images:
                img_path = os.path.join(class_path, images[0])
                try:
                    img = Image.open(img_path)
                    axes[idx].imshow(img)
                    axes[idx].set_title(class_name, fontsize=8, fontweight='bold')
                    axes[idx].axis('off')
                except:
                    axes[idx].axis('off')

        for idx in range(len(classes), 16):
            axes[idx].axis('off')

        plt.suptitle('PlantDoc - Sample Images', fontsize=16, fontweight='bold')
        plt.tight_layout()
        plt.savefig(os.path.join(save_dir, 'plantdoc_samples.png'), dpi=300, bbox_inches='tight')
        print("✓ PlantDoc samples saved")
        plt.close()

    print(f"\n✓ All visualizations saved to {save_dir}/")
    print("="*70 + "\n")

    return pv_stats, pd_stats


# ===================================================================
# RUN THE EXPLORATION
# ===================================================================
PLANTVILLAGE_PATH = '/content/plantvillage/PlantVillage'
PLANTDOC_PATH = '/content/plantdoc'

pv_stats, pd_stats = analyze_and_visualize(PLANTVILLAGE_PATH, PLANTDOC_PATH)

print("✓ EXPLORATION COMPLETE!")
print(f"\nPlantVillage: {pv_stats['total_images']:,} images")
print(f"PlantDoc: {pd_stats['total_images']:,} images")


DATASET EXPLORATION

Analyzing PlantVillage...
  train: 43,444 images, 38 classes
  val: 10,861 images, 38 classes

Analyzing PlantDoc...
  train: 2,670 images, 28 classes
  test: 252 images, 27 classes

SUMMARY
PlantVillage total: 54,305 images
PlantDoc total: 2,922 images

GENERATING VISUALIZATIONS

✓ Dataset comparison saved
✓ PlantVillage class distribution saved
✓ PlantDoc class distribution saved
✓ PlantVillage samples saved
✓ PlantDoc samples saved

✓ All visualizations saved to ./exploration_results/

✓ EXPLORATION COMPLETE!

PlantVillage: 54,305 images
PlantDoc: 2,922 images


# MODULE 1: Data Loading and Pre-Processing Pipeline Implementation

This module handles data loading, preprocessing, and augmentation for both
PlantVillage (source) and PlantDoc (target) datasets.

Figure 1: Data Loading and Preprocessing Pipeline

**CLASS MAPPING ANALYSIS TOOL**

Analyzes and finds common classes between PlantVillage and PlantDoc

In [12]:
%%writefile class_mapper.py

"""
CLASS MAPPING ANALYSIS TOOL
Analyzes and finds common classes between PlantVillage and PlantDoc

Run this FIRST to see what classes will be used for training
"""

import os
import json
from collections import Counter


def normalize_class_name(class_name):
    """
    Normalize class names to find matches
    Examples:
      'Apple___Apple_scab' -> 'apple scab'
      'Apple_Scab_Leaf' -> 'apple scab'
    """
    # Replace underscores with spaces
    normalized = class_name.replace('_', ' ').lower()

    # Remove common suffixes
    suffixes = [' leaf', ' leaves', 'including sour', '(including sour)']
    for suffix in suffixes:
        normalized = normalized.replace(suffix, '')

    # Remove multiple spaces
    normalized = ' '.join(normalized.split())

    # Remove parentheses
    normalized = normalized.replace('(', '').replace(')', '')

    return normalized.strip()


def extract_disease_keywords(class_name):
    """Extract key disease terms for matching"""
    name_lower = class_name.lower()

    # Disease keywords to look for
    diseases = [
        'scab', 'rust', 'blight', 'rot', 'spot', 'mildew',
        'healthy', 'bacterial', 'fungal', 'leaf', 'early', 'late',
        'cercospora', 'gray', 'common', 'northern', 'septoria'
    ]

    # Plant keywords
    plants = [
        'apple', 'tomato', 'potato', 'corn', 'grape', 'cherry',
        'peach', 'pepper', 'strawberry', 'raspberry', 'blueberry',
        'bell pepper', 'maize'
    ]

    found_diseases = [d for d in diseases if d in name_lower]
    found_plants = [p for p in plants if p in name_lower]

    return found_plants, found_diseases


def analyze_class_overlap(plantvillage_path, plantdoc_path):
    """
    Analyze class overlap between datasets

    Returns:
        Dictionary with analysis results and class mappings
    """
    print("\n" + "="*80)
    print("CLASS OVERLAP ANALYSIS")
    print("="*80 + "\n")

    # Get PlantVillage classes
    pv_path = os.path.join(plantvillage_path, 'PlantVillage', 'train') \
              if os.path.exists(os.path.join(plantvillage_path, 'PlantVillage')) \
              else os.path.join(plantvillage_path, 'train')

    pv_classes = sorted([d for d in os.listdir(pv_path)
                        if os.path.isdir(os.path.join(pv_path, d))])

    # Get PlantDoc classes
    pd_path = os.path.join(plantdoc_path, 'train')
    pd_classes = sorted([d for d in os.listdir(pd_path)
                        if os.path.isdir(os.path.join(pd_path, d))])

    print(f"📊 Dataset Statistics:")
    print(f"   PlantVillage classes: {len(pv_classes)}")
    print(f"   PlantDoc classes: {len(pd_classes)}")

    # Normalize class names
    pv_normalized = {cls: normalize_class_name(cls) for cls in pv_classes}
    pd_normalized = {cls: normalize_class_name(cls) for cls in pd_classes}

    # Find exact normalized matches
    pv_norm_set = set(pv_normalized.values())
    pd_norm_set = set(pd_normalized.values())
    exact_matches = pv_norm_set & pd_norm_set

    print(f"\n🔍 Exact matches (after normalization): {len(exact_matches)}")

    # Create mapping for exact matches
    class_mapping = []

    for norm_name in sorted(exact_matches):
        # Find original names
        pv_original = [k for k, v in pv_normalized.items() if v == norm_name][0]
        pd_original = [k for k, v in pd_normalized.items() if v == norm_name][0]

        class_mapping.append({
            'plantvillage': pv_original,
            'plantdoc': pd_original,
            'normalized': norm_name
        })

    # Find fuzzy matches (same plant + disease keywords)
    print(f"\n🔎 Finding additional fuzzy matches...")

    fuzzy_matches = []
    for pv_cls in pv_classes:
        pv_plants, pv_diseases = extract_disease_keywords(pv_cls)

        for pd_cls in pd_classes:
            # Skip if already matched
            if any(m['plantvillage'] == pv_cls or m['plantdoc'] == pd_cls
                   for m in class_mapping):
                continue

            pd_plants, pd_diseases = extract_disease_keywords(pd_cls)

            # Check if same plant and disease
            common_plants = set(pv_plants) & set(pd_plants)
            common_diseases = set(pv_diseases) & set(pd_diseases)

            if common_plants and common_diseases:
                fuzzy_matches.append({
                    'plantvillage': pv_cls,
                    'plantdoc': pd_cls,
                    'plants': list(common_plants),
                    'diseases': list(common_diseases)
                })

    print(f"   Found {len(fuzzy_matches)} potential fuzzy matches")

    # Display results
    print(f"\n{'='*80}")
    print(f"MATCHED CLASSES ({len(class_mapping)} exact matches)")
    print(f"{'='*80}\n")

    for i, match in enumerate(class_mapping, 1):
        print(f"{i:2d}. PlantVillage: {match['plantvillage']}")
        print(f"    PlantDoc:     {match['plantdoc']}")
        print(f"    Normalized:   {match['normalized']}")
        print()

    if fuzzy_matches:
        print(f"\n{'='*80}")
        print(f"POTENTIAL FUZZY MATCHES ({len(fuzzy_matches)} found)")
        print(f"{'='*80}\n")
        print("⚠️  These require manual verification:")
        print()

        for i, match in enumerate(fuzzy_matches, 1):
            print(f"{i:2d}. PlantVillage: {match['plantvillage']}")
            print(f"    PlantDoc:     {match['plantdoc']}")
            print(f"    Common: {match['plants']} + {match['diseases']}")
            print()

    # Summary statistics
    print(f"\n{'='*80}")
    print(f"SUMMARY")
    print(f"{'='*80}")
    print(f"   Total PlantVillage classes: {len(pv_classes)}")
    print(f"   Total PlantDoc classes: {len(pd_classes)}")
    print(f"   Exact matches: {len(class_mapping)}")
    print(f"   Fuzzy matches: {len(fuzzy_matches)}")
    print(f"   Recommended classes for training: {len(class_mapping)}")

    # Coverage analysis
    if class_mapping:
        # Count images in matched classes
        pv_matched_images = 0
        pd_matched_images = 0

        for match in class_mapping:
            pv_class_path = os.path.join(pv_path, match['plantvillage'])
            pv_images = len([f for f in os.listdir(pv_class_path)
                           if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
            pv_matched_images += pv_images

            pd_class_path = os.path.join(pd_path, match['plantdoc'])
            pd_images = len([f for f in os.listdir(pd_class_path)
                           if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
            pd_matched_images += pd_images

        # Total images
        pv_total_images = sum(len([f for f in os.listdir(os.path.join(pv_path, cls))
                                  if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
                            for cls in pv_classes)
        pd_total_images = sum(len([f for f in os.listdir(os.path.join(pd_path, cls))
                                  if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
                            for cls in pd_classes)

        print(f"\n📊 Dataset Coverage:")
        print(f"   PlantVillage: {pv_matched_images:,} / {pv_total_images:,} images "
              f"({100*pv_matched_images/pv_total_images:.1f}%)")
        print(f"   PlantDoc: {pd_matched_images:,} / {pd_total_images:,} images "
              f"({100*pd_matched_images/pd_total_images:.1f}%)")

    print(f"\n{'='*80}\n")

    # Save mapping to file
    mapping_data = {
        'exact_matches': class_mapping,
        'fuzzy_matches': fuzzy_matches,
        'num_classes': len(class_mapping),
        'plantvillage_total': len(pv_classes),
        'plantdoc_total': len(pd_classes)
    }

    with open('class_mapping.json', 'w') as f:
        json.dump(mapping_data, f, indent=2)

    print("✓ Class mapping saved to 'class_mapping.json'")

    return mapping_data


def create_class_mapping_dict(mapping_data):
    """
    Create dictionaries for easy class mapping during data loading

    Returns:
        pv_to_common: PlantVillage class name -> common index
        pd_to_common: PlantDoc class name -> common index
        common_names: List of common class names
    """
    exact_matches = mapping_data['exact_matches']

    pv_to_common = {}
    pd_to_common = {}
    common_names = []

    for idx, match in enumerate(exact_matches):
        pv_to_common[match['plantvillage']] = idx
        pd_to_common[match['plantdoc']] = idx
        common_names.append(match['normalized'])

    return pv_to_common, pd_to_common, common_names


# Main execution
if __name__ == "__main__":
    # Set your paths
    PLANTVILLAGE_PATH = '/content/plantvillage'
    PLANTDOC_PATH = '/content/plantdoc'

    # Run analysis
    mapping_data = analyze_class_overlap(PLANTVILLAGE_PATH, PLANTDOC_PATH)

    # Create mapping dictionaries
    pv_map, pd_map, common_names = create_class_mapping_dict(mapping_data)

    print("\n📋 Usage in training:")
    print(f"   NUM_CLASSES = {len(common_names)}")
    print(f"   Common classes: {common_names[:5]}...")

    print("\n✅ Ready to proceed with filtered datasets!")

Writing class_mapper.py


In [13]:
# Run it
exec(open('class_mapper.py').read())


CLASS OVERLAP ANALYSIS

📊 Dataset Statistics:
   PlantVillage classes: 38
   PlantDoc classes: 28

🔍 Exact matches (after normalization): 9

🔎 Finding additional fuzzy matches...
   Found 18 potential fuzzy matches

MATCHED CLASSES (9 exact matches)

 1. PlantVillage: Grape___Black_rot
    PlantDoc:     grape_leaf_black_rot
    Normalized:   grape black rot

 2. PlantVillage: Potato___Early_blight
    PlantDoc:     Potato_leaf_early_blight
    Normalized:   potato early blight

 3. PlantVillage: Potato___Late_blight
    PlantDoc:     Potato_leaf_late_blight
    Normalized:   potato late blight

 4. PlantVillage: Squash___Powdery_mildew
    PlantDoc:     Squash_Powdery_mildew_leaf
    Normalized:   squash powdery mildew

 5. PlantVillage: Tomato___Bacterial_spot
    PlantDoc:     Tomato_leaf_bacterial_spot
    Normalized:   tomato bacterial spot

 6. PlantVillage: Tomato___Early_blight
    PlantDoc:     Tomato_Early_blight_leaf
    Normalized:   tomato early blight

 7. PlantVillage: T

In [14]:
%%writefile data_preprocessing.py

"""
MODULE 1: DATA PREPROCESSING (WITH COMMON CLASS FILTERING)
Filters datasets to only use classes that exist in both PlantVillage and PlantDoc
"""

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import numpy as np
import json


def load_class_mapping(mapping_file='class_mapping.json'):
    """Load class mapping from JSON file"""
    if not os.path.exists(mapping_file):
        raise FileNotFoundError(
            f"Class mapping file not found: {mapping_file}\n"
            f"Please run class_mapper.py first to generate the mapping!"
        )

    with open(mapping_file, 'r') as f:
        mapping_data = json.load(f)

    # Create mapping dictionaries
    pv_to_common = {}
    pd_to_common = {}
    common_names = []

    for idx, match in enumerate(mapping_data['exact_matches']):
        pv_to_common[match['plantvillage']] = idx
        pd_to_common[match['plantdoc']] = idx
        common_names.append(match['normalized'])

    return pv_to_common, pd_to_common, common_names, mapping_data


class PlantVillageDataset(Dataset):
    """Custom Dataset for PlantVillage (Source Domain) - Filtered to common classes"""

    def __init__(self, root_dir, split='train', transform=None, class_filter=None):
        """
        Args:
            root_dir: Root directory containing PlantVillage dataset
            split: 'train', 'val', or 'test'
            transform: Optional transform to be applied on images
            class_filter: Dictionary mapping class names to indices (only use these classes)
        """
        self.root_dir = root_dir
        self.split = split
        self.transform = transform
        self.class_filter = class_filter

        self.image_paths = []
        self.labels = []
        self.class_names = []

        # Handle nested PlantVillage folder
        if os.path.exists(os.path.join(root_dir, 'PlantVillage')):
            root_dir = os.path.join(root_dir, 'PlantVillage')

        # Determine data directory
        if split == 'train':
            data_dir = os.path.join(root_dir, 'train')
        elif split == 'val':
            data_dir = os.path.join(root_dir, 'val')
        else:  # test
            data_dir = os.path.join(root_dir, 'val')

        if not os.path.exists(data_dir):
            raise FileNotFoundError(f"Directory not found: {data_dir}")

        # Get all class folders
        all_classes = sorted([d for d in os.listdir(data_dir)
                             if os.path.isdir(os.path.join(data_dir, d))])

        # Filter to common classes if specified
        if class_filter is not None:
            classes_to_use = [c for c in all_classes if c in class_filter]
            print(f"  Filtering from {len(all_classes)} to {len(classes_to_use)} common classes")
        else:
            classes_to_use = all_classes

        # Load images from filtered classes
        for class_name in classes_to_use:
            class_path = os.path.join(data_dir, class_name)

            # Get mapped index
            if class_filter is not None:
                class_idx = class_filter[class_name]
            else:
                class_idx = len(self.class_names)

            if class_name not in [c for c, _ in self.class_names]:
                self.class_names.append((class_name, class_idx))

            # Get all images
            images = [f for f in os.listdir(class_path)
                     if f.lower().endswith(('.jpg', '.jpeg', '.png'))]

            # Split for test
            if split == 'test':
                np.random.seed(42)
                np.random.shuffle(images)
                mid = len(images) // 2
                images = images[mid:]
            elif split == 'val':
                np.random.seed(42)
                np.random.shuffle(images)
                mid = len(images) // 2
                images = images[:mid]

            for img_name in images:
                self.image_paths.append(os.path.join(class_path, img_name))
                self.labels.append(class_idx)

        print(f"  Loaded {len(self.image_paths)} images from {len(set(self.labels))} classes")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label, 'source'


class PlantDocDataset(Dataset):
    """Custom Dataset for PlantDoc (Target Domain) - Filtered to common classes"""

    def __init__(self, root_dir, split='train', transform=None, labeled=True, class_filter=None):
        """
        Args:
            root_dir: Root directory containing PlantDoc dataset
            split: 'train', 'val', or 'test'
            transform: Optional transform to be applied on images
            labeled: Whether to return labels
            class_filter: Dictionary mapping class names to indices
        """
        self.root_dir = root_dir
        self.split = split
        self.transform = transform
        self.labeled = labeled
        self.class_filter = class_filter

        self.image_paths = []
        self.labels = []
        self.class_names = []

        # Determine data directory
        if split == 'train':
            data_dir = os.path.join(root_dir, 'train')
        else:
            data_dir = os.path.join(root_dir, 'test')

        if not os.path.exists(data_dir):
            raise FileNotFoundError(f"Directory not found: {data_dir}")

        # Get all class folders
        all_classes = sorted([d for d in os.listdir(data_dir)
                             if os.path.isdir(os.path.join(data_dir, d))])

        # Filter to common classes
        if class_filter is not None:
            classes_to_use = [c for c in all_classes if c in class_filter]
            print(f"  Filtering from {len(all_classes)} to {len(classes_to_use)} common classes")
        else:
            classes_to_use = all_classes

        # Collect all images
        all_images = []

        for class_name in classes_to_use:
            class_path = os.path.join(data_dir, class_name)

            # Get mapped index
            if class_filter is not None:
                class_idx = class_filter[class_name]
            else:
                class_idx = len(self.class_names)

            if class_name not in self.class_names:
                self.class_names.append(class_name)

            # Get all images
            images = [f for f in os.listdir(class_path)
                     if f.lower().endswith(('.jpg', '.jpeg', '.png'))]

            for img_name in images:
                img_path = os.path.join(class_path, img_name)
                all_images.append((img_path, class_idx))

        # Split for val/test
        if split in ['val', 'test']:
            np.random.seed(42)
            np.random.shuffle(all_images)
            mid = len(all_images) // 2
            if split == 'val':
                all_images = all_images[:mid]
            else:
                all_images = all_images[mid:]

        # Store
        for img_path, label in all_images:
            self.image_paths.append(img_path)
            self.labels.append(label)

        print(f"  Loaded {len(self.image_paths)} images from {len(set(self.labels))} classes")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert('RGB')
        label = self.labels[idx] if self.labeled else -1

        if self.transform:
            image = self.transform(image)

        return image, label, 'target'


def get_transforms(split='train', augment=True):
    """Get appropriate transforms"""
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )

    if split == 'train' and augment:
        return transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.RandomResizedCrop(224, scale=(0.9, 1.0)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.RandomRotation(degrees=20),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1, hue=0.1),
            transforms.RandomApply([transforms.GaussianBlur(kernel_size=3)], p=0.3),
            transforms.ToTensor(),
            normalize
        ])
    else:
        return transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize
        ])


def create_data_loaders(plantvillage_path, plantdoc_path, batch_size=32,
                       use_common_classes=True, mapping_file='class_mapping.json'):
    """
    Create data loaders with optional common class filtering

    Args:
        plantvillage_path: Path to PlantVillage
        plantdoc_path: Path to PlantDoc
        batch_size: Batch size
        use_common_classes: If True, filter to common classes only
        mapping_file: Path to class mapping JSON file

    Returns:
        Dictionary of data loaders and mapping info
    """
    print("\n" + "="*70)
    print("CREATING FILTERED DATASETS (COMMON CLASSES ONLY)")
    print("="*70 + "\n")

    # Load class mapping
    if use_common_classes:
        print("📋 Loading class mapping...")
        pv_class_map, pd_class_map, common_names, mapping_data = load_class_mapping(mapping_file)
        num_classes = len(common_names)
        print(f"   Using {num_classes} common classes")
        print(f"   Classes: {', '.join(common_names[:5])}{'...' if len(common_names) > 5 else ''}\n")
    else:
        pv_class_map = None
        pd_class_map = None
        common_names = None
        num_classes = None

    # Create datasets
    print("📂 Loading PlantVillage (Source Domain)...")
    source_train = PlantVillageDataset(
        plantvillage_path, 'train', get_transforms('train', True), pv_class_map
    )
    source_val = PlantVillageDataset(
        plantvillage_path, 'val', get_transforms('val'), pv_class_map
    )
    source_test = PlantVillageDataset(
        plantvillage_path, 'test', get_transforms('test'), pv_class_map
    )

    print("\n📂 Loading PlantDoc (Target Domain)...")
    target_train = PlantDocDataset(
        plantdoc_path, 'train', get_transforms('train', True), False, pd_class_map
    )
    target_val = PlantDocDataset(
        plantdoc_path, 'val', get_transforms('val'), True, pd_class_map
    )
    target_test = PlantDocDataset(
        plantdoc_path, 'test', get_transforms('test'), True, pd_class_map
    )

    # Create loaders
    loaders = {
        'source_train': DataLoader(source_train, batch_size=batch_size, shuffle=True,
                                   num_workers=2, pin_memory=True, drop_last=True),
        'source_val': DataLoader(source_val, batch_size=batch_size, shuffle=False,
                                num_workers=2, pin_memory=True),
        'source_test': DataLoader(source_test, batch_size=batch_size, shuffle=False,
                                 num_workers=2, pin_memory=True),
        'target_train': DataLoader(target_train, batch_size=batch_size//2, shuffle=True,
                                   num_workers=2, pin_memory=True, drop_last=True),
        'target_val': DataLoader(target_val, batch_size=batch_size, shuffle=False,
                                num_workers=2, pin_memory=True),
        'target_test': DataLoader(target_test, batch_size=batch_size, shuffle=False,
                                 num_workers=2, pin_memory=True)
    }

    # Summary
    print("\n" + "="*70)
    print("DATASET SUMMARY (FILTERED TO COMMON CLASSES)")
    print("="*70)
    print(f"\n📊 Common Classes: {num_classes}")
    print(f"\n   Source Domain (PlantVillage):")
    print(f"      Train: {len(source_train):>6,} images")
    print(f"      Val:   {len(source_val):>6,} images")
    print(f"      Test:  {len(source_test):>6,} images")
    print(f"\n   Target Domain (PlantDoc):")
    print(f"      Train: {len(target_train):>6,} images")
    print(f"      Val:   {len(target_val):>6,} images")
    print(f"      Test:  {len(target_test):>6,} images")
    print("\n" + "="*70 + "\n")

    return loaders, num_classes, common_names


if __name__ == "__main__":
    print("Module 1: Data Preprocessing (with class filtering) loaded!")

Writing data_preprocessing.py


In [15]:

from data_preprocessing import create_data_loaders

loaders, num_classes, class_names = create_data_loaders(
    plantvillage_path='/content/plantvillage/PlantVillage',  # Note: No /PlantVillage - it auto-detects
    plantdoc_path='/content/plantdoc',
    batch_size=32,
    use_common_classes=True,  # IMPORTANT: Enable filtering
    mapping_file='class_mapping.json'  # Must exist (run class_mapper.py first)
)

# Now you can access the values
print(f"\n✓ Number of classes: {num_classes}")
print(f"✓ Class names: {class_names[:5] if len(class_names) > 5 else class_names}")

# Access data loaders
print(f"\n✓ Data loaders available:")
for key in loaders.keys():
    print(f"   - {key}")


CREATING FILTERED DATASETS (COMMON CLASSES ONLY)

📋 Loading class mapping...
   Using 9 common classes
   Classes: grape black rot, potato early blight, potato late blight, squash powdery mildew, tomato bacterial spot...

📂 Loading PlantVillage (Source Domain)...
  Filtering from 38 to 9 common classes
  Loaded 10219 images from 9 classes
  Filtering from 38 to 9 common classes
  Loaded 1276 images from 9 classes
  Filtering from 38 to 9 common classes
  Loaded 1279 images from 9 classes

📂 Loading PlantDoc (Target Domain)...
  Filtering from 28 to 9 common classes
  Loaded 1063 images from 9 classes
  Filtering from 27 to 9 common classes
  Loaded 41 images from 9 classes
  Filtering from 27 to 9 common classes
  Loaded 41 images from 9 classes

DATASET SUMMARY (FILTERED TO COMMON CLASSES)

📊 Common Classes: 9

   Source Domain (PlantVillage):
      Train: 10,219 images
      Val:    1,276 images
      Test:   1,279 images

   Target Domain (PlantDoc):
      Train:  1,063 images
    

# MODULE 2: MODEL ARCHITECTURE
UncertaintyResNet with domain adaptation components

In [16]:
%%writefile model_architecture.py

"""
MODULE 2: MODEL ARCHITECTURE
UncertaintyResNet with domain adaptation components
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from torch.autograd import Function
import numpy as np


class GradientReversalFunction(Function):
    """Gradient Reversal Layer from DANN (Ganin & Lempitsky, 2015)"""

    @staticmethod
    def forward(ctx, x, lambda_):
        ctx.lambda_ = lambda_
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.neg() * ctx.lambda_, None


class GradientReversalLayer(nn.Module):
    """Wrapper for gradient reversal function"""

    def __init__(self, lambda_=1.0):
        super(GradientReversalLayer, self).__init__()
        self.lambda_ = lambda_

    def forward(self, x):
        return GradientReversalFunction.apply(x, self.lambda_)

    def set_lambda(self, lambda_):
        """Dynamically adjust lambda during training"""
        self.lambda_ = lambda_


class DomainDiscriminator(nn.Module):
    """Domain discriminator for adversarial domain adaptation"""

    def __init__(self, input_dim=2048):
        super(DomainDiscriminator, self).__init__()

        self.discriminator = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.5),

            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.5),

            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.discriminator(x)


class EvidentialClassifier(nn.Module):
    """
    Evidential classification head using Subjective Logic
    Based on Sensoy et al. (2018)
    """

    def __init__(self, input_dim=2048, num_classes=38):
        super(EvidentialClassifier, self).__init__()

        self.num_classes = num_classes

        self.classifier = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        """
        Forward pass producing Dirichlet parameters

        Returns:
            evidence: Evidence for each class
            alpha: Dirichlet parameters
            uncertainty: Epistemic uncertainty measure
        """
        logits = self.classifier(x)
        evidence = F.softplus(logits)
        alpha = evidence + 1.0
        S = torch.sum(alpha, dim=1, keepdim=True)
        uncertainty = self.num_classes / S

        return evidence, alpha, uncertainty

    def predict_proba(self, alpha):
        """Compute predicted class probabilities from Dirichlet parameters"""
        S = torch.sum(alpha, dim=1, keepdim=True)
        probs = alpha / S
        return probs


class UncertaintyResNet(nn.Module):
    """
    Complete uncertainty-aware domain adaptation model
    Combines ResNet50 + evidential classifier + domain discriminator
    """

    def __init__(self, num_classes=38, pretrained=True, freeze_backbone=True):
        """
        Args:
            num_classes: Number of disease classes
            pretrained: Whether to use ImageNet pretrained weights
            freeze_backbone: Whether to freeze early ResNet layers
        """
        super(UncertaintyResNet, self).__init__()

        self.num_classes = num_classes

        # Load pretrained ResNet50
        resnet = models.resnet50(pretrained=pretrained)

        # Feature extractor (all layers except final FC)
        self.feature_extractor = nn.Sequential(
            *list(resnet.children())[:-1]
        )

        # Freeze early layers if specified
        if freeze_backbone:
            for name, param in self.feature_extractor.named_parameters():
                if 'layer4' not in name:
                    param.requires_grad = False

        # Feature dimension
        self.feature_dim = 2048

        # Domain adaptation components
        self.grl = GradientReversalLayer(lambda_=1.0)
        self.domain_discriminator = DomainDiscriminator(self.feature_dim)

        # Evidential classifier
        self.classifier = EvidentialClassifier(self.feature_dim, num_classes)

    def forward(self, x, alpha=1.0):
        """
        Forward pass with optional alpha for GRL

        Args:
            x: Input images (batch_size, 3, 224, 224)
            alpha: Gradient reversal coefficient

        Returns:
            features, evidence, alpha_params, uncertainty, domain_pred
        """
        # Extract features
        features = self.feature_extractor(x)
        features = features.view(features.size(0), -1)

        # Classification path (evidential)
        evidence, alpha_params, uncertainty = self.classifier(features)

        # Domain discrimination path (with gradient reversal)
        self.grl.set_lambda(alpha)
        reversed_features = self.grl(features)
        domain_pred = self.domain_discriminator(reversed_features)

        return features, evidence, alpha_params, uncertainty, domain_pred

    def predict(self, x):
        """
        Make predictions on input images

        Returns:
            probs, uncertainty, predicted_class
        """
        self.eval()
        with torch.no_grad():
            _, evidence, alpha, uncertainty, _ = self.forward(x, alpha=0.0)
            probs = self.classifier.predict_proba(alpha)
            predicted_class = torch.argmax(probs, dim=1)

        return probs, uncertainty, predicted_class

    def get_num_params(self):
        """Get total number of trainable parameters"""
        return sum(p.numel() for p in self.parameters() if p.requires_grad)


def calculate_lambda_schedule(epoch, max_epochs, gamma=10):
    """
    Calculate lambda for gradient reversal layer following DANN paper

    Args:
        epoch: Current epoch
        max_epochs: Total number of epochs
        gamma: Adjustment parameter

    Returns:
        lambda_: Gradient reversal coefficient
    """
    p = epoch / max_epochs
    lambda_ = 2.0 / (1.0 + np.exp(-gamma * p)) - 1.0
    return lambda_


def create_model(num_classes, device='cuda', pretrained=True, freeze_backbone=True):
    """
    Create and initialize the complete model

    Args:
        num_classes: Number of disease classes
        device: Device to place model on
        pretrained: Use ImageNet pretrained weights
        freeze_backbone: Freeze early ResNet layers

    Returns:
        model: Initialized model on specified device
    """
    model = UncertaintyResNet(
        num_classes=num_classes,
        pretrained=pretrained,
        freeze_backbone=freeze_backbone
    )

    model = model.to(device)

    print(f"Model created with {model.get_num_params():,} trainable parameters")
    print(f"Number of classes: {num_classes}")

    return model

Writing model_architecture.py


# MODULE 3: LOSS FUNCTIONS
Evidential loss, adversarial loss, and combined loss

In [17]:
%%writefile loss_functions.py

"""
MODULE 3: LOSS FUNCTIONS
Evidential loss, adversarial loss, and combined loss
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


class EvidentialLoss(nn.Module):
    """
    Evidential classification loss based on subjective logic
    From Sensoy et al. (2018)
    """

    def __init__(self, num_classes, lambda_kl=0.1):
        """
        Args:
            num_classes: Number of classes
            lambda_kl: Weight for KL divergence regularization
        """
        super(EvidentialLoss, self).__init__()
        self.num_classes = num_classes
        self.lambda_kl = lambda_kl

    def forward(self, evidence, alpha, targets, epoch):
        """
        Compute evidential loss

        Args:
            evidence: Evidence for each class (batch_size, num_classes)
            alpha: Dirichlet parameters (batch_size, num_classes)
            targets: True class labels (batch_size,)
            epoch: Current epoch (for annealing KL weight)

        Returns:
            total_loss, classification_loss, kl_loss
        """
        # Convert targets to one-hot
        y = F.one_hot(targets, num_classes=self.num_classes).float()

        # Sum of Dirichlet parameters
        S = torch.sum(alpha, dim=1, keepdim=True)

        # Classification loss (Type II maximum likelihood)
        classification_loss = torch.sum(
            y * (torch.digamma(S) - torch.digamma(alpha)),
            dim=1
        )
        classification_loss = torch.mean(classification_loss)

        # KL divergence regularization
        alpha_tilde = y + (1 - y) * alpha
        S_tilde = torch.sum(alpha_tilde, dim=1, keepdim=True)

        kl_loss = self._kl_divergence(alpha_tilde, S_tilde)
        kl_loss = torch.mean(kl_loss)

        # Anneal KL weight
        annealing_coef = min(1.0, epoch / 10.0)
        weighted_kl_loss = annealing_coef * self.lambda_kl * kl_loss

        # Total loss
        total_loss = classification_loss + weighted_kl_loss

        return total_loss, classification_loss, kl_loss

    def _kl_divergence(self, alpha, S):
        """Compute KL divergence between Dirichlet distributions"""
        beta = torch.ones_like(alpha)
        S_beta = torch.sum(beta, dim=1, keepdim=True)

        kl = torch.lgamma(S) - torch.lgamma(S_beta) - torch.sum(
            torch.lgamma(alpha) - torch.lgamma(beta), dim=1, keepdim=True
        ) + torch.sum(
            (alpha - beta) * (torch.digamma(alpha) - torch.digamma(S)),
            dim=1, keepdim=True
        )

        return kl


class DomainAdversarialLoss(nn.Module):
    """Domain adversarial loss for domain confusion"""

    def __init__(self):
        super(DomainAdversarialLoss, self).__init__()
        self.criterion = nn.BCELoss()

    def forward(self, domain_pred, domain_labels):
        """
        Compute domain adversarial loss

        Args:
            domain_pred: Domain predictions (batch_size, 1)
            domain_labels: True domain labels (0=source, 1=target)

        Returns:
            loss: Domain adversarial loss
        """
        loss = self.criterion(domain_pred, domain_labels)
        return loss


class CombinedDomainAdaptationLoss(nn.Module):
    """
    Combined loss for uncertainty-aware domain adaptation
    L_total = L_cls + lambda_adv * L_adv + lambda_kl * L_kl
    """

    def __init__(self, num_classes, lambda_adv=1.0, lambda_kl=0.1):
        """
        Args:
            num_classes: Number of classes
            lambda_adv: Weight for adversarial loss
            lambda_kl: Weight for KL regularization
        """
        super(CombinedDomainAdaptationLoss, self).__init__()

        self.evidential_loss = EvidentialLoss(num_classes, lambda_kl)
        self.adversarial_loss = DomainAdversarialLoss()
        self.lambda_adv = lambda_adv

    def forward(self, evidence, alpha, targets, domain_pred, domain_labels,
                epoch, lambda_schedule=1.0):
        """
        Compute combined loss

        Args:
            evidence: Evidence for each class
            alpha: Dirichlet parameters
            targets: True class labels
            domain_pred: Domain predictions
            domain_labels: True domain labels
            epoch: Current epoch
            lambda_schedule: Scheduled weight for adversarial loss

        Returns:
            Dictionary containing all loss components
        """
        # Classification loss (evidential)
        cls_loss, cls_component, kl_component = self.evidential_loss(
            evidence, alpha, targets, epoch
        )

        # Domain adversarial loss
        adv_loss = self.adversarial_loss(domain_pred, domain_labels)

        # Weighted adversarial loss (with scheduling)
        weighted_adv_loss = self.lambda_adv * lambda_schedule * adv_loss

        # Total loss
        total_loss = cls_loss + weighted_adv_loss

        return {
            'total': total_loss,
            'classification': cls_component,
            'kl_divergence': kl_component,
            'adversarial': adv_loss,
            'weighted_adversarial': weighted_adv_loss
        }


def calculate_accuracy(predictions, targets):
    """Calculate classification accuracy"""
    correct = (predictions == targets).sum().item()
    total = targets.size(0)
    accuracy = 100.0 * correct / total
    return accuracy


def calculate_expected_calibration_error(probs, targets, n_bins=15):
    """
    Calculate Expected Calibration Error (ECE)
    Measures reliability of confidence estimates
    """
    # Get confidence and predictions
    confidences, predictions = torch.max(probs, dim=1)
    accuracies = (predictions == targets).float()

    # Create bins
    bin_boundaries = torch.linspace(0, 1, n_bins + 1)
    ece = 0.0

    for i in range(n_bins):
        in_bin = (confidences > bin_boundaries[i]) & (confidences <= bin_boundaries[i + 1])
        prop_in_bin = in_bin.float().mean()

        if prop_in_bin.item() > 0:
            accuracy_in_bin = accuracies[in_bin].mean()
            avg_confidence_in_bin = confidences[in_bin].mean()
            ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin

    return ece.item()


class UncertaintyMetrics:
    """Calculate various uncertainty-related metrics"""

    @staticmethod
    def epistemic_uncertainty(alpha):
        """Calculate epistemic uncertainty from Dirichlet parameters"""
        S = torch.sum(alpha, dim=1)
        num_classes = alpha.size(1)
        uncertainty = num_classes / S
        return uncertainty

    @staticmethod
    def uncertainty_stratified_accuracy(probs, uncertainty, targets,
                                       thresholds=[0.15, 0.35]):
        """Calculate accuracy for different uncertainty strata"""
        predictions = torch.argmax(probs, dim=1)
        correct = (predictions == targets).float()

        # High confidence (low uncertainty)
        high_conf_mask = uncertainty < thresholds[0]
        high_conf_acc = correct[high_conf_mask].mean().item() if high_conf_mask.any() else 0.0
        high_conf_count = high_conf_mask.sum().item()

        # Medium confidence
        med_conf_mask = (uncertainty >= thresholds[0]) & (uncertainty < thresholds[1])
        med_conf_acc = correct[med_conf_mask].mean().item() if med_conf_mask.any() else 0.0
        med_conf_count = med_conf_mask.sum().item()

        # Low confidence (high uncertainty)
        low_conf_mask = uncertainty >= thresholds[1]
        low_conf_acc = correct[low_conf_mask].mean().item() if low_conf_mask.any() else 0.0
        low_conf_count = low_conf_mask.sum().item()

        return {
            'high_confidence': {
                'accuracy': high_conf_acc * 100,
                'count': high_conf_count,
                'percentage': 100 * high_conf_count / len(uncertainty)
            },
            'medium_confidence': {
                'accuracy': med_conf_acc * 100,
                'count': med_conf_count,
                'percentage': 100 * med_conf_count / len(uncertainty)
            },
            'low_confidence': {
                'accuracy': low_conf_acc * 100,
                'count': low_conf_count,
                'percentage': 100 * low_conf_count / len(uncertainty)
            }
        }

Writing loss_functions.py


# MODULE 4: TRAINING LOOP
Two-stage training procedure with domain adaptation

In [18]:
%%writefile training_loop.py

"""
MODULE 4: TRAINING LOOP
Two-stage training procedure with domain adaptation
"""

import torch
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
import os
from tqdm import tqdm


class DomainAdaptationTrainer:
    """Trainer for uncertainty-aware domain adaptation"""

    def __init__(self, model, device, num_classes, save_dir='./checkpoints'):
        """
        Args:
            model: The UncertaintyResNet model
            device: Device to train on
            num_classes: Number of classes
            save_dir: Directory to save checkpoints
        """
        self.model = model
        self.device = device
        self.num_classes = num_classes
        self.save_dir = save_dir

        os.makedirs(save_dir, exist_ok=True)

        # Training history
        self.history = {
            'train_loss': [],
            'train_acc': [],
            'val_acc': [],
            'target_val_acc': []
        }

        self.best_target_acc = 0.0
        self.best_epoch = 0

    def stage1_pretrain(self, train_loader, val_loader, epochs=30,
                       lr=0.001, weight_decay=1e-4, patience=15):
        """Stage 1: Pre-training on source domain"""

        print("\n" + "="*60)
        print("STAGE 1: PRE-TRAINING ON SOURCE DOMAIN")
        print("="*60)

        # Import here to avoid circular imports
        from loss_functions import EvidentialLoss

        # Optimizer and scheduler
        optimizer = optim.Adam(
            filter(lambda p: p.requires_grad, self.model.parameters()),
            lr=lr,
            weight_decay=weight_decay
        )

        scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)
        criterion = EvidentialLoss(self.num_classes, lambda_kl=0.1)

        best_val_acc = 0.0
        patience_counter = 0

        for epoch in range(epochs):
            # Training
            self.model.train()
            train_loss = 0.0
            train_correct = 0
            train_total = 0

            pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}')
            for images, labels, _ in pbar:
                images = images.to(self.device)
                labels = labels.to(self.device)

                optimizer.zero_grad()

                # Forward pass
                _, evidence, alpha, _, _ = self.model(images, alpha=0.0)

                # Compute loss
                loss, cls_loss, kl_loss = criterion(evidence, alpha, labels, epoch)

                # Backward pass
                loss.backward()
                optimizer.step()

                # Calculate accuracy
                probs = self.model.classifier.predict_proba(alpha)
                predictions = torch.argmax(probs, dim=1)

                train_loss += loss.item()
                train_correct += (predictions == labels).sum().item()
                train_total += labels.size(0)

                pbar.set_postfix({
                    'loss': f'{loss.item():.4f}',
                    'acc': f'{100.0 * train_correct / train_total:.2f}%'
                })

            # Validation
            val_acc = self._validate(val_loader, domain='source')

            # Update learning rate
            scheduler.step()

            # Save metrics
            self.history['train_loss'].append(train_loss / len(train_loader))
            self.history['train_acc'].append(100.0 * train_correct / train_total)
            self.history['val_acc'].append(val_acc)

            print(f"Epoch {epoch+1}: Train Acc = {100.0 * train_correct / train_total:.2f}%, "
                  f"Val Acc = {val_acc:.2f}%")

            # Early stopping and checkpointing
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                patience_counter = 0
                self._save_checkpoint(epoch, val_acc, 'stage1_best.pth')
            else:
                patience_counter += 1

            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                break

        # Load best model
        self._load_checkpoint('stage1_best.pth')
        print(f"\nStage 1 completed. Best validation accuracy: {best_val_acc:.2f}%")

    def stage2_domain_adaptation(self, source_loader, target_loader,
                                 target_val_loader, epochs=50,
                                 lr=0.0001, weight_decay=1e-4, patience=15):
        """Stage 2: Domain adaptation with target domain data"""

        print("\n" + "="*60)
        print("STAGE 2: DOMAIN ADAPTATION")
        print("="*60)

        # Import here to avoid circular imports
        from loss_functions import CombinedDomainAdaptationLoss
        from model_architecture import calculate_lambda_schedule

        # Optimizer and scheduler
        optimizer = optim.Adam(
            filter(lambda p: p.requires_grad, self.model.parameters()),
            lr=lr,
            weight_decay=weight_decay
        )

        scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)
        criterion = CombinedDomainAdaptationLoss(self.num_classes, lambda_adv=1.0, lambda_kl=0.1)

        best_target_acc = 0.0
        patience_counter = 0

        for epoch in range(epochs):
            self.model.train()
            train_loss = 0.0
            train_correct = 0
            train_total = 0

            # Calculate gradient reversal lambda
            grl_lambda = calculate_lambda_schedule(epoch, epochs)

            # Iterate through both source and target data
            source_iter = iter(source_loader)
            target_iter = iter(target_loader)

            n_batches = min(len(source_loader), len(target_loader))

            pbar = tqdm(range(n_batches), desc=f'Epoch {epoch+1}/{epochs}')
            for _ in pbar:
                # Get source batch
                try:
                    source_images, source_labels, _ = next(source_iter)
                except StopIteration:
                    source_iter = iter(source_loader)
                    source_images, source_labels, _ = next(source_iter)

                # Get target batch
                try:
                    target_images, _, _ = next(target_iter)
                except StopIteration:
                    target_iter = iter(target_loader)
                    target_images, _, _ = next(target_iter)

                source_images = source_images.to(self.device)
                source_labels = source_labels.to(self.device)
                target_images = target_images.to(self.device)

                # Combine source and target
                combined_images = torch.cat([source_images, target_images], dim=0)
                batch_size_source = source_images.size(0)
                batch_size_target = target_images.size(0)

                # Domain labels
                domain_labels = torch.cat([
                    torch.zeros(batch_size_source, 1),
                    torch.ones(batch_size_target, 1)
                ], dim=0).to(self.device)

                optimizer.zero_grad()

                # Forward pass
                _, evidence, alpha, _, domain_pred = self.model(
                    combined_images, alpha=grl_lambda
                )

                # Split outputs
                source_evidence = evidence[:batch_size_source]
                source_alpha = alpha[:batch_size_source]

                # Compute combined loss
                losses = criterion(
                    source_evidence, source_alpha, source_labels,
                    domain_pred, domain_labels, epoch, lambda_schedule=grl_lambda
                )

                # Backward pass
                losses['total'].backward()
                optimizer.step()

                # Calculate accuracy
                probs = self.model.classifier.predict_proba(source_alpha)
                predictions = torch.argmax(probs, dim=1)

                train_loss += losses['total'].item()
                train_correct += (predictions == source_labels).sum().item()
                train_total += source_labels.size(0)

                pbar.set_postfix({
                    'loss': f'{losses["total"].item():.4f}',
                    'acc': f'{100.0 * train_correct / train_total:.2f}%',
                    'lambda': f'{grl_lambda:.3f}'
                })

            # Validation on target domain
            target_val_acc = self._validate(target_val_loader, domain='target')

            # Update learning rate
            scheduler.step()

            # Save metrics
            self.history['train_loss'].append(train_loss / n_batches)
            self.history['train_acc'].append(100.0 * train_correct / train_total)
            self.history['target_val_acc'].append(target_val_acc)

            print(f"Epoch {epoch+1}: Train Acc = {100.0 * train_correct / train_total:.2f}%, "
                  f"Target Val Acc = {target_val_acc:.2f}%, Lambda = {grl_lambda:.3f}")

            # Early stopping and checkpointing
            if target_val_acc > best_target_acc:
                best_target_acc = target_val_acc
                self.best_target_acc = best_target_acc
                self.best_epoch = epoch
                patience_counter = 0
                self._save_checkpoint(epoch, target_val_acc, 'stage2_best.pth')
            else:
                patience_counter += 1

            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                break

        # Load best model
        self._load_checkpoint('stage2_best.pth')
        print(f"\nStage 2 completed. Best target validation accuracy: {best_target_acc:.2f}%")

    def _validate(self, val_loader, domain='source'):
        """Validate model on validation set"""
        self.model.eval()
        correct = 0
        total = 0

        with torch.no_grad():
            for images, labels, _ in val_loader:
                images = images.to(self.device)
                labels = labels.to(self.device)

                probs, _, predictions = self.model.predict(images)

                correct += (predictions == labels).sum().item()
                total += labels.size(0)

        accuracy = 100.0 * correct / total
        return accuracy

    def _save_checkpoint(self, epoch, accuracy, filename):
        """Save model checkpoint"""
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'accuracy': accuracy,
            'history': self.history
        }
        torch.save(checkpoint, os.path.join(self.save_dir, filename))

    def _load_checkpoint(self, filename):
        """Load model checkpoint"""
        checkpoint = torch.load(os.path.join(self.save_dir, filename))
        self.model.load_state_dict(checkpoint['model_state_dict'])


def train_model(model, data_loaders, device, num_classes, config):
    """
    Complete training pipeline

    Args:
        model: The model to train
        data_loaders: Dictionary of data loaders
        device: Device to train on
        num_classes: Number of classes
        config: Training configuration dictionary

    Returns:
        trained model and trainer object
    """
    trainer = DomainAdaptationTrainer(
        model=model,
        device=device,
        num_classes=num_classes,
        save_dir=config.get('save_dir', './checkpoints')
    )

    # Stage 1: Pre-training
    trainer.stage1_pretrain(
        train_loader=data_loaders['source_train'],
        val_loader=data_loaders['source_val'],
        epochs=config.get('stage1_epochs', 30),
        lr=config.get('stage1_lr', 0.001),
        weight_decay=config.get('weight_decay', 1e-4),
        patience=config.get('patience', 15)
    )

    # Stage 2: Domain adaptation
    trainer.stage2_domain_adaptation(
        source_loader=data_loaders['source_train'],
        target_loader=data_loaders['target_train'],
        target_val_loader=data_loaders['target_val'],
        epochs=config.get('stage2_epochs', 50),
        lr=config.get('stage2_lr', 0.0001),
        weight_decay=config.get('weight_decay', 1e-4),
        patience=config.get('patience', 15)
    )

    return model, trainer

Writing training_loop.py


# MODULE 5: EVALUATION CODE
Comprehensive model evaluation with metrics and visualizations

In [19]:
%%writefile evaluation_code.py

"""
MODULE 5: EVALUATION CODE
Comprehensive model evaluation with metrics and visualizations
"""

import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, precision_recall_fscore_support
from tqdm import tqdm
import os


class ModelEvaluator:
    """Comprehensive model evaluation"""

    def __init__(self, model, device, class_names=None):
        """
        Args:
            model: Trained model
            device: Device to evaluate on
            class_names: List of class names
        """
        self.model = model
        self.device = device
        self.class_names = class_names
        self.model.eval()

    def evaluate(self, test_loader):
        """
        Comprehensive evaluation on test set

        Returns:
            Dictionary containing all evaluation metrics
        """
        print(f"\n{'='*60}")
        print(f"EVALUATING MODEL")
        print(f"{'='*60}\n")

        all_predictions = []
        all_targets = []
        all_probs = []
        all_uncertainties = []

        with torch.no_grad():
            for images, labels, _ in tqdm(test_loader, desc='Evaluating'):
                images = images.to(self.device)
                labels = labels.to(self.device)

                probs, uncertainty, predictions = self.model.predict(images)

                all_predictions.extend(predictions.cpu().numpy())
                all_targets.extend(labels.cpu().numpy())
                all_probs.append(probs.cpu().numpy())
                all_uncertainties.extend(uncertainty.squeeze().cpu().numpy())

        all_predictions = np.array(all_predictions)
        all_targets = np.array(all_targets)
        all_probs = np.vstack(all_probs)
        all_uncertainties = np.array(all_uncertainties)

        # Calculate metrics
        results = {
            'accuracy': self._calculate_accuracy(all_predictions, all_targets),
            'precision_recall_f1': self._calculate_prf1(all_predictions, all_targets),
            'confusion_matrix': confusion_matrix(all_targets, all_predictions),
            'ece': self._calculate_ece(all_probs, all_targets),
            'uncertainty_metrics': self._analyze_uncertainty(
                all_probs, all_uncertainties, all_targets
            )
        }

        # Print results
        self._print_results(results)

        return results, all_predictions, all_targets, all_probs, all_uncertainties

    def _calculate_accuracy(self, predictions, targets):
        """Calculate overall accuracy"""
        correct = (predictions == targets).sum()
        total = len(targets)
        accuracy = 100.0 * correct / total
        return accuracy

    def _calculate_prf1(self, predictions, targets):
        """Calculate precision, recall, F1-score"""
        precision, recall, f1, _ = precision_recall_fscore_support(
            targets, predictions, average='macro', zero_division=0
        )

        return {
            'precision': precision,
            'recall': recall,
            'f1_score': f1
        }

    def _calculate_ece(self, probs, targets, n_bins=15):
        """Calculate Expected Calibration Error"""
        confidences = np.max(probs, axis=1)
        predictions = np.argmax(probs, axis=1)
        accuracies = (predictions == targets).astype(float)

        bin_boundaries = np.linspace(0, 1, n_bins + 1)
        ece = 0.0

        for i in range(n_bins):
            in_bin = (confidences > bin_boundaries[i]) & (confidences <= bin_boundaries[i + 1])
            prop_in_bin = in_bin.mean()

            if prop_in_bin > 0:
                accuracy_in_bin = accuracies[in_bin].mean()
                avg_confidence_in_bin = confidences[in_bin].mean()
                ece += np.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin

        return ece

    def _analyze_uncertainty(self, probs, uncertainties, targets,
                            thresholds=[0.15, 0.35]):
        """Analyze uncertainty-stratified performance"""
        predictions = np.argmax(probs, axis=1)
        correct = (predictions == targets).astype(float)

        # High confidence
        high_conf_mask = uncertainties < thresholds[0]
        high_conf_acc = correct[high_conf_mask].mean() if high_conf_mask.any() else 0.0
        high_conf_count = high_conf_mask.sum()

        # Medium confidence
        med_conf_mask = (uncertainties >= thresholds[0]) & (uncertainties < thresholds[1])
        med_conf_acc = correct[med_conf_mask].mean() if med_conf_mask.any() else 0.0
        med_conf_count = med_conf_mask.sum()

        # Low confidence
        low_conf_mask = uncertainties >= thresholds[1]
        low_conf_acc = correct[low_conf_mask].mean() if low_conf_mask.any() else 0.0
        low_conf_count = low_conf_mask.sum()

        total = len(uncertainties)

        return {
            'high_confidence': {
                'accuracy': high_conf_acc * 100,
                'count': int(high_conf_count),
                'percentage': 100 * high_conf_count / total
            },
            'medium_confidence': {
                'accuracy': med_conf_acc * 100,
                'count': int(med_conf_count),
                'percentage': 100 * med_conf_count / total
            },
            'low_confidence': {
                'accuracy': low_conf_acc * 100,
                'count': int(low_conf_count),
                'percentage': 100 * low_conf_count / total
            }
        }

    def _print_results(self, results):
        """Print evaluation results"""
        print("\n" + "="*60)
        print("EVALUATION RESULTS")
        print("="*60)

        print(f"\nOverall Accuracy: {results['accuracy']:.2f}%")

        prf = results['precision_recall_f1']
        print(f"Precision (Macro): {prf['precision']:.4f}")
        print(f"Recall (Macro): {prf['recall']:.4f}")
        print(f"F1-Score (Macro): {prf['f1_score']:.4f}")

        print(f"\nExpected Calibration Error: {results['ece']:.4f}")

        print("\nUncertainty-Stratified Performance:")
        for conf_level, metrics in results['uncertainty_metrics'].items():
            print(f"  {conf_level.replace('_', ' ').title()}:")
            print(f"    Accuracy: {metrics['accuracy']:.2f}%")
            print(f"    Samples: {metrics['count']} ({metrics['percentage']:.1f}%)")

    def plot_confusion_matrix(self, confusion_mat, save_path='confusion_matrix.png'):
        """Plot confusion matrix"""
        plt.figure(figsize=(12, 10))

        # Normalize
        confusion_mat = confusion_mat.astype('float') / confusion_mat.sum(axis=1)[:, np.newaxis]

        sns.heatmap(
            confusion_mat,
            annot=False,
            fmt='.2f',
            cmap='Blues',
            square=True,
            cbar_kws={'label': 'Proportion'}
        )

        plt.title('Normalized Confusion Matrix', fontsize=16, fontweight='bold')
        plt.ylabel('True Label', fontsize=12)
        plt.xlabel('Predicted Label', fontsize=12)
        plt.tight_layout()
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"\nConfusion matrix saved to {save_path}")
        plt.close()

    def plot_calibration_curve(self, probs, targets, save_path='calibration_curve.png'):
        """Plot reliability diagram"""
        confidences = np.max(probs, axis=1)
        predictions = np.argmax(probs, axis=1)
        accuracies = (predictions == targets).astype(float)

        n_bins = 15
        bin_boundaries = np.linspace(0, 1, n_bins + 1)

        bin_accuracies = []
        bin_confidences = []

        for i in range(n_bins):
            in_bin = (confidences > bin_boundaries[i]) & (confidences <= bin_boundaries[i + 1])
            if in_bin.sum() > 0:
                bin_accuracies.append(accuracies[in_bin].mean())
                bin_confidences.append(confidences[in_bin].mean())

        fig, axes = plt.subplots(1, 2, figsize=(14, 5))

        # Calibration curve
        axes[0].plot([0, 1], [0, 1], 'k--', label='Perfect Calibration')
        axes[0].plot(bin_confidences, bin_accuracies, 'o-', label='Model Calibration')
        axes[0].set_xlabel('Confidence', fontsize=12)
        axes[0].set_ylabel('Accuracy', fontsize=12)
        axes[0].set_title('Reliability Diagram', fontsize=14, fontweight='bold')
        axes[0].legend()
        axes[0].grid(True, alpha=0.3)

        # Confidence histogram
        axes[1].hist(confidences, bins=n_bins, edgecolor='black', alpha=0.7)
        axes[1].set_xlabel('Confidence', fontsize=12)
        axes[1].set_ylabel('Count', fontsize=12)
        axes[1].set_title('Confidence Distribution', fontsize=14, fontweight='bold')
        axes[1].grid(True, alpha=0.3)

        plt.tight_layout()
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Calibration curve saved to {save_path}")
        plt.close()

    def plot_uncertainty_distribution(self, uncertainties, predictions, targets,
                                     save_path='uncertainty_distribution.png'):
        """Plot uncertainty distribution"""
        correct_mask = predictions == targets

        fig, axes = plt.subplots(1, 2, figsize=(14, 5))

        # Histogram
        axes[0].hist(uncertainties[correct_mask], bins=30, alpha=0.6,
                    label='Correct', color='green', edgecolor='black')
        axes[0].hist(uncertainties[~correct_mask], bins=30, alpha=0.6,
                    label='Incorrect', color='red', edgecolor='black')
        axes[0].set_xlabel('Uncertainty', fontsize=12)
        axes[0].set_ylabel('Count', fontsize=12)
        axes[0].set_title('Uncertainty Distribution', fontsize=14, fontweight='bold')
        axes[0].legend()
        axes[0].grid(True, alpha=0.3)

        # Box plot
        data = [uncertainties[correct_mask], uncertainties[~correct_mask]]
        axes[1].boxplot(data, labels=['Correct', 'Incorrect'])
        axes[1].set_ylabel('Uncertainty', fontsize=12)
        axes[1].set_title('Uncertainty Comparison', fontsize=14, fontweight='bold')
        axes[1].grid(True, alpha=0.3)

        plt.tight_layout()
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Uncertainty distribution saved to {save_path}")
        plt.close()


def evaluate_model(model, test_loader, device, class_names=None, save_dir='./results'):
    """
    Complete model evaluation with all visualizations

    Args:
        model: Trained model
        test_loader: Test data loader
        device: Device to evaluate on
        class_names: List of class names
        save_dir: Directory to save results

    Returns:
        Dictionary with all evaluation results
    """
    os.makedirs(save_dir, exist_ok=True)

    evaluator = ModelEvaluator(model, device, class_names)

    # Run evaluation
    results, predictions, targets, probs, uncertainties = evaluator.evaluate(test_loader)

    # Generate visualizations
    print("\nGenerating visualizations...")

    evaluator.plot_confusion_matrix(
        results['confusion_matrix'],
        save_path=os.path.join(save_dir, 'confusion_matrix.png')
    )

    evaluator.plot_calibration_curve(
        probs, targets,
        save_path=os.path.join(save_dir, 'calibration_curve.png')
    )

    evaluator.plot_uncertainty_distribution(
        uncertainties, predictions, targets,
        save_path=os.path.join(save_dir, 'uncertainty_distribution.png')
    )

    print(f"\nAll evaluation results saved to {save_dir}")

    return results

Writing evaluation_code.py


# Complete Pipeline

In [20]:
# Check all modules exist
import os
modules = ['data_preprocessing', 'model_architecture', 'loss_functions',
           'training_loop', 'evaluation_code']
for m in modules:
    print(f"{'✓' if os.path.exists(f'{m}.py') else '✗'} {m}.py")

✓ data_preprocessing.py
✓ model_architecture.py
✓ loss_functions.py
✓ training_loop.py
✓ evaluation_code.py


In [None]:
"""
COMPLETE EXECUTION SCRIPT - WITH COMMON CLASS FILTERING
This version only trains on classes that exist in both datasets
"""

import torch
import os
import json

print("="*80)
print("UNCERTAINTY-AWARE DOMAIN ADAPTATION (COMMON CLASSES)")
print("="*80)

# ============================================================================
# STEP 0: CHECK CLASS MAPPING EXISTS
# ============================================================================

if not os.path.exists('class_mapping.json'):
    print("\n❌ ERROR: class_mapping.json not found!")
    print("\n⚠️  You must run the Class Mapping Analysis first:")
    print("   1. Save and run class_mapper.py")
    print("   2. Review the common classes found")
    print("   3. Then run this script")
    raise FileNotFoundError("Run class_mapper.py first!")

# Load mapping to show info
with open('class_mapping.json', 'r') as f:
    mapping_data = json.load(f)

print(f"\n📋 Class Mapping Info:")
print(f"   Common classes found: {mapping_data['num_classes']}")
print(f"   PlantVillage total: {mapping_data['plantvillage_total']}")
print(f"   PlantDoc total: {mapping_data['plantdoc_total']}")
print(f"   Match rate: {100*mapping_data['num_classes']/min(mapping_data['plantvillage_total'], mapping_data['plantdoc_total']):.1f}%")

# ============================================================================
# CONFIGURATION
# ============================================================================

PLANTVILLAGE_PATH = '/content/plantvillage/PlantVillage'
PLANTDOC_PATH = '/content/plantdoc'

# TRAINING SETTINGS
STAGE1_EPOCHS = 15
STAGE2_EPOCHS = 25
BATCH_SIZE = 32
LEARNING_RATE_STAGE1 = 0.001
LEARNING_RATE_STAGE2 = 0.0001
WEIGHT_DECAY = 1e-4
PATIENCE = 10

print(f"\n⚙️  Training Configuration:")
print(f"   Stage 1 epochs: {STAGE1_EPOCHS}")
print(f"   Stage 2 epochs: {STAGE2_EPOCHS}")
print(f"   Batch size: {BATCH_SIZE}")

# ============================================================================
# SETUP
# ============================================================================

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\n🖥️  Device: {device}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")

os.makedirs('./checkpoints', exist_ok=True)
os.makedirs('./results', exist_ok=True)

# ============================================================================
# IMPORT MODULES
# ============================================================================

print("\n" + "="*80)
print("IMPORTING MODULES")
print("="*80 + "\n")

try:
    from data_preprocessing import create_data_loaders
    print("✓ data_preprocessing (with class filtering)")
except ImportError as e:
    print(f"✗ Error: {e}")
    print("\n⚠️  Make sure you've saved the UPDATED Module 1")
    raise

try:
    from model_architecture import create_model
    print("✓ model_architecture")
except ImportError:
    print("✗ model_architecture not found")
    raise

try:
    from training_loop import train_model
    print("✓ training_loop")
except ImportError:
    print("✗ training_loop not found")
    raise

try:
    from evaluation_code import evaluate_model
    print("✓ evaluation_code")
except ImportError:
    print("✗ evaluation_code not found")
    raise

# ============================================================================
# CREATE DATA LOADERS (WITH CLASS FILTERING)
# ============================================================================

print("\n" + "="*80)
print("CREATING DATA LOADERS (FILTERED TO COMMON CLASSES)")
print("="*80)

data_loaders, NUM_CLASSES, common_names = create_data_loaders(
    plantvillage_path=PLANTVILLAGE_PATH,
    plantdoc_path=PLANTDOC_PATH,
    batch_size=BATCH_SIZE,
    use_common_classes=True,  # KEY: Enable class filtering
    mapping_file='class_mapping.json'
)

print(f"\n✓ Loaded filtered datasets with {NUM_CLASSES} common classes")

# ============================================================================
# CREATE MODEL
# ============================================================================

print("\n" + "="*80)
print("CREATING MODEL")
print("="*80)

model = create_model(
    num_classes=NUM_CLASSES,  # Uses only common classes
    device=device,
    pretrained=True,
    freeze_backbone=True
)

print(f"\n✓ Model created with {NUM_CLASSES} output classes")

# ============================================================================
# TRAINING
# ============================================================================

print("\n" + "="*80)
print("TRAINING MODEL")
print("="*80)

approx_time = (STAGE1_EPOCHS * 2) + (STAGE2_EPOCHS * 3)
print(f"\n⏱️  Estimated time: ~{approx_time} minutes")
print(f"🚀 Starting training...\n")

config = {
    'stage1_epochs': STAGE1_EPOCHS,
    'stage2_epochs': STAGE2_EPOCHS,
    'stage1_lr': LEARNING_RATE_STAGE1,
    'stage2_lr': LEARNING_RATE_STAGE2,
    'weight_decay': WEIGHT_DECAY,
    'patience': PATIENCE,
    'save_dir': './checkpoints'
}

try:
    model, trainer = train_model(
        model=model,
        data_loaders=data_loaders,
        device=device,
        num_classes=NUM_CLASSES,
        config=config
    )

    print("\n" + "="*80)
    print("✓ TRAINING COMPLETED")
    print("="*80)
    print(f"   Best target accuracy: {trainer.best_target_acc:.2f}%")

except Exception as e:
    print(f"\n✗ Training error: {e}")
    import traceback
    traceback.print_exc()
    raise

# ============================================================================
# EVALUATION
# ============================================================================

print("\n" + "="*80)
print("EVALUATING MODEL")
print("="*80)

try:
    results = evaluate_model(
        model=model,
        test_loader=data_loaders['target_test'],
        device=device,
        class_names=common_names,  # Pass common class names
        save_dir='./results'
    )

    print("\n✓ Evaluation completed")

except Exception as e:
    print(f"\n✗ Evaluation error: {e}")
    raise

# ============================================================================
# FINAL SUMMARY
# ============================================================================

print("\n" + "="*80)
print("🎉 EXPERIMENT COMPLETED!")
print("="*80)

print(f"\n📊 Dataset Info:")
print(f"   Common classes used: {NUM_CLASSES}")
print(f"   Class names: {', '.join(common_names[:5])}...")

print(f"\n📈 Results:")
print(f"   Target Accuracy: {results['accuracy']:.2f}%")
print(f"   Precision: {results['precision_recall_f1']['precision']:.4f}")
print(f"   Recall: {results['precision_recall_f1']['recall']:.4f}")
print(f"   F1-Score: {results['precision_recall_f1']['f1_score']:.4f}")
print(f"   ECE: {results['ece']:.4f}")

print(f"\n💾 Saved Files:")
print(f"   class_mapping.json - Class mapping used")
print(f"   ./checkpoints/stage1_best.pth")
print(f"   ./checkpoints/stage2_best.pth")
print(f"   ./results/confusion_matrix.png")
print(f"   ./results/calibration_curve.png")
print(f"   ./results/uncertainty_distribution.png")

# Save detailed summary
with open('./results/experiment_summary.txt', 'w') as f:
    f.write("EXPERIMENT SUMMARY (COMMON CLASSES)\n")
    f.write("="*80 + "\n\n")
    f.write(f"Dataset Configuration:\n")
    f.write(f"  Common classes: {NUM_CLASSES}\n")
    f.write(f"  Class names: {', '.join(common_names)}\n\n")
    f.write(f"Training Configuration:\n")
    f.write(f"  Stage 1 epochs: {STAGE1_EPOCHS}\n")
    f.write(f"  Stage 2 epochs: {STAGE2_EPOCHS}\n\n")
    f.write(f"Results:\n")
    f.write(f"  Accuracy: {results['accuracy']:.2f}%\n")
    f.write(f"  Precision: {results['precision_recall_f1']['precision']:.4f}\n")
    f.write(f"  Recall: {results['precision_recall_f1']['recall']:.4f}\n")
    f.write(f"  F1-Score: {results['precision_recall_f1']['f1_score']:.4f}\n")
    f.write(f"  ECE: {results['ece']:.4f}\n")

print(f"\n✓ Detailed summary saved to ./results/experiment_summary.txt")
print("\n" + "="*80 + "\n")

UNCERTAINTY-AWARE DOMAIN ADAPTATION (COMMON CLASSES)

📋 Class Mapping Info:
   Common classes found: 9
   PlantVillage total: 38
   PlantDoc total: 28
   Match rate: 32.1%

⚙️  Training Configuration:
   Stage 1 epochs: 15
   Stage 2 epochs: 25
   Batch size: 32

🖥️  Device: cuda
   GPU: Tesla T4

IMPORTING MODULES

✓ data_preprocessing (with class filtering)
✓ model_architecture
✓ training_loop
✓ evaluation_code

CREATING DATA LOADERS (FILTERED TO COMMON CLASSES)

CREATING FILTERED DATASETS (COMMON CLASSES ONLY)

📋 Loading class mapping...
   Using 9 common classes
   Classes: grape black rot, potato early blight, potato late blight, squash powdery mildew, tomato bacterial spot...

📂 Loading PlantVillage (Source Domain)...
  Filtering from 38 to 9 common classes
  Loaded 10219 images from 9 classes
  Filtering from 38 to 9 common classes
  Loaded 1276 images from 9 classes
  Filtering from 38 to 9 common classes
  Loaded 1279 images from 9 classes

📂 Loading PlantDoc (Target Domain)..

Epoch 1/15: 100%|██████████| 319/319 [01:27<00:00,  3.64it/s, loss=0.8828, acc=81.33%]


Epoch 1: Train Acc = 81.33%, Val Acc = 86.52%


Epoch 2/15: 100%|██████████| 319/319 [01:25<00:00,  3.74it/s, loss=0.6614, acc=86.94%]


Epoch 2: Train Acc = 86.94%, Val Acc = 90.20%


Epoch 3/15: 100%|██████████| 319/319 [01:25<00:00,  3.71it/s, loss=0.6006, acc=88.10%]


Epoch 3: Train Acc = 88.10%, Val Acc = 89.18%


Epoch 4/15: 100%|██████████| 319/319 [01:26<00:00,  3.69it/s, loss=0.4862, acc=88.94%]


Epoch 4: Train Acc = 88.94%, Val Acc = 90.44%


Epoch 5/15: 100%|██████████| 319/319 [01:25<00:00,  3.73it/s, loss=0.5813, acc=89.77%]


Epoch 5: Train Acc = 89.77%, Val Acc = 91.46%


Epoch 6/15: 100%|██████████| 319/319 [01:26<00:00,  3.68it/s, loss=0.6629, acc=90.49%]


Epoch 6: Train Acc = 90.49%, Val Acc = 91.93%


Epoch 7/15: 100%|██████████| 319/319 [01:25<00:00,  3.72it/s, loss=0.6114, acc=90.87%]


Epoch 7: Train Acc = 90.87%, Val Acc = 92.71%


Epoch 8/15: 100%|██████████| 319/319 [01:25<00:00,  3.72it/s, loss=0.5178, acc=91.79%]


Epoch 8: Train Acc = 91.79%, Val Acc = 93.10%


Epoch 9/15: 100%|██████████| 319/319 [01:25<00:00,  3.72it/s, loss=0.4310, acc=91.52%]


Epoch 9: Train Acc = 91.52%, Val Acc = 93.34%


Epoch 10/15: 100%|██████████| 319/319 [01:26<00:00,  3.70it/s, loss=0.4327, acc=91.96%]


Epoch 10: Train Acc = 91.96%, Val Acc = 93.34%


Epoch 11/15: 100%|██████████| 319/319 [01:25<00:00,  3.73it/s, loss=0.6151, acc=90.38%]


Epoch 11: Train Acc = 90.38%, Val Acc = 90.91%


Epoch 12/15: 100%|██████████| 319/319 [01:25<00:00,  3.73it/s, loss=0.6578, acc=90.37%]


Epoch 12: Train Acc = 90.37%, Val Acc = 92.71%


Epoch 13/15: 100%|██████████| 319/319 [01:26<00:00,  3.70it/s, loss=0.5453, acc=90.68%]


Epoch 13: Train Acc = 90.68%, Val Acc = 91.46%


Epoch 14/15: 100%|██████████| 319/319 [01:25<00:00,  3.72it/s, loss=0.6607, acc=90.76%]


Epoch 14: Train Acc = 90.76%, Val Acc = 91.77%


Epoch 15/15: 100%|██████████| 319/319 [01:26<00:00,  3.67it/s, loss=0.4611, acc=91.08%]


Epoch 15: Train Acc = 91.08%, Val Acc = 93.03%

Stage 1 completed. Best validation accuracy: 93.34%

STAGE 2: DOMAIN ADAPTATION


Epoch 1/25: 100%|██████████| 66/66 [00:41<00:00,  1.58it/s, loss=0.7212, acc=90.39%, lambda=0.000]


Epoch 1: Train Acc = 90.39%, Target Val Acc = 39.02%, Lambda = 0.000


Epoch 2/25: 100%|██████████| 66/66 [00:42<00:00,  1.55it/s, loss=0.5481, acc=92.71%, lambda=0.197]


Epoch 2: Train Acc = 92.71%, Target Val Acc = 36.59%, Lambda = 0.197


Epoch 3/25: 100%|██████████| 66/66 [00:41<00:00,  1.60it/s, loss=0.3942, acc=92.28%, lambda=0.380]


Epoch 3: Train Acc = 92.28%, Target Val Acc = 34.15%, Lambda = 0.380


Epoch 4/25: 100%|██████████| 66/66 [00:41<00:00,  1.61it/s, loss=0.4679, acc=92.23%, lambda=0.537]


Epoch 4: Train Acc = 92.23%, Target Val Acc = 39.02%, Lambda = 0.537


Epoch 5/25: 100%|██████████| 66/66 [00:42<00:00,  1.56it/s, loss=0.5176, acc=91.71%, lambda=0.664]


Epoch 5: Train Acc = 91.71%, Target Val Acc = 41.46%, Lambda = 0.664


Epoch 6/25: 100%|██████████| 66/66 [00:41<00:00,  1.60it/s, loss=0.4587, acc=92.61%, lambda=0.762]


Epoch 6: Train Acc = 92.61%, Target Val Acc = 39.02%, Lambda = 0.762


Epoch 7/25: 100%|██████████| 66/66 [00:41<00:00,  1.61it/s, loss=0.4550, acc=92.00%, lambda=0.834]


Epoch 7: Train Acc = 92.00%, Target Val Acc = 39.02%, Lambda = 0.834


Epoch 8/25: 100%|██████████| 66/66 [00:42<00:00,  1.55it/s, loss=0.4800, acc=92.23%, lambda=0.885]


Epoch 8: Train Acc = 92.23%, Target Val Acc = 41.46%, Lambda = 0.885


Epoch 9/25: 100%|██████████| 66/66 [00:41<00:00,  1.60it/s, loss=0.3703, acc=92.90%, lambda=0.922]


Epoch 9: Train Acc = 92.90%, Target Val Acc = 36.59%, Lambda = 0.922


Epoch 10/25: 100%|██████████| 66/66 [00:41<00:00,  1.57it/s, loss=0.4171, acc=92.90%, lambda=0.947]


Epoch 10: Train Acc = 92.90%, Target Val Acc = 36.59%, Lambda = 0.947


Epoch 11/25: 100%|██████████| 66/66 [00:41<00:00,  1.60it/s, loss=0.7071, acc=93.28%, lambda=0.964]


Epoch 11: Train Acc = 93.28%, Target Val Acc = 39.02%, Lambda = 0.964


Epoch 12/25: 100%|██████████| 66/66 [00:41<00:00,  1.59it/s, loss=0.5231, acc=93.13%, lambda=0.976]


Epoch 12: Train Acc = 93.13%, Target Val Acc = 36.59%, Lambda = 0.976


Epoch 13/25: 100%|██████████| 66/66 [00:42<00:00,  1.56it/s, loss=0.6778, acc=93.09%, lambda=0.984]


Epoch 13: Train Acc = 93.09%, Target Val Acc = 36.59%, Lambda = 0.984


Epoch 14/25: 100%|██████████| 66/66 [00:41<00:00,  1.58it/s, loss=0.4532, acc=91.86%, lambda=0.989]


Epoch 14: Train Acc = 91.86%, Target Val Acc = 36.59%, Lambda = 0.989


Epoch 15/25: 100%|██████████| 66/66 [00:41<00:00,  1.60it/s, loss=0.6095, acc=92.42%, lambda=0.993]


Epoch 15: Train Acc = 92.42%, Target Val Acc = 39.02%, Lambda = 0.993
Early stopping at epoch 15

Stage 2 completed. Best target validation accuracy: 41.46%

✓ TRAINING COMPLETED
   Best target accuracy: 41.46%

EVALUATING MODEL

EVALUATING MODEL



Evaluating: 100%|██████████| 2/2 [00:01<00:00,  1.38it/s]



EVALUATION RESULTS

Overall Accuracy: 46.34%
Precision (Macro): 0.5838
Recall (Macro): 0.4407
F1-Score (Macro): 0.4622

Expected Calibration Error: 0.2306

Uncertainty-Stratified Performance:
  High Confidence:
    Accuracy: 100.00%
    Samples: 2 (4.9%)
  Medium Confidence:
    Accuracy: 56.52%
    Samples: 23 (56.1%)
  Low Confidence:
    Accuracy: 25.00%
    Samples: 16 (39.0%)

Generating visualizations...

Confusion matrix saved to ./results/confusion_matrix.png
Calibration curve saved to ./results/calibration_curve.png
Uncertainty distribution saved to ./results/uncertainty_distribution.png

All evaluation results saved to ./results

✓ Evaluation completed

🎉 EXPERIMENT COMPLETED!

📊 Dataset Info:
   Common classes used: 9
   Class names: grape black rot, potato early blight, potato late blight, squash powdery mildew, tomato bacterial spot...

📈 Results:
   Target Accuracy: 46.34%
   Precision: 0.5838
   Recall: 0.4407
   F1-Score: 0.4622
   ECE: 0.2306

💾 Saved Files:
   class_