<a href="https://colab.research.google.com/github/cat-thomson/ViT-FishID/blob/main/ViT_FishID_Colab_Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 🐟 ViT-FishID: Semi-Supervised Training in Google Colab

This notebook demonstrates how to train the ViT-FishID model using Google Colab's free GPU resources.

**Expected Performance:**
- Training Time: 2-3 hours for 50 epochs on Tesla T4
- Memory Usage: ~6-8GB GPU memory
- Accuracy: ~75-85% validation accuracy

**Requirements:**
- Google account with Google Drive access
- Fish dataset uploaded to Google Drive
- GPU runtime enabled in Colab

## 🚀 Step 1: Setup and GPU Check

In [1]:
# Check GPU availability
import torch
import os

print("🔍 System Information:")
print(f"Python version: {os.sys.version}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"GPU Device: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
    print("✅ GPU is ready for training!")
else:
    print("❌ No GPU detected. Please enable GPU runtime:")
    print("   Runtime → Change runtime type → Hardware accelerator → GPU")

🔍 System Information:
Python version: 3.11.13 (main, Jun  4 2025, 08:57:29) [GCC 11.4.0]
PyTorch version: 2.6.0+cu124
CUDA available: True
GPU Device: Tesla T4
GPU Memory: 14.7 GB
✅ GPU is ready for training!


## 📁 Step 2: Mount Google Drive

This will give us access to your fish dataset stored in Google Drive.

In [2]:
from google.colab import drive
import os

# Mount Google Drive
drive.mount('/content/drive')

# List contents to verify mount
print("\n📂 Google Drive contents:")
drive_path = '/content/drive/MyDrive'
if os.path.exists(drive_path):
    items = os.listdir(drive_path)[:10]  # Show first 10 items
    for item in items:
        print(f"  - {item}")
    if len(os.listdir(drive_path)) > 10:
        print(f"  ... and {len(os.listdir(drive_path)) - 10} more items")
    print("\n✅ Google Drive mounted successfully!")
else:
    print("❌ Failed to mount Google Drive")

Mounted at /content/drive

📂 Google Drive contents:
  - Mock Matric
  - Photos
  - Admin
  - Uni
  - Fish_Training_Output
  - fish_cutouts.zip
  - Colab Notebooks

✅ Google Drive mounted successfully!


## 📦 Step 3: Install Dependencies

Installing all required packages for ViT-FishID training.

In [3]:
# Install required packages
print("📦 Installing dependencies...")

!pip install -q torch torchvision torchaudio
!pip install -q timm transformers
!pip install -q albumentations
!pip install -q wandb
!pip install -q opencv-python-headless
!pip install -q scikit-learn
!pip install -q matplotlib seaborn
!pip install -q tqdm

print("✅ All dependencies installed successfully!")

# Verify installations
import torch
import torchvision
import timm
import albumentations
import cv2
import sklearn

print("\n📋 Package versions:")
print(f"  - torch: {torch.__version__}")
print(f"  - torchvision: {torchvision.__version__}")
print(f"  - timm: {timm.__version__}")
print(f"  - albumentations: {albumentations.__version__}")
print(f"  - opencv: {cv2.__version__}")
print(f"  - sklearn: {sklearn.__version__}")

📦 Installing dependencies...
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m120.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m96.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m54.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m6.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m13.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m

## 🔄 Step 4: Clone ViT-FishID Repository

Getting the latest code from your GitHub repository.

In [4]:
# Clone the repository
import os

# Remove existing directory if it exists
if os.path.exists('/content/ViT-FishID'):
    !rm -rf /content/ViT-FishID

# Clone the repository
print("📥 Cloning ViT-FishID repository...")
!git clone https://github.com/cat-thomson/ViT-FishID.git /content/ViT-FishID

# Change to project directory
%cd /content/ViT-FishID

# List project files
print("\n📂 Project structure:")
!ls -la

print("\n✅ Repository cloned successfully!")

📥 Cloning ViT-FishID repository...
Cloning into '/content/ViT-FishID'...
remote: Enumerating objects: 88, done.[K
remote: Counting objects: 100% (88/88), done.[K
remote: Compressing objects: 100% (66/66), done.[K
remote: Total 88 (delta 25), reused 79 (delta 16), pack-reused 0 (from 0)[K
Receiving objects: 100% (88/88), 147.82 KiB | 2.69 MiB/s, done.
Resolving deltas: 100% (25/25), done.
/content/ViT-FishID

📂 Project structure:
total 264
drwxr-xr-x 6 root root  4096 Aug 13 11:21 .
drwxr-xr-x 1 root root  4096 Aug 13 11:21 ..
drwxr-xr-x 2 root root  4096 Aug 13 11:21 backup_old_files
-rw-r--r-- 1 root root  6938 Aug 13 11:21 COLAB_GUIDE.md
-rw-r--r-- 1 root root 11892 Aug 13 11:21 colab_setup.py
-rw-r--r-- 1 root root  5452 Aug 13 11:21 COLAB_SETUP_SUMMARY.md
-rw-r--r-- 1 root root 19713 Aug 13 11:21 data.py
-rw-r--r-- 1 root root 11572 Aug 13 11:21 evaluate.py
drwxr-xr-x 2 root root  4096 Aug 13 11:21 fish_cutouts
drwxr-xr-x 2 root root  4096 Aug 13 11:21 Frames
drwxr-xr-x 8 root 

## 🗂️ Step 5: Setup Data Path and Extraction

**IMPORTANT:** Specify the path to your fish dataset ZIP file in Google Drive.

This step will:
1. Locate your `fish_cutouts.zip` file in Google Drive
2. Extract it to Colab's local storage for faster access
3. Validate the data structure

Expected structure after extraction:
```
fish_cutouts/
├── labeled/
│   ├── species_1/
│   │   ├── fish_001.jpg
│   │   └── fish_002.jpg
│   └── species_2/
│       └── ...
└── unlabeled/
    ├── fish_003.jpg
    └── fish_004.jpg
```

In [9]:
# MODIFY THIS PATH to point to your fish_cutouts.zip file in Google Drive
ZIP_FILE_PATH = '/content/drive/MyDrive/fish_cutouts.zip'  # 👈 Your ZIP file path

# Alternative common paths (uncomment the one that matches your setup):
# ZIP_FILE_PATH = '/content/drive/MyDrive/ViT-FishID/fish_cutouts.zip'
# ZIP_FILE_PATH = '/content/drive/MyDrive/datasets/fish_cutouts.zip'
# ZIP_FILE_PATH = '/content/drive/MyDrive/data/fish_cutouts.zip'

# Local extraction directory (will be created in Colab)
DATA_DIR = '/content/fish_cutouts'

import zipfile
import shutil
import time

print(f"🎯 ZIP file location: {ZIP_FILE_PATH}")
print(f"📁 Extraction target: {DATA_DIR}")

# Check if ZIP file exists
if os.path.exists(ZIP_FILE_PATH):
    print("✅ ZIP file found!")

    # Get ZIP file size
    zip_size_mb = os.path.getsize(ZIP_FILE_PATH) / (1024 * 1024)
    print(f"📦 ZIP file size: {zip_size_mb:.1f} MB")

    # Remove existing extracted data if present
    if os.path.exists(DATA_DIR):
        print("🧹 Removing existing extracted data...")
        shutil.rmtree(DATA_DIR)

    # Extract ZIP file
    print("📤 Extracting ZIP file to local storage...")
    print("⏳ This may take a few minutes for large datasets...")

    start_time = time.time()

    try:
        with zipfile.ZipFile(ZIP_FILE_PATH, 'r') as zip_ref:
            # Extract to temporary location first
            temp_extract_dir = '/content/temp_extract'
            if os.path.exists(temp_extract_dir):
                shutil.rmtree(temp_extract_dir)
            zip_ref.extractall(temp_extract_dir)

            # DEBUG: Show what was extracted
            extracted_items = os.listdir(temp_extract_dir)
            print(f"🔍 DEBUG - Extracted items: {extracted_items}")

            # Look for the actual dataset (skip macOS artifacts)
            data_candidates = [item for item in extracted_items
                             if not item.startswith('.') and not item.startswith('__MACOSX')]
            print(f"🔍 DEBUG - Data candidates: {data_candidates}")

            # Find the directory that contains labeled/ and unlabeled/
            dataset_found = False

            for candidate in data_candidates:
                candidate_path = os.path.join(temp_extract_dir, candidate)
                if os.path.isdir(candidate_path):
                    print(f"🔍 DEBUG - Checking {candidate}: {os.listdir(candidate_path)}")

                    # Check if this directory has labeled/ and unlabeled/
                    if 'labeled' in os.listdir(candidate_path) and 'unlabeled' in os.listdir(candidate_path):
                        print(f"✅ Found dataset in: {candidate}")
                        shutil.move(candidate_path, DATA_DIR)
                        dataset_found = True
                        break

                    # Check one level deeper
                    subdirs = [d for d in os.listdir(candidate_path)
                             if os.path.isdir(os.path.join(candidate_path, d)) and not d.startswith('.') and not d.startswith('__')]

                    for subdir in subdirs:
                        subdir_path = os.path.join(candidate_path, subdir)
                        print(f"🔍 DEBUG - Checking {candidate}/{subdir}: {os.listdir(subdir_path)}")

                        if 'labeled' in os.listdir(subdir_path) and 'unlabeled' in os.listdir(subdir_path):
                            print(f"✅ Found dataset one level deeper in: {candidate}/{subdir}")
                            shutil.move(subdir_path, DATA_DIR)
                            dataset_found = True
                            break

                    if dataset_found:
                        break

            if not dataset_found:
                print("⚠️  Could not find labeled/ and unlabeled/ folders, moving first candidate")
                if data_candidates:
                    shutil.move(os.path.join(temp_extract_dir, data_candidates[0]), DATA_DIR)
                else:
                    shutil.move(temp_extract_dir, DATA_DIR)

            # Clean up temp directory if it still exists
            if os.path.exists(temp_extract_dir):
                shutil.rmtree(temp_extract_dir)

        extraction_time = time.time() - start_time
        print(f"✅ Extraction completed in {extraction_time:.1f} seconds!")

    except zipfile.BadZipFile:
        print("❌ Error: Invalid ZIP file format")
        print("Please check that your file is a valid ZIP archive")
    except Exception as e:
        print(f"❌ Error during extraction: {str(e)}")
        print("Please check the ZIP file path and try again")

else:
    print("❌ ZIP file not found!")
    print("\n🔧 To fix this:")
    print("1. Upload your fish_cutouts.zip file to Google Drive")
    print("2. Update the ZIP_FILE_PATH variable above with the correct path")
    print("3. Make sure the file name is exactly 'fish_cutouts.zip'")
    print("\n💡 Common locations to check:")
    print("   - /content/drive/MyDrive/fish_cutouts.zip")
    print("   - /content/drive/MyDrive/ViT-FishID/fish_cutouts.zip")
    print("   - /content/drive/MyDrive/datasets/fish_cutouts.zip")

# Now validate the extracted data
print(f"\n📊 Validating extracted dataset...")

# Check if data directory exists
if os.path.exists(DATA_DIR):
    print("✅ Data directory found!")

    # DEBUG: Show the actual structure
    print("🔍 DEBUG - Final directory structure:")
    for item in os.listdir(DATA_DIR):
        item_path = os.path.join(DATA_DIR, item)
        if os.path.isdir(item_path):
            print(f"  📂 {item}/")
            try:
                subitems = os.listdir(item_path)[:10]  # Show first 10 items
                for subitem in subitems:
                    subitem_path = os.path.join(item_path, subitem)
                    if os.path.isdir(subitem_path):
                        print(f"    📂 {subitem}/")
                    else:
                        print(f"    📄 {subitem}")
                if len(os.listdir(item_path)) > 10:
                    print(f"    ... and {len(os.listdir(item_path)) - 10} more items")
            except (PermissionError, OSError):
                print("    (cannot read contents)")
        else:
            print(f"  📄 {item}")

    # Show directory size
    def get_dir_size(path):
        total = 0
        for dirpath, dirnames, filenames in os.walk(path):
            for filename in filenames:
                total += os.path.getsize(os.path.join(dirpath, filename))
        return total / (1024 * 1024)  # Convert to MB

    dir_size_mb = get_dir_size(DATA_DIR)
    print(f"📏 Extracted dataset size: {dir_size_mb:.1f} MB")

    # Check for labeled and unlabeled subdirectories
    labeled_dir = os.path.join(DATA_DIR, 'labeled')
    unlabeled_dir = os.path.join(DATA_DIR, 'unlabeled')

    if os.path.exists(labeled_dir) and os.path.exists(unlabeled_dir):
        print("✅ Semi-supervised structure detected (labeled/ and unlabeled/ folders)")

        # Count classes and samples
        classes = [d for d in os.listdir(labeled_dir) if os.path.isdir(os.path.join(labeled_dir, d))]
        print(f"📊 Found {len(classes)} species classes")

        # Count labeled samples
        labeled_count = 0
        sample_classes = classes[:5]  # Show first 5 classes
        for class_dir in sample_classes:
            class_path = os.path.join(labeled_dir, class_dir)
            if os.path.isdir(class_path):
                class_samples = len([f for f in os.listdir(class_path)
                                   if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
                labeled_count += class_samples
                print(f"  - {class_dir}: {class_samples} samples")

        if len(classes) > 5:
            # Count remaining classes
            remaining_count = 0
            for class_dir in classes[5:]:
                class_path = os.path.join(labeled_dir, class_dir)
                if os.path.isdir(class_path):
                    class_samples = len([f for f in os.listdir(class_path)
                                       if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
                    remaining_count += class_samples
            labeled_count += remaining_count
            print(f"  ... and {len(classes) - 5} more classes with {remaining_count} samples")

        print(f"📊 Total labeled samples: {labeled_count:,}")

        # Count unlabeled samples
        unlabeled_files = [f for f in os.listdir(unlabeled_dir)
                          if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
        print(f"📊 Unlabeled samples: {len(unlabeled_files):,}")

        print(f"\n🎯 Dataset ready for training!")
        print(f"   Total samples: {labeled_count + len(unlabeled_files):,}")
        print(f"   Labeled: {labeled_count:,} ({labeled_count/(labeled_count + len(unlabeled_files))*100:.1f}%)")
        print(f"   Unlabeled: {len(unlabeled_files):,} ({len(unlabeled_files)/(labeled_count + len(unlabeled_files))*100:.1f}%)")

    elif any(os.path.isdir(os.path.join(DATA_DIR, d)) for d in os.listdir(DATA_DIR)):
        print("ℹ️  Supervised structure detected (species folders directly in data dir)")
        classes = [d for d in os.listdir(DATA_DIR) if os.path.isdir(os.path.join(DATA_DIR, d))]
        print(f"📊 Found {len(classes)} species classes")
        print("⚠️  Note: For semi-supervised training, organize data into labeled/ and unlabeled/ folders")

        # Let's see what these 2 classes actually are
        print("� DEBUG - The 2 'classes' found are:")
        for class_dir in classes:
            class_path = os.path.join(DATA_DIR, class_dir)
            print(f"  📂 {class_dir}/")
            try:
                items = os.listdir(class_path)[:5]
                for item in items:
                    print(f"    - {item}")
                if len(os.listdir(class_path)) > 5:
                    print(f"    ... and {len(os.listdir(class_path)) - 5} more items")
            except:
                print("    (cannot read contents)")
    else:
        print("❌ No valid data structure found in extracted files")
        print("Expected: labeled/ and unlabeled/ subdirectories")
        print("\n� Current structure:")
        for item in os.listdir(DATA_DIR):
            print(f"  - {item}")
else:
    print("❌ Data extraction failed or directory not found!")
    print("Please check the ZIP file and extraction process above")

🎯 ZIP file location: /content/drive/MyDrive/fish_cutouts.zip
📁 Extraction target: /content/fish_cutouts
✅ ZIP file found!
📦 ZIP file size: 217.9 MB
🧹 Removing existing extracted data...
📤 Extracting ZIP file to local storage...
⏳ This may take a few minutes for large datasets...
🔍 DEBUG - Extracted items: ['__MACOSX', 'fish_cutouts']
🔍 DEBUG - Data candidates: ['fish_cutouts']
🔍 DEBUG - Checking fish_cutouts: ['dataset_info.json', 'unlabeled', 'labeled']
✅ Found dataset in: fish_cutouts
✅ Extraction completed in 8.4 seconds!

📊 Validating extracted dataset...
✅ Data directory found!
🔍 DEBUG - Final directory structure:
  📄 dataset_info.json
  📂 unlabeled/
    📄 21-05_WILD_061_R (2)_21-05_WILD_061_R (2)_frame_7317_det10_cls0.jpg
    📄 21-05_WILD_065_L (1)_21-05_WILD_065_L (1)_frame_8983_det00_cls31.jpg
    📄 17-04_NSC-S_080_17-04_NSC-S_080_frame_17371_det08_cls0.jpg
    📄 17-04_NSC-S_082_17-04_NSC-S_082_frame_22205_det22_cls0.jpg
    📄 17-04_NSC-S_012_17-04_NSC-S_012_frame_37333_det11_c

## 📊 Step 6: Setup Weights & Biases (Optional)

W&B provides excellent training visualization and experiment tracking.

In [10]:
# Setup Weights & Biases for experiment tracking
import wandb

# Login to W&B (you'll need to create a free account at wandb.ai)
print("🔐 Setting up Weights & Biases...")
print("\nTo use W&B:")
print("1. Go to https://wandb.ai and create a free account")
print("2. Get your API key from https://wandb.ai/authorize")
print("3. Run the cell below and paste your API key when prompted")
print("\nOr skip W&B by setting USE_WANDB = False below")

# Set this to True if you want to use W&B, False to skip
USE_WANDB = True  # 👈 Set to False if you don't want to use W&B

if USE_WANDB:
    try:
        # Try to login (will prompt for API key if not already logged in)
        wandb.login()
        print("✅ W&B login successful!")
    except:
        print("⚠️  W&B login failed. Training will continue without W&B logging.")
        USE_WANDB = False
else:
    print("ℹ️  Skipping W&B setup. Training will run without experiment tracking.")

🔐 Setting up Weights & Biases...

To use W&B:
1. Go to https://wandb.ai and create a free account
2. Get your API key from https://wandb.ai/authorize
3. Run the cell below and paste your API key when prompted

Or skip W&B by setting USE_WANDB = False below


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mcativthomson[0m ([33mcativthomson-university-of-cape-town[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


✅ W&B login successful!


## ⚙️ Step 7: Configure Training Parameters

Adjust these parameters based on your needs and available GPU memory.

In [11]:
# Training configuration
TRAINING_CONFIG = {
    # Basic settings
    'mode': 'semi_supervised',  # 'supervised' or 'semi_supervised'
    'epochs': 50,               # Reduced for Colab (normally 100)
    'batch_size': 16,           # Reduced for GPU memory (normally 32)
    'image_size': 224,
    'num_workers': 2,           # Reduced for Colab

    # Data splitting
    'val_split': 0.2,          # 20% for validation
    'test_split': 0.2,         # 20% for test

    # Model settings
    'model_name': 'vit_base_patch16_224',
    'pretrained': True,
    'dropout_rate': 0.1,

    # Training hyperparameters
    'learning_rate': 1e-4,
    'weight_decay': 0.05,
    'warmup_epochs': 5,        # Reduced for shorter training

    # Semi-supervised settings
    'consistency_weight': 2.0,
    'pseudo_label_threshold': 0.7,
    'temperature': 4.0,
    'unlabeled_ratio': 2.0,
    'ramp_up_epochs': 10,      # Reduced for shorter training

    # EMA settings
    'ema_momentum': 0.999,

    # Logging
    'use_wandb': USE_WANDB,
    'wandb_project': 'vit-fish-colab',
    'save_frequency': 10,      # Save every 10 epochs
    'seed': 42
}

print("🎛️  Training Configuration:")
for key, value in TRAINING_CONFIG.items():
    print(f"  {key}: {value}")

print("\n💡 Tips for Colab:")
print(f"  - Batch size {TRAINING_CONFIG['batch_size']} should work on most Colab GPUs")
print(f"  - {TRAINING_CONFIG['epochs']} epochs will take ~2-3 hours")
print("  - If you get GPU memory errors, reduce batch_size to 8")
print("  - Training will automatically save checkpoints every 10 epochs")

🎛️  Training Configuration:
  mode: semi_supervised
  epochs: 50
  batch_size: 16
  image_size: 224
  num_workers: 2
  val_split: 0.2
  test_split: 0.2
  model_name: vit_base_patch16_224
  pretrained: True
  dropout_rate: 0.1
  learning_rate: 0.0001
  weight_decay: 0.05
  warmup_epochs: 5
  consistency_weight: 2.0
  pseudo_label_threshold: 0.7
  temperature: 4.0
  unlabeled_ratio: 2.0
  ramp_up_epochs: 10
  ema_momentum: 0.999
  use_wandb: True
  wandb_project: vit-fish-colab
  save_frequency: 10
  seed: 42

💡 Tips for Colab:
  - Batch size 16 should work on most Colab GPUs
  - 50 epochs will take ~2-3 hours
  - If you get GPU memory errors, reduce batch_size to 8
  - Training will automatically save checkpoints every 10 epochs


In [None]:
# ========================================
# SAVE PROJECT FILES TO GOOGLE DRIVE
# ========================================

import shutil
import os

# Create project directory in Google Drive
project_dir = '/content/drive/MyDrive/ViT-FishID/code'
os.makedirs(project_dir, exist_ok=True)

# List of files to save (adjust paths as needed)
files_to_save = [
    'data.py',
    'model.py',
    'trainer.py',
    'utils.py',
    'train.py',
    'ViT_FishID_Colab_Training.ipynb'  # Your notebook
]

print("💾 Saving project files to Google Drive...")

for filename in files_to_save:
    try:
        if os.path.exists(filename):
            dest_path = os.path.join(project_dir, filename)
            shutil.copy2(filename, dest_path)
            print(f"✅ Saved: {filename}")
        else:
            print(f"⚠️  File not found: {filename}")
    except Exception as e:
        print(f"❌ Error saving {filename}: {e}")

print(f"\n📁 Project files saved to: {project_dir}")

# Also save the species mapping if it exists
try:
    if os.path.exists('species_mapping.txt'):
        shutil.copy2('species_mapping.txt', os.path.join(project_dir, 'species_mapping.txt'))
        print("✅ Saved: species_mapping.txt")
except:
    print("⚠️  species_mapping.txt not found")

print("\n🎉 All files backed up to Google Drive!")

## 🚀 Step 8: Start Training!

This cell will start the semi-supervised training process. It may take 2-3 hours to complete.

In [14]:
# Update the repository with the latest fixes
import os
os.chdir('/content/ViT-FishID')
!git pull origin main

print("✅ Repository updated with latest fixes!")

remote: Enumerating objects: 5, done.[K
remote: Counting objects:  20% (1/5)[Kremote: Counting objects:  40% (2/5)[Kremote: Counting objects:  60% (3/5)[Kremote: Counting objects:  80% (4/5)[Kremote: Counting objects: 100% (5/5)[Kremote: Counting objects: 100% (5/5), done.[K
remote: Compressing objects: 100% (1/1)[Kremote: Compressing objects: 100% (1/1), done.[K
remote: Total 3 (delta 2), reused 3 (delta 2), pack-reused 0 (from 0)[K
Unpacking objects:  33% (1/3)Unpacking objects:  66% (2/3)Unpacking objects: 100% (3/3)Unpacking objects: 100% (3/3), 718 bytes | 718.00 KiB/s, done.
From https://github.com/cat-thomson/ViT-FishID
 * branch            main       -> FETCH_HEAD
   abd54d1..1739055  main       -> origin/main
Updating abd54d1..1739055
Fast-forward
 data.py | 38 [32m++++++++++++++++++++++++++++++++++[m[31m----[m
 1 file changed, 34 insertions(+), 4 deletions(-)
✅ Repository updated with latest fixes!


In [None]:
# Build training command
training_cmd = f"""python train.py \
    --data_dir "{DATA_DIR}" \
    --mode {TRAINING_CONFIG['mode']} \
    --epochs {TRAINING_CONFIG['epochs']} \
    --batch_size {TRAINING_CONFIG['batch_size']} \
    --image_size {TRAINING_CONFIG['image_size']} \
    --num_workers {TRAINING_CONFIG['num_workers']} \
    --val_split {TRAINING_CONFIG['val_split']} \
    --test_split {TRAINING_CONFIG['test_split']} \
    --model_name {TRAINING_CONFIG['model_name']} \
    --learning_rate {TRAINING_CONFIG['learning_rate']} \
    --weight_decay {TRAINING_CONFIG['weight_decay']} \
    --warmup_epochs {TRAINING_CONFIG['warmup_epochs']} \
    --consistency_weight {TRAINING_CONFIG['consistency_weight']} \
    --pseudo_label_threshold {TRAINING_CONFIG['pseudo_label_threshold']} \
    --temperature {TRAINING_CONFIG['temperature']} \
    --unlabeled_ratio {TRAINING_CONFIG['unlabeled_ratio']} \
    --ramp_up_epochs {TRAINING_CONFIG['ramp_up_epochs']} \
    --ema_momentum {TRAINING_CONFIG['ema_momentum']} \
    --save_frequency {TRAINING_CONFIG['save_frequency']} \
    --seed {TRAINING_CONFIG['seed']}"""

if TRAINING_CONFIG['use_wandb']:
    training_cmd += f" --use_wandb --wandb_project {TRAINING_CONFIG['wandb_project']}"

if TRAINING_CONFIG['pretrained']:
    training_cmd += " --pretrained"

print("🚀 Starting ViT-FishID training...")
print("\n📋 Training command:")
print(training_cmd.replace('\\', '').strip())
print("\n" + "="*60)

# Execute training
!{training_cmd}

print("\n" + "="*60)
print("✅ Training completed!")

🚀 Starting ViT-FishID training...

📋 Training command:
python train.py     --data_dir "/content/fish_cutouts"     --mode semi_supervised     --epochs 50     --batch_size 16     --image_size 224     --num_workers 2     --val_split 0.2     --test_split 0.2     --model_name vit_base_patch16_224     --learning_rate 0.0001     --weight_decay 0.05     --warmup_epochs 5     --consistency_weight 2.0     --pseudo_label_threshold 0.7     --temperature 4.0     --unlabeled_ratio 2.0     --ramp_up_epochs 10     --ema_momentum 0.999     --save_frequency 10     --seed 42 --use_wandb --wandb_project vit-fish-colab --pretrained

2025-08-13 11:41:13.462656: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1755085273.482871    7081 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been regist

In [None]:
# ========================================
# EMERGENCY CHECKPOINT SAVE - RUN NOW!
# ========================================

import torch
import os
from google.colab import drive

# Mount Google Drive if not already mounted
try:
    drive.mount('/content/drive')
except:
    print("Drive already mounted or mount failed")

# Create checkpoints directory in Google Drive
checkpoint_dir = '/content/drive/MyDrive/ViT-FishID/checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)

# Get current epoch from trainer (assuming your trainer variable is called 'trainer')
current_epoch = getattr(trainer, 'current_epoch', 18)  # fallback to 18 if not found
best_accuracy = getattr(trainer, 'best_accuracy', 0.0)

print(f"💾 Saving manual checkpoint for epoch {current_epoch}")
print(f"📊 Current best accuracy: {best_accuracy:.2f}%")

# Create checkpoint dictionary with all necessary information
emergency_checkpoint = {
    'epoch': current_epoch,
    'student_state_dict': trainer.student_model.state_dict(),
    'ema_teacher_state_dict': trainer.ema_teacher.teacher.state_dict(),  # Note: .teacher not .teacher_model
    'optimizer_state_dict': trainer.optimizer.state_dict(),
    'best_accuracy': best_accuracy,
    'consistency_weight': trainer.consistency_weight,
    'pseudo_label_threshold': trainer.pseudo_label_threshold,
    'warmup_epochs': trainer.warmup_epochs,
    'ramp_up_epochs': trainer.ramp_up_epochs,
    'num_classes': trainer.num_classes,
    'timestamp': str(time.time()),
    'manual_save': True  # Flag to indicate this was a manual save
}

# Save the checkpoint
checkpoint_filename = f'emergency_checkpoint_epoch_{current_epoch}.pth'
checkpoint_path = os.path.join(checkpoint_dir, checkpoint_filename)

torch.save(emergency_checkpoint, checkpoint_path)

print(f"✅ Emergency checkpoint saved successfully!")
print(f"📁 Location: {checkpoint_path}")
print(f"📏 File size: {os.path.getsize(checkpoint_path) / (1024*1024):.1f} MB")

# Verify the checkpoint can be loaded
try:
    test_load = torch.load(checkpoint_path, map_location='cpu')
    print(f"✅ Checkpoint verification passed - contains {len(test_load)} keys")
    print(f"🔍 Saved epoch: {test_load['epoch']}")
    print(f"🔍 Best accuracy: {test_load['best_accuracy']:.2f}%")
    del test_load  # Free memory
except Exception as e:
    print(f"❌ Checkpoint verification failed: {e}")

print("\n🚨 Your training progress is now safely saved!")
print("🚨 If Colab times out, you can resume from this checkpoint.")

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## 📊 Step 9: Check Training Results

In [None]:
# Check for saved checkpoints
import os
import glob

checkpoint_dir = '/content/ViT-FishID/checkpoints'
print(f"📁 Checking for checkpoints in: {checkpoint_dir}")

if os.path.exists(checkpoint_dir):
    checkpoints = glob.glob(os.path.join(checkpoint_dir, '*.pth'))
    if checkpoints:
        print(f"\n✅ Found {len(checkpoints)} checkpoint(s):")
        for checkpoint in sorted(checkpoints):
            file_size = os.path.getsize(checkpoint) / (1024**2)  # MB
            print(f"  - {os.path.basename(checkpoint)} ({file_size:.1f} MB)")

        # Check if best model exists
        best_model = os.path.join(checkpoint_dir, 'model_best.pth')
        if os.path.exists(best_model):
            print(f"\n🏆 Best model saved: model_best.pth")
    else:
        print("❌ No checkpoints found")
else:
    print("❌ Checkpoint directory not found")

# Show training logs summary
print("\n📊 Training Summary:")
print("Check the training output above for final accuracy metrics")

if TRAINING_CONFIG['use_wandb']:
    print("\n📈 For detailed metrics and visualizations, check your W&B dashboard:")
    print(f"https://wandb.ai/your-username/{TRAINING_CONFIG['wandb_project']}")

## 💾 Step 10: Download Model and Results

Save your trained model to Google Drive for future use.

In [None]:
# Copy trained model to Google Drive
import shutil
from datetime import datetime

# Create a timestamped folder in Google Drive
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
save_dir = f'/content/drive/MyDrive/ViT-FishID_Training_{timestamp}'
os.makedirs(save_dir, exist_ok=True)

print(f"💾 Saving results to Google Drive: {save_dir}")

# Copy checkpoints
checkpoint_dir = '/content/ViT-FishID/checkpoints'
if os.path.exists(checkpoint_dir):
    drive_checkpoint_dir = os.path.join(save_dir, 'checkpoints')
    shutil.copytree(checkpoint_dir, drive_checkpoint_dir)
    print(f"✅ Checkpoints saved to: {drive_checkpoint_dir}")

# Save training configuration
import json
config_file = os.path.join(save_dir, 'training_config.json')
with open(config_file, 'w') as f:
    json.dump(TRAINING_CONFIG, f, indent=2)
print(f"✅ Training config saved to: {config_file}")

# Create a summary file
summary_file = os.path.join(save_dir, 'training_summary.txt')
with open(summary_file, 'w') as f:
    f.write(f"ViT-FishID Training Summary\n")
    f.write(f"========================\n\n")
    f.write(f"Training Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
    f.write(f"Mode: {TRAINING_CONFIG['mode']}\n")
    f.write(f"Epochs: {TRAINING_CONFIG['epochs']}\n")
    f.write(f"Batch Size: {TRAINING_CONFIG['batch_size']}\n")
    f.write(f"Data Directory: {DATA_DIR}\n")
    f.write(f"\nModel Architecture: {TRAINING_CONFIG['model_name']}\n")
    f.write(f"Learning Rate: {TRAINING_CONFIG['learning_rate']}\n")
    f.write(f"Consistency Weight: {TRAINING_CONFIG['consistency_weight']}\n")
    f.write(f"\nCheckpoints saved in: checkpoints/\n")
    f.write(f"Best model: checkpoints/model_best.pth\n")

print(f"✅ Training summary saved to: {summary_file}")

print(f"\n🎉 All results saved to Google Drive!")
print(f"📁 Location: {save_dir}")
print(f"\n💡 You can now:")
print(f"   1. Download the checkpoints folder for local use")
print(f"   2. Use model_best.pth for inference")
print(f"   3. Continue training from any checkpoint")

## 🧪 Step 11: Quick Model Evaluation (Optional)

Test your trained model on a few sample images.

In [None]:
# Quick evaluation of the trained model
import torch
import torch.nn.functional as F
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

# Check if best model exists
best_model_path = '/content/ViT-FishID/checkpoints/model_best.pth'

if os.path.exists(best_model_path):
    print("🧪 Loading trained model for quick evaluation...")

    # Load model checkpoint info
    checkpoint = torch.load(best_model_path, map_location='cpu')

    print(f"📊 Model training info:")
    if 'epoch' in checkpoint:
        print(f"  - Best epoch: {checkpoint['epoch']}")
    if 'best_acc' in checkpoint:
        print(f"  - Best accuracy: {checkpoint['best_acc']:.2f}%")
    if 'teacher_acc' in checkpoint:
        print(f"  - Teacher accuracy: {checkpoint['teacher_acc']:.2f}%")

    # Get class names if available
    if 'class_names' in checkpoint:
        class_names = checkpoint['class_names']
        print(f"  - Number of classes: {len(class_names)}")
        print(f"  - Sample classes: {class_names[:5]}...")

    print("\n✅ Model evaluation completed! Check the metrics above.")

else:
    print("❌ No trained model found. Make sure training completed successfully.")

print("\n💡 For comprehensive evaluation:")
print("   Use the evaluate.py script with your test dataset")
print("   The test set was automatically created during training")

## 🔧 Troubleshooting

### Common Issues and Solutions:

**1. GPU Memory Error (CUDA out of memory)**
- Reduce batch_size to 8 or 4
- Restart runtime and try again

**2. Data Not Found**
- Check that DATA_DIR path is correct
- Ensure data is uploaded to Google Drive
- Verify folder structure (labeled/ and unlabeled/)

**3. Training Stops Unexpectedly**
- Colab sessions timeout after 12 hours
- Use runtime management to prevent disconnection
- Checkpoints are saved every 10 epochs for resuming

**4. Low Accuracy**
- Increase epochs (try 75-100)
- Adjust consistency_weight (try 1.0-3.0)
- Lower pseudo_label_threshold (try 0.5-0.6)

**5. Consistency Loss is 0.0000**
- Lower pseudo_label_threshold to 0.5
- Check that you have unlabeled data
- Ensure semi_supervised mode is selected

## 🚀 Next Steps

After training is complete, you can:

1. **Download your model**: The trained model is saved in Google Drive
2. **Continue training**: Resume from checkpoints for more epochs
3. **Evaluate performance**: Use the test set for final evaluation
4. **Deploy model**: Use the trained model for fish classification
5. **Experiment**: Try different hyperparameters or architectures

### Model Files Saved:
- `model_best.pth`: Best performing model (use this for inference)
- `model_latest.pth`: Most recent checkpoint
- `model_epoch_XX.pth`: Periodic checkpoints

### Performance Expectations:
- **50 epochs**: ~70-80% accuracy
- **100 epochs**: ~75-85% accuracy
- **Semi-supervised**: Should outperform supervised training

**Happy fish classification! 🐟🎉**