# üêü ViT-FishID: Semi-Supervised Fish Classification

**COMPLETE TRAINING PIPELINE WITH GOOGLE COLAB**

<a href="https://colab.research.google.com/github/cat-thomson/ViT-FishID/blob/main/ViT_FishID_Colab_Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## üéØ What This Notebook Does

This notebook implements a **complete semi-supervised learning pipeline** for fish species classification using:

**ü§ñ Vision Transformer (ViT)**: State-of-the-art transformer architecture for image classification
**üìä Semi-Supervised Learning**: Leverages both labeled and unlabeled fish images
**üéì EMA Teacher-Student Framework**: Uses exponential moving averages for consistency training
**‚òÅÔ∏è Google Colab**: Cloud-based training with GPU acceleration

## üìä Expected Performance

- **Training Time**: 4-6 hours for 100 epochs
- **GPU Requirements**: T4/V100/A100 (Colab Pro recommended)
- **Expected Accuracy**: 80-90% on fish species classification
- **Data Efficiency**: Works well with limited labeled data

## üõ†Ô∏è What You Need

1. **Fish Dataset**: Labeled and unlabeled fish images (upload to Google Drive)
2. **Google Colab Pro**: Recommended for longer training sessions
3. **Weights & Biases Account**: Optional for experiment tracking

## üîß Step 1: Environment Setup and GPU Check

First, let's verify that we have GPU access and set up the optimal environment for training.

In [None]:
# Check GPU availability and system information
import torch
import os
import gc

print("üîç SYSTEM INFORMATION")
print("="*50)
print(f"Python version: {os.sys.version}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    device_name = torch.cuda.get_device_name(0)
    device_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
    print(f"GPU Device: {device_name}")
    print(f"GPU Memory: {device_memory:.1f} GB")
    print("‚úÖ GPU is ready for training!")
    
    # Set optimal GPU settings
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False
    
    # Clear GPU cache
    torch.cuda.empty_cache()
    gc.collect()
    print("üöÄ GPU optimized for training")
    
else:
    print("‚ùå No GPU detected!")
    print("üìù To enable GPU in Colab:")
    print("   Runtime ‚Üí Change runtime type ‚Üí Hardware accelerator ‚Üí GPU")
    print("   Then restart this notebook")

# Set device for later use
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\nüéØ Using device: {DEVICE}")

üîç System Information:
Python version: 3.11.13 (main, Jun  4 2025, 08:57:29) [GCC 11.4.0]
PyTorch version: 2.6.0+cu124
CUDA available: True
GPU Device: NVIDIA A100-SXM4-40GB
GPU Memory: 39.6 GB
‚úÖ GPU is ready for training!


## üìÅ Step 2: Mount Google Drive

This will give us access to your fish dataset stored in Google Drive.

In [2]:
from google.colab import drive
import os

# Mount Google Drive
drive.mount('/content/drive')

# List contents to verify mount
print("\nüìÇ Google Drive contents:")
drive_path = '/content/drive/MyDrive'
if os.path.exists(drive_path):
    items = os.listdir(drive_path)[:10]  # Show first 10 items
    for item in items:
        print(f"  - {item}")
    if len(os.listdir(drive_path)) > 10:
        print(f"  ... and {len(os.listdir(drive_path)) - 10} more items")
    print("\n‚úÖ Google Drive mounted successfully!")
else:
    print("‚ùå Failed to mount Google Drive")

Mounted at /content/drive

üìÇ Google Drive contents:
  - Mock Matric
  - Photos
  - Admin
  - Uni
  - Fish_Training_Output
  - Colab Notebooks
  - ViT-FishID
  - fish_cutouts.zip
  - ViT-FishID_Training_20250814_154652
  - ViT-FishID_Training_20250814_202307
  ... and 3 more items

‚úÖ Google Drive mounted successfully!


## üì¶ Step 3: Install Dependencies

Installing all required packages for ViT-FishID training.

In [3]:
# Install required packages
print("üì¶ Installing dependencies...")

!pip install -q torch torchvision torchaudio
!pip install -q timm transformers
!pip install -q albumentations
!pip install -q wandb
!pip install -q opencv-python-headless
!pip install -q scikit-learn
!pip install -q matplotlib seaborn
!pip install -q tqdm

print("‚úÖ All dependencies installed successfully!")

# Verify installations
import torch
import torchvision
import timm
import albumentations
import cv2
import sklearn

print("\nüìã Package versions:")
print(f"  - torch: {torch.__version__}")
print(f"  - torchvision: {torchvision.__version__}")
print(f"  - timm: {timm.__version__}")
print(f"  - albumentations: {albumentations.__version__}")
print(f"  - opencv: {cv2.__version__}")
print(f"  - sklearn: {sklearn.__version__}")

üì¶ Installing dependencies...
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m363.4/363.4 MB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m13.8/13.8 MB[0m [31m128.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m24.6/24.6 MB[0m [31m100.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m883.7/883.7 kB[0m [31m60.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m664.8/664.8 MB[0m [31m1

## üîÑ Step 4: Clone ViT-FishID Repository

Getting the latest code from your GitHub repository.

In [4]:
# Clone the repository
import os

# Remove existing directory if it exists
if os.path.exists('/content/ViT-FishID'):
    !rm -rf /content/ViT-FishID

# Clone the repository
print("üì• Cloning ViT-FishID repository...")
!git clone https://github.com/cat-thomson/ViT-FishID.git /content/ViT-FishID

# Change to project directory
%cd /content/ViT-FishID

# List project files
print("\nüìÇ Project structure:")
!ls -la

print("\n‚úÖ Repository cloned successfully!")

üì• Cloning ViT-FishID repository...
Cloning into '/content/ViT-FishID'...
remote: Enumerating objects: 119, done.[K
remote: Counting objects: 100% (119/119), done.[K
remote: Compressing objects: 100% (86/86), done.[K
remote: Total 119 (delta 44), reused 98 (delta 27), pack-reused 0 (from 0)[K
Receiving objects: 100% (119/119), 201.94 KiB | 20.19 MiB/s, done.
Resolving deltas: 100% (44/44), done.
/content/ViT-FishID

üìÇ Project structure:
total 360
drwxr-xr-x 4 root root   4096 Aug 15 06:58 .
drwxr-xr-x 1 root root   4096 Aug 15 06:58 ..
-rw-r--r-- 1 root root  21217 Aug 15 06:58 data.py
-rw-r--r-- 1 root root  11572 Aug 15 06:58 evaluate.py
-rw-r--r-- 1 root root   3328 Aug 15 06:58 EXTENDED_TRAINING_SETUP.md
drwxr-xr-x 2 root root   4096 Aug 15 06:58 fish_cutouts
drwxr-xr-x 8 root root   4096 Aug 15 06:58 .git
-rw-r--r-- 1 root root     66 Aug 15 06:58 .gitattributes
-rw-r--r-- 1 root root    646 Aug 15 06:58 .gitignore
-rw-r--r-- 1 root root   9495 Aug 15 06:58 model.py
-rw-r

## üê† Step 5: Setup Fish Dataset

**Important**: Upload your `fish_cutouts.zip` file to Google Drive before running this step.

Expected dataset structure:
```
fish_cutouts/
‚îú‚îÄ‚îÄ labeled/
‚îÇ   ‚îú‚îÄ‚îÄ species_1/
‚îÇ   ‚îÇ   ‚îú‚îÄ‚îÄ fish_001.jpg
‚îÇ   ‚îÇ   ‚îî‚îÄ‚îÄ fish_002.jpg
‚îÇ   ‚îî‚îÄ‚îÄ species_2/
‚îÇ       ‚îî‚îÄ‚îÄ ...
‚îî‚îÄ‚îÄ unlabeled/
    ‚îú‚îÄ‚îÄ fish_003.jpg
    ‚îî‚îÄ‚îÄ fish_004.jpg
```

In [None]:
# Setup fish dataset from Google Drive
import zipfile
import shutil
import os
import glob

print("üê† SETTING UP FISH DATASET")
print("="*50)

# Configuration
ZIP_FILE_PATH = '/content/drive/MyDrive/fish_cutouts.zip'
DATA_DIR = '/content/fish_cutouts'

print(f"üìÇ Looking for dataset: {ZIP_FILE_PATH}")
print(f"üéØ Target directory: {DATA_DIR}")

# Check if data already exists locally
if os.path.exists(DATA_DIR) and os.path.exists(os.path.join(DATA_DIR, 'labeled')):
    print("‚úÖ Dataset already available locally!")
    
    # Quick validation
    labeled_dir = os.path.join(DATA_DIR, 'labeled')
    unlabeled_dir = os.path.join(DATA_DIR, 'unlabeled')
    
    if os.path.exists(labeled_dir):
        species_count = len([d for d in os.listdir(labeled_dir) 
                           if os.path.isdir(os.path.join(labeled_dir, d)) and not d.startswith('.')])
        print(f"üêü Found {species_count} labeled species")
    
    if os.path.exists(unlabeled_dir):
        unlabeled_count = len([f for f in os.listdir(unlabeled_dir) 
                             if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
        print(f"üìä Found {unlabeled_count} unlabeled images")

else:
    print("üì• Extracting dataset from Google Drive...")
    
    # Check if ZIP file exists
    if not os.path.exists(ZIP_FILE_PATH):
        print(f"‚ùå Dataset not found at: {ZIP_FILE_PATH}")
        print("üìù Please upload fish_cutouts.zip to Google Drive root directory")
    else:
        print(f"‚úÖ Found dataset: {os.path.getsize(ZIP_FILE_PATH) / (1024**2):.1f} MB")
        
        try:
            # Extract to temporary directory
            temp_dir = '/content/temp_extract'
            if os.path.exists(temp_dir):
                shutil.rmtree(temp_dir)
            
            with zipfile.ZipFile(ZIP_FILE_PATH, 'r') as zip_ref:
                zip_ref.extractall(temp_dir)
            
            # Find and organize data
            extracted_items = os.listdir(temp_dir)
            print(f"üìÅ Extracted: {extracted_items}")
            
            # Look for labeled and unlabeled directories
            labeled_source = None
            unlabeled_source = None
            
            for item in extracted_items:
                item_path = os.path.join(temp_dir, item)
                if item == 'labeled' and os.path.isdir(item_path):
                    labeled_source = item_path
                elif item == 'unlabeled' and os.path.isdir(item_path):
                    unlabeled_source = item_path
            
            if labeled_source and unlabeled_source:
                # Create target directory
                if os.path.exists(DATA_DIR):
                    shutil.rmtree(DATA_DIR)
                os.makedirs(DATA_DIR)
                
                # Move directories
                shutil.move(labeled_source, os.path.join(DATA_DIR, 'labeled'))
                shutil.move(unlabeled_source, os.path.join(DATA_DIR, 'unlabeled'))
                
                print("‚úÖ Dataset organized successfully!")
                
                # Verify structure
                labeled_dir = os.path.join(DATA_DIR, 'labeled')
                species_count = len([d for d in os.listdir(labeled_dir) 
                                   if os.path.isdir(os.path.join(labeled_dir, d)) and not d.startswith('.')])
                
                unlabeled_dir = os.path.join(DATA_DIR, 'unlabeled')
                unlabeled_count = len([f for f in os.listdir(unlabeled_dir) 
                                     if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
                
                print(f"üêü Verified: {species_count} species")
                print(f"üìä Verified: {unlabeled_count} unlabeled images")
                
            else:
                print("‚ùå Could not find labeled and unlabeled directories")
            
            # Cleanup
            if os.path.exists(temp_dir):
                shutil.rmtree(temp_dir)
                
        except Exception as e:
            print(f"‚ùå Error extracting dataset: {e}")

# Final verification
if os.path.exists(DATA_DIR):
    print(f"\n‚úÖ DATASET READY")
    print(f"üìÅ Location: {DATA_DIR}")
    print("üöÄ Ready for training!")
else:
    print(f"\n‚ùå DATASET SETUP FAILED")
    print("Please check that fish_cutouts.zip is uploaded to Google Drive")

üóÇÔ∏è SETTING UP FISH DATASET - CORRECTED PATHS
üéØ ZIP file location: /content/drive/MyDrive/fish_cutouts.zip
üéØ Target data directory: /content/fish_cutouts
üì• Data not found locally, extracting from Google Drive...
‚úÖ Found ZIP file at: /content/drive/MyDrive/fish_cutouts.zip
üìè ZIP file size: 216.5 MB
üì¶ Extracting fish_cutouts.zip...
‚úÖ ZIP extraction completed
üìÅ Found in ZIP: ['dataset_info.json', '__MACOSX', 'labeled', 'unlabeled']
üìÑ Found dataset info: dataset_info.json
‚úÖ Found labeled directory: labeled
‚úÖ Found unlabeled directory: unlabeled
‚úÖ Data organized at: /content/fish_cutouts
üìÑ Copied dataset_info.json
üêü Verified: 37 species in labeled data
üìä Verified: 24015 images in unlabeled data

‚úÖ DATASET READY
üìÅ Location: /content/fish_cutouts
  üìÇ labeled/: 37 species folders
  üìÇ unlabeled/: 24015 images
  üìÑ dataset_info.json: Available
üöÄ Ready to proceed with training!


## üìà Step 6: Setup Weights & Biases (Optional)

Weights & Biases provides excellent training visualization and experiment tracking.

In [None]:
# Login to Weights & Biases for experiment tracking
import wandb
import os

print("üìà SETTING UP WEIGHTS & BIASES")
print("="*40)

# Check if API key is available
if os.environ.get("WANDB_API_KEY"):
    print("‚úÖ W&B API key found in environment")
    try:
        wandb.login(relogin=True)
        print("‚úÖ Successfully logged in to W&B")
    except Exception as e:
        print(f"‚ö†Ô∏è W&B relogin failed: {e}")
        print("Trying manual login...")
        wandb.login()
else:
    print("üîë Please enter your W&B API key when prompted")
    print("üí° Get your API key from: https://wandb.ai/settings")
    try:
        wandb.login()
        print("‚úÖ Successfully logged in to W&B")
    except Exception as e:
        print(f"‚ùå W&B login failed: {e}")
        print("Continuing without W&B logging...")

# Check connection status
if wandb.run:
    print(f"üöÄ W&B Run URL: {wandb.run.url}")
    USE_WANDB = True
else:
    print("üìä W&B not connected - training will continue without logging")
    USE_WANDB = False

print(f"‚úÖ W&B setup complete (Enabled: {USE_WANDB})")

üìà Connecting to Weights & Biases...
üîë Please enter your W&B API key when prompted.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ¬∑¬∑¬∑¬∑¬∑¬∑¬∑¬∑¬∑¬∑


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mcativthomson[0m ([33mcativthomson-university-of-cape-town[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


‚úÖ Successfully logged in to W&B.
‚ùå W&B connection not established. Logging may be disabled.


## üîÑ Step 6: Locate Checkpoint from Epoch 19

Finding your saved checkpoint to resume training from where you left off.

In [7]:
# Locate checkpoint from epoch 19
import os
import glob
import torch

print("üîç Looking for checkpoint from epoch 100...")

# Possible checkpoint locations
checkpoint_locations = [
    '/content/drive/MyDrive/ViT-FishID/checkpoints_extended', '/content/drive/MyDrive/ViT-FishID/checkpoints_backup'
]

checkpoint_path = None
checkpoint_info = None

# Search for epoch 19 checkpoint
for location_pattern in checkpoint_locations:
    for location in glob.glob(location_pattern):
        if os.path.exists(location):
            print(f"üìÅ Checking: {location}")

            # Look for epoch 19 specifically
            epoch_100_files = glob.glob(os.path.join(location, '*epoch_100*'))
            manual_files = glob.glob(os.path.join(location, '*manual*epoch*100*'))
            emergency_files = glob.glob(os.path.join(location, '*emergency*epoch*100*'))

            all_candidates = epoch_100_files + manual_files + emergency_files

            for candidate in all_candidates:
                if candidate.endswith('.pth'):
                    print(f"üéØ Found candidate: {os.path.basename(candidate)}")
                    try:
                        # Verify checkpoint can be loaded
                        test_checkpoint = torch.load(candidate, map_location='cpu')
                        epoch = test_checkpoint.get('epoch', 'unknown')

                        if epoch == 100 or '100' in os.path.basename(candidate):
                            checkpoint_path = candidate
                            checkpoint_info = test_checkpoint
                            print(f"‚úÖ FOUND EPOCH 100 CHECKPOINT!")
                            print(f"üìÅ Location: {checkpoint_path}")
                            print(f"üìä Epoch: {epoch}")

                            if 'best_accuracy' in test_checkpoint:
                                print(f"üìä Best accuracy so far: {test_checkpoint['best_accuracy']:.2f}%")
                            elif 'best_acc' in test_checkpoint:
                                print(f"üìä Best accuracy so far: {test_checkpoint['best_acc']:.2f}%")

                            break
                    except Exception as e:
                        print(f"‚ö†Ô∏è Could not load {candidate}: {e}")

            if checkpoint_path:
                break

        if checkpoint_path:
            break

if checkpoint_path:
    print(f"\nüéâ Checkpoint ready for resuming training!")
    print(f"üìÑ File: {os.path.basename(checkpoint_path)}")
    print(f"üìè Size: {os.path.getsize(checkpoint_path) / (1024*1024):.1f} MB")

    # Set up checkpoint directory for new saves
    checkpoint_save_dir = '/content/drive/MyDrive/ViT-FishID/checkpoints_extended'
    os.makedirs(checkpoint_save_dir, exist_ok=True)
    print(f"üíæ New checkpoints will be saved to: {checkpoint_save_dir}")

else:
    print("‚ùå No checkpoint found for epoch 19!")
    print("\nüîß Troubleshooting:")
    print("1. Check that you have a checkpoint saved from previous training")
    print("2. Ensure the checkpoint is uploaded to Google Drive")
    print("3. Look for files named like: checkpoint_epoch_19.pth, emergency_checkpoint_epoch_19.pth")
    print("\nüìÅ Checked locations:")
    for location in checkpoint_locations:
        print(f"  - {location}")

    # Fallback: look for any checkpoints
    print("\nüîç All available checkpoints:")
    for location_pattern in checkpoint_locations:
        for location in glob.glob(location_pattern):
            if os.path.exists(location):
                all_checkpoints = glob.glob(os.path.join(location, '*.pth'))
                for cp in all_checkpoints:
                    print(f"  - {os.path.basename(cp)}")

# Store checkpoint path for later use
RESUME_CHECKPOINT = checkpoint_path

üîç Looking for checkpoint from epoch 100...
üìÅ Checking: /content/drive/MyDrive/ViT-FishID/checkpoints_extended
üéØ Found candidate: checkpoint_epoch_100.pth
‚úÖ FOUND EPOCH 100 CHECKPOINT!
üìÅ Location: /content/drive/MyDrive/ViT-FishID/checkpoints_extended/checkpoint_epoch_100.pth
üìä Epoch: 100
üìä Best accuracy so far: 87.56%
üìÅ Checking: /content/drive/MyDrive/ViT-FishID/checkpoints_backup
üéØ Found candidate: checkpoint_epoch_100.pth
‚úÖ FOUND EPOCH 100 CHECKPOINT!
üìÅ Location: /content/drive/MyDrive/ViT-FishID/checkpoints_backup/checkpoint_epoch_100.pth
üìä Epoch: 100
üìä Best accuracy so far: 87.56%

üéâ Checkpoint ready for resuming training!
üìÑ File: checkpoint_epoch_100.pth
üìè Size: 982.4 MB
üíæ New checkpoints will be saved to: /content/drive/MyDrive/ViT-FishID/checkpoints_extended


## ‚öôÔ∏è Step 7: Configure Training Parameters

Configure the training settings for your semi-supervised fish classification model.

In [None]:
# Training Configuration for Semi-Supervised Fish Classification
import os

print("‚öôÔ∏è TRAINING CONFIGURATION")
print("="*50)

# Auto-detect number of species from dataset
NUM_CLASSES = 37  # Default
if 'DATA_DIR' in globals() and os.path.exists(DATA_DIR):
    labeled_dir = os.path.join(DATA_DIR, 'labeled')
    if os.path.exists(labeled_dir):
        species_count = len([d for d in os.listdir(labeled_dir) 
                           if os.path.isdir(os.path.join(labeled_dir, d)) and not d.startswith('.')])
        NUM_CLASSES = species_count
        print(f"üìä Auto-detected {species_count} fish species")

# Create checkpoint directories
CHECKPOINT_DIR = '/content/drive/MyDrive/ViT-FishID/checkpoints'
BACKUP_DIR = '/content/drive/MyDrive/ViT-FishID/checkpoints_backup'

try:
    os.makedirs(CHECKPOINT_DIR, exist_ok=True)
    os.makedirs(BACKUP_DIR, exist_ok=True)
    print(f"üìÅ Checkpoints: {CHECKPOINT_DIR}")
    print(f"üíæ Backups: {BACKUP_DIR}")
except Exception as e:
    print(f"‚ö†Ô∏è Could not create Google Drive directories: {e}")
    CHECKPOINT_DIR = '/content/checkpoints'
    BACKUP_DIR = '/content/checkpoints_backup'
    os.makedirs(CHECKPOINT_DIR, exist_ok=True)
    os.makedirs(BACKUP_DIR, exist_ok=True)
    print(f"üìÅ Using local checkpoints: {CHECKPOINT_DIR}")

# Training Configuration
TRAINING_CONFIG = {
    # BASIC SETTINGS
    'mode': 'semi_supervised',
    'data_dir': DATA_DIR if 'DATA_DIR' in globals() else '/content/fish_cutouts',
    'epochs': 100,
    'batch_size': 16,
    'learning_rate': 1e-4,
    'weight_decay': 0.05,
    
    # MODEL SETTINGS
    'model_name': 'vit_base_patch16_224',
    'num_classes': NUM_CLASSES,
    'pretrained': True,
    
    # SEMI-SUPERVISED SETTINGS
    'consistency_weight': 2.0,
    'pseudo_label_threshold': 0.7,
    'temperature': 4.0,
    'warmup_epochs': 10,
    'ramp_up_epochs': 30,
    
    # CHECKPOINT SETTINGS
    'save_frequency': 10,  # Save every 10 epochs
    'checkpoint_dir': CHECKPOINT_DIR,
    'backup_dir': BACKUP_DIR,
    
    # LOGGING SETTINGS
    'use_wandb': USE_WANDB if 'USE_WANDB' in globals() else False,
    'wandb_project': 'ViT-FishID-Training',
    'wandb_run_name': f'fish-classification-{NUM_CLASSES}-classes',
}

print("\nüìã TRAINING CONFIGURATION")
print("="*50)
print(f"üéØ Training mode: {TRAINING_CONFIG['mode']}")
print(f"üìä Total epochs: {TRAINING_CONFIG['epochs']}")
print(f"üì¶ Batch size: {TRAINING_CONFIG['batch_size']}")
print(f"üß† Model: {TRAINING_CONFIG['model_name']}")
print(f"üêü Number of species: {TRAINING_CONFIG['num_classes']}")
print(f"‚öñÔ∏è Consistency weight: {TRAINING_CONFIG['consistency_weight']}")
print(f"üéØ Pseudo-label threshold: {TRAINING_CONFIG['pseudo_label_threshold']}")
print(f"üíæ Save frequency: Every {TRAINING_CONFIG['save_frequency']} epochs")
print(f"üìà W&B logging: {TRAINING_CONFIG['use_wandb']}")

# Time estimation
estimated_time_hours = TRAINING_CONFIG['epochs'] * 3 / 60  # ~3 minutes per epoch
print(f"\n‚è±Ô∏è Estimated training time: {estimated_time_hours:.1f} hours")
print(f"üí° Recommendation: Use Colab Pro for longer training sessions")

print("\n‚úÖ Configuration complete - ready to start training!")

üéØ EXTENDED TRAINING CONFIGURATION - WITH W&B
üìä Detected 37 fish species

EXTENDED TRAINING CONFIGURATION SUMMARY
üìä Resume from: Epoch 100
üìä Target epochs: 100
üìä Remaining epochs: 1
‚è±Ô∏è Estimated time: 5-7 minutes
üìä Batch size: 16 (optimized for Colab Pro)
üíæ Checkpoint saves: EVERY 1 epoch(s)
üìä Mode: semi_supervised with consistency weight 2.0
üìä Logging: W&B Enabled (Project: ViT-FishID-Extended-Training, Run: resume-epoch-6-to-100)
üìä Num Classes: 37

SETTING UP CHECKPOINT DIRECTORIES
üìÅ Primary saves: /content/drive/MyDrive/ViT-FishID/checkpoints_extended (Created/Exists)
üíæ Backup saves: /content/drive/MyDrive/ViT-FishID/checkpoints_backup (Created/Exists)

‚úÖ Will resume training from: checkpoint_epoch_100.pth

üöÄ Configuration complete. Ready to resume/start training!


## ü§ñ Step 7a: Load MAE Pre-trained Model (Optional)

**This step loads your pre-trained MAE model to initialize the ViT encoder with better features.**

The MAE (Masked Autoencoder) model you trained provides much better initial weights for the Vision Transformer compared to ImageNet pretraining, especially for fish images since it was trained specifically on your fish dataset.

Benefits of using MAE initialization:
- **Better Feature Representations**: Learned specifically on fish images
- **Faster Convergence**: Model starts with relevant features
- **Improved Performance**: Often leads to 2-5% accuracy improvement

### üìÅ MAE Model Locations

Your MAE models should be in one of these locations:
- **Local**: `/Users/catalinathomson/Desktop/Fish/ViT-FishID/mae_checkpoints/mae_final_model.pth`
- **Google Drive**: `/content/drive/MyDrive/mae_checkpoints/mae_final_model.pth` (after upload)

### üîß Setup Instructions

1. **Upload MAE Model**: Upload your `mae_final_model.pth` or `mae_best_model.pth` to Google Drive
2. **Update Path**: Modify `MAE_MODEL_PATH` in the next cell if needed
3. **Enable/Disable**: Set `LOAD_MAE_PRETRAINED = True/False` to control MAE loading

In [None]:
# Load MAE Pre-trained Model and Create Custom ViT Model
import torch
import os
import shutil
from model import ViTForFishClassification

print("ü§ñ SETTING UP MAE-INITIALIZED ViT MODEL")
print("="*60)

# Configuration for MAE loading
MAE_MODEL_PATH = '/content/drive/MyDrive/mae_checkpoints/mae_final_model.pth'  # Update this path if needed
LOAD_MAE_PRETRAINED = True  # Set to False to skip MAE loading

# Global variable to store MAE state for later use
MAE_ENCODER_WEIGHTS = None

def load_mae_encoder_weights(mae_checkpoint_path):
    """
    Load and extract encoder weights from MAE checkpoint.
    
    Args:
        mae_checkpoint_path: Path to MAE checkpoint file
        
    Returns:
        dict: Filtered encoder weights compatible with ViT backbone
    """
    print(f"üì• Loading MAE checkpoint from: {mae_checkpoint_path}")
    
    # Load MAE checkpoint
    checkpoint = torch.load(mae_checkpoint_path, map_location='cpu')
    
    # Print checkpoint info
    if 'epoch' in checkpoint:
        print(f"üìä MAE trained for {checkpoint['epoch']} epochs")
    if 'train_loss' in checkpoint:
        print(f"üìâ Final MAE loss: {checkpoint['train_loss']:.4f}")
    
    # Get model state dict
    mae_state_dict = checkpoint.get('model_state_dict', checkpoint.get('state_dict', checkpoint))
    
    # Filter encoder weights (remove decoder, mask token, and other non-encoder components)
    encoder_weights = {}
    for key, value in mae_state_dict.items():
        # Keep only encoder-related weights
        if any(prefix in key for prefix in [
            'patch_embed',
            'pos_embed', 
            'cls_token',
            'blocks',
            'norm'
        ]) and not any(exclude in key for exclude in [
            'decoder',
            'mask_token',
            'head'
        ]):
            encoder_weights[key] = value
    
    print(f"üìä Extracted {len(encoder_weights)} encoder parameters from MAE")
    
    return encoder_weights

def create_mae_initialized_model(num_classes, model_name='vit_base_patch16_224', mae_weights=None):
    """
    Create ViT model and optionally initialize with MAE weights.
    
    Args:
        num_classes: Number of classification classes
        model_name: ViT model architecture name
        mae_weights: Optional MAE encoder weights dictionary
        
    Returns:
        ViTForFishClassification: Initialized model
    """
    print(f"üèóÔ∏è Creating ViT model: {model_name}")
    
    # Create ViT model (without ImageNet pretraining if we have MAE weights)
    use_imagenet_pretrained = mae_weights is None
    model = ViTForFishClassification(
        num_classes=num_classes,
        model_name=model_name,
        pretrained=use_imagenet_pretrained,
        dropout_rate=0.1
    )
    
    if mae_weights is not None:
        print("‚ö° Initializing ViT backbone with MAE encoder weights...")
        
        # Get current backbone state dict
        backbone_state = model.backbone.state_dict()
        
        # Update with MAE weights (only for matching keys and shapes)
        updated_keys = []
        shape_mismatches = []
        
        for mae_key, mae_weight in mae_weights.items():
            if mae_key in backbone_state:
                if mae_weight.shape == backbone_state[mae_key].shape:
                    backbone_state[mae_key] = mae_weight.clone()
                    updated_keys.append(mae_key)
                else:
                    shape_mismatches.append(f"{mae_key}: MAE{mae_weight.shape} != ViT{backbone_state[mae_key].shape}")
        
        # Load updated weights
        model.backbone.load_state_dict(backbone_state)
        
        print(f"‚úÖ Successfully transferred {len(updated_keys)} MAE encoder weights")
        
        if shape_mismatches:
            print(f"‚ö†Ô∏è Found {len(shape_mismatches)} shape mismatches (using original weights):")
            for mismatch in shape_mismatches[:5]:  # Show first 5 mismatches
                print(f"   {mismatch}")
        
        print("üéØ ViT model initialized with MAE-learned features!")
        
    else:
        print("üåê Using ImageNet pretrained weights")
    
    return model

# Main execution
if LOAD_MAE_PRETRAINED:
    # Check if MAE model exists in Google Drive
    if os.path.exists(MAE_MODEL_PATH):
        print(f"‚úÖ Found MAE model: {os.path.basename(MAE_MODEL_PATH)}")
        print(f"üìè Size: {os.path.getsize(MAE_MODEL_PATH) / (1024**2):.1f} MB")
        
        try:
            # Load MAE encoder weights
            MAE_ENCODER_WEIGHTS = load_mae_encoder_weights(MAE_MODEL_PATH)
            print("üéâ MAE encoder weights loaded successfully!")
            
            # Update training config
            TRAINING_CONFIG['mae_pretrained'] = True
            TRAINING_CONFIG['mae_model_path'] = MAE_MODEL_PATH
            TRAINING_CONFIG['pretrained'] = False  # Don't use ImageNet since we have MAE
            
        except Exception as e:
            print(f"‚ùå Error loading MAE model: {e}")
            print("üîÑ Falling back to ImageNet pretrained weights...")
            MAE_ENCODER_WEIGHTS = None
            TRAINING_CONFIG['mae_pretrained'] = False
            TRAINING_CONFIG['pretrained'] = True
    
    else:
        # MAE model not found, check alternative locations
        print(f"‚ùå MAE model not found at: {MAE_MODEL_PATH}")
        
        # Try to copy from local mae_checkpoints if exists
        local_mae_path = f'/content/ViT-FishID/mae_checkpoints/{os.path.basename(MAE_MODEL_PATH)}'
        if os.path.exists(local_mae_path):
            print(f"? Found MAE model in local repository: {local_mae_path}")
            try:
                # Create directory and copy
                os.makedirs(os.path.dirname(MAE_MODEL_PATH), exist_ok=True)
                shutil.copy2(local_mae_path, MAE_MODEL_PATH)
                print(f"‚úÖ Copied MAE model to Google Drive: {MAE_MODEL_PATH}")
                
                # Now load it
                MAE_ENCODER_WEIGHTS = load_mae_encoder_weights(MAE_MODEL_PATH)
                TRAINING_CONFIG['mae_pretrained'] = True
                TRAINING_CONFIG['mae_model_path'] = MAE_MODEL_PATH
                TRAINING_CONFIG['pretrained'] = False
                
            except Exception as e:
                print(f"‚ùå Error copying/loading MAE model: {e}")
                MAE_ENCODER_WEIGHTS = None
                TRAINING_CONFIG['mae_pretrained'] = False
                TRAINING_CONFIG['pretrained'] = True
        else:
            print("?üìù Available options:")
            print("1. Upload mae_final_model.pth or mae_best_model.pth to /content/drive/MyDrive/mae_checkpoints/")
            print("2. Update MAE_MODEL_PATH variable to correct location")
            print("3. Set LOAD_MAE_PRETRAINED = False to use ImageNet weights")
            print("üîÑ Continuing with ImageNet pretrained weights...")
            MAE_ENCODER_WEIGHTS = None
            TRAINING_CONFIG['mae_pretrained'] = False
            TRAINING_CONFIG['pretrained'] = True

else:
    print("‚è≠Ô∏è Skipping MAE loading - will use ImageNet pretrained weights")
    MAE_ENCODER_WEIGHTS = None
    TRAINING_CONFIG['mae_pretrained'] = False
    TRAINING_CONFIG['pretrained'] = True

# Test model creation (optional - this creates a model to verify everything works)
print(f"\nüß™ Testing model creation...")
try:
    test_model = create_mae_initialized_model(
        num_classes=NUM_CLASSES,
        model_name=TRAINING_CONFIG['model_name'],
        mae_weights=MAE_ENCODER_WEIGHTS
    )
    
    # Test forward pass
    test_input = torch.randn(1, 3, 224, 224)
    with torch.no_grad():
        test_output = test_model(test_input)
    
    print(f"‚úÖ Model test successful!")
    print(f"üìä Input shape: {test_input.shape}")
    print(f"üìä Output shape: {test_output.shape}")
    print(f"üéØ Model ready for training!")
    
    # Clean up test model
    del test_model, test_input, test_output
    
except Exception as e:
    print(f"‚ùå Model test failed: {e}")

print(f"\n" + "="*60)
print(f"‚úÖ MAE INITIALIZATION SETUP COMPLETE!")
print(f"ü§ñ MAE pretrained: {TRAINING_CONFIG.get('mae_pretrained', False)}")
print(f"üåê ImageNet pretrained: {TRAINING_CONFIG.get('pretrained', True)}")
print(f"üìä Model: {TRAINING_CONFIG['model_name']} with {NUM_CLASSES} classes")

if TRAINING_CONFIG.get('mae_pretrained', False):
    print("üéâ Your model will start with MAE-learned features specific to fish images!")
    print("üöÄ This should lead to faster training and better performance!")
else:
    print("üåê Your model will use standard ImageNet pretrained features.")

print("üéØ Ready to proceed to training!")

In [None]:
# Helper: Copy MAE Model to Google Drive (if needed)
import os
import shutil

print("üîç CHECKING MAE MODEL AVAILABILITY")
print("="*50)

# Define possible local locations (in cloned repo)
local_mae_locations = [
    '/content/ViT-FishID/mae_checkpoints/mae_final_model.pth',
    '/content/ViT-FishID/mae_checkpoints/mae_best_model.pth',
]

# Define Google Drive location
gdrive_mae_dir = '/content/drive/MyDrive/mae_checkpoints'
os.makedirs(gdrive_mae_dir, exist_ok=True)

# Check and copy MAE models if they exist locally but not in Google Drive
for local_path in local_mae_locations:
    model_name = os.path.basename(local_path)
    gdrive_path = os.path.join(gdrive_mae_dir, model_name)
    
    if os.path.exists(local_path):
        file_size = os.path.getsize(local_path) / (1024**2)
        print(f"‚úÖ Found local MAE model: {model_name} ({file_size:.1f} MB)")
        
        if not os.path.exists(gdrive_path):
            print(f"üì• Copying to Google Drive...")
            try:
                shutil.copy2(local_path, gdrive_path)
                print(f"‚úÖ Copied {model_name} to Google Drive")
            except Exception as e:
                print(f"‚ùå Error copying {model_name}: {e}")
        else:
            print(f"‚úÖ {model_name} already exists in Google Drive")
    else:
        print(f"‚ùå Local MAE model not found: {model_name}")

# List available MAE models in Google Drive
print(f"\nüìÅ Available MAE models in Google Drive:")
if os.path.exists(gdrive_mae_dir):
    mae_files = [f for f in os.listdir(gdrive_mae_dir) if f.endswith('.pth')]
    if mae_files:
        for mae_file in mae_files:
            file_path = os.path.join(gdrive_mae_dir, mae_file)
            file_size = os.path.getsize(file_path) / (1024**2)
            print(f"  üìÑ {mae_file} ({file_size:.1f} MB)")
    else:
        print("  ‚ùå No MAE models found in Google Drive")
        print("  üìù Please upload your MAE model manually to /content/drive/MyDrive/mae_checkpoints/")
else:
    print("  ‚ùå Mae checkpoints directory not found in Google Drive")

print("\n‚úÖ MAE model check complete!")

## üöÄ Step 8: Start Semi-Supervised Training

This cell will start the complete training process. Expected time: 4-6 hours for 100 epochs.

**Training Process:**
1. **Supervised Learning**: Uses labeled fish images with ground truth
2. **Semi-Supervised Learning**: Leverages unlabeled images with pseudo-labels
3. **EMA Teacher-Student**: Uses exponential moving average for consistency
4. **Automatic Checkpointing**: Saves progress every 10 epochs

## üîÑ Step 7b: Resume Training (If Interrupted)

**Use this section if your training was interrupted and you want to continue from where you left off.**

This will automatically find your latest checkpoint and resume training from that point.

In [None]:
# Start Semi-Supervised Training with Optional MAE Initialization
import os
import glob
from datetime import datetime

print("üöÄ STARTING SEMI-SUPERVISED FISH CLASSIFICATION TRAINING")
print("="*60)

# Change to repository directory
%cd /content/ViT-FishID

# Check for existing checkpoints to resume from
RESUME_FROM = None
if os.path.exists(TRAINING_CONFIG['checkpoint_dir']):
    checkpoints = glob.glob(os.path.join(TRAINING_CONFIG['checkpoint_dir'], 'checkpoint_epoch_*.pth'))
    if checkpoints:
        # Find the latest checkpoint
        epoch_numbers = []
        for cp in checkpoints:
            try:
                epoch_num = int(cp.split('epoch_')[1].split('.')[0])
                epoch_numbers.append((epoch_num, cp))
            except:
                continue
        
        if epoch_numbers:
            epoch_numbers.sort(key=lambda x: x[0], reverse=True)  # Latest first
            latest_epoch, latest_checkpoint = epoch_numbers[0]
            print(f"üîç Found existing checkpoints. Latest: Epoch {latest_epoch}")
            
            # Ask user if they want to resume (auto-skip in Colab for now)
            # resume_choice = input("Do you want to resume from the latest checkpoint? (y/n): ").lower().strip()
            resume_choice = 'n'  # Set to 'y' if you want to auto-resume
            
            if resume_choice in ['y', 'yes']:
                RESUME_FROM = latest_checkpoint
                print(f"‚úÖ Will resume from: {os.path.basename(latest_checkpoint)}")
            else:
                print("üÜï Starting fresh training from epoch 1")

# Create a modified training script if we have MAE weights
if TRAINING_CONFIG.get('mae_pretrained', False) and 'MAE_ENCODER_WEIGHTS' in globals() and MAE_ENCODER_WEIGHTS is not None:
    print("ü§ñ Creating MAE-enhanced training script...")
    
    # Create custom train script that initializes with MAE weights
    mae_train_script = """#!/usr/bin/env python3
import sys
sys.path.append('/content/ViT-FishID')

import torch
import argparse
from model import ViTForFishClassification

# Function to create MAE-initialized model
def create_mae_initialized_model(num_classes, model_name, mae_weights):
    model = ViTForFishClassification(
        num_classes=num_classes,
        model_name=model_name,
        pretrained=False,  # Don't use ImageNet
        dropout_rate=0.1
    )
    
    if mae_weights is not None:
        backbone_state = model.backbone.state_dict()
        updated_keys = []
        
        for mae_key, mae_weight in mae_weights.items():
            if mae_key in backbone_state:
                if mae_weight.shape == backbone_state[mae_key].shape:
                    backbone_state[mae_key] = mae_weight.clone()
                    updated_keys.append(mae_key)
        
        model.backbone.load_state_dict(backbone_state)
        print(f"‚úÖ Loaded {len(updated_keys)} MAE encoder weights into model")
    
    return model

# Load MAE weights
mae_checkpoint = torch.load('{}', map_location='cpu')
mae_state_dict = mae_checkpoint.get('model_state_dict', mae_checkpoint.get('state_dict', mae_checkpoint))

mae_weights = {{}}
for key, value in mae_state_dict.items():
    if any(prefix in key for prefix in ['patch_embed', 'pos_embed', 'cls_token', 'blocks', 'norm']) and not any(exclude in key for exclude in ['decoder', 'mask_token', 'head']):
        mae_weights[key] = value

print(f"ü§ñ Loaded {{len(mae_weights)}} MAE encoder weights")

# Now run the original training with MAE initialization
""".format(TRAINING_CONFIG.get('mae_model_path', ''))
    
    # Write the custom script
    with open('/content/mae_init_prefix.py', 'w') as f:
        f.write(mae_train_script)
    
    # Build training command with MAE initialization
    training_cmd = f"""python -c "
import sys
sys.path.append('/content/ViT-FishID')
exec(open('/content/mae_init_prefix.py').read())

# Now import and run training
from train import *
import torch

# Override model creation in train.py
original_args = parse_arguments()

# Parse our arguments
class Args:
    def __init__(self):
        self.mode = '{TRAINING_CONFIG['mode']}'
        self.data_dir = '{TRAINING_CONFIG['data_dir']}'
        self.epochs = {TRAINING_CONFIG['epochs']}
        self.batch_size = {TRAINING_CONFIG['batch_size']}
        self.learning_rate = {TRAINING_CONFIG['learning_rate']}
        self.weight_decay = {TRAINING_CONFIG['weight_decay']}
        self.model_name = '{TRAINING_CONFIG['model_name']}'
        self.consistency_weight = {TRAINING_CONFIG['consistency_weight']}
        self.pseudo_label_threshold = {TRAINING_CONFIG['pseudo_label_threshold']}
        self.temperature = {TRAINING_CONFIG['temperature']}
        self.warmup_epochs = {TRAINING_CONFIG['warmup_epochs']}
        self.ramp_up_epochs = {TRAINING_CONFIG['ramp_up_epochs']}
        self.save_dir = '{TRAINING_CONFIG['checkpoint_dir']}'
        self.save_frequency = {TRAINING_CONFIG['save_frequency']}
        self.pretrained = False
        self.use_wandb = {str(TRAINING_CONFIG['use_wandb']).lower()}
        self.resume_from = {'None' if not RESUME_FROM else f'\\'{RESUME_FROM}\\''}
        self.num_workers = 4
        self.image_size = 224
        self.dropout_rate = 0.1
        self.num_classes = {NUM_CLASSES}
        
args = Args()

# Set up device and seed
device = get_device()
set_seed(42)

# Create MAE-initialized model
print('ü§ñ Creating MAE-initialized model for training...')
student_model = create_mae_initialized_model(
    num_classes=args.num_classes,
    model_name=args.model_name,
    mae_weights=mae_weights
).to(device)

# Continue with regular training process
from trainer import EMATrainer, SemiSupervisedTrainer
from data import create_dataloaders, create_semi_supervised_dataloaders

# Create data loaders
if args.mode == 'supervised':
    train_loader, val_loader, num_classes = create_dataloaders(
        args.data_dir,
        batch_size=args.batch_size,
        image_size=args.image_size,
        num_workers=args.num_workers
    )
    unlabeled_loader = None
else:
    train_loader, val_loader, unlabeled_loader, num_classes = create_semi_supervised_dataloaders(
        args.data_dir,
        batch_size=args.batch_size,
        image_size=args.image_size,
        num_workers=args.num_workers
    )

print(f'üìä Number of classes: {{num_classes}}')
print(f'üéØ Training mode: {{args.mode}}')

# Create trainer
if args.mode == 'semi_supervised' and unlabeled_loader is not None:
    trainer = SemiSupervisedTrainer(
        student_model=student_model,
        device=device,
        learning_rate=args.learning_rate,
        weight_decay=args.weight_decay,
        consistency_weight=args.consistency_weight,
        pseudo_label_threshold=args.pseudo_label_threshold,
        temperature=args.temperature,
        warmup_epochs=args.warmup_epochs,
        ramp_up_epochs=args.ramp_up_epochs
    )
else:
    trainer = EMATrainer(
        student_model=student_model,
        device=device,
        learning_rate=args.learning_rate,
        weight_decay=args.weight_decay
    )

# Initialize W&B
if args.use_wandb:
    import wandb
    wandb.init(
        project='ViT-FishID-MAE-Training',
        config=vars(args),
        tags=['mae-initialized', 'fish-classification']
    )

# Resume from checkpoint if specified
if args.resume_from and args.resume_from != 'None':
    print(f'üì• Resuming from checkpoint: {{args.resume_from}}')
    try:
        checkpoint = torch.load(args.resume_from, map_location=device)
        trainer.student_model.load_state_dict(checkpoint['student_state_dict'])
        if hasattr(trainer, 'teacher_model') and 'teacher_state_dict' in checkpoint:
            trainer.teacher_model.teacher_model.load_state_dict(checkpoint['teacher_state_dict'])
        trainer.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint.get('epoch', 0) + 1
        print(f'‚úÖ Resumed from epoch {{start_epoch}}')
    except Exception as e:
        print(f'‚ùå Error loading checkpoint: {{e}}')
        start_epoch = 1
else:
    start_epoch = 1

print(f'üöÄ Starting training from epoch {{start_epoch}}')

# Training loop
best_accuracy = 0.0
for epoch in range(start_epoch, args.epochs + 1):
    print(f'\\nüìÖ Epoch {{epoch}}/{{args.epochs}}')
    
    # Training
    if args.mode == 'semi_supervised' and unlabeled_loader is not None:
        train_loss = trainer.train_epoch(train_loader, unlabeled_loader, epoch)
    else:
        train_loss = trainer.train_epoch(train_loader, epoch)
    
    # Validation
    val_accuracy = trainer.validate(val_loader)
    
    # Update best accuracy
    is_best = val_accuracy > best_accuracy
    if is_best:
        best_accuracy = val_accuracy
    
    print(f'üìä Epoch {{epoch}} - Train Loss: {{train_loss:.4f}}, Val Acc: {{val_accuracy:.2f}}% (Best: {{best_accuracy:.2f}}%)')
    
    # Save checkpoint
    if epoch % args.save_frequency == 0 or is_best:
        checkpoint_data = {{
            'epoch': epoch,
            'student_state_dict': trainer.student_model.state_dict(),
            'optimizer_state_dict': trainer.optimizer.state_dict(),
            'best_accuracy': best_accuracy,
            'train_loss': train_loss,
            'val_accuracy': val_accuracy
        }}
        
        if hasattr(trainer, 'teacher_model'):
            checkpoint_data['teacher_state_dict'] = trainer.teacher_model.teacher_model.state_dict()
            checkpoint_data['teacher_acc'] = getattr(trainer, 'teacher_accuracy', val_accuracy)
        
        # Save regular checkpoint
        if epoch % args.save_frequency == 0:
            checkpoint_path = os.path.join(args.save_dir, f'checkpoint_epoch_{{epoch}}.pth')
            torch.save(checkpoint_data, checkpoint_path)
            print(f'üíæ Saved checkpoint: {{checkpoint_path}}')
        
        # Save best model
        if is_best:
            best_path = os.path.join(args.save_dir, 'model_best.pth')
            torch.save(checkpoint_data, best_path)
            print(f'üèÜ New best model saved: {{best_path}}')
    
    # W&B logging
    if args.use_wandb:
        wandb.log({{
            'epoch': epoch,
            'train_loss': train_loss,
            'val_accuracy': val_accuracy,
            'best_accuracy': best_accuracy
        }})

print(f'\\nüéâ Training completed!')
print(f'üèÜ Best accuracy: {{best_accuracy:.2f}}%')

if args.use_wandb:
    wandb.finish()
" """

else:
    # Build standard training command without MAE
    training_cmd = f"""python train.py \\
    --mode {TRAINING_CONFIG['mode']} \\
    --data_dir {TRAINING_CONFIG['data_dir']} \\
    --epochs {TRAINING_CONFIG['epochs']} \\
    --batch_size {TRAINING_CONFIG['batch_size']} \\
    --learning_rate {TRAINING_CONFIG['learning_rate']} \\
    --weight_decay {TRAINING_CONFIG['weight_decay']} \\
    --model_name {TRAINING_CONFIG['model_name']} \\
    --consistency_weight {TRAINING_CONFIG['consistency_weight']} \\
    --pseudo_label_threshold {TRAINING_CONFIG['pseudo_label_threshold']} \\
    --temperature {TRAINING_CONFIG['temperature']} \\
    --warmup_epochs {TRAINING_CONFIG['warmup_epochs']} \\
    --ramp_up_epochs {TRAINING_CONFIG['ramp_up_epochs']} \\
    --save_dir {TRAINING_CONFIG['checkpoint_dir']} \\
    --save_frequency {TRAINING_CONFIG['save_frequency']}"""

    # Add resume checkpoint if found
    if RESUME_FROM:
        training_cmd += f" \\\n    --resume_from {RESUME_FROM}"

    # Add pretrained flag
    if TRAINING_CONFIG['pretrained']:
        training_cmd += " \\\n    --pretrained"

    # Add W&B logging
    if TRAINING_CONFIG['use_wandb']:
        training_cmd += " \\\n    --use_wandb"

print("üìã TRAINING CONFIGURATION:")
print("="*60)
print(f"üéØ Training {TRAINING_CONFIG['num_classes']} fish species")
print(f"üìä Mode: {TRAINING_CONFIG['mode']}")
print(f"ü§ñ MAE pretrained: {TRAINING_CONFIG.get('mae_pretrained', False)}")
print(f"üåê ImageNet pretrained: {TRAINING_CONFIG.get('pretrained', True)}")

if RESUME_FROM:
    print(f"üîÑ Resuming from: {os.path.basename(RESUME_FROM)}")
else:
    print(f"üÜï Starting fresh training")

print(f"‚è±Ô∏è Estimated time: {TRAINING_CONFIG['epochs'] * 3 / 60:.1f} hours")
print(f"üíæ Checkpoints: {TRAINING_CONFIG['checkpoint_dir']}")
print(f"üìà W&B logging: {TRAINING_CONFIG['use_wandb']}")

if TRAINING_CONFIG.get('mae_pretrained', False):
    print(f"üéâ Using MAE-learned features from: {os.path.basename(TRAINING_CONFIG.get('mae_model_path', ''))}")
    print(f"üöÄ This should significantly improve training performance!")

print(f"\nüé¨ TRAINING STARTED")
print("‚è∞ Started at:", datetime.now().strftime('%Y-%m-%d %H:%M:%S'))

# Execute training
!{training_cmd}

print("\n" + "="*60)
print("üéâ TRAINING COMPLETED!")
print("‚è∞ Finished at:", datetime.now().strftime('%Y-%m-%d %H:%M:%S'))

# Check for results
best_model_path = os.path.join(TRAINING_CONFIG['checkpoint_dir'], 'model_best.pth')
if os.path.exists(best_model_path):
    try:
        import torch
        checkpoint = torch.load(best_model_path, map_location='cpu')
        if 'best_accuracy' in checkpoint:
            print(f"üèÜ Best accuracy achieved: {checkpoint['best_accuracy']:.2f}%")
        if 'epoch' in checkpoint:
            print(f"üìä Best model from epoch: {checkpoint['epoch']}")
    except:
        pass

print("‚úÖ Your MAE-enhanced model is ready for evaluation and deployment!")

üöÄ STARTING EXTENDED TRAINING SESSION
üìÇ Resuming from: checkpoint_epoch_99.pth
üöÄ Starting training from epoch: 100
üìä Training for 1 more epochs...
üéØ Target: 100 total epochs
‚è±Ô∏è Estimated time: 4-6 minutes
üíæ Checkpoints saved to: /content/drive/MyDrive/ViT-FishID/checkpoints_extended

üìã Extended Training Command:
python train.py 
    --mode semi_supervised 
    --data_dir /content/fish_cutouts 
    --epochs 100 
    --batch_size 16 
    --learning_rate 0.0001 
    --weight_decay 0.05 
    --model_name vit_base_patch16_224 
    --consistency_weight 2.0 
    --pseudo_label_threshold 0.7 
    --temperature 4.0 
    --warmup_epochs 5 
    --ramp_up_epochs 15 
    --save_dir /content/drive/MyDrive/ViT-FishID/checkpoints_extended 
    --save_frequency 1 
    --resume_from /content/drive/MyDrive/ViT-FishID/checkpoints_extended/checkpoint_epoch_99.pth 
    --use_wandb 
    --pretrained

üé¨ TRAINING STARTED - EPOCH 100 TO 100
‚è∞ Started at: 2025-08-15 07:03:08
‚úÖ Comm

## üìä Step 9: Check Training Results

Review the training progress and model performance.

In [None]:
# Check Training Results and Performance
import os
import glob
import torch
from datetime import datetime

print("üìä CHECKING TRAINING RESULTS")
print("="*50)

checkpoint_dir = TRAINING_CONFIG['checkpoint_dir']
print(f"üìÅ Checkpoint directory: {checkpoint_dir}")

if os.path.exists(checkpoint_dir):
    # Find all checkpoints
    checkpoints = glob.glob(os.path.join(checkpoint_dir, '*.pth'))
    
    if checkpoints:
        print(f"‚úÖ Found {len(checkpoints)} checkpoint(s)")
        
        # Sort checkpoints by epoch
        epoch_checkpoints = []
        other_checkpoints = []
        
        for cp in checkpoints:
            basename = os.path.basename(cp)
            if 'epoch_' in basename:
                try:
                    epoch_num = int(basename.split('epoch_')[1].split('.')[0])
                    epoch_checkpoints.append((epoch_num, cp))
                except:
                    other_checkpoints.append(cp)
            else:
                other_checkpoints.append(cp)
        
        # Show epoch progression
        if epoch_checkpoints:
            epoch_checkpoints.sort(key=lambda x: x[0])
            print(f"\nüìà TRAINING PROGRESSION:")
            latest_epoch = epoch_checkpoints[-1][0]
            print(f"  üèÅ Latest epoch: {latest_epoch}")
            print(f"  üìä Completion: {latest_epoch}/{TRAINING_CONFIG['epochs']} epochs ({latest_epoch/TRAINING_CONFIG['epochs']*100:.1f}%)")
            
            # Show recent checkpoints
            recent_checkpoints = epoch_checkpoints[-5:] if len(epoch_checkpoints) > 5 else epoch_checkpoints
            for epoch, cp in recent_checkpoints:
                file_size = os.path.getsize(cp) / (1024**2)
                print(f"  üìÑ Epoch {epoch}: {file_size:.1f} MB")
        
        # Analyze best model
        best_model_path = os.path.join(checkpoint_dir, 'model_best.pth')
        if os.path.exists(best_model_path):
            print(f"\nüèÜ BEST MODEL ANALYSIS:")
            try:
                best_checkpoint = torch.load(best_model_path, map_location='cpu')
                
                best_epoch = best_checkpoint.get('epoch', 'Unknown')
                best_acc = best_checkpoint.get('best_accuracy', best_checkpoint.get('best_acc', 'Unknown'))
                
                print(f"  üìä Best epoch: {best_epoch}")
                if isinstance(best_acc, (int, float)):
                    print(f"  üéØ Best accuracy: {best_acc:.2f}%")
                    
                    # Performance assessment
                    if best_acc >= 85:
                        print("  üéâ EXCELLENT performance!")
                    elif best_acc >= 75:
                        print("  üëç GOOD performance!")
                    elif best_acc >= 65:
                        print("  üìà FAIR performance - consider more training")
                    else:
                        print("  ‚ö†Ô∏è LOW performance - check data and hyperparameters")
                
                # Check for other metrics
                if 'teacher_acc' in best_checkpoint:
                    print(f"  üéì Teacher accuracy: {best_checkpoint['teacher_acc']:.2f}%")
                
            except Exception as e:
                print(f"  ‚ö†Ô∏è Could not analyze best model: {e}")
        
        # Show other important files
        for cp in other_checkpoints:
            basename = os.path.basename(cp)
            file_size = os.path.getsize(cp) / (1024**2)
            print(f"  üìÑ {basename}: {file_size:.1f} MB")
    
    else:
        print("‚ùå No checkpoints found")
        print("üí° Training may not have started or completed successfully")

else:
    print(f"‚ùå Checkpoint directory not found: {checkpoint_dir}")

# W&B results link
if TRAINING_CONFIG['use_wandb']:
    print(f"\nüìà View detailed training metrics at:")
    print(f"   https://wandb.ai/your-username/{TRAINING_CONFIG['wandb_project']}")

print("\n‚úÖ Results check complete!")

üìÅ Checking results in: /content/drive/MyDrive/ViT-FishID/checkpoints_extended

‚úÖ Found 100 checkpoint(s) from extended training:
  üìä Epoch 1: checkpoint_epoch_1.pth (982.4 MB)
  üìä Epoch 2: checkpoint_epoch_2.pth (982.4 MB)
  üìä Epoch 3: checkpoint_epoch_3.pth (982.4 MB)
  üìä Epoch 4: checkpoint_epoch_4.pth (982.4 MB)
  üìä Epoch 5: checkpoint_epoch_5.pth (982.4 MB)
  üìä Epoch 6: checkpoint_epoch_6.pth (982.4 MB)
  üìä Epoch 7: checkpoint_epoch_7.pth (982.4 MB)
  üìä Epoch 8: checkpoint_epoch_8.pth (982.4 MB)
  üìä Epoch 9: checkpoint_epoch_9.pth (982.4 MB)
  üìä Epoch 10: checkpoint_epoch_10.pth (982.4 MB)
  üìä Epoch 11: checkpoint_epoch_11.pth (982.4 MB)
  üìä Epoch 12: checkpoint_epoch_12.pth (982.4 MB)
  üìä Epoch 13: checkpoint_epoch_13.pth (982.4 MB)
  üìä Epoch 14: checkpoint_epoch_14.pth (982.4 MB)
  üìä Epoch 15: checkpoint_epoch_15.pth (982.4 MB)
  üìä Epoch 16: checkpoint_epoch_16.pth (982.4 MB)
  üìä Epoch 17: checkpoint_epoch_17.pth (982.4 MB)


## üíæ Step 10: Save Model and Results

Backup your trained model and results to Google Drive for future use.

In [None]:
# Save trained model and results to Google Drive
import shutil
import json
import os
from datetime import datetime

print("üíæ SAVING MODEL AND RESULTS")
print("="*50)

# Create timestamped backup directory
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_dir = f'/content/drive/MyDrive/ViT-FishID_Results_{timestamp}'

try:
    os.makedirs(backup_dir, exist_ok=True)
    print(f"üìÅ Created backup directory: {backup_dir}")
    
    # Copy checkpoints
    checkpoint_source = TRAINING_CONFIG['checkpoint_dir']
    if os.path.exists(checkpoint_source):
        checkpoint_backup = os.path.join(backup_dir, 'checkpoints')
        shutil.copytree(checkpoint_source, checkpoint_backup, dirs_exist_ok=True)
        print(f"‚úÖ Checkpoints copied to: {checkpoint_backup}")
        
        # Count files
        checkpoint_files = len([f for f in os.listdir(checkpoint_backup) if f.endswith('.pth')])
        print(f"üìä Backed up {checkpoint_files} checkpoint files")
    
    # Save training configuration
    config_file = os.path.join(backup_dir, 'training_config.json')
    serializable_config = {k: v for k, v in TRAINING_CONFIG.items() 
                          if isinstance(v, (str, int, float, bool, list, dict, type(None)))}
    
    with open(config_file, 'w') as f:
        json.dump(serializable_config, f, indent=2)
    print(f"‚úÖ Training config saved: {config_file}")
    
    # Create training summary
    summary_file = os.path.join(backup_dir, 'training_summary.txt')
    with open(summary_file, 'w') as f:
        f.write(f"ViT-FishID Training Summary\n")
        f.write(f"========================\n\n")
        f.write(f"Training Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        f.write(f"Training Mode: {TRAINING_CONFIG['mode']}\n")
        f.write(f"Total Epochs: {TRAINING_CONFIG['epochs']}\n")
        f.write(f"Batch Size: {TRAINING_CONFIG['batch_size']}\n")
        f.write(f"Model: {TRAINING_CONFIG['model_name']}\n")
        f.write(f"Number of Species: {TRAINING_CONFIG['num_classes']}\n")
        f.write(f"Consistency Weight: {TRAINING_CONFIG['consistency_weight']}\n")
        f.write(f"W&B Logging: {TRAINING_CONFIG['use_wandb']}\n\n")
        f.write(f"Key Files:\n")
        f.write(f"- model_best.pth: Best performing model\n")
        f.write(f"- model_latest.pth: Most recent checkpoint\n")
        f.write(f"- checkpoint_epoch_X.pth: Periodic saves\n")
    
    print(f"‚úÖ Training summary saved: {summary_file}")
    
    # Get final model performance
    best_model_path = os.path.join(checkpoint_source, 'model_best.pth')
    if os.path.exists(best_model_path):
        try:
            import torch
            checkpoint = torch.load(best_model_path, map_location='cpu')
            if 'best_accuracy' in checkpoint:
                print(f"üèÜ Final model accuracy: {checkpoint['best_accuracy']:.2f}%")
                
                # Add performance to summary
                with open(summary_file, 'a') as f:
                    f.write(f"\nFinal Performance:\n")
                    f.write(f"- Best Accuracy: {checkpoint['best_accuracy']:.2f}%\n")
                    f.write(f"- Best Epoch: {checkpoint.get('epoch', 'Unknown')}\n")
        except Exception as e:
            print(f"‚ö†Ô∏è Could not read final performance: {e}")
    
    print(f"\nüéâ ALL RESULTS SAVED SUCCESSFULLY!")
    print(f"üìÅ Backup location: {backup_dir}")
    print(f"\nüí° You can now:")
    print(f"   1. Download the entire results folder")
    print(f"   2. Use model_best.pth for inference")
    print(f"   3. Resume training from any checkpoint")
    print(f"   4. Share results with collaborators")

except Exception as e:
    print(f"‚ùå Error saving results: {e}")
    print("üí° Please check Google Drive permissions and available space")

üíæ Saving results to Google Drive: /content/drive/MyDrive/ViT-FishID_Training_20250815_070649
‚úÖ Checkpoints saved to: /content/drive/MyDrive/ViT-FishID_Training_20250815_070649/checkpoints
‚úÖ Training config saved to: /content/drive/MyDrive/ViT-FishID_Training_20250815_070649/training_config.json
‚úÖ Training summary saved to: /content/drive/MyDrive/ViT-FishID_Training_20250815_070649/training_summary.txt

üéâ All results saved to Google Drive!
üìÅ Location: /content/drive/MyDrive/ViT-FishID_Training_20250815_070649

üí° You can now:
   1. Download the checkpoints folder for local use
   2. Use model_best.pth for inference
   3. Continue training from any checkpoint


## üß™ Step 11: Model Evaluation (Optional)

Test your trained model on sample images and get detailed performance metrics.

In [None]:
# Quick model evaluation and testing
import torch
import os

print("üß™ MODEL EVALUATION")
print("="*50)

# Check for trained model
best_model_path = os.path.join(TRAINING_CONFIG['checkpoint_dir'], 'model_best.pth')

if os.path.exists(best_model_path):
    print(f"‚úÖ Found trained model: {os.path.basename(best_model_path)}")
    
    try:
        # Load model checkpoint
        checkpoint = torch.load(best_model_path, map_location='cpu')
        
        print(f"\nüìä MODEL PERFORMANCE:")
        if 'epoch' in checkpoint:
            print(f"  üèÜ Best epoch: {checkpoint['epoch']}")
        if 'best_accuracy' in checkpoint:
            print(f"  üéØ Best accuracy: {checkpoint['best_accuracy']:.2f}%")
        if 'teacher_acc' in checkpoint:
            print(f"  üéì Teacher accuracy: {checkpoint['teacher_acc']:.2f}%")
        
        # Model architecture info
        if 'num_classes' in checkpoint:
            print(f"  üêü Number of species: {checkpoint['num_classes']}")
        
        # File size
        file_size = os.path.getsize(best_model_path) / (1024**2)
        print(f"  üìè Model size: {file_size:.1f} MB")
        
        # Performance assessment
        if 'best_accuracy' in checkpoint:
            accuracy = checkpoint['best_accuracy']
            if accuracy >= 85:
                print(f"\nüéâ EXCELLENT PERFORMANCE!")
                print(f"   Your model achieved outstanding accuracy for fish classification")
            elif accuracy >= 75:
                print(f"\nüëç GOOD PERFORMANCE!")
                print(f"   Your model shows solid accuracy for practical use")
            elif accuracy >= 65:
                print(f"\nüìà FAIR PERFORMANCE")
                print(f"   Consider additional training or hyperparameter tuning")
            else:
                print(f"\n‚ö†Ô∏è PERFORMANCE NEEDS IMPROVEMENT")
                print(f"   Review data quality and training configuration")
    
    except Exception as e:
        print(f"‚ùå Error loading model: {e}")

else:
    print(f"‚ùå No trained model found at: {best_model_path}")
    print("Please ensure training completed successfully")

# Suggest next steps
print(f"\nüöÄ NEXT STEPS:")
print(f"1. üß™ Run detailed evaluation: Use evaluate.py script")
print(f"2. üî¨ Test on new images: Upload test images and run inference")
print(f"3. üì± Deploy model: Use for real-world fish classification")
print(f"4. üìä Analyze results: Review confusion matrix and per-species performance")
print(f"5. üîÑ Continue training: Resume from checkpoints for more epochs")

print(f"\n‚úÖ Evaluation complete!")

üß™ Looking for best model at: /content/drive/MyDrive/ViT-FishID/checkpoints_extended/checkpoint_epoch_100.pth
‚úÖ Found trained model.
üß™ Loading trained model for quick evaluation...
üìä Model training info:
  - Best epoch: 100
  - Best accuracy: 87.56%
  - Number of classes (from checkpoint): 37

‚úÖ Model loading and info check completed.
üí° Note: This step confirms the model file exists and can be loaded.
   Actual inference or evaluation on test data is done separately.

üí° For comprehensive evaluation:
   Use the evaluate.py script with your test dataset
   The test set was automatically created during training


## üîß Troubleshooting Guide

### Common Issues and Solutions:

**üö´ GPU Memory Error (CUDA out of memory)**
- Reduce `batch_size` from 16 to 8 or 4
- Restart runtime: `Runtime ‚Üí Restart runtime`
- Clear GPU cache: Run `torch.cuda.empty_cache()`

**üìÅ Data Not Found Error**
- Verify `fish_cutouts.zip` is uploaded to Google Drive root
- Check dataset structure has `labeled/` and `unlabeled/` folders
- Re-run Step 5 to extract dataset

**‚è∞ Training Timeout (Colab disconnection)**
- Use Colab Pro for longer sessions (up to 24 hours)
- Enable background execution: `Runtime ‚Üí Change runtime type`
- Checkpoints auto-save every 10 epochs for resuming

**üìâ Low Training Accuracy**
- Increase training epochs (try 150-200)
- Adjust `consistency_weight` (try 1.0-3.0)
- Lower `pseudo_label_threshold` (try 0.5-0.6)
- Check data quality and balance

**üîó W&B Connection Issues**
- Get API key from: https://wandb.ai/settings
- Set as Colab secret: `Tools ‚Üí Secrets`
- Training continues without W&B if connection fails

**üíæ Google Drive Mount Problems**
- Re-run Step 2 to remount
- Check Google Drive permissions
- Use local fallback directories if needed

## üéâ Summary and Next Steps

### üèÜ What You've Accomplished:

‚úÖ **Complete Semi-Supervised Training Pipeline**
- Vision Transformer (ViT) for fish classification
- Semi-supervised learning with labeled + unlabeled data
- EMA teacher-student framework for consistency training
- Automatic checkpointing and progress tracking

‚úÖ **Model Performance**
- Expected accuracy: 80-90% on fish species classification
- Robust to limited labeled data through semi-supervised learning
- Production-ready model saved to Google Drive

### üìÅ Important Files Created:

- **`model_best.pth`**: Best performing model (use for inference)
- **`model_latest.pth`**: Most recent checkpoint
- **`checkpoint_epoch_X.pth`**: Periodic saves for resuming
- **`training_config.json`**: Complete training configuration
- **`training_summary.txt`**: Human-readable training report

### üöÄ Next Steps:

1. **üß™ Detailed Evaluation**
   ```python
   # Run comprehensive evaluation
   !python evaluate.py --data_dir /content/fish_cutouts --model_path model_best.pth
   ```

2. **üî¨ Test on New Images**
   - Upload new fish images
   - Run inference using your trained model
   - Analyze predictions and confidence scores

3. **üì± Deploy Your Model**
   - Download `model_best.pth` to local machine
   - Integrate into web app or mobile application
   - Use for real-world fish species identification

4. **üîÑ Continue Training (if needed)**
   ```python
   # Resume from any checkpoint for more epochs
   --resume_from checkpoint_epoch_100.pth --epochs 150
   ```

5. **üìä Experiment and Improve**
   - Try different hyperparameters
   - Collect more training data
   - Experiment with data augmentation

### üéØ Expected Performance:
- **Accuracy**: 80-90% on test set
- **Inference Speed**: ~50-100ms per image
- **Model Size**: ~300MB
- **Production Ready**: Yes! üéâ

**Congratulations on training your fish classification model! üêüüéä**

## üìà Step 7b: Connect to Weights & Biases (Optional)

Log in to Weights & Biases for experiment tracking and visualization. You will be prompted to enter your API key.

## üíæ Step 8b: Explicitly Save Best Model Backup

This step ensures that `model_best.pth` is copied to a dedicated backup location in Google Drive immediately after training completes.

In [18]:
# Explicitly copy model_best.pth to a backup location
import shutil
import os
from datetime import datetime

print("üíæ Explicitly backing up model_best.pth...")

# Get the primary checkpoint directory from TRAINING_CONFIG
checkpoint_dir = TRAINING_CONFIG.get('checkpoint_dir')

if checkpoint_dir and os.path.exists(checkpoint_dir):
    best_model_source_path = os.path.join(checkpoint_dir, 'checkpoint_epoch_100.pth')

    if os.path.exists(best_model_source_path):
        # Define a dedicated backup directory path in Google Drive
        # Using a simpler path than the full Step 10 save for quick verification
        backup_base_dir = '/content/drive/MyDrive/ViT-FishID_BestModel_Backups'
        os.makedirs(backup_base_dir, exist_ok=True)

        # Create a timestamped filename for the backup
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        backup_filename = f"model_best_backup_{timestamp}.pth"
        backup_dest_path = os.path.join(backup_base_dir, backup_filename)

        try:
            shutil.copy2(best_model_source_path, backup_dest_path)
            print(f"‚úÖ Successfully copied model_best.pth to backup:")
            print(f"   üìÅ Source: {best_model_source_path}")
            print(f"   üíæ Destination: {backup_dest_path}")
            print(f"   üìè Size: {os.path.getsize(backup_dest_path) / (1024**2):.1f} MB")
            print("üéâ Please check your Google Drive in the 'ViT-FishID_BestModel_Backups' folder!")

        except Exception as e:
            print(f"‚ùå Error copying model_best.pth to backup: {e}")
            print("Please check your Google Drive connection and permissions.")

    else:
        print(f"‚ö†Ô∏è model_best.pth not found in the primary checkpoint directory: {checkpoint_dir}")
        print("   This means training likely did not complete successfully or the best model wasn't saved.")

else:
    print("‚ùå Primary checkpoint directory not found or TRAINING_CONFIG is not set.")
    print("   Please ensure Step 7 is run before this step.")

print("\nüíæ Explicit backup step complete.")

üíæ Explicitly backing up model_best.pth...
‚úÖ Successfully copied model_best.pth to backup:
   üìÅ Source: /content/drive/MyDrive/ViT-FishID/checkpoints_extended/checkpoint_epoch_100.pth
   üíæ Destination: /content/drive/MyDrive/ViT-FishID_BestModel_Backups/model_best_backup_20250815_075025.pth
   üìè Size: 982.4 MB
üéâ Please check your Google Drive in the 'ViT-FishID_BestModel_Backups' folder!

üíæ Explicit backup step complete.


## üìä Step 12: Evaluate Model on Test Dataset

This step runs the `evaluate.py` script to assess the performance of your trained model on the unseen test dataset.

In [31]:
# Run evaluation script
import os
import fileinput # Import fileinput for modifying files

print("üß™ Starting evaluation on the test dataset...")
print("="*50)

# Define the path to the evaluation script relative to the repo root
eval_script_name = 'evaluate.py'
repo_dir = '/content/ViT-FishID'
eval_script_path = os.path.join(repo_dir, eval_script_name)


# Define the path to the trained model checkpoint
# Using the epoch 100 checkpoint as it has the best recorded accuracy
model_checkpoint_path = '/content/drive/MyDrive/ViT-FishID/checkpoints_extended/checkpoint_epoch_100.pth'

# Define the data directory (from Step 5)
data_directory = DATA_DIR # Ensure DATA_DIR is defined from Step 5

# Check if the evaluation script and model checkpoint exist
if not os.path.exists(eval_script_path):
    print(f"‚ùå Evaluation script not found at: {eval_script_path}")
    print(f"Please ensure the ViT-FishID repository was cloned correctly in Step 4 to {repo_dir}.")
elif not os.path.exists(model_checkpoint_path):
     print(f"‚ùå Model checkpoint not found at: {model_checkpoint_path}")
     print("Please ensure training completed successfully and the checkpoint exists.")
elif not os.path.exists(data_directory):
     print(f"‚ùå Data directory not found at: {data_directory}")
     print("Please ensure Step 5 was run correctly.")
else:
    print(f"‚úÖ Found evaluation script: {eval_script_path}")
    print(f"‚úÖ Found model checkpoint: {model_checkpoint_path}")
    print(f"‚úÖ Found data directory: {data_directory}")

    # --- FIX 1: Modify evaluate.py to correct the vit_model import statement ---
    print(f"\nüîß Correcting import statement for ViTForFishClassification in {eval_script_name}...")
    try:
        with fileinput.FileInput(eval_script_path, inplace=True) as file:
            for line in file:
                # Replace 'from vit_model import' with 'from model import'
                # Do NOT print anything else here
                print(line.replace('from vit_model import ViTForFishClassification', 'from model import ViTForFishClassification'), end='')
        print(f"‚úÖ Corrected import statement for ViTForFishClassification in {eval_script_name}.")
    except Exception as e:
        print(f"‚ùå Error modifying ViTForFishClassification import in {eval_script_name}: {e}")
        print("üö® Evaluation might still fail due to this import error.")
    # --- End of FIX 1 ---

    # --- FIX 2: Modify evaluate.py to comment out the ema_teacher import ---
    print(f"\nüîß Commenting out import statement for EMATeacher in {eval_script_name}...")
    try:
        with fileinput.FileInput(eval_script_path, inplace=True) as file:
            for line in file:
                # Comment out 'from ema_teacher import EMATeacher'
                # Do NOT print anything else here
                if 'from ema_teacher import EMATeacher' in line:
                     print("# " + line, end='') # Add # to comment out the line
                else:
                    print(line, end='')
        print(f"‚úÖ Commented out import statement for EMATeacher in {eval_script_name}.")
    except Exception as e:
        print(f"‚ùå Error commenting out EMATeacher import in {eval_script_name}: {e}")
        print("üö® Evaluation might still fail due to this import error.")
    # --- End of FIX 2 ---

    # --- FIX 3: Modify evaluate.py to correct the data_loader import statement ---
    print(f"\nüîß Correcting import statement for create_fish_dataloaders in {eval_script_name}...")
    try:
        with fileinput.FileInput(eval_script_path, inplace=True) as file:
            for line in file:
                # Replace 'from data_loader import' with 'from data import'
                # Do NOT print anything else here
                print(line.replace('from data_loader import create_fish_dataloaders', 'from data import create_fish_dataloaders'), end='')
        print(f"‚úÖ Corrected import statement for create_fish_dataloaders in {eval_script_name}.")
    except Exception as e:
        print(f"‚ùå Error modifying create_fish_dataloaders import in {eval_script_name}: {e}")
        print("üö® Evaluation might still fail due to this import error.")
    # --- End of FIX 3 ---


    # Construct the evaluation command
    # Use PYTHONPATH to help the script find local modules like model
    # Use %cd before and after, but rely on PYTHONPATH for the import
    eval_cmd = f"PYTHONPATH={repo_dir} python {eval_script_name} --data_dir {data_directory} --model_path {model_checkpoint_path}"


    print("\nüìã Evaluation Command:")
    # Print the command cleanly without the PYTHONPATH for readability, but it's included in the execution
    print(f"python {eval_script_name} --data_dir {data_directory} --model_path {model_checkpoint_path} (with PYTHONPATH={repo_dir})")
    print("\n" + "="*50)

    print("üöÄ Running evaluation...")
    # Change to the repository directory before executing
    %cd {repo_dir}

    # Execute the evaluation script with PYTHONPATH set
    !{eval_cmd}

    # Change back to original content directory (optional but good practice)
    %cd /content

    print("\n" + "="*50)
    print("üéâ Evaluation complete!")

print("\nüí° Check the output above for accuracy metrics on the test set.")

üß™ Starting evaluation on the test dataset...
‚úÖ Found evaluation script: /content/ViT-FishID/evaluate.py
‚úÖ Found model checkpoint: /content/drive/MyDrive/ViT-FishID/checkpoints_extended/checkpoint_epoch_100.pth
‚úÖ Found data directory: /content/fish_cutouts

üîß Correcting import statement for ViTForFishClassification in evaluate.py...
‚úÖ Corrected import statement for ViTForFishClassification in evaluate.py.

üîß Commenting out import statement for EMATeacher in evaluate.py...
‚úÖ Commented out import statement for EMATeacher in evaluate.py.

üîß Correcting import statement for create_fish_dataloaders in evaluate.py...
‚úÖ Corrected import statement for create_fish_dataloaders in evaluate.py.

üìã Evaluation Command:
python evaluate.py --data_dir /content/fish_cutouts --model_path /content/drive/MyDrive/ViT-FishID/checkpoints_extended/checkpoint_epoch_100.pth (with PYTHONPATH=/content/ViT-FishID)

üöÄ Running evaluation...
/content/ViT-FishID
2025-08-15 08:01:40.428842: I

## üîç Step 12b: Diagnose `ModuleNotFoundError`

This step checks the file structure and import statements to understand why `vit_model` is not being found.

In [27]:
import os

print("üîç Diagnosing ModuleNotFoundError...")
repo_dir = '/content/ViT-FishID'
eval_script_path = os.path.join(repo_dir, 'evaluate.py')
model_file_guess = os.path.join(repo_dir, 'model.py') # Common name for model file
vit_model_file_guess = os.path.join(repo_dir, 'vit_model.py') # Guessed name based on import

print(f"Repo directory: {repo_dir}")

print("\nüìÇ Files in repository root:")
# List files in the repository root
if os.path.exists(repo_dir):
    !ls -la {repo_dir}
else:
    print(f"‚ùå Repository directory not found: {repo_dir}")


print(f"\nüìÑ Content of {os.path.basename(eval_script_path)} (checking import):")
# Read and print the content of evaluate.py
if os.path.exists(eval_script_path):
    try:
        with open(eval_script_path, 'r') as f:
            lines = f.readlines()
            for i, line in enumerate(lines):
                if 'import vit_model' in line or 'from vit_model' in line:
                    print(f"  Line {i+1}: {line.strip()}")
                elif 'ViTForFishClassification' in line:
                     print(f"  Line {i+1}: {line.strip()} (contains class name)")
                if i < 20: # Print first 20 lines for context
                     print(f"  Line {i+1}: {line.strip()}")


    except Exception as e:
        print(f"‚ùå Could not read {eval_script_path}: {e}")
else:
    print(f"‚ùå {eval_script_path} not found.")


print(f"\nüìÑ Checking potential model file: {os.path.basename(model_file_guess)}")
# Check if model.py exists and print relevant lines
if os.path.exists(model_file_guess):
    try:
        with open(model_file_guess, 'r') as f:
            lines = f.readlines()
            print(f"‚úÖ Found {os.path.basename(model_file_guess)}. Checking for class definition...")
            found_class = False
            for i, line in enumerate(lines):
                 if 'class ViTForFishClassification' in line:
                      print(f"  Line {i+1}: {line.strip()}")
                      found_class = True
                      break # Found the class, stop searching

            if not found_class:
                 print(f"‚ö†Ô∏è 'ViTForFishClassification' class definition not found in {os.path.basename(model_file_guess)}")

    except Exception as e:
        print(f"‚ùå Could not read {model_file_guess}: {e}")
else:
    print(f"‚ùì {os.path.basename(model_file_guess)} not found. Checking alternative name...")

print(f"\nüìÑ Checking alternative model file: {os.path.basename(vit_model_file_guess)}")
# Check if vit_model.py exists and print relevant lines
if os.path.exists(vit_model_file_guess):
    try:
        with open(vit_model_file_guess, 'r') as f:
            lines = f.readlines()
            print(f"‚úÖ Found {os.path.basename(vit_model_file_guess)}. Checking for class definition...")
            found_class = False
            for i, line in enumerate(lines):
                 if 'class ViTForFishClassification' in line:
                      print(f"  Line {i+1}: {line.strip()}")
                      found_class = True
                      break # Found the class, stop searching

            if not found_class:
                 print(f"‚ö†Ô∏è 'ViTForFishClassification' class definition not found in {os.path.basename(vit_model_file_guess)}")


    except Exception as e:
        print(f"‚ùå Could not read {vit_model_file_guess}: {e}")
else:
    print(f"‚ùì {os.path.basename(vit_model_file_guess)} not found.")

print("\nDiagnosis steps complete. Please review the output.")

üîç Diagnosing ModuleNotFoundError...
Repo directory: /content/ViT-FishID

üìÇ Files in repository root:
total 368
drwxr-xr-x 6 root root   4096 Aug 15 07:03 .
drwxr-xr-x 1 root root   4096 Aug 15 06:58 ..
-rw-r--r-- 1 root root  21217 Aug 15 06:58 data.py
-rw-r--r-- 1 root root  11572 Aug 15 06:58 evaluate.py
-rw-r--r-- 1 root root   3328 Aug 15 06:58 EXTENDED_TRAINING_SETUP.md
drwxr-xr-x 2 root root   4096 Aug 15 06:58 fish_cutouts
drwxr-xr-x 8 root root   4096 Aug 15 06:58 .git
-rw-r--r-- 1 root root     66 Aug 15 06:58 .gitattributes
-rw-r--r-- 1 root root    646 Aug 15 06:58 .gitignore
-rw-r--r-- 1 root root   9495 Aug 15 06:58 model.py
-rw-r--r-- 1 root root  16771 Aug 15 06:58 pipeline.py
drwxr-xr-x 2 root root   4096 Aug 15 07:03 __pycache__
-rw-r--r-- 1 root root  16566 Aug 15 06:58 README.md
-rw-r--r-- 1 root root    202 Aug 15 06:58 requirements.txt
-rw-r--r-- 1 root root   4265 Aug 15 06:58 resume_training.py
-rw-r--r-- 1 root root   5134 Aug 15 06:58 species_mapping.txt
