# 🐟 ViT-FishID: Supervised Fish Classification Training

**COMPLETE SUPERVISED TRAINING PIPELINE WITH GOOGLE COLAB**

<a href="https://colab.research.google.com/github/cat-thomson/ViT-FishID/blob/main/ViT_FishID_Supervised_Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 🎯 What This Notebook Does

This notebook implements a **fully supervised learning pipeline** for fish species classification using:

**🤖 Vision Transformer (ViT)**: State-of-the-art transformer architecture for image classification
**📊 Supervised Learning**: Uses only labeled fish images with ground truth
**🔍 Data Filtering**: Automatically removes species with fewer than 2 images
**☁️ Google Colab**: Cloud-based training with GPU acceleration

## 📊 Expected Performance

- **Training Time**: 2-4 hours for 100 epochs
- **GPU Requirements**: T4/V100/A100 (Colab Pro recommended)
- **Expected Accuracy**: 85-95% on fish species classification
- **Data Requirements**: Minimum 2 images per species class

## 🛠️ What You Need

1. **Fish Dataset**: Labeled fish images organized by species (upload to Google Drive)
2. **Google Colab Pro**: Recommended for longer training sessions
3. **Weights & Biases Account**: Optional for experiment tracking

## 🔄 Key Differences from Semi-Supervised Version

- ✅ **Supervised Only**: No unlabeled data used
- ✅ **Data Filtering**: Species with <2 images automatically excluded
- ✅ **Simplified Training**: Standard supervised learning approach
- ✅ **Faster Training**: No consistency loss or pseudo-labeling overhead

## 🔧 Step 1: Environment Setup and GPU Check

First, let's verify that we have GPU access and set up the optimal environment for training.

In [None]:
# Check GPU availability and system information
import torch
import os
import gc

print("🔍 SYSTEM INFORMATION")
print("="*50)
print(f"Python version: {os.sys.version}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    device_name = torch.cuda.get_device_name(0)
    device_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
    print(f"GPU Device: {device_name}")
    print(f"GPU Memory: {device_memory:.1f} GB")
    print("✅ GPU is ready for training!")

    # Set optimal GPU settings
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False

    # Clear GPU cache
    torch.cuda.empty_cache()
    gc.collect()
    print("🚀 GPU optimized for training")

else:
    print("❌ No GPU detected!")
    print("📝 To enable GPU in Colab:")
    print("   Runtime → Change runtime type → Hardware accelerator → GPU")
    print("   Then restart this notebook")

# Set device for later use
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\n🎯 Using device: {DEVICE}")

## 📁 Step 2: Mount Google Drive

This will give us access to your fish dataset stored in Google Drive.

In [None]:
from google.colab import drive
import os
import shutil

# Mount Google Drive
print("Attempting to mount Google Drive...")

# Ensure the mount point is clean before mounting
mount_point = '/content/drive'
if os.path.exists(mount_point) and os.path.isdir(mount_point):
    print(f"Clearing contents of mount point: {mount_point}")
    try:
        # Use `rm -rf` via shell command for robustness in Colab environment
        !rm -rf {mount_point}/*
        # Recreate the directory structure if it was completely removed
        if not os.path.exists(mount_point):
             os.makedirs(mount_point)
        print("✅ Mount point cleared.")
    except Exception as e:
        print(f"❌ Error clearing mount point: {e}")
        print("Attempting to proceed with mount anyway...")

drive.mount('/content/drive')

# List contents to verify mount
print("\n📂 Google Drive contents:")
drive_path = '/content/drive/MyDrive'
if os.path.exists(drive_path):
    items = os.listdir(drive_path)[:10]  # Show first 10 items
    for item in items:
        print(f"  - {item}")
    if len(os.listdir(drive_path)) > 10:
        print(f"  ... and {len(os.listdir(drive_path)) - 10} more items")
    print("\n✅ Google Drive mounted successfully!")
else:
    print("❌ Failed to mount Google Drive")

## 📦 Step 3: Install Dependencies

Installing all required packages for ViT-FishID supervised training.

In [None]:
# Install required packages
print("📦 Installing dependencies...")

!pip install -q torch torchvision torchaudio
!pip install -q timm transformers
!pip install -q albumentations
!pip install -q wandb
!pip install -q opencv-python-headless
!pip install -q scikit-learn
!pip install -q matplotlib seaborn
!pip install -q tqdm

print("✅ All dependencies installed successfully!")

# Verify installations
import torch
import torchvision
import timm
import albumentations
import cv2
import sklearn

print("\n📋 Package versions:")
print(f"  - torch: {torch.__version__}")
print(f"  - torchvision: {torchvision.__version__}")
print(f"  - timm: {timm.__version__}")
print(f"  - albumentations: {albumentations.__version__}")
print(f"  - opencv: {cv2.__version__}")
print(f"  - sklearn: {sklearn.__version__}")

## 🔄 Step 4: Clone ViT-FishID Repository

Getting the latest code from your GitHub repository.

In [None]:
# Clone the repository
import os

# Remove existing directory if it exists
if os.path.exists('/content/ViT-FishID'):
    !rm -rf /content/ViT-FishID

# Clone the repository
print("📥 Cloning ViT-FishID repository...")
!git clone https://github.com/cat-thomson/ViT-FishID.git /content/ViT-FishID

# Change to project directory
%cd /content/ViT-FishID

# List project files
print("\n📂 Project structure:")
!ls -la

print("\n✅ Repository cloned successfully!")