<a href="https://colab.research.google.com/github/cat-thomson/ViT-FishID/blob/main/Colab_Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 🐟 ViT-FishID: Semi-Supervised Fish Classification in Google Colab

This notebook runs the ViT-FishID semi-supervised learning pipeline in Google Colab with images stored in Google Drive.

## Setup Requirements:
1. **GitHub Repository**: Your code should be in a GitHub repo (images excluded via .gitignore)
2. **Google Drive**: Your fish images should be organized in Google Drive
3. **GPU Runtime**: Enable GPU in Runtime → Change runtime type → Hardware accelerator → GPU

## 🔧 Setup and Installation

In [None]:
# Check GPU availability
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
else:
    print("⚠️ No GPU detected. Enable GPU: Runtime → Change runtime type → Hardware accelerator → GPU")

In [None]:
# Clone your GitHub repository
import os

# Replace with your actual GitHub repository URL
REPO_URL = "https://github.com/cat-thomson/ViT-FishID.git"
REPO_NAME = "ViT-FishID"

# Clone repository if not already cloned
if not os.path.exists(REPO_NAME):
    !git clone {REPO_URL}
    print(f"✅ Cloned {REPO_NAME}")
else:
    print(f"📁 {REPO_NAME} already exists")

# Change to repository directory
os.chdir(REPO_NAME)
print(f"📂 Current directory: {os.getcwd()}")

In [None]:
# Install required packages
!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install -q timm transformers
!pip install -q wandb
!pip install -q pillow opencv-python
!pip install -q scikit-learn matplotlib seaborn
!pip install -q tqdm

print("✅ Dependencies installed")

## 📁 Google Drive Setup

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Set your Google Drive paths
# UPDATE THESE PATHS TO MATCH YOUR GOOGLE DRIVE STRUCTURE
DRIVE_ROOT = "/content/drive/MyDrive"
FISH_IMAGES_PATH = f"{DRIVE_ROOT}/Fish_Images"  # ⚠️ UPDATE THIS PATH
OUTPUT_PATH = f"{DRIVE_ROOT}/Fish_Training_Output"  # Where to save results

print(f"📁 Drive mounted at: {DRIVE_ROOT}")
print(f"🐟 Fish images path: {FISH_IMAGES_PATH}")
print(f"💾 Output path: {OUTPUT_PATH}")

# Create output directory if it doesn't exist
os.makedirs(OUTPUT_PATH, exist_ok=True)

# List contents of fish images directory
if os.path.exists(FISH_IMAGES_PATH):
    print(f"\n📋 Contents of {FISH_IMAGES_PATH}:")
    !ls -la "{FISH_IMAGES_PATH}"
else:
    print(f"⚠️ Path {FISH_IMAGES_PATH} not found. Please update FISH_IMAGES_PATH variable above.")

## 📊 Data Organization and Preparation

This section helps you organize your fish images from Google Drive into the proper structure for training.

In [None]:
# Check if data is already organized
ORGANIZED_DATA_PATH = f"{OUTPUT_PATH}/organized_fish_dataset"

def check_data_structure(path):
    """Check if data is properly organized for training."""
    labeled_path = os.path.join(path, "labeled")
    unlabeled_path = os.path.join(path, "unlabeled")
    
    if os.path.exists(labeled_path) and os.path.exists(unlabeled_path):
        print(f"✅ Organized dataset found at: {path}")
        
        # Count labeled species
        species = [d for d in os.listdir(labeled_path) if os.path.isdir(os.path.join(labeled_path, d))]
        print(f"📋 Labeled species ({len(species)}): {species}")
        
        # Count images
        labeled_count = sum([len(os.listdir(os.path.join(labeled_path, s))) for s in species if os.path.isdir(os.path.join(labeled_path, s))])
        unlabeled_files = [f for f in os.listdir(unlabeled_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
        unlabeled_count = len(unlabeled_files)
        
        print(f"🏷️ Labeled images: {labeled_count}")
        print(f"🔄 Unlabeled images: {unlabeled_count}")
        print(f"📊 Total images: {labeled_count + unlabeled_count}")
        
        return True
    else:
        print(f"❌ No organized dataset found at: {path}")
        return False

data_organized = check_data_structure(ORGANIZED_DATA_PATH)

In [None]:
# Organize data if not already done
if not data_organized:
    print("🔄 Organizing fish data...")
    
    # Run the organize script with your fish images
    # You may need to adjust these parameters based on your data structure
    # Update the species list below based on your fish species
    !python organize_fish_data.py \
        --input_dir "{FISH_IMAGES_PATH}" \
        --output_dir "{ORGANIZED_DATA_PATH}" \
        --labeled_species bass trout salmon tuna cod \
        --no-interactive
    
    # Check if organization was successful
    data_organized = check_data_structure(ORGANIZED_DATA_PATH)
    
    if data_organized:
        print("✅ Data organization complete!")
    else:
        print("❌ Data organization failed. Please check your input path and data structure.")
        print("\nTroubleshooting tips:")
        print("1. Make sure FISH_IMAGES_PATH points to your fish images in Google Drive")
        print("2. Update the --labeled_species list to match your fish species")
        print("3. Your images should be in common formats (jpg, jpeg, png)")
else:
    print("✅ Data already organized, skipping organization step.")

## 🎯 Training Configuration

In [None]:
# Setup Weights & Biases (optional but recommended)
import wandb

# Login to wandb (you'll need to enter your API key)
print("🔑 Setting up Weights & Biases...")
print("If you don't have a W&B account, create one at https://wandb.ai/")
print("Get your API key from https://wandb.ai/authorize")

try:
    wandb.login()
    USE_WANDB = True
    print("✅ W&B login successful!")
except:
    print("⚠️ W&B login failed. Training will continue without logging.")
    USE_WANDB = False

In [None]:
# Training configuration optimized for Google Colab
from datetime import datetime

TRAINING_CONFIG = {
    # Data settings
    'data_dir': ORGANIZED_DATA_PATH,
    'batch_size': 16,  # Reduced for Colab memory limitations
    'image_size': 224,
    'num_workers': 2,  # Reduced for Colab
    'unlabeled_ratio': 2.0,
    
    # Model settings
    'model_name': 'vit_base_patch16_224',
    'pretrained': True,
    'dropout_rate': 0.1,
    
    # Training settings (optimized for Colab)
    'epochs': 50,  # Reduced for faster training
    'learning_rate': 1e-4,
    'weight_decay': 0.05,
    'warmup_epochs': 5,  # Reduced
    'ramp_up_epochs': 10,  # Reduced
    
    # Semi-supervised settings
    'ema_momentum': 0.999,
    'consistency_loss': 'mse',
    'consistency_weight': 1.0,
    'pseudo_label_threshold': 0.95,
    'temperature': 4.0,
    
    # System settings
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'seed': 42,
    
    # Saving settings
    'save_dir': f'{OUTPUT_PATH}/checkpoints',
    'save_frequency': 5,  # Save every 5 epochs
    'use_wandb': USE_WANDB,
    'wandb_project': 'vit-fish-colab',
    'wandb_run_name': f'colab-run-{datetime.now().strftime("%Y%m%d-%H%M%S")}'
}

print("🎯 Training Configuration:")
for key, value in TRAINING_CONFIG.items():
    print(f"  {key}: {value}")

# Create checkpoint directory
os.makedirs(TRAINING_CONFIG['save_dir'], exist_ok=True)
print(f"\n📁 Checkpoints will be saved to: {TRAINING_CONFIG['save_dir']}")

## 🚀 Training

In [None]:
# Run semi-supervised training
print(f"🚀 Starting semi-supervised training at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"📊 Training for {TRAINING_CONFIG['epochs']} epochs with batch size {TRAINING_CONFIG['batch_size']}")
print(f"🎯 Using device: {TRAINING_CONFIG['device']}")

# Build command arguments
args = [
    f"--data_dir \"{TRAINING_CONFIG['data_dir']}\"",
    f"--batch_size {TRAINING_CONFIG['batch_size']}",
    f"--image_size {TRAINING_CONFIG['image_size']}",
    f"--num_workers {TRAINING_CONFIG['num_workers']}",
    f"--unlabeled_ratio {TRAINING_CONFIG['unlabeled_ratio']}",
    f"--model_name {TRAINING_CONFIG['model_name']}",
    f"--epochs {TRAINING_CONFIG['epochs']}",
    f"--learning_rate {TRAINING_CONFIG['learning_rate']}",
    f"--weight_decay {TRAINING_CONFIG['weight_decay']}",
    f"--warmup_epochs {TRAINING_CONFIG['warmup_epochs']}",
    f"--ramp_up_epochs {TRAINING_CONFIG['ramp_up_epochs']}",
    f"--ema_momentum {TRAINING_CONFIG['ema_momentum']}",
    f"--consistency_loss {TRAINING_CONFIG['consistency_loss']}",
    f"--consistency_weight {TRAINING_CONFIG['consistency_weight']}",
    f"--pseudo_label_threshold {TRAINING_CONFIG['pseudo_label_threshold']}",
    f"--temperature {TRAINING_CONFIG['temperature']}",
    f"--device {TRAINING_CONFIG['device']}",
    f"--seed {TRAINING_CONFIG['seed']}",
    f"--save_dir \"{TRAINING_CONFIG['save_dir']}\"",
    f"--save_frequency {TRAINING_CONFIG['save_frequency']}"
]

if TRAINING_CONFIG['use_wandb']:
    args.extend([
        "--use_wandb",
        f"--wandb_project {TRAINING_CONFIG['wandb_project']}",
        f"--wandb_run_name {TRAINING_CONFIG['wandb_run_name']}"
    ])

cmd = f"python main_semi_supervised.py {' '.join(args)}"

print("\nExecuting command:")
print(cmd)
print("\n" + "="*80)

# Execute training
!{cmd}

## 📊 Training Monitoring and Results

In [None]:
# Check training results
import glob
from datetime import datetime

checkpoint_dir = TRAINING_CONFIG['save_dir']
checkpoints = glob.glob(os.path.join(checkpoint_dir, "*.pth"))

print(f"🔍 Checking results in: {checkpoint_dir}")

if checkpoints:
    print(f"✅ Training completed! Found {len(checkpoints)} checkpoints:")
    for ckpt in sorted(checkpoints):
        size_mb = os.path.getsize(ckpt) / (1024 * 1024)
        mod_time = datetime.fromtimestamp(os.path.getmtime(ckpt)).strftime('%Y-%m-%d %H:%M:%S')
        print(f"  📁 {os.path.basename(ckpt)} ({size_mb:.1f} MB, modified: {mod_time})")
    
    # Find best model
    best_models = [ckpt for ckpt in checkpoints if 'best' in ckpt]
    if best_models:
        print(f"\n🏆 Best model: {os.path.basename(best_models[0])}")
        
    # Find latest model
    latest_models = [ckpt for ckpt in checkpoints if 'latest' in ckpt or 'final' in ckpt]
    if latest_models:
        print(f"📊 Latest model: {os.path.basename(latest_models[0])}")
        
else:
    print("❌ No checkpoints found. Training may have failed or is still running.")
    print("Check the output above for any error messages.")

# Check if W&B was used
if USE_WANDB:
    print(f"\n📈 View training metrics at: https://wandb.ai/{wandb.api.default_entity}/{TRAINING_CONFIG['wandb_project']}")
    print(f"🔗 Run name: {TRAINING_CONFIG['wandb_run_name']}")

# Check for log files
log_files = glob.glob(os.path.join(checkpoint_dir, "*.json")) + glob.glob(os.path.join(checkpoint_dir, "*.txt"))
if log_files:
    print(f"\n📋 Found {len(log_files)} log files:")
    for log in log_files:
        print(f"  📄 {os.path.basename(log)}")

In [None]:
# Plot training results (if available)
import matplotlib.pyplot as plt
import json
import numpy as np

# Look for training history/log files
checkpoint_dir = TRAINING_CONFIG['save_dir']
log_files = glob.glob(os.path.join(checkpoint_dir, "*history*.json")) + \
           glob.glob(os.path.join(checkpoint_dir, "*log*.json")) + \
           glob.glob(os.path.join(checkpoint_dir, "training_*.json"))

if log_files:
    print("📊 Plotting training results...")
    
    try:
        # Load the most recent log file
        latest_log = max(log_files, key=os.path.getmtime)
        print(f"📄 Loading results from: {os.path.basename(latest_log)}")
        
        with open(latest_log, 'r') as f:
            history = json.load(f)
        
        # Create subplots
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        fig.suptitle('ViT-FishID Training Results', fontsize=16, fontweight='bold')
        
        # Plot 1: Training and Validation Loss
        ax1 = axes[0, 0]
        if 'train_loss' in history and 'val_loss' in history:
            epochs = range(1, len(history['train_loss']) + 1)
            ax1.plot(epochs, history['train_loss'], 'b-', label='Training Loss', linewidth=2)
            ax1.plot(epochs, history['val_loss'], 'r-', label='Validation Loss', linewidth=2)
            ax1.set_title('Training and Validation Loss', fontweight='bold')
            ax1.set_xlabel('Epoch')
            ax1.set_ylabel('Loss')
            ax1.legend()
            ax1.grid(True, alpha=0.3)
        
        # Plot 2: Validation Accuracy
        ax2 = axes[0, 1]
        if 'val_acc_student' in history or 'val_acc' in history:
            val_acc_key = 'val_acc_student' if 'val_acc_student' in history else 'val_acc'
            epochs = range(1, len(history[val_acc_key]) + 1)
            ax2.plot(epochs, [acc * 100 for acc in history[val_acc_key]], 'g-', linewidth=2, label='Student')
            
            if 'val_acc_teacher' in history:
                ax2.plot(epochs, [acc * 100 for acc in history['val_acc_teacher']], 'orange', linewidth=2, label='Teacher')
                ax2.legend()
            
            ax2.set_title('Validation Accuracy', fontweight='bold')
            ax2.set_xlabel('Epoch')
            ax2.set_ylabel('Accuracy (%)')
            ax2.grid(True, alpha=0.3)
            
            # Show best accuracy
            best_acc = max(history[val_acc_key]) * 100
            ax2.axhline(y=best_acc, color='red', linestyle='--', alpha=0.7)
            ax2.text(0.02, 0.98, f'Best: {best_acc:.1f}%', transform=ax2.transAxes, 
                    verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
        
        # Plot 3: Consistency Loss
        ax3 = axes[1, 0]
        if 'consistency_loss' in history:
            epochs = range(1, len(history['consistency_loss']) + 1)
            ax3.plot(epochs, history['consistency_loss'], 'purple', linewidth=2)
            ax3.set_title('Consistency Loss (Teacher-Student)', fontweight='bold')
            ax3.set_xlabel('Epoch')
            ax3.set_ylabel('Consistency Loss')
            ax3.grid(True, alpha=0.3)
        
        # Plot 4: Learning Rate Schedule
        ax4 = axes[1, 1]
        if 'learning_rate' in history:
            epochs = range(1, len(history['learning_rate']) + 1)
            ax4.plot(epochs, history['learning_rate'], 'brown', linewidth=2)
            ax4.set_title('Learning Rate Schedule', fontweight='bold')
            ax4.set_xlabel('Epoch')
            ax4.set_ylabel('Learning Rate')
            ax4.set_yscale('log')
            ax4.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
        
        # Print summary statistics
        print("\n📈 Training Summary:")
        if 'val_acc_student' in history or 'val_acc' in history:
            val_acc_key = 'val_acc_student' if 'val_acc_student' in history else 'val_acc'
            best_acc = max(history[val_acc_key]) * 100
            final_acc = history[val_acc_key][-1] * 100
            print(f"  🎯 Best Validation Accuracy: {best_acc:.2f}%")
            print(f"  📊 Final Validation Accuracy: {final_acc:.2f}%")
        
        if 'train_loss' in history:
            initial_loss = history['train_loss'][0]
            final_loss = history['train_loss'][-1]
            print(f"  📉 Loss Reduction: {initial_loss:.3f} → {final_loss:.3f} ({((initial_loss - final_loss) / initial_loss * 100):.1f}% improvement)")
        
        print(f"  ⏱️ Total Epochs Completed: {len(history.get('train_loss', []))}")
        
    except Exception as e:
        print(f"❌ Could not plot results: {e}")
        print("Available keys in history:", list(history.keys()) if 'history' in locals() else "No history loaded")
else:
    print("📊 No training log files found for plotting.")
    print("Training logs are typically saved as JSON files in the checkpoint directory.")

## 💾 Download Results

In [None]:
# Create and download a zip file with all training results
from google.colab import files
import zipfile
from datetime import datetime

# Create a timestamped zip file
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
results_zip = f"{OUTPUT_PATH}/vit_fish_results_{timestamp}.zip"

print("📦 Packaging training results...")

with zipfile.ZipFile(results_zip, 'w', zipfile.ZIP_DEFLATED) as zipf:
    files_added = 0
    
    # Add model checkpoints
    checkpoint_dir = TRAINING_CONFIG['save_dir']
    for ckpt in glob.glob(os.path.join(checkpoint_dir, "*.pth")):
        arcname = f"checkpoints/{os.path.basename(ckpt)}"
        zipf.write(ckpt, arcname)
        files_added += 1
        print(f"  ✅ Added checkpoint: {os.path.basename(ckpt)}")
    
    # Add log files
    for log in glob.glob(os.path.join(checkpoint_dir, "*.json")):
        arcname = f"logs/{os.path.basename(log)}"
        zipf.write(log, arcname)
        files_added += 1
        print(f"  ✅ Added log: {os.path.basename(log)}")
    
    # Add text files (training logs, etc.)
    for txt in glob.glob(os.path.join(checkpoint_dir, "*.txt")):
        arcname = f"logs/{os.path.basename(txt)}"
        zipf.write(txt, arcname)
        files_added += 1
        print(f"  ✅ Added text file: {os.path.basename(txt)}")
    
    # Add dataset info if available
    dataset_info = os.path.join(ORGANIZED_DATA_PATH, "dataset_info.json")
    if os.path.exists(dataset_info):
        zipf.write(dataset_info, "dataset_info.json")
        files_added += 1
        print(f"  ✅ Added dataset info")
    
    # Add a summary file
    summary_content = f"""ViT-FishID Training Results Summary
Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}

Configuration:
- Model: {TRAINING_CONFIG['model_name']}
- Epochs: {TRAINING_CONFIG['epochs']}
- Batch Size: {TRAINING_CONFIG['batch_size']}
- Learning Rate: {TRAINING_CONFIG['learning_rate']}
- Device: {TRAINING_CONFIG['device']}
- Dataset: {TRAINING_CONFIG['data_dir']}

Files included:
- {len(glob.glob(os.path.join(checkpoint_dir, '*.pth')))} model checkpoints
- {len(glob.glob(os.path.join(checkpoint_dir, '*.json')))} JSON log files
- {len(glob.glob(os.path.join(checkpoint_dir, '*.txt')))} text log files

Usage:
1. Extract the zip file
2. Load the best model checkpoint with PyTorch
3. Use for inference on new fish images

For questions or support, refer to the ViT-FishID repository.
"""
    
    zipf.writestr("README.txt", summary_content)
    files_added += 1

print(f"\n📦 Results packaged successfully!")
print(f"📁 File: {os.path.basename(results_zip)}")
print(f"📊 Size: {os.path.getsize(results_zip) / (1024*1024):.1f} MB")
print(f"📋 Files included: {files_added}")

# Download the zip file
if files_added > 0:
    print("\n⬇️ Downloading results to your computer...")
    files.download(results_zip)
    print("✅ Download complete!")
else:
    print("⚠️ No files to download. Training may not have completed successfully.")

## 🔮 Quick Inference Demo (Optional)

In [None]:
# Quick test of the trained model (if checkpoints exist)
import torch
from PIL import Image
import torchvision.transforms as transforms

checkpoint_dir = TRAINING_CONFIG['save_dir']
best_models = glob.glob(os.path.join(checkpoint_dir, "*best*.pth"))

if best_models and os.path.exists(ORGANIZED_DATA_PATH):
    print("🔮 Testing the trained model...")
    
    try:
        # Load the best model
        model_path = best_models[0]
        print(f"📁 Loading model: {os.path.basename(model_path)}")
        
        # Load model (you might need to adjust this based on how the model was saved)
        checkpoint = torch.load(model_path, map_location='cpu')
        
        # Get some test images
        test_images = []
        labeled_dir = os.path.join(ORGANIZED_DATA_PATH, "labeled")
        
        if os.path.exists(labeled_dir):
            for species_dir in os.listdir(labeled_dir)[:3]:  # Test first 3 species
                species_path = os.path.join(labeled_dir, species_dir)
                if os.path.isdir(species_path):
                    images = [f for f in os.listdir(species_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))][:2]
                    for img in images:
                        test_images.append((os.path.join(species_path, img), species_dir))
        
        if test_images:
            print(f"🐟 Testing on {len(test_images)} sample images...")
            
            # Define transforms
            transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
            
            for img_path, true_species in test_images[:3]:  # Show first 3 results
                try:
                    # Load and preprocess image
                    image = Image.open(img_path).convert('RGB')
                    input_tensor = transform(image).unsqueeze(0)
                    
                    print(f"  📸 {os.path.basename(img_path)} (True: {true_species})")
                    print(f"    ✅ Image loaded and preprocessed successfully")
                    
                except Exception as e:
                    print(f"    ❌ Error processing {img_path}: {e}")
        else:
            print("🔍 No test images found in labeled directory")
            
    except Exception as e:
        print(f"❌ Error testing model: {e}")
        print("This is normal - the demo requires the exact model architecture to be loaded.")
        
else:
    print("🔮 Skipping inference demo (no trained model or dataset found)")
    
print("\n💡 To use your trained model for inference:")
print("1. Download the results zip file")
print("2. Load the best checkpoint with your ViT model architecture")
print("3. Use the same preprocessing transforms as during training")
print("4. Run inference on new fish images")

## 🧹 Cleanup (Optional)

In [None]:
# Clean up temporary files to free Google Drive space (optional)
import shutil

print("🧹 Cleanup options to free Google Drive space:")
print("1. Keep everything (recommended if you want to resume training)")
print("2. Remove organized dataset (saves space, keep only results)")
print("3. Remove checkpoints (keep only the downloaded zip)")
print("4. Remove everything except results zip")

# Calculate current usage
def get_directory_size(path):
    if not os.path.exists(path):
        return 0
    total = 0
    for dirpath, dirnames, filenames in os.walk(path):
        for filename in filenames:
            filepath = os.path.join(dirpath, filename)
            try:
                total += os.path.getsize(filepath)
            except (OSError, IOError):
                pass
    return total

dataset_size = get_directory_size(ORGANIZED_DATA_PATH) / (1024**3)
checkpoint_size = get_directory_size(TRAINING_CONFIG['save_dir']) / (1024**3)
total_size = dataset_size + checkpoint_size

print(f"\n📊 Current usage:")
print(f"  📁 Organized dataset: {dataset_size:.2f} GB")
print(f"  💾 Checkpoints: {checkpoint_size:.2f} GB")
print(f"  📊 Total: {total_size:.2f} GB")

cleanup_choice = input("\nEnter choice (1-4): ")

if cleanup_choice == "2":
    if os.path.exists(ORGANIZED_DATA_PATH):
        shutil.rmtree(ORGANIZED_DATA_PATH)
        print(f"🗑️ Removed organized dataset: {ORGANIZED_DATA_PATH}")
        print(f"💾 Freed ~{dataset_size:.2f} GB of space")
        
elif cleanup_choice == "3":
    checkpoint_dir = TRAINING_CONFIG['save_dir']
    if os.path.exists(checkpoint_dir):
        shutil.rmtree(checkpoint_dir)
        print(f"🗑️ Removed checkpoints: {checkpoint_dir}")
        print(f"💾 Freed ~{checkpoint_size:.2f} GB of space")
        
elif cleanup_choice == "4":
    if os.path.exists(ORGANIZED_DATA_PATH):
        shutil.rmtree(ORGANIZED_DATA_PATH)
    checkpoint_dir = TRAINING_CONFIG['save_dir']
    if os.path.exists(checkpoint_dir):
        shutil.rmtree(checkpoint_dir)
    print("🗑️ Cleaned up all temporary files")
    print(f"💾 Freed ~{total_size:.2f} GB of space")
    print("📦 Your results are still available in the downloaded zip file")
    
else:
    print("✅ No cleanup performed - all files retained")

print("\n🎉 Training complete!")
print("\n📋 Summary of what you have:")
print("  ✅ Trained ViT model for fish classification")
print("  ✅ Training logs and metrics")
print("  ✅ Downloaded results zip file")
if USE_WANDB:
    print(f"  ✅ W&B dashboard: https://wandb.ai/{wandb.api.default_entity}/{TRAINING_CONFIG['wandb_project']}")
print("\n🚀 Ready to classify fish with your trained model!")