# Initialization and Config


In [1]:
IS_COLAB = False

In [2]:
# Put all colab initialization code inside this block or make a copy of this notebook
if IS_COLAB:   
    from google.colab import userdata
    import os

    github_token = ""
    github_username = "" # Replace with your GitHub username
    repository_url = f"https://{github_username}:{github_token}@github.com/neskech/Multimodal-2025.git"

    !git clone {repository_url}

    %cd Multimodal-2025
    !pip install -r requirements.txt

In [3]:
%load_ext autoreload
%autoreload 2

import torch
import gc
import tqdm
import matplotlib.pyplot as plt
import numpy as np
import logging
import wandb
import os
import sys
import peft
from dotenv import load_dotenv
from typing import Literal

sys.path.append("..")
from Datasets.coco import CocoDataset
from Datasets.cc12m import CC12mDataset
from Datasets.cood import CoodDataset
from Datasets.laion import LaionDataset
from Models.clipModel import CLIPModel
from Models.cloobModel import CLOOBModel
from Models.vClipModel import VariationalCLIPModel
from Models.alignClipModel import AlignCLIPModel
from losses.clipLoss import ClipLoss
from losses.cloobLoss import CLOOBLoss

  from .autonotebook import tqdm as notebook_tqdm
2025-10-22 21:53:26,270 - INFO - JAX version 0.8.0 available.


In [4]:
type Model = Literal['CLIP', 'CLOOB', 'ALIGN']
type ModelClass = CLIPModel | CLOOBModel | AlignCLIPModel
MODEL: Model = 'CLIP'

type Dataset = Literal['COCO', 'COOD', 'CC12M', 'LAION']
DATASET: Dataset = 'LAION'

DEVICE = None
if torch.cuda.is_available():
    DEVICE = 'cuda'
elif torch.mps.is_available():
    DEVICE = 'mps'
else:
    DEVICE = 'cpu'

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Get WANDB Key (Use a .env file to store the key)
load_dotenv()
WANDB_API_KEY = os.environ.get('WANDB_API_KEY')

In [5]:
CONFIG = {
    'NUM_EPOCHS': 20,
    'BATCH_SIZE': 32,
    'LEARNING_RATE': 1e-5,
    'WEIGHT_DECAY': 1e-3,

    # Scheduler parameters
    'STEP_LR_STEP_SIZE': 5,
    'STEP_LR_GAMMA': 0.5,

    # To avoid gradient explosion. Set to 1 to disable
    'GRAD_ACCUMULATION_STEPS': 1,
    # Clip gradients to avoid explosion
    'CLIP_GRADIENTS': False,
    'EMPTY_CACHE_AFTER_BATCH': False,

    'DATA_DIR': '../Data',
    'TRAIN_RATIO': 0.8,
    'TOTAL_DATAPOINTS': 10_000,

    'USE_LORA': False,

    'USE_WANDB': True,
    'WANDB_RUN_NAME': 'CLIP_LAION',
    'WANDB_PREVIOUS_RUN_ID': None, # set to None if not resuming
    'WANDB_PROJECT_NAME': 'multimodal_2025',
}

# Data


In [6]:
num_train = int(CONFIG['TOTAL_DATAPOINTS'] * CONFIG['TRAIN_RATIO'])
num_val = int(CONFIG['TOTAL_DATAPOINTS'] * (1.0 - CONFIG['TRAIN_RATIO']))

if DATASET == 'COCO':
    CocoDataset.download(download_script_path='../Datasets/download_coco.sh', data_dir=CONFIG['DATA_DIR'])
    train = CocoDataset(
        data_dir=CONFIG['DATA_DIR'],
        split='train2017',
        tokenize=True,
        max_samples=num_train
    )
    val = CocoDataset(
        data_dir=CONFIG['DATA_DIR'],
        split='val2017',
        tokenize=True,
        max_samples=num_val
    ) 
    collate_fn = CocoDataset.collate_function  
elif DATASET == 'COOD':
    CoodDataset.download(data_dir=CONFIG['DATA_DIR'])
    all_data = CoodDataset(
        data_dir=CONFIG['DATA_DIR'],
        tokenize=True,
        max_samples=CONFIG['TOTAL_DATAPOINTS']
    )
    train = torch.utils.data.Subset(
        all_data,
        range(0, num_train)
    )
    val = torch.utils.data.Subset(
        all_data,
        range(num_train, CONFIG['TOTAL_DATAPOINTS'])
    )
    collate_fn = CoodDataset.collate_function
elif DATASET == 'LAION':
    LaionDataset.download(max_samples=CONFIG['TOTAL_DATAPOINTS'], data_dir=CONFIG['DATA_DIR'])
    all_data = LaionDataset(
        tokenize=True,
        max_samples=CONFIG['TOTAL_DATAPOINTS']
    )
    train = torch.utils.data.Subset(
        all_data,
        range(0, num_train)
    )
    val = torch.utils.data.Subset(
        all_data,
        range(num_train, CONFIG['TOTAL_DATAPOINTS'])
    )
    collate_fn = LaionDataset.collate_function
elif DATASET == 'CC12M':
    CC12mDataset.download(max_samples=CONFIG['TOTAL_DATAPOINTS'], data_dir=CONFIG['DATA_DIR'])
    all_data = CC12mDataset(
        data_dir=CONFIG['DATA_DIR'],
        tokenize=True,
        max_samples=CONFIG['TOTAL_DATAPOINTS']
    )
    train = torch.utils.data.Subset(
        all_data,
        range(0, num_train)
    )
    val = torch.utils.data.Subset(
        all_data,
        range(num_train, CONFIG['TOTAL_DATAPOINTS'])
    )
    collate_fn = CC12mDataset.collate_function


Downloading LAION samples:   2%|▏         | 167/10000 [01:45<1:43:39,  1.58it/s]


KeyboardInterrupt: 

In [None]:
train_loader = torch.utils.data.DataLoader(
    train,
    CONFIG['BATCH_SIZE'],
    shuffle=True,
    num_workers=0,
    pin_memory=DEVICE == 'cuda',
    collate_fn=collate_fn
)
val_loader = torch.utils.data.DataLoader(
    val,
    CONFIG['BATCH_SIZE'],
    shuffle=False,
    num_workers=0,
    pin_memory=DEVICE == 'cuda',
    collate_fn=collate_fn
)

In [None]:
logger.info(f"Training on {len(train)} samples, validating on {len(val)} samples.")

2025-10-22 21:34:48,919 - INFO - Training on 8000 samples, validating on 2000 samples.


# Model


In [None]:
if MODEL == 'CLIP':
    model = CLIPModel(DEVICE)
    loss = ClipLoss()
elif MODEL == 'CLOOB':
    model = CLOOBModel(DEVICE)
    config = model.get_config()
    loss = CLOOBLoss(config['inv_tau'], config['scale_hopfield'])
else:
    model = AlignCLIPModel(DEVICE)
    loss = None # TODO: Implement

model = model.float()

In [None]:
# TODO: This only works if you apply it to some modules
# See https://github.com/huggingface/peft/blob/main/examples/multilayer_perceptron/multilayer_perceptron_lora.ipynb 
# See https://huggingface.co/docs/peft/en/developer_guides/custom_models
if CONFIG['USE_LORA']:
   # Idek what any of this does
    lora_config = peft.LoraConfig(
        r=8, # Rank of the low-rank matrices
        lora_alpha=16, # Scaling factor for LoRA updates
       # target_modules=["model.visual.conv1"], # Layers to apply LoRA to (e.g., in a Transformer)
        lora_dropout=0.1,
        bias="none",
        task_type=peft.TaskType.FEATURE_EXTRACTION
    )
    model = peft.PeftModel(model, lora_config)

In [None]:
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=CONFIG['LEARNING_RATE'],
    weight_decay=CONFIG['WEIGHT_DECAY']
)

scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer, 
    step_size=CONFIG['STEP_LR_STEP_SIZE'],
    gamma=CONFIG['STEP_LR_GAMMA']
)

# Training


In [None]:
def train_epoch(
        model: ModelClass, 
        dataloader: torch.utils.data.DataLoader, 
        optimizer: torch.optim.Optimizer, 
        criterion: torch.nn.Module, 
):
    model.train()

    progress_bar = tqdm.tqdm(dataloader, desc="Training Epoch")
    total_loss = 0.0
    nan_count = 0

    for batch_idx, (images, text_tokens) in enumerate(progress_bar):
        images, text_tokens = images.to(DEVICE), text_tokens.to(DEVICE)
        images = images.float()
    
        # Check for NaN in input (laion gives NAN's if it. can't load images)
        if torch.isnan(images).any() or torch.isnan(text_tokens).any():
            logger.warning(f"NaN in input batch {batch_idx}")
            optimizer.zero_grad()
            continue
        
        image_features = model.encode_image_tensors(images)
        text_features = model.encode_text_tokens(text_tokens)
    
        # Check for NaN in features
        if torch.isnan(image_features).any() or torch.isnan(text_features).any():
            logger.warning(f"NaN in features at batch {batch_idx}: Image features stats - min={image_features.min()}, max={image_features.max()}, mean={image_features.mean()}; Text features stats - min={text_features.min()}, max={text_features.max()}, mean={text_features.mean()}")
            nan_count += 1
            optimizer.zero_grad()
            continue
        
        loss = criterion(image_features, text_features)

        # Check for NaN in loss
        if torch.isnan(loss):
            logger.warning(f"NaN loss detected at batch {batch_idx}")
            nan_count += 1
            optimizer.zero_grad()
            continue
        
        # If not NaN, then add to total loss
        total_loss += loss.item()
    
        # Scale loss for gradient accumulation
        scaled_loss = loss / CONFIG['GRAD_ACCUMULATION_STEPS']
        # Backward pass
        scaled_loss.backward()

        has_nan_grads = False
        for name, param in model.named_parameters():
            if param.grad is not None and torch.isnan(param.grad).any():
                has_nan_grads = True
                logger.warning(f"NaN gradient in {name}")
                break

        if has_nan_grads:
            logger.warning(f"NaN gradients detected at batch {batch_idx}, skipping update")
            optimizer.zero_grad()
            continue

        # Gradient accumulation and optimization step
        if (batch_idx + 1) % CONFIG['GRAD_ACCUMULATION_STEPS'] == 0:
            # Clip gradients to prevent explosion
            if CONFIG['CLIP_GRADIENTS']:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

            optimizer.step()
            optimizer.zero_grad()

            progress_bar.set_postfix({
                'loss': total_loss / (batch_idx + 1),
                'nan_count': nan_count
            })


        if CONFIG['EMPTY_CACHE_AFTER_BATCH']:
            torch.cuda.empty_cache()
            gc.collect()

    # Handle remaining gradients if not accumulated evenly
    if len(dataloader) % CONFIG['GRAD_ACCUMULATION_STEPS'] != 0:
        if CONFIG['CLIP_GRADIENTS']:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        optimizer.step()
        optimizer.zero_grad()

    epoch_loss = total_loss / len(dataloader)
    return epoch_loss

# Validation

In [None]:
def validate(
        model: ModelClass, 
        dataloader: torch.utils.data.DataLoader, 
        criterion: torch.nn.Module,
):
    model.eval()

    progress_bar = tqdm.tqdm(dataloader, desc="Evaluating")
    total_loss = 0.0

    with torch.no_grad():
        for batch_idx, (images, text_tokens) in enumerate(progress_bar):
            images, text_tokens = images.to(DEVICE), text_tokens.to(DEVICE)
            images = images.float()
    
            # Check for NaN in input (laion gives NAN's if it. can't load images)
            if torch.isnan(images).any() or torch.isnan(text_tokens).any():
                logger.warning(f"NaN in input batch {batch_idx}")
                optimizer.zero_grad()
                continue
            
            image_features = model.encode_image_tensors(images)
            text_features = model.encode_text_tokens(text_tokens)
            
            # Check for NaN in features
            if torch.isnan(image_features).any() or torch.isnan(text_features).any():
                logger.warning(f"NaN in features at batch {batch_idx}")
                optimizer.zero_grad()
                continue
        
            loss = criterion(image_features, text_features)

            # Check for NaN loss
            if torch.isnan(loss):
                logger.warning("NaN in validation loss, skipping batch")
                continue
        
            total_loss += loss.item()
            progress_bar.set_postfix({'loss': total_loss / (batch_idx + 1)})

            if CONFIG['EMPTY_CACHE_AFTER_BATCH']:
                torch.cuda.empty_cache()
                gc.collect()

    epoch_loss = total_loss / len(dataloader)
    return epoch_loss


# Full Train Eval Pipeline

In [None]:
def save_checkpoint(
    filename: str,
    model: ModelClass, 
    optimizer: torch.optim.Optimizer, 
    train_losses: list[float],
    val_losses: list[float],
):
    """Save model checkpoint."""
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_losses': train_losses,
        'val_losses': val_losses,
        'config': CONFIG,
    }
    torch.save(checkpoint, filename)
    logger.info(f"Checkpoint saved: {filename}")

def plot_losses(model_name: str, train_losses: list[float], val_losses: list[float]):
    """Plot training and validation losses."""
    if len(train_losses) == 0:
        logger.warning("No losses to plot")
        return

    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label='Train Loss', marker='o')
    plt.plot(val_losses, label='Val Loss', marker='s')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f'{model_name.upper()} Training Curves')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

In [None]:
def train( 
    num_epochs: int,
    model: ModelClass, 
    train_loader: torch.utils.data.DataLoader,
    val_loader: torch.utils.data.DataLoader,
    optimizer: torch.optim.Optimizer, 
    scheduler: torch.optim.lr_scheduler.LRScheduler,
    criterion: torch.nn.Module
):
    """Train model for specified epochs."""
    logger.info(f"Starting training on {DEVICE} for {num_epochs} epochs...")

    train_losses = []
    val_losses = []
    best_val_loss = float('inf')

    for epoch in range(num_epochs):
        logger.info(f"\nEpoch {epoch + 1}/{num_epochs}")

        # Train
        train_loss = train_epoch(
            model,
            train_loader,
            optimizer,
    criterion
        )
        train_losses.append(train_loss)
        logger.info(f"Train Loss: {train_loss:.6f}")

        # Validate
        val_loss = validate(
            model,
            val_loader,
            criterion
        )
        val_losses.append(val_loss)
        logger.info(f"Val Loss: {val_loss:.6f}")

        # Skip if losses are NaN
        if torch.isnan(torch.tensor(train_loss)) or torch.isnan(torch.tensor(val_loss)):
            logger.error("NaN loss detected! Stopping training.")
            break

        scheduler.step()
        logger.info(f"Learning Rate adjusted to: {scheduler.get_last_lr()[0]:.6f}")

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            save_checkpoint(
                f"best_{MODEL}_model_on_{DATASET}.pt", 
                model,
                optimizer, 
                train_losses,
                val_losses
            )
            logger.info(f"✓ Saved best model (val_loss: {val_loss:.6f})")

        if CONFIG['USE_WANDB']:
            wandb.log({
                'epoch': epoch + 1,
                'train_loss': train_loss,
                'val_loss': val_loss,
                'learning_rate': scheduler.get_last_lr()[0],
            })

    return train_losses, val_losses

# WandB

In [None]:
# Use wandb? Resume Training?
PROJECT_NAME = CONFIG['WANDB_PROJECT_NAME']
USE_WANDB = CONFIG['USE_WANDB']
RESUME_LOGGING = CONFIG['WANDB_PREVIOUS_RUN_ID'] is not None
run_name = CONFIG['WANDB_RUN_NAME']

if USE_WANDB:
    wandb.login(key=WANDB_API_KEY)  # your wandb key

    if RESUME_LOGGING:
        run_id = CONFIG['WANDB_PREVIOUS_RUN_ID']
        run = wandb.init(
            settings=wandb.Settings(symlink=False),
            id=run_id,
            resume="must",
            project=PROJECT_NAME,
        )
    else:
        run = wandb.init(
            name=run_name,
            reinit=True,
            project=PROJECT_NAME,
            config=CONFIG
        )

[34m[1mwandb[0m: Currently logged in as: [33mcmellor[0m ([33mcmellor-carnegie-mellon-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


# Run Training


In [None]:
train_losses, val_losses = train(
    CONFIG["NUM_EPOCHS"],
    model,
    train_loader,
    val_loader,
    optimizer,
    scheduler,
    loss # type: ignore
)

2025-10-22 21:34:52,030 - INFO - Starting training on cuda for 20 epochs...
2025-10-22 21:34:52,030 - INFO - 
Epoch 1/20
Training Epoch:   2%|▏         | 4/250 [01:29<1:31:37, 22.35s/it, loss=1.08, nan_count=0]


KeyboardInterrupt: 

Error in callback <bound method _WandbInit._post_run_cell_hook of <wandb.sdk.wandb_init._WandbInit object at 0x7c9383f14980>> (for post_run_cell), with arguments args (<ExecutionResult object at 7c939815de50, execution_count=17 error_before_exec=None error_in_exec= info=<ExecutionInfo object at 7c937fafe350, raw_cell="train_losses, val_losses = train(
    CONFIG["NUM_.." transformed_cell="train_losses, val_losses = train(
    CONFIG["NUM_.." store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell://wsl%2Bubuntu/home/ness/School/Senior/Multimodal-2025/Notebooks/finetune.ipynb#X33sdnNjb2RlLXJlbW90ZQ%3D%3D> result=None>,),kwargs {}:


ConnectionResetError: Connection lost

In [None]:
plot_losses(MODEL, train_losses, val_losses)