# Model Training Notebook

This notebook provides a simple interface to train different models on the BBBC021 dataset.

## Available Models:
1. **Vanilla SimCLR** - Standard contrastive learning with data augmentations (optionally use weak labels to prevent compound of positive pair in negative pairs)
2. **Weak Supervision SimCLR** - Uses compound labels to create positive pairs
3. **WS-DINO** - Teacher-student distillation approach

## Quick Start:
1. Set your training parameters in the configuration section (Check out our training module for a more detailed look at what params to set for each training approach)
2. Choose your model type
3. Run the training cell

In [None]:
import os
import sys
import torch
import gc
from pathlib import Path

# Add the parent directory to path so we can import our modules
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(''))))

# Import our training functions
from training.simclr_vanilla_train import train_simclr_vanilla
from training.simclr_ws_train import train_simclr
from training.wsdino_resnet_train import train_wsdino

print("Available devices:")
if torch.cuda.is_available():
    print(f"CUDA: {torch.cuda.get_device_name(0)}")
    print(f"CUDA Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
else:
    print("CPU only")

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"Number of GPUs: {torch.cuda.device_count()}")
    
# Clean up any existing GPU memory
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    gc.collect()

## Configuration

Set your training parameters here. You can modify these values based on your computational resources and requirements.

In [None]:
# TRAINING CONFIGURATION

# Data path - Update this to point to your BBBC021 dataset
DATA_ROOT = "/scratch/cv-course2025/group8"

# Model selection - Choose one of: 'vanilla_simclr', 'ws_simclr', 'wsdino'
MODEL_TYPE = "vanilla_simclr"

# Training parameters
EPOCHS = 50  # Number of training epochs (reduce for testing)
BATCH_SIZE = 128  # Batch size (reduce if you get out of memory errors)
LEARNING_RATE = 0.0003  # Learning rate
TEMPERATURE = 0.1  # Temperature for contrastive loss
PROJECTION_DIM = 128  # Projection head output dimension

# Saving options
SAVE_EVERY = 10  # Save model every N epochs
SAVE_DIR = "/scratch/cv-course2025/group8/model_weights"  # Directory to save models

# Advanced options (usually don't need to change)
COMPOUND_AWARE = True  # For vanilla SimCLR: use compound-aware loss
MOMENTUM = 0.996  # For WS-DINO: teacher momentum

print("Training Configuration:")
print(f"  Model Type: {MODEL_TYPE}")
print(f"  Data Root: {DATA_ROOT}")
print(f"  Epochs: {EPOCHS}")
print(f"  Batch Size: {BATCH_SIZE}")
print(f"  Learning Rate: {LEARNING_RATE}")
print(f"  Save Directory: {SAVE_DIR}")

# Create save directory if it doesn't exist
os.makedirs(SAVE_DIR, exist_ok=True)
print(f"  Save directory ready: {os.path.exists(SAVE_DIR)}")

## Model Information

Here's a brief overview of each model type:

### 1. Vanilla SimCLR
- **Method**: Standard contrastive learning with data augmentations
- **Positive pairs**: Two augmented versions of the same image
- You can use weak labels to prevent same compounds being ussed in negative pairs here, just use `compound_aware=True`

### 2. Weak Supervision SimCLR (WS-SimCLR)
- **Method**: Uses compound labels to create positive pairs
- **Positive pairs**: Two different images from the same compound

### 3. WS-DINO
- **Method**: Teacher-student distillation with weak supervision
- **Positive pairs**: Uses compound labels for supervision

## Training

Run the cell below to start training with your configured parameters.

In [None]:
# =============================================================================
# TRAINING EXECUTION
# =============================================================================

def train_model(model_type, **kwargs):
    """
    Train a model based on the specified type and parameters.
    """
    # Clear GPU memory before training
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        gc.collect()
    
    print(f"Starting training for {model_type}")
    print("=" * 50)
    
    try:
        if model_type == "vanilla_simclr":
            print("Training Vanilla SimCLR...")
            model = train_simclr_vanilla(
                root_path=kwargs['root_path'],
                epochs=kwargs['epochs'],
                batch_size=kwargs['batch_size'],
                learning_rate=kwargs['learning_rate'],
                temperature=kwargs['temperature'],
                projection_dim=kwargs['projection_dim'],
                save_every=kwargs['save_every'],
                save_dir=kwargs['save_dir'],
                compound_aware=kwargs.get('compound_aware', True)
            )
            
        elif model_type == "ws_simclr":
            print("Training Weak Supervision SimCLR...")
            model = train_simclr(
                root_path=kwargs['root_path'],
                epochs=kwargs['epochs'],
                batch_size=kwargs['batch_size'],
                learning_rate=kwargs['learning_rate'],
                temperature=kwargs['temperature'],
                projection_dim=kwargs['projection_dim'],
                save_every=kwargs['save_every']
            )
            
        elif model_type == "wsdino":
            print("Training WS-DINO...")
            model = train_wsdino(
                root_path=kwargs['root_path'],
                epochs=kwargs['epochs'],
                batch_size=kwargs['batch_size'],
                lr=kwargs['learning_rate'],
                momentum=kwargs.get('momentum', 0.996),
                temperature=kwargs['temperature'],
                save_every=kwargs['save_every']
            )
            
        else:
            raise ValueError(f"Unknown model type: {model_type}")
            
        print("=" * 50)
        print(f"Training completed successfully!")
        print(f"Models saved in: {kwargs['save_dir']}")
        
        return model
        
    except Exception as e:
        print(f"Training failed with error: {str(e)}")
        print("Please check your configuration and try again.")
        raise e

# Prepare training parameters
training_params = {
    'root_path': DATA_ROOT,
    'epochs': EPOCHS,
    'batch_size': BATCH_SIZE,
    'learning_rate': LEARNING_RATE,
    'temperature': TEMPERATURE,
    'projection_dim': PROJECTION_DIM,
    'save_every': SAVE_EVERY,
    'save_dir': SAVE_DIR,
    'compound_aware': COMPOUND_AWARE,
    'momentum': MOMENTUM
}

print("Training parameters:")
for key, value in training_params.items():
    print(f"  {key}: {value}")

# Start training
print(f"\nStarting training with model type: {MODEL_TYPE}")
trained_model = train_model(MODEL_TYPE, **training_params)

## Save your model

depending on your training approach, you will find your model under `/scratch/cv-course2025/group8/model_weights/<training_approach>`. You can then use the extractor and evaluator to see how your model performed. If you think you created a WORTHY model, we recommend giving it a unique and somewhat descriptive name and renaming the folders containing your model/features.