<a href="https://colab.research.google.com/github/cat-thomson/ViT-FishID/blob/main/ViT_FishID_Colab_Training_final_working.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 🐟 ViT-FishID: Semi-Supervised Fish Classification

**COMPLETE TRAINING PIPELINE WITH GOOGLE COLAB**

<a href="https://colab.research.google.com/github/cat-thomson/ViT-FishID/blob/main/ViT_FishID_Colab_Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 🎯 What This Notebook Does

This notebook implements a **complete semi-supervised learning pipeline** for fish species classification using:

**🤖 Vision Transformer (ViT)**: State-of-the-art transformer architecture for image classification
**📊 Semi-Supervised Learning**: Leverages both labeled and unlabeled fish images
**🎓 EMA Teacher-Student Framework**: Uses exponential moving averages for consistency training
**☁️ Google Colab**: Cloud-based training with GPU acceleration

## 📊 Expected Performance

- **Training Time**: 4-6 hours for 100 epochs
- **GPU Requirements**: T4/V100/A100 (Colab Pro recommended)
- **Expected Accuracy**: 80-90% on fish species classification
- **Data Efficiency**: Works well with limited labeled data

## 🛠️ What You Need

1. **Fish Dataset**: Labeled and unlabeled fish images (upload to Google Drive)
2. **Google Colab Pro**: Recommended for longer training sessions
3. **Weights & Biases Account**: Optional for experiment tracking

## 🔧 Step 1: Environment Setup and GPU Check

First, let's verify that we have GPU access and set up the optimal environment for training.

In [1]:
# Check GPU availability and system information
import torch
import os
import gc

print("🔍 SYSTEM INFORMATION")
print("="*50)
print(f"Python version: {os.sys.version}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    device_name = torch.cuda.get_device_name(0)
    device_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
    print(f"GPU Device: {device_name}")
    print(f"GPU Memory: {device_memory:.1f} GB")
    print("✅ GPU is ready for training!")

    # Set optimal GPU settings
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False

    # Clear GPU cache
    torch.cuda.empty_cache()
    gc.collect()
    print("🚀 GPU optimized for training")

else:
    print("❌ No GPU detected!")
    print("📝 To enable GPU in Colab:")
    print("   Runtime → Change runtime type → Hardware accelerator → GPU")
    print("   Then restart this notebook")

# Set device for later use
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\n🎯 Using device: {DEVICE}")

🔍 SYSTEM INFORMATION
Python version: 3.12.11 (main, Jun  4 2025, 08:56:18) [GCC 11.4.0]
PyTorch version: 2.8.0+cu126
CUDA available: True
GPU Device: NVIDIA A100-SXM4-40GB
GPU Memory: 39.6 GB
✅ GPU is ready for training!
🚀 GPU optimized for training

🎯 Using device: cuda


## 📁 Step 2: Mount Google Drive

This will give us access to your fish dataset stored in Google Drive.

In [2]:
from google.colab import drive
import os

# Mount Google Drive
drive.mount('/content/drive')

# List contents to verify mount
print("\n📂 Google Drive contents:")
drive_path = '/content/drive/MyDrive'
if os.path.exists(drive_path):
    items = os.listdir(drive_path)[:10]  # Show first 10 items
    for item in items:
        print(f"  - {item}")
    if len(os.listdir(drive_path)) > 10:
        print(f"  ... and {len(os.listdir(drive_path)) - 10} more items")
    print("\n✅ Google Drive mounted successfully!")
else:
    print("❌ Failed to mount Google Drive")

Mounted at /content/drive

📂 Google Drive contents:
  - Mock Matric
  - Photos
  - Admin
  - Uni
  - Colab Notebooks
  - ViT-FishID
  - ViT-FishID_BestModel_Backups
  - fish_cutouts_old.zip
  - fish_cutouts.zip
  - ViT_FishID_test_split.json
  ... and 4 more items

✅ Google Drive mounted successfully!


## 📦 Step 3: Install Dependencies

Install all required packages for ViT training and semi-supervised learning.

In [3]:
# Install required packages
print("📦 Installing required packages...")

!pip install -q timm
!pip install -q wandb
!pip install -q matplotlib seaborn
!pip install -q scikit-learn
!pip install -q tqdm
!pip install -q Pillow

print("✅ All packages installed successfully!")

# Verify installations
try:
    import timm
    import wandb
    import matplotlib.pyplot as plt
    import seaborn as sns
    from sklearn.metrics import accuracy_score
    from tqdm import tqdm
    from PIL import Image
    print("🔍 All imports successful!")
except ImportError as e:
    print(f"❌ Import error: {e}")

📦 Installing required packages...
✅ All packages installed successfully!
🔍 All imports successful!


## 🔄 Step 4: Clone ViT-FishID Repository

Clone the repository to get all necessary training scripts and model definitions.

In [4]:
import os

# Clone repository if not already present
repo_dir = '/content/ViT-FishID'

if not os.path.exists(repo_dir):
    print("📥 Cloning ViT-FishID repository...")
    !git clone https://github.com/cat-thomson/ViT-FishID.git /content/ViT-FishID
    print("✅ Repository cloned successfully!")
else:
    print("✅ Repository already exists")

# Change to repository directory
os.chdir(repo_dir)
print(f"📁 Current directory: {os.getcwd()}")

# List repository contents
print("\n📂 Repository contents:")
for item in sorted(os.listdir('.')):
    if os.path.isfile(item):
        print(f"  📄 {item}")
    else:
        print(f"  📁 {item}/")

📥 Cloning ViT-FishID repository...
Cloning into '/content/ViT-FishID'...
remote: Enumerating objects: 215, done.[K
remote: Counting objects: 100% (215/215), done.[K
remote: Compressing objects: 100% (155/155), done.[K
remote: Total 215 (delta 98), reused 160 (delta 53), pack-reused 0 (from 0)[K
Receiving objects: 100% (215/215), 1.25 MiB | 14.55 MiB/s, done.
Resolving deltas: 100% (98/98), done.
✅ Repository cloned successfully!
📁 Current directory: /content/ViT-FishID

📂 Repository contents:
  📁 .git/
  📄 .gitattributes
  📄 .gitignore
  📄 CHECKPOINT_BACKUP_ENHANCEMENT.md
  📄 EXTENDED_TRAINING_SETUP.md
  📄 GOOGLE_DRIVE_METRICS_BACKUP_ENHANCEMENT.md
  📄 MAE_INTEGRATION_GUIDE.md
  📄 README.md
  📄 ViT_FishID_Colab_Training_FULLYWORKING.ipynb
  📄 ViT_FishID_Colab_Training_newest.ipynb
  📄 ViT_FishID_Colab_Training_old.ipynb
  📄 ViT_FishID_Colab_Training_working_withoutcheckpointbackup.ipynb
  📄 ViT_FishID_MAE_EMA_Training.ipynb
  📄 data.py
  📄 evaluate.py
  📁 fish_cutouts/
  📄 local_re

## 🐠 Step 5: Setup Fish Dataset

**Important**: Upload your `fish_cutouts.zip` file to Google Drive before running this step.

This step will locate and extract your fish dataset for training.

In [5]:
import os
import zipfile
import json
from pathlib import Path

print("🐠 FISH DATASET SETUP")
print("="*50)

# Search for fish_cutouts.zip in Google Drive
drive_path = '/content/drive/MyDrive'
zip_locations = []

print("🔍 Searching for fish_cutouts.zip...")
for root, dirs, files in os.walk(drive_path):
    for file in files:
        if 'fish_cutouts.zip' in file.lower():
            zip_locations.append(os.path.join(root, file))

if not zip_locations:
    print("❌ fish_cutouts.zip not found in Google Drive!")
    print("📝 Please upload fish_cutouts.zip to Google Drive and try again")
    raise FileNotFoundError("fish_cutouts.zip not found")

# Use the first found zip file
zip_path = zip_locations[0]
print(f"✅ Found dataset: {zip_path}")

# Extract dataset
extract_path = '/content/fish_cutouts'
if not os.path.exists(extract_path):
    print(f"📂 Extracting to: {extract_path}")
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall('/content')
    print("✅ Dataset extracted successfully!")
else:
    print("✅ Dataset already extracted")

# Verify dataset structure
labeled_path = os.path.join(extract_path, 'labeled')
unlabeled_path = os.path.join(extract_path, 'unlabeled')

if os.path.exists(labeled_path):
    labeled_species = os.listdir(labeled_path)
    total_labeled = sum(len(os.listdir(os.path.join(labeled_path, species)))
                       for species in labeled_species if os.path.isdir(os.path.join(labeled_path, species)))
    print(f"📊 Labeled data: {len(labeled_species)} species, {total_labeled} images")
else:
    print("❌ Labeled folder not found!")

if os.path.exists(unlabeled_path):
    unlabeled_images = len([f for f in os.listdir(unlabeled_path)
                           if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
    print(f"📊 Unlabeled data: {unlabeled_images} images")
else:
    print("⚠️ Unlabeled folder not found (optional for supervised training)")

# Save dataset info for later use
dataset_info = {
    'labeled_path': labeled_path,
    'unlabeled_path': unlabeled_path if os.path.exists(unlabeled_path) else None,
    'num_species': len(labeled_species) if os.path.exists(labeled_path) else 0,
    'labeled_images': total_labeled if os.path.exists(labeled_path) else 0,
    'unlabeled_images': unlabeled_images if os.path.exists(unlabeled_path) else 0
}

with open('dataset_info.json', 'w') as f:
    json.dump(dataset_info, f, indent=2)

print(f"\n✅ Dataset setup complete!")
print(f"📁 Labeled path: {labeled_path}")
print(f"📁 Unlabeled path: {unlabeled_path}")

🐠 FISH DATASET SETUP
🔍 Searching for fish_cutouts.zip...
✅ Found dataset: /content/drive/MyDrive/fish_cutouts.zip
📂 Extracting to: /content/fish_cutouts
✅ Dataset extracted successfully!
📊 Labeled data: 34 species, 5137 images
📊 Unlabeled data: 24015 images

✅ Dataset setup complete!
📁 Labeled path: /content/fish_cutouts/labeled
📁 Unlabeled path: /content/fish_cutouts/unlabeled


## 📈 Step 6: Setup Weights & Biases (Optional)

Setup experiment tracking with Weights & Biases. Skip this cell if you don't want to use W&B.

In [6]:
# Setup Weights & Biases (optional)
import wandb

print("📈 Setting up Weights & Biases...")

try:
    # Login to wandb (you'll need to paste your API key)
    wandb.login()

    # Initialize project
    wandb.init(
        project="vit-fish-classification",
        name="colab-training-run",
        config={
            "model": "vit_base_patch16_224",
            "framework": "semi-supervised",
            "environment": "google-colab"
        }
    )

    print("✅ W&B setup complete!")
    USE_WANDB = True

except Exception as e:
    print(f"⚠️ W&B setup failed: {e}")
    print("📝 Training will continue without experiment tracking")
    USE_WANDB = False

📈 Setting up Weights & Biases...


  | |_| | '_ \/ _` / _` |  _/ -_)


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mcativthomson[0m ([33mcativthomson-university-of-cape-town[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


✅ W&B setup complete!


## ⚙️ Step 7: Configure Training Parameters

Set up all training configurations and hyperparameters.

In [7]:
# Training Configuration
import json

print("⚙️ TRAINING CONFIGURATION")
print("="*50)

# Load dataset info
with open('dataset_info.json', 'r') as f:
    dataset_info = json.load(f)

TRAINING_CONFIG = {
    # Model settings
    'model_name': 'vit_base_patch16_224',
    'num_classes': dataset_info['num_species'],
    'image_size': 224,

    # Training settings
    'epochs': 100,
    'batch_size': 32,
    'learning_rate': 1e-4,
    'weight_decay': 1e-4,

    # Semi-supervised settings
    'ema_decay': 0.999,
    'consistency_weight': 1.0,
    'confidence_threshold': 0.95,

    # Data settings
    'labeled_batch_size': 16,
    'unlabeled_batch_size': 16,
    'num_workers': 2,

    # Paths
    'labeled_data_path': dataset_info['labeled_path'],
    'unlabeled_data_path': dataset_info['unlabeled_path'],
    'checkpoint_dir': 'checkpoints',

    # Other settings
    'save_every': 10,
    'eval_every': 5,
    'device': str(DEVICE),
    'use_wandb': USE_WANDB if 'USE_WANDB' in globals() else False
}

# Create checkpoint directory
os.makedirs(TRAINING_CONFIG['checkpoint_dir'], exist_ok=True)

# Print configuration
print("📋 Training Configuration:")
for key, value in TRAINING_CONFIG.items():
    print(f"  {key}: {value}")

# Save configuration
with open('training_config.json', 'w') as f:
    json.dump(TRAINING_CONFIG, f, indent=2)

print("\n✅ Configuration saved!")

⚙️ TRAINING CONFIGURATION
📋 Training Configuration:
  model_name: vit_base_patch16_224
  num_classes: 34
  image_size: 224
  epochs: 100
  batch_size: 32
  learning_rate: 0.0001
  weight_decay: 0.0001
  ema_decay: 0.999
  consistency_weight: 1.0
  confidence_threshold: 0.95
  labeled_batch_size: 16
  unlabeled_batch_size: 16
  num_workers: 2
  labeled_data_path: /content/fish_cutouts/labeled
  unlabeled_data_path: /content/fish_cutouts/unlabeled
  checkpoint_dir: checkpoints
  save_every: 10
  eval_every: 5
  device: cuda
  use_wandb: True

✅ Configuration saved!


## 🤖 Step 8: Load MAE Pre-trained Model (Optional)

**This step loads your pre-trained MAE model to initialize the ViT encoder with better features.**

Skip this cell if you don't have a MAE pre-trained model or want to train from scratch.

In [None]:
# Load MAE pre-trained model (optional)
import torch
import glob

print("🤖 MAE PRE-TRAINED MODEL LOADING")
print("="*50)

# Search for MAE checkpoints
mae_patterns = [
    '/content/drive/MyDrive/ViT-FishID/mae_base_checkpoints/mae_final_model.pth',
]

mae_checkpoints = []
for pattern in mae_patterns:
    mae_checkpoints.extend(glob.glob(pattern, recursive=True))

if mae_checkpoints:
    # Use the first found MAE checkpoint
    mae_path = mae_checkpoints[0]
    print(f"✅ Found MAE model: {mae_path}")

    try:
        # Load MAE checkpoint with weights_only=False for compatibility
        # Note: Only use weights_only=False if you trust the checkpoint source
        print("🔧 Loading checkpoint with legacy compatibility mode...")
        mae_checkpoint = torch.load(mae_path, map_location='cpu', weights_only=False)

        # Debug: Print checkpoint structure
        print("🔍 Checkpoint structure:")
        if isinstance(mae_checkpoint, dict):
            print(f"   Checkpoint keys: {list(mae_checkpoint.keys())}")
        else:
            print(f"   Checkpoint type: {type(mae_checkpoint)}")

        # Extract model state dict from checkpoint
        model_state_dict = None
        if 'model_state_dict' in mae_checkpoint:
            model_state_dict = mae_checkpoint['model_state_dict']
            print("✅ Found 'model_state_dict' key in checkpoint")
        elif 'model' in mae_checkpoint:
            model_state_dict = mae_checkpoint['model']
            print("✅ Found 'model' key in checkpoint")
        elif 'state_dict' in mae_checkpoint:
            model_state_dict = mae_checkpoint['state_dict']
            print("✅ Found 'state_dict' key in checkpoint")
        else:
            # Check if the entire checkpoint is the state dict
            # Look for typical model weight keys (should contain tensors, not scalars)
            tensor_keys = [k for k, v in mae_checkpoint.items() if isinstance(v, torch.Tensor) and v.dim() > 0]
            if tensor_keys:
                model_state_dict = mae_checkpoint
                print("✅ Using entire checkpoint as state_dict")
            else:
                print("❌ No model weights found in checkpoint")
                model_state_dict = None

        if model_state_dict is not None:
            print(f"🔍 Model state dict has {len(model_state_dict)} keys")

            # Show sample keys to understand structure
            sample_keys = list(model_state_dict.keys())[:10]
            print(f"📋 Sample model keys: {sample_keys}")

            # Filter encoder weights for ViT
            mae_encoder_weights = {}
            encoder_count = 0

            for key, value in model_state_dict.items():
                # Skip non-tensor values
                if not isinstance(value, torch.Tensor):
                    continue

                new_key = None

                # Handle different MAE checkpoint formats
                if key.startswith('encoder.'):
                    # Standard MAE format: encoder.xxx -> backbone.xxx
                    new_key = key.replace('encoder.', 'backbone.')
                    encoder_count += 1
                elif key.startswith('backbone.'):
                    # Already has backbone prefix
                    new_key = key
                    encoder_count += 1
                elif key.startswith(('patch_embed', 'pos_embed', 'cls_token', 'blocks')):
                    # Direct ViT encoder components
                    new_key = f'backbone.{key}'
                    encoder_count += 1
                elif not key.startswith(('decoder', 'mask_token', 'loss', 'optimizer', 'scheduler', 'epoch')):
                    # Include other potential encoder weights, exclude decoder/metadata
                    if 'embed' in key or 'block' in key or 'norm' in key or 'head' in key:
                        new_key = f'backbone.{key}'
                        encoder_count += 1

                if new_key is not None:
                    mae_encoder_weights[new_key] = value

            if encoder_count > 0:
                print(f"✅ Successfully extracted {len(mae_encoder_weights)} encoder weights from MAE")
                print(f"🎯 Encoder weights found: {encoder_count} parameters")

                # Show actual weight key samples (not metadata)
                actual_weight_keys = [k for k in mae_encoder_weights.keys() if isinstance(mae_encoder_weights[k], torch.Tensor)][:5]
                print(f"🔍 Sample weight parameter keys: {actual_weight_keys}")

                # Verify we have actual model weights, not just metadata
                total_params = sum(v.numel() for v in mae_encoder_weights.values() if isinstance(v, torch.Tensor))
                print(f"📊 Total parameters loaded: {total_params:,}")

                if total_params > 1000:  # Reasonable check - should have many parameters
                    TRAINING_CONFIG['mae_pretrained_weights'] = mae_encoder_weights
                    TRAINING_CONFIG['use_mae_init'] = True
                else:
                    print("⚠️ Very few parameters found - might be loading metadata instead of weights")
                    TRAINING_CONFIG['use_mae_init'] = False
            else:
                print("⚠️ No encoder weights found in MAE checkpoint")
                TRAINING_CONFIG['use_mae_init'] = False
        else:
            print("❌ Could not extract model state dict from checkpoint")
            TRAINING_CONFIG['use_mae_init'] = False

    except Exception as e:
        print(f"❌ Failed to load MAE model: {e}")
        print("💡 Trying alternative loading methods...")

        try:
            # Fallback: Try to load and inspect structure more carefully
            mae_checkpoint = torch.load(mae_path, map_location='cpu', weights_only=False)

            # More detailed inspection
            def find_model_weights(obj, path=""):
                """Recursively find model weights in checkpoint"""
                if isinstance(obj, torch.Tensor) and obj.dim() > 0:
                    return {path: obj}
                elif isinstance(obj, dict):
                    weights = {}
                    for k, v in obj.items():
                        sub_weights = find_model_weights(v, f"{path}.{k}" if path else k)
                        weights.update(sub_weights)
                    return weights
                else:
                    return {}

            all_weights = find_model_weights(mae_checkpoint)
            encoder_weights = {k: v for k, v in all_weights.items()
                             if not any(exclude in k for exclude in ['decoder', 'mask_token', 'loss', 'optimizer', 'scheduler'])}

            if encoder_weights:
                # Convert keys to backbone format
                mae_encoder_weights = {}
                for k, v in encoder_weights.items():
                    if k.startswith('encoder.'):
                        new_key = k.replace('encoder.', 'backbone.')
                    elif not k.startswith('backbone.'):
                        new_key = f'backbone.{k}'
                    else:
                        new_key = k
                    mae_encoder_weights[new_key] = v

                total_params = sum(v.numel() for v in mae_encoder_weights.values())
                print(f"✅ Fallback method found {len(mae_encoder_weights)} weights ({total_params:,} parameters)")

                TRAINING_CONFIG['mae_pretrained_weights'] = mae_encoder_weights
                TRAINING_CONFIG['use_mae_init'] = True
            else:
                print("❌ No suitable weights found with fallback method")
                TRAINING_CONFIG['use_mae_init'] = False

        except Exception as e2:
            print(f"❌ All loading methods failed: {e2}")
            TRAINING_CONFIG['use_mae_init'] = False
else:
    print("⚠️ No MAE pre-trained model found")
    print("📝 Training will use standard ImageNet initialization")
    TRAINING_CONFIG['use_mae_init'] = False

print(f"\n✅ MAE initialization: {'Enabled' if TRAINING_CONFIG.get('use_mae_init', False) else 'Disabled'}")

# Print additional info if MAE is enabled
if TRAINING_CONFIG.get('use_mae_init', False):
    mae_weights = TRAINING_CONFIG.get('mae_pretrained_weights', {})
    if isinstance(mae_weights, dict):
        # Count actual tensor parameters
        tensor_count = len([k for k, v in mae_weights.items() if isinstance(v, torch.Tensor)])
        total_params = sum(v.numel() for v in mae_weights.values() if isinstance(v, torch.Tensor))

        print(f"📊 MAE weights loaded: {tensor_count} weight tensors")
        print(f"🔢 Total parameters: {total_params:,}")

        # Show actual weight parameter keys (not metadata)
        weight_keys = [k for k, v in mae_weights.items() if isinstance(v, torch.Tensor)][:5]
        print(f"🔍 Sample weight keys: {weight_keys}")

    print("🚀 ViT encoder will be initialized with MAE pre-trained weights!")
else:
    print("📝 ViT will use standard ImageNet pre-trained initialization")

🤖 MAE PRE-TRAINED MODEL LOADING
✅ Found MAE model: /content/drive/MyDrive/ViT-FishID/mae_base_checkpoints/mae_final_model.pth
🔧 Loading checkpoint with legacy compatibility mode...
🔍 Checkpoint structure:
   Checkpoint keys: ['epoch', 'model_state_dict', 'optimizer_state_dict', 'scheduler_state_dict', 'loss', 'config', 'timestamp', 'training_losses', 'epoch_times']
✅ Found 'model_state_dict' key in checkpoint
🔍 Model state dict has 254 keys
📋 Sample model keys: ['encoder.cls_token', 'encoder.pos_embed', 'encoder.patch_embed.proj.weight', 'encoder.patch_embed.proj.bias', 'encoder.blocks.0.norm1.weight', 'encoder.blocks.0.norm1.bias', 'encoder.blocks.0.attn.qkv.weight', 'encoder.blocks.0.attn.qkv.bias', 'encoder.blocks.0.attn.proj.weight', 'encoder.blocks.0.attn.proj.bias']
✅ Successfully extracted 150 encoder weights from MAE
🎯 Encoder weights found: 150 parameters
🔍 Sample weight parameter keys: ['backbone.cls_token', 'backbone.pos_embed', 'backbone.patch_embed.proj.weight', 'backbone.

## 🚀 Step 9: Start Semi-Supervised Training

Begin the main training loop with semi-supervised learning and EMA teacher-student framework.

In [9]:
# 🚀 OPTION 2: A100 GPU OPTIMIZED TRAINING (Full MAE + Semi-Supervised)
import subprocess
import sys
import os
import time
import torch

print("🚀 A100 GPU OPTIMIZED TRAINING")
print("="*60)

# Check if we have A100 or high-memory GPU
print("🔍 GPU Check:")
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
    print(f"   GPU: {gpu_name}")
    print(f"   Memory: {gpu_memory:.1f}GB")

    if "A100" in gpu_name or gpu_memory > 20:
        print("✅ High-performance GPU detected - optimal for full training!")
        use_full_config = True
    else:
        print("⚠️ Standard GPU detected - consider upgrading for better performance")
        use_full_config = False
else:
    print("❌ No GPU detected")
    use_full_config = False

if use_full_config:
    print("\n🔥 FULL CONFIGURATION ENABLED:")
    print("✅ MAE pretraining: ENABLED")
    print("✅ Semi-supervised mode: ENABLED")
    print("✅ Batch size: 32 (full size)")
    print("✅ Mixed precision: AUTO (handled by PyTorch)")
    print("✅ All optimizations: ENABLED")

    # Get data path
    configured_labeled_path = TRAINING_CONFIG['labeled_data_path']
    parent_data_dir = os.path.dirname(configured_labeled_path)
    correct_data_path = parent_data_dir

    # A100 optimized training args (removed --mixed_precision)
    training_args_a100 = [
        sys.executable, 'train.py',
        '--data_dir', correct_data_path,
        '--epochs', '100',
        '--batch_size', '32',  # Full batch size
        '--learning_rate', '0.0001',
        '--save_dir', 'checkpoints',
        '--device', 'cuda',
        '--mode', 'semi_supervised',  # Full semi-supervised
        #'--pretrained',  # MAE pretraining
        '--val_split', '0.2',
        '--test_split', '0.2',
        '--seed', '42',
        '--warmup_epochs', '10',
        '--save_frequency', '10'
        # Removed --mixed_precision as it's not supported by the script
    ]

    print(f"\n📝 A100 Command: {' '.join(training_args_a100)}")

    try:
        print("\n🚀 STARTING A100 OPTIMIZED TRAINING...")
        print("Expected startup time: 3-5 minutes")
        print("Expected first epoch: 8-12 minutes")
        print("Expected total time: 3-4 hours")
        print("-" * 50)

        # Execute training with real-time monitoring
        process = subprocess.Popen(training_args_a100, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1)

        start_time = time.time()
        line_count = 0

        while True:
            output = process.stdout.readline()
            if output == '' and process.poll() is not None:
                break
            if output:
                print(output.strip())
                line_count += 1

                # Show progress indicators
                if 'Epoch' in output:
                    elapsed = (time.time() - start_time) / 60
                    print(f"⏰ Time elapsed: {elapsed:.1f} minutes")

        rc = process.poll()
        if rc == 0:
            print(f"\n🎉 A100 TRAINING COMPLETED SUCCESSFULLY!")
            total_time = (time.time() - start_time) / 60
            print(f"⏰ Total time: {total_time:.1f} minutes")
        else:
            print(f"\n❌ Training failed with exit code: {rc}")

    except Exception as e:
        print(f"❌ A100 training error: {e}")

else:
    print(f"\n💡 TO USE THIS OPTION:")
    print("1. Go to Runtime → Change runtime type")
    print("2. Hardware accelerator: GPU")
    print("3. GPU type: A100 (if available)")
    print("4. Then restart this notebook and run this cell")
    print("5. A100 will handle the full workload easily!")

    print(f"\n🔧 OR TRY SMALLER CONFIG ON CURRENT GPU:")
    print("Run the previous cell (Option 1) for a working solution")

print(f"\n✅ Run this cell ONLY if you have A100 or similar high-memory GPU!")
print("Otherwise, use Option 1 (Fast Training) from the previous cell.")

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
⏰ Time elapsed: 110.1 minutes
Epoch 69:  92%|█████████▏| 266/288 [01:16<00:06,  3.48it/s, Total=0.1454, Sup=0.1418, Cons=0.0018, L-Acc=95.8%, P-Acc=94.6%]
⏰ Time elapsed: 110.1 minutes
Epoch 69:  92%|█████████▏| 266/288 [01:17<00:06,  3.48it/s, Total=0.1462, Sup=0.1426, Cons=0.0018, L-Acc=95.8%, P-Acc=94.6%]
⏰ Time elapsed: 110.1 minutes
Epoch 69:  93%|█████████▎| 267/288 [01:17<00:06,  3.48it/s, Total=0.1462, Sup=0.1426, Cons=0.0018, L-Acc=95.8%, P-Acc=94.6%]
⏰ Time elapsed: 110.1 minutes
Epoch 69:  93%|█████████▎| 267/288 [01:17<00:06,  3.48it/s, Total=0.1493, Sup=0.1457, Cons=0.0018, L-Acc=95.8%, P-Acc=94.6%]
⏰ Time elapsed: 110.1 minutes
Epoch 69:  93%|█████████▎| 268/288 [01:17<00:05,  3.48it/s, Total=0.1493, Sup=0.1457, Cons=0.0018, L-Acc=95.8%, P-Acc=94.6%]
⏰ Time elapsed: 110.1 minutes
Epoch 69:  93%|█████████▎| 268/288 [01:17<00:05,  3.48it/s, Total=0.1488, Sup=0.1452, Cons=0.0018, L-Acc=95.8%, P-Acc=94.7%]
⏰ Tim

## 🔄 Step 10: Resume Training (If Interrupted)

Use this cell to resume training from the latest checkpoint if your session was interrupted.

In [10]:
# 🚀 OPTION 2: A100 GPU OPTIMIZED TRAINING (Full MAE + Semi-Supervised)
import subprocess
import sys
import os
import time
import torch
import glob
import shutil

print("🚀 A100 GPU OPTIMIZED TRAINING")
print("="*60)

# Check for existing checkpoints in Google Drive
checkpoint_path = None
resume_epoch = 70 # Set to the desired epoch

drive_checkpoint_dir = '/content/drive/MyDrive/ViT-FishID/checkpoints_backup'
local_checkpoint_dir = 'checkpoints'

print(f"🔍 Checking for specific checkpoint: epoch {resume_epoch}...")

# Define the expected checkpoint filename
specific_checkpoint_filename = f"checkpoint_epoch_{resume_epoch}.pth"

# Check Google Drive for the specific checkpoint
drive_specific_checkpoint_path = os.path.join(drive_checkpoint_dir, specific_checkpoint_filename)

if os.path.exists(drive_specific_checkpoint_path):
    print(f"✅ Found specific checkpoint in Google Drive: {specific_checkpoint_filename}")

    # Create local checkpoint directory
    os.makedirs(local_checkpoint_dir, exist_ok=True)

    # Define the local path for the specific checkpoint
    local_specific_checkpoint_path = os.path.join(local_checkpoint_dir, specific_checkpoint_filename)

    # Copy checkpoint to local directory if it's not already there
    if not os.path.exists(local_specific_checkpoint_path):
        shutil.copy2(drive_specific_checkpoint_path, local_specific_checkpoint_path)
        print(f"📋 Checkpoint copied to: {local_specific_checkpoint_path}")
    else:
        print(f"📋 Checkpoint already exists locally: {local_specific_checkpoint_path}")

    checkpoint_path = local_specific_checkpoint_path
    print(f"🔄 Will resume from epoch {resume_epoch}")
else:
    print(f"❌ Specific checkpoint not found in Google Drive: {specific_checkpoint_filename}")
    print("📝 Please ensure the checkpoint exists in your Google Drive backup.")
    # Optionally, you might want to exit or raise an error here if the specific checkpoint is required
    # For now, I'll let it proceed without a resume path, which will start from scratch

# Check if we have A100 or high-memory GPU
print("\n🔍 GPU Check:")
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
    print(f"   GPU: {gpu_name}")
    print(f"   Memory: {gpu_memory:.1f}GB")

    if "A100" in gpu_name or gpu_memory > 20:
        print("✅ High-performance GPU detected - optimal for full training!")
        use_full_config = True
    else:
        print("⚠️ Standard GPU detected - consider upgrading for better performance")
        use_full_config = False
else:
    print("❌ No GPU detected")
    use_full_config = False

if use_full_config:
    print("\n🔥 FULL CONFIGURATION ENABLED:")
    print("✅ MAE pretraining: ENABLED")
    print("✅ Semi-supervised mode: ENABLED")
    print("✅ Batch size: 32 (full size)")
    print("✅ Mixed precision: AUTO (handled by PyTorch)")
    print("✅ All optimizations: ENABLED")
    if checkpoint_path:
        print(f"✅ Resuming from checkpoint: epoch {resume_epoch}")
    else:
        print("✅ Starting training from scratch (no checkpoint found)")


    # Get data path
    configured_labeled_path = TRAINING_CONFIG['labeled_data_path']
    parent_data_dir = os.path.dirname(configured_labeled_path)
    correct_data_path = parent_data_dir

    # A100 optimized training args with checkpoint resumption
    training_args_a100 = [
        sys.executable, 'train.py',
        '--data_dir', correct_data_path,
        '--epochs', str(TRAINING_CONFIG['epochs']),
        '--batch_size', str(TRAINING_CONFIG['batch_size']),
        '--learning_rate', str(TRAINING_CONFIG['learning_rate']),
        '--save_dir', 'checkpoints',
        '--device', 'cuda',
        '--mode', 'semi_supervised',
        '--val_split', '0.2',
        '--test_split', '0.2',
        '--seed', '42',
        '--warmup_epochs', '10',
        '--save_frequency', '10'
    ]

    # Add resume checkpoint if available
    if checkpoint_path:
        training_args_a100.extend(['--resume', checkpoint_path])

    print(f"\n📝 A100 Command: {' '.join(training_args_a100)}")

    try:
        print("\n🚀 STARTING A100 OPTIMIZED TRAINING...")
        if checkpoint_path:
            print(f"🔄 Resuming from epoch {resume_epoch}")
            print("Expected startup time: 1-2 minutes (faster resume)")
        else:
            print("Expected startup time: 3-5 minutes")
        print("Expected first epoch: 8-12 minutes")
        print("Expected total time: 3-4 hours")
        print("-" * 50)

        # Execute training with real-time monitoring
        process = subprocess.Popen(training_args_a100, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1)

        start_time = time.time()
        line_count = 0
        last_backup_time = time.time()

        while True:
            output = process.stdout.readline()
            if output == '' and process.poll() is not None:
                break
            if output:
                print(output.strip())
                line_count += 1

                # Show progress indicators
                if 'Epoch' in output:
                    elapsed = (time.time() - start_time) / 60
                    print(f"⏰ Time elapsed: {elapsed:.1f} minutes")

                # Backup checkpoints to Google Drive every 30 minutes
                current_time = time.time()
                if (current_time - last_backup_time) > 1800:  # 30 minutes
                    try:
                        # Find latest local checkpoint
                        local_checkpoints = glob.glob("checkpoints/checkpoint_epoch_*.pth")
                        if local_checkpoints:
                            local_checkpoints.sort(key=lambda x: int(os.path.basename(x).split('_')[-1].split('.')[0]))
                            latest_local = local_checkpoints[-1]
                            backup_name = os.path.basename(latest_local)
                            # Use the existing Google Drive checkpoint backup directory
                            backup_path = f"/content/drive/MyDrive/ViT-FishID/checkpoints_backup/{backup_name}"

                            if not os.path.exists(backup_path):
                                shutil.copy2(latest_local, backup_path)
                                print(f"💾 Backed up checkpoint to Google Drive: {backup_name}")

                            last_backup_time = current_time
                    except Exception as backup_error:
                        print(f"⚠️ Backup warning: {backup_error}")

        rc = process.poll()

        # Final backup of all checkpoints
        if rc == 0:
            try:
                print("\n💾 Backing up final checkpoints to Google Drive...")
                local_checkpoints = glob.glob("checkpoints/*.pth")
                for checkpoint in local_checkpoints:
                    backup_name = os.path.basename(checkpoint)
                     # Use the existing Google Drive checkpoint backup directory
                    backup_path = f"/content/drive/MyDrive/ViT-FishID/checkpoints_backup/{backup_name}"
                    if not os.path.exists(backup_path):
                        shutil.copy2(checkpoint, backup_path)
                print("✅ All checkpoints backed up successfully!")
            except Exception as e:
                print(f"⚠️ Final backup warning: {e}")

        if rc == 0:
            print(f"\n🎉 A100 TRAINING COMPLETED SUCCESSFULLY!")
            total_time = (time.time() - start_time) / 60
            print(f"⏰ Total time: {total_time:.1f} minutes")
        else:
            print(f"\n❌ Training failed with exit code: {rc}")

    except Exception as e:
        print(f"❌ A100 training error: {e}")

else:
    print(f"\n💡 TO USE THIS OPTION:")
    print("1. Go to Runtime → Change runtime type")
    print("2. Hardware accelerator: GPU")
    print("3. GPU type: A100 (if available)")
    print("4. Then restart this notebook and run this cell")
    print("5. A100 will handle the full workload easily!")

    print(f"\n🔧 OR TRY SMALLER CONFIG ON CURRENT GPU:")
    print("Run the previous cell (Option 1) for a working solution")

print(f"\n✅ Run this cell ONLY if you have A100 or similar high-memory GPU!")
print("Otherwise, use Option 1 (Fast Training) from the previous cell.")
print("\n📁 Checkpoints are automatically saved to Google Drive for persistence!")

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Epoch 96:  91%|█████████▏| 263/288 [01:18<00:07,  3.40it/s, Total=0.1094, Sup=0.1090, Cons=0.0002, L-Acc=96.8%, P-Acc=99.7%]
⏰ Time elapsed: 41.8 minutes
Epoch 96:  92%|█████████▏| 264/288 [01:18<00:07,  3.41it/s, Total=0.1094, Sup=0.1090, Cons=0.0002, L-Acc=96.8%, P-Acc=99.7%]
⏰ Time elapsed: 41.8 minutes
Epoch 96:  92%|█████████▏| 264/288 [01:18<00:07,  3.41it/s, Total=0.1090, Sup=0.1086, Cons=0.0002, L-Acc=96.9%, P-Acc=99.7%]
⏰ Time elapsed: 41.8 minutes
Epoch 96:  92%|█████████▏| 265/288 [01:18<00:06,  3.41it/s, Total=0.1090, Sup=0.1086, Cons=0.0002, L-Acc=96.9%, P-Acc=99.7%]
⏰ Time elapsed: 41.8 minutes
Epoch 96:  92%|█████████▏| 265/288 [01:18<00:06,  3.41it/s, Total=0.1115, Sup=0.1110, Cons=0.0002, L-Acc=96.8%, P-Acc=99.7%]
⏰ Time elapsed: 41.8 minutes
Epoch 96:  92%|█████████▏| 266/288 [01:18<00:06,  3.41it/s, Total=0.1115, Sup=0.1110, Cons=0.0002, L-Acc=96.8%, P-Acc=99.7%]
⏰ Time elapsed: 41.8 minutes
Epoch 96:  

In [10]:
# 💾 SAVE TEST SPLIT FOR POST-SESSION EVALUATION
import json
import os
import pickle
from sklearn.model_selection import train_test_split
import numpy as np

print("💾 SAVING TEST SPLIT FOR POST-SESSION EVALUATION")
print("="*60)

# This ensures you can evaluate on the EXACT same test data later
# even after Colab session times out

try:
    # Get the data directory
    data_dir = '/content/fish_cutouts'  # Your fish dataset location

    # Recreate the exact same split using the same parameters as training
    print("🔄 Recreating train/val/test split with same parameters...")

    # Collect all image paths and labels (same logic as in training)
    image_paths = []
    labels = []
    class_names = []

    labeled_dir = os.path.join(data_dir, 'labeled')

    if os.path.exists(labeled_dir):
        # Correctly populate class_names by sorting the directory list
        class_names = sorted([d for d in os.listdir(labeled_dir) if os.path.isdir(os.path.join(labeled_dir, d)) and not d.startswith('.')])

        # Populate image_paths and labels using the sorted class_names
        for class_idx, class_name in enumerate(class_names):
            class_path = os.path.join(labeled_dir, class_name)
            if os.path.isdir(class_path):
                for img_file in os.listdir(class_path):
                    if img_file.lower().endswith(('.jpg', '.jpeg', '.png')):
                        image_paths.append(os.path.join(class_path, img_file))
                        labels.append(class_idx)

    print(f"📊 Found {len(image_paths)} images across {len(class_names)} classes")

    # Use the EXACT same split parameters as training
    SEED = 42  # Same seed as training
    VAL_SPLIT = 0.2  # Same as training
    TEST_SPLIT = 0.2  # Same as training

    # Recreate the exact same split
    np.random.seed(SEED)

    # First split: train+val vs test
    # Ensure stratification is used if you have multiple classes
    if len(class_names) > 1 and len(set(labels)) > 1:
        train_val_paths, test_paths, train_val_labels, test_labels = train_test_split(
            image_paths, labels,
            test_size=TEST_SPLIT,
            random_state=SEED,
            stratify=labels
        )

        # Second split: train vs val
        train_paths, val_paths, train_labels, val_labels = train_test_split(
            train_val_paths, train_val_labels,
            test_size=VAL_SPLIT/(1-TEST_SPLIT),  # Adjust for remaining data
            random_state=SEED,
            stratify=train_val_labels
        )
    else:
         # Handle case with only one class or no data (though previous checks should catch no data)
        print("⚠️ Only one class or no data found. Skipping stratification.")
        train_val_paths, test_paths, train_val_labels, test_labels = train_test_split(
            image_paths, labels,
            test_size=TEST_SPLIT,
            random_state=SEED
        )
        train_paths, val_paths, train_labels, val_labels = train_test_split(
            train_val_paths, train_val_labels,
            test_size=VAL_SPLIT/(1-TEST_SPLIT),
            random_state=SEED
        )


    print(f"✅ Split recreated:")
    print(f"   Train: {len(train_paths)} samples")
    print(f"   Val: {len(val_paths)} samples")
    print(f"   Test: {len(test_paths)} samples")

    # Save test split information
    test_split_data = {
        'test_image_paths': test_paths,
        'test_labels': test_labels,
        'class_names': class_names, # Ensure class_names is correctly populated here
        'data_dir': data_dir,
        'seed': SEED,
        'val_split': VAL_SPLIT,
        'test_split': TEST_SPLIT,
        'total_classes': len(class_names),
        'test_samples': len(test_paths),
        'split_timestamp': '2025-08-19_21:37:58'  # When training started (or update if needed)
    }

    # Save to multiple formats for reliability

    # 1. Save as JSON (human readable)
    with open('/content/test_split_info.json', 'w') as f:
        json.dump(test_split_data, f, indent=2)
    print("✅ Test split saved to: /content/test_split_info.json")

    # 2. Save as pickle (exact Python objects)
    with open('/content/test_split_data.pkl', 'wb') as f:
        pickle.dump(test_split_data, f)
    print("✅ Test split saved to: /content/test_split_data.pkl")

    # 3. Save to Google Drive (persistent storage)
    drive_save_path = '/content/drive/MyDrive/ViT_FishID_test_split.json'
    try:
        with open(drive_save_path, 'w') as f:
            json.dump(test_split_data, f, indent=2)
        print(f"✅ Test split backed up to Google Drive: {drive_save_path}")
    except Exception as e:
        print(f"⚠️ Could not save to Google Drive: {e}")

    # 4. Create a simple test list file
    test_list_path = '/content/test_image_list.txt'
    with open(test_list_path, 'w') as f:
        # Safely get class name using a check
        for path, label in zip(test_paths, test_labels):
            class_name = class_names[label] if 0 <= label < len(class_names) else f"UnknownClass_{label}"
            f.write(f"{path}\t{label}\t{class_name}\n")
    print(f"✅ Test image list saved to: {test_list_path}")

    print(f"\n🎯 POST-SESSION EVALUATION INSTRUCTIONS:")
    print("="*50)
    print("When your session times out, you can still evaluate by:")
    print("1. 📂 Download your trained checkpoint from 'checkpoints/' folder")
    print("2. 💾 Download the test split files created above")
    print("3. 🔄 Start a new Colab session")
    print("4. 📤 Upload the checkpoint and test split files")
    print("5. 🧪 Run evaluation using the saved test split")

    print(f"\n📋 Test Split Summary:")
    print(f"   Random seed: {SEED}")
    print(f"   Test samples: {len(test_paths)}")
    print(f"   Classes: {len(class_names)}")
    print(f"   Split ratios: Train=60%, Val=20%, Test=20%")

    # Show sample test images for verification
    print(f"\n🔍 Sample test images (first 5):")
    for i in range(min(5, len(test_paths))):
        # Safely get class name for printing sample
        class_name = class_names[test_labels[i]] if 0 <= test_labels[i] < len(class_names) else f"UnknownClass_{test_labels[i]}"
        print(f"   {test_paths[i]} -> {class_name}")


except Exception as e:
    print(f"❌ Error saving test split: {e}")
    print("💡 Make sure the fish dataset is properly loaded and has multiple classes if using stratification.")

print(f"\n✅ Test split preservation complete!")
print("Now your training can continue, and you can evaluate later even after timeout!")

💾 SAVING TEST SPLIT FOR POST-SESSION EVALUATION
🔄 Recreating train/val/test split with same parameters...
📊 Found 5137 images across 33 classes
✅ Split recreated:
   Train: 3081 samples
   Val: 1028 samples
   Test: 1028 samples
❌ Error saving test split: [Errno 28] No space left on device
💡 Make sure the fish dataset is properly loaded and has multiple classes if using stratification.

✅ Test split preservation complete!
Now your training can continue, and you can evaluate later even after timeout!


## 📊 Step 11: Check Training Results

Examine training progress and results.

In [None]:
# Check training results
import glob
import torch
import matplotlib.pyplot as plt
import os # Import os module

print("📊 TRAINING RESULTS")
print("="*50)

# Find all checkpoints in both local and Google Drive directories
checkpoint_locations = [
    TRAINING_CONFIG['checkpoint_dir'],
    '/content/drive/MyDrive/ViT-FishID/checkpoints_backup' # Add Google Drive backup path
]

checkpoints = []
for location in checkpoint_locations:
    # Ensure the directory exists before searching
    if os.path.exists(location):
        checkpoint_pattern = os.path.join(location, 'checkpoint_epoch_*.pth')
        checkpoints.extend(glob.glob(checkpoint_pattern))

# Sort checkpoints by epoch number
checkpoints = sorted(checkpoints,
                    key=lambda x: int(os.path.basename(x).split('epoch_')[1].split('.')[0]))


if not checkpoints:
    print("❌ No training checkpoints found in local or Google Drive backup.")
    print("📝 Make sure training has run and checkpoints are saved.")
else:
    print(f"✅ Found {len(checkpoints)} checkpoints across specified locations.")

    # Extract training metrics
    epochs = []
    train_losses = []
    val_accuracies = []

    # Use a set to keep track of processed epochs to avoid duplicates if checkpoints exist in both locations
    processed_epochs = set()

    for checkpoint_path in checkpoints:
        try:
            checkpoint = torch.load(checkpoint_path, map_location='cpu')
            epoch = checkpoint.get('epoch')

            # Only process if epoch exists and hasn't been processed from a later checkpoint
            if epoch is not None and epoch not in processed_epochs:
                epochs.append(epoch)
                # Safely get metrics, default to 0 if not found
                # Note: The key names for loss/accuracy might vary slightly depending on the save logic
                # Using common keys or inspecting checkpoint structure might be needed for robustness
                train_losses.append(checkpoint.get('train_loss', checkpoint.get('train_metrics', {}).get('total_loss', 0)))
                # Prioritize 'accuracy' or 'val_metrics.top1_accuracy'
                val_accuracies.append(checkpoint.get('accuracy', checkpoint.get('val_metrics', {}).get('top1_accuracy', 0)))
                processed_epochs.add(epoch) # Mark epoch as processed
        except Exception as e:
            print(f"⚠️ Could not load or process checkpoint {checkpoint_path}: {e}")
            continue

    # Sort metrics by epoch number after loading
    sorted_metrics = sorted(zip(epochs, train_losses, val_accuracies), key=lambda x: x[0])
    epochs, train_losses, val_accuracies = zip(*sorted_metrics) if sorted_metrics else ([], [], [])


    # Plot training progress
    if epochs:
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

        # Training loss
        ax1.plot(epochs, train_losses, 'b-', label='Training Loss')
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Loss')
        ax1.set_title('Training Loss Over Time')
        ax1.grid(True)
        ax1.legend()

        # Validation accuracy
        ax2.plot(epochs, val_accuracies, 'g-', label='Validation Accuracy')
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Accuracy')
        ax2.set_title('Validation Accuracy Over Time')
        ax2.grid(True)
        ax2.legend()

        plt.tight_layout()
        plt.show()

        # Print best results
        if val_accuracies: # Check if list is not empty
            best_accuracy = max(val_accuracies)
            best_epoch = epochs[val_accuracies.index(best_accuracy)]
            final_accuracy = val_accuracies[-1]

            print(f"\n🏆 TRAINING SUMMARY:")
            print(f"  Total epochs represented: {len(epochs)}")
            print(f"  Best accuracy: {best_accuracy:.4f} (epoch {best_epoch})")
            print(f"  Final accuracy: {final_accuracy:.4f}")
            # Try to find the path of the latest checkpoint if available
            latest_checkpoint_path = checkpoints[-1] if checkpoints else "N/A"
            print(f"  Latest checkpoint file found: {os.path.basename(latest_checkpoint_path)}")

        else:
            print("\n⚠️ No valid validation accuracy data found to summarize.")
    else:
        print("❌ Could not extract valid training metrics from checkpoints")

## 💾 Step 12: Save Model and Results

Save the final trained model and results to Google Drive.

In [12]:
# Save final model and results
import shutil
import glob
from datetime import datetime

print("💾 SAVING MODEL AND RESULTS")
print("="*50)

# Create timestamped folder in Google Drive
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
save_dir = f'/content/drive/MyDrive/ViT_FishID_Results_{timestamp}'
os.makedirs(save_dir, exist_ok=True)

print(f"📁 Saving to: {save_dir}")

# Copy all checkpoints
checkpoint_save_dir = os.path.join(save_dir, 'checkpoints')
if os.path.exists(TRAINING_CONFIG['checkpoint_dir']):
    shutil.copytree(TRAINING_CONFIG['checkpoint_dir'], checkpoint_save_dir, dirs_exist_ok=True)
    print(f"✅ Saved checkpoints to {checkpoint_save_dir}")

# Copy configuration files
config_files = ['training_config.json', 'dataset_info.json']
for config_file in config_files:
    if os.path.exists(config_file):
        shutil.copy2(config_file, save_dir)
        print(f"✅ Saved {config_file}")

# Copy training scripts
script_files = ['train.py', 'model.py', 'data.py', 'utils.py']
for script_file in script_files:
    if os.path.exists(script_file):
        shutil.copy2(script_file, save_dir)
        print(f"✅ Saved {script_file}")

# Create results summary
results_summary = {
    'timestamp': timestamp,
    'training_config': TRAINING_CONFIG,
    'dataset_info': dataset_info,
    'total_checkpoints': len(glob.glob(os.path.join(checkpoint_save_dir, '*.pth'))) if os.path.exists(checkpoint_save_dir) else 0
}

with open(os.path.join(save_dir, 'results_summary.json'), 'w') as f:
    json.dump(results_summary, f, indent=2)

print(f"\n✅ All results saved to Google Drive!")
print(f"📁 Location: {save_dir}")
print(f"\n📋 Saved files:")
for item in os.listdir(save_dir):
    if os.path.isdir(os.path.join(save_dir, item)):
        print(f"  📁 {item}/")
    else:
        print(f"  📄 {item}")

💾 SAVING MODEL AND RESULTS
📁 Saving to: /content/drive/MyDrive/ViT_FishID_Results_20250820_132015
✅ Saved checkpoints to /content/drive/MyDrive/ViT_FishID_Results_20250820_132015/checkpoints
✅ Saved training_config.json
✅ Saved dataset_info.json
✅ Saved train.py
✅ Saved model.py
✅ Saved data.py
✅ Saved utils.py

✅ All results saved to Google Drive!
📁 Location: /content/drive/MyDrive/ViT_FishID_Results_20250820_132015

📋 Saved files:
  📁 checkpoints/
  📄 training_config.json
  📄 dataset_info.json
  📄 train.py
  📄 model.py
  📄 data.py
  📄 utils.py
  📄 results_summary.json


## 🧪 Step 13: Model Evaluation

Evaluate the trained model on test data with comprehensive metrics and visualizations.

In [None]:
# Model selection for evaluation
import glob
import os

print("🧪 MODEL EVALUATION SETUP")
print("="*50)

# Find available models
checkpoint_locations = [
    TRAINING_CONFIG['checkpoint_dir'],
    '/content/drive/MyDrive/ViT-FishID/checkpoints_backup',
    '/content/drive/MyDrive*ViT-FishID/checkpoints_backup/*checkpoint*.pth'
]

found_models = []
for location in checkpoint_locations:
    if '*' in location:
        found_models.extend(glob.glob(location, recursive=True))
    elif os.path.exists(location):
        found_models.extend(glob.glob(os.path.join(location, '*.pth')))

if not found_models:
    print("❌ No trained models found for evaluation")
    print("📝 Please complete training first")
else:
    print(f"✅ Found {len(found_models)} model checkpoints")

    # Select best model (highest epoch number or best accuracy)
    latest_model = max(found_models, key=lambda x:
                      int(x.split('epoch_')[1].split('.')[0]) if 'epoch_' in x else 0)

    print(f"🎯 Selected model: {latest_model}")

    # Set paths for evaluation
    SELECTED_MODEL_PATH = latest_model
    TEST_DATA_PATH = TRAINING_CONFIG['labeled_data_path']  # Use labeled data for testing

    print(f"📊 Test data path: {TEST_DATA_PATH}")
    print(f"\n✅ Ready for model evaluation!")

In [None]:
# Enhanced Model Evaluation with Top-1, Top-5, Metrics Export & Google Drive Backup
import torch
import torch.nn as nn
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, precision_recall_fscore_support
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
import time
import sys
from tqdm import tqdm
import json
import os
import shutil
from datetime import datetime

def create_species_index_mapping(class_names):
    """Create a mapping from full species names to indices and shortened names"""
    species_index = {}
    index_to_species = {}
    shortened_names = []

    for idx, full_name in enumerate(class_names):
        # Create shortened name (first letter of family + first few letters of species)
        parts = full_name.split('_')
        if len(parts) >= 3:
            family = parts[0]
            genus = parts[1]
            species = parts[2]
            shortened = f"{family[0]}{genus[0]}{species[:3]}"
        else:
            shortened = full_name[:6]

        species_index[full_name] = idx
        index_to_species[idx] = full_name
        shortened_names.append(f"{idx:02d}_{shortened}")

    return species_index, index_to_species, shortened_names

def calculate_topk_accuracy(predictions_probs, labels, k=5):
    """Calculate top-k accuracy"""
    # Get top-k predictions for each sample
    _, topk_pred = torch.topk(torch.tensor(predictions_probs), k, dim=1)

    # Check if true label is in top-k predictions
    labels_expanded = torch.tensor(labels).unsqueeze(1).expand_as(topk_pred)
    correct = (topk_pred == labels_expanded).any(dim=1)

    return correct.float().mean().item()

def backup_results_to_google_drive(results_dir, timestamp):
    """Backup all evaluation results to Google Drive"""
    print(f"\n💾 BACKING UP RESULTS TO GOOGLE DRIVE...")

    # Define Google Drive backup locations
    drive_backup_locations = [
        '/content/drive/MyDrive/ViT-FishID/evaluation_results',
        '/content/drive/MyDrive/ViT-FishID/model_evaluation_backup',
        '/content/drive/MyDrive/ViT-FishID/metrics_backup'
    ]

    backup_success = 0

    for backup_location in drive_backup_locations:
        try:
            # Create backup directory structure
            drive_results_dir = os.path.join(backup_location, f"evaluation_{timestamp}")
            os.makedirs(drive_results_dir, exist_ok=True)

            # Copy entire results directory to Google Drive
            for item in os.listdir(results_dir):
                src_path = os.path.join(results_dir, item)
                dst_path = os.path.join(drive_results_dir, item)

                if os.path.isfile(src_path):
                    shutil.copy2(src_path, dst_path)
                    print(f"✅ Backed up: {item} to {backup_location}")

            # Create a summary backup info file
            backup_info = {
                'backup_timestamp': datetime.now().isoformat(),
                'original_results_dir': results_dir,
                'backup_location': drive_results_dir,
                'files_backed_up': os.listdir(results_dir),
                'backup_size_mb': sum(os.path.getsize(os.path.join(results_dir, f))
                                    for f in os.listdir(results_dir)
                                    if os.path.isfile(os.path.join(results_dir, f))) / (1024*1024)
            }

            backup_info_path = os.path.join(drive_results_dir, 'backup_info.json')
            with open(backup_info_path, 'w') as f:
                json.dump(backup_info, f, indent=2)

            backup_success += 1
            print(f"💾 Drive backup successful: {drive_results_dir}")

        except Exception as e:
            print(f"❌ Drive backup failed for {backup_location}: {e}")

    if backup_success > 0:
        print(f"✅ Successfully backed up to {backup_success}/{len(drive_backup_locations)} Google Drive locations")
    else:
        print(f"⚠️ No Google Drive backups successful - results saved locally only")

    return backup_success > 0

def save_training_metrics_to_drive():
    """Save training metrics from checkpoints to Google Drive"""
    print(f"\n📊 SAVING TRAINING METRICS TO GOOGLE DRIVE...")

    try:
        # Look for training metrics in checkpoints
        checkpoint_locations = [
            'checkpoints',
            '/content/drive/MyDrive/ViT-FishID/checkpoints_backup',
            'local_checkpoints'
        ]

        training_metrics = {
            'checkpoints_found': [],
            'training_history': {},
            'best_metrics': {}
        }

        for checkpoint_dir in checkpoint_locations:
            if os.path.exists(checkpoint_dir):
                checkpoints = [f for f in os.listdir(checkpoint_dir)
                             if f.startswith('checkpoint_epoch_') and f.endswith('.pth')]

                for checkpoint_file in sorted(checkpoints):
                    try:
                        checkpoint_path = os.path.join(checkpoint_dir, checkpoint_file)
                        checkpoint = torch.load(checkpoint_path, map_location='cpu')

                        epoch = checkpoint.get('epoch', 0)
                        training_metrics['checkpoints_found'].append(checkpoint_file)

                        # Extract metrics from checkpoint
                        if 'val_metrics' in checkpoint:
                            training_metrics['training_history'][epoch] = checkpoint['val_metrics']

                        if 'best_accuracy' in checkpoint:
                            if epoch not in training_metrics['best_metrics'] or \
                               checkpoint['best_accuracy'] > training_metrics['best_metrics'].get('accuracy', 0):
                                training_metrics['best_metrics'] = {
                                    'epoch': epoch,
                                    'accuracy': checkpoint['best_accuracy'],
                                    'checkpoint_file': checkpoint_file
                                }

                    except Exception as e:
                        print(f"⚠️ Could not read {checkpoint_file}: {e}")

        # Save training metrics to Google Drive
        drive_metrics_path = '/content/drive/MyDrive/ViT-FishID/training_metrics.json'
        os.makedirs(os.path.dirname(drive_metrics_path), exist_ok=True)

        with open(drive_metrics_path, 'w') as f:
            json.dump(training_metrics, f, indent=2)

        print(f"✅ Training metrics saved to: {drive_metrics_path}")
        return True

    except Exception as e:
        print(f"❌ Failed to save training metrics: {e}")
        return False

def enhanced_model_evaluation(model_path, data_path, device='cuda', max_batches=None, save_results=True):
    """
    Enhanced evaluation with comprehensive metrics, file export, and Google Drive backup
    """
    print(f"🚀 ENHANCED MODEL EVALUATION WITH GOOGLE DRIVE BACKUP")
    print(f"Time: {time.strftime('%H:%M:%S')}")
    print("=" * 70)

    start_time = time.time()
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

    # Load test dataset with progress
    print("📁 Loading test dataset...")
    try:
        from data import create_dataloaders

        # Load dataset info
        dataset_info_path = 'dataset_info.json'
        if not os.path.exists(dataset_info_path):
            repo_base_dir = '/content/ViT-FishID'
            dataset_info_path = os.path.join(repo_base_dir, 'dataset_info.json')

        if os.path.exists(dataset_info_path):
            with open(dataset_info_path, 'r') as f:
                dataset_info_loaded = json.load(f)
            num_classes_dataset = dataset_info_loaded['num_species']
            parent_data_dir = os.path.dirname(dataset_info_loaded['labeled_path'])
            print(f"✅ Loaded dataset info: {num_classes_dataset} classes, data from {parent_data_dir}")
        else:
            print("❌ dataset_info.json not found.")
            return None

        # Create data loaders
        train_loader, val_loader, test_loader, class_names = create_dataloaders(
            data_dir=parent_data_dir,
            batch_size=32,
            image_size=224,
            val_split=0.2,
            test_split=0.2,
            seed=42,
            num_workers=2
        )

        num_classes = len(class_names)
        dataset_load_time = time.time() - start_time
        print(f"✅ Dataset loaded: {len(test_loader.dataset)} test images in {dataset_load_time:.1f}s")

        # Create species indexing
        species_index, index_to_species, shortened_names = create_species_index_mapping(class_names)
        print(f"✅ Species indexing created: {len(species_index)} species")

    except Exception as e:
        print(f"❌ Dataset loading failed: {e}")
        import traceback
        traceback.print_exc()
        return None

    # Load model
    print("\n🧠 Loading model...")
    model_start_time = time.time()

    try:
        from model import ViTForFishClassification

        model = ViTForFishClassification(num_classes=num_classes)
        checkpoint = torch.load(model_path, map_location=device)

        if 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'])
            epoch = checkpoint.get('epoch', 'Unknown')
            print(f"✅ Model loaded from epoch {epoch}")
        elif 'student_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['student_state_dict'])
            epoch = checkpoint.get('epoch', 'Unknown')
            print(f"✅ Student model loaded from epoch {epoch}")
        else:
            model.load_state_dict(checkpoint)
            print(f"✅ Model loaded from checkpoint")

        model.to(device)
        model.eval()

        model_load_time = time.time() - model_start_time
        print(f"   Model loading took: {model_load_time:.1f}s")

    except Exception as e:
        print(f"❌ Model loading failed: {e}")
        import traceback
        traceback.print_exc()
        return None

    # Evaluation with progress bar
    print(f"\n🔬 Starting comprehensive evaluation...")
    eval_start_time = time.time()

    all_predictions = []
    all_labels = []
    all_probabilities = []

    total_batches = len(test_loader)
    if max_batches:
        total_batches = min(total_batches, max_batches)

    # Progress bar
    pbar = tqdm(total=total_batches, desc="Evaluating",
                bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]')

    with torch.no_grad():
        for batch_idx, (images, labels) in enumerate(test_loader):
            if max_batches and batch_idx >= max_batches:
                break

            batch_start = time.time()

            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            probabilities = torch.softmax(outputs, dim=1)
            _, predicted = torch.max(outputs, 1)

            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probabilities.extend(probabilities.cpu().numpy())

            batch_time = time.time() - batch_start

            pbar.set_postfix({
                'Batch_Time': f'{batch_time:.2f}s',
                'Images/s': f'{len(images)/batch_time:.1f}',
                'Processed': f'{len(all_predictions)}'
            })
            pbar.update(1)

    pbar.close()
    eval_time = time.time() - eval_start_time

    print(f"\n✅ Evaluation complete!")
    print(f"   Total time: {eval_time:.1f}s")
    print(f"   Speed: {len(all_predictions)/eval_time:.1f} images/second")

    # Calculate comprehensive metrics
    print(f"\n📊 CALCULATING COMPREHENSIVE METRICS...")

    # Convert to numpy arrays
    all_predictions = np.array(all_predictions)
    all_labels = np.array(all_labels)
    all_probabilities = np.array(all_probabilities)

    # Top-1 and Top-5 accuracy
    top1_accuracy = accuracy_score(all_labels, all_predictions)
    top5_accuracy = calculate_topk_accuracy(all_probabilities, all_labels, k=5)

    # Per-class metrics
    precision, recall, f1_score, support = precision_recall_fscore_support(
        all_labels, all_predictions, average=None, zero_division=0
    )

    # Macro and weighted averages
    precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
        all_labels, all_predictions, average='macro', zero_division=0
    )
    precision_weighted, recall_weighted, f1_weighted, _ = precision_recall_fscore_support(
        all_labels, all_predictions, average='weighted', zero_division=0
    )

    # Confidence statistics
    confidences = np.max(all_probabilities, axis=1)
    mean_confidence = np.mean(confidences)
    correct_mask = (all_predictions == all_labels)

    # Confusion matrix
    cm = confusion_matrix(all_labels, all_predictions)

    print(f"\n🎯 EVALUATION RESULTS")
    print("=" * 50)
    print(f"📊 Top-1 Accuracy: {top1_accuracy:.4f} ({top1_accuracy*100:.2f}%)")
    print(f"📊 Top-5 Accuracy: {top5_accuracy:.4f} ({top5_accuracy*100:.2f}%)")
    print(f"📊 Macro F1-Score: {f1_macro:.4f}")
    print(f"📊 Weighted F1-Score: {f1_weighted:.4f}")
    print(f"📊 Mean Confidence: {mean_confidence:.4f}")
    print(f"📊 Total Test Samples: {len(all_labels)}")
    print(f"📊 Number of Classes: {num_classes}")

    # Create comprehensive results dictionary
    results = {
        'evaluation_metadata': {
            'timestamp': timestamp,
            'model_path': model_path,
            'data_path': data_path,
            'device': str(device),
            'evaluation_time_seconds': eval_time,
            'total_samples': len(all_labels),
            'num_classes': num_classes,
            'google_drive_backup': True
        },
        'accuracy_metrics': {
            'top1_accuracy': float(top1_accuracy),
            'top5_accuracy': float(top5_accuracy)
        },
        'aggregate_metrics': {
            'precision_macro': float(precision_macro),
            'recall_macro': float(recall_macro),
            'f1_score_macro': float(f1_macro),
            'precision_weighted': float(precision_weighted),
            'recall_weighted': float(recall_weighted),
            'f1_score_weighted': float(f1_weighted)
        },
        'confidence_metrics': {
            'mean_confidence': float(mean_confidence),
            'high_confidence_samples': int(np.sum(confidences > 0.9)),
            'low_confidence_samples': int(np.sum(confidences < 0.5)),
            'confidence_std': float(np.std(confidences))
        },
        'species_indexing': {
            'species_index': species_index,
            'index_to_species': index_to_species,
            'shortened_names': shortened_names
        },
        'per_class_metrics': {},
        'confusion_matrix': cm.tolist()  # Convert for JSON serialization
    }

    # Calculate per-class metrics
    for i, class_name in enumerate(class_names):
        mask = (all_labels == i)
        class_samples = np.sum(mask)

        if class_samples > 0:
            class_accuracy = np.sum((all_predictions == i) & mask) / class_samples
            results['per_class_metrics'][f'class_{i:02d}'] = {
                'species_name': class_name,
                'shortened_name': shortened_names[i],
                'samples': int(class_samples),
                'accuracy': float(class_accuracy),
                'precision': float(precision[i]),
                'recall': float(recall[i]),
                'f1_score': float(f1_score[i])
            }

    # Save results to files
    if save_results:
        print(f"\n💾 SAVING EVALUATION RESULTS...")

        # Create results directory
        results_dir = f"evaluation_results_{timestamp}"
        os.makedirs(results_dir, exist_ok=True)

        # Save comprehensive results as JSON
        json_path = os.path.join(results_dir, f"evaluation_metrics_{timestamp}.json")
        with open(json_path, 'w') as f:
            json.dump(results, f, indent=2)
        print(f"✅ Metrics saved to: {json_path}")

        # Save species index mapping
        index_path = os.path.join(results_dir, f"species_index_{timestamp}.json")
        index_mapping = {
            'species_to_index': species_index,
            'index_to_species': {str(k): v for k, v in index_to_species.items()},
            'shortened_names': {str(i): name for i, name in enumerate(shortened_names)},
            'full_names': class_names
        }
        with open(index_path, 'w') as f:
            json.dump(index_mapping, f, indent=2)
        print(f"✅ Species indexing saved to: {index_path}")

        # Save confusion matrix
        cm_path = os.path.join(results_dir, f"confusion_matrix_{timestamp}.csv")
        np.savetxt(cm_path, cm, delimiter=',', fmt='%d')
        print(f"✅ Confusion matrix saved to: {cm_path}")

        # Save detailed classification report
        report_path = os.path.join(results_dir, f"classification_report_{timestamp}.txt")
        with open(report_path, 'w') as f:
            f.write(f"ENHANCED MODEL EVALUATION REPORT\n")
            f.write(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
            f.write(f"Model: {model_path}\n")
            f.write(f"Data: {data_path}\n\n")
            f.write(f"ACCURACY METRICS:\n")
            f.write(f"Top-1 Accuracy: {top1_accuracy:.4f} ({top1_accuracy*100:.2f}%)\n")
            f.write(f"Top-5 Accuracy: {top5_accuracy:.4f} ({top5_accuracy*100:.2f}%)\n\n")
            f.write(f"AGGREGATE METRICS:\n")
            f.write(f"Macro F1-Score: {f1_macro:.4f}\n")
            f.write(f"Weighted F1-Score: {f1_weighted:.4f}\n")
            f.write(f"Macro Precision: {precision_macro:.4f}\n")
            f.write(f"Macro Recall: {recall_macro:.4f}\n\n")
            f.write(f"SPECIES INDEX MAPPING:\n")
            for i, (full_name, short_name) in enumerate(zip(class_names, shortened_names)):
                f.write(f"{i:02d}: {short_name} = {full_name}\n")
            f.write(f"\nDETAILED CLASSIFICATION REPORT:\n")
            f.write(classification_report(all_labels, all_predictions,
                                        target_names=[f"{i:02d}_{name}" for i, name in enumerate(shortened_names)]))
        print(f"✅ Detailed report saved to: {report_path}")

        # Save visualizations
        plt.figure(figsize=(20, 16))

        # Create subplots
        fig, axes = plt.subplots(2, 2, figsize=(20, 16))
        fig.suptitle(f'Enhanced Evaluation Results - Top-1: {top1_accuracy:.3f}, Top-5: {top5_accuracy:.3f}',
                     fontsize=16)

        # Top-1 vs Top-5 Accuracy
        accuracies = [top1_accuracy, top5_accuracy]
        accuracy_labels = ['Top-1', 'Top-5']
        bars = axes[0,0].bar(accuracy_labels, accuracies, color=['skyblue', 'lightgreen'], edgecolor='navy')
        axes[0,0].set_ylim(0, 1)
        axes[0,0].set_title(f'Top-K Accuracy Comparison\n{len(all_labels)} test images')
        axes[0,0].set_ylabel('Accuracy')
        for bar, acc in zip(bars, accuracies):
            axes[0,0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                          f'{acc:.3f}', ha='center', va='bottom', fontweight='bold')

        # Per-class F1-scores
        class_indices = list(range(min(len(f1_score), 20)))
        f1_subset = f1_score[:20] if len(f1_score) > 20 else f1_score
        colors = ['green' if f1 >= 0.8 else 'orange' if f1 >= 0.6 else 'red' for f1 in f1_subset]

        bars = axes[0,1].bar(class_indices, f1_subset, color=colors, alpha=0.7)
        axes[0,1].set_xlabel('Species Index')
        axes[0,1].set_ylabel('F1-Score')
        axes[0,1].set_title(f'Per-Class F1-Score (First 20 Species)\nMacro F1: {f1_macro:.3f}')
        axes[0,1].set_xticks(class_indices)
        axes[0,1].set_ylim(0, 1)

        # Confusion Matrix
        if num_classes <= 15:
            sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[1,0],
                       xticklabels=[f"{i:02d}" for i in range(num_classes)],
                       yticklabels=[f"{i:02d}" for i in range(num_classes)])
            axes[1,0].set_title('Confusion Matrix (Indexed)')
        else:
            im = axes[1,0].imshow(cm, cmap='Blues', aspect='auto')
            axes[1,0].set_title(f'Confusion Matrix ({num_classes} classes)')
            axes[1,0].set_xlabel('Predicted Class Index')
            axes[1,0].set_ylabel('Actual Class Index')
            plt.colorbar(im, ax=axes[1,0])

        # Confidence Distribution
        correct_confidences = confidences[correct_mask]
        incorrect_confidences = confidences[~correct_mask]

        axes[1,1].hist(correct_confidences, bins=20, alpha=0.7,
                      label=f'Correct ({len(correct_confidences)})', color='green', density=True)
        axes[1,1].hist(incorrect_confidences, bins=20, alpha=0.7,
                      label=f'Incorrect ({len(incorrect_confidences)})', color='red', density=True)
        axes[1,1].axvline(mean_confidence, color='black', linestyle='--',
                         label=f'Mean: {mean_confidence:.3f}')
        axes[1,1].set_xlabel('Prediction Confidence')
        axes[1,1].set_ylabel('Density')
        axes[1,1].set_title('Confidence Distribution')
        axes[1,1].legend()

        plt.tight_layout()

        # Save visualization
        viz_path = os.path.join(results_dir, f"evaluation_visualization_{timestamp}.png")
        plt.savefig(viz_path, dpi=300, bbox_inches='tight')
        print(f"✅ Visualization saved to: {viz_path}")
        plt.show()

        print(f"📁 All results saved in directory: {results_dir}")

        # 💾 BACKUP TO GOOGLE DRIVE
        drive_backup_success = backup_results_to_google_drive(results_dir, timestamp)

        # Save training metrics to Google Drive as well
        training_metrics_success = save_training_metrics_to_drive()

        if drive_backup_success:
            print(f"\n🎉 COMPLETE SUCCESS!")
            print(f"✅ Local results saved: {results_dir}")
            print(f"✅ Google Drive backup: Multiple locations")
            print(f"✅ Training metrics: Saved to Drive")
        else:
            print(f"\n⚠️ PARTIAL SUCCESS:")
            print(f"✅ Local results saved: {results_dir}")
            print(f"❌ Google Drive backup: Failed (check mount)")

    # Print species index mapping
    print(f"\n🏷️  SPECIES INDEX MAPPING:")
    print("=" * 60)
    print("Index | Short Name    | Full Species Name")
    print("-" * 60)
    for i, (full_name, short_name) in enumerate(zip(class_names, shortened_names)):
        print(f"{i:02d}   | {short_name:12s} | {full_name}")

    # Print top and worst performing species
    f1_with_names = [(f1_score[i], i, class_names[i], shortened_names[i])
                     for i in range(len(class_names))
                     if i < len(f1_score) and support[i] > 0]
    f1_with_names.sort(reverse=True, key=lambda x: x[0])

    print(f"\n🏆 TOP 5 PERFORMING SPECIES (by F1-Score):")
    for f1, idx, full_name, short_name in f1_with_names[:5]:
        print(f"  {idx:02d} ({short_name}): {f1:.3f} - {full_name}")

    print(f"\n⚠️  BOTTOM 5 PERFORMING SPECIES (by F1-Score):")
    for f1, idx, full_name, short_name in f1_with_names[-5:]:
        print(f"  {idx:02d} ({short_name}): {f1:.3f} - {full_name}")

    return results

# Run enhanced evaluation
if 'SELECTED_MODEL_PATH' in globals() and 'TEST_DATA_PATH' in globals():
    print(f"🎯 Starting enhanced evaluation with comprehensive metrics and Google Drive backup...")
    print(f"Model: {SELECTED_MODEL_PATH}")
    print(f"Test Data: {TEST_DATA_PATH}")
    print(f"Device: {DEVICE}")

    results = enhanced_model_evaluation(
        model_path=SELECTED_MODEL_PATH,
        data_path=TEST_DATA_PATH,
        device=DEVICE,
        save_results=True
    )

    if results:
        print(f"\n🎉 ENHANCED EVALUATION WITH GOOGLE DRIVE BACKUP COMPLETED!")
        print(f"📊 Top-1 Accuracy: {results['accuracy_metrics']['top1_accuracy']:.4f}")
        print(f"📊 Top-5 Accuracy: {results['accuracy_metrics']['top5_accuracy']:.4f}")
        print(f"📊 Macro F1-Score: {results['aggregate_metrics']['f1_score_macro']:.4f}")
        print(f"💾 Results exported to files with timestamp")
        print(f"☁️  All results backed up to Google Drive!")
else:
    print("⚠️ Please run the model selection cell first")

## 🎯 Quick Test: Evaluate Trained Model

**Use this cell to quickly test your already trained ViT model with proper train/val/test split (60/20/20)**

This will load your checkpoint from epoch 100 and evaluate it on the test set.

In [None]:
# Enhanced Test Trained ViT Model - Complete Evaluation with Google Drive Backup
%pip install -q seaborn

import torch
import numpy as np
import os
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, precision_recall_fscore_support
import time
from tqdm import tqdm
import json
import shutil
from datetime import datetime

def create_species_index_mapping(class_names):
    """Create a mapping from full species names to indices and shortened names"""
    species_index = {}
    index_to_species = {}
    shortened_names = []

    for idx, full_name in enumerate(class_names):
        # Create shortened name (first letter of family + first few letters of species)
        parts = full_name.split('_')
        if len(parts) >= 3:
            family = parts[0]
            genus = parts[1]
            species = parts[2]
            shortened = f"{family[0]}{genus[0]}{species[:3]}"
        else:
            shortened = full_name[:6]

        species_index[full_name] = idx
        index_to_species[idx] = full_name
        shortened_names.append(f"{idx:02d}_{shortened}")

    return species_index, index_to_species, shortened_names

def calculate_topk_accuracy(predictions_probs, labels, k=5):
    """Calculate top-k accuracy"""
    # Get top-k predictions for each sample
    _, topk_pred = torch.topk(torch.tensor(predictions_probs), k, dim=1)

    # Check if true label is in top-k predictions
    labels_expanded = torch.tensor(labels).unsqueeze(1).expand_as(topk_pred)
    correct = (topk_pred == labels_expanded).any(dim=1)

    return correct.float().mean().item()

def backup_to_google_drive(local_results_dir, timestamp):
    """Backup evaluation results to multiple Google Drive locations"""
    print(f"\n☁️  BACKING UP TO GOOGLE DRIVE...")

    drive_backup_paths = [
        '/content/drive/MyDrive/ViT-FishID/evaluation_results',
        '/content/drive/MyDrive/ViT-FishID/model_testing_results',
        '/content/drive/MyDrive/ViT-FishID/backup_evaluation'
    ]

    backup_success = 0

    for drive_path in drive_backup_paths:
        try:
            # Create timestamped backup directory
            drive_results_dir = os.path.join(drive_path, f"test_evaluation_{timestamp}")
            os.makedirs(drive_results_dir, exist_ok=True)

            # Copy all files to Google Drive
            files_copied = 0
            for filename in os.listdir(local_results_dir):
                src_path = os.path.join(local_results_dir, filename)
                dst_path = os.path.join(drive_results_dir, filename)

                if os.path.isfile(src_path):
                    shutil.copy2(src_path, dst_path)
                    files_copied += 1

            # Create backup summary
            backup_summary = {
                'backup_timestamp': datetime.now().isoformat(),
                'source_directory': local_results_dir,
                'backup_location': drive_results_dir,
                'files_copied': files_copied,
                'total_size_mb': sum(os.path.getsize(os.path.join(local_results_dir, f))
                                   for f in os.listdir(local_results_dir)
                                   if os.path.isfile(os.path.join(local_results_dir, f))) / (1024*1024)
            }

            summary_path = os.path.join(drive_results_dir, 'backup_summary.json')
            with open(summary_path, 'w') as f:
                json.dump(backup_summary, f, indent=2)

            print(f"✅ Drive backup: {files_copied} files → {drive_path}")
            backup_success += 1

        except Exception as e:
            print(f"❌ Drive backup failed for {drive_path}: {e}")

    print(f"📊 Google Drive backup: {backup_success}/{len(drive_backup_paths)} locations successful")
    return backup_success > 0

def test_trained_vit_model():
    """Test trained ViT model with comprehensive evaluation and Google Drive backup"""

    print("🧪 TESTING TRAINED ViT MODEL WITH GOOGLE DRIVE BACKUP")
    print("="*70)

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

    # Configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"🔧 Using device: {device}")

    # Model path - look for latest checkpoint
    checkpoint_paths = [
        'checkpoints',
        '/content/drive/MyDrive/ViT-FishID/checkpoints_backup',
        'local_checkpoints'
    ]

    model_checkpoint = None
    for checkpoint_dir in checkpoint_paths:
        if os.path.exists(checkpoint_dir):
            checkpoints = [f for f in os.listdir(checkpoint_dir)
                         if f.startswith('checkpoint_epoch_') and f.endswith('.pth')]
            if checkpoints:
                # Get latest checkpoint
                latest = max(checkpoints, key=lambda x: int(x.split('epoch_')[1].split('.')[0]))
                model_checkpoint = os.path.join(checkpoint_dir, latest)
                print(f"✅ Found model checkpoint: {model_checkpoint}")
                break

    if not model_checkpoint:
        print("❌ No trained model checkpoint found!")
        return None

    # Data configuration
    try:
        # Load dataset info
        dataset_info_path = 'dataset_info.json'
        if not os.path.exists(dataset_info_path):
            dataset_info_path = '/content/ViT-FishID/dataset_info.json'

        if not os.path.exists(dataset_info_path):
            print("❌ dataset_info.json not found")
            return None

        with open(dataset_info_path, 'r') as f:
            dataset_info = json.load(f)

        parent_data_dir = os.path.dirname(dataset_info['labeled_path'])
        print(f"✅ Data directory: {parent_data_dir}")

        # Create data loaders
        from data import create_dataloaders

        train_loader, val_loader, test_loader, class_names = create_dataloaders(
            data_dir=parent_data_dir,
            batch_size=32,
            image_size=224,
            val_split=0.2,
            test_split=0.2,
            seed=42,
            num_workers=2
        )

        num_classes = len(class_names)
        print(f"✅ Dataset loaded: {num_classes} classes, {len(test_loader.dataset)} test images")

        # Create species indexing
        species_index, index_to_species, shortened_names = create_species_index_mapping(class_names)

    except Exception as e:
        print(f"❌ Dataset loading failed: {e}")
        return None

    # Load model
    try:
        from model import ViTForFishClassification

        print(f"🧠 Loading model...")
        model = ViTForFishClassification(num_classes=num_classes)
        checkpoint = torch.load(model_checkpoint, map_location=device)

        if 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'])
        elif 'student_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['student_state_dict'])
        else:
            model.load_state_dict(checkpoint)

        model.to(device)
        model.eval()

        epoch = checkpoint.get('epoch', 'Unknown')
        print(f"✅ Model loaded from epoch {epoch}")

    except Exception as e:
        print(f"❌ Model loading failed: {e}")
        return None

    # Evaluation
    print(f"\n🔬 Running evaluation...")

    all_predictions = []
    all_labels = []
    all_probabilities = []

    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Testing"):
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            probabilities = torch.softmax(outputs, dim=1)
            _, predicted = torch.max(outputs, 1)

            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probabilities.extend(probabilities.cpu().numpy())

    # Convert to numpy arrays
    all_predictions = np.array(all_predictions)
    all_labels = np.array(all_labels)
    all_probabilities = np.array(all_probabilities)

    # Calculate metrics
    top1_accuracy = accuracy_score(all_labels, all_predictions)
    top5_accuracy = calculate_topk_accuracy(all_probabilities, all_labels, k=5)

    # Per-class metrics
    precision, recall, f1_score, support = precision_recall_fscore_support(
        all_labels, all_predictions, average=None, zero_division=0
    )

    # Macro and weighted averages
    precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
        all_labels, all_predictions, average='macro', zero_division=0
    )
    precision_weighted, recall_weighted, f1_weighted, _ = precision_recall_fscore_support(
        all_labels, all_predictions, average='weighted', zero_division=0
    )

    # Confidence statistics
    confidences = np.max(all_probabilities, axis=1)
    mean_confidence = np.mean(confidences)
    correct_mask = (all_predictions == all_labels)

    # Confusion matrix
    cm = confusion_matrix(all_labels, all_predictions)

    print(f"\n� TEST RESULTS")
    print("=" * 50)
    print(f"📊 Top-1 Accuracy: {top1_accuracy:.4f} ({top1_accuracy*100:.2f}%)")
    print(f"📊 Top-5 Accuracy: {top5_accuracy:.4f} ({top5_accuracy*100:.2f}%)")
    print(f"📊 Macro F1-Score: {f1_macro:.4f}")
    print(f"📊 Weighted F1-Score: {f1_weighted:.4f}")
    print(f"📊 Mean Confidence: {mean_confidence:.4f}")

    # Create comprehensive results
    results = {
        'test_metadata': {
            'timestamp': timestamp,
            'model_checkpoint': model_checkpoint,
            'data_directory': parent_data_dir,
            'device': str(device),
            'total_samples': len(all_labels),
            'num_classes': num_classes,
            'google_drive_backup': True
        },
        'accuracy_metrics': {
            'top1_accuracy': float(top1_accuracy),
            'top5_accuracy': float(top5_accuracy)
        },
        'aggregate_metrics': {
            'precision_macro': float(precision_macro),
            'recall_macro': float(recall_macro),
            'f1_score_macro': float(f1_macro),
            'precision_weighted': float(precision_weighted),
            'recall_weighted': float(recall_weighted),
            'f1_score_weighted': float(f1_weighted)
        },
        'confidence_metrics': {
            'mean_confidence': float(mean_confidence),
            'high_confidence_samples': int(np.sum(confidences > 0.9)),
            'low_confidence_samples': int(np.sum(confidences < 0.5)),
            'confidence_std': float(np.std(confidences))
        },
        'species_indexing': {
            'species_index': species_index,
            'index_to_species': index_to_species,
            'shortened_names': shortened_names
        },
        'per_class_metrics': {},
        'confusion_matrix': cm.tolist()
    }

    # Calculate per-class metrics
    for i, class_name in enumerate(class_names):
        if i < len(f1_score):
            results['per_class_metrics'][f'class_{i:02d}'] = {
                'species_name': class_name,
                'shortened_name': shortened_names[i],
                'samples': int(support[i]),
                'precision': float(precision[i]),
                'recall': float(recall[i]),
                'f1_score': float(f1_score[i])
            }

    # Save results locally
    results_dir = f"test_results_{timestamp}"
    os.makedirs(results_dir, exist_ok=True)

    # JSON results
    json_path = os.path.join(results_dir, f"test_metrics_{timestamp}.json")
    with open(json_path, 'w') as f:
        json.dump(results, f, indent=2)
    print(f"✅ Test metrics saved to: {json_path}")

    # Species index mapping
    index_path = os.path.join(results_dir, f"species_index_{timestamp}.json")
    index_mapping = {
        'species_to_index': species_index,
        'index_to_species': {str(k): v for k, v in index_to_species.items()},
        'shortened_names': {str(i): name for i, name in enumerate(shortened_names)},
        'full_names': class_names
    }
    with open(index_path, 'w') as f:
        json.dump(index_mapping, f, indent=2)
    print(f"✅ Species index saved to: {index_path}")

    # Confusion matrix CSV
    cm_path = os.path.join(results_dir, f"confusion_matrix_{timestamp}.csv")
    np.savetxt(cm_path, cm, delimiter=',', fmt='%d')
    print(f"✅ Confusion matrix saved to: {cm_path}")

    # Detailed text report
    report_path = os.path.join(results_dir, f"test_report_{timestamp}.txt")
    with open(report_path, 'w') as f:
        f.write(f"TRAINED ViT MODEL TEST REPORT\n")
        f.write(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        f.write(f"Model: {model_checkpoint}\n")
        f.write(f"Test Images: {len(all_labels)}\n")
        f.write(f"Classes: {num_classes}\n\n")
        f.write(f"ACCURACY METRICS:\n")
        f.write(f"Top-1 Accuracy: {top1_accuracy:.4f} ({top1_accuracy*100:.2f}%)\n")
        f.write(f"Top-5 Accuracy: {top5_accuracy:.4f} ({top5_accuracy*100:.2f}%)\n\n")
        f.write(f"AGGREGATE METRICS:\n")
        f.write(f"Macro F1-Score: {f1_macro:.4f}\n")
        f.write(f"Weighted F1-Score: {f1_weighted:.4f}\n\n")
        f.write(f"SPECIES MAPPING:\n")
        for i, (full_name, short_name) in enumerate(zip(class_names, shortened_names)):
            if i < len(f1_score):
                f.write(f"{i:02d}: {short_name} = {full_name} (F1: {f1_score[i]:.3f})\n")
        f.write(f"\nCLASSIFICATION REPORT:\n")
        f.write(classification_report(all_labels, all_predictions,
                                    target_names=[f"{i:02d}_{name}" for i, name in enumerate(shortened_names)]))
    print(f"✅ Test report saved to: {report_path}")

    # Create and save visualizations
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    fig.suptitle(f'ViT Model Test Results - Top-1: {top1_accuracy:.3f}, Top-5: {top5_accuracy:.3f}',
                 fontsize=14)

    # Top-K accuracy comparison
    accuracies = [top1_accuracy, top5_accuracy]
    bars = axes[0,0].bar(['Top-1', 'Top-5'], accuracies, color=['skyblue', 'lightgreen'])
    axes[0,0].set_ylim(0, 1)
    axes[0,0].set_title('Top-K Accuracy')
    axes[0,0].set_ylabel('Accuracy')
    for bar, acc in zip(bars, accuracies):
        axes[0,0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                      f'{acc:.3f}', ha='center', va='bottom', fontweight='bold')

    # F1-scores (first 20 classes)
    class_indices = list(range(min(len(f1_score), 20)))
    f1_subset = f1_score[:20] if len(f1_score) > 20 else f1_score
    colors = ['green' if f1 >= 0.8 else 'orange' if f1 >= 0.6 else 'red' for f1 in f1_subset]

    axes[0,1].bar(class_indices, f1_subset, color=colors, alpha=0.7)
    axes[0,1].set_xlabel('Species Index')
    axes[0,1].set_ylabel('F1-Score')
    axes[0,1].set_title(f'Per-Class F1 (First 20)\nMacro F1: {f1_macro:.3f}')
    axes[0,1].set_xticks(class_indices)
    axes[0,1].set_ylim(0, 1)

    # Confusion matrix
    if num_classes <= 15:
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[1,0],
                   xticklabels=[f"{i:02d}" for i in range(num_classes)],
                   yticklabels=[f"{i:02d}" for i in range(num_classes)])
        axes[1,0].set_title('Confusion Matrix')
    else:
        im = axes[1,0].imshow(cm, cmap='Blues', aspect='auto')
        axes[1,0].set_title(f'Confusion Matrix ({num_classes} classes)')
        axes[1,0].set_xlabel('Predicted')
        axes[1,0].set_ylabel('Actual')
        plt.colorbar(im, ax=axes[1,0])

    # Confidence distribution
    correct_confidences = confidences[correct_mask]
    incorrect_confidences = confidences[~correct_mask]

    axes[1,1].hist(correct_confidences, bins=20, alpha=0.7,
                  label=f'Correct ({len(correct_confidences)})', color='green', density=True)
    axes[1,1].hist(incorrect_confidences, bins=20, alpha=0.7,
                  label=f'Incorrect ({len(incorrect_confidences)})', color='red', density=True)
    axes[1,1].axvline(mean_confidence, color='black', linestyle='--',
                     label=f'Mean: {mean_confidence:.3f}')
    axes[1,1].set_xlabel('Confidence')
    axes[1,1].set_ylabel('Density')
    axes[1,1].set_title('Confidence Distribution')
    axes[1,1].legend()

    plt.tight_layout()

    # Save visualization
    viz_path = os.path.join(results_dir, f"test_visualization_{timestamp}.png")
    plt.savefig(viz_path, dpi=300, bbox_inches='tight')
    print(f"✅ Visualization saved to: {viz_path}")
    plt.show()

    # 💾 GOOGLE DRIVE BACKUP
    drive_success = backup_to_google_drive(results_dir, timestamp)

    if drive_success:
        print(f"\n� COMPLETE SUCCESS!")
        print(f"✅ Local test results: {results_dir}")
        print(f"☁️  Google Drive backup: Multiple locations")
        print(f"📊 All metrics and visualizations saved and backed up!")
    else:
        print(f"\n⚠️ PARTIAL SUCCESS:")
        print(f"✅ Local test results: {results_dir}")
        print(f"❌ Google Drive backup: Check Drive mount")

    # Print summary
    print(f"\n📋 SPECIES INDEX REFERENCE:")
    print("-" * 50)
    for i, (full_name, short_name) in enumerate(zip(class_names[:10], shortened_names[:10])):
        f1_val = f1_score[i] if i < len(f1_score) else 0.0
        print(f"{i:02d}: {short_name:12s} | {full_name} (F1: {f1_val:.3f})")
    if len(class_names) > 10:
        print(f"... and {len(class_names)-10} more species")

    return results

# Run the test
try:
    test_results = test_trained_vit_model()
    if test_results:
        print(f"\n� FINAL TEST SUMMARY:")
        print(f"Top-1 Accuracy: {test_results['accuracy_metrics']['top1_accuracy']:.4f}")
        print(f"Top-5 Accuracy: {test_results['accuracy_metrics']['top5_accuracy']:.4f}")
        print(f"Macro F1-Score: {test_results['aggregate_metrics']['f1_score_macro']:.4f}")
        print(f"☁️  Results backed up to Google Drive!")
except Exception as e:
    print(f"❌ Test failed: {e}")
    import traceback
    traceback.print_exc()

In [None]:
# 📋 SPECIES INDEX REFERENCE - Quick Lookup Table
import json

def print_species_index_reference():
    """Print a clean species index reference table"""

    # Try to load from saved index file first
    index_files = [
        'species_index_*.json',
        'evaluation_results_*/species_index_*.json',
        'trained_model_evaluation_*/species_index_*.json'
    ]

    species_mapping = None

    # Try to find existing index file
    import glob
    for pattern in index_files:
        files = glob.glob(pattern)
        if files:
            try:
                with open(files[0], 'r') as f:
                    data = json.load(f)
                    species_mapping = data
                    print(f"✅ Loaded species index from: {files[0]}")
                    break
            except:
                continue

    # If no saved file found, create from class names if available
    if species_mapping is None:
        print("🔄 Creating species index from dataset...")
        try:
            # Try to load dataset info
            dataset_info_path = 'dataset_info.json'
            if not os.path.exists(dataset_info_path):
                dataset_info_path = '/content/ViT-FishID/dataset_info.json'

            if os.path.exists(dataset_info_path):
                with open(dataset_info_path, 'r') as f:
                    dataset_info = json.load(f)

                # Get class names from data loader
                from data import create_dataloaders
                parent_data_dir = os.path.dirname(dataset_info['labeled_path'])
                _, _, _, class_names = create_dataloaders(
                    data_dir=parent_data_dir, batch_size=32, seed=42
                )

                # Create mapping
                species_mapping = {
                    'full_names': class_names,
                    'shortened_names': {},
                    'species_to_index': {},
                    'index_to_species': {}
                }

                for idx, full_name in enumerate(class_names):
                    parts = full_name.split('_')
                    if len(parts) >= 3:
                        family = parts[0]
                        genus = parts[1]
                        species = parts[2]
                        shortened = f"{family[0]}{genus[0]}{species[:3]}"
                    else:
                        shortened = full_name[:6]

                    short_name = f"{idx:02d}_{shortened}"

                    species_mapping['shortened_names'][str(idx)] = short_name
                    species_mapping['species_to_index'][full_name] = idx
                    species_mapping['index_to_species'][str(idx)] = full_name

                print("✅ Species index created from dataset")
            else:
                print("❌ No dataset info found")
                return

        except Exception as e:
            print(f"❌ Error creating species index: {e}")
            return

    # Print the reference table
    print(f"\n🏷️  COMPLETE SPECIES INDEX REFERENCE")
    print("=" * 80)
    print("Index | Short Code    | Full Species Name")
    print("-" * 80)

    if 'full_names' in species_mapping:
        full_names = species_mapping['full_names']
        shortened_names = species_mapping.get('shortened_names', {})

        for idx, full_name in enumerate(full_names):
            short_name = shortened_names.get(str(idx), f"{idx:02d}_???")
            print(f"{idx:02d}   | {short_name:12s} | {full_name}")

    elif 'index_to_species' in species_mapping:
        index_to_species = species_mapping['index_to_species']
        shortened_names = species_mapping.get('shortened_names', {})

        for idx_str in sorted(index_to_species.keys(), key=int):
            idx = int(idx_str)
            full_name = index_to_species[idx_str]
            short_name = shortened_names.get(idx_str, f"{idx:02d}_???")
            print(f"{idx:02d}   | {short_name:12s} | {full_name}")

    print("=" * 80)
    print(f"Total Species: {len(species_mapping.get('full_names', species_mapping.get('index_to_species', {})))}")

    # Print family breakdown
    if 'full_names' in species_mapping or 'index_to_species' in species_mapping:
        full_names = species_mapping.get('full_names', list(species_mapping.get('index_to_species', {}).values()))

        families = {}
        for name in full_names:
            family = name.split('_')[0]
            families[family] = families.get(family, 0) + 1

        print(f"\n🐟 FAMILY BREAKDOWN:")
        print("-" * 30)
        for family, count in sorted(families.items()):
            print(f"{family:20s}: {count:2d} species")

    # Usage examples
    print(f"\n💡 USAGE IN CODE:")
    print("-" * 30)
    print("# In your graphs and analyses, use:")
    print("species_indices = list(range(num_classes))")
    print("species_labels = [f'{i:02d}' for i in species_indices]")
    print("")
    print("# For matplotlib/seaborn:")
    print("plt.xticks(species_indices, species_labels)")
    print("sns.heatmap(cm, xticklabels=species_labels, yticklabels=species_labels)")

    return species_mapping

# Generate the species index reference
print("📋 GENERATING SPECIES INDEX REFERENCE...")
species_data = print_species_index_reference()

# Save as a standalone reference file
if species_data:
    try:
        reference_file = "species_index_reference.json"
        with open(reference_file, 'w') as f:
            json.dump(species_data, f, indent=2)
        print(f"\n💾 Species reference saved to: {reference_file}")

        # Also save as a simple CSV for easy Excel import
        csv_file = "species_index_reference.csv"
        with open(csv_file, 'w') as f:
            f.write("Index,Short_Code,Full_Species_Name\n")

            if 'full_names' in species_data:
                for idx, full_name in enumerate(species_data['full_names']):
                    short_name = species_data.get('shortened_names', {}).get(str(idx), f"{idx:02d}_???")
                    f.write(f"{idx:02d},{short_name},{full_name}\n")

        print(f"📊 CSV reference saved to: {csv_file}")

    except Exception as e:
        print(f"⚠️ Could not save reference files: {e}")

print(f"\n✅ Species index reference complete!")