# 🐟 ViT-FishID: Semi-Supervised Fish Classification

**COMPLETE TRAINING PIPELINE WITH GOOGLE COLAB**

<a href="https://colab.research.google.com/github/cat-thomson/ViT-FishID/blob/main/ViT_FishID_Colab_Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 🎯 What This Notebook Does

This notebook implements a **complete semi-supervised learning pipeline** for fish species classification using:

**🤖 Vision Transformer (ViT)**: State-of-the-art transformer architecture for image classification
**📊 Semi-Supervised Learning**: Leverages both labeled and unlabeled fish images
**🎓 EMA Teacher-Student Framework**: Uses exponential moving averages for consistency training
**☁️ Google Colab**: Cloud-based training with GPU acceleration

## 📊 Expected Performance

- **Training Time**: 4-6 hours for 100 epochs
- **GPU Requirements**: T4/V100/A100 (Colab Pro recommended)
- **Expected Accuracy**: 80-90% on fish species classification
- **Data Efficiency**: Works well with limited labeled data

## 🛠️ What You Need

1. **Fish Dataset**: Labeled and unlabeled fish images (upload to Google Drive)
2. **Google Colab Pro**: Recommended for longer training sessions
3. **Weights & Biases Account**: Optional for experiment tracking

## 🔧 Step 1: Environment Setup and GPU Check

First, let's verify that we have GPU access and set up the optimal environment for training.

In [2]:
# Check GPU availability and system information
import torch
import os
import gc

print("🔍 SYSTEM INFORMATION")
print("="*50)
print(f"Python version: {os.sys.version}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    device_name = torch.cuda.get_device_name(0)
    device_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
    print(f"GPU Device: {device_name}")
    print(f"GPU Memory: {device_memory:.1f} GB")
    print("✅ GPU is ready for training!")

    # Set optimal GPU settings
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False

    # Clear GPU cache
    torch.cuda.empty_cache()
    gc.collect()
    print("🚀 GPU optimized for training")

else:
    print("❌ No GPU detected!")
    print("📝 To enable GPU in Colab:")
    print("   Runtime → Change runtime type → Hardware accelerator → GPU")
    print("   Then restart this notebook")

# Set device for later use
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\n🎯 Using device: {DEVICE}")

🔍 SYSTEM INFORMATION
Python version: 3.11.13 (main, Jun  4 2025, 08:57:29) [GCC 11.4.0]
PyTorch version: 2.6.0+cu124
CUDA available: True
GPU Device: Tesla T4
GPU Memory: 14.7 GB
✅ GPU is ready for training!
🚀 GPU optimized for training

🎯 Using device: cuda


## 📁 Step 2: Mount Google Drive

This will give us access to your fish dataset stored in Google Drive.

In [11]:
from google.colab import drive
import os
import shutil

# Mount Google Drive
print("Attempting to mount Google Drive...")

# Ensure the mount point is clean before mounting
mount_point = '/content/drive'
if os.path.exists(mount_point) and os.path.isdir(mount_point):
    print(f"Clearing contents of mount point: {mount_point}")
    try:
        # Use `rm -rf` via shell command for robustness in Colab environment
        !rm -rf {mount_point}/*
        # Recreate the directory structure if it was completely removed
        if not os.path.exists(mount_point):
             os.makedirs(mount_point)
        print("✅ Mount point cleared.")
    except Exception as e:
        print(f"❌ Error clearing mount point: {e}")
        print("Attempting to proceed with mount anyway...")


drive.mount('/content/drive')

# List contents to verify mount
print("\n📂 Google Drive contents:")
drive_path = '/content/drive/MyDrive'
if os.path.exists(drive_path):
    items = os.listdir(drive_path)[:10]  # Show first 10 items
    for item in items:
        print(f"  - {item}")
    if len(os.listdir(drive_path)) > 10:
        print(f"  ... and {len(os.listdir(drive_path)) - 10} more items")
    print("\n✅ Google Drive mounted successfully!")
else:
    print("❌ Failed to mount Google Drive")

Attempting to mount Google Drive...
Clearing contents of mount point: /content/drive
✅ Mount point cleared.
Mounted at /content/drive

📂 Google Drive contents:
  - Mock Matric
  - Photos
  - Admin
  - Uni
  - Fish_Training_Output
  - Colab Notebooks
  - ViT-FishID
  - ViT-FishID_Training_20250814_154652
  - ViT-FishID_Training_20250814_202307
  - ViT-FishID_Training_20250814_205442
  ... and 6 more items

✅ Google Drive mounted successfully!


## 📦 Step 3: Install Dependencies

Installing all required packages for ViT-FishID training.

In [12]:
# Install required packages
print("📦 Installing dependencies...")

!pip install -q torch torchvision torchaudio
!pip install -q timm transformers
!pip install -q albumentations
!pip install -q wandb
!pip install -q opencv-python-headless
!pip install -q scikit-learn
!pip install -q matplotlib seaborn
!pip install -q tqdm

print("✅ All dependencies installed successfully!")

# Verify installations
import torch
import torchvision
import timm
import albumentations
import cv2
import sklearn

print("\n📋 Package versions:")
print(f"  - torch: {torch.__version__}")
print(f"  - torchvision: {torchvision.__version__}")
print(f"  - timm: {timm.__version__}")
print(f"  - albumentations: {albumentations.__version__}")
print(f"  - opencv: {cv2.__version__}")
print(f"  - sklearn: {sklearn.__version__}")

📦 Installing dependencies...
✅ All dependencies installed successfully!

📋 Package versions:
  - torch: 2.6.0+cu124
  - torchvision: 0.21.0+cu124
  - timm: 1.0.19
  - albumentations: 2.0.8
  - opencv: 4.12.0
  - sklearn: 1.6.1


## 🔄 Step 4: Clone ViT-FishID Repository

Getting the latest code from your GitHub repository.

In [13]:
# Clone the repository
import os

# Remove existing directory if it exists
if os.path.exists('/content/ViT-FishID'):
    !rm -rf /content/ViT-FishID

# Clone the repository
print("📥 Cloning ViT-FishID repository...")
!git clone https://github.com/cat-thomson/ViT-FishID.git /content/ViT-FishID

# Change to project directory
%cd /content/ViT-FishID

# List project files
print("\n📂 Project structure:")
!ls -la

print("\n✅ Repository cloned successfully!")

📥 Cloning ViT-FishID repository...
Cloning into '/content/ViT-FishID'...
remote: Enumerating objects: 164, done.[K
remote: Counting objects: 100% (164/164), done.[K
remote: Compressing objects: 100% (122/122), done.[K
remote: Total 164 (delta 69), reused 124 (delta 35), pack-reused 0 (from 0)[K
Receiving objects: 100% (164/164), 322.97 KiB | 1.32 MiB/s, done.
Resolving deltas: 100% (69/69), done.
/content/ViT-FishID

📂 Project structure:
total 612
drwxr-xr-x 5 root root   4096 Aug 18 09:21 .
drwxr-xr-x 1 root root   4096 Aug 18 09:21 ..
-rw-r--r-- 1 root root   4182 Aug 18 09:21 COLAB_CRASH_FIXES.md
-rw-r--r-- 1 root root  21217 Aug 18 09:21 data.py
-rw-r--r-- 1 root root  11572 Aug 18 09:21 evaluate.py
-rw-r--r-- 1 root root   3328 Aug 18 09:21 EXTENDED_TRAINING_SETUP.md
drwxr-xr-x 3 root root   4096 Aug 18 09:21 fish_cutouts
drwxr-xr-x 8 root root   4096 Aug 18 09:21 .git
-rw-r--r-- 1 root root     66 Aug 18 09:21 .gitattributes
-rw-r--r-- 1 root root    646 Aug 18 09:21 .gitigno

## 🐠 Step 5: Setup Fish Dataset

**Important**: Upload your `fish_cutouts.zip` file to Google Drive before running this step.

Expected dataset structure:
```
fish_cutouts/
├── labeled/
│   ├── species_1/
│   │   ├── fish_001.jpg
│   │   └── fish_002.jpg
│   └── species_2/
│       └── ...
└── unlabeled/
    ├── fish_003.jpg
    └── fish_004.jpg
```

In [14]:
# Setup fish dataset from Google Drive
import zipfile
import shutil
import os
import glob

print("🐠 SETTING UP FISH DATASET")
print("="*50)

# Configuration
ZIP_FILE_PATH = '/content/drive/MyDrive/fish_cutouts.zip'
DATA_DIR = '/content/fish_cutouts'

print(f"📂 Looking for dataset: {ZIP_FILE_PATH}")
print(f"🎯 Target directory: {DATA_DIR}")

# Check if data already exists locally
if os.path.exists(DATA_DIR) and os.path.exists(os.path.join(DATA_DIR, 'labeled')):
    print("✅ Dataset already available locally!")

    # Quick validation
    labeled_dir = os.path.join(DATA_DIR, 'labeled')
    unlabeled_dir = os.path.join(DATA_DIR, 'unlabeled')

    if os.path.exists(labeled_dir):
        species_count = len([d for d in os.listdir(labeled_dir)
                           if os.path.isdir(os.path.join(labeled_dir, d)) and not d.startswith('.')])
        print(f"🐟 Found {species_count} labeled species")

    if os.path.exists(unlabeled_dir):
        unlabeled_count = len([f for f in os.listdir(unlabeled_dir)
                             if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
        print(f"📊 Found {unlabeled_count} unlabeled images")

else:
    print("📥 Extracting dataset from Google Drive...")

    # Check if ZIP file exists
    if not os.path.exists(ZIP_FILE_PATH):
        print(f"❌ Dataset not found at: {ZIP_FILE_PATH}")
        print("📝 Please upload fish_cutouts.zip to Google Drive root directory")
    else:
        print(f"✅ Found dataset: {os.path.getsize(ZIP_FILE_PATH) / (1024**2):.1f} MB")

        try:
            # Extract to temporary directory
            temp_dir = '/content/temp_extract'
            if os.path.exists(temp_dir):
                shutil.rmtree(temp_dir)

            with zipfile.ZipFile(ZIP_FILE_PATH, 'r') as zip_ref:
                zip_ref.extractall(temp_dir)

            # Find and organize data
            extracted_items = os.listdir(temp_dir)
            print(f"📁 Extracted: {extracted_items}")

            # Look for labeled and unlabeled directories
            labeled_source = None
            unlabeled_source = None

            for item in extracted_items:
                item_path = os.path.join(temp_dir, item)
                if item == 'labeled' and os.path.isdir(item_path):
                    labeled_source = item_path
                elif item == 'unlabeled' and os.path.isdir(item_path):
                    unlabeled_source = item_path

            if labeled_source and unlabeled_source:
                # Create target directory
                if os.path.exists(DATA_DIR):
                    shutil.rmtree(DATA_DIR)
                os.makedirs(DATA_DIR)

                # Move directories
                shutil.move(labeled_source, os.path.join(DATA_DIR, 'labeled'))
                shutil.move(unlabeled_source, os.path.join(DATA_DIR, 'unlabeled'))

                print("✅ Dataset organized successfully!")

                # Verify structure
                labeled_dir = os.path.join(DATA_DIR, 'labeled')
                species_count = len([d for d in os.listdir(labeled_dir)
                                   if os.path.isdir(os.path.join(labeled_dir, d)) and not d.startswith('.')])

                unlabeled_dir = os.path.join(DATA_DIR, 'unlabeled')
                unlabeled_count = len([f for f in os.listdir(unlabeled_dir)
                                     if f.lower().endswith(('.jpg', '.jpeg', '.png'))])

                print(f"🐟 Verified: {species_count} species")
                print(f"📊 Verified: {unlabeled_count} unlabeled images")

            else:
                print("❌ Could not find labeled and unlabeled directories")

            # Cleanup
            if os.path.exists(temp_dir):
                shutil.rmtree(temp_dir)

        except Exception as e:
            print(f"❌ Error extracting dataset: {e}")

# Final verification
if os.path.exists(DATA_DIR):
    print(f"\n✅ DATASET READY")
    print(f"📁 Location: {DATA_DIR}")
    print("🚀 Ready for training!")
else:
    print(f"\n❌ DATASET SETUP FAILED")
    print("Please check that fish_cutouts.zip is uploaded to Google Drive")

🐠 SETTING UP FISH DATASET
📂 Looking for dataset: /content/drive/MyDrive/fish_cutouts.zip
🎯 Target directory: /content/fish_cutouts
📥 Extracting dataset from Google Drive...
✅ Found dataset: 216.5 MB
📁 Extracted: ['dataset_info.json', 'unlabeled', '__MACOSX', 'labeled']
✅ Dataset organized successfully!
🐟 Verified: 37 species
📊 Verified: 24015 unlabeled images

✅ DATASET READY
📁 Location: /content/fish_cutouts
🚀 Ready for training!


## 📈 Step 6: Setup Weights & Biases (Optional)

Weights & Biases provides excellent training visualization and experiment tracking.

In [15]:
# Login to Weights & Biases for experiment tracking
import wandb
import os

print("📈 SETTING UP WEIGHTS & BIASES")
print("="*40)

# Check if API key is available
if os.environ.get("WANDB_API_KEY"):
    print("✅ W&B API key found in environment")
    try:
        wandb.login(relogin=True)
        print("✅ Successfully logged in to W&B")
    except Exception as e:
        print(f"⚠️ W&B relogin failed: {e}")
        print("Trying manual login...")
        wandb.login()
else:
    print("🔑 Please enter your W&B API key when prompted")
    print("💡 Get your API key from: https://wandb.ai/settings")
    try:
        wandb.login()
        print("✅ Successfully logged in to W&B")
    except Exception as e:
        print(f"❌ W&B login failed: {e}")
        print("Continuing without W&B logging...")

# Check connection status
if wandb.run:
    print(f"🚀 W&B Run URL: {wandb.run.url}")
    USE_WANDB = True
else:
    print("📊 W&B not connected - training will continue without logging")
    USE_WANDB = False

print(f"✅ W&B setup complete (Enabled: {USE_WANDB})")

📈 SETTING UP WEIGHTS & BIASES
🔑 Please enter your W&B API key when prompted
💡 Get your API key from: https://wandb.ai/settings
✅ Successfully logged in to W&B
📊 W&B not connected - training will continue without logging
✅ W&B setup complete (Enabled: False)


## 🔄 Step 6: Locate Checkpoint from Epoch 19

Finding your saved checkpoint to resume training from where you left off.

In [6]:
# Locate checkpoint from epoch 19
import os
import glob
import torch

print("🔍 Looking for checkpoint from epoch 100...")

# Configuration
# Always start from the beginning
checkpoint_path = None
checkpoint_info = None
RESUME_CHECKPOINT = None # Ensure this is explicitly set to None

# Set up checkpoint directory for new saves
checkpoint_save_dir = '/content/drive/MyDrive/ViT-FishID/pretrained_checkpoints'
os.makedirs(checkpoint_save_dir, exist_ok=True)
print(f"💾 New checkpoints will be saved to: {checkpoint_save_dir}")


if checkpoint_path:
    print(f"\n🎉 Checkpoint ready for resuming training!")
    print(f"📄 File: {os.path.basename(checkpoint_path)}")
    print(f"📏 Size: {os.path.getsize(checkpoint_path) / (1024*1024):.1f} MB")
    print(f"💾 New checkpoints will be saved to: {checkpoint_save_dir}")

else:
    print("❌ No checkpoint found for epoch 19!")
    print("🚀 Starting fresh training from epoch 1")

# Store checkpoint path for later use
RESUME_CHECKPOINT = checkpoint_path

🔍 Looking for checkpoint from epoch 100...
💾 New checkpoints will be saved to: /content/drive/MyDrive/ViT-FishID/pretrained_checkpoints
❌ No checkpoint found for epoch 19!
🚀 Starting fresh training from epoch 1


## ⚙️ Step 7: Configure Training Parameters

Configure the training settings for your semi-supervised fish classification model.

In [7]:
# Training Configuration for Semi-Supervised Fish Classification
import os

print("⚙️ TRAINING CONFIGURATION")
print("="*50)

# Auto-detect number of species from dataset
NUM_CLASSES = 37  # Default
if 'DATA_DIR' in globals() and os.path.exists(DATA_DIR):
    labeled_dir = os.path.join(DATA_DIR, 'labeled')
    if os.path.exists(labeled_dir):
        species_count = len([d for d in os.listdir(labeled_dir)
                           if os.path.isdir(os.path.join(labeled_dir, d)) and not d.startswith('.')])
        NUM_CLASSES = species_count
        print(f"📊 Auto-detected {species_count} fish species")

# Create checkpoint directories
CHECKPOINT_DIR = '/content/drive/MyDrive/ViT-FishID/pretrained_checkpoints'
BACKUP_DIR = '/content/drive/MyDrive/ViT-FishID/pretrained_checkpoints_backup'

try:
    os.makedirs(CHECKPOINT_DIR, exist_ok=True)
    os.makedirs(BACKUP_DIR, exist_ok=True)
    print(f"📁 Checkpoints: {CHECKPOINT_DIR}")
    print(f"💾 Backups: {BACKUP_DIR}")
except Exception as e:
    print(f"⚠️ Could not create Google Drive directories: {e}")
    CHECKPOINT_DIR = '/content/pretraine_checkpoints'
    BACKUP_DIR = '/content/pretrained_checkpoints_backup'
    os.makedirs(CHECKPOINT_DIR, exist_ok=True)
    os.makedirs(BACKUP_DIR, exist_ok=True)
    print(f"📁 Using local checkpoints: {CHECKPOINT_DIR}")

# Training Configuration
TRAINING_CONFIG = {
    # BASIC SETTINGS
    'mode': 'semi_supervised',
    'data_dir': DATA_DIR if 'DATA_DIR' in globals() else '/content/fish_cutouts',
    'epochs': 100,
    'batch_size': 16,
    'learning_rate': 1e-4,
    'weight_decay': 0.05,

    # MODEL SETTINGS
    'model_name': 'vit_base_patch16_224',
    'num_classes': NUM_CLASSES,
    'pretrained': True,

    # SEMI-SUPERVISED SETTINGS
    'consistency_weight': 2.0,
    'pseudo_label_threshold': 0.7,
    'temperature': 4.0,
    'warmup_epochs': 10,
    'ramp_up_epochs': 30,

    # CHECKPOINT SETTINGS
    'save_frequency': 10,  # Save every 10 epochs
    'checkpoint_dir': CHECKPOINT_DIR,
    'backup_dir': BACKUP_DIR,

    # LOGGING SETTINGS
    'use_wandb': USE_WANDB if 'USE_WANDB' in globals() else False,
    'wandb_project': 'ViT-FishID-Training',
    'wandb_run_name': f'fish-classification-{NUM_CLASSES}-classes',
}

print("\n📋 TRAINING CONFIGURATION")
print("="*50)
print(f"🎯 Training mode: {TRAINING_CONFIG['mode']}")
print(f"📊 Total epochs: {TRAINING_CONFIG['epochs']}")
print(f"📦 Batch size: {TRAINING_CONFIG['batch_size']}")
print(f"🧠 Model: {TRAINING_CONFIG['model_name']}")
print(f"🐟 Number of species: {TRAINING_CONFIG['num_classes']}")
print(f"⚖️ Consistency weight: {TRAINING_CONFIG['consistency_weight']}")
print(f"🎯 Pseudo-label threshold: {TRAINING_CONFIG['pseudo_label_threshold']}")
print(f"💾 Save frequency: Every {TRAINING_CONFIG['save_frequency']} epochs")
print(f"📈 W&B logging: {TRAINING_CONFIG['use_wandb']}")

# Time estimation
estimated_time_hours = TRAINING_CONFIG['epochs'] * 3 / 60  # ~3 minutes per epoch
print(f"\n⏱️ Estimated training time: {estimated_time_hours:.1f} hours")
print(f"💡 Recommendation: Use Colab Pro for longer training sessions")

print("\n✅ Configuration complete - ready to start training!")

⚙️ TRAINING CONFIGURATION
📁 Checkpoints: /content/drive/MyDrive/ViT-FishID/pretrained_checkpoints
💾 Backups: /content/drive/MyDrive/ViT-FishID/pretrained_checkpoints_backup

📋 TRAINING CONFIGURATION
🎯 Training mode: semi_supervised
📊 Total epochs: 100
📦 Batch size: 16
🧠 Model: vit_base_patch16_224
🐟 Number of species: 37
⚖️ Consistency weight: 2.0
🎯 Pseudo-label threshold: 0.7
💾 Save frequency: Every 10 epochs
📈 W&B logging: False

⏱️ Estimated training time: 5.0 hours
💡 Recommendation: Use Colab Pro for longer training sessions

✅ Configuration complete - ready to start training!


## 🤖 Step 7a: Load MAE Pre-trained Model (Optional)

**This step loads your pre-trained MAE model to initialize the ViT encoder with better features.**

The MAE (Masked Autoencoder) model you trained provides much better initial weights for the Vision Transformer compared to ImageNet pretraining, especially for fish images since it was trained specifically on your fish dataset.

Benefits of using MAE initialization:
- **Better Feature Representations**: Learned specifically on fish images
- **Faster Convergence**: Model starts with relevant features
- **Improved Performance**: Often leads to 2-5% accuracy improvement

### 📁 MAE Model Locations

Your MAE models should be in one of these locations:
- **Local**: `/Users/catalinathomson/Desktop/Fish/ViT-FishID/mae_checkpoints/mae_final_model.pth`
- **Google Drive**: `/content/drive/MyDrive/mae_checkpoints/mae_final_model.pth` (after upload)

### 🔧 Setup Instructions

1. **Upload MAE Model**: Upload your `mae_final_model.pth` or `mae_best_model.pth` to Google Drive
2. **Update Path**: Modify `MAE_MODEL_PATH` in the next cell if needed
3. **Enable/Disable**: Set `LOAD_MAE_PRETRAINED = True/False` to control MAE loading

In [17]:
# Load MAE Pre-trained Model and Create Custom ViT Model
import torch
import os
import shutil
from model import ViTForFishClassification
import time # Import time for basic profiling

print("🤖 SETTING UP MAE-INITIALIZED ViT MODEL")
print("="*60)

# Configuration for MAE loading
MAE_MODEL_PATH = '/content/drive/MyDrive/mae_checkpoints/mae_final_model.pth'  # Update this path if needed
LOAD_MAE_PRETRAINED = True  # Set to False to skip MAE loading

# Global variable to store MAE state for later use
MAE_ENCODER_WEIGHTS = None
print(f"Configured MAE_MODEL_PATH: {MAE_MODEL_PATH}")
print(f"Configured LOAD_MAE_PRETRAINED: {LOAD_MAE_PRETRAINED}")


def load_mae_encoder_weights(mae_checkpoint_path):
    """
    Load and extract encoder weights from MAE checkpoint.

    Args:
        mae_checkpoint_path: Path to MAE checkpoint file

    Returns:
        dict: Filtered encoder weights compatible with ViT backbone
    """
    print(f"📥 Loading MAE checkpoint from: {mae_checkpoint_path}")
    start_time = time.time()

    try:
        # Load MAE checkpoint
        # Use map_location='cpu' first, then move to GPU if needed later
        # Added weights_only=False based on error message
        checkpoint = torch.load(mae_checkpoint_path, map_location='cpu', weights_only=False)
        print(f"✅ MAE checkpoint loaded in {time.time() - start_time:.2f} seconds.")

        # Print checkpoint info
        if 'epoch' in checkpoint:
            print(f"📊 MAE trained for {checkpoint['epoch']} epochs")
        if 'train_loss' in checkpoint:
            print(f"📉 Final MAE loss: {checkpoint['train_loss']:.4f}")
        if 'model_state_dict' in checkpoint or 'state_dict' in checkpoint:
             print("✅ Found model state dictionary in checkpoint.")
        else:
             print("⚠️ Could not find 'model_state_dict' or 'state_dict' in checkpoint.")


        # Get model state dict
        mae_state_dict = checkpoint.get('model_state_dict', checkpoint.get('state_dict', None))

        if mae_state_dict is None:
             print("❌ MAE state dictionary not found in checkpoint.")
             return None

        # Filter encoder weights (remove decoder, mask token, and other non-encoder components)
        encoder_weights = {}
        filter_prefixes = ['patch_embed', 'pos_embed', 'cls_token', 'blocks', 'norm']
        exclude_substrings = ['decoder', 'mask_token', 'head']

        print("Filtering MAE state dictionary for encoder weights...")
        start_time = time.time()
        for key, value in mae_state_dict.items():
            # Keep only encoder-related weights
            if any(prefix in key for prefix in filter_prefixes) and not any(exclude in key for exclude in exclude_substrings):
                encoder_weights[key] = value

        print(f"📊 Extracted {len(encoder_weights)} encoder parameters from MAE in {time.time() - start_time:.2f} seconds.")

        if not encoder_weights:
             print("⚠️ No encoder weights were extracted. Check filter logic or checkpoint structure.")


        return encoder_weights

    except Exception as e:
        print(f"❌ Error loading or processing MAE checkpoint: {e}")
        return None

def create_mae_initialized_model(num_classes, model_name='vit_base_patch16_224', mae_weights=None):
    """
    Create ViT model and optionally initialize with MAE weights.

    Args:
        num_classes: Number of classification classes
        model_name: ViT model architecture name
        mae_weights: Optional MAE encoder weights dictionary

    Returns:
        ViTForFishClassification: Initialized model
    """
    print(f"🏗️ Creating ViT model: {model_name}")

    # Create ViT model (without ImageNet pretraining if we have MAE weights)
    use_imagenet_pretrained = mae_weights is None
    print(f"Using ImageNet pretrained weights: {use_imagenet_pretrained}")
    start_time = time.time()
    model = ViTForFishClassification(
        num_classes=num_classes,
        model_name=model_name,
        pretrained=use_imagenet_pretrained,
        dropout_rate=0.1
    )
    print(f"✅ ViT model created in {time.time() - start_time:.2f} seconds.")

    if mae_weights is not None:
        print("⚡ Initializing ViT backbone with MAE encoder weights...")
        start_time = time.time()
        # Get current backbone state dict
        backbone_state = model.backbone.state_dict()

        # Update with MAE weights (only for matching keys and shapes)
        updated_keys = []
        shape_mismatches = []

        for mae_key, mae_weight in mae_weights.items():
            if mae_key in backbone_state:
                if mae_weight.shape == backbone_state[mae_key].shape:
                    backbone_state[mae_key] = mae_weight.clone()
                    updated_keys.append(mae_key)
                else:
                    shape_mismatches.append(f"{mae_key}: MAE{mae_weight.shape} != ViT{backbone_state[mae_key].shape}")

        # Load updated weights
        try:
            model.backbone.load_state_dict(backbone_state)
            print(f"✅ Successfully transferred {len(updated_keys)} MAE encoder weights in {time.time() - start_time:.2f} seconds.")

            if shape_mismatches:
                print(f"⚠️ Found {len(shape_mismatches)} shape mismatches (using original weights for these):")
                for mismatch in shape_mismatches[:5]:  # Show first 5 mismatches
                    print(f"   {mismatch}")

            print("🎯 ViT model initialized with MAE-learned features!")
        except Exception as e:
            print(f"❌ Error loading MAE weights into ViT backbone: {e}")
            print("Continuing with potentially partially loaded weights or default ImageNet (if applicable).")


    else:
        print("🌐 Using ImageNet pretrained weights")

    return model

# Main execution
if LOAD_MAE_PRETRAINED:
    print(f"Attempting to load MAE pretrained model from: {MAE_MODEL_PATH}")
    # Check if MAE model exists in Google Drive
    if os.path.exists(MAE_MODEL_PATH):
        print(f"✅ Found MAE model: {os.path.basename(MAE_MODEL_PATH)}")
        try:
            file_size = os.path.getsize(MAE_MODEL_PATH) / (1024**2)
            print(f"📏 Size: {file_size:.1f} MB")
        except Exception as e:
             print(f"⚠️ Could not get file size: {e}")


        try:
            # Load MAE encoder weights
            MAE_ENCODER_WEIGHTS = load_mae_encoder_weights(MAE_MODEL_PATH)

            if MAE_ENCODER_WEIGHTS is not None:
                 print("🎉 MAE encoder weights loaded successfully!")

                 # Update training config
                 if 'TRAINING_CONFIG' in globals():
                    TRAINING_CONFIG['mae_pretrained'] = True
                    TRAINING_CONFIG['mae_model_path'] = MAE_MODEL_PATH
                    TRAINING_CONFIG['pretrained'] = False  # Don't use ImageNet since we have MAE
                    print("✅ TRAINING_CONFIG updated for MAE pretraining.")
                 else:
                    print("⚠️ TRAINING_CONFIG not found. Cannot update config with MAE settings.")

            else:
                 print("❌ Failed to load MAE encoder weights. MAE_ENCODER_WEIGHTS is None.")
                 print("🔄 Falling back to ImageNet pretrained weights...")
                 MAE_ENCODER_WEIGHTS = None
                 if 'TRAINING_CONFIG' in globals():
                    TRAINING_CONFIG['mae_pretrained'] = False
                    TRAINING_CONFIG['pretrained'] = True

        except Exception as e:
            print(f"❌ Error during MAE loading process: {e}")
            print("🔄 Falling back to ImageNet pretrained weights...")
            MAE_ENCODER_WEIGHTS = None
            if 'TRAINING_CONFIG' in globals():
                TRAINING_CONFIG['mae_pretrained'] = False
                TRAINING_CONFIG['pretrained'] = True


    else:
        # MAE model not found, check alternative locations
        print(f"❌ MAE model not found at: {MAE_MODEL_PATH}")

        # Try to copy from local mae_checkpoints if exists
        local_mae_path = f'/content/ViT-FishID/mae_checkpoints/{os.path.basename(MAE_MODEL_PATH)}'
        print(f"Checking local path: {local_mae_path}")
        if os.path.exists(local_mae_path):
            print(f"� Found MAE model in local repository: {local_mae_path}")
            try:
                # Create directory and copy
                print(f"Attempting to create directory: {os.path.dirname(MAE_MODEL_PATH)}")
                os.makedirs(os.path.dirname(MAE_MODEL_PATH), exist_ok=True)
                print(f"Copying from {local_mae_path} to {MAE_MODEL_PATH}")
                shutil.copy2(local_mae_path, MAE_MODEL_PATH)
                print(f"✅ Copied MAE model to Google Drive: {MAE_MODEL_PATH}")

                # Now load it
                MAE_ENCODER_WEIGHTS = load_mae_encoder_weights(MAE_MODEL_PATH)
                if MAE_ENCODER_WEIGHTS is not None:
                    if 'TRAINING_CONFIG' in globals():
                        TRAINING_CONFIG['mae_pretrained'] = True
                        TRAINING_CONFIG['mae_model_path'] = MAE_MODEL_PATH
                        TRAINING_CONFIG['pretrained'] = False
                else:
                     print("❌ Failed to load MAE encoder weights after copying.")
                     MAE_ENCODER_WEIGHTS = None
                     if 'TRAINING_CONFIG' in globals():
                        TRAINING_CONFIG['mae_pretrained'] = False
                        TRAINING_CONFIG['pretrained'] = True


            except Exception as e:
                print(f"❌ Error copying/loading MAE model from local path: {e}")
                MAE_ENCODER_WEIGHTS = None
                if 'TRAINING_CONFIG' in globals():
                    TRAINING_CONFIG['mae_pretrained'] = False
                    TRAINING_CONFIG['pretrained'] = True
        else:
            print("�📝 MAE model not found in local repository either.")
            print("�📝 Available options:")
            print("1. Upload mae_final_model.pth or mae_best_model.pth to /content/drive/MyDrive/mae_checkpoints/")
            print("2. Update MAE_MODEL_PATH variable to correct location")
            print("3. Set LOAD_MAE_PRETRAINED = False to use ImageNet weights")
            print("🔄 Continuing with ImageNet pretrained weights...")
            MAE_ENCODER_WEIGHTS = None
            if 'TRAINING_CONFIG' in globals():
                TRAINING_CONFIG['mae_pretrained'] = False
                TRAINING_CONFIG['pretrained'] = True

else:
    print("⏭️ LOAD_MAE_PRETRAINED is False. Skipping MAE loading - will use ImageNet pretrained weights")
    MAE_ENCODER_WEIGHTS = None
    if 'TRAINING_CONFIG' in globals():
        TRAINING_CONFIG['mae_pretrained'] = False
        TRAINING_CONFIG['pretrained'] = True


# Test model creation (optional - this creates a model to verify everything works)
print(f"\n🧪 Testing model creation...")
if 'NUM_CLASSES' not in globals():
    print("⚠️ NUM_CLASSES not defined. Skipping model creation test.")
else:
    try:
        # Ensure model name and pretrained flag are correctly picked up from TRAINING_CONFIG
        model_name_for_test = TRAINING_CONFIG.get('model_name', 'vit_base_patch16_224')
        use_imagenet_for_test = TRAINING_CONFIG.get('pretrained', True) # Use the updated flag

        # If MAE weights were loaded, pass them, otherwise rely on TRAINING_CONFIG['pretrained']
        weights_for_test = MAE_ENCODER_WEIGHTS if MAE_ENCODER_WEIGHTS is not None else None

        print(f"Using model_name: {model_name_for_test}")
        print(f"Using MAE weights for test model: {weights_for_test is not None}")
        print(f"Using ImageNet pretrained for test model: {use_imagenet_for_test}")

        test_model = create_mae_initialized_model(
            num_classes=NUM_CLASSES,
            model_name=model_name_for_test,
            mae_weights=weights_for_test # Pass MAE weights if available
        )

        # Move model to device for testing (optional but good practice)
        # Assuming DEVICE is defined globally from Step 1
        if 'DEVICE' in globals():
            print(f"Moving test model to device: {DEVICE}")
            test_model.to(DEVICE)
        else:
            print("⚠️ DEVICE variable not found. Skipping moving test model to device.")


        # Test forward pass
        test_input = torch.randn(1, 3, 224, 224)
        # Move test input to the same device as the model
        if 'DEVICE' in globals():
            test_input = test_input.to(DEVICE)

        with torch.no_grad():
            test_output = test_model(test_input)

        print(f"✅ Model test successful!")
        print(f"📊 Input shape: {test_input.shape}")
        print(f"📊 Output shape: {test_output.shape}")
        print(f"🎯 Model ready for training!")

        # Clean up test model
        del test_model, test_input, test_output
        if 'DEVICE' in globals():
             torch.cuda.empty_cache() # Clear GPU cache after test if on GPU
        import gc
        gc.collect()


    except Exception as e:
        print(f"❌ Model test failed: {e}")
        import traceback
        traceback.print_exc() # Print full traceback for debugging


print(f"\n" + "="*60)
print(f"✅ MAE INITIALIZATION SETUP COMPLETE!")
if 'TRAINING_CONFIG' in globals():
    print(f"🤖 MAE pretrained: {TRAINING_CONFIG.get('mae_pretrained', False)}")
    print(f"🌐 ImageNet pretrained: {TRAINING_CONFIG.get('pretrained', True)}")
    print(f"📊 Model: {TRAINING_CONFIG.get('model_name', 'N/A')} with {TRAINING_CONFIG.get('num_classes', 'N/A')} classes")

    if TRAINING_CONFIG.get('mae_pretrained', False):
        print("🎉 Your model will start with MAE-learned features specific to fish images!")
        print("🚀 This should lead to faster training and better performance!")
    else:
        print("🌐 Your model will use standard ImageNet pretrained features.")
else:
     print("⚠️ TRAINING_CONFIG was not found, cannot provide detailed summary.")


print("🎯 Ready to proceed to training!")

🤖 SETTING UP MAE-INITIALIZED ViT MODEL
Configured MAE_MODEL_PATH: /content/drive/MyDrive/mae_checkpoints/mae_final_model.pth
Configured LOAD_MAE_PRETRAINED: True
Attempting to load MAE pretrained model from: /content/drive/MyDrive/mae_checkpoints/mae_final_model.pth
✅ Found MAE model: mae_final_model.pth
📏 Size: 149.6 MB
📥 Loading MAE checkpoint from: /content/drive/MyDrive/mae_checkpoints/mae_final_model.pth
✅ MAE checkpoint loaded in 0.27 seconds.
📊 MAE trained for 50 epochs
✅ Found model state dictionary in checkpoint.
Filtering MAE state dictionary for encoder weights...
📊 Extracted 78 encoder parameters from MAE in 0.00 seconds.
🎉 MAE encoder weights loaded successfully!
✅ TRAINING_CONFIG updated for MAE pretraining.

🧪 Testing model creation...
Using model_name: vit_base_patch16_224
Using MAE weights for test model: True
Using ImageNet pretrained for test model: False
🏗️ Creating ViT model: vit_base_patch16_224
Using ImageNet pretrained weights: False
✅ ViT model created in 1.43 

## 🚀 Step 8: Start Semi-Supervised Training

This cell will start the complete training process. Expected time: 4-6 hours for 100 epochs.

**Training Process:**
1. **Supervised Learning**: Uses labeled fish images with ground truth
2. **Semi-Supervised Learning**: Leverages unlabeled images with pseudo-labels
3. **EMA Teacher-Student**: Uses exponential moving average for consistency
4. **Automatic Checkpointing**: Saves progress every 10 epochs

In [38]:
# Start Semi-Supervised Training with Optional MAE Initialization
import os
import glob
from datetime import datetime

print("🚀 STARTING SEMI-SUPERVISED FISH CLASSIFICATION TRAINING")
print("="*60)

# Change to repository directory
%cd /content/ViT-FishID

# Check for existing checkpoints to resume from
RESUME_FROM = None
if os.path.exists(TRAINING_CONFIG['checkpoint_dir']):
    checkpoints = glob.glob(os.path.join(TRAINING_CONFIG['checkpoint_dir'], 'checkpoint_epoch_*.pth'))
    if checkpoints:
        # Find the latest checkpoint
        epoch_numbers = []
        for cp in checkpoints:
            try:
                epoch_num = int(cp.split('epoch_')[1].split('.')[0])
                epoch_numbers.append((epoch_num, cp))
            except:
                continue

        if epoch_numbers:
            epoch_numbers.sort(key=lambda x: x[0], reverse=True)  # Latest first
            latest_epoch, latest_checkpoint = epoch_numbers[0]
            print(f"🔍 Found existing checkpoints. Latest: Epoch {latest_epoch}")

            # Ask user if they want to resume (auto-skip in Colab for now)
            # resume_choice = input("Do you want to resume from the latest checkpoint? (y/n): ").lower().strip()
            resume_choice = 'n'  # Set to 'y' if you want to auto-resume

            if resume_choice in ['y', 'yes']:
                RESUME_FROM = latest_checkpoint
                print(f"✅ Will resume from: {os.path.basename(latest_checkpoint)}")
            else:
                print("🆕 Starting fresh training from epoch 1")

# Determine which training script to use
use_mae_script = TRAINING_CONFIG.get('mae_pretrained', False) and 'MAE_ENCODER_WEIGHTS' in globals() and MAE_ENCODER_WEIGHTS is not None

if use_mae_script:
    print("🤖 Preparing MAE-enhanced training script...")

    # Generate the full training script with MAE initialization logic
    training_script_content = f"""#!/usr/bin/env python3
import sys
sys.path.append('/content/ViT-FishID')

import torch
import argparse
import os
import glob
import wandb
from datetime import datetime

from model import ViTForFishClassification
from trainer import EMATrainer, SemiSupervisedTrainer # Ensure both trainers are imported here
from data import create_dataloaders, create_semi_supervised_dataloaders
from utils import get_device, set_seed

# Function to create MAE-initialized model
def create_mae_initialized_model(num_classes, model_name, mae_weights):
    model = ViTForFishClassification(
        num_classes=num_classes,
        model_name=model_name,
        pretrained=False,  # Don't use ImageNet
        dropout_rate=0.1
    )

    # Initialize updated_keys here to ensure it's always defined
    updated_keys = []

    if mae_weights is not None:
        backbone_state = model.backbone.state_dict()

        # Define exclude_substrings inside the function where it's used
        exclude_substrings = ['decoder', 'mask_token', 'head']
        # Define filter_prefixes inside the function where it's used
        filter_prefixes = ['patch_embed', 'pos_embed', 'cls_token', 'blocks', 'norm']

        # Initialize encoder_weights before the loop
        encoder_weights = {{}}

        for mae_key, mae_weight in mae_weights.items():
            # Keep only encoder-related weights and exclude specified substrings
            # Rewritten filtering logic to avoid potential scope issues with 'exclude'
            should_include_prefix = False
            for prefix in filter_prefixes:
                if prefix in mae_key:
                    should_include_prefix = True
                    break

            should_exclude_substring = False
            for exclude_str in exclude_substrings:
                 if exclude_str in mae_key:
                      should_exclude_substring = True
                      break

            if should_include_prefix and not should_exclude_substring:
                 # Check if the key exists in the model's state dict and if shapes match
                 if mae_key in backbone_state and mae_weight.shape == backbone_state[mae_key].shape:
                      backbone_state[mae_key] = mae_weight.clone()
                      updated_keys.append(mae_key)
                 # elif mae_key not in backbone_state:
                 #     print(f"Skipping MAE weight '{{mae_key}}': not found in model backbone")
                 # else:
                 #     print(f"Skipping MAE weight '{{mae_key}}': shape mismatch (MAE:{{mae_weight.shape}} != Model:{{backbone_state[mae_key].shape}})")


        try:
            model.backbone.load_state_dict(backbone_state, strict=False) # strict=False allows skipping mismatched keys
            print(f"✅ Loaded {{len(updated_keys)}} MAE encoder weights into model")
        except Exception as e:
            print(f"❌ Error loading MAE weights into ViT backbone: {{e}}")
            print("Continuing with potentially partially loaded weights.")


    return model

# Access the global MAE_ENCODER_WEIGHTS set in the Colab cell
# This relies on the fact that the kernel state is preserved when executing the script this way
# If running as a separate process, this would require saving/loading MAE_ENCODER_WEIGHTS
global MAE_ENCODER_WEIGHTS
mae_weights_from_global = MAE_ENCODER_WEIGHTS

if mae_weights_from_global is not None:
    print(f"🤖 Loaded {{len(mae_weights_from_global)}} MAE encoder weights (from global scope in Colab)")
else:
    print("⚠️ MAE_ENCODER_WEIGHTS not found in global scope. Model will use ImageNet weights.")


# Define arguments directly (or parse if needed)
# For simplicity, we are defining them based on the Colab cell's TRAINING_CONFIG
class Args:
    def __init__(self, config, resume_from, num_classes):
        self.mode = config['mode']
        self.data_dir = config['data_dir']
        self.epochs = config['epochs']
        self.batch_size = config['batch_size']
        self.learning_rate = config['learning_rate']
        self.weight_decay = config['weight_decay']
        self.model_name = config['model_name']
        self.consistency_weight = config['consistency_weight']
        self.pseudo_label_threshold = config['pseudo_label_threshold']
        self.temperature = config['temperature']
        self.warmup_epochs = config['warmup_epochs']
        self.ramp_up_epochs = config['ramp_up_epochs']
        self.save_dir = config['checkpoint_dir']
        self.save_frequency = config['save_frequency']
        self.pretrained = False # Explicitly False when using MAE
        self.use_wandb = config['use_wandb']
        self.resume_from = resume_from
        self.num_workers = 4 # Or get from config if available
        self.image_size = 224 # Or get from config if available
        self.dropout_rate = 0.1 # Or get from config if available
        self.num_classes = num_classes

args = Args(TRAINING_CONFIG, RESUME_FROM, NUM_CLASSES) # Pass the global variables

# Set up device and seed
device = get_device()
set_seed(42)

# Create model
print('🤖 Creating model for training...')
# Check if MAE weights are available before attempting MAE initialization
if mae_weights_from_global is not None:
     student_model = create_mae_initialized_model(
         num_classes=args.num_classes,
         model_name=args.model_name,
         mae_weights=mae_weights_from_global # Pass the loaded MAE weights
     ).to(device)
else: # Fallback if MAE was not loaded (this branch won't be taken if use_mae_script is True)
     student_model = ViTForFishClassification(
        num_classes=args.num_classes,
        model_name=args.model_name,
        pretrained=args.pretrained,
        dropout_rate=args.dropout_rate
     ).to(device)

# Create teacher model for EMA (if needed)
teacher_model = None
if args.mode == 'semi_supervised':
    try:
        from trainer import EMATeacher # Import EMATeacher here if needed
        teacher_model = EMATeacher(student_model).to(device) # Assuming EMATeacher takes student model
        print('🎓 EMA Teacher model created.')
    except NameError:
         print("⚠️ EMATeacher class not found. Semi-supervised training will not work correctly.")
         print("❌ Please ensure EMATeacher is defined or imported in trainer.py")
         sys.exit(1) # Exit if EMATeacher is needed but not found
    except Exception as e:
         print(f"⚠️ Error creating EMATeacher model: {{e}}")
         import traceback
         traceback.print_exc()
         sys.exit(1) # Exit on other errors


# Create data loaders
print('Loading data...')
if args.mode == 'supervised':
    train_loader, val_loader, num_classes_data = create_dataloaders(
        args.data_dir,
        batch_size=args.batch_size,
        image_size=args.image_size,
        num_workers=args.num_workers
    )
    unlabeled_loader = None
else:
    train_loader, val_loader, unlabeled_loader, num_classes_data = create_semi_supervised_dataloaders(
        args.data_dir,
        batch_size=args.batch_size,
        image_size=args.image_size,
        num_workers=args.num_workers
    )
print('✅ Data loaders created.')

if num_classes_data != args.num_classes:
     print(f"⚠️ Warning: Configured num_classes ({args.num_classes}) does not match detected data classes ({num_classes_data})")
     # Use the detected number of classes if they differ
     args.num_classes = num_classes_data
     print(f"✅ Using {args.num_classes} detected classes for training.")
     # Need to re-create the model if num_classes changed (unlikely but safe)
     # For this script, we'll assume num_classes is consistent

print(f'📊 Number of classes: {{args.num_classes}}')
print(f'🎯 Training mode: {{args.mode}}')

# Create trainer
print('Setting up trainer...')
if args.mode == 'semi_supervised' and unlabeled_loader is not None and teacher_model is not None:
    trainer = SemiSupervisedTrainer(
        student_model=student_model,
        device=device,
        learning_rate=args.learning_rate,
        weight_decay=args.weight_decay,
        consistency_weight=args.consistency_weight,
        pseudo_label_threshold=args.pseudo_label_threshold,
        temperature=args.temperature,
        warmup_epochs=args.warmup_epochs,
        ramp_up_epochs=args.ramp_up_epochs,
        teacher_model=teacher_model # Pass teacher model
    )
    print('✅ SemiSupervisedTrainer created.')
elif args.mode == 'supervised':
    trainer = EMATrainer( # Using EMATrainer for supervised mode if needed, or could use a simple Trainer
        student_model=student_model,
        device=device,
        learning_rate=args.learning_rate,
        weight_decay=args.weight_decay
    )
    print('✅ EMATrainer created (for supervised mode).')
else:
     print("❌ Cannot create trainer. Check mode and data loaders.")
     sys.exit(1)


# Initialize W&B
if args.use_wandb:
    print('Initializing W&B...')
    wandb.init(
        project=TRAINING_CONFIG.get('wandb_project', 'ViT-FishID-Training'),
        name=TRAINING_CONFIG.get('wandb_run_name', f'fish-classification-{{args.num_classes}}-classes'),
        config=vars(args),
        tags=['mae-initialized', 'fish-classification'] if use_mae_script else ['imagenet-pretrained', 'fish-classification']
    )
    print('✅ W&B initialized.')

# Resume from checkpoint if specified
if args.resume_from and args.resume_from != 'None':
    print(f'📥 Resuming from checkpoint: {{args.resume_from}}')
    try:
        checkpoint = torch.load(args.resume_from, map_location=device)
        trainer.student_model.load_state_dict(checkpoint['student_state_dict'])
        # Load teacher state dict if it exists and trainer has a teacher model
        if hasattr(trainer, 'teacher_model') and trainer.teacher_model is not None and 'teacher_state_dict' in checkpoint:
             try:
                trainer.teacher_model.teacher_model.load_state_dict(checkpoint['teacher_state_dict'])
                print('✅ Teacher model state dict loaded.')
             except Exception as e:
                 print(f"⚠️ Error loading teacher state dict: {{e}}")

        # Load optimizer state dict
        if 'optimizer_state_dict' in checkpoint:
             try:
                 trainer.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                 print('✅ Optimizer state dict loaded.')
             except Exception as e:
                  print(f"⚠️ Error loading optimizer state dict: {{e}}")


        start_epoch = checkpoint.get('epoch', 0) + 1
        best_accuracy = checkpoint.get('best_accuracy', 0.0) # Resume best accuracy as well
        print(f'✅ Resumed from epoch {{start_epoch}} with best accuracy {{best_accuracy:.2f}}%')
    except Exception as e:
        print(f'❌ Error loading checkpoint: {{e}}')
        import traceback
        traceback.print_exc() # Print full traceback
        print("Starting fresh training from epoch 1.")
        start_epoch = 1
        best_accuracy = 0.0
else:
    start_epoch = 1
    best_accuracy = 0.0

print(f'🚀 Starting training from epoch {{start_epoch}}')

# Training loop
for epoch in range(start_epoch, args.epochs + 1):
    print(f'\\n📅 Epoch {{epoch}}/{{args.epochs}}')

    # Training step
    try:
        if args.mode == 'semi_supervised' and unlabeled_loader is not None:
            train_loss = trainer.train_epoch(train_loader, unlabeled_loader, epoch)
        elif args.mode == 'supervised':
            train_loss = trainer.train_epoch(train_loader, epoch)
        else:
             print("❌ Invalid training mode or data loaders for training epoch.")
             break # Exit training loop if setup is wrong
    except Exception as e:
         print(f"❌ Error during training epoch {{epoch}}: {{e}}")
         import traceback
         traceback.print_exc()
         break # Exit training loop on error


    # Validation step
    try:
        val_accuracy = trainer.validate(val_loader)
    except Exception as e:
        print(f"❌ Error during validation epoch {{epoch}}: {{e}}")
        import traceback
        traceback.print_exc()
        val_accuracy = 0.0 # Set accuracy to 0 to avoid saving best model on error


    # Update best accuracy
    is_best = val_accuracy > best_accuracy
    if is_best:
        best_accuracy = val_accuracy

    print(f'📊 Epoch {{epoch}} - Train Loss: {{train_loss:.4f}}, Val Acc: {{val_accuracy:.2f}}% (Best: {{best_accuracy:.2f}}%)')

    # Save checkpoint
    if epoch % args.save_frequency == 0 or is_best:
        print(f'💾 Saving checkpoint for epoch {{epoch}}...')
        checkpoint_data = {{
            'epoch': epoch,
            'student_state_dict': trainer.student_model.state_dict(),
            'optimizer_state_dict': trainer.optimizer.state_dict(), # Corrected: Access optimizer state dict
            'best_accuracy': best_accuracy,
            'train_loss': train_loss,
            'val_accuracy': val_accuracy
        }}

        if hasattr(trainer, 'teacher_model') and trainer.teacher_model is not None:
            try:
                checkpoint_data['teacher_state_dict'] = trainer.teacher_model.teacher_model.state_dict() # Corrected: teacher_model.state_dict()
                checkpoint_data['teacher_acc'] = getattr(trainer, 'teacher_accuracy', val_accuracy) # Use teacher_accuracy if available
            except Exception as e:
                 print(f"⚠️ Could not save teacher state dict: {{e}}")


        # Ensure save directory exists
        os.makedirs(args.save_dir, exist_ok=True)

        # Save regular checkpoint
        if epoch % args.save_frequency == 0:
            checkpoint_path = os.path.join(args.save_dir, f'checkpoint_epoch_{{epoch}}.pth')
            try:
                torch.save(checkpoint_data, checkpoint_path)
                print(f'✅ Saved checkpoint: {{checkpoint_path}}')
            except Exception as e:
                 print(f"❌ Error saving checkpoint {{checkpoint_path}}: {{e}}")


        # Save best model
        if is_best:
            best_path = os.path.join(args.save_dir, 'model_best.pth')
            try:
                torch.save(checkpoint_data, best_path)
                print(f'🏆 New best model saved: {{best_path}}')
            except Exception as e:
                 print(f"❌ Error saving best model {{best_path}}: {{e}}")


    # W&B logging
    if args.use_wandb:
        try:
            wandb.log({{
                'epoch': epoch,
                'train_loss': train_loss,
                'val_accuracy': val_accuracy,
                'best_accuracy': best_accuracy
            }})
        except Exception as e:
             print(f"⚠️ Error logging to W&B at epoch {{epoch}}: {{e}}")


print(f'\\n🎉 Training completed!')
print(f'🏆 Best accuracy: {{best_accuracy:.2f}}%')

if args.use_wandb:
    try:
        wandb.finish()
    except Exception as e:
         print(f"⚠️ Error finishing W&B run: {{e}}")

"""
    # Write the script to a temporary file
    script_filename = '/content/run_mae_training.py'
    with open(script_filename, 'w') as f:
        f.write(training_script_content)

    # Execute the temporary script
    training_cmd = f"python {script_filename}"

else:
    # Build standard training command without MAE
    training_cmd = f"""python train.py \\
    --mode {TRAINING_CONFIG['mode']} \\
    --data_dir {TRAINING_CONFIG['data_dir']} \\
    --epochs {TRAINING_CONFIG['epochs']} \\
    --batch_size {TRAINING_CONFIG['batch_size']} \\
    --learning_rate {TRAINING_CONFIG['learning_rate']} \\
    --weight_decay {TRAINING_CONFIG['weight_decay']} \\
    --model_name {TRAINING_CONFIG['model_name']} \\
    --consistency_weight {TRAINING_CONFIG['consistency_weight']} \\
    --pseudo_label_threshold {TRAINING_CONFIG['pseudo_label_threshold']} \\
    --temperature {TRAINING_CONFIG['temperature']} \\
    --warmup_epochs {TRAINING_CONFIG['warmup_epochs']} \\
    --ramp_up_epochs {TRAINING_CONFIG['ramp_up_epochs']} \\
    --save_dir {TRAINING_CONFIG['checkpoint_dir']} \\
    --save_frequency {TRAINING_CONFIG['save_frequency']}"""

    # Add resume checkpoint if found
    if RESUME_FROM:
        training_cmd += f" \\\n    --resume_from {RESUME_FROM}"

    # Add pretrained flag
    if TRAINING_CONFIG['pretrained']:
        training_cmd += " \\\n    --pretrained"

    # Add W&B logging
    if TRAINING_CONFIG['use_wandb']:
        training_cmd += " \\\n    --use_wandb"

print("📋 TRAINING CONFIGURATION:")
print("="*60)
print(f"🎯 Training {TRAINING_CONFIG['num_classes']} fish species")
print(f"📊 Mode: {TRAINING_CONFIG['mode']}")
print(f"🤖 MAE pretrained: {TRAINING_CONFIG.get('mae_pretrained', False)}")
print(f"🌐 ImageNet pretrained: {TRAINING_CONFIG.get('pretrained', True)}")

if RESUME_FROM:
    print(f"🔄 Resuming from: {os.path.basename(RESUME_FROM)}")
else:
    print(f"🆕 Starting fresh training")

print(f"⏱️ Estimated time: {TRAINING_CONFIG['epochs'] * 3 / 60:.1f} hours")
print(f"💾 Checkpoints: {TRAINING_CONFIG['checkpoint_dir']}")
print(f"📈 W&B logging: {TRAINING_CONFIG['use_wandb']}")

if TRAINING_CONFIG.get('mae_pretrained', False):
    print(f"🎉 Using MAE-learned features from: {os.path.basename(TRAINING_CONFIG.get('mae_model_path', ''))}")
    print(f"🚀 This should significantly improve training performance!")

print(f"\n🎬 TRAINING STARTED")
print("⏰ Started at:", datetime.now().strftime('%Y-%m-%d %H:%M:%S'))

# Execute training
# Use PYTHONPATH to help the executed script find local modules
!PYTHONPATH=/content/ViT-FishID {training_cmd}

print("\n" + "="*60)
print("🎉 TRAINING COMPLETED!")
print("⏰ Finished at:", datetime.now().strftime('%Y-%m-%d %H:%M:%S'))

# Check for results
best_model_path = os.path.join(TRAINING_CONFIG['checkpoint_dir'], 'model_best.pth')
if os.path.exists(best_model_path):
    try:
        import torch
        checkpoint = torch.load(best_model_path, map_location='cpu')
        if 'best_accuracy' in checkpoint:
            print(f"🏆 Best accuracy achieved: {checkpoint['best_accuracy']:.2f}%")
        if 'epoch' in checkpoint:
            print(f"📊 Best model from epoch: {checkpoint['epoch']}")
    except:
        pass

print("✅ Your MAE-enhanced model is ready for evaluation and deployment!")

🚀 STARTING SEMI-SUPERVISED FISH CLASSIFICATION TRAINING
/content/ViT-FishID
🤖 Preparing MAE-enhanced training script...


NameError: name 'args' is not defined

## 📊 Step 9: Check Training Results

Review the training progress and model performance.

In [None]:
# Check Training Results and Performance
import os
import glob
import torch
from datetime import datetime

print("📊 CHECKING TRAINING RESULTS")
print("="*50)

checkpoint_dir = TRAINING_CONFIG['checkpoint_dir']
print(f"📁 Checkpoint directory: {checkpoint_dir}")

if os.path.exists(checkpoint_dir):
    # Find all checkpoints
    checkpoints = glob.glob(os.path.join(checkpoint_dir, '*.pth'))

    if checkpoints:
        print(f"✅ Found {len(checkpoints)} checkpoint(s)")

        # Sort checkpoints by epoch
        epoch_checkpoints = []
        other_checkpoints = []

        for cp in checkpoints:
            basename = os.path.basename(cp)
            if 'epoch_' in basename:
                try:
                    epoch_num = int(basename.split('epoch_')[1].split('.')[0])
                    epoch_checkpoints.append((epoch_num, cp))
                except:
                    other_checkpoints.append(cp)
            else:
                other_checkpoints.append(cp)

        # Show epoch progression
        if epoch_checkpoints:
            epoch_checkpoints.sort(key=lambda x: x[0])
            print(f"\n📈 TRAINING PROGRESSION:")
            latest_epoch = epoch_checkpoints[-1][0]
            print(f"  🏁 Latest epoch: {latest_epoch}")
            print(f"  📊 Completion: {latest_epoch}/{TRAINING_CONFIG['epochs']} epochs ({latest_epoch/TRAINING_CONFIG['epochs']*100:.1f}%)")

            # Show recent checkpoints
            recent_checkpoints = epoch_checkpoints[-5:] if len(epoch_checkpoints) > 5 else epoch_checkpoints
            for epoch, cp in recent_checkpoints:
                file_size = os.path.getsize(cp) / (1024**2)
                print(f"  📄 Epoch {epoch}: {file_size:.1f} MB")

        # Analyze best model
        best_model_path = os.path.join(checkpoint_dir, 'model_best.pth')
        if os.path.exists(best_model_path):
            print(f"\n🏆 BEST MODEL ANALYSIS:")
            try:
                best_checkpoint = torch.load(best_model_path, map_location='cpu')

                best_epoch = best_checkpoint.get('epoch', 'Unknown')
                best_acc = best_checkpoint.get('best_accuracy', best_checkpoint.get('best_acc', 'Unknown'))

                print(f"  📊 Best epoch: {best_epoch}")
                if isinstance(best_acc, (int, float)):
                    print(f"  🎯 Best accuracy: {best_acc:.2f}%")

                    # Performance assessment
                    if best_acc >= 85:
                        print("  🎉 EXCELLENT performance!")
                    elif best_acc >= 75:
                        print("  👍 GOOD performance!")
                    elif best_acc >= 65:
                        print("  📈 FAIR performance - consider more training")
                    else:
                        print("  ⚠️ LOW performance - check data and hyperparameters")

                # Check for other metrics
                if 'teacher_acc' in best_checkpoint:
                    print(f"  🎓 Teacher accuracy: {best_checkpoint['teacher_acc']:.2f}%")

            except Exception as e:
                print(f"  ⚠️ Could not analyze best model: {e}")

        # Show other important files
        for cp in other_checkpoints:
            basename = os.path.basename(cp)
            file_size = os.path.getsize(cp) / (1024**2)
            print(f"  📄 {basename}: {file_size:.1f} MB")

    else:
        print("❌ No checkpoints found")
        print("💡 Training may not have started or completed successfully")

else:
    print(f"❌ Checkpoint directory not found: {checkpoint_dir}")

# W&B results link
if TRAINING_CONFIG['use_wandb']:
    print(f"\n📈 View detailed training metrics at:")
    print(f"   https://wandb.ai/your-username/{TRAINING_CONFIG['wandb_project']}")

print("\n✅ Results check complete!")

📁 Checking results in: /content/drive/MyDrive/ViT-FishID/checkpoints_extended

✅ Found 100 checkpoint(s) from extended training:
  📊 Epoch 1: checkpoint_epoch_1.pth (982.4 MB)
  📊 Epoch 2: checkpoint_epoch_2.pth (982.4 MB)
  📊 Epoch 3: checkpoint_epoch_3.pth (982.4 MB)
  📊 Epoch 4: checkpoint_epoch_4.pth (982.4 MB)
  📊 Epoch 5: checkpoint_epoch_5.pth (982.4 MB)
  📊 Epoch 6: checkpoint_epoch_6.pth (982.4 MB)
  📊 Epoch 7: checkpoint_epoch_7.pth (982.4 MB)
  📊 Epoch 8: checkpoint_epoch_8.pth (982.4 MB)
  📊 Epoch 9: checkpoint_epoch_9.pth (982.4 MB)
  📊 Epoch 10: checkpoint_epoch_10.pth (982.4 MB)
  📊 Epoch 11: checkpoint_epoch_11.pth (982.4 MB)
  📊 Epoch 12: checkpoint_epoch_12.pth (982.4 MB)
  📊 Epoch 13: checkpoint_epoch_13.pth (982.4 MB)
  📊 Epoch 14: checkpoint_epoch_14.pth (982.4 MB)
  📊 Epoch 15: checkpoint_epoch_15.pth (982.4 MB)
  📊 Epoch 16: checkpoint_epoch_16.pth (982.4 MB)
  📊 Epoch 17: checkpoint_epoch_17.pth (982.4 MB)
  📊 Epoch 18: checkpoint_epoch_18.pth (982.4 MB)
  📊 Epo

## 💾 Step 10: Save Model and Results

Backup your trained model and results to Google Drive for future use.

In [None]:
# Save trained model and results to Google Drive
import shutil
import json
import os
from datetime import datetime

print("💾 SAVING MODEL AND RESULTS")
print("="*50)

# Create timestamped backup directory
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_dir = f'/content/drive/MyDrive/ViT-FishID_Results_{timestamp}'

try:
    os.makedirs(backup_dir, exist_ok=True)
    print(f"📁 Created backup directory: {backup_dir}")

    # Copy checkpoints
    checkpoint_source = TRAINING_CONFIG['checkpoint_dir']
    if os.path.exists(checkpoint_source):
        checkpoint_backup = os.path.join(backup_dir, 'checkpoints')
        shutil.copytree(checkpoint_source, checkpoint_backup, dirs_exist_ok=True)
        print(f"✅ Checkpoints copied to: {checkpoint_backup}")

        # Count files
        checkpoint_files = len([f for f in os.listdir(checkpoint_backup) if f.endswith('.pth')])
        print(f"📊 Backed up {checkpoint_files} checkpoint files")

    # Save training configuration
    config_file = os.path.join(backup_dir, 'training_config.json')
    serializable_config = {k: v for k, v in TRAINING_CONFIG.items()
                          if isinstance(v, (str, int, float, bool, list, dict, type(None)))}

    with open(config_file, 'w') as f:
        json.dump(serializable_config, f, indent=2)
    print(f"✅ Training config saved: {config_file}")

    # Create training summary
    summary_file = os.path.join(backup_dir, 'training_summary.txt')
    with open(summary_file, 'w') as f:
        f.write(f"ViT-FishID Training Summary\n")
        f.write(f"========================\n\n")
        f.write(f"Training Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        f.write(f"Training Mode: {TRAINING_CONFIG['mode']}\n")
        f.write(f"Total Epochs: {TRAINING_CONFIG['epochs']}\n")
        f.write(f"Batch Size: {TRAINING_CONFIG['batch_size']}\n")
        f.write(f"Model: {TRAINING_CONFIG['model_name']}\n")
        f.write(f"Number of Species: {TRAINING_CONFIG['num_classes']}\n")
        f.write(f"Consistency Weight: {TRAINING_CONFIG['consistency_weight']}\n")
        f.write(f"W&B Logging: {TRAINING_CONFIG['use_wandb']}\n\n")
        f.write(f"Key Files:\n")
        f.write(f"- model_best.pth: Best performing model\n")
        f.write(f"- model_latest.pth: Most recent checkpoint\n")
        f.write(f"- checkpoint_epoch_X.pth: Periodic saves\n")

    print(f"✅ Training summary saved: {summary_file}")

    # Get final model performance
    best_model_path = os.path.join(checkpoint_source, 'model_best.pth')
    if os.path.exists(best_model_path):
        try:
            import torch
            checkpoint = torch.load(best_model_path, map_location='cpu')
            if 'best_accuracy' in checkpoint:
                print(f"🏆 Final model accuracy: {checkpoint['best_accuracy']:.2f}%")

                # Add performance to summary
                with open(summary_file, 'a') as f:
                    f.write(f"\nFinal Performance:\n")
                    f.write(f"- Best Accuracy: {checkpoint['best_accuracy']:.2f}%\n")
                    f.write(f"- Best Epoch: {checkpoint.get('epoch', 'Unknown')}\n")
        except Exception as e:
            print(f"⚠️ Could not read final performance: {e}")

    print(f"\n🎉 ALL RESULTS SAVED SUCCESSFULLY!")
    print(f"📁 Backup location: {backup_dir}")
    print(f"\n💡 You can now:")
    print(f"   1. Download the entire results folder")
    print(f"   2. Use model_best.pth for inference")
    print(f"   3. Resume training from any checkpoint")
    print(f"   4. Share results with collaborators")

except Exception as e:
    print(f"❌ Error saving results: {e}")
    print("💡 Please check Google Drive permissions and available space")

💾 Saving results to Google Drive: /content/drive/MyDrive/ViT-FishID_Training_20250815_070649
✅ Checkpoints saved to: /content/drive/MyDrive/ViT-FishID_Training_20250815_070649/checkpoints
✅ Training config saved to: /content/drive/MyDrive/ViT-FishID_Training_20250815_070649/training_config.json
✅ Training summary saved to: /content/drive/MyDrive/ViT-FishID_Training_20250815_070649/training_summary.txt

🎉 All results saved to Google Drive!
📁 Location: /content/drive/MyDrive/ViT-FishID_Training_20250815_070649

💡 You can now:
   1. Download the checkpoints folder for local use
   2. Use model_best.pth for inference
   3. Continue training from any checkpoint


## 🧪 Step 11: Model Evaluation (Optional)

Test your trained model on sample images and get detailed performance metrics.

In [None]:
# Quick model evaluation and testing
import torch
import os

print("🧪 MODEL EVALUATION")
print("="*50)

# Check for trained model
best_model_path = os.path.join(TRAINING_CONFIG['checkpoint_dir'], 'model_best.pth')

if os.path.exists(best_model_path):
    print(f"✅ Found trained model: {os.path.basename(best_model_path)}")

    try:
        # Load model checkpoint
        checkpoint = torch.load(best_model_path, map_location='cpu')

        print(f"\n📊 MODEL PERFORMANCE:")
        if 'epoch' in checkpoint:
            print(f"  🏆 Best epoch: {checkpoint['epoch']}")
        if 'best_accuracy' in checkpoint:
            print(f"  🎯 Best accuracy: {checkpoint['best_accuracy']:.2f}%")
        if 'teacher_acc' in checkpoint:
            print(f"  🎓 Teacher accuracy: {checkpoint['teacher_acc']:.2f}%")

        # Model architecture info
        if 'num_classes' in checkpoint:
            print(f"  🐟 Number of species: {checkpoint['num_classes']}")

        # File size
        file_size = os.path.getsize(best_model_path) / (1024**2)
        print(f"  📏 Model size: {file_size:.1f} MB")

        # Performance assessment
        if 'best_accuracy' in checkpoint:
            accuracy = checkpoint['best_accuracy']
            if accuracy >= 85:
                print(f"\n🎉 EXCELLENT PERFORMANCE!")
                print(f"   Your model achieved outstanding accuracy for fish classification")
            elif accuracy >= 75:
                print(f"\n👍 GOOD PERFORMANCE!")
                print(f"   Your model shows solid accuracy for practical use")
            elif accuracy >= 65:
                print(f"\n📈 FAIR PERFORMANCE")
                print(f"   Consider additional training or hyperparameter tuning")
            else:
                print(f"\n⚠️ PERFORMANCE NEEDS IMPROVEMENT")
                print(f"   Review data quality and training configuration")

    except Exception as e:
        print(f"❌ Error loading model: {e}")

else:
    print(f"❌ No trained model found at: {best_model_path}")
    print("Please ensure training completed successfully")

# Suggest next steps
print(f"\n🚀 NEXT STEPS:")
print(f"1. 🧪 Run detailed evaluation: Use evaluate.py script")
print(f"2. 🔬 Test on new images: Upload test images and run inference")
print(f"3. 📱 Deploy model: Use for real-world fish classification")
print(f"4. 📊 Analyze results: Review confusion matrix and per-species performance")
print(f"5. 🔄 Continue training: Resume from checkpoints for more epochs")

print(f"\n✅ Evaluation complete!")

🧪 Looking for best model at: /content/drive/MyDrive/ViT-FishID/checkpoints_extended/checkpoint_epoch_100.pth
✅ Found trained model.
🧪 Loading trained model for quick evaluation...
📊 Model training info:
  - Best epoch: 100
  - Best accuracy: 87.56%
  - Number of classes (from checkpoint): 37

✅ Model loading and info check completed.
💡 Note: This step confirms the model file exists and can be loaded.
   Actual inference or evaluation on test data is done separately.

💡 For comprehensive evaluation:
   Use the evaluate.py script with your test dataset
   The test set was automatically created during training


## 🔧 Troubleshooting Guide

### Common Issues and Solutions:

**🚫 GPU Memory Error (CUDA out of memory)**
- Reduce `batch_size` from 16 to 8 or 4
- Restart runtime: `Runtime → Restart runtime`
- Clear GPU cache: Run `torch.cuda.empty_cache()`

**📁 Data Not Found Error**
- Verify `fish_cutouts.zip` is uploaded to Google Drive root
- Check dataset structure has `labeled/` and `unlabeled/` folders
- Re-run Step 5 to extract dataset

**⏰ Training Timeout (Colab disconnection)**
- Use Colab Pro for longer sessions (up to 24 hours)
- Enable background execution: `Runtime → Change runtime type`
- Checkpoints auto-save every 10 epochs for resuming

**📉 Low Training Accuracy**
- Increase training epochs (try 150-200)
- Adjust `consistency_weight` (try 1.0-3.0)
- Lower `pseudo_label_threshold` (try 0.5-0.6)
- Check data quality and balance

**🔗 W&B Connection Issues**
- Get API key from: https://wandb.ai/settings
- Set as Colab secret: `Tools → Secrets`
- Training continues without W&B if connection fails

**💾 Google Drive Mount Problems**
- Re-run Step 2 to remount
- Check Google Drive permissions
- Use local fallback directories if needed

## 🎉 Summary and Next Steps

### 🏆 What You've Accomplished:

✅ **Complete Semi-Supervised Training Pipeline**
- Vision Transformer (ViT) for fish classification
- Semi-supervised learning with labeled + unlabeled data
- EMA teacher-student framework for consistency training
- Automatic checkpointing and progress tracking

✅ **Model Performance**
- Expected accuracy: 80-90% on fish species classification
- Robust to limited labeled data through semi-supervised learning
- Production-ready model saved to Google Drive

### 📁 Important Files Created:

- **`model_best.pth`**: Best performing model (use for inference)
- **`model_latest.pth`**: Most recent checkpoint
- **`checkpoint_epoch_X.pth`**: Periodic saves for resuming
- **`training_config.json`**: Complete training configuration
- **`training_summary.txt`**: Human-readable training report

### 🚀 Next Steps:

1. **🧪 Detailed Evaluation**
   ```python
   # Run comprehensive evaluation
   !python evaluate.py --data_dir /content/fish_cutouts --model_path model_best.pth
   ```

2. **🔬 Test on New Images**
   - Upload new fish images
   - Run inference using your trained model
   - Analyze predictions and confidence scores

3. **📱 Deploy Your Model**
   - Download `model_best.pth` to local machine
   - Integrate into web app or mobile application
   - Use for real-world fish species identification

4. **🔄 Continue Training (if needed)**
   ```python
   # Resume from any checkpoint for more epochs
   --resume_from checkpoint_epoch_100.pth --epochs 150
   ```

5. **📊 Experiment and Improve**
   - Try different hyperparameters
   - Collect more training data
   - Experiment with data augmentation

### 🎯 Expected Performance:
- **Accuracy**: 80-90% on test set
- **Inference Speed**: ~50-100ms per image
- **Model Size**: ~300MB
- **Production Ready**: Yes! 🎉

**Congratulations on training your fish classification model! 🐟🎊**

## 📈 Step 7b: Connect to Weights & Biases (Optional)

Log in to Weights & Biases for experiment tracking and visualization. You will be prompted to enter your API key.

## 💾 Step 8b: Explicitly Save Best Model Backup

This step ensures that `model_best.pth` is copied to a dedicated backup location in Google Drive immediately after training completes.

In [None]:
# Explicitly copy model_best.pth to a backup location
import shutil
import os
from datetime import datetime

print("💾 Explicitly backing up model_best.pth...")

# Get the primary checkpoint directory from TRAINING_CONFIG
checkpoint_dir = TRAINING_CONFIG.get('checkpoint_dir')

if checkpoint_dir and os.path.exists(checkpoint_dir):
    best_model_source_path = os.path.join(checkpoint_dir, 'checkpoint_epoch_100.pth')

    if os.path.exists(best_model_source_path):
        # Define a dedicated backup directory path in Google Drive
        # Using a simpler path than the full Step 10 save for quick verification
        backup_base_dir = '/content/drive/MyDrive/ViT-FishID_BestModel_Backups'
        os.makedirs(backup_base_dir, exist_ok=True)

        # Create a timestamped filename for the backup
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        backup_filename = f"model_best_backup_{timestamp}.pth"
        backup_dest_path = os.path.join(backup_base_dir, backup_filename)

        try:
            shutil.copy2(best_model_source_path, backup_dest_path)
            print(f"✅ Successfully copied model_best.pth to backup:")
            print(f"   📁 Source: {best_model_source_path}")
            print(f"   💾 Destination: {backup_dest_path}")
            print(f"   📏 Size: {os.path.getsize(backup_dest_path) / (1024**2):.1f} MB")
            print("🎉 Please check your Google Drive in the 'ViT-FishID_BestModel_Backups' folder!")

        except Exception as e:
            print(f"❌ Error copying model_best.pth to backup: {e}")
            print("Please check your Google Drive connection and permissions.")

    else:
        print(f"⚠️ model_best.pth not found in the primary checkpoint directory: {checkpoint_dir}")
        print("   This means training likely did not complete successfully or the best model wasn't saved.")

else:
    print("❌ Primary checkpoint directory not found or TRAINING_CONFIG is not set.")
    print("   Please ensure Step 7 is run before this step.")

print("\n💾 Explicit backup step complete.")

💾 Explicitly backing up model_best.pth...
✅ Successfully copied model_best.pth to backup:
   📁 Source: /content/drive/MyDrive/ViT-FishID/checkpoints_extended/checkpoint_epoch_100.pth
   💾 Destination: /content/drive/MyDrive/ViT-FishID_BestModel_Backups/model_best_backup_20250815_075025.pth
   📏 Size: 982.4 MB
🎉 Please check your Google Drive in the 'ViT-FishID_BestModel_Backups' folder!

💾 Explicit backup step complete.


## 📊 Step 12: Evaluate Model on Test Dataset

This step runs the `evaluate.py` script to assess the performance of your trained model on the unseen test dataset.

In [None]:
# Run evaluation script
import os
import fileinput # Import fileinput for modifying files

print("🧪 Starting evaluation on the test dataset...")
print("="*50)

# Define the path to the evaluation script relative to the repo root
eval_script_name = 'evaluate.py'
repo_dir = '/content/ViT-FishID'
eval_script_path = os.path.join(repo_dir, eval_script_name)


# Define the path to the trained model checkpoint
# Using the epoch 100 checkpoint as it has the best recorded accuracy
model_checkpoint_path = '/content/drive/MyDrive/ViT-FishID/checkpoints_extended/checkpoint_epoch_100.pth'

# Define the data directory (from Step 5)
data_directory = DATA_DIR # Ensure DATA_DIR is defined from Step 5

# Check if the evaluation script and model checkpoint exist
if not os.path.exists(eval_script_path):
    print(f"❌ Evaluation script not found at: {eval_script_path}")
    print(f"Please ensure the ViT-FishID repository was cloned correctly in Step 4 to {repo_dir}.")
elif not os.path.exists(model_checkpoint_path):
     print(f"❌ Model checkpoint not found at: {model_checkpoint_path}")
     print("Please ensure training completed successfully and the checkpoint exists.")
elif not os.path.exists(data_directory):
     print(f"❌ Data directory not found at: {data_directory}")
     print("Please ensure Step 5 was run correctly.")
else:
    print(f"✅ Found evaluation script: {eval_script_path}")
    print(f"✅ Found model checkpoint: {model_checkpoint_path}")
    print(f"✅ Found data directory: {data_directory}")

    # --- FIX 1: Modify evaluate.py to correct the vit_model import statement ---
    print(f"\n🔧 Correcting import statement for ViTForFishClassification in {eval_script_name}...")
    try:
        with fileinput.FileInput(eval_script_path, inplace=True) as file:
            for line in file:
                # Replace 'from vit_model import' with 'from model import'
                # Do NOT print anything else here
                print(line.replace('from vit_model import ViTForFishClassification', 'from model import ViTForFishClassification'), end='')
        print(f"✅ Corrected import statement for ViTForFishClassification in {eval_script_name}.")
    except Exception as e:
        print(f"❌ Error modifying ViTForFishClassification import in {eval_script_name}: {e}")
        print("🚨 Evaluation might still fail due to this import error.")
    # --- End of FIX 1 ---

    # --- FIX 2: Modify evaluate.py to comment out the ema_teacher import ---
    print(f"\n🔧 Commenting out import statement for EMATeacher in {eval_script_name}...")
    try:
        with fileinput.FileInput(eval_script_path, inplace=True) as file:
            for line in file:
                # Comment out 'from ema_teacher import EMATeacher'
                # Do NOT print anything else here
                if 'from ema_teacher import EMATeacher' in line:
                     print("# " + line, end='') # Add # to comment out the line
                else:
                    print(line, end='')
        print(f"✅ Commented out import statement for EMATeacher in {eval_script_name}.")
    except Exception as e:
        print(f"❌ Error commenting out EMATeacher import in {eval_script_name}: {e}")
        print("🚨 Evaluation might still fail due to this import error.")
    # --- End of FIX 2 ---

    # --- FIX 3: Modify evaluate.py to correct the data_loader import statement ---
    print(f"\n🔧 Correcting import statement for create_fish_dataloaders in {eval_script_name}...")
    try:
        with fileinput.FileInput(eval_script_path, inplace=True) as file:
            for line in file:
                # Replace 'from data_loader import' with 'from data import'
                # Do NOT print anything else here
                print(line.replace('from data_loader import create_fish_dataloaders', 'from data import create_fish_dataloaders'), end='')
        print(f"✅ Corrected import statement for create_fish_dataloaders in {eval_script_name}.")
    except Exception as e:
        print(f"❌ Error modifying create_fish_dataloaders import in {eval_script_name}: {e}")
        print("🚨 Evaluation might still fail due to this import error.")
    # --- End of FIX 3 ---


    # Construct the evaluation command
    # Use PYTHONPATH to help the script find local modules like model
    # Use %cd before and after, but rely on PYTHONPATH for the import
    eval_cmd = f"PYTHONPATH={repo_dir} python {eval_script_name} --data_dir {data_directory} --model_path {model_checkpoint_path}"


    print("\n📋 Evaluation Command:")
    # Print the command cleanly without the PYTHONPATH for readability, but it's included in the execution
    print(f"python {eval_script_name} --data_dir {data_directory} --model_path {model_checkpoint_path} (with PYTHONPATH={repo_dir})")
    print("\n" + "="*50)

    print("🚀 Running evaluation...")
    # Change to the repository directory before executing
    %cd {repo_dir}

    # Execute the evaluation script with PYTHONPATH set
    !{eval_cmd}

    # Change back to original content directory (optional but good practice)
    %cd /content

    print("\n" + "="*50)
    print("🎉 Evaluation complete!")

print("\n💡 Check the output above for accuracy metrics on the test set.")

🧪 Starting evaluation on the test dataset...
✅ Found evaluation script: /content/ViT-FishID/evaluate.py
✅ Found model checkpoint: /content/drive/MyDrive/ViT-FishID/checkpoints_extended/checkpoint_epoch_100.pth
✅ Found data directory: /content/fish_cutouts

🔧 Correcting import statement for ViTForFishClassification in evaluate.py...
✅ Corrected import statement for ViTForFishClassification in evaluate.py.

🔧 Commenting out import statement for EMATeacher in evaluate.py...
✅ Commented out import statement for EMATeacher in evaluate.py.

🔧 Correcting import statement for create_fish_dataloaders in evaluate.py...
✅ Corrected import statement for create_fish_dataloaders in evaluate.py.

📋 Evaluation Command:
python evaluate.py --data_dir /content/fish_cutouts --model_path /content/drive/MyDrive/ViT-FishID/checkpoints_extended/checkpoint_epoch_100.pth (with PYTHONPATH=/content/ViT-FishID)

🚀 Running evaluation...
/content/ViT-FishID
2025-08-15 08:01:40.428842: I tensorflow/core/util/port.cc:

## 🔍 Step 12b: Diagnose `ModuleNotFoundError`

This step checks the file structure and import statements to understand why `vit_model` is not being found.

In [None]:
import os

print("🔍 Diagnosing ModuleNotFoundError...")
repo_dir = '/content/ViT-FishID'
eval_script_path = os.path.join(repo_dir, 'evaluate.py')
model_file_guess = os.path.join(repo_dir, 'model.py') # Common name for model file
vit_model_file_guess = os.path.join(repo_dir, 'vit_model.py') # Guessed name based on import

print(f"Repo directory: {repo_dir}")

print("\n📂 Files in repository root:")
# List files in the repository root
if os.path.exists(repo_dir):
    !ls -la {repo_dir}
else:
    print(f"❌ Repository directory not found: {repo_dir}")


print(f"\n📄 Content of {os.path.basename(eval_script_path)} (checking import):")
# Read and print the content of evaluate.py
if os.path.exists(eval_script_path):
    try:
        with open(eval_script_path, 'r') as f:
            lines = f.readlines()
            for i, line in enumerate(lines):
                if 'import vit_model' in line or 'from vit_model' in line:
                    print(f"  Line {i+1}: {line.strip()}")
                elif 'ViTForFishClassification' in line:
                     print(f"  Line {i+1}: {line.strip()} (contains class name)")
                if i < 20: # Print first 20 lines for context
                     print(f"  Line {i+1}: {line.strip()}")


    except Exception as e:
        print(f"❌ Could not read {eval_script_path}: {e}")
else:
    print(f"❌ {eval_script_path} not found.")


print(f"\n📄 Checking potential model file: {os.path.basename(model_file_guess)}")
# Check if model.py exists and print relevant lines
if os.path.exists(model_file_guess):
    try:
        with open(model_file_guess, 'r') as f:
            lines = f.readlines()
            print(f"✅ Found {os.path.basename(model_file_guess)}. Checking for class definition...")
            found_class = False
            for i, line in enumerate(lines):
                 if 'class ViTForFishClassification' in line:
                      print(f"  Line {i+1}: {line.strip()}")
                      found_class = True
                      break # Found the class, stop searching

            if not found_class:
                 print(f"⚠️ 'ViTForFishClassification' class definition not found in {os.path.basename(model_file_guess)}")

    except Exception as e:
        print(f"❌ Could not read {model_file_guess}: {e}")
else:
    print(f"❓ {os.path.basename(model_file_guess)} not found. Checking alternative name...")

print(f"\n📄 Checking alternative model file: {os.path.basename(vit_model_file_guess)}")
# Check if vit_model.py exists and print relevant lines
if os.path.exists(vit_model_file_guess):
    try:
        with open(vit_model_file_guess, 'r') as f:
            lines = f.readlines()
            print(f"✅ Found {os.path.basename(vit_model_file_guess)}. Checking for class definition...")
            found_class = False
            for i, line in enumerate(lines):
                 if 'class ViTForFishClassification' in line:
                      print(f"  Line {i+1}: {line.strip()}")
                      found_class = True
                      break # Found the class, stop searching

            if not found_class:
                 print(f"⚠️ 'ViTForFishClassification' class definition not found in {os.path.basename(vit_model_file_guess)}")


    except Exception as e:
        print(f"❌ Could not read {vit_model_file_guess}: {e}")
else:
    print(f"❓ {os.path.basename(vit_model_file_guess)} not found.")

print("\nDiagnosis steps complete. Please review the output.")

🔍 Diagnosing ModuleNotFoundError...
Repo directory: /content/ViT-FishID

📂 Files in repository root:
total 368
drwxr-xr-x 6 root root   4096 Aug 15 07:03 .
drwxr-xr-x 1 root root   4096 Aug 15 06:58 ..
-rw-r--r-- 1 root root  21217 Aug 15 06:58 data.py
-rw-r--r-- 1 root root  11572 Aug 15 06:58 evaluate.py
-rw-r--r-- 1 root root   3328 Aug 15 06:58 EXTENDED_TRAINING_SETUP.md
drwxr-xr-x 2 root root   4096 Aug 15 06:58 fish_cutouts
drwxr-xr-x 8 root root   4096 Aug 15 06:58 .git
-rw-r--r-- 1 root root     66 Aug 15 06:58 .gitattributes
-rw-r--r-- 1 root root    646 Aug 15 06:58 .gitignore
-rw-r--r-- 1 root root   9495 Aug 15 06:58 model.py
-rw-r--r-- 1 root root  16771 Aug 15 06:58 pipeline.py
drwxr-xr-x 2 root root   4096 Aug 15 07:03 __pycache__
-rw-r--r-- 1 root root  16566 Aug 15 06:58 README.md
-rw-r--r-- 1 root root    202 Aug 15 06:58 requirements.txt
-rw-r--r-- 1 root root   4265 Aug 15 06:58 resume_training.py
-rw-r--r-- 1 root root   5134 Aug 15 06:58 species_mapping.txt
-rw-r-

In [8]:
import os

print("Checking the contents of the MAE checkpoint directory:")
mae_checkpoint_dir = '/content/drive/MyDrive/mae_checkpoints'

if os.path.exists(mae_checkpoint_dir):
    print(f"✅ Directory found: {mae_checkpoint_dir}")
    print("\nFiles in the directory:")
    try:
        # List all items in the directory
        items = os.listdir(mae_checkpoint_dir)
        if items:
            for item in items:
                item_path = os.path.join(mae_checkpoint_dir, item)
                if os.path.isfile(item_path):
                    file_size = os.path.getsize(item_path) / (1024**2) # Size in MB
                    print(f"  - {item} ({file_size:.2f} MB)")
                else:
                    print(f"  - {item} (Directory)")
        else:
            print("  (Directory is empty)")
    except Exception as e:
        print(f"❌ Error listing directory contents: {e}")
else:
    print(f"❌ Directory not found: {mae_checkpoint_dir}")
    print("Please ensure the directory exists in your Google Drive.")

print("\n--- Check complete ---")

Checking the contents of the MAE checkpoint directory:
❌ Directory not found: /content/drive/MyDrive/mae_checkpoints
Please ensure the directory exists in your Google Drive.

--- Check complete ---


In [9]:
from google.colab import drive
drive.flush_and_unmount('/content/drive')
#

Drive not mounted, so nothing to flush and unmount.
