# 🐟 ViT-FishID: Semi-Supervised Fish Classification in Google Colab

This notebook adapts the ViT-FishID project to run in Google Colab, providing a complete end-to-end pipeline for semi-supervised fish classification using Vision Transformers and EMA teacher-student framework.

## 🚀 What This Notebook Does:
- Sets up the complete environment in Google Colab
- Handles data upload and organization from Google Drive
- Implements ViT-Base with EMA teacher-student semi-supervised learning
- Provides interactive training with progress monitoring
- Saves results and checkpoints to Google Drive

## 📋 Before You Start:
1. **Enable GPU**: Go to Runtime → Change runtime type → GPU (T4 or better recommended)
2. **Prepare Your Data**: Upload your fish images to Google Drive
3. **Get W&B API Key**: Optional but recommended for experiment tracking

Let's get started! 🎣

## 📦 Section 1: Install Required Dependencies

First, let's install all the required packages that aren't pre-installed in Colab.

In [None]:
# Install required packages
!pip install -q timm transformers wandb pillow opencv-python scikit-learn tqdm

# Install PyTorch with CUDA support (if not already installed)
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")

# Verify installations
try:
    import timm
    import transformers
    import wandb
    from PIL import Image
    import cv2
    import sklearn
    print("✅ All packages installed successfully!")
except ImportError as e:
    print(f"❌ Error importing package: {e}")

# Set up for deterministic training
import numpy as np
import random

def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)
print("🎯 Random seed set to 42 for reproducibility")

## 💾 Section 2: Mount Google Drive

Mount Google Drive to access your datasets and save training outputs.

In [None]:
from google.colab import drive
import os

# Mount Google Drive
drive.mount('/content/drive')

# Set up project directories in Google Drive
DRIVE_ROOT = "/content/drive/MyDrive"
PROJECT_DIR = f"{DRIVE_ROOT}/ViT-FishID"
DATA_DIR = f"{PROJECT_DIR}/data"
CHECKPOINTS_DIR = f"{PROJECT_DIR}/checkpoints"
RESULTS_DIR = f"{PROJECT_DIR}/results"

# Create directories if they don't exist
os.makedirs(PROJECT_DIR, exist_ok=True)
os.makedirs(DATA_DIR, exist_ok=True)
os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)

print("📂 Google Drive mounted successfully!")
print(f"Project directory: {PROJECT_DIR}")
print(f"Data directory: {DATA_DIR}")
print(f"Checkpoints directory: {CHECKPOINTS_DIR}")
print(f"Results directory: {RESULTS_DIR}")

# List contents of project directory
if os.path.exists(PROJECT_DIR):
    print(f"\n📋 Contents of {PROJECT_DIR}:")
    for item in os.listdir(PROJECT_DIR):
        item_path = os.path.join(PROJECT_DIR, item)
        if os.path.isdir(item_path):
            print(f"  📁 {item}/")
        else:
            print(f"  📄 {item}")
else:
    print(f"🆕 {PROJECT_DIR} is empty - ready for your data!")

## 📤 Section 3: Upload Dataset to Google Drive

Choose one of the following methods to get your fish dataset into Colab:

In [None]:
# Method 1: Upload files directly to Colab (for small datasets)
from google.colab import files
import zipfile
import shutil

def upload_and_extract_zip():
    """Upload a zip file and extract it to the data directory"""
    print("📤 Upload your fish dataset as a ZIP file:")
    uploaded = files.upload()
    
    for filename in uploaded.keys():
        if filename.endswith('.zip'):
            print(f"📦 Extracting {filename}...")
            with zipfile.ZipFile(filename, 'r') as zip_ref:
                zip_ref.extractall(DATA_DIR)
            print(f"✅ Extracted to {DATA_DIR}")
            os.remove(filename)  # Clean up
        else:
            # Move non-zip files to data directory
            shutil.move(filename, os.path.join(DATA_DIR, filename))

# Method 2: Use existing data from Google Drive
def check_existing_data():
    """Check if data already exists in Google Drive"""
    if os.path.exists(f"{DATA_DIR}/organized_fish_dataset"):
        print("✅ Found existing organized dataset!")
        return True
    elif os.path.exists(f"{DATA_DIR}/Images"):
        print("✅ Found raw images directory!")
        return True
    else:
        print("❌ No existing data found in Google Drive")
        return False

# Method 3: Download from URL (if you have a dataset URL)
def download_from_url(url, filename):
    """Download dataset from URL"""
    import urllib.request
    print(f"📥 Downloading {filename} from {url}...")
    urllib.request.urlretrieve(url, filename)
    
    if filename.endswith('.zip'):
        print("📦 Extracting...")
        with zipfile.ZipFile(filename, 'r') as zip_ref:
            zip_ref.extractall(DATA_DIR)
        os.remove(filename)
    print("✅ Download complete!")

# Check what we have
print("🔍 Checking for existing data...")
has_data = check_existing_data()

if not has_data:
    print("\n🎯 Choose your upload method:")
    print("1. Run upload_and_extract_zip() to upload a ZIP file")
    print("2. Manually upload to Google Drive and restart this cell")
    print("3. Use download_from_url(url, filename) if you have a URL")
    print("\nExample: upload_and_extract_zip()")
else:
    print("📊 Listing data directory contents:")
    for root, dirs, files in os.walk(DATA_DIR):
        level = root.replace(DATA_DIR, '').count(os.sep)
        indent = ' ' * 2 * level
        print(f"{indent}📁 {os.path.basename(root)}/")
        subindent = ' ' * 2 * (level + 1)
        for file in files[:5]:  # Show first 5 files
            print(f"{subindent}📄 {file}")
        if len(files) > 5:
            print(f"{subindent}... and {len(files) - 5} more files")

## 🏗️ Section 4: Set Up Directory Structure

Create the necessary directory structure and organize data for training.

In [None]:
import json
from pathlib import Path
from collections import defaultdict
import glob

def organize_fish_data(input_dir, output_dir, labeled_species=None):
    """
    Organize fish images into labeled and unlabeled directories
    """
    if labeled_species is None:
        # Common fish species - you can modify this list
        labeled_species = ['bass', 'trout', 'salmon', 'tuna', 'cod', 'mackerel']
    
    # Create output structure
    labeled_dir = os.path.join(output_dir, 'labeled')
    unlabeled_dir = os.path.join(output_dir, 'unlabeled')
    
    os.makedirs(labeled_dir, exist_ok=True)
    os.makedirs(unlabeled_dir, exist_ok=True)
    
    # Create species directories
    for species in labeled_species:
        os.makedirs(os.path.join(labeled_dir, species), exist_ok=True)
    
    # Find all image files
    image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.gif']
    all_images = []
    
    for ext in image_extensions:
        all_images.extend(glob.glob(os.path.join(input_dir, '**', ext), recursive=True))
    
    print(f"🔍 Found {len(all_images)} images")
    
    # Organize images
    species_counts = defaultdict(int)
    unlabeled_count = 0
    
    for img_path in all_images:
        filename = os.path.basename(img_path).lower()
        assigned = False
        
        # Check if filename contains any labeled species
        for species in labeled_species:
            if species.lower() in filename:
                dest_dir = os.path.join(labeled_dir, species)
                dest_path = os.path.join(dest_dir, os.path.basename(img_path))
                
                # Copy file if not already there
                if not os.path.exists(dest_path):
                    shutil.copy2(img_path, dest_path)
                
                species_counts[species] += 1
                assigned = True
                break
        
        if not assigned:
            # Move to unlabeled
            dest_path = os.path.join(unlabeled_dir, os.path.basename(img_path))
            if not os.path.exists(dest_path):
                shutil.copy2(img_path, dest_path)
            unlabeled_count += 1
    
    # Create dataset info
    dataset_info = {
        'total_images': len(all_images),
        'labeled_species': dict(species_counts),
        'unlabeled_count': unlabeled_count,
        'species_list': labeled_species
    }
    
    # Save dataset info
    with open(os.path.join(output_dir, 'dataset_info.json'), 'w') as f:
        json.dump(dataset_info, f, indent=2)
    
    print("\n📊 Dataset Organization Complete!")
    print(f"Total images: {len(all_images)}")
    print(f"Labeled images: {sum(species_counts.values())}")
    print(f"Unlabeled images: {unlabeled_count}")
    print("\nSpecies breakdown:")
    for species, count in species_counts.items():
        print(f"  🐟 {species}: {count} images")
    
    return output_dir

# Set up the organized dataset directory
ORGANIZED_DATA_DIR = f"{DATA_DIR}/organized_fish_dataset"

# Check if we need to organize data
if os.path.exists(ORGANIZED_DATA_DIR) and os.path.exists(f"{ORGANIZED_DATA_DIR}/dataset_info.json"):
    print("✅ Found existing organized dataset!")
    
    # Load and display dataset info
    with open(f"{ORGANIZED_DATA_DIR}/dataset_info.json", 'r') as f:
        dataset_info = json.load(f)
    
    print(f"📊 Dataset Summary:")
    print(f"  Total images: {dataset_info['total_images']}")
    print(f"  Labeled images: {sum(dataset_info['labeled_species'].values())}")
    print(f"  Unlabeled images: {dataset_info['unlabeled_count']}")
    print(f"  Species: {', '.join(dataset_info['species_list'])}")
    
else:
    print("🔧 Need to organize data...")
    
    # Look for raw images
    raw_images_dir = None
    possible_dirs = [f"{DATA_DIR}/Images", f"{DATA_DIR}/images", f"{DATA_DIR}/fish_images"]
    
    for pdir in possible_dirs:
        if os.path.exists(pdir):
            raw_images_dir = pdir
            break
    
    if raw_images_dir:
        print(f"📁 Found raw images in: {raw_images_dir}")
        print("🔄 Organizing data...")
        
        # You can customize these species based on your dataset
        my_species = ['bass', 'trout', 'salmon', 'tuna', 'cod', 'mackerel', 'snapper', 'grouper']
        
        organize_fish_data(raw_images_dir, ORGANIZED_DATA_DIR, my_species)
    else:
        print("❌ No raw images found. Please upload your data first!")
        print("Expected directories: Images/, images/, or fish_images/")

## 🧠 Section 5: Model Definitions

Now let's define our Vision Transformer and EMA teacher-student framework.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from copy import deepcopy

class ViTFishClassifier(nn.Module):
    """Vision Transformer for Fish Classification"""
    
    def __init__(self, num_classes, model_name='vit_base_patch16_224', pretrained=True):
        super().__init__()
        self.num_classes = num_classes
        
        # Create ViT model - IMPORTANT: use global_pool='token' for classification
        self.model = timm.create_model(
            model_name,
            pretrained=pretrained,
            num_classes=num_classes,
            global_pool='token'  # This ensures we get [batch_size, num_classes] output
        )
        
        print(f"✅ Created ViT model: {model_name}")
        print(f"   Parameters: {sum(p.numel() for p in self.model.parameters()) / 1e6:.1f}M")
        print(f"   Classes: {num_classes}")
    
    def forward(self, x):
        return self.model(x)
    
    def get_features(self, x):
        """Extract features before the classification head"""
        # Get features from the model
        features = self.model.forward_features(x)
        
        # Global pool if needed
        if hasattr(self.model, 'global_pool') and self.model.global_pool == 'token':
            features = features[:, 0]  # Take CLS token
        elif len(features.shape) > 2:
            features = features.mean(dim=1)  # Global average pooling
            
        return features

class EMATeacher:
    """Exponential Moving Average Teacher"""
    
    def __init__(self, student_model, momentum=0.999):
        self.momentum = momentum
        self.teacher_model = deepcopy(student_model)
        
        # Initialize teacher with student weights
        for teacher_param, student_param in zip(
            self.teacher_model.parameters(), 
            student_model.parameters()
        ):
            teacher_param.data.copy_(student_param.data)
            teacher_param.requires_grad = False
        
        self.teacher_model.eval()
        print(f"✅ EMA Teacher initialized with momentum: {momentum}")
    
    def update(self, student_model):
        """Update teacher model using EMA"""
        with torch.no_grad():
            for teacher_param, student_param in zip(
                self.teacher_model.parameters(),
                student_model.parameters()
            ):
                teacher_param.data = (
                    self.momentum * teacher_param.data +
                    (1.0 - self.momentum) * student_param.data
                )
    
    def __call__(self, x):
        """Forward pass through teacher model"""
        with torch.no_grad():
            return self.teacher_model(x)

# Test the models
def test_models():
    """Test model creation and forward pass"""
    # Create a dummy batch
    batch_size = 4
    dummy_input = torch.randn(batch_size, 3, 224, 224)
    
    # Test with different number of classes
    num_classes = 10  # Will be updated based on actual data
    
    print("🧪 Testing model creation...")
    
    # Create student model
    student = ViTFishClassifier(num_classes=num_classes)
    
    # Test forward pass
    with torch.no_grad():
        output = student(dummy_input)
        print(f"✅ Student output shape: {output.shape}")
        assert output.shape == (batch_size, num_classes), f"Expected {(batch_size, num_classes)}, got {output.shape}"
        
        # Test feature extraction
        features = student.get_features(dummy_input)
        print(f"✅ Feature shape: {features.shape}")
    
    # Create teacher model
    teacher = EMATeacher(student)
    
    # Test teacher forward pass
    with torch.no_grad():
        teacher_output = teacher(dummy_input)
        print(f"✅ Teacher output shape: {teacher_output.shape}")
        assert teacher_output.shape == (batch_size, num_classes)
    
    print("🎉 All model tests passed!")
    
    return student, teacher

# Run tests
student_model, teacher_model = test_models()

In [None]:
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import glob

class FishDataset(Dataset):
    """Dataset for fish images"""
    
    def __init__(self, image_paths, labels=None, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
        self.is_labeled = labels is not None
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        # Load image
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        if self.is_labeled:
            return image, self.labels[idx]
        else:
            return image

def create_transforms(image_size=224, is_training=True):
    """Create data transforms"""
    if is_training:
        transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(degrees=15),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    else:
        transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    return transform

def prepare_data_loaders(data_dir, batch_size=32, val_split=0.2):
    """Prepare data loaders for semi-supervised learning"""
    
    # Load dataset info
    with open(os.path.join(data_dir, 'dataset_info.json'), 'r') as f:
        dataset_info = json.load(f)
    
    # Create species to ID mapping
    species_list = list(dataset_info['labeled_species'].keys())
    species_to_id = {species: idx for idx, species in enumerate(species_list)}
    id_to_species = {idx: species for species, idx in species_to_id.items()}
    num_classes = len(species_list)
    
    print(f"📊 Found {num_classes} species: {species_list}")
    
    # Collect labeled data
    labeled_paths = []
    labeled_labels = []
    
    labeled_dir = os.path.join(data_dir, 'labeled')
    for species, species_id in species_to_id.items():
        species_dir = os.path.join(labeled_dir, species)
        if os.path.exists(species_dir):
            species_paths = glob.glob(os.path.join(species_dir, '*.*'))
            labeled_paths.extend(species_paths)
            labeled_labels.extend([species_id] * len(species_paths))
    
    print(f"📚 Labeled data: {len(labeled_paths)} images")
    
    # Collect unlabeled data
    unlabeled_dir = os.path.join(data_dir, 'unlabeled')
    unlabeled_paths = []
    if os.path.exists(unlabeled_dir):
        unlabeled_paths = glob.glob(os.path.join(unlabeled_dir, '*.*'))
    
    print(f"🔄 Unlabeled data: {len(unlabeled_paths)} images")
    
    # Split labeled data into train/val
    from sklearn.model_selection import train_test_split
    
    if len(labeled_paths) > 0:
        train_paths, val_paths, train_labels, val_labels = train_test_split(
            labeled_paths, labeled_labels, 
            test_size=val_split, 
            stratify=labeled_labels,
            random_state=42
        )
    else:
        train_paths, val_paths, train_labels, val_labels = [], [], [], []
    
    print(f"🏋️ Training labeled: {len(train_paths)} images")
    print(f"🎯 Validation: {len(val_paths)} images")
    
    # Create transforms
    train_transform = create_transforms(is_training=True)
    val_transform = create_transforms(is_training=False)
    
    # Create datasets
    train_labeled_dataset = FishDataset(train_paths, train_labels, train_transform)
    val_dataset = FishDataset(val_paths, val_labels, val_transform)
    unlabeled_dataset = FishDataset(unlabeled_paths, None, train_transform)
    
    # Create data loaders
    train_labeled_loader = DataLoader(
        train_labeled_dataset, 
        batch_size=batch_size, 
        shuffle=True, 
        num_workers=2,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset, 
        batch_size=batch_size, 
        shuffle=False, 
        num_workers=2,
        pin_memory=True
    )
    
    unlabeled_loader = DataLoader(
        unlabeled_dataset, 
        batch_size=batch_size, 
        shuffle=True, 
        num_workers=2,
        pin_memory=True
    ) if len(unlabeled_paths) > 0 else None
    
    return {
        'train_labeled': train_labeled_loader,
        'validation': val_loader,
        'unlabeled': unlabeled_loader,
        'num_classes': num_classes,
        'species_to_id': species_to_id,
        'id_to_species': id_to_species,
        'dataset_info': dataset_info
    }

# Test data loading
if os.path.exists(ORGANIZED_DATA_DIR):
    print("🔍 Testing data loading...")
    data_info = prepare_data_loaders(ORGANIZED_DATA_DIR, batch_size=8)
    print(f"✅ Data loaders created successfully!")
    print(f"   Number of classes: {data_info['num_classes']}")
    print(f"   Species: {list(data_info['species_to_id'].keys())}")
else:
    print("⚠️ No organized data found. Please run the data organization step first.")

## 🏋️ Section 6: Training Pipeline

Now let's implement the semi-supervised training loop with EMA teacher-student framework.