In [None]:
# train_clip.py
import os
import math
import csv
from dataclasses import dataclass
from typing import List

import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image

import open_clip

# -----------------------
# Config
# -----------------------
@dataclass
class Config:
    model_name: str = "ViT-B-32"
    pretrained: str = "laion400m_e32"   # or laion400m_e31
    train_csv: str = "data/train.csv"
    batch_size: int = 64
    num_workers: int = 4
    max_epochs: int = 5
    lr: float = 5e-5
    weight_decay: float = 0.02
    grad_clip: float = 1.0
    amp: bool = True
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    out_dir: str = "checkpoints"
    finetune_mode: str = "proj_only"  # one of: "full", "proj_only", "text_only", "vision_only"
    seed: int = 42

cfg = Config()

torch.manual_seed(cfg.seed)
os.makedirs(cfg.out_dir, exist_ok=True)

# -----------------------
# Dataset
# -----------------------
class ImageTextCSV(Dataset):
    def __init__(self, csv_path: str, preprocess, tokenizer, text_ctx_len: int = None):
        self.items: List[tuple[str, str]] = []
        self.preprocess = preprocess
        self.tokenizer = tokenizer
        self.text_ctx_len = text_ctx_len  # open_clip models typically 77 for CLIP; some support more

        with open(csv_path, "r", newline="", encoding="utf-8") as f:
            reader = csv.DictReader(f)
            assert "image_path" in reader.fieldnames and "caption" in reader.fieldnames, \
                "CSV must have columns: image_path, caption"
            for row in reader:
                self.items.append((row["image_path"], row["caption"]))

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx: int):
        img_path, caption = self.items[idx]
        image = Image.open(img_path).convert("RGB")
        image = self.preprocess(image)  # tensor CHW
        # open_clip tokenizer pads/truncates to the model's context length automatically
        text_tokens = self.tokenizer([caption])[0]
        return image, text_tokens

def collate_fn(batch):
    # batch: List[(image_tensor, text_ids)]
    images, texts = zip(*batch)  # images: tuple of tensors, texts: tuple of 1D LongTensors
    images = torch.stack(images, dim=0)
    texts = torch.stack(texts, dim=0)
    return images, texts

# -----------------------
# Build model + data
# -----------------------
model, _, preprocess = open_clip.create_model_and_transforms(cfg.model_name, pretrained=cfg.pretrained)
tokenizer = open_clip.get_tokenizer(cfg.model_name)

dataset = ImageTextCSV(cfg.train_csv, preprocess, tokenizer)
loader = DataLoader(
    dataset,
    batch_size=cfg.batch_size,
    shuffle=True,
    num_workers=cfg.num_workers,
    pin_memory=True,
    collate_fn=collate_fn,
)

device = torch.device(cfg.device)
model = model.to(device)

# -----------------------
# Choose what to fine-tune
# -----------------------
def set_trainable(mode: str):
    # Freeze everything first
    for p in model.parameters():
        p.requires_grad = False

    if mode == "full":
        for p in model.parameters():
            p.requires_grad = True

    elif mode == "proj_only":
        # Train the projection heads + logit_scale for effective adaptation
        # Projection layers typically named:
        #   model.visual.proj (ViT projection)
        #   model.text_projection (text projection)
        # Also enable logit_scale (learnable temperature)
        if hasattr(model, "visual") and hasattr(model.visual, "proj"):
            for p in model.visual.proj.parameters() if hasattr(model.visual.proj, "parameters") else [model.visual.proj]:
                p.requires_grad = True
        if hasattr(model, "text_projection"):
            if hasattr(model.text_projection, "parameters"):
                for p in model.text_projection.parameters():
                    p.requires_grad = True
            else:
                model.text_projection.requires_grad_(True)

        model.logit_scale.requires_grad_(True)

    elif mode == "text_only":
        # Unfreeze text encoder + text_projection + logit_scale
        if hasattr(model, "transformer"):
            for p in model.transformer.parameters():
                p.requires_grad = True
        if hasattr(model, "text_projection"):
            if hasattr(model.text_projection, "parameters"):
                for p in model.text_projection.parameters():
                    p.requires_grad = True
            else:
                model.text_projection.requires_grad_(True)
        model.logit_scale.requires_grad_(True)

    elif mode == "vision_only":
        # Unfreeze vision encoder + visual.proj + logit_scale
        if hasattr(model, "visual"):
            for p in model.visual.parameters():
                p.requires_grad = True
        model.logit_scale.requires_grad_(True)

    else:
        raise ValueError(f"Unknown finetune_mode: {mode}")

set_trainable(cfg.finetune_mode)

# Verify
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"Trainable params: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)")

# -----------------------
# Optimizer & Scheduler
# -----------------------
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(params, lr=cfg.lr, weight_decay=cfg.weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(loader)*cfg.max_epochs)

scaler = torch.cuda.amp.GradScaler(enabled=cfg.amp)

# -----------------------
# Training (CLIP loss)
# -----------------------
def clip_contrastive_loss(logits_per_image, logits_per_text):
    batch_size = logits_per_image.size(0)
    labels = torch.arange(batch_size, device=logits_per_image.device)
    loss_i = torch.nn.functional.cross_entropy(logits_per_image, labels)
    loss_t = torch.nn.functional.cross_entropy(logits_per_text, labels)
    return (loss_i + loss_t) / 2

best_loss = math.inf

for epoch in range(cfg.max_epochs):
    model.train()
    running = 0.0

    for step, (images, texts) in enumerate(loader):
        images = images.to(device, non_blocking=True)
        texts = texts.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=cfg.amp):
            # open_clip forward returns logits in temperature-scaled space already
            logits_per_image, logits_per_text = model(images, texts)
            loss = clip_contrastive_loss(logits_per_image, logits_per_text)

        scaler.scale(loss).backward()
        if cfg.grad_clip is not None:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(params, cfg.grad_clip)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        running += loss.item()
        if (step + 1) % 50 == 0:
            avg = running / 50
            print(f"Epoch {epoch+1}/{cfg.max_epochs} | Step {step+1}/{len(loader)} | Loss {avg:.4f}")
            running = 0.0

    # Save checkpoint each epoch
    ckpt_path = os.path.join(cfg.out_dir, f"{cfg.model_name}-{cfg.pretrained}-epoch{epoch+1}.pt")
    torch.save({"model": model.state_dict(), "config": cfg.__dict__}, ckpt_path)
    print(f"Saved: {ckpt_path}")
