In [None]:
from google.colab import userdata
import os

github_token = ""
github_username = "" # Replace with your GitHub username
repository_url = f"https://{github_username}:{github_token}@github.com/neskech/Multimodal-2025.git"

!git clone {repository_url}


%cd Multimodal-2025
!git checkout align
!git submodule update --init --recursive

!pip install -r requirements.txt

In [None]:
import torch
import gc
import tqdm
import matplotlib.pyplot as plt
import numpy as np
import logging
import wandb
import os
import sys
import peft
from dotenv import load_dotenv
from typing import Literal, Union

sys.path.append("..")
from Datasets.coco import CocoDataset
from Datasets.cc12m import CC12mDataset
from Datasets.cood import CoodDataset
from Datasets.laion import LaionDataset
from Models.clipModel import CLIPModel
from Models.cloobModel import CLOOBModel
from Models.vClipModel import VariationalCLIPModel
from Models.alignClipModel import AlignCLIPModel
from losses.clipLoss import ClipLoss
from losses.cloobLoss import CLOOBLoss

In [None]:
# Type aliases for compatibility with older Python versions
Model = Literal['CLIP', 'CLOOB', 'ALIGN']
ModelClass = Union[CLIPModel, CLOOBModel, AlignCLIPModel]
MODEL: Model = 'ALIGN'

type Dataset = Literal['COCO', 'COOD', 'CC12M', 'LAION']
DATASET: Dataset = 'COOD'

DEVICE = None
if torch.cuda.is_available():
    DEVICE = 'cuda'
elif torch.mps.is_available():
    DEVICE = 'mps'
else:
    DEVICE = 'cpu'

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Get WANDB Key (Use a .env file to store the key)
load_dotenv()
WANDB_API_KEY = os.environ.get('WANDB_API_KEY')

In [None]:
CONFIG = {
    'NUM_EPOCHS': 15,
    'BATCH_SIZE': 32,
    'LEARNING_RATE': 1e-6,
    'WEIGHT_DECAY': 1e-2,

    # Scheduler parameters
    'WARMUP_EPOCHS': 5,
    'DECAY_EPOCHS': 10,

    # To avoid gradient explosion. Set to 1 to disable
    'GRAD_ACCUMULATION_STEPS': 1,
    # Clip gradients to avoid explosion
    'CLIP_GRADIENTS': True,
    'EMPTY_CACHE_AFTER_BATCH': False,

    'DATA_DIR': '../Data',
    'TRAIN_RATIO': 0.9,
    'TOTAL_DATAPOINTS': 9_000,

    'USE_LORA': False,

    'USE_WANDB': True,
    'WANDB_RUN_NAME': f'{MODEL}_finetune_on_{DATASET}',
    'WANDB_PREVIOUS_RUN_ID': None, # set to None if not resuming
    'WANDB_PROJECT_NAME': 'multimodal_2025',
}

In [None]:
num_train = int(CONFIG['TOTAL_DATAPOINTS'] * CONFIG['TRAIN_RATIO'])
num_val = int(CONFIG['TOTAL_DATAPOINTS'] * (1.0 - CONFIG['TRAIN_RATIO']))

if DATASET == 'COCO':
    CocoDataset.download(download_script_path='../Datasets/download_coco.sh', data_dir=CONFIG['DATA_DIR'])
    train = CocoDataset(
        data_dir=CONFIG['DATA_DIR'],
        split='train2017',
        tokenize=True,
        max_samples=num_train
    )
    val = CocoDataset(
        data_dir=CONFIG['DATA_DIR'],
        split='val2017',
        tokenize=True,
        max_samples=num_val
    )
    collate_fn = CocoDataset.collate_function
elif DATASET == 'COOD':
    CoodDataset.download(data_dir=CONFIG['DATA_DIR'])
    all_data = CoodDataset(
        data_dir=CONFIG['DATA_DIR'],
        tokenize=True,
        max_samples=CONFIG['TOTAL_DATAPOINTS']
    )
    train = torch.utils.data.Subset(
        all_data,
        range(0, num_train)
    )
    val = torch.utils.data.Subset(
        all_data,
        range(num_train, CONFIG['TOTAL_DATAPOINTS'])
    )
    collate_fn = CoodDataset.collate_function
elif DATASET == 'LAION':
    LaionDataset.download(max_samples=CONFIG['TOTAL_DATAPOINTS'], data_dir=CONFIG['DATA_DIR'])
    all_data = LaionDataset(
        data_dir=CONFIG['DATA_DIR'],
        tokenize=True,
        max_samples=CONFIG['TOTAL_DATAPOINTS']
    )
    train = torch.utils.data.Subset(
        all_data,
        range(0, num_train)
    )
    val = torch.utils.data.Subset(
        all_data,
        range(num_train, CONFIG['TOTAL_DATAPOINTS'])
    )
    collate_fn = LaionDataset.collate_function
elif DATASET == 'CC12M':
    CC12mDataset.download(max_samples=CONFIG['TOTAL_DATAPOINTS'], data_dir=CONFIG['DATA_DIR'])
    all_data = CC12mDataset(
        data_dir=CONFIG['DATA_DIR'],
        tokenize=True,
        max_samples=CONFIG['TOTAL_DATAPOINTS']
    )
    train = torch.utils.data.Subset(
        all_data,
        range(0, num_train)
    )
    val = torch.utils.data.Subset(
        all_data,
        range(num_train, CONFIG['TOTAL_DATAPOINTS'])
    )
    collate_fn = CC12mDataset.collate_function


In [None]:
train_loader = torch.utils.data.DataLoader(
    train,
    CONFIG['BATCH_SIZE'],
    shuffle=True,
    num_workers=0,
    pin_memory=DEVICE == 'cuda',
    collate_fn=collate_fn
)
val_loader = torch.utils.data.DataLoader(
    val,
    CONFIG['BATCH_SIZE'],
    shuffle=False,
    num_workers=0,
    pin_memory=DEVICE == 'cuda',
    collate_fn=collate_fn
)

In [None]:
logger.info(f"Training on {len(train)} samples, validating on {len(val)} samples.")

In [None]:
if MODEL == 'CLIP':
    model = CLIPModel(DEVICE)
    model.freeze_for_finetuning()
    loss = ClipLoss()
elif MODEL == 'CLOOB':
    model = CLOOBModel(DEVICE)
    model.freeze_for_finetuning()
    config = model.get_config()
    loss = CLOOBLoss(config['inv_tau'], config['scale_hopfield'], device=DEVICE)
else:
    model = AlignCLIPModel(DEVICE)



model = model.float()

In [None]:
from AlignCLIP.align_clip.loss import ClipInModalityLoss

loss =  ClipInModalityLoss(alpha = 1.0, beta = 0.5, nl_semantic_supervision=True)

In [None]:
from sentence_transformers import SentenceTransformer as SBERT
sbert = SBERT('all-mpnet-base-v2')

In [None]:
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=CONFIG['LEARNING_RATE'],
    weight_decay=CONFIG['WEIGHT_DECAY']
)

# Warmup scheduler: Linear increase from 0 to target_lr
warmup_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.01, end_factor=1.0, total_iters=CONFIG['WARMUP_EPOCHS'])

# Decay scheduler: Cosine annealing after warmup
decay_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CONFIG['DECAY_EPOCHS'])

# Combine them using SequentialLR
# The schedulers will be applied sequentially
scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[CONFIG['WARMUP_EPOCHS']])

In [None]:
def train_epoch(
        model: ModelClass,
        dataloader: torch.utils.data.DataLoader,
        optimizer: torch.optim.Optimizer,
        criterion: torch.nn.Module,
):
    model.train()

    progress_bar = tqdm.tqdm(dataloader, desc="Training Epoch")
    total_loss = 0.0
    nan_count = 0

    for batch_idx, (images, text_tokens, caption) in enumerate(progress_bar):
        images, text_tokens = images.to(DEVICE), text_tokens.to(DEVICE)
        images = images.float()

        # Check for NaN in input (laion gives NAN's if it. can't load images)
        if torch.isnan(images).any() or torch.isnan(text_tokens).any():
            logger.warning(f"NaN in input batch {batch_idx}")
            optimizer.zero_grad()
            continue
        semantic_features = sbert.encode(sentences=caption, show_progress_bar=False)
        semantic_features = torch.from_numpy(semantic_features)
        semantic_features = semantic_features.to(device=DEVICE, non_blocking=True)
        image_features = model.encode_image_tensors(images)
        text_features = model.encode_text_tokens(text_tokens)
        logit_scale = model.model.logit_scale.exp()
        # Check for NaN in features
        if torch.isnan(image_features).any() or torch.isnan(text_features).any():
            logger.warning(f"NaN in features at batch {batch_idx}: Image features stats - min={image_features.min()}, max={image_features.max()}, mean={image_features.mean()}; Text features stats - min={text_features.min()}, max={text_features.max()}, mean={text_features.mean()}")
            nan_count += 1
            optimizer.zero_grad()
            continue

        loss = criterion(image_features, text_features, logit_scale, semantic_features = semantic_features)

        # Check for NaN in loss
        if torch.isnan(loss):
            logger.warning(f"NaN loss detected at batch {batch_idx}")
            nan_count += 1
            optimizer.zero_grad()
            continue

        # If not NaN, then add to total loss
        total_loss += loss.item()

        # Scale loss for gradient accumulation
        scaled_loss = loss / CONFIG['GRAD_ACCUMULATION_STEPS']
        # Backward pass
        scaled_loss.backward()

        has_nan_grads = False
        for name, param in model.named_parameters():
            if param.grad is not None and torch.isnan(param.grad).any():
                has_nan_grads = True
                logger.warning(f"NaN gradient in {name}")
                break

        if has_nan_grads:
            logger.warning(f"NaN gradients detected at batch {batch_idx}, skipping update")
            optimizer.zero_grad()
            continue

        # Gradient accumulation and optimization step
        if (batch_idx + 1) % CONFIG['GRAD_ACCUMULATION_STEPS'] == 0:
            # Clip gradients to prevent explosion
            if CONFIG['CLIP_GRADIENTS']:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

            optimizer.step()
            optimizer.zero_grad()

            progress_bar.set_postfix({
                'loss': total_loss / (batch_idx + 1),
                'nan_count': nan_count
            })


        if CONFIG['EMPTY_CACHE_AFTER_BATCH']:
            torch.cuda.empty_cache()
            gc.collect()

    # Handle remaining gradients if not accumulated evenly
    if len(dataloader) % CONFIG['GRAD_ACCUMULATION_STEPS'] != 0:
        if CONFIG['CLIP_GRADIENTS']:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        optimizer.step()
        optimizer.zero_grad()

    epoch_loss = total_loss / len(dataloader)
    return epoch_loss

In [None]:
def validate(
        model: ModelClass,
        dataloader: torch.utils.data.DataLoader,
        criterion: torch.nn.Module,
):
    model.eval()

    progress_bar = tqdm.tqdm(dataloader, desc="Evaluating")
    total_loss = 0.0

    with torch.no_grad():
        for batch_idx, (images, text_tokens, caption) in enumerate(progress_bar):
            images, text_tokens = images.to(DEVICE), text_tokens.to(DEVICE)
            images = images.float()

            # Check for NaN in input (laion gives NAN's if it. can't load images)
            if torch.isnan(images).any() or torch.isnan(text_tokens).any():
                logger.warning(f"NaN in input batch {batch_idx}")
                optimizer.zero_grad()
                continue

            semantic_features = sbert.encode(sentences=caption, show_progress_bar=False)
            semantic_features = torch.from_numpy(semantic_features)
            semantic_features = semantic_features.to(device=DEVICE, non_blocking=True)
            image_features = model.encode_image_tensors(images)
            text_features = model.encode_text_tokens(text_tokens)
            logit_scale = model.model.logit_scale.exp()

            # Check for NaN in features
            if torch.isnan(image_features).any() or torch.isnan(text_features).any():
                logger.warning(f"NaN in features at batch {batch_idx}")
                optimizer.zero_grad()
                continue

            loss = criterion(image_features, text_features, logit_scale, semantic_features = semantic_features)

            # Check for NaN loss
            if torch.isnan(loss):
                logger.warning("NaN in validation loss, skipping batch")
                continue

            total_loss += loss.item()
            progress_bar.set_postfix({'loss': total_loss / (batch_idx + 1)})

            if CONFIG['EMPTY_CACHE_AFTER_BATCH']:
                torch.cuda.empty_cache()
                gc.collect()

    epoch_loss = total_loss / len(dataloader)
    return epoch_loss

In [None]:
def save_checkpoint(
    filename: str,
    model: ModelClass,
    optimizer: torch.optim.Optimizer,
    train_losses: list[float],
    val_losses: list[float],
):
    """Save model checkpoint."""
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_losses': train_losses,
        'val_losses': val_losses,
        'config': CONFIG,
    }
    torch.save(checkpoint, filename)
    logger.info(f"Checkpoint saved: {filename}")

def plot_losses(model_name: str, train_losses: list[float], val_losses: list[float]):
    """Plot training and validation losses."""
    if len(train_losses) == 0:
        logger.warning("No losses to plot")
        return

    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label='Train Loss', marker='o')
    plt.plot(val_losses, label='Val Loss', marker='s')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f'{model_name.upper()} Training Curves')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

In [None]:
def train(
    num_epochs: int,
    model: ModelClass,
    train_loader: torch.utils.data.DataLoader,
    val_loader: torch.utils.data.DataLoader,
    optimizer: torch.optim.Optimizer,
    scheduler: torch.optim.lr_scheduler.LRScheduler,
    criterion: torch.nn.Module
):
    """Train model for specified epochs."""
    logger.info(f"Starting training on {DEVICE} for {num_epochs} epochs...")

    train_losses = []
    val_losses = []
    best_val_loss = float('inf')

    for epoch in range(num_epochs):
        logger.info(f"\nEpoch {epoch + 1}/{num_epochs}")

        # Train
        train_loss = train_epoch(
            model,
            train_loader,
            optimizer,
            criterion
        )
        train_losses.append(train_loss)
        logger.info(f"Train Loss: {train_loss:.6f}")

        # Validate
        val_loss = validate(
            model,
            val_loader,
            criterion
        )
        val_losses.append(val_loss)
        logger.info(f"Val Loss: {val_loss:.6f}")

        # Skip if losses are NaN
        if torch.isnan(torch.tensor(train_loss)) or torch.isnan(torch.tensor(val_loss)):
            logger.error("NaN loss detected! Stopping training.")
            break

        scheduler.step()
        logger.info(f"Learning Rate adjusted to: {scheduler.get_last_lr()[0]:.6f}")

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            save_checkpoint(
                f"best_{MODEL}_model_on_{DATASET}.pt",
                model,
                optimizer,
                train_losses,
                val_losses
            )
            logger.info(f"✓ Saved best model (val_loss: {val_loss:.6f})")

        if CONFIG['USE_WANDB']:
            wandb.log({
                'epoch': epoch + 1,
                'train_loss': train_loss,
                'val_loss': val_loss,
                'learning_rate': scheduler.get_last_lr()[0],
            })

    return train_losses, val_losses

In [None]:
PROJECT_NAME = CONFIG['WANDB_PROJECT_NAME']
USE_WANDB = CONFIG['USE_WANDB']
RESUME_LOGGING = CONFIG['WANDB_PREVIOUS_RUN_ID'] is not None
run_name = CONFIG['WANDB_RUN_NAME']

if USE_WANDB:
    wandb.login(key=WANDB_API_KEY)

    if RESUME_LOGGING:
        run_id = CONFIG['WANDB_PREVIOUS_RUN_ID']
        run = wandb.init(
            settings=wandb.Settings(symlink=False),
            id=run_id,
            resume="must",
            project=PROJECT_NAME,
            entity="multimodal_2025",
        )
    else:
        run = wandb.init(
            name=run_name,
            reinit=True,
            project=PROJECT_NAME,
            config=CONFIG,
            entity="multimodal_2025",
        )

In [None]:
train_losses, val_losses = train(
    CONFIG["NUM_EPOCHS"],
    model,
    train_loader,
    val_loader,
    optimizer,
    scheduler,
    loss # type: ignore
)