# 🐟 ViT Fish Classification - Semi-Supervised Learning with EMA Teacher-Student

This notebook trains a Vision Transformer (ViT) using **semi-supervised learning** with an **EMA Teacher-Student framework** for fish species classification.

**Key Features:**
- ✅ **Fixed Consistency Loss** - Now properly uses unlabeled data
- ✅ **Optimized Pseudo-label Threshold** (0.7 instead of 0.95)
- ✅ **Temperature Scaling** for better probability calibration
- ✅ **Exponential Moving Average (EMA) Teacher**
- ✅ **Consistency Regularization** between student and teacher
- ✅ **Auto-GPU Detection** and optimization

**Expected Performance:**
- **Without Semi-Supervised**: ~60% accuracy (your current result)
- **With Semi-Supervised**: 65-75% accuracy (leveraging 13,908 unlabeled images)

## 🚀 Step 1: Environment Setup

Install required packages and check GPU availability.

In [None]:
# Install required packages
!pip install -q torch torchvision timm transformers
!pip install -q albumentations opencv-python
!pip install -q wandb tqdm scikit-learn
!pip install -q Pillow numpy pandas matplotlib seaborn

# Check GPU availability
import torch
print(f"🔥 PyTorch version: {torch.__version__}")
print(f"🚀 CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"🎯 GPU: {torch.cuda.get_device_name(0)}")
    print(f"💾 GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("⚠️ No GPU available - training will be slow!")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🎮 Using device: {device}")

## 📁 Step 2: Mount Google Drive & Clone Repository

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Clone your GitHub repository
!git clone https://github.com/cat-thomson/ViT-FishID.git
%cd ViT-FishID

# List files to confirm
!ls -la

print("\n✅ Repository cloned successfully!")
print("📂 Available files:")
!ls *.py | head -10

## 🗂️ Step 3: Setup Fish Dataset

**Important**: Make sure your fish images are uploaded to Google Drive!

In [None]:
import os
import zipfile
from pathlib import Path

# Configuration - UPDATE THESE PATHS
DRIVE_DATA_PATH = "/content/drive/MyDrive/fish_dataset.zip"  # 👈 UPDATE THIS
# Alternative if your data is in a folder:
# DRIVE_DATA_PATH = "/content/drive/MyDrive/fish_images/"  # 👈 UPDATE THIS

# Extract dataset if it's a zip file
if DRIVE_DATA_PATH.endswith('.zip'):
    print(f"📦 Extracting {DRIVE_DATA_PATH}...")
    with zipfile.ZipFile(DRIVE_DATA_PATH, 'r') as zip_ref:
        zip_ref.extractall('/content/')
    
    # Find the extracted folder
    extracted_folders = [f for f in os.listdir('/content/') if os.path.isdir(f'/content/{f}') and 'fish' in f.lower()]
    if extracted_folders:
        data_dir = f'/content/{extracted_folders[0]}'
        print(f"✅ Dataset extracted to: {data_dir}")
    else:
        print("❌ Could not find fish dataset folder")
        data_dir = '/content/fish_images'  # fallback
else:
    # Copy folder from Drive
    data_dir = '/content/fish_images'
    !cp -r {DRIVE_DATA_PATH} {data_dir}
    print(f"✅ Dataset copied to: {data_dir}")

# Check dataset structure
print(f"\n📊 Dataset contents:")
if os.path.exists(data_dir):
    print(f"📁 Total files: {len(list(Path(data_dir).rglob('*.*')))}")
    print(f"📁 Directories: {len([d for d in Path(data_dir).iterdir() if d.is_dir()])}")
    
    # Show first few directories
    dirs = [d.name for d in Path(data_dir).iterdir() if d.is_dir()][:5]
    print(f"📂 Sample directories: {dirs}")
else:
    print(f"❌ Dataset not found at {data_dir}")
    print("Please update DRIVE_DATA_PATH in the cell above!")

## 🏗️ Step 4: Organize Dataset for Semi-Supervised Learning

In [None]:
# Organize the dataset for semi-supervised learning
organize_command = f"python organize_fish_data.py --input_dir {data_dir} --output_dir /content/organized_fish_dataset --train_ratio 0.7 --val_ratio 0.2 --test_ratio 0.1 --seed 42"

print(f"🔄 Organizing dataset...")
print(f"Command: {organize_command}")

!{organize_command}

# Verify organization
print("\n✅ Dataset organization complete!")
print("📊 Final structure:")
!ls -la /content/organized_fish_dataset/

# Count files in each split
labeled_dir = "/content/organized_fish_dataset/labeled"
unlabeled_dir = "/content/organized_fish_dataset/unlabeled"
val_dir = "/content/organized_fish_dataset/val"

if os.path.exists(labeled_dir):
    labeled_count = len(list(Path(labeled_dir).rglob('*.jpg'))) + len(list(Path(labeled_dir).rglob('*.png')))
    print(f"📈 Labeled images: {labeled_count:,}")

if os.path.exists(unlabeled_dir):
    unlabeled_count = len(list(Path(unlabeled_dir).rglob('*.jpg'))) + len(list(Path(unlabeled_dir).rglob('*.png')))
    print(f"📊 Unlabeled images: {unlabeled_count:,}")

if os.path.exists(val_dir):
    val_count = len(list(Path(val_dir).rglob('*.jpg'))) + len(list(Path(val_dir).rglob('*.png')))
    print(f"📋 Validation images: {val_count:,}")

print(f"\n🎯 Semi-supervised ratio: {unlabeled_count/labeled_count:.1f}x unlabeled data")

## 📊 Step 5: Setup Weights & Biases (Optional)

Track your training progress with W&B!

In [None]:
import wandb

# Login to W&B (optional - comment out if you don't want to use it)
try:
    wandb.login()
    use_wandb = True
    print("✅ W&B login successful!")
except:
    use_wandb = False
    print("⚠️ W&B login skipped - training will run without logging")

# Alternatively, you can skip W&B entirely
# use_wandb = False
# print("📊 Training without W&B logging")

## 🎯 Step 6: Start Semi-Supervised Training (IMPROVED)

**🔧 Fixed Issues:**
- ✅ **Consistency Loss**: Now properly computed on ALL unlabeled data
- ✅ **Pseudo-label Threshold**: Lowered to 0.7 (from 0.95) for better utilization
- ✅ **Temperature Scaling**: Added for better probability calibration
- ✅ **MSE Consistency Loss**: More stable than KL divergence

**Expected Results:**
- **Consistency Loss**: Should be > 0.001 (not 0.0000)
- **High-conf Pseudo**: Should be 10-30% (not 0.0%)
- **Validation Accuracy**: 65-75% (improved from 60%)

In [None]:
# Improved training command with optimized parameters
training_command = f"""
python main_semi_supervised.py \
    --data_dir /content/organized_fish_dataset \
    --epochs 50 \
    --batch_size 32 \
    --learning_rate 1e-4 \
    --weight_decay 0.05 \
    --warmup_epochs 5 \
    --ramp_up_epochs 15 \
    --ema_momentum 0.999 \
    --consistency_loss mse \
    --consistency_weight 1.0 \
    --pseudo_label_threshold 0.7 \
    --temperature 4.0 \
    --unlabeled_ratio 2.0 \
    --save_frequency 10 \
    --model_name vit_base_patch16_224 \
    --pretrained \
    {'--use_wandb' if use_wandb else ''} \
    {'--wandb_project vit-fish-semi-supervised-colab' if use_wandb else ''}
""".strip()

print("🚀 Starting IMPROVED Semi-Supervised Training...")
print("\n🔧 Key Improvements:")
print("  ✅ Pseudo-label threshold: 0.7 (was 0.95)")
print("  ✅ Fixed consistency loss computation")
print("  ✅ Temperature scaling: 4.0")
print("  ✅ MSE consistency loss (more stable)")
print("  ✅ Shorter warmup: 5 epochs (was 10)")

print(f"\n📋 Command: {training_command}")
print("\n⏰ Training will take 2-4 hours depending on your GPU...")
print("\n🎯 Watch for:")
print("  - Cons Loss > 0.001 (should NOT be 0.0000)")
print("  - High-conf Pseudo: 10-30% (should NOT be 0.0%)")
print("  - Validation Acc: 65-75% (improvement from 60%)")

# Execute training
!{training_command}

## 📈 Step 7: Monitor Training Progress

Check if the semi-supervised learning is working correctly!

In [None]:
# Check if training is producing the right outputs
import re
import matplotlib.pyplot as plt

print("🔍 Checking Semi-Supervised Training Health...")

# Check if checkpoints are being saved
checkpoint_dir = "/content/ViT-FishID/semi_supervised_checkpoints"
if os.path.exists(checkpoint_dir):
    checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith('.pth')]
    print(f"✅ Checkpoints saved: {len(checkpoints)}")
    if checkpoints:
        print(f"📂 Latest: {sorted(checkpoints)[-1]}")
else:
    print("⚠️ No checkpoints found yet")

# Tips for monitoring
print("\n🎯 What to Look For:")
print("")
✅ **Good Signs:**
- Cons Loss: 0.001-0.1 (NOT 0.0000)
- High-conf Pseudo: 10-50%
- Validation accuracy increasing
- Loss decreasing steadily

❌ **Bad Signs:**
- Cons Loss: 0.0000 (semi-supervised not working)
- High-conf Pseudo: 0.0% (threshold too high)
- Validation accuracy stuck
- Training loss not decreasing

🔧 **If Problems:**
1. Lower pseudo_label_threshold to 0.5
2. Increase temperature to 5.0
3. Check consistency_weight ramping
""")

# If using W&B, show link
if use_wandb:
    print(f"\n📊 Monitor training at: https://wandb.ai/")
    print("🔗 Look for project: vit-fish-semi-supervised-colab")

## 📊 Step 8: Analyze Results

Compare supervised vs semi-supervised performance!

In [None]:
# Load and analyze the best model
import torch
import json
from pathlib import Path

print("📊 Analyzing Training Results...")

# Check if best model exists
best_model_path = "/content/ViT-FishID/semi_supervised_checkpoints/model_best.pth"

if os.path.exists(best_model_path):
    print(f"✅ Best model found: {best_model_path}")
    
    # Load checkpoint info
    try:
        checkpoint = torch.load(best_model_path, map_location='cpu')
        
        print(f"\n🎯 Best Results:")
        print(f"  📈 Best Epoch: {checkpoint.get('epoch', 'N/A')}")
        print(f"  🎯 Best Accuracy: {checkpoint.get('best_accuracy', 'N/A'):.2f}%")
        print(f"  📉 Final Loss: {checkpoint.get('loss', 'N/A')}")
        
        # Model info
        if 'model_state_dict' in checkpoint:
            model_size = sum(p.numel() for p in checkpoint['model_state_dict'].values())
            print(f"  🧠 Model Parameters: {model_size:,}")
            
    except Exception as e:
        print(f"⚠️ Could not load checkpoint details: {e}")
        
else:
    print("❌ Best model not found - training may still be running")

# Performance comparison
print("\n📈 Expected Performance Comparison:")
print("")
**Supervised Learning Only** (your previous result):
- Validation Accuracy: ~60%
- Uses only: 3,273 labeled images
- Consistency Loss: 0.0000

**Semi-Supervised Learning** (this improved version):
- Validation Accuracy: 65-75% 📈
- Uses: 3,273 labeled + 13,908 unlabeled images
- Consistency Loss: >0.001
- High-conf Pseudo: 10-30%

🚀 **Expected Improvement: +5-15% accuracy**
""")

# Show training artifacts
checkpoint_dir = "/content/ViT-FishID/semi_supervised_checkpoints"
if os.path.exists(checkpoint_dir):
    files = os.listdir(checkpoint_dir)
    print(f"\n📁 Training Artifacts ({len(files)} files):")
    for f in sorted(files)[:10]:  # Show first 10
        print(f"  📄 {f}")
    if len(files) > 10:
        print(f"  ... and {len(files)-10} more files")

## 💾 Step 9: Download Results

Package and download your trained model!

In [None]:
# Create a results package
import zipfile
from datetime import datetime

print("📦 Packaging results for download...")

# Create timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
results_zip = f"/content/vit_fish_semi_supervised_results_{timestamp}.zip"

with zipfile.ZipFile(results_zip, 'w') as zipf:
    # Add checkpoints
    checkpoint_dir = "/content/ViT-FishID/semi_supervised_checkpoints"
    if os.path.exists(checkpoint_dir):
        for file in os.listdir(checkpoint_dir):
            if file.endswith('.pth'):
                zipf.write(os.path.join(checkpoint_dir, file), f"checkpoints/{file}")
    
    # Add code files
    code_files = [
        "main_semi_supervised.py",
        "semi_supervised_trainer.py", 
        "vit_model.py",
        "semi_supervised_data.py"
    ]
    
    for file in code_files:
        if os.path.exists(f"/content/ViT-FishID/{file}"):
            zipf.write(f"/content/ViT-FishID/{file}", f"code/{file}")
    
    # Add this notebook
    notebook_files = [f for f in os.listdir("/content/ViT-FishID") if f.endswith('.ipynb')]
    for nb in notebook_files:
        zipf.write(f"/content/ViT-FishID/{nb}", f"notebooks/{nb}")

print(f"✅ Results packaged: {results_zip}")
print(f"📦 Size: {os.path.getsize(results_zip) / 1e6:.1f} MB")

# Download the results
from google.colab import files
print("⬇️ Starting download...")
files.download(results_zip)

print("\n🎉 Training Complete!")
print("📁 Your package contains:")
print("  - 🏆 Best trained model (model_best.pth)")
print("  - 💾 Training checkpoints")
print("  - 📝 All source code")
print("  - 📓 This notebook")

# Copy to Google Drive as backup
drive_backup = f"/content/drive/MyDrive/vit_fish_results_{timestamp}.zip"
!cp {results_zip} {drive_backup}
print(f"\n☁️ Backup saved to Google Drive: {drive_backup}")

## 🚀 Next Steps

### 🎯 **Key Improvements Made:**
1. **Fixed Consistency Loss** - Now properly computed on ALL unlabeled data
2. **Optimized Threshold** - Lowered from 0.95 to 0.7 for better utilization
3. **Added Temperature Scaling** - Better probability calibration
4. **Improved Loss Function** - MSE consistency loss for stability

### 📊 **Expected Results:**
- **Consistency Loss**: Should be > 0.001 (not 0.0000)
- **High-confidence Pseudo-labels**: 10-30% (not 0.0%)
- **Validation Accuracy**: 65-75% (improvement from 60%)

### 🔧 **If Results Aren't Good:**
1. **Lower threshold further**: Try 0.5 or 0.6
2. **Adjust temperature**: Try 5.0 or 6.0
3. **Different consistency loss**: Try 'kl' instead of 'mse'
4. **More unlabeled data**: Increase unlabeled_ratio to 3.0

### 📈 **For Better Performance:**
1. **Longer training**: Increase epochs to 100
2. **Learning rate scheduling**: Add cosine annealing
3. **Data augmentation**: More aggressive augmentations
4. **Model size**: Try ViT-Large for better capacity

### 🏆 **Success Metrics:**
- ✅ Cons Loss > 0.001
- ✅ High-conf Pseudo > 10%
- ✅ Validation Accuracy > 65%
- ✅ Steady improvement over epochs

**🎉 Happy Training!**