# üöÄ ViT-FishID: MAE Pretraining + EMA Student-Teacher Training

**ADVANCED FISH CLASSIFICATION WITH MASKED AUTOENCODERS & SEMI-SUPERVISED LEARNING**

<a href="https://colab.research.google.com/github/cat-thomson/ViT-FishID/blob/main/ViT_FishID_MAE_EMA_Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## üéØ Training Pipeline Overview

This notebook implements a comprehensive two-stage training approach:

**Stage 1: Masked Autoencoder (MAE) Pretraining** üé≠
- Self-supervised pretraining on unlabeled fish images
- Learns robust visual representations by reconstructing masked patches
- Uses 75% masking ratio for strong representation learning
- Expected training time: 2-3 hours

**Stage 2: EMA Student-Teacher Semi-Supervised Learning** üéì
- Fine-tunes MAE-pretrained backbone for fish classification
- Combines labeled supervision with unlabeled consistency learning
- Uses exponential moving average teacher for pseudo-labeling
- Expected training time: 4-6 hours

## üìä Expected Performance Improvements

- **Without MAE**: ~75-80% accuracy after 100 epochs
- **With MAE + EMA**: ~85-92% accuracy after 100 epochs
- **Data efficiency**: Better performance with limited labeled data
- **Generalization**: Improved robustness to unseen fish species

## üõ†Ô∏è Requirements

- **GPU**: Colab Pro recommended (T4/V100/A100)
- **Memory**: ~12-16GB GPU memory
- **Runtime**: 6-9 hours total training time
- **Data**: Fish cutouts dataset with labeled/unlabeled images

## üîß Section 1: Environment Setup and GPU Check

Setting up the optimal environment for MAE pretraining and EMA training.

In [None]:
# Basic environment setup and GPU check (lightweight)
import sys
import os
import platform

print("üîç BASIC SYSTEM INFORMATION")
print("="*50)

# System info (no heavy imports)
print(f"Python version: {sys.version}")
print(f"Platform: {platform.platform()}")
print(f"Architecture: {platform.machine()}")

# Check if we're in Colab
try:
    import google.colab
    IN_COLAB = True
    print("‚úÖ Running in Google Colab")
    
    # Check Colab GPU status
    try:
        gpu_info = !nvidia-smi
        print("‚úÖ nvidia-smi available - GPU runtime detected")
    except:
        print("‚ö†Ô∏è  nvidia-smi not available - may be CPU runtime")
        print("   Please enable GPU: Runtime ‚Üí Change runtime type ‚Üí GPU")
    
except ImportError:
    IN_COLAB = False
    print("‚ÑπÔ∏è  Not running in Google Colab")

print("\nüéØ Training Pipeline Overview:")
print("  - Stage 1: MAE Pretraining (Self-supervised)")
print("  - Stage 2: EMA Student-Teacher (Semi-supervised)")
print("  - Expected total time: 6-9 hours")
print("  - Memory requirements: 8-12GB GPU (optimized)")

print("\n‚ö†Ô∏è  IMPORTANT SETUP NOTES:")
print("1. Ensure GPU runtime is enabled")
print("2. Upload fish_cutouts.zip to Google Drive root")
print("3. Allow kernel restart in Section 2 (this fixes CUDNN issues)")
print("4. Section 1 should run in <30 seconds")

print("\n‚úÖ Basic environment check complete!")
print("üöÄ Proceed to Section 2 to mount Drive and install dependencies")

## üö® TROUBLESHOOTING: If Section 1 is Running Too Long

**If the previous cell has been executing for >2 minutes:**

1. **INTERRUPT THE CELL**: Click the ‚èπÔ∏è **Stop** button in Colab
2. **RESTART RUNTIME**: Runtime ‚Üí Restart Runtime  
3. **RE-RUN**: Execute the cell again

**Common causes of long execution:**
- Missing GPU runtime (switches to slow CPU mode)
- Automatic package installation in background
- Import conflicts from previous runs

**Expected behavior:**
- Section 1 should complete in **<30 seconds**
- Should show "‚úÖ Basic environment check complete!"
- No heavy library imports (PyTorch comes later)

**Next steps after Section 1:**
1. Section 2: Mount Google Drive (quick)
2. Section 2: Install Dependencies (will restart kernel - normal!)
3. Section 2: Verify Installation (checks PyTorch/CUDA)
4. Then proceed with data setup and MAE training

## üìÅ Section 2: Mount Google Drive and Install Dependencies

Setting up data access and installing packages for MAE and EMA training.

In [None]:
# Mount Google Drive for data access
from google.colab import drive
import os

print("üìÅ MOUNTING GOOGLE DRIVE")
print("="*50)

# Mount Google Drive
drive.mount('/content/drive')

# Verify mount and show available space
drive_path = '/content/drive/MyDrive'
if os.path.exists(drive_path):
    # Get drive info
    statvfs = os.statvfs(drive_path)
    free_space = statvfs.f_frsize * statvfs.f_bavail / (1024**3)  # GB
    
    print(f"‚úÖ Google Drive mounted successfully!")
    print(f"üíæ Available space: {free_space:.1f} GB")
    
    # List some contents to verify
    items = os.listdir(drive_path)[:10]
    print(f"\nüìÇ Drive contents (first 10 items):")
    for item in items:
        print(f"  - {item}")
    
    if len(os.listdir(drive_path)) > 10:
        print(f"  ... and {len(os.listdir(drive_path)) - 10} more items")
    
    # Check for required dataset
    dataset_path = '/content/drive/MyDrive/fish_cutouts.zip'
    if os.path.exists(dataset_path):
        dataset_size = os.path.getsize(dataset_path) / (1024**2)  # MB
        print(f"\nüêü Found fish dataset: {dataset_size:.1f} MB")
    else:
        print(f"\n‚ö†Ô∏è  Fish dataset not found at: {dataset_path}")
        print("   Please ensure fish_cutouts.zip is uploaded to Google Drive root")
        
else:
    print("‚ùå Failed to mount Google Drive")
    print("   Please check your Google account permissions")

print("\nüí° Ready for data setup and model training!")

In [None]:
# Install comprehensive dependencies for MAE + EMA training
print("üì¶ INSTALLING ADVANCED DEPENDENCIES")
print("="*50)

# CRITICAL: Fix CUDNN compatibility issues first
print("üîß Fixing CUDNN compatibility...")

# Uninstall existing PyTorch to avoid conflicts
!pip uninstall -y torch torchvision torchaudio

# Install compatible PyTorch version for Colab
print("üîß Installing compatible PyTorch...")
!pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2+cu118 --index-url https://download.pytorch.org/whl/cu118

# Core ML libraries
print("üîß Installing core ML libraries...")
!pip install -q timm==0.9.7  # Specific version to avoid conflicts
!pip install -q transformers==4.33.0

# Vision and augmentation
print("üñºÔ∏è  Installing vision libraries...")
!pip install -q albumentations==1.3.1
!pip install -q opencv-python-headless==4.8.0.76
!pip install -q pillow==9.5.0

# Training utilities
print("‚öôÔ∏è Installing training utilities...")
!pip install -q wandb==0.15.8
!pip install -q scikit-learn==1.3.0
!pip install -q matplotlib==3.7.2
!pip install -q seaborn==0.12.2
!pip install -q tqdm==4.66.1

# Additional utilities for MAE (minimal versions)
print("üé≠ Installing MAE-specific utilities...")
!pip install -q accelerate==0.21.0
!pip install -q datasets==2.14.4

print("‚úÖ All dependencies installed with version pinning!")

# Restart Python kernel to ensure clean imports
print("\n? RESTARTING PYTHON KERNEL")
print("="*50)
print("‚ö†Ô∏è  After running this cell, you may see a kernel restart.")
print("   This is NORMAL and fixes CUDNN compatibility issues.")
print("   Simply continue with the next cell.")

import os
os.kill(os.getpid(), 9)  # Force restart to clear CUDNN conflicts

In [None]:
# Verify installations and setup PyTorch environment
print("üìã VERIFYING PYTORCH INSTALLATION & GPU SETUP")
print("="*50)

# Import PyTorch and related libraries (now that they're installed)
import torch
import torchvision
import timm
import transformers
import albumentations
import cv2
import sklearn
import wandb
import gc

print("‚úÖ Package Versions:")
print(f"  torch: {torch.__version__}")
print(f"  torchvision: {torchvision.__version__}")
print(f"  timm: {timm.__version__}")
print(f"  transformers: {transformers.__version__}")
print(f"  albumentations: {albumentations.__version__}")
print(f"  opencv: {cv2.__version__}")
print(f"  sklearn: {sklearn.__version__}")

# Comprehensive CUDA and GPU verification
print(f"\nüîç CUDA & GPU VERIFICATION")
print("="*30)
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU Device: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
    
    # Test CUDNN compatibility (critical fix)
    try:
        device = torch.device('cuda')
        test_tensor = torch.randn(2, 3, 224, 224).to(device)
        
        # Configure CUDNN for stability
        torch.backends.cudnn.enabled = True
        torch.backends.cudnn.benchmark = False  # Stable for variable inputs
        torch.backends.cudnn.deterministic = True  # Reproducible results
        
        # Test basic GPU operations
        result = test_tensor * 2.0
        result = result.cpu()
        
        print("‚úÖ CUDNN compatibility test PASSED!")
        print("‚úÖ GPU operations working correctly")
        
        # Clean up test tensors
        del test_tensor, result
        torch.cuda.empty_cache()
        gc.collect()
        
    except Exception as e:
        print(f"‚ùå CUDNN compatibility issue: {e}")
        print("üîß This indicates PyTorch/CUDA version mismatch")
        print("   Try restarting runtime and running from the beginning")
    
    # Set global device and optimize
    DEVICE = torch.device('cuda')
    
    # Memory optimization for long training
    torch.cuda.empty_cache()
    gc.collect()
    
    print(f"\nüéØ Using device: {DEVICE}")
    print("üí° Memory optimization enabled for long training sessions")
    
else:
    DEVICE = torch.device('cpu')
    print("\n‚ö†Ô∏è  No GPU detected!")
    print("üîß Enable GPU: Runtime ‚Üí Change runtime type ‚Üí Hardware accelerator ‚Üí GPU")
    print("‚ö†Ô∏è  MAE pretraining will be extremely slow on CPU")

# Set random seeds for reproducibility
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

print(f"\nüöÄ PyTorch environment ready for MAE + EMA training!")
print("üí° Key success indicators:")
print("  ‚úÖ CUDNN compatibility test passed")
print("  ‚úÖ GPU operations working")
print("  ‚úÖ No import errors")

# Store device for use in other cells
globals()['DEVICE'] = DEVICE

## üîÑ Section 3: Clone Repository and Setup Data

Cloning ViT-FishID repository and preparing fish dataset for MAE pretraining and classification.

In [None]:
# Clone ViT-FishID repository and prepare codebase
import os

print("üì• CLONING ViT-FishID REPOSITORY")
print("="*50)

# Remove existing directory if it exists
if os.path.exists('/content/ViT-FishID'):
    !rm -rf /content/ViT-FishID
    print("üóëÔ∏è  Removed existing repository")

# Clone the repository
print("üì• Cloning ViT-FishID repository...")
!git clone https://github.com/cat-thomson/ViT-FishID.git /content/ViT-FishID

# Change to project directory
%cd /content/ViT-FishID

# Verify repository structure
print("\nüìÇ Repository structure:")
!ls -la

# Check for key files
required_files = ['model.py', 'trainer.py', 'data.py', 'train.py']
missing_files = []

for file in required_files:
    if os.path.exists(file):
        print(f"‚úÖ Found: {file}")
    else:
        print(f"‚ùå Missing: {file}")
        missing_files.append(file)

if missing_files:
    print(f"\n‚ö†Ô∏è  Missing files: {missing_files}")
    print("   These will be created as part of the MAE implementation")
else:
    print("\n‚úÖ All required files found!")

# Set up Python path for imports
import sys
if '/content/ViT-FishID' not in sys.path:
    sys.path.append('/content/ViT-FishID')
    print("üîß Added repository to Python path")

print("\nüöÄ Repository ready for MAE and EMA implementation!")

In [None]:
# Extract and setup fish dataset for MAE + EMA training
import zipfile
import shutil
import os
from pathlib import Path

print("üêü EXTRACTING FISH DATASET")
print("="*50)

# Configuration
ZIP_FILE_PATH = '/content/drive/MyDrive/fish_cutouts.zip'
DATA_DIR = '/content/fish_cutouts'

print(f"üéØ ZIP location: {ZIP_FILE_PATH}")
print(f"üéØ Target directory: {DATA_DIR}")

# Check if data already exists
if os.path.exists(DATA_DIR) and os.path.exists(os.path.join(DATA_DIR, 'labeled')):
    print("‚úÖ Data already available locally!")
    
    # Quick validation
    labeled_dir = os.path.join(DATA_DIR, 'labeled')
    unlabeled_dir = os.path.join(DATA_DIR, 'unlabeled')
    
    if os.path.exists(labeled_dir):
        labeled_species = [d for d in os.listdir(labeled_dir) 
                         if os.path.isdir(os.path.join(labeled_dir, d)) and not d.startswith('.')]
        labeled_count = sum([len([f for f in os.listdir(os.path.join(labeled_dir, species))
                                if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
                           for species in labeled_species])
        print(f"üêü Labeled: {len(labeled_species)} species, {labeled_count} images")
    
    if os.path.exists(unlabeled_dir):
        unlabeled_count = len([f for f in os.listdir(unlabeled_dir)
                             if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
        print(f"üìä Unlabeled: {unlabeled_count} images (for MAE pretraining)")
    
    print("‚úÖ Data validation passed!")

else:
    print("üì• Extracting dataset from Google Drive...")
    
    if not os.path.exists(ZIP_FILE_PATH):
        print(f"‚ùå ZIP file not found: {ZIP_FILE_PATH}")
        print("üîß Please upload fish_cutouts.zip to Google Drive root")
        raise FileNotFoundError("Dataset ZIP file not found")
    
    print(f"üìè ZIP size: {os.path.getsize(ZIP_FILE_PATH) / (1024**2):.1f} MB")
    
    # Extract with progress
    temp_dir = '/content/temp_extract'
    if os.path.exists(temp_dir):
        shutil.rmtree(temp_dir)
    
    try:
        print("üì¶ Extracting ZIP file...")
        with zipfile.ZipFile(ZIP_FILE_PATH, 'r') as zip_ref:
            zip_ref.extractall(temp_dir)
        
        # Organize extracted files
        extracted_items = os.listdir(temp_dir)
        print(f"üìÅ Extracted items: {extracted_items}")
        
        # Find and move labeled/unlabeled directories
        labeled_source = None
        unlabeled_source = None
        
        for item in extracted_items:
            item_path = os.path.join(temp_dir, item)
            if item == 'labeled' and os.path.isdir(item_path):
                labeled_source = item_path
            elif item == 'unlabeled' and os.path.isdir(item_path):
                unlabeled_source = item_path
        
        if labeled_source and unlabeled_source:
            # Create target and move directories
            os.makedirs(DATA_DIR, exist_ok=True)
            shutil.move(labeled_source, os.path.join(DATA_DIR, 'labeled'))
            shutil.move(unlabeled_source, os.path.join(DATA_DIR, 'unlabeled'))
            
            print(f"‚úÖ Data organized at: {DATA_DIR}")
            
            # Verification
            labeled_dir = os.path.join(DATA_DIR, 'labeled')
            unlabeled_dir = os.path.join(DATA_DIR, 'unlabeled')
            
            species_count = len([d for d in os.listdir(labeled_dir)
                               if os.path.isdir(os.path.join(labeled_dir, d)) and not d.startswith('.')])
            unlabeled_count = len([f for f in os.listdir(unlabeled_dir)
                                 if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
            
            print(f"üêü Labeled: {species_count} species")
            print(f"üìä Unlabeled: {unlabeled_count} images")
            
        else:
            print("‚ùå Could not find labeled/unlabeled directories")
            raise FileNotFoundError("Invalid dataset structure")
        
        # Cleanup
        shutil.rmtree(temp_dir)
        
    except Exception as e:
        print(f"‚ùå Extraction failed: {e}")
        if os.path.exists(temp_dir):
            shutil.rmtree(temp_dir)
        raise

# Final verification and stats
print(f"\nüìä DATASET SUMMARY")
print("="*50)

labeled_dir = os.path.join(DATA_DIR, 'labeled')
unlabeled_dir = os.path.join(DATA_DIR, 'unlabeled')

if os.path.exists(labeled_dir):
    species = [d for d in os.listdir(labeled_dir) 
              if os.path.isdir(os.path.join(labeled_dir, d)) and not d.startswith('.')]
    total_labeled = sum([len([f for f in os.listdir(os.path.join(labeled_dir, s))
                            if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
                        for s in species])
    print(f"üéØ Labeled data: {len(species)} species, {total_labeled} images")
    
    # Show species distribution
    print("\nüêü Species distribution:")
    for species in species[:10]:  # Show first 10 species
        species_path = os.path.join(labeled_dir, species)
        count = len([f for f in os.listdir(species_path)
                    if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
        print(f"  - {species}: {count} images")
    
    if len(species) > 10:
        print(f"  ... and {len(species) - 10} more species")

if os.path.exists(unlabeled_dir):
    unlabeled_count = len([f for f in os.listdir(unlabeled_dir)
                         if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
    print(f"\nüé≠ Unlabeled data: {unlabeled_count} images (for MAE pretraining)")

print(f"\n‚úÖ Dataset ready for MAE pretraining and EMA fine-tuning!")

# Store global variables for later use
LABELED_DIR = labeled_dir
UNLABELED_DIR = unlabeled_dir
NUM_CLASSES = len(species) if 'species' in locals() else 37

## üé≠ Section 4: Implement Masked Autoencoder (MAE) Components

Creating the complete MAE architecture with ViT encoder, lightweight decoder, and masking strategy.

In [None]:
# Implement Masked Autoencoder (MAE) for Fish Images
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Optional, Tuple

print("üé≠ IMPLEMENTING MASKED AUTOENCODER (MAE)")
print("="*50)

# CRITICAL: Check GPU memory and compatibility first
print("üîç CHECKING GPU COMPATIBILITY")
print("="*30)

if torch.cuda.is_available():
    # Check CUDA and CUDNN compatibility
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"PyTorch CUDA Available: {torch.cuda.is_available()}")
    
    # Check GPU memory
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
    allocated_memory = torch.cuda.memory_allocated(0) / (1024**3)
    cached_memory = torch.cuda.memory_reserved(0) / (1024**3)
    
    print(f"GPU Total Memory: {gpu_memory:.1f} GB")
    print(f"Allocated Memory: {allocated_memory:.1f} GB")
    print(f"Cached Memory: {cached_memory:.1f} GB")
    print(f"Available Memory: {gpu_memory - cached_memory:.1f} GB")
    
    # Clear any existing GPU memory
    torch.cuda.empty_cache()
    print("‚úÖ GPU cache cleared")
    
    # Set memory management
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = False  # More stable for variable input sizes
    torch.backends.cudnn.deterministic = True  # For reproducible results
    
    print("‚úÖ CUDNN configured for stability")
    
    # Check if we have enough memory for MAE (needs ~8GB minimum)
    if gpu_memory < 8.0:
        print("‚ö†Ô∏è  WARNING: GPU has less than 8GB memory")
        print("   Consider using smaller model or batch sizes")
        USE_LIGHTWEIGHT_MAE = True
    else:
        USE_LIGHTWEIGHT_MAE = False
        print("‚úÖ Sufficient GPU memory for full MAE model")
        
else:
    print("‚ùå No GPU available - MAE requires GPU for reasonable training time")
    raise RuntimeError("GPU required for MAE training")

print("\nüèóÔ∏è  BUILDING MAE COMPONENTS")
print("="*30)

class PatchEmbedding(nn.Module):
    """Image to Patch Embedding for MAE"""
    
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        
        self.proj = nn.Conv2d(in_chans, embed_dim, 
                             kernel_size=patch_size, stride=patch_size)
        
    def forward(self, x):
        B, C, H, W = x.shape
        x = self.proj(x)  # (B, embed_dim, H/patch_size, W/patch_size)
        x = x.flatten(2).transpose(1, 2)  # (B, num_patches, embed_dim)
        return x

# Lightweight MultiHead Attention (to avoid timm dependency issues)
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
        
        self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=True)
        self.attn_drop = nn.Dropout(dropout)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.proj_drop = nn.Dropout(dropout)
        
    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        
        attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

# Lightweight Transformer Block
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadAttention(embed_dim, num_heads, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        
        mlp_hidden_dim = int(embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden_dim, embed_dim),
            nn.Dropout(dropout)
        )
        
    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class MAEEncoder(nn.Module):
    """Memory-efficient MAE Encoder"""
    
    def __init__(self, 
                 img_size=224,
                 patch_size=16, 
                 embed_dim=512,  # Reduced from 768
                 depth=8,        # Reduced from 12
                 num_heads=8,    # Reduced from 12
                 mlp_ratio=4.0,
                 dropout=0.1):
        super().__init__()
        
        self.patch_embed = PatchEmbedding(img_size, patch_size, 3, embed_dim)
        num_patches = self.patch_embed.num_patches
        
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        
        # Use custom lightweight blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])
        
        self.norm = nn.LayerNorm(embed_dim)
        
        # Initialize weights
        self._init_weights()
        
    def _init_weights(self):
        torch.nn.init.normal_(self.cls_token, std=.02)
        torch.nn.init.normal_(self.pos_embed, std=.02)
        
        for m in self.modules():
            if isinstance(m, nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)
    
    def random_masking(self, x, mask_ratio=0.75):
        """Perform random masking by per-sample shuffling"""
        B, N, D = x.shape
        len_keep = int(N * (1 - mask_ratio))
        
        noise = torch.rand(B, N, device=x.device)
        ids_shuffle = torch.argsort(noise, dim=1)
        ids_restore = torch.argsort(ids_shuffle, dim=1)
        
        # Keep the first subset
        ids_keep = ids_shuffle[:, :len_keep]
        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
        
        # Generate binary mask: 0 is keep, 1 is remove
        mask = torch.ones([B, N], device=x.device)
        mask[:, :len_keep] = 0
        mask = torch.gather(mask, dim=1, index=ids_restore)
        
        return x_masked, mask, ids_restore
    
    def forward(self, x, mask_ratio=0.75):
        # Patch embedding
        x = self.patch_embed(x)
        
        # Add pos embed without cls token
        x = x + self.pos_embed[:, 1:, :]
        
        # Masking: length -> length * mask_ratio
        x, mask, ids_restore = self.random_masking(x, mask_ratio)
        
        # Append cls token
        cls_token = self.cls_token + self.pos_embed[:, :1, :]
        cls_tokens = cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        
        # Apply Transformer blocks
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        
        return x, mask, ids_restore

class MAEDecoder(nn.Module):
    """Lightweight MAE Decoder"""
    
    def __init__(self, 
                 num_patches=196,
                 encoder_embed_dim=512,
                 decoder_embed_dim=256,  # Reduced from 512
                 decoder_depth=4,        # Reduced from 8
                 decoder_num_heads=8,    # Reduced from 16
                 mlp_ratio=4.0,
                 dropout=0.1):
        super().__init__()
        
        self.num_patches = num_patches
        self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True)
        
        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
        self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim))
        
        self.decoder_blocks = nn.ModuleList([
            TransformerBlock(decoder_embed_dim, decoder_num_heads, mlp_ratio, dropout)
            for _ in range(decoder_depth)
        ])
        
        self.decoder_norm = nn.LayerNorm(decoder_embed_dim)
        self.decoder_pred = nn.Linear(decoder_embed_dim, 16 * 16 * 3, bias=True)  # patch_size^2 * 3
        
        # Initialize weights
        self._init_weights()
        
    def _init_weights(self):
        torch.nn.init.normal_(self.mask_token, std=.02)
        torch.nn.init.normal_(self.decoder_pos_embed, std=.02)
        
        for m in self.modules():
            if isinstance(m, nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)
    
    def forward(self, x, ids_restore):
        # Embed tokens
        x = self.decoder_embed(x)
        
        # Append mask tokens to sequence
        mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
        x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)  # no cls token
        x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))  # unshuffle
        x = torch.cat([x[:, :1, :], x_], dim=1)  # append cls token
        
        # Add pos embed
        x = x + self.decoder_pos_embed
        
        # Apply Transformer blocks
        for blk in self.decoder_blocks:
            x = blk(x)
        x = self.decoder_norm(x)
        
        # Predictor projection
        x = self.decoder_pred(x)
        
        # Remove cls token
        x = x[:, 1:, :]
        
        return x

# Test with minimal memory allocation first
print("üß™ Testing basic components...")

try:
    # Test patch embedding first
    patch_embed = PatchEmbedding(224, 16, 3, 512).to(DEVICE)
    test_input = torch.randn(1, 3, 224, 224).to(DEVICE)
    with torch.no_grad():
        patches = patch_embed(test_input)
    print(f"‚úÖ Patch embedding works: {patches.shape}")
    
    # Clear memory
    del patch_embed, test_input, patches
    torch.cuda.empty_cache()
    
    print("‚úÖ Basic components tested successfully")
    
except Exception as e:
    print(f"‚ùå Component test failed: {e}")
    print("? Try reducing batch size or model dimensions")
    raise

print("\nüé≠ MAE components ready for model creation!")

In [None]:
# Create the complete MAE model with memory management
print("üèóÔ∏è  CREATING COMPLETE MAE MODEL")
print("="*50)

class MaskedAutoEncoder(nn.Module):
    """Complete Masked Autoencoder for Fish Images - Memory Optimized"""
    
    def __init__(self,
                 img_size=224,
                 patch_size=16,
                 encoder_embed_dim=512,  # Reduced for memory efficiency
                 encoder_depth=8,        # Reduced from 12
                 encoder_num_heads=8,    # Reduced from 12
                 decoder_embed_dim=256,  # Reduced from 512
                 decoder_depth=4,        # Reduced from 8
                 decoder_num_heads=8,    # Reduced from 16
                 mlp_ratio=4.0,
                 norm_pix_loss=False):
        super().__init__()
        
        self.patch_size = patch_size
        self.norm_pix_loss = norm_pix_loss
        
        print(f"üîß Encoder: {encoder_embed_dim}d, {encoder_depth} layers, {encoder_num_heads} heads")
        print(f"üîß Decoder: {decoder_embed_dim}d, {decoder_depth} layers, {decoder_num_heads} heads")
        
        # MAE encoder
        self.encoder = MAEEncoder(
            img_size=img_size,
            patch_size=patch_size,
            embed_dim=encoder_embed_dim,
            depth=encoder_depth,
            num_heads=encoder_num_heads,
            mlp_ratio=mlp_ratio
        )
        
        # MAE decoder
        num_patches = (img_size // patch_size) ** 2
        self.decoder = MAEDecoder(
            num_patches=num_patches,
            encoder_embed_dim=encoder_embed_dim,
            decoder_embed_dim=decoder_embed_dim,
            decoder_depth=decoder_depth,
            decoder_num_heads=decoder_num_heads,
            mlp_ratio=mlp_ratio
        )
        
    def patchify(self, imgs):
        """Convert images to patches"""
        p = self.patch_size
        assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
        
        h = w = imgs.shape[2] // p
        x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
        x = torch.einsum('nchpwq->nhwpqc', x)
        x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
        return x
    
    def unpatchify(self, x):
        """Convert patches back to images"""
        p = self.patch_size
        h = w = int(x.shape[1]**.5)
        assert h * w == x.shape[1]
        
        x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
        x = torch.einsum('nhwpqc->nchpwq', x)
        imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
        return imgs
    
    def forward_loss(self, imgs, pred, mask):
        """Compute reconstruction loss"""
        target = self.patchify(imgs)
        
        if self.norm_pix_loss:
            mean = target.mean(dim=-1, keepdim=True)
            var = target.var(dim=-1, keepdim=True)
            target = (target - mean) / (var + 1.e-6)**.5
        
        loss = (pred - target) ** 2
        loss = loss.mean(dim=-1)  # [N, L], mean loss per patch
        
        loss = (loss * mask).sum() / mask.sum()  # mean loss on removed patches
        return loss
    
    def forward(self, imgs, mask_ratio=0.75):
        latent, mask, ids_restore = self.encoder(imgs, mask_ratio)
        pred = self.decoder(latent, ids_restore)
        loss = self.forward_loss(imgs, pred, mask)
        return loss, pred, mask, latent

# Create MAE model with progressive memory checking
print("üöÄ Creating MAE model...")

try:
    # Check memory before model creation
    if torch.cuda.is_available():
        memory_before = torch.cuda.memory_allocated(0) / (1024**3)
        print(f"Memory before model: {memory_before:.2f} GB")
    
    # Create model with reduced dimensions for Colab compatibility
    mae_model = MaskedAutoEncoder(
        img_size=224,
        patch_size=16,
        encoder_embed_dim=512,  # Reduced from 768
        encoder_depth=8,        # Reduced from 12
        encoder_num_heads=8,    # Reduced from 12
        decoder_embed_dim=256,  # Reduced from 512
        decoder_depth=4,        # Reduced from 8
        decoder_num_heads=8,    # Reduced from 16
        mlp_ratio=4.0,
        norm_pix_loss=True
    )
    
    print("‚úÖ Model created successfully in CPU memory")
    
    # Move to GPU gradually
    print("üîÑ Moving model to GPU...")
    mae_model = mae_model.to(DEVICE)
    
    # Check memory after model creation
    if torch.cuda.is_available():
        memory_after = torch.cuda.memory_allocated(0) / (1024**3)
        print(f"Memory after model: {memory_after:.2f} GB")
        print(f"Model memory usage: {memory_after - memory_before:.2f} GB")
    
    print("‚úÖ Model successfully moved to GPU!")
    
    # Count parameters
    def count_parameters(model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)

    total_params = count_parameters(mae_model)
    encoder_params = count_parameters(mae_model.encoder)
    decoder_params = count_parameters(mae_model.decoder)

    print(f"\nüìä MODEL STATISTICS")
    print(f"üìä Total parameters: {total_params:,}")
    print(f"üìä Encoder parameters: {encoder_params:,}")
    print(f"üìä Decoder parameters: {decoder_params:,}")
    print(f"üéØ Model ready for pretraining")

except torch.cuda.OutOfMemoryError as e:
    print(f"‚ùå GPU Out of Memory: {e}")
    print("üîß Solutions:")
    print("   1. Restart runtime and run again")
    print("   2. Use smaller batch sizes (try batch_size=16)")
    print("   3. Further reduce model dimensions")
    print("   4. Use Colab Pro for more GPU memory")
    raise

except Exception as e:
    print(f"‚ùå Model creation failed: {e}")
    print("üîß Check CUDA installation and GPU compatibility")
    raise

# Test model with a small batch
print("\nüß™ TESTING MAE MODEL")
print("="*30)

try:
    # Test with minimal batch size
    batch_size = 2
    test_input = torch.randn(batch_size, 3, 224, 224).to(DEVICE)
    
    print(f"üß™ Testing with batch size {batch_size}")
    
    with torch.no_grad():
        loss, pred, mask, latent = mae_model(test_input, mask_ratio=0.75)
        
    print(f"‚úÖ Forward pass successful!")
    print(f"üìä Loss: {loss.item():.4f}")
    print(f"üìä Prediction shape: {pred.shape}")
    print(f"üìä Mask shape: {mask.shape}")
    print(f"üìä Latent shape: {latent.shape}")
    print(f"üìä Masked patches: {mask.sum(dim=1).float().mean().item():.1f} / {mask.shape[1]}")
    
    # Clean up test
    del test_input, loss, pred, mask, latent
    torch.cuda.empty_cache()
    
    print(f"\nüé≠ MAE MODEL READY FOR PRETRAINING!")
    
except Exception as e:
    print(f"‚ùå Model test failed: {e}")
    print("üîß Try reducing batch size or restarting runtime")
    raise

## ‚öôÔ∏è Section 5: Configure MAE Pretraining Parameters

Setting up optimal hyperparameters for self-supervised MAE pretraining on fish images.

In [None]:
# Configure MAE Pretraining Parameters and Data Pipeline
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image
import os
import glob
import wandb
from datetime import datetime

print("‚öôÔ∏è CONFIGURING MAE PRETRAINING")
print("="*50)

# MAE Pretraining Configuration
MAE_CONFIG = {
    # Model settings
    'mask_ratio': 0.75,  # Aggressive masking for strong representation learning
    'img_size': 224,
    'patch_size': 16,
    
    # Training settings
    'epochs': 50,  # Sufficient for good representations on fish data
    'batch_size': 64,  # Optimized for GPU memory
    'learning_rate': 1e-4,  # Conservative LR for stable training
    'weight_decay': 0.05,
    'warmup_epochs': 5,
    
    # Optimization
    'beta1': 0.9,
    'beta2': 0.95,
    'clip_grad': 1.0,
    
    # Saving
    'save_frequency': 10,  # Save every 10 epochs
    'checkpoint_dir': '/content/drive/MyDrive/ViT-FishID/mae_checkpoints',
    
    # Logging
    'use_wandb': True,
    'wandb_project': 'ViT-FishID-MAE-Pretraining',
    'wandb_run_name': f'mae-pretrain-{datetime.now().strftime("%Y%m%d-%H%M%S")}',
    
    # Data
    'data_dir': UNLABELED_DIR,
    'num_workers': 4,
}

print("üìä MAE Configuration:")
for key, value in MAE_CONFIG.items():
    print(f"  {key}: {value}")

# Create checkpoint directory
os.makedirs(MAE_CONFIG['checkpoint_dir'], exist_ok=True)
print(f"\nüíæ Checkpoint directory: {MAE_CONFIG['checkpoint_dir']}")

# MAE Dataset for Unlabeled Images
class MAEDataset(Dataset):
    """Dataset for MAE pretraining using unlabeled fish images"""
    
    def __init__(self, data_dir, img_size=224):
        self.data_dir = data_dir
        self.img_size = img_size
        
        # Find all image files
        self.image_paths = []
        for ext in ['*.jpg', '*.jpeg', '*.png', '*.JPG', '*.JPEG', '*.PNG']:
            self.image_paths.extend(glob.glob(os.path.join(data_dir, ext)))
        
        print(f"üìä Found {len(self.image_paths)} unlabeled images for MAE pretraining")
        
        # MAE-specific transforms - minimal augmentation to preserve structure
        self.transform = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                               std=[0.229, 0.224, 0.225])
        ])
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        
        try:
            # Load and process image
            image = Image.open(img_path).convert('RGB')
            image = self.transform(image)
            return image
        except Exception as e:
            print(f"Warning: Could not load {img_path}: {e}")
            # Return a black image as fallback
            return torch.zeros(3, self.img_size, self.img_size)

# Create MAE dataset and dataloader
print("\nüì¶ Creating MAE dataset...")
mae_dataset = MAEDataset(MAE_CONFIG['data_dir'], MAE_CONFIG['img_size'])

mae_dataloader = DataLoader(
    mae_dataset,
    batch_size=MAE_CONFIG['batch_size'],
    shuffle=True,
    num_workers=MAE_CONFIG['num_workers'],
    pin_memory=True,
    drop_last=True
)

print(f"‚úÖ MAE DataLoader created:")
print(f"  üìä Dataset size: {len(mae_dataset)} images")
print(f"  üìä Batch size: {MAE_CONFIG['batch_size']}")
print(f"  üìä Batches per epoch: {len(mae_dataloader)}")
print(f"  ‚è±Ô∏è  Estimated epoch time: {len(mae_dataloader) * 0.5:.1f}s")

# Setup MAE optimizer and scheduler
print("\nüîß Setting up MAE optimizer...")

mae_optimizer = optim.AdamW(
    mae_model.parameters(),
    lr=MAE_CONFIG['learning_rate'],
    betas=(MAE_CONFIG['beta1'], MAE_CONFIG['beta2']),
    weight_decay=MAE_CONFIG['weight_decay']
)

# Cosine annealing scheduler with warmup
def cosine_scheduler(optimizer, warmup_epochs, total_epochs):
    """Cosine annealing scheduler with linear warmup"""
    def lr_lambda(epoch):
        if epoch < warmup_epochs:
            return epoch / warmup_epochs
        else:
            progress = (epoch - warmup_epochs) / (total_epochs - warmup_epochs)
            return 0.5 * (1 + np.cos(np.pi * progress))
    
    return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

mae_scheduler = cosine_scheduler(
    mae_optimizer, 
    MAE_CONFIG['warmup_epochs'], 
    MAE_CONFIG['epochs']
)

print(f"‚úÖ Optimizer: AdamW with LR={MAE_CONFIG['learning_rate']}")
print(f"‚úÖ Scheduler: Cosine annealing with {MAE_CONFIG['warmup_epochs']} warmup epochs")

# Initialize Weights & Biases for MAE pretraining
if MAE_CONFIG['use_wandb']:
    print("\nüìà Initializing Weights & Biases for MAE pretraining...")
    try:
        wandb.init(
            project=MAE_CONFIG['wandb_project'],
            name=MAE_CONFIG['wandb_run_name'],
            config=MAE_CONFIG,
            tags=['mae', 'pretraining', 'fish', 'self-supervised']
        )
        print(f"‚úÖ W&B initialized: {wandb.run.url}")
    except Exception as e:
        print(f"‚ö†Ô∏è  W&B initialization failed: {e}")
        MAE_CONFIG['use_wandb'] = False

# Test data loading
print("\nüß™ Testing data pipeline...")
try:
    test_batch = next(iter(mae_dataloader))
    print(f"‚úÖ Data loading successful!")
    print(f"üìä Batch shape: {test_batch.shape}")
    print(f"üìä Batch dtype: {test_batch.dtype}")
    print(f"üìä Value range: [{test_batch.min():.3f}, {test_batch.max():.3f}]")
    
    # Test MAE forward pass
    with torch.no_grad():
        test_batch = test_batch.to(DEVICE)
        loss, pred, mask, latent = mae_model(test_batch, MAE_CONFIG['mask_ratio'])
        print(f"‚úÖ MAE forward pass successful!")
        print(f"üìä Reconstruction loss: {loss.item():.4f}")
        
except Exception as e:
    print(f"‚ùå Data pipeline test failed: {e}")
    raise

print(f"\nüé≠ MAE pretraining configuration complete!")
print(f"üöÄ Ready to start self-supervised pretraining on {len(mae_dataset)} fish images")
print(f"‚è±Ô∏è  Estimated total pretraining time: {MAE_CONFIG['epochs'] * len(mae_dataloader) * 0.5 / 3600:.1f} hours")

## üöÄ Section 6: Execute MAE Pretraining Phase

Running self-supervised pretraining to learn robust visual representations from unlabeled fish images.

In [None]:
# Execute MAE Pretraining on Unlabeled Fish Images
import torch
import time
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np

print("üé≠ STARTING MAE PRETRAINING")
print("="*60)

def save_mae_checkpoint(model, optimizer, scheduler, epoch, loss, config, filename):
    """Save MAE checkpoint"""
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'loss': loss,
        'config': config,
    }
    torch.save(checkpoint, filename)
    print(f"üíæ Saved checkpoint: {filename}")

def visualize_mae_reconstruction(model, dataloader, device, epoch, num_samples=4):
    """Visualize MAE reconstruction results"""
    model.eval()
    
    with torch.no_grad():
        # Get a batch of images
        images = next(iter(dataloader))[:num_samples].to(device)
        
        # Forward pass
        loss, pred, mask, latent = model(images, mask_ratio=0.75)
        
        # Convert to numpy for visualization
        images_np = images.cpu().numpy()
        pred_np = pred.cpu().numpy()
        mask_np = mask.cpu().numpy()
        
        # Denormalize images
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        
        for i in range(num_samples):
            img = images_np[i].transpose(1, 2, 0)
            img = img * std + mean
            img = np.clip(img, 0, 1)
            
            # Create reconstruction
            pred_patches = pred_np[i]  # [N_patches, patch_size^2 * 3]
            
            # This is a simplified visualization - in practice you'd want to 
            # properly reconstruct the image from patches
            
        print(f"üìä Reconstruction loss at epoch {epoch}: {loss.item():.4f}")
        
        if wandb.run and MAE_CONFIG['use_wandb']:
            try:
                # Log some metrics to wandb
                wandb.log({
                    'mae_reconstruction_loss': loss.item(),
                    'epoch': epoch
                })
            except:
                pass
    
    model.train()
    return loss.item()

# Training loop
print(f"üé¨ Starting MAE pretraining for {MAE_CONFIG['epochs']} epochs...")
print(f"üìä Training on {len(mae_dataset)} unlabeled fish images")
print(f"‚è∞ Start time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

mae_model.train()
best_loss = float('inf')
training_losses = []

start_time = time.time()

for epoch in range(MAE_CONFIG['epochs']):
    epoch_start_time = time.time()
    epoch_losses = []
    
    # Progress bar for epoch
    pbar = tqdm(mae_dataloader, desc=f"MAE Epoch {epoch+1}/{MAE_CONFIG['epochs']}")
    
    for batch_idx, images in enumerate(pbar):
        images = images.to(DEVICE)
        
        # Zero gradients
        mae_optimizer.zero_grad()
        
        # Forward pass
        loss, pred, mask, latent = mae_model(images, MAE_CONFIG['mask_ratio'])
        
        # Backward pass
        loss.backward()
        
        # Gradient clipping
        if MAE_CONFIG['clip_grad'] > 0:
            torch.nn.utils.clip_grad_norm_(mae_model.parameters(), MAE_CONFIG['clip_grad'])
        
        # Optimizer step
        mae_optimizer.step()
        
        # Record loss
        epoch_losses.append(loss.item())
        
        # Update progress bar
        pbar.set_postfix({
            'Loss': f"{loss.item():.4f}",
            'LR': f"{mae_optimizer.param_groups[0]['lr']:.6f}"
        })
        
        # Log to wandb
        if MAE_CONFIG['use_wandb'] and wandb.run:
            try:
                wandb.log({
                    'mae_batch_loss': loss.item(),
                    'mae_learning_rate': mae_optimizer.param_groups[0]['lr'],
                    'mae_step': epoch * len(mae_dataloader) + batch_idx
                })
            except:
                pass
    
    # Scheduler step
    mae_scheduler.step()
    
    # Epoch statistics
    epoch_loss = np.mean(epoch_losses)
    training_losses.append(epoch_loss)
    epoch_time = time.time() - epoch_start_time
    
    print(f"\nüìä Epoch {epoch+1} Summary:")
    print(f"  üìâ Average Loss: {epoch_loss:.4f}")
    print(f"  ‚è±Ô∏è  Time: {epoch_time:.1f}s")
    print(f"  üìà Learning Rate: {mae_optimizer.param_groups[0]['lr']:.6f}")
    
    # Visualize reconstruction periodically
    if (epoch + 1) % 10 == 0:
        recon_loss = visualize_mae_reconstruction(mae_model, mae_dataloader, DEVICE, epoch + 1)
    
    # Save checkpoint periodically
    if (epoch + 1) % MAE_CONFIG['save_frequency'] == 0:
        checkpoint_path = os.path.join(
            MAE_CONFIG['checkpoint_dir'], 
            f'mae_checkpoint_epoch_{epoch+1}.pth'
        )
        save_mae_checkpoint(
            mae_model, mae_optimizer, mae_scheduler, 
            epoch + 1, epoch_loss, MAE_CONFIG, checkpoint_path
        )
    
    # Save best model
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        best_checkpoint_path = os.path.join(
            MAE_CONFIG['checkpoint_dir'], 
            'mae_best_model.pth'
        )
        save_mae_checkpoint(
            mae_model, mae_optimizer, mae_scheduler, 
            epoch + 1, epoch_loss, MAE_CONFIG, best_checkpoint_path
        )
        print(f"üèÜ New best model saved! Loss: {best_loss:.4f}")
    
    # Memory cleanup
    torch.cuda.empty_cache()

# Training completed
total_time = time.time() - start_time
print(f"\nüéâ MAE PRETRAINING COMPLETED!")
print("="*60)
print(f"‚è∞ Total training time: {total_time/3600:.2f} hours")
print(f"üèÜ Best reconstruction loss: {best_loss:.4f}")
print(f"üìà Final learning rate: {mae_optimizer.param_groups[0]['lr']:.6f}")

# Save final checkpoint
final_checkpoint_path = os.path.join(
    MAE_CONFIG['checkpoint_dir'], 
    'mae_final_model.pth'
)
save_mae_checkpoint(
    mae_model, mae_optimizer, mae_scheduler, 
    MAE_CONFIG['epochs'], training_losses[-1], MAE_CONFIG, final_checkpoint_path
)

# Plot training curve
plt.figure(figsize=(10, 6))
plt.plot(range(1, len(training_losses) + 1), training_losses, 'b-', linewidth=2)
plt.title('MAE Pretraining Loss Curve', fontsize=14, fontweight='bold')
plt.xlabel('Epoch')
plt.ylabel('Reconstruction Loss')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.join(MAE_CONFIG['checkpoint_dir'], 'mae_training_curve.png'), dpi=300, bbox_inches='tight')
plt.show()

# Log final results to wandb
if MAE_CONFIG['use_wandb'] and wandb.run:
    try:
        wandb.log({
            'mae_final_loss': training_losses[-1],
            'mae_best_loss': best_loss,
            'mae_training_time_hours': total_time/3600,
            'mae_total_epochs': MAE_CONFIG['epochs']
        })
        wandb.finish()
    except:
        pass

print(f"\n‚úÖ MAE encoder is now pretrained on fish images!")
print(f"üíæ Checkpoints saved to: {MAE_CONFIG['checkpoint_dir']}")
print(f"üéØ Ready to extract pretrained weights for classification training!")

# Store checkpoint path for next stage
MAE_PRETRAINED_PATH = best_checkpoint_path

## üîÑ Section 7: Load MAE Pretrained Weights for ViT

Extracting pretrained encoder weights from MAE and loading them into the ViT classification model.

In [None]:
# Load MAE Pretrained Weights into ViT Classification Model
import torch
import torch.nn as nn
from collections import OrderedDict

print("üîÑ LOADING MAE PRETRAINED WEIGHTS")
print("="*50)

# Enhanced ViT model with MAE pretraining support
class MAEPretrainedViT(nn.Module):
    """ViT model that can load MAE pretrained encoder weights"""
    
    def __init__(self, num_classes, mae_encoder=None, dropout_rate=0.1):
        super().__init__()
        self.num_classes = num_classes
        
        if mae_encoder is not None:
            # Use pretrained MAE encoder
            self.backbone = mae_encoder
            self.feature_dim = mae_encoder.blocks[0].norm1.normalized_shape[0]  # Get embed_dim
            print(f"‚úÖ Using MAE pretrained encoder with {self.feature_dim} features")
        else:
            # Fallback to timm ViT
            import timm
            self.backbone = timm.create_model(
                'vit_base_patch16_224',
                pretrained=True,
                num_classes=0,
                global_pool='token'
            )
            self.feature_dim = self.backbone.num_features
            print(f"‚ö†Ô∏è  Using ImageNet pretrained ViT with {self.feature_dim} features")
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.LayerNorm(self.feature_dim),
            nn.Dropout(dropout_rate),
            nn.Linear(self.feature_dim, num_classes)
        )
        
        # Initialize classification head
        self._init_classifier()
    
    def _init_classifier(self):
        """Initialize the classification head weights"""
        for module in self.classifier.modules():
            if isinstance(module, nn.Linear):
                nn.init.trunc_normal_(module.weight, std=0.02)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
    
    def forward(self, x):
        """Forward pass for classification"""
        if hasattr(self.backbone, 'patch_embed'):
            # MAE encoder forward pass
            x = self.backbone.patch_embed(x)
            x = x + self.backbone.pos_embed[:, 1:, :]  # Add pos embed without cls
            
            # Add cls token
            cls_token = self.backbone.cls_token + self.backbone.pos_embed[:, :1, :]
            cls_tokens = cls_token.expand(x.shape[0], -1, -1)
            x = torch.cat((cls_tokens, x), dim=1)
            
            # Apply transformer blocks
            for blk in self.backbone.blocks:
                x = blk(x)
            x = self.backbone.norm(x)
            
            # Take cls token
            x = x[:, 0]
        else:
            # Standard timm ViT forward
            x = self.backbone(x)
        
        # Classification
        logits = self.classifier(x)
        return logits
    
    def get_features(self, x):
        """Extract features without classification"""
        if hasattr(self.backbone, 'patch_embed'):
            # MAE encoder feature extraction
            x = self.backbone.patch_embed(x)
            x = x + self.backbone.pos_embed[:, 1:, :]
            
            cls_token = self.backbone.cls_token + self.backbone.pos_embed[:, :1, :]
            cls_tokens = cls_token.expand(x.shape[0], -1, -1)
            x = torch.cat((cls_tokens, x), dim=1)
            
            for blk in self.backbone.blocks:
                x = blk(x)
            x = self.backbone.norm(x)
            
            return x[:, 0]  # Return cls token features
        else:
            return self.backbone(x)

print(f"üèóÔ∏è  Creating ViT classification model...")

# Load MAE checkpoint
if 'MAE_PRETRAINED_PATH' in globals() and os.path.exists(MAE_PRETRAINED_PATH):
    print(f"üìÇ Loading MAE checkpoint: {MAE_PRETRAINED_PATH}")
    
    mae_checkpoint = torch.load(MAE_PRETRAINED_PATH, map_location='cpu')
    print(f"‚úÖ MAE checkpoint loaded (epoch {mae_checkpoint['epoch']})")
    print(f"üìä MAE training loss: {mae_checkpoint['loss']:.4f}")
    
    # Create new MAE model and load weights
    mae_pretrained = MaskedAutoEncoder(
        img_size=224,
        patch_size=16,
        encoder_embed_dim=768,
        encoder_depth=12,
        encoder_num_heads=12,
        decoder_embed_dim=512,
        decoder_depth=8,
        decoder_num_heads=16,
        mlp_ratio=4.0,
        norm_pix_loss=True
    )
    
    mae_pretrained.load_state_dict(mae_checkpoint['model_state_dict'])
    mae_encoder = mae_pretrained.encoder
    
    # Create ViT with MAE pretrained encoder
    vit_model = MAEPretrainedViT(
        num_classes=NUM_CLASSES,
        mae_encoder=mae_encoder,
        dropout_rate=0.1
    ).to(DEVICE)
    
    print(f"‚úÖ ViT model created with MAE pretrained encoder!")
    
else:
    print("‚ö†Ô∏è  MAE checkpoint not found, using ImageNet pretrained ViT")
    vit_model = MAEPretrainedViT(
        num_classes=NUM_CLASSES,
        mae_encoder=None,
        dropout_rate=0.1
    ).to(DEVICE)

# Model statistics
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

total_params = count_parameters(vit_model)
backbone_params = count_parameters(vit_model.backbone)
classifier_params = count_parameters(vit_model.classifier)

print(f"\nüìä ViT Model Statistics:")
print(f"  Total parameters: {total_params:,}")
print(f"  Backbone parameters: {backbone_params:,}")
print(f"  Classifier parameters: {classifier_params:,}")
print(f"  Number of classes: {NUM_CLASSES}")

# Test model forward pass
print(f"\nüß™ Testing ViT model...")
test_input = torch.randn(2, 3, 224, 224).to(DEVICE)

with torch.no_grad():
    # Test classification
    logits = vit_model(test_input)
    print(f"‚úÖ Classification forward pass successful!")
    print(f"üìä Logits shape: {logits.shape}")
    
    # Test feature extraction
    features = vit_model.get_features(test_input)
    print(f"‚úÖ Feature extraction successful!")
    print(f"üìä Features shape: {features.shape}")

# Enhanced ViT model is now ready with MAE pretrained weights
print(f"\nüéØ ViT MODEL READY FOR EMA TRAINING")
print("="*50)
print(f"‚úÖ Model architecture: Vision Transformer")
print(f"‚úÖ Pretraining: {'MAE self-supervised' if 'MAE_PRETRAINED_PATH' in globals() else 'ImageNet supervised'}")
print(f"‚úÖ Classification head: Initialized for {NUM_CLASSES} fish species")
print(f"üöÄ Ready for EMA student-teacher semi-supervised training!")

# Clean up MAE model to free memory
if 'mae_model' in globals():
    del mae_model
    torch.cuda.empty_cache()
    print(f"üóëÔ∏è  Cleaned up MAE model to free GPU memory")

## üéì Section 8: Configure EMA Student-Teacher Framework

Setting up the EMA teacher model and semi-supervised learning pipeline with the MAE-pretrained backbone.

In [None]:
# Configure EMA Student-Teacher Framework for Semi-Supervised Learning
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
from PIL import Image
import copy
import os
import glob
import numpy as np
from datetime import datetime

print("üéì CONFIGURING EMA STUDENT-TEACHER FRAMEWORK")
print("="*50)

# EMA Teacher Implementation
class EMATeacher(nn.Module):
    """Exponential Moving Average Teacher for Semi-Supervised Learning"""
    
    def __init__(self, student_model, ema_decay=0.995):
        super().__init__()
        self.ema_decay = ema_decay
        self.student_model = student_model
        
        # Create teacher as copy of student
        self.teacher_model = copy.deepcopy(student_model)
        
        # Disable gradients for teacher
        for param in self.teacher_model.parameters():
            param.requires_grad = False
        
        print(f"‚úÖ EMA Teacher created with decay: {ema_decay}")
    
    def update_teacher(self):
        """Update teacher weights using EMA"""
        with torch.no_grad():
            for teacher_param, student_param in zip(
                self.teacher_model.parameters(), 
                self.student_model.parameters()
            ):
                teacher_param.data = (
                    self.ema_decay * teacher_param.data + 
                    (1 - self.ema_decay) * student_param.data
                )
    
    def forward(self, x, use_teacher=False):
        """Forward pass through student or teacher"""
        if use_teacher:
            return self.teacher_model(x)
        else:
            return self.student_model(x)
    
    def get_teacher_predictions(self, x):
        """Get teacher predictions for pseudo-labeling"""
        self.teacher_model.eval()
        with torch.no_grad():
            return self.teacher_model(x)

# Consistency Loss for Semi-Supervised Learning
class ConsistencyLoss(nn.Module):
    """Consistency loss between student and teacher predictions"""
    
    def __init__(self, temperature=4.0):
        super().__init__()
        self.temperature = temperature
        self.kl_div = nn.KLDivLoss(reduction='batchmean')
    
    def forward(self, student_logits, teacher_logits):
        """Compute consistency loss using KL divergence"""
        # Apply temperature scaling
        student_probs = F.log_softmax(student_logits / self.temperature, dim=1)
        teacher_probs = F.softmax(teacher_logits / self.temperature, dim=1)
        
        # KL divergence loss
        consistency_loss = self.kl_div(student_probs, teacher_probs)
        consistency_loss *= (self.temperature ** 2)
        
        return consistency_loss

# Semi-Supervised Dataset
class FishSemiSupervisedDataset(Dataset):
    """Dataset combining labeled and unlabeled fish images"""
    
    def __init__(self, labeled_dir, unlabeled_dir, img_size=224, mode='train'):
        self.img_size = img_size
        self.mode = mode
        
        # Load labeled data
        self.labeled_data = []
        self.class_to_idx = {}
        self.idx_to_class = {}
        
        if os.path.exists(labeled_dir):
            species_dirs = [d for d in os.listdir(labeled_dir) 
                          if os.path.isdir(os.path.join(labeled_dir, d)) and not d.startswith('.')]
            species_dirs.sort()
            
            for idx, species in enumerate(species_dirs):
                self.class_to_idx[species] = idx
                self.idx_to_class[idx] = species
                
                species_path = os.path.join(labeled_dir, species)
                for ext in ['*.jpg', '*.jpeg', '*.png', '*.JPG', '*.JPEG', '*.PNG']:
                    for img_path in glob.glob(os.path.join(species_path, ext)):
                        self.labeled_data.append((img_path, idx))
        
        # Load unlabeled data
        self.unlabeled_data = []
        if os.path.exists(unlabeled_dir):
            for ext in ['*.jpg', '*.jpeg', '*.png', '*.JPG', '*.JPEG', '*.PNG']:
                for img_path in glob.glob(os.path.join(unlabeled_dir, ext)):
                    self.unlabeled_data.append(img_path)
        
        print(f"üìä Loaded {len(self.labeled_data)} labeled images from {len(self.class_to_idx)} species")
        print(f"üìä Loaded {len(self.unlabeled_data)} unlabeled images")
        
        # Data augmentation transforms
        if mode == 'train':
            self.labeled_transform = transforms.Compose([
                transforms.Resize((img_size, img_size)),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomRotation(degrees=15),
                transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
            
            self.unlabeled_transform_weak = transforms.Compose([
                transforms.Resize((img_size, img_size)),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
            
            self.unlabeled_transform_strong = transforms.Compose([
                transforms.Resize((img_size, img_size)),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomRotation(degrees=20),
                transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
                transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
        else:
            self.labeled_transform = transforms.Compose([
                transforms.Resize((img_size, img_size)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
    
    def __len__(self):
        return len(self.labeled_data) + len(self.unlabeled_data)
    
    def __getitem__(self, idx):
        if idx < len(self.labeled_data):
            # Labeled data
            img_path, label = self.labeled_data[idx]
            try:
                image = Image.open(img_path).convert('RGB')
                image = self.labeled_transform(image)
                return image, label, True  # True indicates labeled data
            except Exception as e:
                print(f"Warning: Could not load {img_path}: {e}")
                return torch.zeros(3, self.img_size, self.img_size), 0, True
        else:
            # Unlabeled data
            unlabeled_idx = idx - len(self.labeled_data)
            img_path = self.unlabeled_data[unlabeled_idx]
            
            try:
                image = Image.open(img_path).convert('RGB')
                
                if self.mode == 'train':
                    # Return both weak and strong augmentations for consistency training
                    image_weak = self.unlabeled_transform_weak(image)
                    image_strong = self.unlabeled_transform_strong(image)
                    return (image_weak, image_strong), -1, False  # -1 indicates no label, False indicates unlabeled
                else:
                    image = self.labeled_transform(image)
                    return image, -1, False
                    
            except Exception as e:
                print(f"Warning: Could not load {img_path}: {e}")
                if self.mode == 'train':
                    zero_img = torch.zeros(3, self.img_size, self.img_size)
                    return (zero_img, zero_img), -1, False
                else:
                    return torch.zeros(3, self.img_size, self.img_size), -1, False

# EMA Training Configuration
EMA_CONFIG = {
    # Model settings
    'num_classes': NUM_CLASSES,
    'img_size': 224,
    
    # Training settings
    'epochs': 100,
    'batch_size': 32,  # Balanced for labeled + unlabeled data
    'learning_rate': 1e-4,
    'weight_decay': 0.05,
    'warmup_epochs': 10,
    
    # EMA settings
    'ema_decay': 0.995,
    'consistency_weight': 2.0,
    'pseudo_label_threshold': 0.7,
    'temperature': 4.0,
    'ramp_up_epochs': 20,
    
    # Optimization
    'clip_grad': 1.0,
    'label_smoothing': 0.1,
    
    # Saving
    'save_frequency': 10,
    'checkpoint_dir': '/content/drive/MyDrive/ViT-FishID/ema_checkpoints',
    
    # Logging
    'use_wandb': True,
    'wandb_project': 'ViT-FishID-EMA-Training',
    'wandb_run_name': f'ema-mae-pretrained-{datetime.now().strftime("%Y%m%d-%H%M%S")}',
    
    # Data
    'labeled_dir': LABELED_DIR,
    'unlabeled_dir': UNLABELED_DIR,
    'train_split': 0.8,
    'num_workers': 4,
}

print("üìä EMA Configuration:")
for key, value in EMA_CONFIG.items():
    print(f"  {key}: {value}")

# Create checkpoint directory
os.makedirs(EMA_CONFIG['checkpoint_dir'], exist_ok=True)
print(f"\nüíæ EMA Checkpoint directory: {EMA_CONFIG['checkpoint_dir']}")

# Create datasets and dataloaders
print(f"\nüì¶ Creating semi-supervised datasets...")

# Full dataset
full_dataset = FishSemiSupervisedDataset(
    EMA_CONFIG['labeled_dir'], 
    EMA_CONFIG['unlabeled_dir'], 
    EMA_CONFIG['img_size'], 
    mode='train'
)

# Split labeled data into train/validation
labeled_size = len(full_dataset.labeled_data)
train_size = int(EMA_CONFIG['train_split'] * labeled_size)
val_size = labeled_size - train_size

# Create train and validation datasets
train_labeled_data = full_dataset.labeled_data[:train_size]
val_labeled_data = full_dataset.labeled_data[train_size:]

# Training dataset (includes all unlabeled data)
class TrainDataset(Dataset):
    def __init__(self, labeled_data, unlabeled_data, class_to_idx, img_size):
        self.labeled_data = labeled_data
        self.unlabeled_data = unlabeled_data
        self.class_to_idx = class_to_idx
        self.img_size = img_size
        
        # Same transforms as before
        self.labeled_transform = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(degrees=15),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        self.unlabeled_transform_weak = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        self.unlabeled_transform_strong = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(degrees=20),
            transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
            transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    def __len__(self):
        return len(self.labeled_data) + len(self.unlabeled_data)
    
    def __getitem__(self, idx):
        if idx < len(self.labeled_data):
            # Labeled data
            img_path, label = self.labeled_data[idx]
            try:
                image = Image.open(img_path).convert('RGB')
                image = self.labeled_transform(image)
                return image, label, True
            except:
                return torch.zeros(3, self.img_size, self.img_size), 0, True
        else:
            # Unlabeled data
            unlabeled_idx = idx - len(self.labeled_data)
            img_path = self.unlabeled_data[unlabeled_idx]
            
            try:
                image = Image.open(img_path).convert('RGB')
                image_weak = self.unlabeled_transform_weak(image)
                image_strong = self.unlabeled_transform_strong(image)
                return (image_weak, image_strong), -1, False
            except:
                zero_img = torch.zeros(3, self.img_size, self.img_size)
                return (zero_img, zero_img), -1, False

# Validation dataset
class ValDataset(Dataset):
    def __init__(self, labeled_data, class_to_idx, img_size):
        self.labeled_data = labeled_data
        self.class_to_idx = class_to_idx
        self.img_size = img_size
        
        self.transform = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    def __len__(self):
        return len(self.labeled_data)
    
    def __getitem__(self, idx):
        img_path, label = self.labeled_data[idx]
        try:
            image = Image.open(img_path).convert('RGB')
            image = self.transform(image)
            return image, label
        except:
            return torch.zeros(3, self.img_size, self.img_size), 0

# Create datasets
train_dataset = TrainDataset(
    train_labeled_data, 
    full_dataset.unlabeled_data, 
    full_dataset.class_to_idx, 
    EMA_CONFIG['img_size']
)

val_dataset = ValDataset(
    val_labeled_data, 
    full_dataset.class_to_idx, 
    EMA_CONFIG['img_size']
)

# Create dataloaders
train_dataloader = DataLoader(
    train_dataset,
    batch_size=EMA_CONFIG['batch_size'],
    shuffle=True,
    num_workers=EMA_CONFIG['num_workers'],
    pin_memory=True,
    drop_last=True
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=EMA_CONFIG['batch_size'],
    shuffle=False,
    num_workers=EMA_CONFIG['num_workers'],
    pin_memory=True
)

print(f"‚úÖ Datasets created:")
print(f"  üìä Training: {len(train_dataset)} samples ({len(train_labeled_data)} labeled + {len(full_dataset.unlabeled_data)} unlabeled)")
print(f"  üìä Validation: {len(val_dataset)} samples (labeled)")
print(f"  üìä Classes: {len(full_dataset.class_to_idx)}")

# Create EMA teacher
ema_teacher = EMATeacher(vit_model, ema_decay=EMA_CONFIG['ema_decay']).to(DEVICE)
consistency_loss_fn = ConsistencyLoss(temperature=EMA_CONFIG['temperature']).to(DEVICE)

print(f"\n‚úÖ EMA Framework Ready:")
print(f"  üéì Student model: MAE-pretrained ViT")
print(f"  üë®‚Äçüè´ Teacher model: EMA with decay {EMA_CONFIG['ema_decay']}")
print(f"  üîÑ Consistency loss: KL divergence with temperature {EMA_CONFIG['temperature']}")
print(f"üöÄ Ready for semi-supervised training!")

## üöÄ Section 9: Execute Semi-Supervised Training with EMA

Running the complete semi-supervised training pipeline combining labeled supervision with unlabeled consistency learning.

In [None]:
# Execute Semi-Supervised EMA Training with MAE Pretrained Model
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import time
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

print("üéì STARTING EMA SEMI-SUPERVISED TRAINING")
print("="*60)

# Setup optimizer and scheduler
ema_optimizer = optim.AdamW(
    vit_model.parameters(),
    lr=EMA_CONFIG['learning_rate'],
    weight_decay=EMA_CONFIG['weight_decay']
)

# Cosine annealing scheduler with warmup
def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + np.cos(np.pi * progress)))
    
    return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

total_steps = len(train_dataloader) * EMA_CONFIG['epochs']
warmup_steps = len(train_dataloader) * EMA_CONFIG['warmup_epochs']

ema_scheduler = get_cosine_schedule_with_warmup(
    ema_optimizer, 
    warmup_steps, 
    total_steps
)

# Loss functions
supervised_loss_fn = nn.CrossEntropyLoss(label_smoothing=EMA_CONFIG['label_smoothing'])

# Consistency weight ramp-up function
def get_consistency_weight(epoch, ramp_up_epochs):
    """Gradually ramp up consistency weight"""
    if epoch < ramp_up_epochs:
        return EMA_CONFIG['consistency_weight'] * (epoch / ramp_up_epochs)
    return EMA_CONFIG['consistency_weight']

# Pseudo-labeling function
def get_pseudo_labels(teacher_logits, threshold):
    """Generate pseudo-labels from teacher predictions"""
    teacher_probs = F.softmax(teacher_logits, dim=1)
    max_probs, pseudo_labels = torch.max(teacher_probs, dim=1)
    
    # Create mask for confident predictions
    confident_mask = max_probs >= threshold
    
    return pseudo_labels, confident_mask, max_probs

# Validation function
def validate_model(model, val_dataloader, device):
    """Validate model on labeled validation set"""
    model.eval()
    total_correct = 0
    total_samples = 0
    total_loss = 0
    
    with torch.no_grad():
        for images, labels in val_dataloader:
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = F.cross_entropy(outputs, labels)
            
            _, predicted = torch.max(outputs, 1)
            total_correct += (predicted == labels).sum().item()
            total_samples += labels.size(0)
            total_loss += loss.item()
    
    accuracy = 100.0 * total_correct / total_samples
    avg_loss = total_loss / len(val_dataloader)
    
    return accuracy, avg_loss

# Save checkpoint function
def save_ema_checkpoint(student_model, teacher_model, optimizer, scheduler, epoch, 
                       best_acc, config, filename):
    """Save EMA training checkpoint"""
    checkpoint = {
        'epoch': epoch,
        'student_state_dict': student_model.state_dict(),
        'teacher_state_dict': teacher_model.teacher_model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'best_accuracy': best_acc,
        'config': config,
        'class_to_idx': full_dataset.class_to_idx,
        'num_classes': config['num_classes']
    }
    torch.save(checkpoint, filename)
    print(f"üíæ Saved checkpoint: {filename}")

# Initialize Weights & Biases
if EMA_CONFIG['use_wandb']:
    print("üìà Initializing Weights & Biases...")
    try:
        wandb.init(
            project=EMA_CONFIG['wandb_project'],
            name=EMA_CONFIG['wandb_run_name'],
            config=EMA_CONFIG,
            tags=['ema', 'semi-supervised', 'fish', 'mae-pretrained']
        )
        print(f"‚úÖ W&B initialized: {wandb.run.url}")
    except Exception as e:
        print(f"‚ö†Ô∏è  W&B initialization failed: {e}")
        EMA_CONFIG['use_wandb'] = False

# Training loop
print(f"üé¨ Starting EMA training for {EMA_CONFIG['epochs']} epochs...")
print(f"üìä Training data: {len(train_dataset)} samples")
print(f"üìä Validation data: {len(val_dataset)} samples")
print(f"‚è∞ Start time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

vit_model.train()
ema_teacher.student_model.train()
ema_teacher.teacher_model.eval()

best_accuracy = 0.0
training_history = {
    'supervised_loss': [],
    'consistency_loss': [],
    'total_loss': [],
    'val_accuracy': [],
    'val_loss': []
}

start_time = time.time()

for epoch in range(EMA_CONFIG['epochs']):
    epoch_start_time = time.time()
    
    # Training phase
    vit_model.train()
    epoch_supervised_loss = 0.0
    epoch_consistency_loss = 0.0
    epoch_total_loss = 0.0
    num_labeled_samples = 0
    num_unlabeled_samples = 0
    
    # Get current consistency weight
    current_consistency_weight = get_consistency_weight(epoch, EMA_CONFIG['ramp_up_epochs'])
    
    pbar = tqdm(train_dataloader, desc=f"EMA Epoch {epoch+1}/{EMA_CONFIG['epochs']}")
    
    for batch_idx, batch_data in enumerate(pbar):
        # Separate labeled and unlabeled data
        labeled_data = []
        unlabeled_data = []
        
        for data, label, is_labeled in zip(*batch_data):
            if is_labeled:
                labeled_data.append((data, label))
            else:
                unlabeled_data.append(data)
        
        total_loss = 0.0
        supervised_loss = torch.tensor(0.0).to(DEVICE)
        consistency_loss = torch.tensor(0.0).to(DEVICE)
        
        # Process labeled data
        if labeled_data:
            labeled_images = torch.stack([data for data, _ in labeled_data]).to(DEVICE)
            labeled_targets = torch.tensor([label for _, label in labeled_data]).to(DEVICE)
            
            # Student forward pass
            student_outputs = vit_model(labeled_images)
            supervised_loss = supervised_loss_fn(student_outputs, labeled_targets)
            
            num_labeled_samples += len(labeled_data)
        
        # Process unlabeled data
        if unlabeled_data and current_consistency_weight > 0:
            # Unlabeled data comes as (weak_aug, strong_aug) tuples
            weak_images = torch.stack([data[0] for data in unlabeled_data]).to(DEVICE)
            strong_images = torch.stack([data[1] for data in unlabeled_data]).to(DEVICE)
            
            # Teacher predictions on weakly augmented images
            teacher_outputs = ema_teacher.get_teacher_predictions(weak_images)
            
            # Generate pseudo-labels
            pseudo_labels, confident_mask, max_probs = get_pseudo_labels(
                teacher_outputs, EMA_CONFIG['pseudo_label_threshold']
            )
            
            if confident_mask.sum() > 0:
                # Student predictions on strongly augmented images
                student_outputs_unlabeled = vit_model(strong_images)
                
                # Consistency loss only for confident predictions
                if confident_mask.sum() > 0:
                    student_confident = student_outputs_unlabeled[confident_mask]
                    teacher_confident = teacher_outputs[confident_mask]
                    
                    consistency_loss = consistency_loss_fn(student_confident, teacher_confident)
                    consistency_loss *= current_consistency_weight
            
            num_unlabeled_samples += len(unlabeled_data)
        
        # Total loss
        total_loss = supervised_loss + consistency_loss
        
        # Backward pass
        ema_optimizer.zero_grad()
        total_loss.backward()
        
        # Gradient clipping
        if EMA_CONFIG['clip_grad'] > 0:
            torch.nn.utils.clip_grad_norm_(vit_model.parameters(), EMA_CONFIG['clip_grad'])
        
        ema_optimizer.step()
        ema_scheduler.step()
        
        # Update teacher with EMA
        ema_teacher.update_teacher()
        
        # Record losses
        epoch_supervised_loss += supervised_loss.item()
        epoch_consistency_loss += consistency_loss.item()
        epoch_total_loss += total_loss.item()
        
        # Update progress bar
        pbar.set_postfix({
            'Sup_Loss': f"{supervised_loss.item():.4f}",
            'Con_Loss': f"{consistency_loss.item():.4f}",
            'Con_Weight': f"{current_consistency_weight:.3f}",
            'LR': f"{ema_optimizer.param_groups[0]['lr']:.6f}"
        })
        
        # Log to wandb
        if EMA_CONFIG['use_wandb'] and wandb.run:
            try:
                wandb.log({
                    'batch_supervised_loss': supervised_loss.item(),
                    'batch_consistency_loss': consistency_loss.item(),
                    'batch_total_loss': total_loss.item(),
                    'consistency_weight': current_consistency_weight,
                    'learning_rate': ema_optimizer.param_groups[0]['lr'],
                    'step': epoch * len(train_dataloader) + batch_idx
                })
            except:
                pass
    
    # Validation phase
    val_accuracy, val_loss = validate_model(vit_model, val_dataloader, DEVICE)
    
    # Epoch statistics
    avg_supervised_loss = epoch_supervised_loss / len(train_dataloader)
    avg_consistency_loss = epoch_consistency_loss / len(train_dataloader)
    avg_total_loss = epoch_total_loss / len(train_dataloader)
    epoch_time = time.time() - epoch_start_time
    
    training_history['supervised_loss'].append(avg_supervised_loss)
    training_history['consistency_loss'].append(avg_consistency_loss)
    training_history['total_loss'].append(avg_total_loss)
    training_history['val_accuracy'].append(val_accuracy)
    training_history['val_loss'].append(val_loss)
    
    print(f"\nüìä Epoch {epoch+1} Summary:")
    print(f"  üìâ Supervised Loss: {avg_supervised_loss:.4f}")
    print(f"  üìâ Consistency Loss: {avg_consistency_loss:.4f}")
    print(f"  üìâ Total Loss: {avg_total_loss:.4f}")
    print(f"  üìà Validation Accuracy: {val_accuracy:.2f}%")
    print(f"  üìâ Validation Loss: {val_loss:.4f}")
    print(f"  ‚è±Ô∏è  Time: {epoch_time:.1f}s")
    print(f"  üéì Labeled samples: {num_labeled_samples}")
    print(f"  üîÑ Unlabeled samples: {num_unlabeled_samples}")
    
    # Save checkpoint periodically
    if (epoch + 1) % EMA_CONFIG['save_frequency'] == 0:
        checkpoint_path = os.path.join(
            EMA_CONFIG['checkpoint_dir'], 
            f'ema_checkpoint_epoch_{epoch+1}.pth'
        )
        save_ema_checkpoint(
            vit_model, ema_teacher, ema_optimizer, ema_scheduler,
            epoch + 1, val_accuracy, EMA_CONFIG, checkpoint_path
        )
    
    # Save best model
    if val_accuracy > best_accuracy:
        best_accuracy = val_accuracy
        best_checkpoint_path = os.path.join(
            EMA_CONFIG['checkpoint_dir'], 
            'ema_best_model.pth'
        )
        save_ema_checkpoint(
            vit_model, ema_teacher, ema_optimizer, ema_scheduler,
            epoch + 1, val_accuracy, EMA_CONFIG, best_checkpoint_path
        )
        print(f"üèÜ New best model saved! Accuracy: {best_accuracy:.2f}%")
    
    # Log epoch results to wandb
    if EMA_CONFIG['use_wandb'] and wandb.run:
        try:
            wandb.log({
                'epoch': epoch + 1,
                'epoch_supervised_loss': avg_supervised_loss,
                'epoch_consistency_loss': avg_consistency_loss,
                'epoch_total_loss': avg_total_loss,
                'val_accuracy': val_accuracy,
                'val_loss': val_loss,
                'best_accuracy': best_accuracy,
                'epoch_time': epoch_time
            })
        except:
            pass
    
    # Memory cleanup
    torch.cuda.empty_cache()

# Training completed
total_time = time.time() - start_time
print(f"\nüéâ EMA TRAINING COMPLETED!")
print("="*60)
print(f"‚è∞ Total training time: {total_time/3600:.2f} hours")
print(f"üèÜ Best validation accuracy: {best_accuracy:.2f}%")
print(f"üìà Final learning rate: {ema_optimizer.param_groups[0]['lr']:.6f}")

# Save final checkpoint
final_checkpoint_path = os.path.join(
    EMA_CONFIG['checkpoint_dir'], 
    'ema_final_model.pth'
)
save_ema_checkpoint(
    vit_model, ema_teacher, ema_optimizer, ema_scheduler,
    EMA_CONFIG['epochs'], training_history['val_accuracy'][-1], 
    EMA_CONFIG, final_checkpoint_path
)

# Plot training curves
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))

# Loss curves
epochs = range(1, len(training_history['total_loss']) + 1)
ax1.plot(epochs, training_history['supervised_loss'], 'b-', label='Supervised Loss', linewidth=2)
ax1.plot(epochs, training_history['consistency_loss'], 'r-', label='Consistency Loss', linewidth=2)
ax1.plot(epochs, training_history['total_loss'], 'g-', label='Total Loss', linewidth=2)
ax1.set_title('Training Loss Curves', fontweight='bold')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Validation accuracy
ax2.plot(epochs, training_history['val_accuracy'], 'purple', linewidth=2)
ax2.set_title('Validation Accuracy', fontweight='bold')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.grid(True, alpha=0.3)

# Validation loss
ax3.plot(epochs, training_history['val_loss'], 'orange', linewidth=2)
ax3.set_title('Validation Loss', fontweight='bold')
ax3.set_xlabel('Epoch')
ax3.set_ylabel('Loss')
ax3.grid(True, alpha=0.3)

# Loss ratio
consistency_ratio = [c/(s+1e-8) for s, c in zip(training_history['supervised_loss'], 
                                                training_history['consistency_loss'])]
ax4.plot(epochs, consistency_ratio, 'brown', linewidth=2)
ax4.set_title('Consistency/Supervised Loss Ratio', fontweight='bold')
ax4.set_xlabel('Epoch')
ax4.set_ylabel('Ratio')
ax4.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(EMA_CONFIG['checkpoint_dir'], 'ema_training_curves.png'), 
           dpi=300, bbox_inches='tight')
plt.show()

# Final wandb logging
if EMA_CONFIG['use_wandb'] and wandb.run:
    try:
        wandb.log({
            'final_best_accuracy': best_accuracy,
            'final_val_accuracy': training_history['val_accuracy'][-1],
            'total_training_time_hours': total_time/3600,
            'total_epochs': EMA_CONFIG['epochs']
        })
        wandb.finish()
    except:
        pass

print(f"\n‚úÖ Semi-supervised EMA training with MAE pretraining completed!")
print(f"üíæ Checkpoints saved to: {EMA_CONFIG['checkpoint_dir']}")
print(f"üéØ Best model achieved {best_accuracy:.2f}% accuracy!")
print(f"üöÄ Model ready for evaluation and deployment!")

# Store paths for evaluation
EMA_BEST_MODEL_PATH = best_checkpoint_path

## üìä Section 10: Monitor Training Progress and Save Checkpoints

Tracking training metrics, analyzing model performance, and managing checkpoint saves.

In [None]:
# Monitor Training Progress and Manage Checkpoints
import os
import glob
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
import shutil

print("üìä TRAINING PROGRESS MONITORING & CHECKPOINT MANAGEMENT")
print("="*60)

def analyze_training_checkpoints(checkpoint_dir):
    """Analyze all training checkpoints and extract metrics"""
    
    checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'ema_checkpoint_epoch_*.pth'))
    checkpoint_files.sort(key=lambda x: int(x.split('epoch_')[1].split('.')[0]))
    
    if not checkpoint_files:
        print("‚ùå No checkpoint files found")
        return None
    
    checkpoint_data = []
    
    print(f"üìÇ Found {len(checkpoint_files)} checkpoint files")
    print("\nüìä Checkpoint Analysis:")
    print("-" * 60)
    print(f"{'Epoch':<8} {'Accuracy':<12} {'File Size':<12} {'Timestamp'}")
    print("-" * 60)
    
    for checkpoint_file in checkpoint_files:
        try:
            checkpoint = torch.load(checkpoint_file, map_location='cpu')
            
            epoch = checkpoint.get('epoch', 0)
            accuracy = checkpoint.get('best_accuracy', 0)
            file_size = os.path.getsize(checkpoint_file) / (1024 * 1024)  # MB
            timestamp = datetime.fromtimestamp(os.path.getmtime(checkpoint_file))
            
            checkpoint_data.append({
                'epoch': epoch,
                'accuracy': accuracy,
                'file_size': file_size,
                'timestamp': timestamp,
                'file_path': checkpoint_file
            })
            
            print(f"{epoch:<8} {accuracy:<12.2f} {file_size:<12.1f} {timestamp.strftime('%H:%M:%S')}")
            
        except Exception as e:
            print(f"‚ö†Ô∏è  Could not load {checkpoint_file}: {e}")
    
    return checkpoint_data

def visualize_training_progress(checkpoint_dir, training_history=None):
    """Create comprehensive training progress visualizations"""
    
    if training_history is None:
        print("‚ö†Ô∏è  Training history not available, using checkpoint data only")
        return
    
    # Create comprehensive visualization
    fig = plt.figure(figsize=(20, 12))
    
    # 1. Loss progression
    ax1 = plt.subplot(2, 4, 1)
    epochs = range(1, len(training_history['total_loss']) + 1)
    ax1.plot(epochs, training_history['supervised_loss'], 'b-', label='Supervised', linewidth=2)
    ax1.plot(epochs, training_history['consistency_loss'], 'r-', label='Consistency', linewidth=2)
    ax1.plot(epochs, training_history['total_loss'], 'g-', label='Total', linewidth=2)
    ax1.set_title('Training Loss Progression', fontweight='bold', fontsize=12)
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # 2. Validation accuracy
    ax2 = plt.subplot(2, 4, 2)
    ax2.plot(epochs, training_history['val_accuracy'], 'purple', linewidth=3)
    ax2.set_title('Validation Accuracy', fontweight='bold', fontsize=12)
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.grid(True, alpha=0.3)
    
    # Add best accuracy line
    best_acc = max(training_history['val_accuracy'])
    ax2.axhline(y=best_acc, color='red', linestyle='--', alpha=0.7, 
                label=f'Best: {best_acc:.2f}%')
    ax2.legend()
    
    # 3. Loss smoothed (moving average)
    ax3 = plt.subplot(2, 4, 3)
    window = min(10, len(epochs) // 4)
    if window > 1:
        smoothed_total = np.convolve(training_history['total_loss'], 
                                   np.ones(window)/window, mode='valid')
        smoothed_epochs = epochs[window-1:]
        ax3.plot(smoothed_epochs, smoothed_total, 'darkgreen', linewidth=2)
    ax3.set_title(f'Smoothed Total Loss (window={window})', fontweight='bold', fontsize=12)
    ax3.set_xlabel('Epoch')
    ax3.set_ylabel('Loss')
    ax3.grid(True, alpha=0.3)
    
    # 4. Learning dynamics
    ax4 = plt.subplot(2, 4, 4)
    consistency_ratio = [c/(s+1e-8) for s, c in zip(training_history['supervised_loss'], 
                                                    training_history['consistency_loss'])]
    ax4.plot(epochs, consistency_ratio, 'brown', linewidth=2)
    ax4.set_title('Consistency/Supervised Ratio', fontweight='bold', fontsize=12)
    ax4.set_xlabel('Epoch')
    ax4.set_ylabel('Ratio')
    ax4.grid(True, alpha=0.3)
    
    # 5. Validation loss vs accuracy
    ax5 = plt.subplot(2, 4, 5)
    scatter = ax5.scatter(training_history['val_loss'], training_history['val_accuracy'], 
                         c=epochs, cmap='viridis', alpha=0.7)
    ax5.set_title('Validation Loss vs Accuracy', fontweight='bold', fontsize=12)
    ax5.set_xlabel('Validation Loss')
    ax5.set_ylabel('Validation Accuracy (%)')
    plt.colorbar(scatter, ax=ax5, label='Epoch')
    ax5.grid(True, alpha=0.3)
    
    # 6. Training efficiency
    ax6 = plt.subplot(2, 4, 6)
    total_loss_diff = np.diff(training_history['total_loss'])
    ax6.plot(epochs[1:], total_loss_diff, 'orange', linewidth=2)
    ax6.axhline(y=0, color='black', linestyle='--', alpha=0.5)
    ax6.set_title('Loss Change Rate', fontweight='bold', fontsize=12)
    ax6.set_xlabel('Epoch')
    ax6.set_ylabel('Loss Œî')
    ax6.grid(True, alpha=0.3)
    
    # 7. Accuracy improvement
    ax7 = plt.subplot(2, 4, 7)
    acc_diff = np.diff(training_history['val_accuracy'])
    ax7.plot(epochs[1:], acc_diff, 'darkblue', linewidth=2)
    ax7.axhline(y=0, color='black', linestyle='--', alpha=0.5)
    ax7.set_title('Accuracy Change Rate', fontweight='bold', fontsize=12)
    ax7.set_xlabel('Epoch')
    ax7.set_ylabel('Accuracy Œî (%)')
    ax7.grid(True, alpha=0.3)
    
    # 8. Training summary stats
    ax8 = plt.subplot(2, 4, 8)
    ax8.axis('off')
    
    # Calculate statistics
    final_acc = training_history['val_accuracy'][-1]
    best_acc = max(training_history['val_accuracy'])
    best_epoch = training_history['val_accuracy'].index(best_acc) + 1
    acc_improvement = final_acc - training_history['val_accuracy'][0]
    
    summary_text = f\"\"\"Training Summary
    
üìä Final Accuracy: {final_acc:.2f}%
üèÜ Best Accuracy: {best_acc:.2f}%
üéØ Best Epoch: {best_epoch}
üìà Total Improvement: {acc_improvement:.2f}%
üî• Epochs Trained: {len(epochs)}

üí° Loss Components:
üìâ Final Supervised: {training_history['supervised_loss'][-1]:.4f}
üîÑ Final Consistency: {training_history['consistency_loss'][-1]:.4f}
‚öñÔ∏è Final Total: {training_history['total_loss'][-1]:.4f}
    \"\"\"
    
    ax8.text(0.1, 0.9, summary_text, transform=ax8.transAxes, fontsize=10,
            verticalalignment='top', fontfamily='monospace',
            bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.8))
    
    plt.tight_layout()
    
    # Save the visualization
    viz_path = os.path.join(checkpoint_dir, 'training_analysis.png')
    plt.savefig(viz_path, dpi=300, bbox_inches='tight')
    print(f"üìä Training analysis saved to: {viz_path}")
    
    plt.show()
    
    return fig

def backup_best_models(checkpoint_dir, backup_dir=None):
    """Backup the best models to Google Drive"""
    
    if backup_dir is None:
        backup_dir = '/content/drive/MyDrive/ViT-FishID_MAE_EMA_Backup'
    
    os.makedirs(backup_dir, exist_ok=True)
    
    # Find best model files
    best_files = ['ema_best_model.pth', 'ema_final_model.pth']
    
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    
    for filename in best_files:
        source_path = os.path.join(checkpoint_dir, filename)
        if os.path.exists(source_path):
            # Create timestamped backup
            backup_filename = f\"{timestamp}_{filename}\"
            backup_path = os.path.join(backup_dir, backup_filename)\n            \n            try:\n                shutil.copy2(source_path, backup_path)\n                file_size = os.path.getsize(backup_path) / (1024 * 1024)  # MB\n                print(f\"‚úÖ Backed up {filename} -> {backup_filename} ({file_size:.1f} MB)\")\n            except Exception as e:\n                print(f\"‚ùå Failed to backup {filename}: {e}\")\n        else:\n            print(f\"‚ö†Ô∏è  {filename} not found in {checkpoint_dir}\")\n    \n    print(f\"üíæ Backups saved to: {backup_dir}\")\n\n# Run analysis if checkpoints exist\nif 'EMA_CONFIG' in globals() and os.path.exists(EMA_CONFIG['checkpoint_dir']):\n    print(f\"üìÅ Analyzing checkpoints in: {EMA_CONFIG['checkpoint_dir']}\")\n    \n    # Analyze checkpoints\n    checkpoint_data = analyze_training_checkpoints(EMA_CONFIG['checkpoint_dir'])\n    \n    # Visualize training progress\n    if 'training_history' in globals():\n        print(\"\\nüìä Creating training progress visualization...\")\n        fig = visualize_training_progress(EMA_CONFIG['checkpoint_dir'], training_history)\n    \n    # Backup best models\n    print(\"\\nüíæ Backing up best models to Google Drive...\")\n    backup_best_models(EMA_CONFIG['checkpoint_dir'])\n    \n    # Model size analysis\n    print(\"\\nüìä MODEL SIZE ANALYSIS:\")\n    print(\"=\"*40)\n    \n    if checkpoint_data:\n        latest_checkpoint = max(checkpoint_data, key=lambda x: x['epoch'])\n        print(f\"Latest checkpoint size: {latest_checkpoint['file_size']:.1f} MB\")\n        \n        total_size = sum(cp['file_size'] for cp in checkpoint_data)\n        print(f\"Total checkpoint storage: {total_size:.1f} MB\")\n        \n        avg_size = total_size / len(checkpoint_data)\n        print(f\"Average checkpoint size: {avg_size:.1f} MB\")\n    \n    # Training efficiency metrics\n    if 'training_history' in globals() and 'total_time' in globals():\n        print(\"\\n‚ö° TRAINING EFFICIENCY:\")\n        print(\"=\"*40)\n        \n        total_epochs = len(training_history['val_accuracy'])\n        time_per_epoch = total_time / total_epochs\n        \n        print(f\"Total training time: {total_time/3600:.2f} hours\")\n        print(f\"Time per epoch: {time_per_epoch:.1f} seconds\")\n        print(f\"Final accuracy: {training_history['val_accuracy'][-1]:.2f}%\")\n        print(f\"Best accuracy: {max(training_history['val_accuracy']):.2f}%\")\n        \n        # Accuracy per hour\n        acc_per_hour = max(training_history['val_accuracy']) / (total_time / 3600)\n        print(f\"Accuracy gained per hour: {acc_per_hour:.2f}%/hr\")\n    \n    # Disk usage summary\n    print(\"\\nüíæ STORAGE SUMMARY:\")\n    print(\"=\"*40)\n    \n    checkpoint_size = sum(os.path.getsize(os.path.join(EMA_CONFIG['checkpoint_dir'], f)) \n                         for f in os.listdir(EMA_CONFIG['checkpoint_dir']) \n                         if f.endswith('.pth')) / (1024 * 1024)  # MB\n    \n    mae_checkpoint_size = 0\n    if 'MAE_CONFIG' in globals() and os.path.exists(MAE_CONFIG['checkpoint_dir']):\n        mae_checkpoint_size = sum(os.path.getsize(os.path.join(MAE_CONFIG['checkpoint_dir'], f)) \n                                 for f in os.listdir(MAE_CONFIG['checkpoint_dir']) \n                                 if f.endswith('.pth')) / (1024 * 1024)  # MB\n    \n    total_storage = checkpoint_size + mae_checkpoint_size\n    \n    print(f\"EMA checkpoints: {checkpoint_size:.1f} MB\")\n    print(f\"MAE checkpoints: {mae_checkpoint_size:.1f} MB\")\n    print(f\"Total storage used: {total_storage:.1f} MB\")\n    \nelse:\n    print(\"‚ö†Ô∏è  No checkpoint directory found. Training may not have completed.\")\n\nprint(f\"\\n‚úÖ Training monitoring and checkpoint management complete!\")\nprint(f\"üìä All analysis saved to checkpoint directories\")\nprint(f\"üíæ Best models backed up to Google Drive\")\nprint(f\"üéØ Ready for model evaluation and deployment!\")

## üß™ Section 11: Evaluate Final Model Performance

Comprehensive evaluation of the trained model with performance comparisons and deployment preparation.

In [None]:
# Comprehensive Model Evaluation and Performance Analysis
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support
from sklearn.metrics import top_k_accuracy_score
import pandas as pd
from PIL import Image
import random
import os

print("üß™ COMPREHENSIVE MODEL EVALUATION")
print("="*60)

def load_best_model(checkpoint_path, device):
    """Load the best trained model"""
    if not os.path.exists(checkpoint_path):
        print(f"‚ùå Checkpoint not found: {checkpoint_path}")
        return None, None
    
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    print(f"üìÇ Loading checkpoint from epoch {checkpoint['epoch']}")
    print(f"üèÜ Best accuracy: {checkpoint['best_accuracy']:.2f}%")
    
    # Recreate model architecture
    model = MAEPretrainedViT(
        num_classes=checkpoint['num_classes'],
        mae_encoder=None,  # Will load weights directly
        dropout_rate=0.1
    ).to(device)
    
    # Load state dict
    model.load_state_dict(checkpoint['student_state_dict'])
    model.eval()
    
    class_to_idx = checkpoint.get('class_to_idx', {})
    idx_to_class = {v: k for k, v in class_to_idx.items()}
    
    return model, idx_to_class

def create_test_dataset(labeled_dir, class_to_idx, img_size=224, test_split=0.2):
    \"\"\"Create a test dataset from labeled data\"\"\"\n    \n    test_data = []\n    \n    for species, class_idx in class_to_idx.items():\n        species_path = os.path.join(labeled_dir, species)\n        if not os.path.exists(species_path):\n            continue\n            \n        # Get all images for this species\n        images = []\n        for ext in ['*.jpg', '*.jpeg', '*.png', '*.JPG', '*.JPEG', '*.PNG']:\n            images.extend(glob.glob(os.path.join(species_path, ext)))\n        \n        # Take last 20% as test set (assuming first 80% used for training)\n        test_count = max(1, int(len(images) * test_split))\n        test_images = images[-test_count:]\n        \n        for img_path in test_images:\n            test_data.append((img_path, class_idx))\n    \n    print(f\"üìä Created test set with {len(test_data)} images\")\n    return test_data\n\ndef evaluate_model_comprehensive(model, test_data, idx_to_class, device, img_size=224):\n    \"\"\"Comprehensive model evaluation\"\"\"\n    \n    model.eval()\n    \n    # Prepare data transforms\n    test_transform = transforms.Compose([\n        transforms.Resize((img_size, img_size)),\n        transforms.ToTensor(),\n        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n    ])\n    \n    all_predictions = []\n    all_labels = []\n    all_probabilities = []\n    prediction_details = []\n    \n    print(f\"üß™ Evaluating on {len(test_data)} test images...\")\n    \n    with torch.no_grad():\n        for i, (img_path, true_label) in enumerate(tqdm(test_data, desc=\"Evaluating\")):\n            try:\n                # Load and preprocess image\n                image = Image.open(img_path).convert('RGB')\n                image_tensor = test_transform(image).unsqueeze(0).to(device)\n                \n                # Get model prediction\n                outputs = model(image_tensor)\n                probabilities = F.softmax(outputs, dim=1)\n                predicted_class = torch.argmax(outputs, dim=1).item()\n                confidence = probabilities[0, predicted_class].item()\n                \n                all_predictions.append(predicted_class)\n                all_labels.append(true_label)\n                all_probabilities.append(probabilities.cpu().numpy()[0])\n                \n                prediction_details.append({\n                    'image_path': img_path,\n                    'true_label': true_label,\n                    'true_species': idx_to_class[true_label],\n                    'predicted_label': predicted_class,\n                    'predicted_species': idx_to_class[predicted_class],\n                    'confidence': confidence,\n                    'correct': predicted_class == true_label\n                })\n                \n            except Exception as e:\n                print(f\"‚ö†Ô∏è  Error processing {img_path}: {e}\")\n                continue\n    \n    return all_predictions, all_labels, all_probabilities, prediction_details\n\ndef analyze_results(predictions, labels, probabilities, prediction_details, idx_to_class):\n    \"\"\"Analyze evaluation results\"\"\"\n    \n    # Basic metrics\n    accuracy = np.mean(np.array(predictions) == np.array(labels)) * 100\n    \n    # Top-k accuracy\n    top3_acc = top_k_accuracy_score(labels, probabilities, k=3) * 100\n    top5_acc = top_k_accuracy_score(labels, probabilities, k=5) * 100\n    \n    print(f\"\\nüìä EVALUATION RESULTS\")\n    print(\"=\"*50)\n    print(f\"üéØ Top-1 Accuracy: {accuracy:.2f}%\")\n    print(f\"üéØ Top-3 Accuracy: {top3_acc:.2f}%\")\n    print(f\"üéØ Top-5 Accuracy: {top5_acc:.2f}%\")\n    \n    # Per-class metrics\n    precision, recall, f1, support = precision_recall_fscore_support(\n        labels, predictions, average=None, zero_division=0\n    )\n    \n    # Create detailed classification report\n    class_names = [idx_to_class[i] for i in range(len(idx_to_class))]\n    report = classification_report(\n        labels, predictions, \n        target_names=class_names, \n        output_dict=True,\n        zero_division=0\n    )\n    \n    # Convert to DataFrame for better visualization\n    report_df = pd.DataFrame(report).transpose()\n    \n    print(f\"\\nüìä PER-CLASS PERFORMANCE (Top 10 by F1-Score):\")\n    print(\"-\"*70)\n    \n    # Sort by F1-score and show top 10\n    class_metrics = report_df.iloc[:-3].sort_values('f1-score', ascending=False)\n    top_classes = class_metrics.head(10)\n    \n    for idx, (species, metrics) in enumerate(top_classes.iterrows()):\n        print(f\"{idx+1:2d}. {species[:25]:<25} Precision: {metrics['precision']:.3f} \"\n              f\"Recall: {metrics['recall']:.3f} F1: {metrics['f1-score']:.3f}\")\n    \n    # Confidence analysis\n    confidences = [detail['confidence'] for detail in prediction_details]\n    correct_confidences = [detail['confidence'] for detail in prediction_details if detail['correct']]\n    incorrect_confidences = [detail['confidence'] for detail in prediction_details if not detail['correct']]\n    \n    print(f\"\\nüìä CONFIDENCE ANALYSIS:\")\n    print(\"-\"*40)\n    print(f\"Average confidence (all): {np.mean(confidences):.3f}\")\n    print(f\"Average confidence (correct): {np.mean(correct_confidences):.3f}\")\n    print(f\"Average confidence (incorrect): {np.mean(incorrect_confidences):.3f}\")\n    \n    return {\n        'accuracy': accuracy,\n        'top3_accuracy': top3_acc,\n        'top5_accuracy': top5_acc,\n        'report_df': report_df,\n        'class_metrics': class_metrics,\n        'confidences': confidences,\n        'correct_confidences': correct_confidences,\n        'incorrect_confidences': incorrect_confidences\n    }\n\ndef visualize_evaluation_results(results, predictions, labels, idx_to_class, save_dir):\n    \"\"\"Create comprehensive evaluation visualizations\"\"\"\n    \n    # Create figure with multiple subplots\n    fig = plt.figure(figsize=(20, 15))\n    \n    # 1. Confusion Matrix\n    ax1 = plt.subplot(3, 3, 1)\n    cm = confusion_matrix(labels, predictions)\n    \n    # Normalize confusion matrix\n    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]\n    \n    # Only show top 15 classes for readability\n    top_classes_idx = results['class_metrics'].head(15).index\n    top_class_indices = [i for i, name in enumerate(idx_to_class.values()) if name in top_classes_idx]\n    \n    if len(top_class_indices) > 1:\n        cm_subset = cm_normalized[np.ix_(top_class_indices, top_class_indices)]\n        class_names_subset = [list(idx_to_class.values())[i][:10] for i in top_class_indices]\n        \n        sns.heatmap(cm_subset, annot=True, fmt='.2f', cmap='Blues', \n                   xticklabels=class_names_subset, yticklabels=class_names_subset,\n                   ax=ax1)\n        ax1.set_title('Confusion Matrix (Top 15 Classes)', fontweight='bold')\n        ax1.set_xlabel('Predicted')\n        ax1.set_ylabel('Actual')\n    \n    # 2. Accuracy by class\n    ax2 = plt.subplot(3, 3, 2)\n    top_f1_classes = results['class_metrics'].head(15)\n    ax2.barh(range(len(top_f1_classes)), top_f1_classes['f1-score'])\n    ax2.set_yticks(range(len(top_f1_classes)))\n    ax2.set_yticklabels([name[:15] for name in top_f1_classes.index])\n    ax2.set_xlabel('F1-Score')\n    ax2.set_title('F1-Score by Species (Top 15)', fontweight='bold')\n    ax2.grid(True, alpha=0.3)\n    \n    # 3. Confidence distribution\n    ax3 = plt.subplot(3, 3, 3)\n    ax3.hist(results['correct_confidences'], bins=30, alpha=0.7, label='Correct', color='green')\n    ax3.hist(results['incorrect_confidences'], bins=30, alpha=0.7, label='Incorrect', color='red')\n    ax3.set_xlabel('Confidence')\n    ax3.set_ylabel('Frequency')\n    ax3.set_title('Confidence Distribution', fontweight='bold')\n    ax3.legend()\n    ax3.grid(True, alpha=0.3)\n    \n    # 4. Top-k accuracy\n    ax4 = plt.subplot(3, 3, 4)\n    k_values = [1, 3, 5]\n    k_accuracies = [results['accuracy'], results['top3_accuracy'], results['top5_accuracy']]\n    ax4.bar(k_values, k_accuracies, color=['blue', 'orange', 'green'])\n    ax4.set_xlabel('K (Top-K)')\n    ax4.set_ylabel('Accuracy (%)')\n    ax4.set_title('Top-K Accuracy', fontweight='bold')\n    ax4.grid(True, alpha=0.3)\n    \n    # 5. Precision vs Recall scatter\n    ax5 = plt.subplot(3, 3, 5)\n    class_metrics = results['class_metrics'].iloc[:-3]  # Exclude summary rows\n    scatter = ax5.scatter(class_metrics['recall'], class_metrics['precision'], \n                         c=class_metrics['f1-score'], cmap='viridis', alpha=0.7)\n    ax5.set_xlabel('Recall')\n    ax5.set_ylabel('Precision')\n    ax5.set_title('Precision vs Recall by Species', fontweight='bold')\n    plt.colorbar(scatter, ax=ax5, label='F1-Score')\n    ax5.grid(True, alpha=0.3)\n    \n    # 6. Support distribution\n    ax6 = plt.subplot(3, 3, 6)\n    support_counts = results['class_metrics']['support'].iloc[:-3]\n    ax6.hist(support_counts, bins=20, color='purple', alpha=0.7)\n    ax6.set_xlabel('Number of Test Samples')\n    ax6.set_ylabel('Number of Species')\n    ax6.set_title('Test Sample Distribution', fontweight='bold')\n    ax6.grid(True, alpha=0.3)\n    \n    # 7. Performance summary text\n    ax7 = plt.subplot(3, 3, 7)\n    ax7.axis('off')\n    \n    summary_text = f\"\"\"Model Performance Summary\n    \nüéØ Overall Accuracy: {results['accuracy']:.2f}%\nüéØ Top-3 Accuracy: {results['top3_accuracy']:.2f}%\nüéØ Top-5 Accuracy: {results['top5_accuracy']:.2f}%\n\nüìä Best Performing Species:\n{results['class_metrics'].head(3).index.tolist()[0][:20]}...\n{results['class_metrics'].head(3).index.tolist()[1][:20]}...\n{results['class_metrics'].head(3).index.tolist()[2][:20]}...\n\nüí° Average Confidence:\n‚úÖ Correct: {np.mean(results['correct_confidences']):.3f}\n‚ùå Incorrect: {np.mean(results['incorrect_confidences']):.3f}\n\nüìà Model Quality:\n{\"Excellent\" if results['accuracy'] > 90 else \"Good\" if results['accuracy'] > 80 else \"Fair\" if results['accuracy'] > 70 else \"Needs Improvement\"}\n    \"\"\"\n    \n    ax7.text(0.1, 0.9, summary_text, transform=ax7.transAxes, fontsize=10,\n            verticalalignment='top', fontfamily='monospace',\n            bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.8))\n    \n    plt.tight_layout()\n    \n    # Save visualization\n    viz_path = os.path.join(save_dir, 'evaluation_results.png')\n    plt.savefig(viz_path, dpi=300, bbox_inches='tight')\n    print(f\"üìä Evaluation visualization saved to: {viz_path}\")\n    \n    plt.show()\n    \n    return fig\n\n# Run comprehensive evaluation\nif 'EMA_BEST_MODEL_PATH' in globals() and os.path.exists(EMA_BEST_MODEL_PATH):\n    print(f\"üß™ Loading best EMA model: {EMA_BEST_MODEL_PATH}\")\n    \n    # Load the best model\n    best_model, idx_to_class = load_best_model(EMA_BEST_MODEL_PATH, DEVICE)\n    \n    if best_model is not None:\n        # Create test dataset\n        test_data = create_test_dataset(LABELED_DIR, \n                                       {v: k for k, v in idx_to_class.items()},\n                                       img_size=224)\n        \n        if test_data:\n            # Run comprehensive evaluation\n            predictions, labels, probabilities, prediction_details = evaluate_model_comprehensive(\n                best_model, test_data, idx_to_class, DEVICE\n            )\n            \n            # Analyze results\n            results = analyze_results(predictions, labels, probabilities, \n                                    prediction_details, idx_to_class)\n            \n            # Create visualizations\n            if 'EMA_CONFIG' in globals():\n                fig = visualize_evaluation_results(results, predictions, labels, \n                                                  idx_to_class, EMA_CONFIG['checkpoint_dir'])\n            \n            # Save detailed results\n            if 'EMA_CONFIG' in globals():\n                results_path = os.path.join(EMA_CONFIG['checkpoint_dir'], 'evaluation_results.json')\n                \n                # Prepare results for JSON serialization\n                json_results = {\n                    'overall_accuracy': float(results['accuracy']),\n                    'top3_accuracy': float(results['top3_accuracy']),\n                    'top5_accuracy': float(results['top5_accuracy']),\n                    'average_confidence_correct': float(np.mean(results['correct_confidences'])),\n                    'average_confidence_incorrect': float(np.mean(results['incorrect_confidences'])),\n                    'total_test_samples': len(test_data),\n                    'num_classes': len(idx_to_class)\n                }\n                \n                import json\n                with open(results_path, 'w') as f:\n                    json.dump(json_results, f, indent=2)\n                \n                print(f\"üìä Detailed results saved to: {results_path}\")\n            \n            print(f\"\\nüéâ EVALUATION COMPLETED SUCCESSFULLY!\")\n            print(f\"üèÜ Final model achieved {results['accuracy']:.2f}% accuracy\")\n            print(f\"üöÄ Model ready for deployment!\")\n        \n        else:\n            print(\"‚ùå No test data available for evaluation\")\n    \n    else:\n        print(\"‚ùå Could not load the best model\")\n\nelse:\n    print(\"‚ö†Ô∏è  Best model checkpoint not found. Please complete training first.\")\n\nprint(f\"\\n‚úÖ Model evaluation and analysis complete!\")\nprint(f\"üìä All evaluation results saved to checkpoint directory\")\nprint(f\"üéØ Ready for model deployment and production use!\")