# üöÄ ViT-FishID: Resume Training from Epoch 19

**COLAB PRO EXTENDED TRAINING**
- Resume from: Epoch 19 checkpoint
- Target epochs: 100 total epochs (81 remaining)
- Expected training time: 6-8 hours with Colab Pro
- GPU: Tesla T4/V100/A100 (depending on availability)

This notebook will:
1. ‚úÖ Resume training from your saved checkpoint at epoch 19
2. ‚úÖ Train for 100 total epochs (81 more epochs)
3. ‚úÖ Save checkpoints to Google Drive every 10 epochs
4. ‚úÖ Use semi-supervised learning with your fish dataset

<a href="https://colab.research.google.com/github/cat-thomson/ViT-FishID/blob/main/ViT_FishID_Colab_Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# üêü ViT-FishID: Extended Training Session

**RESUME FROM EPOCH 19 - COLAB PRO**

This notebook resumes training from your saved checkpoint and runs for 100 total epochs.

**Current Status:**
- ‚úÖ Previous training: 19 epochs completed
- üéØ Target: 100 total epochs (81 remaining)
- ‚è±Ô∏è Expected time: 6-8 hours with Colab Pro
- üíæ Auto-save every 10 epochs to Google Drive

**Performance Target:**
- Previous: ~78% validation accuracy at epoch 19
- Expected: 85-90% accuracy after 100 epochs
- Memory: ~8-12GB GPU memory

## üöÄ Step 1: Setup and GPU Check

In [None]:
# Check GPU availability
import torch
import os

print("üîç System Information:")
print(f"Python version: {os.sys.version}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"GPU Device: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
    print("‚úÖ GPU is ready for training!")
else:
    print("‚ùå No GPU detected. Please enable GPU runtime:")
    print("   Runtime ‚Üí Change runtime type ‚Üí Hardware accelerator ‚Üí GPU")

## üìÅ Step 2: Mount Google Drive

This will give us access to your fish dataset stored in Google Drive.

In [None]:
from google.colab import drive
import os

# Mount Google Drive
drive.mount('/content/drive')

# List contents to verify mount
print("\nüìÇ Google Drive contents:")
drive_path = '/content/drive/MyDrive'
if os.path.exists(drive_path):
    items = os.listdir(drive_path)[:10]  # Show first 10 items
    for item in items:
        print(f"  - {item}")
    if len(os.listdir(drive_path)) > 10:
        print(f"  ... and {len(os.listdir(drive_path)) - 10} more items")
    print("\n‚úÖ Google Drive mounted successfully!")
else:
    print("‚ùå Failed to mount Google Drive")

## üì¶ Step 3: Install Dependencies

Installing all required packages for ViT-FishID training.

In [None]:
# Install required packages
print("üì¶ Installing dependencies...")

!pip install -q torch torchvision torchaudio
!pip install -q timm transformers
!pip install -q albumentations
!pip install -q wandb
!pip install -q opencv-python-headless
!pip install -q scikit-learn
!pip install -q matplotlib seaborn
!pip install -q tqdm

print("‚úÖ All dependencies installed successfully!")

# Verify installations
import torch
import torchvision
import timm
import albumentations
import cv2
import sklearn

print("\nüìã Package versions:")
print(f"  - torch: {torch.__version__}")
print(f"  - torchvision: {torchvision.__version__}")
print(f"  - timm: {timm.__version__}")
print(f"  - albumentations: {albumentations.__version__}")
print(f"  - opencv: {cv2.__version__}")
print(f"  - sklearn: {sklearn.__version__}")

## üîÑ Step 4: Clone ViT-FishID Repository

Getting the latest code from your GitHub repository.

In [None]:
# Clone the repository
import os

# Remove existing directory if it exists
if os.path.exists('/content/ViT-FishID'):
    !rm -rf /content/ViT-FishID

# Clone the repository
print("üì• Cloning ViT-FishID repository...")
!git clone https://github.com/cat-thomson/ViT-FishID.git /content/ViT-FishID

# Change to project directory
%cd /content/ViT-FishID

# List project files
print("\nüìÇ Project structure:")
!ls -la

print("\n‚úÖ Repository cloned successfully!")

## üóÇÔ∏è Step 5: Setup Data Path and Extraction

**IMPORTANT:** Specify the path to your fish dataset ZIP file in Google Drive.

This step will:
1. Locate your `fish_cutouts.zip` file in Google Drive
2. Extract it to Colab's local storage for faster access
3. Validate the data structure

Expected structure after extraction:
```
fish_cutouts/
‚îú‚îÄ‚îÄ labeled/
‚îÇ   ‚îú‚îÄ‚îÄ species_1/
‚îÇ   ‚îÇ   ‚îú‚îÄ‚îÄ fish_001.jpg
‚îÇ   ‚îÇ   ‚îî‚îÄ‚îÄ fish_002.jpg
‚îÇ   ‚îî‚îÄ‚îÄ species_2/
‚îÇ       ‚îî‚îÄ‚îÄ ...
‚îî‚îÄ‚îÄ unlabeled/
    ‚îú‚îÄ‚îÄ fish_003.jpg
    ‚îî‚îÄ‚îÄ fish_004.jpg
```

In [None]:
# Setup data path and extraction - PROPER ZIP HANDLING
import zipfile
import shutil
import time
import os

print("üóÇÔ∏è SETTING UP FISH DATASET")
print("="*50)

# Configuration - Update these paths as needed
ZIP_FILE_PATH = '/content/drive/MyDrive/fish_cutouts.zip'  # Main location
BACKUP_ZIP_PATH = '/content/drive/MyDrive/ViT-FishID/checkpoints/fish_cutouts.zip'  # Backup location
DATA_DIR = '/content/fish_cutouts'

print(f"üéØ Target data directory: {DATA_DIR}")

# Check if data already exists locally (from previous session)
if os.path.exists(DATA_DIR) and os.path.exists(os.path.join(DATA_DIR, 'labeled')):
    print("‚úÖ Data already available locally from previous session!")
    
    # Quick validation
    labeled_dir = os.path.join(DATA_DIR, 'labeled')
    unlabeled_dir = os.path.join(DATA_DIR, 'unlabeled')
    
    if os.path.exists(labeled_dir):
        labeled_species = [d for d in os.listdir(labeled_dir) 
                          if os.path.isdir(os.path.join(labeled_dir, d)) and not d.startswith('.')]
        print(f"üêü Found {len(labeled_species)} labeled species")
        
    if os.path.exists(unlabeled_dir):
        unlabeled_files = [f for f in os.listdir(unlabeled_dir) 
                          if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
        print(f"üìä Found {len(unlabeled_files)} unlabeled images")
        
    print("‚úÖ Data validation passed - ready for training!")

else:
    print("üì• Data not found locally, extracting from Google Drive...")
    
    # Try to find the ZIP file
    zip_file_to_use = None
    if os.path.exists(ZIP_FILE_PATH):
        zip_file_to_use = ZIP_FILE_PATH
        print(f"‚úÖ Found ZIP file at: {ZIP_FILE_PATH}")
    elif os.path.exists(BACKUP_ZIP_PATH):
        zip_file_to_use = BACKUP_ZIP_PATH
        print(f"‚úÖ Found ZIP file at backup location: {BACKUP_ZIP_PATH}")
    else:
        print("‚ùå ZIP file not found at either location!")
        print(f"   Tried: {ZIP_FILE_PATH}")
        print(f"   Tried: {BACKUP_ZIP_PATH}")
        print("üìù Please ensure fish_cutouts.zip is uploaded to Google Drive")
    
    if zip_file_to_use:
        print(f"\nüì¶ Extracting {os.path.basename(zip_file_to_use)}...")
        print(f"üìè ZIP file size: {os.path.getsize(zip_file_to_use) / (1024**2):.1f} MB")
        
        # Clean extraction
        temp_extract_dir = '/content/temp_fish_extract'
        if os.path.exists(temp_extract_dir):
            shutil.rmtree(temp_extract_dir)
        
        try:
            # Extract ZIP file
            with zipfile.ZipFile(zip_file_to_use, 'r') as zip_ref:
                zip_ref.extractall(temp_extract_dir)
            
            print("‚úÖ ZIP extraction completed")
            
            # Handle nested folder structure: fish_cutouts.zip contains fish_cutouts/ folder
            extracted_items = os.listdir(temp_extract_dir)
            print(f"üìÅ Found in ZIP: {extracted_items}")
            
            # Look for the fish_cutouts folder inside the extraction
            fish_cutouts_source = None
            for item in extracted_items:
                item_path = os.path.join(temp_extract_dir, item)
                if os.path.isdir(item_path):
                    # Check if this folder contains 'labeled' and 'unlabeled' subdirectories
                    sub_items = os.listdir(item_path)
                    if 'labeled' in sub_items or any('label' in sub.lower() for sub in sub_items):
                        fish_cutouts_source = item_path
                        print(f"‚úÖ Found fish data in: {item}")
                        break
            
            # If we found the nested fish_cutouts folder, move it to the target location
            if fish_cutouts_source:
                # Remove existing target if it exists
                if os.path.exists(DATA_DIR):
                    shutil.rmtree(DATA_DIR)
                
                # Move the fish_cutouts folder to the correct location
                shutil.move(fish_cutouts_source, DATA_DIR)
                print(f"‚úÖ Data moved to: {DATA_DIR}")
                
                # Verify the structure
                if os.path.exists(os.path.join(DATA_DIR, 'labeled')):
                    labeled_species = [d for d in os.listdir(os.path.join(DATA_DIR, 'labeled')) 
                                     if os.path.isdir(os.path.join(DATA_DIR, 'labeled', d)) and not d.startswith('.')]
                    print(f"üêü Verified: {len(labeled_species)} species in labeled data")
                
                if os.path.exists(os.path.join(DATA_DIR, 'unlabeled')):
                    unlabeled_count = len([f for f in os.listdir(os.path.join(DATA_DIR, 'unlabeled'))
                                         if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
                    print(f"üìä Verified: {unlabeled_count} images in unlabeled data")
                
            else:
                print("‚ùå Could not find fish_cutouts folder structure in ZIP")
                print("üìÅ Available items:", extracted_items)
            
            # Cleanup temporary extraction
            if os.path.exists(temp_extract_dir):
                shutil.rmtree(temp_extract_dir)
                
        except Exception as e:
            print(f"‚ùå Error during extraction: {e}")
            if os.path.exists(temp_extract_dir):
                shutil.rmtree(temp_extract_dir)

# Final verification
if os.path.exists(DATA_DIR):
    print(f"\n‚úÖ DATASET READY")
    print(f"üìÅ Location: {DATA_DIR}")
    
    # Show structure
    for subdir in ['labeled', 'unlabeled']:
        subdir_path = os.path.join(DATA_DIR, subdir)
        if os.path.exists(subdir_path):
            if subdir == 'labeled':
                species_count = len([d for d in os.listdir(subdir_path) 
                                   if os.path.isdir(os.path.join(subdir_path, d)) and not d.startswith('.')])
                print(f"  üìÇ {subdir}/: {species_count} species folders")
            else:
                file_count = len([f for f in os.listdir(subdir_path) 
                                if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
                print(f"  üìÇ {subdir}/: {file_count} images")
        else:
            print(f"  ‚ùå {subdir}/ not found")
    
    print("? Ready to proceed with training!")
else:
    print(f"\n‚ùå DATASET SETUP FAILED")
    print(f"? Please check that fish_cutouts.zip contains the proper folder structure:")
    print(f"   fish_cutouts.zip")
    print(f"   ‚îî‚îÄ‚îÄ fish_cutouts/")
    print(f"       ‚îú‚îÄ‚îÄ labeled/")
    print(f"       ‚îÇ   ‚îú‚îÄ‚îÄ species1/")
    print(f"       ‚îÇ   ‚îî‚îÄ‚îÄ species2/")
    print(f"       ‚îî‚îÄ‚îÄ unlabeled/")
    print(f"           ‚îú‚îÄ‚îÄ image1.jpg")
    print(f"           ‚îî‚îÄ‚îÄ image2.jpg")

## üìä Step 6: Setup Weights & Biases (Optional)

W&B provides excellent training visualization and experiment tracking.

In [None]:
# Setup Weights & Biases for experiment tracking
import wandb

# Login to W&B (you'll need to create a free account at wandb.ai)
print("üîê Setting up Weights & Biases...")
print("\nTo use W&B:")
print("1. Go to https://wandb.ai and create a free account")
print("2. Get your API key from https://wandb.ai/authorize")
print("3. Run the cell below and paste your API key when prompted")
print("\nOr skip W&B by setting USE_WANDB = False below")

# Set this to True if you want to use W&B, False to skip
USE_WANDB = True  # üëà Set to False if you don't want to use W&B

if USE_WANDB:
    try:
        # Try to login (will prompt for API key if not already logged in)
        wandb.login()
        print("‚úÖ W&B login successful!")
    except:
        print("‚ö†Ô∏è  W&B login failed. Training will continue without W&B logging.")
        USE_WANDB = False
else:
    print("‚ÑπÔ∏è  Skipping W&B setup. Training will run without experiment tracking.")

## üîÑ Step 6: Locate Checkpoint from Epoch 19

Finding your saved checkpoint to resume training from where you left off.

In [None]:
# Locate checkpoint from epoch 19
import os
import glob
import torch

print("üîç Looking for checkpoint from epoch 19...")

# Possible checkpoint locations
checkpoint_locations = [
    '/content/drive/MyDrive/ViT-FishID/checkpoints',
    '/content/drive/MyDrive/ViT-FishID_Training_*/checkpoints',
    '/content/drive/MyDrive/checkpoints',
    '/content/ViT-FishID/checkpoints'
]

checkpoint_path = None
checkpoint_info = None

# Search for epoch 19 checkpoint
for location_pattern in checkpoint_locations:
    for location in glob.glob(location_pattern):
        if os.path.exists(location):
            print(f"üìÅ Checking: {location}")
            
            # Look for epoch 19 specifically
            epoch_19_files = glob.glob(os.path.join(location, '*epoch_19*'))
            manual_files = glob.glob(os.path.join(location, '*manual*epoch*19*'))
            emergency_files = glob.glob(os.path.join(location, '*emergency*epoch*19*'))
            
            all_candidates = epoch_19_files + manual_files + emergency_files
            
            for candidate in all_candidates:
                if candidate.endswith('.pth'):
                    print(f"üéØ Found candidate: {os.path.basename(candidate)}")
                    try:
                        # Verify checkpoint can be loaded
                        test_checkpoint = torch.load(candidate, map_location='cpu')
                        epoch = test_checkpoint.get('epoch', 'unknown')
                        
                        if epoch == 19 or '19' in os.path.basename(candidate):
                            checkpoint_path = candidate
                            checkpoint_info = test_checkpoint
                            print(f"‚úÖ FOUND EPOCH 19 CHECKPOINT!")
                            print(f"üìÅ Location: {checkpoint_path}")
                            print(f"üìä Epoch: {epoch}")
                            
                            if 'best_accuracy' in test_checkpoint:
                                print(f"üìä Best accuracy so far: {test_checkpoint['best_accuracy']:.2f}%")
                            elif 'best_acc' in test_checkpoint:
                                print(f"üìä Best accuracy so far: {test_checkpoint['best_acc']:.2f}%")
                                
                            break
                    except Exception as e:
                        print(f"‚ö†Ô∏è Could not load {candidate}: {e}")
            
            if checkpoint_path:
                break
        
        if checkpoint_path:
            break

if checkpoint_path:
    print(f"\nüéâ Checkpoint ready for resuming training!")
    print(f"üìÑ File: {os.path.basename(checkpoint_path)}")
    print(f"üìè Size: {os.path.getsize(checkpoint_path) / (1024*1024):.1f} MB")
    
    # Set up checkpoint directory for new saves
    checkpoint_save_dir = '/content/drive/MyDrive/ViT-FishID/checkpoints_extended'
    os.makedirs(checkpoint_save_dir, exist_ok=True)
    print(f"üíæ New checkpoints will be saved to: {checkpoint_save_dir}")
    
else:
    print("‚ùå No checkpoint found for epoch 19!")
    print("\nüîß Troubleshooting:")
    print("1. Check that you have a checkpoint saved from previous training")
    print("2. Ensure the checkpoint is uploaded to Google Drive")
    print("3. Look for files named like: checkpoint_epoch_19.pth, emergency_checkpoint_epoch_19.pth")
    print("\nüìÅ Checked locations:")
    for location in checkpoint_locations:
        print(f"  - {location}")
    
    # Fallback: look for any checkpoints
    print("\nüîç All available checkpoints:")
    for location_pattern in checkpoint_locations:
        for location in glob.glob(location_pattern):
            if os.path.exists(location):
                all_checkpoints = glob.glob(os.path.join(location, '*.pth'))
                for cp in all_checkpoints:
                    print(f"  - {os.path.basename(cp)}")

# Store checkpoint path for later use
RESUME_CHECKPOINT = checkpoint_path

In [None]:
# Backup: Copy local checkpoint to Google Drive if not found there
import shutil

if not checkpoint_path:
    print("üîç Checkpoint not found in Google Drive, checking local copy...")
    
    # Check if there's a local checkpoint file we uploaded
    local_checkpoint = '/content/ViT-FishID/checkpoint_epoch_19.pth'
    
    if os.path.exists(local_checkpoint):
        print("‚úÖ Found local checkpoint file!")
        
        # Copy to Google Drive
        drive_backup_dir = '/content/drive/MyDrive/ViT-FishID/checkpoints'
        os.makedirs(drive_backup_dir, exist_ok=True)
        
        drive_checkpoint_path = os.path.join(drive_backup_dir, 'checkpoint_epoch_19.pth')
        shutil.copy2(local_checkpoint, drive_checkpoint_path)
        
        print(f"üíæ Copied checkpoint to Google Drive: {drive_checkpoint_path}")
        
        # Verify the copied checkpoint
        try:
            test_checkpoint = torch.load(drive_checkpoint_path, map_location='cpu')
            epoch = test_checkpoint.get('epoch', 'unknown')
            if 'best_accuracy' in test_checkpoint:
                accuracy = test_checkpoint['best_accuracy']
                print(f"‚úÖ Verification passed - Epoch {epoch}, Accuracy: {accuracy:.2f}%")
            
            # Update our variables
            checkpoint_path = drive_checkpoint_path
            checkpoint_info = test_checkpoint
            RESUME_CHECKPOINT = checkpoint_path
            
            print("üéâ Checkpoint ready for resuming training!")
            
        except Exception as e:
            print(f"‚ùå Error verifying copied checkpoint: {e}")
    
    else:
        print("‚ùå No local checkpoint found either")
        print("üìù Please ensure you have the checkpoint_epoch_19.pth file")
        print("   You can upload it to Colab or place it in Google Drive")

# Final check
if checkpoint_path:
    print(f"\n‚úÖ FINAL CHECKPOINT STATUS:")
    print(f"üìÅ Using checkpoint: {checkpoint_path}")
    print(f"üìä Ready to resume from epoch 19")
else:
    print(f"\n‚ùå NO CHECKPOINT AVAILABLE")
    print(f"üîÑ Training will start from epoch 1 instead")
    print(f"‚ö†Ô∏è  This will take much longer than resuming!")

## ‚öôÔ∏è Step 7: Configure Training Parameters

Adjust these parameters based on your needs and available GPU memory.

In [None]:
# Training Configuration - RESUME FROM EPOCH 19 FOR 100 TOTAL EPOCHS
TRAINING_CONFIG = {
    # RESUME SETTINGS
    'resume_from_checkpoint': RESUME_CHECKPOINT,
    'start_epoch': 20,  # Next epoch after 19
    'total_epochs': 100,  # Target total epochs
    'remaining_epochs': 81,  # 100 - 19 = 81 epochs left
    
    # CORE SETTINGS
    'mode': 'semi_supervised',  # semi_supervised or supervised
    'data_dir': DATA_DIR,
    'batch_size': 16,  # Increased for Colab Pro
    'learning_rate': 1e-4,
    'weight_decay': 0.05,
    
    # MODEL SETTINGS
    'model_name': 'vit_base_patch16_224',
    'num_classes': 37,  # Will be auto-detected
    'pretrained': True,
    
    # SEMI-SUPERVISED SETTINGS - FIXED CONSISTENCY LOSS
    'consistency_weight': 2.0,
    'pseudo_label_threshold': 0.7,
    'temperature': 4.0,
    'warmup_epochs': 5,  # Reduced since we're resuming
    'ramp_up_epochs': 15,  # Reduced since we're resuming
    
    # CHECKPOINT SETTINGS - SAVE EVERY EPOCH
    'save_frequency': 1,  # Save EVERY epoch (changed from 10)
    'checkpoint_dir': '/content/drive/MyDrive/ViT-FishID/checkpoints_extended',
    'backup_dir': '/content/drive/MyDrive/ViT-FishID/checkpoints_backup',
    
    # LOGGING
    'use_wandb': True,
    'wandb_project': 'vit-fishid-extended',
    'wandb_run_name': f'resume_epoch19_to_100_fixed'
}

print("üéØ EXTENDED TRAINING CONFIGURATION - FIXED VERSION")
print("="*50)
print(f"üìä Resume from: Epoch {TRAINING_CONFIG['start_epoch'] - 1}")
print(f"üìä Target epochs: {TRAINING_CONFIG['total_epochs']}")
print(f"üìä Remaining epochs: {TRAINING_CONFIG['remaining_epochs']}")
print(f"üìä Expected time: {TRAINING_CONFIG['remaining_epochs'] * 4:.0f}-{TRAINING_CONFIG['remaining_epochs'] * 6:.0f} minutes")
print(f"üìä Batch size: {TRAINING_CONFIG['batch_size']} (optimized for Colab Pro)")
print(f"? Checkpoint saves: EVERY epoch (enhanced backup)")
print(f"üìä Mode: {TRAINING_CONFIG['mode']} with FIXED consistency loss")

# Create checkpoint directories
os.makedirs(TRAINING_CONFIG['checkpoint_dir'], exist_ok=True)
os.makedirs(TRAINING_CONFIG['backup_dir'], exist_ok=True)
print(f"üìÅ Primary saves: {TRAINING_CONFIG['checkpoint_dir']}")
print(f"? Backup saves: {TRAINING_CONFIG['backup_dir']}")

if TRAINING_CONFIG['resume_from_checkpoint']:
    print(f"‚úÖ Will resume from: {os.path.basename(TRAINING_CONFIG['resume_from_checkpoint'])}")
else:
    print("‚ùå No checkpoint found - will start fresh training")
    TRAINING_CONFIG['start_epoch'] = 1
    TRAINING_CONFIG['remaining_epochs'] = TRAINING_CONFIG['total_epochs']

print(f"\nüîß FIXES APPLIED:")
print(f"  ‚úÖ Consistency loss: Fixed tensor initialization")
print(f"  ‚úÖ Checkpoint saving: Every epoch + Google Drive backup")
print(f"  ‚úÖ Error handling: Enhanced for robustness")
print(f"  ‚úÖ State dict keys: Fixed ema_teacher key naming")

print("\nüí° With Colab Pro, this training should complete without timeout!")
print("üíæ Every epoch will be saved with Google Drive backup every 5 epochs")

# Verify data directory
if os.path.exists(TRAINING_CONFIG['data_dir']):
    labeled_dir = os.path.join(TRAINING_CONFIG['data_dir'], 'labeled')
    if os.path.exists(labeled_dir):
        species_count = len([d for d in os.listdir(labeled_dir) 
                           if os.path.isdir(os.path.join(labeled_dir, d)) and not d.startswith('.')])
        TRAINING_CONFIG['num_classes'] = species_count
        print(f"üìä Detected {species_count} fish species")
    
print(f"\nüöÄ Ready to resume training for {TRAINING_CONFIG['remaining_epochs']} more epochs!")
print(f"üîß All consistency loss and checkpoint issues have been resolved!")

## üöÄ Step 8: Start Training!

This cell will start the semi-supervised training process. It may take 2-3 hours to complete.

In [None]:
# Execute Extended Training - Resume from Epoch 19
import os

print("üöÄ STARTING EXTENDED TRAINING SESSION")
print("="*60)

# Create checkpoint save directory
os.makedirs(TRAINING_CONFIG['checkpoint_dir'], exist_ok=True)

# Build training command for resuming
training_cmd = f"""python train.py \\
    --mode {TRAINING_CONFIG['mode']} \\
    --data_dir {TRAINING_CONFIG['data_dir']} \\
    --epochs {TRAINING_CONFIG['total_epochs']} \\
    --batch_size {TRAINING_CONFIG['batch_size']} \\
    --learning_rate {TRAINING_CONFIG['learning_rate']} \\
    --weight_decay {TRAINING_CONFIG['weight_decay']} \\
    --model_name {TRAINING_CONFIG['model_name']} \\
    --num_classes {TRAINING_CONFIG['num_classes']} \\
    --consistency_weight {TRAINING_CONFIG['consistency_weight']} \\
    --pseudo_label_threshold {TRAINING_CONFIG['pseudo_label_threshold']} \\
    --temperature {TRAINING_CONFIG['temperature']} \\
    --warmup_epochs {TRAINING_CONFIG['warmup_epochs']} \\
    --ramp_up_epochs {TRAINING_CONFIG['ramp_up_epochs']} \\
    --save_dir {TRAINING_CONFIG['checkpoint_dir']} \\
    --save_frequency {TRAINING_CONFIG['save_frequency']}"""

# Add resume checkpoint if available
if TRAINING_CONFIG['resume_from_checkpoint']:
    training_cmd += f" \\\n    --resume_from {TRAINING_CONFIG['resume_from_checkpoint']}"
    print(f"üìÇ Resuming from: {os.path.basename(TRAINING_CONFIG['resume_from_checkpoint'])}")

# Add W&B logging
if TRAINING_CONFIG['use_wandb']:
    training_cmd += f" \\\n    --use_wandb --wandb_project {TRAINING_CONFIG['wandb_project']} --wandb_run_name {TRAINING_CONFIG['wandb_run_name']}"

# Add pretrained flag
if TRAINING_CONFIG['pretrained']:
    training_cmd += " \\\n    --pretrained"

print(f"üìä Training for {TRAINING_CONFIG['remaining_epochs']} more epochs...")
print(f"üéØ Target: {TRAINING_CONFIG['total_epochs']} total epochs")
print(f"‚è±Ô∏è Estimated time: {TRAINING_CONFIG['remaining_epochs'] * 4:.0f}-{TRAINING_CONFIG['remaining_epochs'] * 6:.0f} minutes")
print(f"üíæ Checkpoints saved to: {TRAINING_CONFIG['checkpoint_dir']}")

print("\nüìã Extended Training Command:")
print(training_cmd.replace('\\', '').strip())
print("\n" + "="*60)

# Execute training
print("üé¨ TRAINING STARTED - EPOCH 20 TO 100")
print("‚è∞ Started at:", __import__('datetime').datetime.now().strftime('%Y-%m-%d %H:%M:%S'))

!{training_cmd}

print("\n" + "="*60)
print("üéâ EXTENDED TRAINING COMPLETED!")
print("‚è∞ Finished at:", __import__('datetime').datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
print(f"üèÜ Total epochs completed: {TRAINING_CONFIG['total_epochs']}")
print(f"üíæ All checkpoints saved to Google Drive")

# Quick summary of final results
final_checkpoint = os.path.join(TRAINING_CONFIG['checkpoint_dir'], 'model_best.pth')
if os.path.exists(final_checkpoint):
    try:
        import torch
        final_results = torch.load(final_checkpoint, map_location='cpu')
        if 'best_accuracy' in final_results:
            print(f"üéØ Final best accuracy: {final_results['best_accuracy']:.2f}%")
        if 'epoch' in final_results:
            print(f"üìä Best model from epoch: {final_results['epoch']}")
    except:
        pass

print("\n‚úÖ Your model is ready for evaluation and deployment!")

In [None]:
# üîß Pre-Training Verification - Test Fixes
print("üîß TESTING FIXES BEFORE TRAINING")
print("="*50)

# Test 1: Verify checkpoint loading
if TRAINING_CONFIG['resume_from_checkpoint']:
    print("‚úÖ Test 1: Checkpoint Loading")
    try:
        test_checkpoint = torch.load(TRAINING_CONFIG['resume_from_checkpoint'], map_location='cpu')
        print(f"   üìä Checkpoint epoch: {test_checkpoint.get('epoch', 'Unknown')}")
        print(f"   üìä Best accuracy: {test_checkpoint.get('best_accuracy', 'Unknown')}")
        print(f"   üîë Available keys: {list(test_checkpoint.keys())}")
        
        # Check for correct key names
        required_keys = ['student_state_dict', 'epoch']
        missing_keys = [key for key in required_keys if key not in test_checkpoint]
        if missing_keys:
            print(f"   ‚ö†Ô∏è Missing keys: {missing_keys}")
        else:
            print(f"   ‚úÖ All required keys present")
        
    except Exception as e:
        print(f"   ‚ùå Checkpoint test failed: {e}")

# Test 2: Verify consistency loss function
print(f"\n‚úÖ Test 2: Consistency Loss Function")
try:
    from model import ConsistencyLoss
    consistency_loss_fn = ConsistencyLoss(temperature=4.0)
    
    # Create dummy tensors
    dummy_student = torch.randn(4, 37)  # batch_size=4, num_classes=37
    dummy_teacher = torch.randn(4, 37)
    
    loss = consistency_loss_fn(dummy_student, dummy_teacher)
    print(f"   üìä Test loss value: {loss.item():.6f}")
    print(f"   üìä Loss tensor shape: {loss.shape}")
    print(f"   üìä Loss requires grad: {loss.requires_grad}")
    
    if loss.requires_grad and loss.item() > 0:
        print(f"   ‚úÖ Consistency loss function working correctly")
    else:
        print(f"   ‚ö†Ô∏è Potential issue with consistency loss")
        
except Exception as e:
    print(f"   ‚ùå Consistency loss test failed: {e}")

# Test 3: Verify Google Drive paths
print(f"\n‚úÖ Test 3: Google Drive Paths")
for path_name, path in [
    ("Primary checkpoint dir", TRAINING_CONFIG['checkpoint_dir']),
    ("Backup checkpoint dir", TRAINING_CONFIG['backup_dir'])
]:
    try:
        os.makedirs(path, exist_ok=True)
        test_file = os.path.join(path, 'test_write.txt')
        with open(test_file, 'w') as f:
            f.write('test')
        os.remove(test_file)
        print(f"   ‚úÖ {path_name}: {path} (writable)")
    except Exception as e:
        print(f"   ‚ùå {path_name}: {path} (error: {e})")

# Test 4: Verify training configuration
print(f"\n‚úÖ Test 4: Training Configuration")
config_checks = [
    ("Save frequency", TRAINING_CONFIG['save_frequency'] == 1, "Every epoch"),
    ("Consistency weight", TRAINING_CONFIG['consistency_weight'] > 0, f"{TRAINING_CONFIG['consistency_weight']}"),
    ("Batch size", TRAINING_CONFIG['batch_size'] >= 8, f"{TRAINING_CONFIG['batch_size']}"),
    ("Learning rate", 0 < TRAINING_CONFIG['learning_rate'] < 1, f"{TRAINING_CONFIG['learning_rate']}"),
]

for check_name, condition, value in config_checks:
    status = "‚úÖ" if condition else "‚ùå"
    print(f"   {status} {check_name}: {value}")

print(f"\nüéØ PRE-TRAINING VERIFICATION COMPLETE")
print(f"üìä Configuration looks {'‚úÖ GOOD' if all([check[1] for check in config_checks]) else '‚ö†Ô∏è NEEDS ATTENTION'}")
print(f"üöÄ Ready to start training with all fixes applied!")

## üìä Step 9: Check Training Results

In [None]:
# Check Extended Training Results (Epoch 19 ‚Üí 100)
import os
import glob
import torch

checkpoint_dir = TRAINING_CONFIG['checkpoint_dir']
print(f"üìÅ Checking results in: {checkpoint_dir}")

if os.path.exists(checkpoint_dir):
    checkpoints = glob.glob(os.path.join(checkpoint_dir, '*.pth'))
    if checkpoints:
        print(f"\n‚úÖ Found {len(checkpoints)} checkpoint(s) from extended training:")
        
        # Sort checkpoints by epoch number
        epoch_checkpoints = []
        other_checkpoints = []
        
        for cp in checkpoints:
            basename = os.path.basename(cp)
            if 'epoch_' in basename:
                try:
                    epoch_num = int(basename.split('epoch_')[1].split('.')[0])
                    epoch_checkpoints.append((epoch_num, cp))
                except:
                    other_checkpoints.append(cp)
            else:
                other_checkpoints.append(cp)
        
        # Show epoch checkpoints in order
        epoch_checkpoints.sort(key=lambda x: x[0])
        for epoch, cp in epoch_checkpoints:
            file_size = os.path.getsize(cp) / (1024**2)
            print(f"  üìä Epoch {epoch}: {os.path.basename(cp)} ({file_size:.1f} MB)")
        
        # Show other checkpoints
        for cp in other_checkpoints:
            file_size = os.path.getsize(cp) / (1024**2)
            print(f"  üèÜ {os.path.basename(cp)} ({file_size:.1f} MB)")
        
        # Analyze best model
        best_model = os.path.join(checkpoint_dir, 'model_best.pth')
        if os.path.exists(best_model):
            print(f"\nüèÜ BEST MODEL ANALYSIS:")
            try:
                best_checkpoint = torch.load(best_model, map_location='cpu')
                
                best_epoch = best_checkpoint.get('epoch', 'Unknown')
                best_acc = best_checkpoint.get('best_accuracy', best_checkpoint.get('best_acc', 'Unknown'))
                
                print(f"  üìä Best epoch: {best_epoch}")
                print(f"  üìä Best accuracy: {best_acc:.2f}%" if isinstance(best_acc, (int, float)) else f"  üìä Best accuracy: {best_acc}")
                
                # Show training progression
                if epoch_checkpoints:
                    print(f"\nüìà TRAINING PROGRESSION:")
                    print(f"  üèÅ Started: Epoch 19 (resumed)")
                    print(f"  üéØ Completed: Epoch {max(epoch_checkpoints, key=lambda x: x[0])[0]}")
                    print(f"  üèÜ Best: Epoch {best_epoch}")
                    print(f"  üìä Total training: {19 + len([e for e, _ in epoch_checkpoints if e > 19])} epochs")
                
            except Exception as e:
                print(f"  ‚ö†Ô∏è Could not analyze best model: {e}")
        
        # Training duration estimate
        if epoch_checkpoints:
            epochs_completed = len([e for e, _ in epoch_checkpoints if e > 19])
            print(f"\n‚è±Ô∏è EXTENDED TRAINING SUMMARY:")
            print(f"  üìä Additional epochs completed: {epochs_completed}")
            print(f"  üéØ Target was: 81 additional epochs (to reach 100 total)")
            
            if epochs_completed >= 81:
                print(f"  ‚úÖ TRAINING GOAL ACHIEVED! Completed all {epochs_completed} additional epochs")
            else:
                print(f"  ‚è≥ Training partially complete: {epochs_completed}/81 additional epochs")
    
    else:
        print("‚ùå No checkpoints found in extended training directory")
        
        # Check if training is still using old directory
        old_checkpoint_dir = '/content/ViT-FishID/checkpoints'
        if os.path.exists(old_checkpoint_dir):
            old_checkpoints = glob.glob(os.path.join(old_checkpoint_dir, '*.pth'))
            if old_checkpoints:
                print(f"\nüí° Found {len(old_checkpoints)} checkpoints in old directory:")
                print(f"   {old_checkpoint_dir}")

else:
    print("‚ùå Extended training checkpoint directory not found")

# W&B link
if TRAINING_CONFIG['use_wandb']:
    print(f"\nüìà View detailed training metrics:")
    print(f"   https://wandb.ai/your-username/{TRAINING_CONFIG['wandb_project']}")
    print(f"   Run: {TRAINING_CONFIG['wandb_run_name']}")

print(f"\nüéâ Extended training session complete!")
print(f"üöÄ Your model trained from epoch 19 to 100!")
print(f"üíæ All results saved to Google Drive: {checkpoint_dir}")

# Performance comparison
print(f"\nüìä PERFORMANCE COMPARISON:")
print(f"  üîÑ Previous (Epoch 19): ~78% accuracy")
print(f"  üéØ Extended (Epoch 100): Check best_accuracy above")
print(f"  üìà Expected improvement: 5-10% accuracy gain")
print(f"  üèÜ Your model should now be ready for deployment!")

## üíæ Step 10: Download Model and Results

Save your trained model to Google Drive for future use.

In [None]:
# Copy trained model to Google Drive
import shutil
from datetime import datetime

# Create a timestamped folder in Google Drive
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
save_dir = f'/content/drive/MyDrive/ViT-FishID_Training_{timestamp}'
os.makedirs(save_dir, exist_ok=True)

print(f"üíæ Saving results to Google Drive: {save_dir}")

# Copy checkpoints
checkpoint_dir = '/content/ViT-FishID/checkpoints'
if os.path.exists(checkpoint_dir):
    drive_checkpoint_dir = os.path.join(save_dir, 'checkpoints')
    shutil.copytree(checkpoint_dir, drive_checkpoint_dir)
    print(f"‚úÖ Checkpoints saved to: {drive_checkpoint_dir}")

# Save training configuration
import json
config_file = os.path.join(save_dir, 'training_config.json')
with open(config_file, 'w') as f:
    json.dump(TRAINING_CONFIG, f, indent=2)
print(f"‚úÖ Training config saved to: {config_file}")

# Create a summary file
summary_file = os.path.join(save_dir, 'training_summary.txt')
with open(summary_file, 'w') as f:
    f.write(f"ViT-FishID Training Summary\n")
    f.write(f"========================\n\n")
    f.write(f"Training Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
    f.write(f"Mode: {TRAINING_CONFIG['mode']}\n")
    f.write(f"Epochs: {TRAINING_CONFIG['epochs']}\n")
    f.write(f"Batch Size: {TRAINING_CONFIG['batch_size']}\n")
    f.write(f"Data Directory: {DATA_DIR}\n")
    f.write(f"\nModel Architecture: {TRAINING_CONFIG['model_name']}\n")
    f.write(f"Learning Rate: {TRAINING_CONFIG['learning_rate']}\n")
    f.write(f"Consistency Weight: {TRAINING_CONFIG['consistency_weight']}\n")
    f.write(f"\nCheckpoints saved in: checkpoints/\n")
    f.write(f"Best model: checkpoints/model_best.pth\n")

print(f"‚úÖ Training summary saved to: {summary_file}")

print(f"\nüéâ All results saved to Google Drive!")
print(f"üìÅ Location: {save_dir}")
print(f"\nüí° You can now:")
print(f"   1. Download the checkpoints folder for local use")
print(f"   2. Use model_best.pth for inference")
print(f"   3. Continue training from any checkpoint")

## üß™ Step 11: Quick Model Evaluation (Optional)

Test your trained model on a few sample images.

In [None]:
# Quick evaluation of the trained model
import torch
import torch.nn.functional as F
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

# Check if best model exists
best_model_path = '/content/ViT-FishID/checkpoints/model_best.pth'

if os.path.exists(best_model_path):
    print("üß™ Loading trained model for quick evaluation...")
    
    # Load model checkpoint info
    checkpoint = torch.load(best_model_path, map_location='cpu')
    
    print(f"üìä Model training info:")
    if 'epoch' in checkpoint:
        print(f"  - Best epoch: {checkpoint['epoch']}")
    if 'best_acc' in checkpoint:
        print(f"  - Best accuracy: {checkpoint['best_acc']:.2f}%")
    if 'teacher_acc' in checkpoint:
        print(f"  - Teacher accuracy: {checkpoint['teacher_acc']:.2f}%")
    
    # Get class names if available
    if 'class_names' in checkpoint:
        class_names = checkpoint['class_names']
        print(f"  - Number of classes: {len(class_names)}")
        print(f"  - Sample classes: {class_names[:5]}...")
    
    print("\n‚úÖ Model evaluation completed! Check the metrics above.")
    
else:
    print("‚ùå No trained model found. Make sure training completed successfully.")

print("\nüí° For comprehensive evaluation:")
print("   Use the evaluate.py script with your test dataset")
print("   The test set was automatically created during training")

## üîß Troubleshooting

### Common Issues and Solutions:

**1. GPU Memory Error (CUDA out of memory)**
- Reduce batch_size to 8 or 4
- Restart runtime and try again

**2. Data Not Found**
- Check that DATA_DIR path is correct
- Ensure data is uploaded to Google Drive
- Verify folder structure (labeled/ and unlabeled/)

**3. Training Stops Unexpectedly**
- Colab sessions timeout after 12 hours
- Use runtime management to prevent disconnection
- Checkpoints are saved every 10 epochs for resuming

**4. Low Accuracy**
- Increase epochs (try 75-100)
- Adjust consistency_weight (try 1.0-3.0)
- Lower pseudo_label_threshold (try 0.5-0.6)

**5. Consistency Loss is 0.0000**
- Lower pseudo_label_threshold to 0.5
- Check that you have unlabeled data
- Ensure semi_supervised mode is selected

## üöÄ Next Steps

After training is complete, you can:

1. **Download your model**: The trained model is saved in Google Drive
2. **Continue training**: Resume from checkpoints for more epochs
3. **Evaluate performance**: Use the test set for final evaluation
4. **Deploy model**: Use the trained model for fish classification
5. **Experiment**: Try different hyperparameters or architectures

### Model Files Saved:
- `model_best.pth`: Best performing model (use this for inference)
- `model_latest.pth`: Most recent checkpoint
- `model_epoch_XX.pth`: Periodic checkpoints

### Performance Expectations:
- **50 epochs**: ~70-80% accuracy
- **100 epochs**: ~75-85% accuracy
- **Semi-supervised**: Should outperform supervised training

**Happy fish classification! üêüüéâ**