# Initialization and Config


In [None]:
IS_COLAB = True

In [None]:
if IS_COLAB:   
    import os

    github_token = ""
    github_username = "" # Replace with your GitHub username
    repository_url = f"https://{github_username}:{github_token}@github.com/neskech/Multimodal-2025.git"

    !git clone {repository_url}
    

    %cd Multimodal-2025
    !git submodule update --init --recursive
    


In [None]:
reqs = """
torch>=1.9.0
torchvision>=0.10.0
numpy>=1.21.0
scipy>=1.7.0
scikit-learn>=1.0.0
matplotlib>=3.4.0
seaborn>=0.11.0
Pillow>=8.3.0
clip @ git+https://github.com/openai/CLIP.git@dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
power-spherical @ git+https://github.com/nicola-decao/power_spherical.git@3d4619a9d6c01bc9b427533d386271a233e304cd
dotenv
"""
with open("/root/Multimodal-2025/requirements.txt", "w") as f:
    f.write(reqs)

In [None]:
%uv pip install -r requirements.txt
%uv pip install wandb
%uv pip install webdataset
%uv pip install datasets
!cd Multimodal-2025

In [None]:
if not IS_COLAB: # colab does not seem to support these
    %load_ext autoreload
    %autoreload 2
    %reload_ext autoreload


import torch
import gc
import tqdm
import matplotlib.pyplot as plt
import numpy as np
import logging
import wandb
import os
import sys
from dotenv import load_dotenv
from typing import Literal, Union

sys.path.append("..")
sys.path.append("Multimodal-2025")
from Datasets.coco import CocoDataset
from Datasets.cc12m import CC12mDataset
from Datasets.cood import CoodDataset
from Datasets.laion import LaionDataset
from Models.variationalClip import VariationalCLIPModel
from losses.vclipLoss import VClipLoss
from power_spherical import PowerSpherical

In [None]:
# Type aliases for compatibility with older Python versions
type Dataset = Literal['COCO', 'COOD', 'CC12M', 'LAION']
DATASET: Dataset = 'COCO'

DEVICE = None
if torch.cuda.is_available():
    DEVICE = 'cuda'
elif torch.mps.is_available():
    DEVICE = 'mps'
else:
    DEVICE = 'cpu'

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Get WANDB Key (Use a .env file to store the key)
load_dotenv()
WANDB_API_KEY = os.environ.get('WANDB_API_KEY')

In [None]:
CONFIG = {
    'NUM_EPOCHS': 20,
    'BATCH_SIZE': 128,
    'LEARNING_RATE': 1e-3,
    'WEIGHT_DECAY': 1e-8,

    # Scheduler parameters
    'WARMUP_EPOCHS': 2,
    'DECAY_EPOCHS': 30,

    # To avoid gradient explosion. Set to 1 to disable
    'GRAD_ACCUMULATION_STEPS': 1,
    # Clip gradients to avoid explosion
    'CLIP_GRADIENTS': False,
    'EMPTY_CACHE_AFTER_BATCH': False,

    # False to disable preloaded weights
    'LOAD_PRETRAINED_WEIGHTS': False,

    'KL_WEIGHT': 100,
    'NUM_EPOCHS_TO_FULL_KL': 5,

    'DATA_DIR': '/mnt/content/Data',
    'TRAIN_RATIO': 0.9,
    'TOTAL_DATAPOINTS': 50_000,

    'USE_WANDB': True,
    'WANDB_RUN_NAME': f'VCLIP_Train_On_{DATASET}_{50_000}',
    'WANDB_PREVIOUS_RUN_ID': None,  # set to None if not resuming
    'WANDB_PROJECT_NAME': 'multimodal_2025',

    'FREEZE_BACKBONE': True,
}

# Data


In [None]:
num_train = int(CONFIG['TOTAL_DATAPOINTS'] * CONFIG['TRAIN_RATIO'])
num_val = int(CONFIG['TOTAL_DATAPOINTS'] * (1.0 - CONFIG['TRAIN_RATIO']))

if DATASET == 'COCO':
    train = CocoDataset(
        data_dir="/mnt/content/Data",
        split='train2017',
        tokenize=True,
        max_samples=num_train
    )
    val = CocoDataset(
        data_dir="/mnt/content/Data",
        split='val2017',
        tokenize=True,
        max_samples=num_val
    )
    collate_fn = CocoDataset.collate_function
elif DATASET == 'COOD':
    CoodDataset.download(data_dir=CONFIG['DATA_DIR'])
    all_data = CoodDataset(
        data_dir=CONFIG['DATA_DIR'],
        tokenize=True,
        max_samples=CONFIG['TOTAL_DATAPOINTS']
    )
    train = torch.utils.data.Subset(
        all_data,
        range(0, num_train)
    )
    val = torch.utils.data.Subset(
        all_data,
        range(num_train, CONFIG['TOTAL_DATAPOINTS'])
    )
    collate_fn = CoodDataset.collate_function
elif DATASET == 'LAION':
    LaionDataset.download(
        max_samples=CONFIG['TOTAL_DATAPOINTS'], data_dir=CONFIG['DATA_DIR'])
    all_data = LaionDataset(
        data_dir=CONFIG['DATA_DIR'],
        tokenize=True,
        max_samples=CONFIG['TOTAL_DATAPOINTS']
    )
    train = torch.utils.data.Subset(
        all_data,
        range(0, num_train)
    )
    val = torch.utils.data.Subset(
        all_data,
        range(num_train, CONFIG['TOTAL_DATAPOINTS'])
    )
    collate_fn = LaionDataset.collate_function
elif DATASET == 'CC12M':
    CC12mDataset.download(
        max_samples=CONFIG['TOTAL_DATAPOINTS'], data_dir=CONFIG['DATA_DIR'])
    all_data = CC12mDataset(
        data_dir=CONFIG['DATA_DIR'],
        tokenize=True,
        max_samples=CONFIG['TOTAL_DATAPOINTS']
    )
    train = torch.utils.data.Subset(
        all_data,
        range(0, num_train)
    )
    val = torch.utils.data.Subset(
        all_data,
        range(num_train, CONFIG['TOTAL_DATAPOINTS'])
    )
    collate_fn = CC12mDataset.collate_function

In [None]:
train_loader = torch.utils.data.DataLoader(
    train,
    CONFIG['BATCH_SIZE'],
    shuffle=True,
    num_workers=0,
    pin_memory=DEVICE == 'cuda',
    collate_fn=collate_fn
)
val_loader = torch.utils.data.DataLoader(
    val,
    CONFIG['BATCH_SIZE'],
    shuffle=False,
    num_workers=0,
    pin_memory=DEVICE == 'cuda',
    collate_fn=collate_fn
)

In [None]:
logger.info(
    f"Training on {len(train)} samples, validating on {len(val)} samples.")

# Model


In [None]:
wmodel = model.float().to(DEVICE).freeze_backbone(True)

In [None]:
loss = VClipLoss(kl_weight=CONFIG['KL_WEIGHT'],
                 use_mean_only=False, label_smoothing=0.0)

In [None]:
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=CONFIG['LEARNING_RATE'],
    weight_decay=CONFIG['WEIGHT_DECAY']
)

# Warmup scheduler: Linear increase from 0 to target_lr
warmup_scheduler = torch.optim.lr_scheduler.LinearLR(
    optimizer, start_factor=0.01, end_factor=1.0, total_iters=CONFIG['WARMUP_EPOCHS'])

# Decay scheduler: Cosine annealing after warmup
decay_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=CONFIG['DECAY_EPOCHS'])

# Combine them using SequentialLR
# The schedulers will be applied sequentially
scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=[
                                                  warmup_scheduler, decay_scheduler], milestones=[CONFIG['WARMUP_EPOCHS']])

In [None]:
def kl_schedule(epoch: int) -> float:
    """Returns the KL weight for the given epoch based on a linear schedule."""
    epoch = epoch + 1  # Epochs are 0-indexed
    if epoch < CONFIG['NUM_EPOCHS_TO_FULL_KL']:
        return 0
    if epoch >= 2 * CONFIG['NUM_EPOCHS_TO_FULL_KL']:
        return CONFIG['KL_WEIGHT']
    else:
        return CONFIG['KL_WEIGHT'] * (epoch / (2 * CONFIG['NUM_EPOCHS_TO_FULL_KL']))

# Training


In [None]:
def riemannian_gradient_hook(grad, mean):
    """
    Hook to project Euclidean gradient onto the tangent space of the unit sphere.
    This ensures the gradient update stays on the manifold.
    
    For the unit sphere, the tangent space projection is:
    grad_tangent = grad - (grad · x) * x
    
    Args:
        grad: Euclidean gradient [batch_size, dim]
        mean: Points on the unit sphere [batch_size, dim]
    
    Returns:
        Riemannian gradient projected onto tangent space
    """
    # Compute dot product: (grad · x)
    dot_product = (grad * mean).sum(dim=-1, keepdim=True)
    # Project: grad - (grad · x) * x
    riemannian_grad = grad - dot_product * mean
    return riemannian_grad


def train_epoch(
        model: VariationalCLIPModel,
        dataloader: torch.utils.data.DataLoader,
        optimizer: torch.optim.Optimizer,
        criterion: torch.nn.Module,
        epoch: int,
):
    model.train()

    progress_bar = tqdm.tqdm(dataloader, desc="Training Epoch")
    total_loss = 0.0
    total_clip_loss = 0.0
    total_image_kl_loss = 0.0
    total_text_kl_loss = 0.0
    total_kl_loss = 0.0
    nan_count = 0

    for batch_idx, (images, text_tokens, _) in enumerate(progress_bar):
        images, text_tokens = images.to(DEVICE), text_tokens.to(DEVICE)
        images = images.float()

        # Check for NaN in input (laion gives NAN's if it. can't load images)
        if torch.isnan(images).any() or torch.isnan(text_tokens).any():
            logger.warning(f"NaN in input batch {batch_idx}")
            optimizer.zero_grad()
            continue

        image_means, image_concentrations = model.encode_image_tensors(images)
        text_means, text_concentrations = model.encode_text_tokens(text_tokens)

        # Register hooks to project gradients onto tangent space
        # This ensures gradients respect the spherical constraint
        if epoch > CONFIG['NUM_EPOCHS_TO_FULL_KL']:
            if image_means.requires_grad:
                image_means.register_hook(lambda grad: riemannian_gradient_hook(grad, image_means))
            if text_means.requires_grad:
                text_means.register_hook(lambda grad: riemannian_gradient_hook(grad, text_means))

        # Check for NaN in features
        nan_image = torch.isnan(image_means).any(
        ) or torch.isnan(image_concentrations).any()
        nan_text = torch.isnan(text_means).any(
        ) or torch.isnan(text_concentrations).any()
        if nan_image or nan_text:
            logger.warning(f"NaN in features at batch {batch_idx}")
            nan_count += 1
            optimizer.zero_grad()
            continue

        # Debug log for scale parameters before constructing distributions
        for name, p in model.named_parameters():
            if "log_concentration_scale" in name and not torch.isfinite(p).all():
                print(f"Non-finite scale param {name}: {p.data}")

        image_distribution = PowerSpherical(image_means, image_concentrations)
        text_distribution = PowerSpherical(text_means, text_concentrations)
        loss_dict = criterion(image_distribution, text_distribution, model.get_logits_scale(
        ), kl_weight_override=kl_schedule(epoch))
        loss = loss_dict['total_loss']

        # Check for NaN in loss
        if torch.isnan(loss):
            logger.warning(f"NaN loss detected at batch {batch_idx}")
            nan_count += 1
            optimizer.zero_grad()
            continue

        # If not NaN, then add to total loss
        total_loss += loss.item()
        total_clip_loss += loss_dict['clip_loss'].item()
        total_image_kl_loss += loss_dict['image_kl_loss'].item()
        total_text_kl_loss += loss_dict['text_kl_loss'].item()
        total_kl_loss += loss_dict['image_kl_loss'].item() + loss_dict['text_kl_loss'].item()

        # Scale loss for gradient accumulation
        scaled_loss = loss / CONFIG['GRAD_ACCUMULATION_STEPS']
        # Backward pass
        scaled_loss.backward()

        has_nan_grads = False
        for name, param in model.named_parameters():
            if param.grad is not None and torch.isnan(param.grad).any():
                has_nan_grads = True
                logger.warning(f"NaN gradient in {name}")
                break

        if has_nan_grads:
            logger.warning(
                f"NaN gradients detected at batch {batch_idx}, skipping update")
            optimizer.zero_grad()
            continue

        # Gradient accumulation and optimization step
        if (batch_idx + 1) % CONFIG['GRAD_ACCUMULATION_STEPS'] == 0:
            # Clip gradients to prevent explosion
            if CONFIG['CLIP_GRADIENTS']:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

            optimizer.step()
            optimizer.zero_grad()

            progress_bar.set_postfix({
                'loss': total_loss / (batch_idx + 1),
                'clip_loss': total_clip_loss / (batch_idx + 1),
                'total_kl_loss': total_kl_loss / (batch_idx + 1),
                'nan_count': nan_count
            })

        if CONFIG['EMPTY_CACHE_AFTER_BATCH']:
            torch.cuda.empty_cache()
            gc.collect()

    # Handle remaining gradients if not accumulated evenly
    if len(dataloader) % CONFIG['GRAD_ACCUMULATION_STEPS'] != 0:
        if CONFIG['CLIP_GRADIENTS']:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        optimizer.step()
        optimizer.zero_grad()

    total_loss = total_loss / len(dataloader)
    total_clip_loss = total_clip_loss / len(dataloader)
    total_image_kl_loss = total_image_kl_loss / len(dataloader)
    total_text_kl_loss = total_text_kl_loss / len(dataloader)
    total_kl_loss = total_kl_loss / len(dataloader)
    return total_loss, total_clip_loss, total_image_kl_loss, total_text_kl_loss, total_kl_loss

# Validation


In [None]:
def validate(
    model: VariationalCLIPModel,
    dataloader: torch.utils.data.DataLoader,
    criterion: torch.nn.Module,
    epoch: int
):
    model.eval()

    progress_bar = tqdm.tqdm(dataloader, desc="Evaluating")
    total_loss = 0.0
    total_clip_loss = 0.0
    total_image_kl_loss = 0.0
    total_text_kl_loss = 0.0
    total_kl_loss = 0.0

    with torch.no_grad():
        for batch_idx, (images, text_tokens, _) in enumerate(progress_bar):
            images, text_tokens = images.to(DEVICE), text_tokens.to(DEVICE)
            images = images.float()

            # Check for NaN in input (laion gives NAN's if it. can't load images)
            if torch.isnan(images).any() or torch.isnan(text_tokens).any():
                logger.warning(f"NaN in input batch {batch_idx}")
                optimizer.zero_grad()
                continue

            image_means, image_concentrations = model.encode_image_tensors(
                images)
            text_means, text_concentrations = model.encode_text_tokens(
                text_tokens)

            # Check for NaN in features
            nan_image = torch.isnan(image_means).any() or torch.isnan(
                image_concentrations).any()
            nan_text = torch.isnan(text_means).any() or torch.isnan(
                text_concentrations).any()
            if nan_image or nan_text:
                logger.warning(f"NaN in features at batch {batch_idx}")
                optimizer.zero_grad()
                continue

            # Debug log for scale parameters before constructing distributions
            for name, p in model.named_parameters():
                if "log_concentration_scale" in name and not torch.isfinite(p).all():
                    print(f"Non-finite scale param {name}: {p.data}")

            image_distribution = PowerSpherical(image_means,
                                                image_concentrations)
            text_distribution = PowerSpherical(text_means, text_concentrations)
            loss_dict = criterion(image_distribution, text_distribution, model.get_logits_scale(
            ), kl_weight_override=kl_schedule(epoch))
            loss = loss_dict['total_loss']

            # Check for NaN loss
            if torch.isnan(loss):
                logger.warning("NaN in validation loss, skipping batch")
                continue

            total_loss += loss.item()
            total_clip_loss += loss_dict['clip_loss'].item()
            total_image_kl_loss += loss_dict['image_kl_loss'].item()
            total_text_kl_loss += loss_dict['text_kl_loss'].item()
            total_kl_loss += loss_dict['image_kl_loss'].item() + loss_dict['text_kl_loss'].item()
            progress_bar.set_postfix({
                'loss': total_loss / (batch_idx + 1),
                'clip_loss':  total_clip_loss / (batch_idx + 1),
                'total_kl_loss': total_kl_loss / (batch_idx + 1),
            })

            if CONFIG['EMPTY_CACHE_AFTER_BATCH']:
                torch.cuda.empty_cache()
                gc.collect()

    total_loss = total_loss / len(dataloader)
    total_clip_loss = total_clip_loss / len(dataloader)
    total_image_kl_loss = total_image_kl_loss / len(dataloader)
    total_text_kl_loss = total_text_kl_loss / len(dataloader)
    total_kl_loss = total_kl_loss / len(dataloader)
    return total_loss, total_clip_loss, total_image_kl_loss, total_text_kl_loss, total_kl_loss

# Full Train Eval Pipeline


In [None]:
def save_checkpoint(
    filename: str,
    model: VariationalCLIPModel,
    optimizer: torch.optim.Optimizer,
    train_losses: list[float],
    val_losses: list[float],
):
    """Save model checkpoint."""
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_losses': train_losses,
        'val_losses': val_losses,
        'config': CONFIG,
    }
    torch.save(checkpoint, filename)
    logger.info(f"Checkpoint saved: {filename}")

In [None]:
def train(
    num_epochs: int,
    model: VariationalCLIPModel,
    train_loader: torch.utils.data.DataLoader,
    val_loader: torch.utils.data.DataLoader,
    optimizer: torch.optim.Optimizer,
    scheduler: torch.optim.lr_scheduler.LRScheduler,
    criterion: torch.nn.Module
):
    """Train model for specified epochs."""
    logger.info(f"Starting training on {DEVICE} for {num_epochs} epochs...")

    train_losses = []
    val_losses = []
    best_val_loss = float('inf')

    for epoch in range(num_epochs):
        logger.info(f"\nEpoch {epoch + 1}/{num_epochs}")
        logger.info(f"KL Weight: {kl_schedule(epoch):.6f}")

        # Train
        train_loss, train_clip, train_img_kl, train_text_kl, train_kl = train_epoch(
            model,
            train_loader,
            optimizer,
            criterion,
            epoch
        )
        train_losses.append(train_loss)
        logger.info(f"Train Loss: {train_loss:.6f}")

        # Validate
        val_loss, val_clip, val_img_kl, val_text_kl, val_kl = validate(
            model,
            val_loader,
            criterion,
            epoch
        )
        val_losses.append(val_loss)
        logger.info(f"Val Loss: {val_loss:.6f}")

        # Skip if losses are NaN
        if torch.isnan(torch.tensor(train_loss)) or torch.isnan(torch.tensor(val_loss)):
            logger.error("NaN loss detected! Stopping training.")
            break

        scheduler.step()
        logger.info(
            f"Learning Rate adjusted to: {scheduler.get_last_lr()[0]:.6f}")

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            save_checkpoint(
                f"best_VCLIP_model_on_{DATASET}.pt",
                model,
                optimizer,
                train_losses,
                val_losses
            )
            logger.info(f"✓ Saved best model (val_loss: {val_loss:.6f})")

        if CONFIG['USE_WANDB']:
            wandb.log({
                'train_loss': train_loss,
                'val_loss': val_loss,
                'train_clip_loss': train_clip,
                'val_clip_loss': val_clip,
                'train_kl_loss': train_kl,
                'val_kl_loss': val_kl,
                'kl_weight': kl_schedule(epoch),
                'learning_rate': scheduler.get_last_lr()[0],
            })

    return train_losses, val_losses

# WandB


In [None]:
# Use wandb? Resume Training?
PROJECT_NAME = CONFIG['WANDB_PROJECT_NAME']
USE_WANDB = CONFIG['USE_WANDB']
RESUME_LOGGING = CONFIG['WANDB_PREVIOUS_RUN_ID'] is not None
run_name = CONFIG['WANDB_RUN_NAME']

if USE_WANDB:
    wandb.login(key=WANDB_API_KEY)

    if RESUME_LOGGING:
        run_id = CONFIG['WANDB_PREVIOUS_RUN_ID']
        run = wandb.init(
            settings=wandb.Settings(symlink=False),
            id=run_id,
            resume="must",
            project=PROJECT_NAME,
            entity="multimodal_2025",
        )
    else:
        run = wandb.init(
            name=run_name,
            reinit=True,
            project=PROJECT_NAME,
            config=CONFIG,
            entity="multimodal_2025",
        )

# Run Training


In [None]:
!export PYTORCH_ENABLE_MPS_FALLBACK=1
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
train_losses, val_losses = train(
    CONFIG["NUM_EPOCHS"],
    model,
    train_loader,
    val_loader,
    optimizer,
    scheduler,
    loss  # type: ignore
)

In [None]:
# grab the current time
from datetime import datetime
timestamp = datetime.timestamp(datetime.now())

# Save best model to /mnt/content/Models
os.makedirs('/mnt/content/Models', exist_ok=True)
torch.save(model.state_dict(), f"/mnt/content/Models/best_vclip_model_at_{timestamp}.pth")
logger.info(f"/mnt/content/Models/best_vclip_model_at_{timestamp}.pth")