# Validation Notebook: Multi-Task BEiT for Hierarchical Fungi Classification

This notebook validates the trained multitask BEiT model for hierarchical fungi classification across 6 taxonomic ranks, with special focus on **Amanita phalloides** (Death Cap) recognition.

## Overview
- **Model**: BEiT multi-task model with 6 classification heads
- **Dataset**: FungiTastic validation set (~1M samples)
- **Ranks**: Phylum, Class, Order, Family, Genus, Species
- **Special Focus**: Amanita phalloides (148 specimens in validation set)

---
## Section 1: Setup

In [None]:
# Cell 1: Import libraries
import os
import sys
import json
import pickle
from pathlib import Path
from collections import defaultdict

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from PIL import Image
from tqdm.auto import tqdm

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Metrics
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    confusion_matrix, classification_report
)

# Image transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Set random seed for reproducibility
np.random.seed(42)
torch.manual_seed(42)

# Display settings
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 10
sns.set_style('whitegrid')

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

In [None]:
# Cell 2: Configuration
CONFIG = {
    # Paths
    'checkpoint_path': '/home/j/Documents/git/amanita/artifacts/exp-organized-valley-fig-260101/checkpoints/AtomicDirectory_checkpoint_64/best_model.pt',
    'val_csv_path': '/media/j/Extra FAT/FungiTastic/dataset/FungiTastic/metadata/FungiTastic/FungiTastic-ClosedSet-Val.csv',
    'image_root': '/media/j/Extra FAT/FungiTastic/dataset/FungiTastic/FungiTastic/',
    'taxonomic_mappings_path': '/home/j/Documents/git/amanita/taxonomic_mappings.json',
    
    # Model settings
    'image_size': 224,
    'num_classes_dict': {
        'phylum': 7,
        'class': 28,
        'order': 95,
        'family': 308,
        'genus': 918,
        'species': 2786
    },
    
    # Dataloader settings
    'batch_size': 64,
    'num_workers': 8,
    
    # Thresholds for analysis
    'high_confidence_threshold': 0.9,
    'low_confidence_threshold': 0.5,
    
    # Output paths
    'output_dir': '/home/j/Documents/git/amanita/validation_results',
}

# Create output directory
os.makedirs(CONFIG['output_dir'], exist_ok=True)

# Taxonomic ranks in hierarchical order
TAXONOMIC_RANKS = ['phylum', 'class', 'order', 'family', 'genus', 'species']

# Amanita phalloides info
AMANITA_PHALLOIDES = {
    'species_name': 'Amanita phalloides',
    'species_id': 58,
    'genus_id': 14,
    'genus_name': 'Amanita',
    'family_name': 'Amanitaceae',
    'order_name': 'Agaricales',
    'class_name': 'Agaricomycetes',
    'phylum_name': 'Basidiomycota'
}

print("Configuration loaded.")
print(f"\nCheckpoint: {CONFIG['checkpoint_path']}")
print(f"Validation CSV: {CONFIG['val_csv_path']}")
print(f"Image root: {CONFIG['image_root']}")

In [None]:
# Cell 3: Load taxonomic mappings
with open(CONFIG['taxonomic_mappings_path'], 'r') as f:
    taxonomic_mappings = json.load(f)

name_to_id = taxonomic_mappings['name_to_id']
id_to_name = taxonomic_mappings['id_to_name']

print("Taxonomic mappings loaded:")
for rank in TAXONOMIC_RANKS:
    print(f"  {rank:8s}: {len(name_to_id[rank]):5d} classes")

# Verify Amanita phalloides mappings
print(f"\nVerifying Amanita phalloides:")
print(f"  Species ID: {name_to_id['species']['Amanita phalloides']}")
print(f"  Genus ID (Amanita): {name_to_id['genus']['Amanita']}")

---
## Section 2: Model Loading

In [None]:
# Cell 4: Import model class and create model
from models.beit_multitask import BEiTMultiTask, create_beit_multitask

# Create model (this loads the base BEiT architecture)
print("Creating BEiT multi-task model...")
model = create_beit_multitask(
    pretrained=False,  # We'll load trained weights
    num_classes_dict=CONFIG['num_classes_dict']
)

In [None]:
# Cell 5: Load checkpoint with DDP handling
print(f"Loading checkpoint from: {CONFIG['checkpoint_path']}")
checkpoint = torch.load(CONFIG['checkpoint_path'], map_location='cpu', weights_only=False)

# Get state dict
state_dict = checkpoint['model_state_dict']

# Handle DDP checkpoint format (strip 'module.' prefix if present)
if any(k.startswith('module.') for k in state_dict.keys()):
    print("Removing 'module.' prefix from DDP checkpoint...")
    state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}

# Load weights
model.load_state_dict(state_dict)
print("Weights loaded successfully!")

# Print checkpoint info if available
if 'epoch' in checkpoint:
    print(f"\nCheckpoint info:")
    print(f"  Epoch: {checkpoint.get('epoch', 'N/A')}")
    print(f"  Best val accuracy: {checkpoint.get('best_val_acc', 'N/A')}")

In [None]:
# Cell 6: Move model to GPU and set to eval mode
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model = model.to(device)
model.eval()

print(f"\nModel ready for inference.")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M")

---
## Section 2b: Data Loading

In [None]:
# Cell 6b: Load validation dataset
from dataset_multitask import FungiTasticMultiTask

# Load validation CSV
print(f"Loading validation data from: {CONFIG['val_csv_path']}")
val_df = pd.read_csv(CONFIG['val_csv_path'])
print(f"Loaded {len(val_df)} validation samples")

# Check for Amanita phalloides samples
amanita_mask = val_df['species'] == 'Amanita phalloides'
print(f"\nAmanita phalloides samples: {amanita_mask.sum()}")

# Validation transforms (no augmentation)
val_transform = A.Compose([
    A.Resize(CONFIG['image_size'], CONFIG['image_size']),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

# Create dataset
val_dataset = FungiTasticMultiTask(
    df=val_df,
    transform=val_transform,
    taxonomic_mappings=taxonomic_mappings,
    image_root=CONFIG['image_root'],
    split='val'
)

print(f"\nDataset created with {len(val_dataset)} samples")

In [None]:
# Cell 6c: Create DataLoader with custom collate function
def collate_fn(batch):
    """Collate batch with dictionary labels."""
    images = torch.stack([item[0] for item in batch])
    
    labels = {}
    for rank in TAXONOMIC_RANKS:
        labels[rank] = torch.tensor([item[1][rank] for item in batch], dtype=torch.long)
    
    filepaths = [item[2] for item in batch]
    
    return images, labels, filepaths

val_loader = DataLoader(
    val_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    num_workers=CONFIG['num_workers'],
    pin_memory=True,
    collate_fn=collate_fn
)

print(f"DataLoader created with {len(val_loader)} batches")
print(f"Batch size: {CONFIG['batch_size']}")

---
## Section 3: Validation Loop

In [None]:
# Cell 7: Run inference on validation set
print("Running validation...")
print("="*60)

# Storage for results
results = {
    'predictions': {rank: [] for rank in TAXONOMIC_RANKS},
    'labels': {rank: [] for rank in TAXONOMIC_RANKS},
    'confidences': {rank: [] for rank in TAXONOMIC_RANKS},
    'top5_predictions': {rank: [] for rank in TAXONOMIC_RANKS},
    'top5_confidences': {rank: [] for rank in TAXONOMIC_RANKS},
    'filepaths': []
}

with torch.no_grad():
    for batch_idx, (images, labels, filepaths) in enumerate(tqdm(val_loader, desc="Validating")):
        # Move to device
        images = images.to(device)
        
        # Forward pass
        outputs = model(images)
        
        # Process each rank
        for rank in TAXONOMIC_RANKS:
            logits = outputs[rank]
            probs = F.softmax(logits, dim=1)
            
            # Top-1 predictions and confidence
            confidence, predictions = probs.max(dim=1)
            results['predictions'][rank].extend(predictions.cpu().numpy())
            results['confidences'][rank].extend(confidence.cpu().numpy())
            results['labels'][rank].extend(labels[rank].numpy())
            
            # Top-5 predictions and confidences
            top5_conf, top5_pred = probs.topk(5, dim=1)
            results['top5_predictions'][rank].extend(top5_pred.cpu().numpy())
            results['top5_confidences'][rank].extend(top5_conf.cpu().numpy())
        
        # Store filepaths
        results['filepaths'].extend(filepaths)

print(f"\nValidation complete!")
print(f"Total samples processed: {len(results['filepaths'])}")

In [None]:
# Cell 8: Convert results to numpy arrays
for rank in TAXONOMIC_RANKS:
    results['predictions'][rank] = np.array(results['predictions'][rank])
    results['labels'][rank] = np.array(results['labels'][rank])
    results['confidences'][rank] = np.array(results['confidences'][rank])
    results['top5_predictions'][rank] = np.array(results['top5_predictions'][rank])
    results['top5_confidences'][rank] = np.array(results['top5_confidences'][rank])

print("Results converted to numpy arrays.")
print(f"\nShape verification:")
for rank in TAXONOMIC_RANKS:
    print(f"  {rank:8s}: predictions={results['predictions'][rank].shape}, top5={results['top5_predictions'][rank].shape}")

---
## Section 4: Metrics Calculation

In [None]:
# Cell 9: Calculate per-rank metrics
metrics = {}

for rank in TAXONOMIC_RANKS:
    preds = results['predictions'][rank]
    lbls = results['labels'][rank]
    confs = results['confidences'][rank]
    top5_preds = results['top5_predictions'][rank]
    
    # Basic metrics
    metrics[rank] = {
        'accuracy': accuracy_score(lbls, preds),
        'precision': precision_score(lbls, preds, average='macro', zero_division=0),
        'recall': recall_score(lbls, preds, average='macro', zero_division=0),
        'f1': f1_score(lbls, preds, average='macro', zero_division=0),
    }
    
    # Top-5 accuracy
    top5_correct = np.any(top5_preds == lbls.reshape(-1, 1), axis=1)
    metrics[rank]['top5_accuracy'] = top5_correct.mean()
    
    # Confidence statistics
    correct_mask = preds == lbls
    metrics[rank]['avg_confidence'] = confs.mean()
    metrics[rank]['avg_confidence_correct'] = confs[correct_mask].mean() if correct_mask.sum() > 0 else 0
    metrics[rank]['avg_confidence_incorrect'] = confs[~correct_mask].mean() if (~correct_mask).sum() > 0 else 0
    metrics[rank]['confidence_gap'] = metrics[rank]['avg_confidence_correct'] - metrics[rank]['avg_confidence_incorrect']

# Calculate hierarchical accuracy (all ranks correct)
all_correct = np.ones(len(results['filepaths']), dtype=bool)
for rank in TAXONOMIC_RANKS:
    all_correct &= (results['predictions'][rank] == results['labels'][rank])
metrics['hierarchical_accuracy'] = all_correct.mean()

print("Metrics calculated!")

In [None]:
# Cell 10: Display metrics table
print("\n" + "="*80)
print("VALIDATION METRICS SUMMARY")
print("="*80)

# Create metrics dataframe
metrics_data = []
for rank in TAXONOMIC_RANKS:
    m = metrics[rank]
    metrics_data.append({
        'Rank': rank.capitalize(),
        'Top-1 Acc': f"{m['accuracy']:.2%}",
        'Top-5 Acc': f"{m['top5_accuracy']:.2%}",
        'Precision': f"{m['precision']:.2%}",
        'Recall': f"{m['recall']:.2%}",
        'F1 (Macro)': f"{m['f1']:.2%}",
        'Avg Conf': f"{m['avg_confidence']:.3f}",
        'Conf Gap': f"{m['confidence_gap']:.3f}"
    })

metrics_df = pd.DataFrame(metrics_data)
print(metrics_df.to_string(index=False))

print(f"\nHierarchical Accuracy (all ranks correct): {metrics['hierarchical_accuracy']:.2%}")

# Average metrics
avg_acc = np.mean([metrics[r]['accuracy'] for r in TAXONOMIC_RANKS])
avg_f1 = np.mean([metrics[r]['f1'] for r in TAXONOMIC_RANKS])
print(f"Average Accuracy across ranks: {avg_acc:.2%}")
print(f"Average F1 across ranks: {avg_f1:.2%}")

---
## Section 5: Confidence Analysis

In [None]:
# Cell 11: Plot confidence distributions
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

for idx, rank in enumerate(TAXONOMIC_RANKS):
    ax = axes[idx]
    
    preds = results['predictions'][rank]
    lbls = results['labels'][rank]
    confs = results['confidences'][rank]
    
    correct_mask = preds == lbls
    
    # Plot distributions with error handling for edge cases
    try:
        if correct_mask.sum() > 0:
            ax.hist(confs[correct_mask], bins=50, alpha=0.6, label='Correct', color='green', density=True)
        if (~correct_mask).sum() > 0:
            ax.hist(confs[~correct_mask], bins=50, alpha=0.6, label='Incorrect', color='red', density=True)
    except ValueError:
        # Fall back to fewer bins if data range is too small
        if correct_mask.sum() > 0:
            ax.hist(confs[correct_mask], bins='auto', alpha=0.6, label='Correct', color='green', density=True)
        if (~correct_mask).sum() > 0:
            ax.hist(confs[~correct_mask], bins='auto', alpha=0.6, label='Incorrect', color='red', density=True)
    
    ax.set_xlabel('Confidence')
    ax.set_ylabel('Density')
    ax.set_title(f'{rank.capitalize()}\nAcc: {metrics[rank]["accuracy"]:.2%}')
    ax.legend()
    ax.set_xlim(0, 1)

plt.suptitle('Confidence Distributions: Correct vs Incorrect Predictions', fontsize=14, y=1.02)
plt.tight_layout()
plt.savefig(os.path.join(CONFIG['output_dir'], 'confidence_distributions.png'), dpi=150, bbox_inches='tight')
plt.show()

print(f"Saved: {CONFIG['output_dir']}/confidence_distributions.png")

---
## Section 6: Confusion Matrices

In [None]:
# Cell 12: Confusion matrix for Phylum (7 classes)
rank = 'phylum'
preds = results['predictions'][rank]
lbls = results['labels'][rank]

cm = confusion_matrix(lbls, preds)

# Get class names
class_names = [id_to_name[rank][str(i)] for i in range(CONFIG['num_classes_dict'][rank])]

plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title(f'Confusion Matrix - Phylum (Acc: {metrics[rank]["accuracy"]:.2%})')
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.savefig(os.path.join(CONFIG['output_dir'], 'confusion_matrix_phylum.png'), dpi=150, bbox_inches='tight')
plt.show()

print(f"Saved: {CONFIG['output_dir']}/confusion_matrix_phylum.png")

In [None]:
# Cell 13: Confusion matrix for Class (28 classes)
rank = 'class'
preds = results['predictions'][rank]
lbls = results['labels'][rank]

cm = confusion_matrix(lbls, preds)

# Get class names
class_names = [id_to_name[rank][str(i)] for i in range(CONFIG['num_classes_dict'][rank])]

plt.figure(figsize=(14, 12))
sns.heatmap(cm, annot=False, cmap='Blues',
            xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title(f'Confusion Matrix - Class (Acc: {metrics[rank]["accuracy"]:.2%})')
plt.xticks(rotation=90, ha='center', fontsize=8)
plt.yticks(rotation=0, fontsize=8)
plt.tight_layout()
plt.savefig(os.path.join(CONFIG['output_dir'], 'confusion_matrix_class.png'), dpi=150, bbox_inches='tight')
plt.show()

print(f"Saved: {CONFIG['output_dir']}/confusion_matrix_class.png")

In [None]:
# Cell 14: Top confused pairs for higher ranks (Order, Family, Genus, Species)
def get_top_confused_pairs(predictions, labels, id_to_name_dict, top_n=15):
    """Get top-N most confused class pairs."""
    confusion_counts = defaultdict(int)
    
    for pred, label in zip(predictions, labels):
        if pred != label:
            confusion_counts[(label, pred)] += 1
    
    # Sort by count
    sorted_pairs = sorted(confusion_counts.items(), key=lambda x: x[1], reverse=True)[:top_n]
    
    # Convert to readable format
    results = []
    for (true_id, pred_id), count in sorted_pairs:
        true_name = id_to_name_dict[str(true_id)]
        pred_name = id_to_name_dict[str(pred_id)]
        results.append({
            'True': true_name,
            'Predicted': pred_name,
            'Count': count
        })
    
    return pd.DataFrame(results)

print("Top-15 Confused Pairs by Rank")
print("="*80)

for rank in ['order', 'family', 'genus', 'species']:
    print(f"\n{rank.upper()}:")
    confused_df = get_top_confused_pairs(
        results['predictions'][rank],
        results['labels'][rank],
        id_to_name[rank]
    )
    print(confused_df.to_string(index=False))

---
## Section 7: Well vs Poorly Recognized Specimens

In [None]:
# Cell 15: Identify well and poorly recognized specimens (at species level)
rank = 'species'
preds = results['predictions'][rank]
lbls = results['labels'][rank]
confs = results['confidences'][rank]

correct_mask = preds == lbls

# Well-recognized: correct prediction + high confidence
well_recognized_mask = correct_mask & (confs >= CONFIG['high_confidence_threshold'])

# Poorly-recognized: incorrect OR low confidence
poorly_recognized_mask = ~correct_mask | (confs < CONFIG['low_confidence_threshold'])

print(f"Well-recognized specimens (correct + conf >= {CONFIG['high_confidence_threshold']}): {well_recognized_mask.sum():,}")
print(f"Poorly-recognized specimens (incorrect OR conf < {CONFIG['low_confidence_threshold']}): {poorly_recognized_mask.sum():,}")

# Get indices
well_recognized_indices = np.where(well_recognized_mask)[0]
poorly_recognized_indices = np.where(poorly_recognized_mask & ~correct_mask)[0]  # Focus on incorrect ones


# Helper function to filter out microscopy/spore images
def is_microscopy_image(filepath, caption=None):
    """
    Detect if an image is likely a microscopy/spore image rather than a macroscopic photo.
    
    Uses multiple heuristics:
    1. Caption-based detection (if captions available)
    2. Filename patterns
    
    Returns True if image should be excluded.
    """
    # Check caption for microscopy keywords
    if caption:
        caption_lower = caption.lower()
        microscopy_keywords = [
            'spore', 'microscop', 'magnif', 'slide', 'cell', 'hypha', 'hyphae',
            'basidi', 'ascus', 'asci', 'cystid', 'gill section', 'cross section',
            'Î¼m', 'micron', '400x', '1000x', 'oil immersion'
        ]
        for keyword in microscopy_keywords:
            if keyword in caption_lower:
                return True
    
    # Check filename patterns (some datasets use prefixes for microscopy)
    filename = os.path.basename(filepath).lower()
    micro_patterns = ['micro', 'spore', 'slide', 'section']
    for pattern in micro_patterns:
        if pattern in filename:
            return True
    
    return False


def filter_non_microscopy_indices(indices, filepaths, val_df=None):
    """
    Filter indices to exclude microscopy images.
    
    Args:
        indices: Array of indices to filter
        filepaths: List of all filepaths (indexed by results indices)
        val_df: Optional DataFrame with captions column
    
    Returns:
        Filtered array of indices
    """
    filtered = []
    for idx in indices:
        filepath = filepaths[idx]
        
        # Try to get caption if available
        caption = None
        if val_df is not None and 'captions' in val_df.columns:
            try:
                # Find matching row by filename
                filename = os.path.basename(filepath)
                matching = val_df[val_df['filename'] == filename]
                if len(matching) > 0:
                    caption = matching.iloc[0]['captions']
            except Exception:
                pass
        
        if not is_microscopy_image(filepath, caption):
            filtered.append(idx)
    
    return np.array(filtered)


# Filter indices to exclude microscopy images
print("\nFiltering out potential microscopy/spore images...")
try:
    well_recognized_indices_filtered = filter_non_microscopy_indices(
        well_recognized_indices, results['filepaths'], val_df
    )
    poorly_recognized_indices_filtered = filter_non_microscopy_indices(
        poorly_recognized_indices, results['filepaths'], val_df
    )
    print(f"Well-recognized after filtering: {len(well_recognized_indices_filtered):,} (removed {len(well_recognized_indices) - len(well_recognized_indices_filtered)})")
    print(f"Poorly-recognized after filtering: {len(poorly_recognized_indices_filtered):,} (removed {len(poorly_recognized_indices) - len(poorly_recognized_indices_filtered)})")
    
    # Use filtered indices
    well_recognized_indices = well_recognized_indices_filtered
    poorly_recognized_indices = poorly_recognized_indices_filtered
except Exception as e:
    print(f"Warning: Could not filter microscopy images: {e}")
    print("Using unfiltered indices.")

In [None]:
# Cell 16: Display well-recognized specimens
def display_specimens(indices, title, n_samples=12, n_cols=4):
    """Display a grid of specimen images with predictions."""
    if len(indices) == 0:
        print(f"No specimens found for: {title}")
        return None
    
    # Sample indices if too many
    if len(indices) > n_samples:
        sample_indices = np.random.choice(indices, n_samples, replace=False)
    else:
        sample_indices = indices[:n_samples]
    
    n_rows = (len(sample_indices) + n_cols - 1) // n_cols
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(4*n_cols, 4*n_rows))
    axes = np.array(axes).flatten()
    
    for ax_idx, sample_idx in enumerate(sample_indices):
        ax = axes[ax_idx]
        
        # Load and display image
        filepath = results['filepaths'][sample_idx]
        try:
            img = Image.open(filepath)
            ax.imshow(img)
        except Exception as e:
            ax.text(0.5, 0.5, 'Image\nNot Found', ha='center', va='center')
        
        # Get prediction info
        pred_id = results['predictions']['species'][sample_idx]
        true_id = results['labels']['species'][sample_idx]
        conf = results['confidences']['species'][sample_idx]
        
        pred_name = id_to_name['species'][str(pred_id)]
        true_name = id_to_name['species'][str(true_id)]
        
        is_correct = pred_id == true_id
        color = 'green' if is_correct else 'red'
        
        ax.set_title(f"True: {true_name[:25]}\nPred: {pred_name[:25]}\nConf: {conf:.3f}",
                     fontsize=8, color=color)
        ax.axis('off')
    
    # Hide empty axes
    for ax_idx in range(len(sample_indices), len(axes)):
        axes[ax_idx].axis('off')
    
    plt.suptitle(title, fontsize=14, y=1.02)
    plt.tight_layout()
    return fig

# Display well-recognized specimens
fig = display_specimens(well_recognized_indices, 'Well-Recognized Specimens (High Confidence Correct)')
if fig:
    plt.savefig(os.path.join(CONFIG['output_dir'], 'well_recognized_specimens.png'), dpi=150, bbox_inches='tight')
    plt.show()
    print(f"Saved: {CONFIG['output_dir']}/well_recognized_specimens.png")

In [None]:
# Cell 17: Display poorly-recognized specimens
fig = display_specimens(poorly_recognized_indices, 'Poorly-Recognized Specimens (Incorrect Predictions)')
if fig:
    plt.savefig(os.path.join(CONFIG['output_dir'], 'poorly_recognized_specimens.png'), dpi=150, bbox_inches='tight')
    plt.show()
    print(f"Saved: {CONFIG['output_dir']}/poorly_recognized_specimens.png")

---
## Section 8: Single Observation Validation

In [None]:
# Cell 18: Single observation validation function
def validate_single_observation(idx_or_path, show_image=True):
    """
    Validate a single observation and show top-5 predictions at all ranks.
    
    Args:
        idx_or_path: Either an index into the validation set or a filepath string
        show_image: Whether to display the image
    
    Returns:
        dict: Prediction results for all ranks
    """
    # Determine if we're using pre-computed results or need to run inference
    if isinstance(idx_or_path, int):
        idx = idx_or_path
        filepath = results['filepaths'][idx]
        use_precomputed = True
    else:
        filepath = idx_or_path
        use_precomputed = False
        # Find index if filepath exists in results
        if filepath in results['filepaths']:
            idx = results['filepaths'].index(filepath)
            use_precomputed = True
    
    print(f"Validating: {filepath}")
    print("="*80)
    
    # Display image
    if show_image:
        try:
            img = Image.open(filepath)
            plt.figure(figsize=(6, 6))
            plt.imshow(img)
            plt.axis('off')
            plt.title(os.path.basename(filepath))
            plt.show()
        except Exception as e:
            print(f"Could not display image: {e}")
    
    # Get results
    observation_results = {}
    
    print(f"\n{'Rank':<10} {'True':<30} {'Predicted':<30} {'Conf':<8} {'Correct'}")
    print("-"*90)
    
    for rank in TAXONOMIC_RANKS:
        if use_precomputed:
            true_id = results['labels'][rank][idx]
            pred_id = results['predictions'][rank][idx]
            conf = results['confidences'][rank][idx]
            top5_preds = results['top5_predictions'][rank][idx]
            top5_confs = results['top5_confidences'][rank][idx]
        else:
            # Would need to run inference - not implemented for external files
            print(f"External file inference not implemented. Use validation set index.")
            return None
        
        true_name = id_to_name[rank][str(true_id)]
        pred_name = id_to_name[rank][str(pred_id)]
        is_correct = true_id == pred_id
        
        print(f"{rank.capitalize():<10} {true_name:<30} {pred_name:<30} {conf:.4f}  {'YES' if is_correct else 'NO'}")
        
        observation_results[rank] = {
            'true_id': int(true_id),
            'true_name': true_name,
            'pred_id': int(pred_id),
            'pred_name': pred_name,
            'confidence': float(conf),
            'correct': is_correct,
            'top5': [(id_to_name[rank][str(p)], float(c)) for p, c in zip(top5_preds, top5_confs)]
        }
    
    # Show top-5 predictions for species
    print(f"\nTop-5 Species Predictions:")
    for i, (name, conf) in enumerate(observation_results['species']['top5']):
        marker = '*' if name == observation_results['species']['true_name'] else ' '
        print(f"  {i+1}. {name:<40} {conf:.4f} {marker}")
    
    return observation_results

In [None]:
# Cell 19: Example - validate a random sample
random_idx = np.random.randint(0, len(results['filepaths']))
example_results = validate_single_observation(random_idx)

---
## Section 9: Amanita Phalloides Analysis

In [None]:
# Cell 20: Filter Amanita phalloides specimens
amanita_species_id = name_to_id['species']['Amanita phalloides']
amanita_mask = results['labels']['species'] == amanita_species_id
amanita_indices = np.where(amanita_mask)[0]

print(f"Amanita phalloides (Death Cap) Analysis")
print("="*80)
print(f"Total specimens in validation set: {len(amanita_indices)}")
print(f"Species ID: {amanita_species_id}")

In [None]:
# Cell 21: Species-level analysis for Amanita phalloides
amanita_preds = results['predictions']['species'][amanita_mask]
amanita_confs = results['confidences']['species'][amanita_mask]

# Accuracy at species level
species_correct = amanita_preds == amanita_species_id
species_accuracy = species_correct.mean()

print(f"\nSpecies-Level Performance:")
print(f"  Correctly identified: {species_correct.sum()} / {len(amanita_indices)} ({species_accuracy:.2%})")
print(f"  Average confidence when correct: {amanita_confs[species_correct].mean():.4f}" if species_correct.sum() > 0 else "  N/A")
print(f"  Average confidence when incorrect: {amanita_confs[~species_correct].mean():.4f}" if (~species_correct).sum() > 0 else "  N/A")

# What is it confused with?
if (~species_correct).sum() > 0:
    print(f"\nMisclassified as (species level):")
    confused_species = pd.Series(amanita_preds[~species_correct]).value_counts().head(10)
    for pred_id, count in confused_species.items():
        pred_name = id_to_name['species'][str(pred_id)]
        print(f"  {pred_name:<40}: {count} times")

In [None]:
# Cell 22: Hierarchical analysis - accuracy at each rank
print(f"\nHierarchical Accuracy for Amanita phalloides:")
print("-"*60)

amanita_hierarchical = {}
for rank in TAXONOMIC_RANKS:
    true_name = AMANITA_PHALLOIDES.get(f'{rank}_name', 'Amanita phalloides' if rank == 'species' else None)
    if true_name is None:
        continue
    
    true_id = name_to_id[rank][true_name]
    rank_preds = results['predictions'][rank][amanita_mask]
    rank_correct = rank_preds == true_id
    rank_accuracy = rank_correct.mean()
    
    amanita_hierarchical[rank] = {
        'true_name': true_name,
        'true_id': true_id,
        'correct': rank_correct.sum(),
        'total': len(rank_correct),
        'accuracy': rank_accuracy
    }
    
    print(f"  {rank.capitalize():8s}: {rank_correct.sum():4d}/{len(rank_correct):4d} ({rank_accuracy:.2%}) - {true_name}")

# Key finding: Can we at least identify it as Amanita genus?
genus_correct = amanita_hierarchical['genus']['accuracy']
species_correct_pct = amanita_hierarchical['species']['accuracy']

print(f"\n" + "="*60)
print(f"KEY FINDING: Amanita phalloides Recognition")
print(f"="*60)
print(f"Species-level accuracy: {species_correct_pct:.2%}")
print(f"Genus-level accuracy (Amanita): {genus_correct:.2%}")

if genus_correct > species_correct_pct:
    print(f"\n> Even when species is wrong, {genus_correct - species_correct_pct:.2%} more are correctly")
    print(f"  identified at the genus level (Amanita).")
    print(f"  This is valuable for poisonous mushroom detection!")

In [None]:
# Cell 23: Display Amanita phalloides specimens
# Show both correctly and incorrectly classified specimens

species_correct_mask = results['predictions']['species'][amanita_mask] == amanita_species_id
correct_amanita_indices = amanita_indices[species_correct_mask]
incorrect_amanita_indices = amanita_indices[~species_correct_mask]

print(f"Correctly classified Amanita phalloides: {len(correct_amanita_indices)}")
print(f"Incorrectly classified Amanita phalloides: {len(incorrect_amanita_indices)}")

# Display correctly classified ones
if len(correct_amanita_indices) > 0:
    fig = display_specimens(correct_amanita_indices, 
                           f'Correctly Identified Amanita phalloides ({len(correct_amanita_indices)} specimens)',
                           n_samples=8, n_cols=4)
    if fig:
        plt.show()

In [None]:
# Cell 24: Display misclassified Amanita phalloides with analysis
if len(incorrect_amanita_indices) > 0:
    print(f"\nMisclassified Amanita phalloides specimens:")
    print("="*80)
    
    # Show detailed predictions for misclassified ones
    for idx in incorrect_amanita_indices[:5]:  # Show first 5
        print(f"\n--- Specimen at index {idx} ---")
        _ = validate_single_observation(idx, show_image=True)
    
    # Summary figure
    fig = display_specimens(incorrect_amanita_indices,
                           f'Misclassified Amanita phalloides ({len(incorrect_amanita_indices)} specimens)',
                           n_samples=8, n_cols=4)
    if fig:
        plt.savefig(os.path.join(CONFIG['output_dir'], 'amanita_phalloides_analysis.png'), 
                    dpi=150, bbox_inches='tight')
        plt.show()
        print(f"\nSaved: {CONFIG['output_dir']}/amanita_phalloides_analysis.png")
else:
    print("\nAll Amanita phalloides specimens were correctly classified!")

---
## Section 10: HTML Report Generation

In [None]:
# Cell 25: Generate HTML report
def generate_html_report(metrics, results, amanita_hierarchical, output_path):
    """Generate a comprehensive HTML validation report."""
    
    html = f"""
<!DOCTYPE html>
<html>
<head>
    <title>Validation Report - Multi-Task BEiT Fungi Classification</title>
    <style>
        body {{ font-family: Arial, sans-serif; margin: 40px; background-color: #f5f5f5; }}
        h1 {{ color: #2c3e50; border-bottom: 3px solid #3498db; padding-bottom: 10px; }}
        h2 {{ color: #34495e; margin-top: 30px; }}
        h3 {{ color: #7f8c8d; }}
        table {{ border-collapse: collapse; width: 100%; margin: 20px 0; background: white; }}
        th, td {{ border: 1px solid #ddd; padding: 12px; text-align: left; }}
        th {{ background-color: #3498db; color: white; }}
        tr:nth-child(even) {{ background-color: #f9f9f9; }}
        tr:hover {{ background-color: #f1f1f1; }}
        .metric-good {{ color: #27ae60; font-weight: bold; }}
        .metric-medium {{ color: #f39c12; font-weight: bold; }}
        .metric-poor {{ color: #e74c3c; font-weight: bold; }}
        .key-finding {{ background-color: #e8f6f3; border-left: 4px solid #1abc9c; padding: 15px; margin: 20px 0; }}
        .warning {{ background-color: #fdf2e9; border-left: 4px solid #e67e22; padding: 15px; margin: 20px 0; }}
        .image-container {{ text-align: center; margin: 20px 0; }}
        img {{ max-width: 100%; height: auto; border: 1px solid #ddd; }}
        .summary-box {{ background: white; padding: 20px; border-radius: 5px; box-shadow: 0 2px 5px rgba(0,0,0,0.1); margin: 20px 0; }}
    </style>
</head>
<body>
    <h1>Validation Report: Multi-Task BEiT for Hierarchical Fungi Classification</h1>
    
    <div class="summary-box">
        <h2>Summary</h2>
        <p><strong>Model:</strong> BEiT Multi-Task (6 classification heads)</p>
        <p><strong>Validation Samples:</strong> {len(results['filepaths']):,}</p>
        <p><strong>Hierarchical Accuracy:</strong> <span class="{'metric-good' if metrics['hierarchical_accuracy'] > 0.5 else 'metric-medium' if metrics['hierarchical_accuracy'] > 0.3 else 'metric-poor'}">{metrics['hierarchical_accuracy']:.2%}</span></p>
    </div>
    
    <h2>Per-Rank Metrics</h2>
    <table>
        <tr>
            <th>Rank</th>
            <th>Top-1 Accuracy</th>
            <th>Top-5 Accuracy</th>
            <th>Precision</th>
            <th>Recall</th>
            <th>F1 (Macro)</th>
            <th>Avg Confidence</th>
        </tr>
"""
    
    for rank in TAXONOMIC_RANKS:
        m = metrics[rank]
        acc_class = 'metric-good' if m['accuracy'] > 0.8 else 'metric-medium' if m['accuracy'] > 0.5 else 'metric-poor'
        html += f"""
        <tr>
            <td><strong>{rank.capitalize()}</strong></td>
            <td class="{acc_class}">{m['accuracy']:.2%}</td>
            <td>{m['top5_accuracy']:.2%}</td>
            <td>{m['precision']:.2%}</td>
            <td>{m['recall']:.2%}</td>
            <td>{m['f1']:.2%}</td>
            <td>{m['avg_confidence']:.3f}</td>
        </tr>
"""
    
    html += """    </table>
    
    <h2>Confidence Distributions</h2>
    <div class="image-container">
        <img src="confidence_distributions.png" alt="Confidence Distributions">
    </div>
    
    <h2>Confusion Matrices</h2>
    <h3>Phylum (7 classes)</h3>
    <div class="image-container">
        <img src="confusion_matrix_phylum.png" alt="Confusion Matrix - Phylum">
    </div>
    
    <h3>Class (28 classes)</h3>
    <div class="image-container">
        <img src="confusion_matrix_class.png" alt="Confusion Matrix - Class">
    </div>
    
    <h2>Amanita phalloides (Death Cap) Analysis</h2>
    <div class="warning">
        <strong>Important:</strong> Amanita phalloides is one of the most poisonous mushrooms. 
        Accurate identification is critical for public safety.
    </div>
    
    <table>
        <tr>
            <th>Taxonomic Rank</th>
            <th>Expected Value</th>
            <th>Correct</th>
            <th>Total</th>
            <th>Accuracy</th>
        </tr>
"""
    
    for rank in TAXONOMIC_RANKS:
        if rank in amanita_hierarchical:
            h = amanita_hierarchical[rank]
            acc_class = 'metric-good' if h['accuracy'] > 0.8 else 'metric-medium' if h['accuracy'] > 0.5 else 'metric-poor'
            html += f"""
        <tr>
            <td><strong>{rank.capitalize()}</strong></td>
            <td>{h['true_name']}</td>
            <td>{h['correct']}</td>
            <td>{h['total']}</td>
            <td class="{acc_class}">{h['accuracy']:.2%}</td>
        </tr>
"""
    
    # Key findings
    species_acc = amanita_hierarchical.get('species', {}).get('accuracy', 0)
    genus_acc = amanita_hierarchical.get('genus', {}).get('accuracy', 0)
    
    html += f"""
    </table>
    
    <div class="key-finding">
        <h3>Key Finding</h3>
        <p><strong>Species-level recognition:</strong> {species_acc:.2%}</p>
        <p><strong>Genus-level recognition (Amanita):</strong> {genus_acc:.2%}</p>
        <p>Even when the exact species is misidentified, the model correctly identifies 
           the genus as Amanita in {genus_acc:.2%} of cases, which is valuable for 
           flagging potentially dangerous specimens.</p>
    </div>
    
    <h2>Specimen Examples</h2>
    <h3>Well-Recognized Specimens</h3>
    <div class="image-container">
        <img src="well_recognized_specimens.png" alt="Well-Recognized Specimens">
    </div>
    
    <h3>Poorly-Recognized Specimens</h3>
    <div class="image-container">
        <img src="poorly_recognized_specimens.png" alt="Poorly-Recognized Specimens">
    </div>
    
    <footer style="margin-top: 40px; padding-top: 20px; border-top: 1px solid #ddd; color: #7f8c8d;">
        <p>Generated by validation_notebook.ipynb</p>
    </footer>
</body>
</html>
"""
    
    with open(output_path, 'w') as f:
        f.write(html)
    
    return output_path

# Generate report
report_path = os.path.join(CONFIG['output_dir'], 'validation_report.html')
generate_html_report(metrics, results, amanita_hierarchical, report_path)
print(f"HTML report generated: {report_path}")

In [None]:
# Cell 26: Save results to pickle for future analysis
save_data = {
    'metrics': metrics,
    'results': results,
    'amanita_hierarchical': amanita_hierarchical,
    'config': CONFIG,
    'taxonomic_mappings': taxonomic_mappings
}

pickle_path = os.path.join(CONFIG['output_dir'], 'validation_results.pkl')
with open(pickle_path, 'wb') as f:
    pickle.dump(save_data, f)

print(f"Results saved to: {pickle_path}")

# Summary of output files
print(f"\n" + "="*60)
print("OUTPUT FILES GENERATED:")
print("="*60)
for f in os.listdir(CONFIG['output_dir']):
    filepath = os.path.join(CONFIG['output_dir'], f)
    size = os.path.getsize(filepath) / 1024
    print(f"  {f:<40} ({size:.1f} KB)")

---
## Final Summary

In [None]:
# Cell 27: Final summary
print("\n" + "="*80)
print("VALIDATION COMPLETE")
print("="*80)

print(f"\nModel Performance Summary:")
print(f"-"*40)
for rank in TAXONOMIC_RANKS:
    print(f"  {rank.capitalize():8s}: {metrics[rank]['accuracy']:.2%} (Top-5: {metrics[rank]['top5_accuracy']:.2%})")

print(f"\nHierarchical Accuracy (all ranks correct): {metrics['hierarchical_accuracy']:.2%}")

print(f"\nAmanita phalloides Recognition:")
print(f"-"*40)
print(f"  Species-level: {amanita_hierarchical['species']['accuracy']:.2%}")
print(f"  Genus-level:   {amanita_hierarchical['genus']['accuracy']:.2%}")

print(f"\nOutput files saved to: {CONFIG['output_dir']}")
print(f"Open {CONFIG['output_dir']}/validation_report.html in a browser to view the full report.")