# 🚀 ViT-FishID: Resume Training from Epoch 19

**COLAB PRO EXTENDED TRAINING**
- Resume from: Epoch 19 checkpoint
- Target epochs: 100 total epochs (81 remaining)
- Expected training time: 6-8 hours with Colab Pro
- GPU: Tesla T4/V100/A100 (depending on availability)

This notebook will:
1. ✅ Resume training from your saved checkpoint at epoch 19
2. ✅ Train for 100 total epochs (81 more epochs)
3. ✅ Save checkpoints to Google Drive every 10 epochs
4. ✅ Use semi-supervised learning with your fish dataset

<a href="https://colab.research.google.com/github/cat-thomson/ViT-FishID/blob/main/ViT_FishID_Colab_Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 🐟 ViT-FishID: Extended Training Session

**RESUME FROM EPOCH 19 - COLAB PRO**

This notebook resumes training from your saved checkpoint and runs for 100 total epochs.

**Current Status:**
- ✅ Previous training: 19 epochs completed
- 🎯 Target: 100 total epochs (81 remaining)
- ⏱️ Expected time: 6-8 hours with Colab Pro
- 💾 Auto-save every 10 epochs to Google Drive

**Performance Target:**
- Previous: ~78% validation accuracy at epoch 19
- Expected: 85-90% accuracy after 100 epochs
- Memory: ~8-12GB GPU memory

## 🚀 Step 1: Setup and GPU Check

In [2]:
# Check GPU availability
import torch
import os

print("🔍 System Information:")
print(f"Python version: {os.sys.version}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"GPU Device: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
    print("✅ GPU is ready for training!")
else:
    print("❌ No GPU detected. Please enable GPU runtime:")
    print("   Runtime → Change runtime type → Hardware accelerator → GPU")

🔍 System Information:
Python version: 3.11.13 (main, Jun  4 2025, 08:57:29) [GCC 11.4.0]
PyTorch version: 2.6.0+cu124
CUDA available: True
GPU Device: NVIDIA A100-SXM4-40GB
GPU Memory: 39.6 GB
✅ GPU is ready for training!


## 📁 Step 2: Mount Google Drive

This will give us access to your fish dataset stored in Google Drive.

In [3]:
from google.colab import drive
import os

# Mount Google Drive
drive.mount('/content/drive')

# List contents to verify mount
print("\n📂 Google Drive contents:")
drive_path = '/content/drive/MyDrive'
if os.path.exists(drive_path):
    items = os.listdir(drive_path)[:10]  # Show first 10 items
    for item in items:
        print(f"  - {item}")
    if len(os.listdir(drive_path)) > 10:
        print(f"  ... and {len(os.listdir(drive_path)) - 10} more items")
    print("\n✅ Google Drive mounted successfully!")
else:
    print("❌ Failed to mount Google Drive")

Mounted at /content/drive

📂 Google Drive contents:
  - Mock Matric
  - Photos
  - Admin
  - Uni
  - Fish_Training_Output
  - Colab Notebooks
  - ViT-FishID
  - fish_cutouts.zip
  - ViT-FishID_Training_20250814_154652

✅ Google Drive mounted successfully!


## 📦 Step 3: Install Dependencies

Installing all required packages for ViT-FishID training.

In [4]:
# Install required packages
print("📦 Installing dependencies...")

!pip install -q torch torchvision torchaudio
!pip install -q timm transformers
!pip install -q albumentations
!pip install -q wandb
!pip install -q opencv-python-headless
!pip install -q scikit-learn
!pip install -q matplotlib seaborn
!pip install -q tqdm

print("✅ All dependencies installed successfully!")

# Verify installations
import torch
import torchvision
import timm
import albumentations
import cv2
import sklearn

print("\n📋 Package versions:")
print(f"  - torch: {torch.__version__}")
print(f"  - torchvision: {torchvision.__version__}")
print(f"  - timm: {timm.__version__}")
print(f"  - albumentations: {albumentations.__version__}")
print(f"  - opencv: {cv2.__version__}")
print(f"  - sklearn: {sklearn.__version__}")

📦 Installing dependencies...
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m128.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m99.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m60.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m1.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m9.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m37.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m8.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m

## 🔄 Step 4: Clone ViT-FishID Repository

Getting the latest code from your GitHub repository.

In [5]:
# Clone the repository
import os

# Remove existing directory if it exists
if os.path.exists('/content/ViT-FishID'):
    !rm -rf /content/ViT-FishID

# Clone the repository
print("📥 Cloning ViT-FishID repository...")
!git clone https://github.com/cat-thomson/ViT-FishID.git /content/ViT-FishID

# Change to project directory
%cd /content/ViT-FishID

# List project files
print("\n📂 Project structure:")
!ls -la

print("\n✅ Repository cloned successfully!")

📥 Cloning ViT-FishID repository...
Cloning into '/content/ViT-FishID'...
remote: Enumerating objects: 116, done.[K
remote: Counting objects: 100% (116/116), done.[K
remote: Compressing objects: 100% (83/83), done.[K
remote: Total 116 (delta 42), reused 98 (delta 27), pack-reused 0 (from 0)[K
Receiving objects: 100% (116/116), 187.84 KiB | 17.08 MiB/s, done.
Resolving deltas: 100% (42/42), done.
/content/ViT-FishID

📂 Project structure:
total 268
drwxr-xr-x 4 root root  4096 Aug 14 20:21 .
drwxr-xr-x 1 root root  4096 Aug 14 20:21 ..
-rw-r--r-- 1 root root 21217 Aug 14 20:21 data.py
-rw-r--r-- 1 root root 11572 Aug 14 20:21 evaluate.py
-rw-r--r-- 1 root root  3328 Aug 14 20:21 EXTENDED_TRAINING_SETUP.md
drwxr-xr-x 2 root root  4096 Aug 14 20:21 fish_cutouts
drwxr-xr-x 8 root root  4096 Aug 14 20:21 .git
-rw-r--r-- 1 root root    66 Aug 14 20:21 .gitattributes
-rw-r--r-- 1 root root   646 Aug 14 20:21 .gitignore
-rw-r--r-- 1 root root  9495 Aug 14 20:21 model.py
-rw-r--r-- 1 root roo

## 🗂️ Step 5: Setup Data Path and Extraction

**IMPORTANT:** Specify the path to your fish dataset ZIP file in Google Drive.

This step will:
1. Locate your `fish_cutouts.zip` file in Google Drive
2. Extract it to Colab's local storage for faster access
3. Validate the data structure

Expected structure after extraction:
```
fish_cutouts/
├── labeled/
│   ├── species_1/
│   │   ├── fish_001.jpg
│   │   └── fish_002.jpg
│   └── species_2/
│       └── ...
└── unlabeled/
    ├── fish_003.jpg
    └── fish_004.jpg
```

In [6]:
# Setup data path and extraction - CORRECTED PATHS
import zipfile
import shutil
import time
import os

print("🗂️ SETTING UP FISH DATASET - CORRECTED PATHS")
print("="*50)

# Configuration - CORRECTED file paths
ZIP_FILE_PATH = '/content/drive/MyDrive/fish_cutouts.zip'  # Correct location
DATA_DIR = '/content/fish_cutouts'

print(f"🎯 ZIP file location: {ZIP_FILE_PATH}")
print(f"🎯 Target data directory: {DATA_DIR}")

# Check if data already exists locally (from previous session)
if os.path.exists(DATA_DIR) and os.path.exists(os.path.join(DATA_DIR, 'labeled')):
    print("✅ Data already available locally from previous session!")

    # Quick validation
    labeled_dir = os.path.join(DATA_DIR, 'labeled')
    unlabeled_dir = os.path.join(DATA_DIR, 'unlabeled')

    if os.path.exists(labeled_dir):
        labeled_species = [d for d in os.listdir(labeled_dir)
                          if os.path.isdir(os.path.join(labeled_dir, d)) and not d.startswith('.')]
        print(f"🐟 Found {len(labeled_species)} labeled species")

    if os.path.exists(unlabeled_dir):
        unlabeled_files = [f for f in os.listdir(unlabeled_dir)
                          if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
        print(f"📊 Found {len(unlabeled_files)} unlabeled images")

    print("✅ Data validation passed - ready for training!")

else:
    print("📥 Data not found locally, extracting from Google Drive...")

    # Check if ZIP file exists
    if os.path.exists(ZIP_FILE_PATH):
        print(f"✅ Found ZIP file at: {ZIP_FILE_PATH}")
        print(f"📏 ZIP file size: {os.path.getsize(ZIP_FILE_PATH) / (1024**2):.1f} MB")

        # Clean extraction
        temp_extract_dir = '/content/temp_fish_extract'
        if os.path.exists(temp_extract_dir):
            shutil.rmtree(temp_extract_dir)

        try:
            # Extract ZIP file directly
            print(f"📦 Extracting {os.path.basename(ZIP_FILE_PATH)}...")
            with zipfile.ZipFile(ZIP_FILE_PATH, 'r') as zip_ref:
                zip_ref.extractall(temp_extract_dir)

            print("✅ ZIP extraction completed")

            # Check what was extracted
            extracted_items = os.listdir(temp_extract_dir)
            print(f"📁 Found in ZIP: {extracted_items}")

            # Based on your description: dataset_info.json, labeled, unlabeled, MACOS
            # Look for labeled and unlabeled directories directly
            labeled_source = None
            unlabeled_source = None

            for item in extracted_items:
                item_path = os.path.join(temp_extract_dir, item)
                if item == 'labeled' and os.path.isdir(item_path):
                    labeled_source = item_path
                    print(f"✅ Found labeled directory: {item}")
                elif item == 'unlabeled' and os.path.isdir(item_path):
                    unlabeled_source = item_path
                    print(f"✅ Found unlabeled directory: {item}")
                elif item == 'dataset_info.json':
                    print(f"📄 Found dataset info: {item}")
                elif item == 'MACOS' or item == '__MACOS__':
                    print(f"🗑️ Skipping Mac system folder: {item}")

            # Create target directory and move the labeled/unlabeled folders
            if labeled_source and unlabeled_source:
                # Remove existing target if it exists
                if os.path.exists(DATA_DIR):
                    shutil.rmtree(DATA_DIR)

                # Create target directory
                os.makedirs(DATA_DIR, exist_ok=True)

                # Move labeled and unlabeled directories
                shutil.move(labeled_source, os.path.join(DATA_DIR, 'labeled'))
                shutil.move(unlabeled_source, os.path.join(DATA_DIR, 'unlabeled'))

                print(f"✅ Data organized at: {DATA_DIR}")

                # Copy dataset_info.json if it exists
                dataset_info = os.path.join(temp_extract_dir, 'dataset_info.json')
                if os.path.exists(dataset_info):
                    shutil.copy2(dataset_info, os.path.join(DATA_DIR, 'dataset_info.json'))
                    print(f"📄 Copied dataset_info.json")

                # Verify the structure
                labeled_dir = os.path.join(DATA_DIR, 'labeled')
                if os.path.exists(labeled_dir):
                    labeled_species = [d for d in os.listdir(labeled_dir)
                                     if os.path.isdir(os.path.join(labeled_dir, d)) and not d.startswith('.')]
                    print(f"🐟 Verified: {len(labeled_species)} species in labeled data")

                unlabeled_dir = os.path.join(DATA_DIR, 'unlabeled')
                if os.path.exists(unlabeled_dir):
                    unlabeled_count = len([f for f in os.listdir(unlabeled_dir)
                                         if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
                    print(f"📊 Verified: {unlabeled_count} images in unlabeled data")

            else:
                print("❌ Could not find both labeled and unlabeled directories")
                print("📁 Available items:", extracted_items)

            # Cleanup temporary extraction
            if os.path.exists(temp_extract_dir):
                shutil.rmtree(temp_extract_dir)

        except Exception as e:
            print(f"❌ Error during extraction: {e}")
            if os.path.exists(temp_extract_dir):
                shutil.rmtree(temp_extract_dir)

    else:
        print(f"❌ ZIP file not found at: {ZIP_FILE_PATH}")
        print("📝 Please ensure fish_cutouts.zip is uploaded to Google Drive root directory")

# Final verification
if os.path.exists(DATA_DIR):
    print(f"\n✅ DATASET READY")
    print(f"📁 Location: {DATA_DIR}")

    # Show structure
    for subdir in ['labeled', 'unlabeled']:
        subdir_path = os.path.join(DATA_DIR, subdir)
        if os.path.exists(subdir_path):
            if subdir == 'labeled':
                species_count = len([d for d in os.listdir(subdir_path)
                                   if os.path.isdir(os.path.join(subdir_path, d)) and not d.startswith('.')])
                print(f"  📂 {subdir}/: {species_count} species folders")
            else:
                file_count = len([f for f in os.listdir(subdir_path)
                                if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
                print(f"  📂 {subdir}/: {file_count} images")
        else:
            print(f"  ❌ {subdir}/ not found")

    # Check for dataset_info.json
    dataset_info_path = os.path.join(DATA_DIR, 'dataset_info.json')
    if os.path.exists(dataset_info_path):
        print(f"  📄 dataset_info.json: Available")

    print("🚀 Ready to proceed with training!")
else:
    print(f"\n❌ DATASET SETUP FAILED")
    print(f"📝 Please check that fish_cutouts.zip contains:")
    print(f"   fish_cutouts.zip")
    print(f"   ├── dataset_info.json")
    print(f"   ├── labeled/")
    print(f"   │   ├── species1/")
    print(f"   │   └── species2/")
    print(f"   ├── unlabeled/")
    print(f"   │   ├── image1.jpg")
    print(f"   │   └── image2.jpg")
    print(f"   └── __MACOS__ (ignored)")

🗂️ SETTING UP FISH DATASET - CORRECTED PATHS
🎯 ZIP file location: /content/drive/MyDrive/fish_cutouts.zip
🎯 Target data directory: /content/fish_cutouts
📥 Data not found locally, extracting from Google Drive...
✅ Found ZIP file at: /content/drive/MyDrive/fish_cutouts.zip
📏 ZIP file size: 216.5 MB
📦 Extracting fish_cutouts.zip...
✅ ZIP extraction completed
📁 Found in ZIP: ['dataset_info.json', '__MACOSX', 'labeled', 'unlabeled']
📄 Found dataset info: dataset_info.json
✅ Found labeled directory: labeled
✅ Found unlabeled directory: unlabeled
✅ Data organized at: /content/fish_cutouts
📄 Copied dataset_info.json
🐟 Verified: 37 species in labeled data
📊 Verified: 24015 images in unlabeled data

✅ DATASET READY
📁 Location: /content/fish_cutouts
  📂 labeled/: 37 species folders
  📂 unlabeled/: 24015 images
  📄 dataset_info.json: Available
🚀 Ready to proceed with training!


## 📊 Step 6: Setup Weights & Biases (Optional)

W&B provides excellent training visualization and experiment tracking.

## 🔄 Step 6: Locate Checkpoint from Epoch 19

Finding your saved checkpoint to resume training from where you left off.

In [7]:
# Locate checkpoint from epoch 19
import os
import glob
import torch

print("🔍 Looking for checkpoint from epoch 19...")

# Possible checkpoint locations
checkpoint_locations = [
    '/content/ViT-FishID','/content/drive/MyDrive/ViT-FishID/checkpoints', '/content/drive/MyDrive/ViT-FishID/'
]

checkpoint_path = None
checkpoint_info = None

# Search for epoch 19 checkpoint
for location_pattern in checkpoint_locations:
    for location in glob.glob(location_pattern):
        if os.path.exists(location):
            print(f"📁 Checking: {location}")

            # Look for epoch 19 specifically
            epoch_19_files = glob.glob(os.path.join(location, '*epoch_19*'))
            manual_files = glob.glob(os.path.join(location, '*manual*epoch*19*'))
            emergency_files = glob.glob(os.path.join(location, '*emergency*epoch*19*'))

            all_candidates = epoch_19_files + manual_files + emergency_files

            for candidate in all_candidates:
                if candidate.endswith('.pth'):
                    print(f"🎯 Found candidate: {os.path.basename(candidate)}")
                    try:
                        # Verify checkpoint can be loaded
                        test_checkpoint = torch.load(candidate, map_location='cpu')
                        epoch = test_checkpoint.get('epoch', 'unknown')

                        if epoch == 19 or '19' in os.path.basename(candidate):
                            checkpoint_path = candidate
                            checkpoint_info = test_checkpoint
                            print(f"✅ FOUND EPOCH 19 CHECKPOINT!")
                            print(f"📁 Location: {checkpoint_path}")
                            print(f"📊 Epoch: {epoch}")

                            if 'best_accuracy' in test_checkpoint:
                                print(f"📊 Best accuracy so far: {test_checkpoint['best_accuracy']:.2f}%")
                            elif 'best_acc' in test_checkpoint:
                                print(f"📊 Best accuracy so far: {test_checkpoint['best_acc']:.2f}%")

                            break
                    except Exception as e:
                        print(f"⚠️ Could not load {candidate}: {e}")

            if checkpoint_path:
                break

        if checkpoint_path:
            break

if checkpoint_path:
    print(f"\n🎉 Checkpoint ready for resuming training!")
    print(f"📄 File: {os.path.basename(checkpoint_path)}")
    print(f"📏 Size: {os.path.getsize(checkpoint_path) / (1024*1024):.1f} MB")

    # Set up checkpoint directory for new saves
    checkpoint_save_dir = '/content/drive/MyDrive/ViT-FishID/checkpoints_extended'
    os.makedirs(checkpoint_save_dir, exist_ok=True)
    print(f"💾 New checkpoints will be saved to: {checkpoint_save_dir}")

else:
    print("❌ No checkpoint found for epoch 19!")
    print("\n🔧 Troubleshooting:")
    print("1. Check that you have a checkpoint saved from previous training")
    print("2. Ensure the checkpoint is uploaded to Google Drive")
    print("3. Look for files named like: checkpoint_epoch_19.pth, emergency_checkpoint_epoch_19.pth")
    print("\n📁 Checked locations:")
    for location in checkpoint_locations:
        print(f"  - {location}")

    # Fallback: look for any checkpoints
    print("\n🔍 All available checkpoints:")
    for location_pattern in checkpoint_locations:
        for location in glob.glob(location_pattern):
            if os.path.exists(location):
                all_checkpoints = glob.glob(os.path.join(location, '*.pth'))
                for cp in all_checkpoints:
                    print(f"  - {os.path.basename(cp)}")

# Store checkpoint path for later use
RESUME_CHECKPOINT = checkpoint_path

🔍 Looking for checkpoint from epoch 19...
📁 Checking: /content/ViT-FishID
📁 Checking: /content/drive/MyDrive/ViT-FishID/checkpoints
🎯 Found candidate: checkpoint_epoch_19.pth
✅ FOUND EPOCH 19 CHECKPOINT!
📁 Location: /content/drive/MyDrive/ViT-FishID/checkpoints/checkpoint_epoch_19.pth
📊 Epoch: 19
📊 Best accuracy so far: 79.98%
📁 Checking: /content/drive/MyDrive/ViT-FishID/
🎯 Found candidate: checkpoint_epoch_19.pth
✅ FOUND EPOCH 19 CHECKPOINT!
📁 Location: /content/drive/MyDrive/ViT-FishID/checkpoint_epoch_19.pth
📊 Epoch: 19
📊 Best accuracy so far: 79.98%

🎉 Checkpoint ready for resuming training!
📄 File: checkpoint_epoch_19.pth
📏 Size: 1309.9 MB
💾 New checkpoints will be saved to: /content/drive/MyDrive/ViT-FishID/checkpoints_extended


## ⚙️ Step 7: Configure Training Parameters

Adjust these parameters based on your needs and available GPU memory.

In [13]:
# Training Configuration - RESUME FROM EPOCH 5 FOR 100 TOTAL EPOCHS
import os

print("🎯 EXTENDED TRAINING CONFIGURATION - WITH W&B")
print("="*50)

# Define directories first to make config cleaner
DRIVE_CHECKPOINT_BASE = '/content/drive/MyDrive/ViT-FishID'
CHECKPOINT_SAVE_DIR = os.path.join(DRIVE_CHECKPOINT_BASE, 'checkpoints_extended')
BACKUP_DIR = os.path.join(DRIVE_CHECKPOINT_BASE, 'checkpoints_backup')

TRAINING_CONFIG = {
    # RESUME SETTINGS
    # Pointing to the latest valid checkpoint found (Epoch 5)
    'resume_from_checkpoint': os.path.join(CHECKPOINT_SAVE_DIR, 'checkpoint_epoch_5.pth'),
    'start_epoch': 6,  # Next epoch after 5
    'total_epochs': 100,  # Target total epochs
    'remaining_epochs': 100 - 5, # Calculate based on total and start

    # CORE SETTINGS
    'mode': 'semi_supervised',  # semi_supervised or supervised
    'data_dir': DATA_DIR, # This variable comes from Step 5
    'batch_size': 16,  # Increased for Colab Pro
    'learning_rate': 1e-4,
    'weight_decay': 0.05,

    # MODEL SETTINGS
    'model_name': 'vit_base_patch16_224',
    'num_classes': 37,  # Will be auto-detected below

    # SEMI-SUPERVISED SETTINGS
    'consistency_weight': 2.0,
    'pseudo_label_threshold': 0.7,
    'temperature': 4.0,
    'warmup_epochs': 5,  # Reduced since we're resuming
    'ramp_up_epochs': 15,  # Reduced since we're resuming

    # CHECKPOINT SETTINGS - SAVE EVERY EPOCH
    'save_frequency': 1,  # Save EVERY epoch
    'checkpoint_dir': CHECKPOINT_SAVE_DIR,
    'backup_dir': BACKUP_DIR,

    # LOGGING - W&B ENABLED
    'use_wandb': True, # Enable W&B logging
    'wandb_project': 'ViT-FishID-Extended-Training', # Your W&B project name
    'wandb_run_name': 'resume-epoch-6-to-100', # A name for this specific run

    # Add pretrained flag here as a config item
    'pretrained': True,
}

# Verify data directory and auto-detect num_classes
if os.path.exists(TRAINING_CONFIG['data_dir']):
    labeled_dir = os.path.join(TRAINING_CONFIG['data_dir'], 'labeled')
    if os.path.exists(labeled_dir):
        species_count = len([d for d in os.listdir(labeled_dir)
                           if os.path.isdir(os.path.join(labeled_dir, d)) and not d.startswith('.')])
        TRAINING_CONFIG['num_classes'] = species_count
        print(f"📊 Detected {species_count} fish species")
    else:
         print(f"⚠️ Labeled data directory not found: {labeled_dir}. Cannot auto-detect num_classes.")
         print(f"💡 Using default num_classes: {TRAINING_CONFIG['num_classes']}")
else:
    print(f"❌ Data directory not found: {TRAINING_CONFIG['data_dir']}. Cannot auto-detect num_classes.")
    print(f"💡 Using default num_classes: {TRAINING_CONFIG['num_classes']}")


print("\nEXTENDED TRAINING CONFIGURATION SUMMARY")
print("="*50)
print(f"📊 Resume from: Epoch {TRAINING_CONFIG['start_epoch'] - 1}")
print(f"📊 Target epochs: {TRAINING_CONFIG['total_epochs']}")
print(f"📊 Remaining epochs: {TRAINING_CONFIG['remaining_epochs']}")
# Estimate time based on remaining epochs and a rough per-epoch time (e.g., 5-7 mins)
estimated_min_time = TRAINING_CONFIG['remaining_epochs'] * 5
estimated_max_time = TRAINING_CONFIG['remaining_epochs'] * 7
print(f"⏱️ Estimated time: {estimated_min_time:.0f}-{estimated_max_time:.0f} minutes")
print(f"📊 Batch size: {TRAINING_CONFIG['batch_size']} (optimized for Colab Pro)")
print(f"💾 Checkpoint saves: EVERY {TRAINING_CONFIG['save_frequency']} epoch(s)")
print(f"📊 Mode: {TRAINING_CONFIG['mode']} with consistency weight {TRAINING_CONFIG['consistency_weight']}")
print(f"📊 Logging: W&B Enabled (Project: {TRAINING_CONFIG['wandb_project']}, Run: {TRAINING_CONFIG['wandb_run_name']})")
print(f"📊 Num Classes: {TRAINING_CONFIG['num_classes']}")


# Create checkpoint directories with more robust error handling
print("\nSETTING UP CHECKPOINT DIRECTORIES")
print("="*50)
try:
    os.makedirs(TRAINING_CONFIG['checkpoint_dir'], exist_ok=True)
    print(f"📁 Primary saves: {TRAINING_CONFIG['checkpoint_dir']} (Created/Exists)")
except Exception as e:
    print(f"⚠️ Could not create primary checkpoint dir {TRAINING_CONFIG['checkpoint_dir']}: {e}")
    # Fallback to local directory if Google Drive mount is the issue
    local_fallback_dir = '/content/checkpoints_extended_local'
    TRAINING_CONFIG['checkpoint_dir'] = local_fallback_dir
    try:
        os.makedirs(TRAINING_CONFIG['checkpoint_dir'], exist_ok=True)
        print(f"📁 Primary saves (FALLBACK to local): {TRAINING_CONFIG['checkpoint_dir']} (Created/Exists)")
        print("💡 Check Google Drive mount if this is unexpected.")
    except Exception as e_local:
         print(f"❌ Could not create fallback local checkpoint dir {local_fallback_dir}: {e_local}")
         print("🚨 Check permissions or disk space.")


try:
    os.makedirs(TRAINING_CONFIG['backup_dir'], exist_ok=True)
    print(f"💾 Backup saves: {TRAINING_CONFIG['backup_dir']} (Created/Exists)")
except Exception as e:
    print(f"⚠️ Could not create backup dir {TRAINING_CONFIG['backup_dir']}: {e}")
    print("💾 Backup saves: Disabled due to Google Drive issues or permissions.")
    TRAINING_CONFIG['backup_dir'] = None # Explicitly set to None if creation fails


if TRAINING_CONFIG['resume_from_checkpoint'] and os.path.exists(TRAINING_CONFIG['resume_from_checkpoint']):
    print(f"\n✅ Will resume training from: {os.path.basename(TRAINING_CONFIG['resume_from_checkpoint'])}")
else:
    print("\n❌ Specified resume checkpoint not found or not set - will start fresh training from epoch 1")
    TRAINING_CONFIG['resume_from_checkpoint'] = None # Ensure it's None if file not found
    TRAINING_CONFIG['start_epoch'] = 1
    TRAINING_CONFIG['remaining_epochs'] = TRAINING_CONFIG['total_epochs']

print(f"\n🚀 Configuration complete. Ready to resume/start training!")

🎯 EXTENDED TRAINING CONFIGURATION - WITH W&B
📊 Detected 37 fish species

EXTENDED TRAINING CONFIGURATION SUMMARY
📊 Resume from: Epoch 5
📊 Target epochs: 100
📊 Remaining epochs: 95
⏱️ Estimated time: 475-665 minutes
📊 Batch size: 16 (optimized for Colab Pro)
💾 Checkpoint saves: EVERY 1 epoch(s)
📊 Mode: semi_supervised with consistency weight 2.0
📊 Logging: W&B Enabled (Project: ViT-FishID-Extended-Training, Run: resume-epoch-6-to-100)
📊 Num Classes: 37

SETTING UP CHECKPOINT DIRECTORIES
📁 Primary saves: /content/drive/MyDrive/ViT-FishID/checkpoints_extended (Created/Exists)
💾 Backup saves: /content/drive/MyDrive/ViT-FishID/checkpoints_backup (Created/Exists)

✅ Will resume training from: checkpoint_epoch_5.pth

🚀 Configuration complete. Ready to resume/start training!


## 🚀 Step 8: Start Training!

This cell will start the semi-supervised training process. It may take 2-3 hours to complete.

In [18]:
# Execute Extended Training - Resume from Epoch 17
import os

print("🚀 STARTING EXTENDED TRAINING SESSION")
print("="*60)

# Create checkpoint save directory
os.makedirs(TRAINING_CONFIG['checkpoint_dir'], exist_ok=True)

# Build training command for resuming
training_cmd = f"""python train.py \\
    --mode {TRAINING_CONFIG['mode']} \\
    --data_dir {TRAINING_CONFIG['data_dir']} \\
    --epochs {TRAINING_CONFIG['total_epochs']} \\
    --batch_size {TRAINING_CONFIG['batch_size']} \\
    --learning_rate {TRAINING_CONFIG['learning_rate']} \\
    --weight_decay {TRAINING_CONFIG['weight_decay']} \\
    --model_name {TRAINING_CONFIG['model_name']} \\
    --consistency_weight {TRAINING_CONFIG['consistency_weight']} \\
    --pseudo_label_threshold {TRAINING_CONFIG['pseudo_label_threshold']} \\
    --temperature {TRAINING_CONFIG['temperature']} \\
    --warmup_epochs {TRAINING_CONFIG['warmup_epochs']} \\
    --ramp_up_epochs {TRAINING_CONFIG['ramp_up_epochs']} \\
    --save_dir {TRAINING_CONFIG['checkpoint_dir']} \\
    --save_frequency {TRAINING_CONFIG['save_frequency']}"""

# Add resume checkpoint if available
# Pointing to the epoch 17 checkpoint
TRAINING_CONFIG['resume_from_checkpoint'] = os.path.join(TRAINING_CONFIG['checkpoint_dir'], 'checkpoint_epoch_17.pth')
TRAINING_CONFIG['start_epoch'] = 18 # Start from the epoch *after* the resumed checkpoint

if TRAINING_CONFIG['resume_from_checkpoint']:
    training_cmd += f" \\\n    --resume_from {TRAINING_CONFIG['resume_from_checkpoint']}"
    print(f"📂 Resuming from: {os.path.basename(TRAINING_CONFIG['resume_from_checkpoint'])}")
    print(f"🚀 Starting training from epoch: {TRAINING_CONFIG['start_epoch']}")

# Add W&B logging - Only add the --use_wandb flag
if TRAINING_CONFIG['use_wandb']:
    training_cmd += f" \\\n    --use_wandb"
    # Removed --wandb_project and --wandb_run_name as train.py doesn't recognize them

# Add pretrained flag
if TRAINING_CONFIG['pretrained']:
    training_cmd += " \\\n    --pretrained"

# Update remaining epochs based on new start_epoch
TRAINING_CONFIG['remaining_epochs'] = TRAINING_CONFIG['total_epochs'] - (TRAINING_CONFIG['start_epoch'] - 1) # Calculate based on total and start

print(f"📊 Training for {TRAINING_CONFIG['remaining_epochs']} more epochs...")
print(f"🎯 Target: {TRAINING_CONFIG['total_epochs']} total epochs")
print(f"⏱️ Estimated time: {TRAINING_CONFIG['remaining_epochs'] * 4:.0f}-{TRAINING_CONFIG['remaining_epochs'] * 6:.0f} minutes")
print(f"💾 Checkpoints saved to: {TRAINING_CONFIG['checkpoint_dir']}")

print("\n📋 Extended Training Command:")
print(training_cmd.replace('\\', '').strip())
print("\n" + "="*60)

# Execute training
print(f"🎬 TRAINING STARTED - EPOCH {TRAINING_CONFIG['start_epoch']} TO {TRAINING_CONFIG['total_epochs']}")
print("⏰ Started at:", __import__('datetime').datetime.now().strftime('%Y-%m-%d %H:%M:%S'))

# Before executing, modify trainer.py to fix the AttributeError during checkpoint saving
# This is a temporary fix directly modifying the cloned file
trainer_file_path = '/content/ViT-FishID/trainer.py'
try:
    with open(trainer_file_path, 'r') as f:
        trainer_code = f.read()

    # Find the line that saves the ema_teacher_state_dict and comment it out
    # Look for patterns like 'ema_teacher_state_dict': ...
    lines = trainer_code.splitlines()
    modified_lines = []
    ema_line_found = False
    for line in lines:
        # Check for the line saving ema_teacher_state_dict, allowing for variations in spacing/access
        if "'ema_teacher_state_dict':" in line and "state_dict()" in line:
             modified_lines.append("# " + line) # Comment out the line
             ema_line_found = True
             print(f"✅ Commented out line saving ema_teacher_state_dict: {line.strip()}")
        else:
            modified_lines.append(line)

    if ema_line_found:
        corrected_code = "\n".join(modified_lines)
        with open(trainer_file_path, 'w') as f:
            f.write(corrected_code)
        print(f"✅ Modified {trainer_file_path} to skip saving EMA teacher state_dict.")
    else:
        print(f"⚠️ Could not find the line saving ema_teacher_state_dict in {trainer_file_path}. The fix might not be applied.")
        print("💡 Training might still fail due to the EMA teacher state_dict error.")


except Exception as e:
    print(f"❌ Error modifying {trainer_file_path}: {e}")
    print("🚨 Training might still fail due to the EMA teacher state_dict error.")


!{training_cmd}

print("\n" + "="*60)
print("🎉 EXTENDED TRAINING COMPLETED!")
print("⏰ Finished at:", __import__('datetime').datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
print(f"🏆 Total epochs completed: {TRAINING_CONFIG['total_epochs']}")
print(f"💾 All checkpoints saved to Google Drive")

# Quick summary of final results
final_checkpoint = os.path.join(TRAINING_CONFIG['checkpoint_dir'], 'model_best.pth')
if os.path.exists(final_checkpoint):
    try:
        import torch
        final_results = torch.load(final_checkpoint, map_location='cpu')
        if 'best_accuracy' in final_results:
            print(f"🎯 Final best accuracy: {final_results['best_accuracy']:.2f}%")
        if 'epoch' in final_results:
            print(f"📊 Best model from epoch: {final_results['epoch']}")
    except:
        pass

print("\n✅ Your model is ready for evaluation and deployment!")

🚀 STARTING EXTENDED TRAINING SESSION
📂 Resuming from: checkpoint_epoch_17.pth
🚀 Starting training from epoch: 18
📊 Training for 83 more epochs...
🎯 Target: 100 total epochs
⏱️ Estimated time: 332-498 minutes
💾 Checkpoints saved to: /content/drive/MyDrive/ViT-FishID/checkpoints_extended

📋 Extended Training Command:
python train.py 
    --mode semi_supervised 
    --data_dir /content/fish_cutouts 
    --epochs 100 
    --batch_size 16 
    --learning_rate 0.0001 
    --weight_decay 0.05 
    --model_name vit_base_patch16_224 
    --consistency_weight 2.0 
    --pseudo_label_threshold 0.7 
    --temperature 4.0 
    --warmup_epochs 5 
    --ramp_up_epochs 15 
    --save_dir /content/drive/MyDrive/ViT-FishID/checkpoints_extended 
    --save_frequency 1 
    --resume_from /content/drive/MyDrive/ViT-FishID/checkpoints_extended/checkpoint_epoch_17.pth 
    --use_wandb 
    --pretrained

🎬 TRAINING STARTED - EPOCH 18 TO 100
⏰ Started at: 2025-08-14 20:57:41
✅ Commented out line saving ema_tea

## 📊 Step 9: Check Training Results

In [21]:
# Check Extended Training Results (Epoch 19 → 100)
import os
import glob
import torch

checkpoint_dir = TRAINING_CONFIG['checkpoint_dir']
print(f"📁 Checking results in: {checkpoint_dir}")

if os.path.exists(checkpoint_dir):
    checkpoints = glob.glob(os.path.join(checkpoint_dir, '*.pth'))
    if checkpoints:
        print(f"\n✅ Found {len(checkpoints)} checkpoint(s) from extended training:")

        # Sort checkpoints by epoch number
        epoch_checkpoints = []
        other_checkpoints = []

        for cp in checkpoints:
            basename = os.path.basename(cp)
            if 'epoch_' in basename:
                try:
                    epoch_num = int(basename.split('epoch_')[1].split('.')[0])
                    epoch_checkpoints.append((epoch_num, cp))
                except:
                    other_checkpoints.append(cp)
            else:
                other_checkpoints.append(cp)

        # Show epoch checkpoints in order
        epoch_checkpoints.sort(key=lambda x: x[0])
        for epoch, cp in epoch_checkpoints:
            file_size = os.path.getsize(cp) / (1024**2)
            print(f"  📊 Epoch {epoch}: {os.path.basename(cp)} ({file_size:.1f} MB)")

        # Show other checkpoints
        for cp in other_checkpoints:
            file_size = os.path.getsize(cp) / (1024**2)
            print(f"  🏆 {os.path.basename(cp)} ({file_size:.1f} MB)")

        # Analyze best model
        best_model = os.path.join(checkpoint_dir, 'model_best.pth')
        if os.path.exists(best_model):
            print(f"\n🏆 BEST MODEL ANALYSIS:")
            try:
                best_checkpoint = torch.load(best_model, map_location='cpu')

                best_epoch = best_checkpoint.get('epoch', 'Unknown')
                best_acc = best_checkpoint.get('best_accuracy', best_checkpoint.get('best_acc', 'Unknown'))

                print(f"  📊 Best epoch: {best_epoch}")
                print(f"  📊 Best accuracy: {best_acc:.2f}%" if isinstance(best_acc, (int, float)) else f"  📊 Best accuracy: {best_acc}")

                # Show training progression
                if epoch_checkpoints:
                    print(f"\n📈 TRAINING PROGRESSION:")
                    print(f"  🏁 Started: Epoch 19 (resumed)")
                    print(f"  🎯 Completed: Epoch {max(epoch_checkpoints, key=lambda x: x[0])[0]}")
                    print(f"  🏆 Best: Epoch {best_epoch}")
                    print(f"  📊 Total training: {19 + len([e for e, _ in epoch_checkpoints if e > 19])} epochs")

            except Exception as e:
                print(f"  ⚠️ Could not analyze best model: {e}")

        # Training duration estimate
        if epoch_checkpoints:
            epochs_completed = len([e for e, _ in epoch_checkpoints if e > 19])
            print(f"\n⏱️ EXTENDED TRAINING SUMMARY:")
            print(f"  📊 Additional epochs completed: {epochs_completed}")
            print(f"  🎯 Target was: 81 additional epochs (to reach 100 total)")

            if epochs_completed >= 81:
                print(f"  ✅ TRAINING GOAL ACHIEVED! Completed all {epochs_completed} additional epochs")
            else:
                print(f"  ⏳ Training partially complete: {epochs_completed}/81 additional epochs")

    else:
        print("❌ No checkpoints found in extended training directory")

        # Check if training is still using old directory
        old_checkpoint_dir = '/content/ViT-FishID/checkpoints'
        if os.path.exists(old_checkpoint_dir):
            old_checkpoints = glob.glob(os.path.join(old_checkpoint_dir, '*.pth'))
            if old_checkpoints:
                print(f"\n💡 Found {len(old_checkpoints)} checkpoints in old directory:")
                print(f"   {old_checkpoint_dir}")

else:
    print("❌ Extended training checkpoint directory not found")

# W&B link
if TRAINING_CONFIG['use_wandb']:
    print(f"\n📈 View detailed training metrics:")
    print(f"   https://wandb.ai/your-username/{TRAINING_CONFIG['wandb_project']}")
    print(f"   Run: {TRAINING_CONFIG['wandb_run_name']}")

print(f"\n🎉 Extended training session complete!")
print(f"🚀 Your model trained from epoch 19 to 100!")
print(f"💾 All results saved to Google Drive: {checkpoint_dir}")

# Performance comparison
print(f"\n📊 PERFORMANCE COMPARISON:")
print(f"  🔄 Previous (Epoch 19): ~78% accuracy")
print(f"  🎯 Extended (Epoch 100): Check best_accuracy above")
print(f"  📈 Expected improvement: 5-10% accuracy gain")
print(f"  🏆 Your model should now be ready for deployment!")

📁 Checking results in: /content/drive/MyDrive/ViT-FishID/checkpoints_extended

✅ Found 100 checkpoint(s) from extended training:
  📊 Epoch 1: checkpoint_epoch_1.pth (982.4 MB)
  📊 Epoch 2: checkpoint_epoch_2.pth (982.4 MB)
  📊 Epoch 3: checkpoint_epoch_3.pth (982.4 MB)
  📊 Epoch 4: checkpoint_epoch_4.pth (982.4 MB)
  📊 Epoch 5: checkpoint_epoch_5.pth (982.4 MB)
  📊 Epoch 6: checkpoint_epoch_6.pth (982.4 MB)
  📊 Epoch 7: checkpoint_epoch_7.pth (982.4 MB)
  📊 Epoch 8: checkpoint_epoch_8.pth (982.4 MB)
  📊 Epoch 9: checkpoint_epoch_9.pth (982.4 MB)
  📊 Epoch 10: checkpoint_epoch_10.pth (982.4 MB)
  📊 Epoch 11: checkpoint_epoch_11.pth (982.4 MB)
  📊 Epoch 12: checkpoint_epoch_12.pth (982.4 MB)
  📊 Epoch 13: checkpoint_epoch_13.pth (982.4 MB)
  📊 Epoch 14: checkpoint_epoch_14.pth (982.4 MB)
  📊 Epoch 15: checkpoint_epoch_15.pth (982.4 MB)
  📊 Epoch 16: checkpoint_epoch_16.pth (982.4 MB)
  📊 Epoch 17: checkpoint_epoch_17.pth (982.4 MB)
  📊 Epoch 18: checkpoint_epoch_18.pth (982.4 MB)
  📊 Epo

## 💾 Step 10: Download Model and Results

Save your trained model to Google Drive for future use.

In [22]:
# Copy trained model to Google Drive
import shutil
from datetime import datetime

# Create a timestamped folder in Google Drive
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
save_dir = f'/content/drive/MyDrive/ViT-FishID_Training_{timestamp}'
os.makedirs(save_dir, exist_ok=True)

print(f"💾 Saving results to Google Drive: {save_dir}")

# Copy checkpoints
checkpoint_dir = '/content/ViT-FishID/checkpoints'
if os.path.exists(checkpoint_dir):
    drive_checkpoint_dir = os.path.join(save_dir, 'checkpoints')
    shutil.copytree(checkpoint_dir, drive_checkpoint_dir)
    print(f"✅ Checkpoints saved to: {drive_checkpoint_dir}")

# Save training configuration
import json
config_file = os.path.join(save_dir, 'training_config.json')
with open(config_file, 'w') as f:
    json.dump(TRAINING_CONFIG, f, indent=2)
print(f"✅ Training config saved to: {config_file}")

# Create a summary file
summary_file = os.path.join(save_dir, 'training_summary.txt')
with open(summary_file, 'w') as f:
    f.write(f"ViT-FishID Training Summary\n")
    f.write(f"========================\n\n")
    f.write(f"Training Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
    f.write(f"Mode: {TRAINING_CONFIG['mode']}\n")
    f.write(f"Epochs: {TRAINING_CONFIG['epochs']}\n")
    f.write(f"Batch Size: {TRAINING_CONFIG['batch_size']}\n")
    f.write(f"Data Directory: {DATA_DIR}\n")
    f.write(f"\nModel Architecture: {TRAINING_CONFIG['model_name']}\n")
    f.write(f"Learning Rate: {TRAINING_CONFIG['learning_rate']}\n")
    f.write(f"Consistency Weight: {TRAINING_CONFIG['consistency_weight']}\n")
    f.write(f"\nCheckpoints saved in: checkpoints/\n")
    f.write(f"Best model: checkpoints/model_best.pth\n")

print(f"✅ Training summary saved to: {summary_file}")

print(f"\n🎉 All results saved to Google Drive!")
print(f"📁 Location: {save_dir}")
print(f"\n💡 You can now:")
print(f"   1. Download the checkpoints folder for local use")
print(f"   2. Use model_best.pth for inference")
print(f"   3. Continue training from any checkpoint")

💾 Saving results to Google Drive: /content/drive/MyDrive/ViT-FishID_Training_20250814_233546
✅ Training config saved to: /content/drive/MyDrive/ViT-FishID_Training_20250814_233546/training_config.json


KeyError: 'epochs'

## 🧪 Step 11: Quick Model Evaluation (Optional)

Test your trained model on a few sample images.

In [None]:
# Quick evaluation of the trained model
import torch
import torch.nn.functional as F
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

# Check if best model exists
best_model_path = '/content/ViT-FishID/checkpoints/model_best.pth'

if os.path.exists(best_model_path):
    print("🧪 Loading trained model for quick evaluation...")

    # Load model checkpoint info
    checkpoint = torch.load(best_model_path, map_location='cpu')

    print(f"📊 Model training info:")
    if 'epoch' in checkpoint:
        print(f"  - Best epoch: {checkpoint['epoch']}")
    if 'best_acc' in checkpoint:
        print(f"  - Best accuracy: {checkpoint['best_acc']:.2f}%")
    if 'teacher_acc' in checkpoint:
        print(f"  - Teacher accuracy: {checkpoint['teacher_acc']:.2f}%")

    # Get class names if available
    if 'class_names' in checkpoint:
        class_names = checkpoint['class_names']
        print(f"  - Number of classes: {len(class_names)}")
        print(f"  - Sample classes: {class_names[:5]}...")

    print("\n✅ Model evaluation completed! Check the metrics above.")

else:
    print("❌ No trained model found. Make sure training completed successfully.")

print("\n💡 For comprehensive evaluation:")
print("   Use the evaluate.py script with your test dataset")
print("   The test set was automatically created during training")

## 🔧 Troubleshooting

### Common Issues and Solutions:

**1. GPU Memory Error (CUDA out of memory)**
- Reduce batch_size to 8 or 4
- Restart runtime and try again

**2. Data Not Found**
- Check that DATA_DIR path is correct
- Ensure data is uploaded to Google Drive
- Verify folder structure (labeled/ and unlabeled/)

**3. Training Stops Unexpectedly**
- Colab sessions timeout after 12 hours
- Use runtime management to prevent disconnection
- Checkpoints are saved every 10 epochs for resuming

**4. Low Accuracy**
- Increase epochs (try 75-100)
- Adjust consistency_weight (try 1.0-3.0)
- Lower pseudo_label_threshold (try 0.5-0.6)

**5. Consistency Loss is 0.0000**
- Lower pseudo_label_threshold to 0.5
- Check that you have unlabeled data
- Ensure semi_supervised mode is selected

## 🚀 Next Steps

After training is complete, you can:

1. **Download your model**: The trained model is saved in Google Drive
2. **Continue training**: Resume from checkpoints for more epochs
3. **Evaluate performance**: Use the test set for final evaluation
4. **Deploy model**: Use the trained model for fish classification
5. **Experiment**: Try different hyperparameters or architectures

### Model Files Saved:
- `model_best.pth`: Best performing model (use this for inference)
- `model_latest.pth`: Most recent checkpoint
- `model_epoch_XX.pth`: Periodic checkpoints

### Performance Expectations:
- **50 epochs**: ~70-80% accuracy
- **100 epochs**: ~75-85% accuracy
- **Semi-supervised**: Should outperform supervised training

**Happy fish classification! 🐟🎉**

## 📈 Step 7b: Connect to Weights & Biases (Optional)

Log in to Weights & Biases for experiment tracking and visualization. You will be prompted to enter your API key.

In [12]:
# Login to Weights & Biases
import wandb
import os

print("📈 Connecting to Weights & Biases...")

# Check if already logged in (optional)
if os.environ.get("WANDB_API_KEY"):
    print("✅ W&B API key found in environment variables.")
    # You might still want to run wandb.login() explicitly for clarity or if using interactive login
    try:
        wandb.login(relogin=True) # Use relogin=True to re-authenticate even if key is found
        print("✅ Successfully logged in to W&B.")
    except Exception as e:
        print(f"⚠️ Could not relogin to W&B: {e}")
        print("💡 You may need to manually enter your API key below.")
        wandb.login()

else:
    print("🔑 Please enter your W&B API key when prompted.")
    try:
        wandb.login()
        print("✅ Successfully logged in to W&B.")
    except Exception as e:
        print(f"❌ W&B login failed: {e}")
        print("Please check your API key and try again.")
        # Optionally, add a step to show where to get the key
        print("\n💡 Get your API key from: https://wandb.ai/settings")
        print("   Or manually set it as a Colab Secret named WANDB_API_KEY.")


if wandb.run:
    print(f"🚀 W&B Run URL: {wandb.run.url}")
    print("✅ W&B connection established.")
else:
     print("❌ W&B connection not established. Logging may be disabled.")

📈 Connecting to Weights & Biases...
🔑 Please enter your W&B API key when prompted.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mcativthomson[0m ([33mcativthomson-university-of-cape-town[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


✅ Successfully logged in to W&B.
❌ W&B connection not established. Logging may be disabled.
