In [None]:
!wandb login --relogin

In [None]:
from huggingface_hub import login
login()

In [None]:
LOCAL_DIR = "/content/checkpoints"
os.makedirs(LOCAL_DIR, exist_ok=True)

In [None]:
import os, math, random, io, requests, json, time, torch
from torch import nn
from torch.utils.data import DataLoader
from PIL import Image
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt

from datasets import load_dataset, Dataset, Image as HFImage
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoProcessor,
    SiglipVisionModel,
)
from torch.amp import autocast, GradScaler
from concurrent.futures import ThreadPoolExecutor, as_completed


# 0) CONFIG
LLM_NAME    = "meta-llama/Llama-3.2-1B" # "google/gemma-3-1b-it"
VISION_NAME = "google/siglip-so400m-patch14-384"

cache_dir        = "/content/livis_cache"
max_cached_items = 100000      # max (train+val) cached image+caption pairs
train_size       = 95000
val_size         = 5000       # train + val = 100k

batch_size   = 8
max_txt_len  = 128
total_steps  = 50000
warmup_steps = 1000
grad_accum   = 8
val_interval = 1000
save_step_10k = 10000
save_step_5k = 5000

use_wandb   = True
wandb_project = "blip2-llama-mode-livis-50k"

# resume flags (only affect Q-Former + Projector)
RESUME = False
RESUME_QFORMER_PATH   = "checkpoints/qformer_best.pt"
RESUME_PROJECTOR_PATH = "checkpoints/projector_best.pt"

device = "cuda"
torch.backends.cuda.matmul.allow_tf32 = True
print("Device:", device)

os.makedirs(cache_dir, exist_ok=True)
os.makedirs("logs", exist_ok=True)

existing_jpgs = [f for f in os.listdir(cache_dir) if f.lower().endswith(".jpg")]
print(f"Found existing cached JPGs in {cache_dir}: {len(existing_jpgs)}")


# LOAD LLM (Best results with Llama-3.2-1B)
print(f"Loading LLM: {LLM_NAME}")
tokenizer = AutoTokenizer.from_pretrained(LLM_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

llm = AutoModelForCausalLM.from_pretrained(
    LLM_NAME,
    torch_dtype=torch.float16,
    device_map="auto",
)
llm.eval()

d_model = llm.config.hidden_size
print("LLM hidden size:", d_model)


# LOAD SIGLIP VISION ENCODER
print(f"Loading vision encoder: {VISION_NAME}")
vision_model = SiglipVisionModel.from_pretrained(
    VISION_NAME,
    torch_dtype=torch.float16,
).to(device)
processor = AutoProcessor.from_pretrained(VISION_NAME)
vision_model.eval()

d_vision = vision_model.config.hidden_size
resize_size = processor.image_processor.size.get("shortest_edge", 384)
print("SigLIP hidden size:", d_vision)
print("Resize size:", resize_size)

# Pre-compute patch grid size for attention visualization
img_size  = getattr(vision_model.config, "image_size", 384)
patch_sz  = getattr(vision_model.config, "patch_size", 14)
grid_H = img_size // patch_sz
grid_W = img_size // patch_sz


# CACHE LIVIS DATASET LOCALLY 

print("Loading raw LIVIS dataset from Hugging Face...")
raw = load_dataset("laion/220k-GPT4Vision-captions-from-LIVIS", split="train")
print("Raw dataset size:", len(raw))

def extract_caption(ex):
    """
    Prefer short_caption; fallback to caption; always prefix 'Short caption:' so
    the model learns consistent prompting. (Matches inference prompt.)
    """
    sc = (ex.get("short_caption") or "").strip()
    if not sc:
        sc = (ex.get("caption") or "An image.").strip()
    return f"Short caption: {sc}"

paths_ok = []
caps_ok  = []
jobs     = []  # (idx, url, caption) to be downloaded

print(f"Scanning dataset to reuse cache & queue missing images (target {max_cached_items})...")

for idx, ex in enumerate(raw):
    if len(paths_ok) + len(jobs) >= max_cached_items:
        break

    url = ex.get("url")
    if not url:
        continue

    cap = extract_caption(ex)
    img_path = os.path.join(cache_dir, f"{idx:06d}.jpg")
    txt_path = os.path.join(cache_dir, f"{idx:06d}.txt")

    if os.path.exists(img_path):
        # Try to reuse existing image; if broken, schedule for re-download
        try:
            Image.open(img_path).verify()
            if not os.path.exists(txt_path):
                with open(txt_path, "w", encoding="utf-8") as f:
                    f.write(cap)
            paths_ok.append(img_path)
            caps_ok.append(cap)
        except Exception:
            jobs.append((idx, url, cap))
        continue

    # Missing: schedule for download
    jobs.append((idx, url, cap))

print(f"Already valid images (reused): {len(paths_ok)}")
print(f"Images queued for download:   {len(jobs)}")

# Clip job list so final count <= max_cached_items
if len(paths_ok) > max_cached_items:
    paths_ok = paths_ok[:max_cached_items]
    caps_ok  = caps_ok [:max_cached_items]
    jobs = []
elif len(paths_ok) + len(jobs) > max_cached_items:
    jobs = jobs[: max_cached_items - len(paths_ok)]

print(f"Final plan -> keep {len(paths_ok)} existing + download {len(jobs)} new.")

def fetch_one(job):
    """
    Worker for parallel download.
    job: (idx, url, caption)
    Returns (img_path, caption) or None if failed.
    """
    idx, url, cap = job
    img_path = os.path.join(cache_dir, f"{idx:06d}.jpg")
    txt_path = os.path.join(cache_dir, f"{idx:06d}.txt")

    # If already exists and ok, just ensure caption
    if os.path.exists(img_path):
        try:
            Image.open(img_path).verify()
            if not os.path.exists(txt_path):
                with open(txt_path, "w", encoding="utf-8") as f:
                    f.write(cap)
            return (img_path, cap)
        except Exception:
            pass  # will re-download

    try:
        r = requests.get(url, timeout=7)
        img = Image.open(io.BytesIO(r.content)).convert("RGB")
        img.save(img_path, "JPEG", quality=92)

        with open(txt_path, "w", encoding="utf-8") as f:
            f.write(cap)

        return (img_path, cap)
    except Exception:
        return None

if jobs:
    print("üöÄ Parallel downloading with workers...")
    max_workers = 100
    done = 0
    good = 0

    with ThreadPoolExecutor(max_workers=max_workers) as ex_pool:
        futures = {ex_pool.submit(fetch_one, j): j[0] for j in jobs}
        for fut in as_completed(futures):
            res = fut.result()
            done += 1
            if res is not None:
                p, c = res
                paths_ok.append(p)
                caps_ok.append(c)
                good += 1
            if done % 500 == 0:
                print(f"[Download progress] finished={done}/{len(jobs)}, successful={good}")

print(f"‚úÖ Cache ready. Total valid pairs: {len(paths_ok)}")

if len(paths_ok) < (train_size + val_size):
    raise RuntimeError(
        f"Not enough cached images+captions ({len(paths_ok)}) "
        f"for requested train+val={train_size+val_size}."
    )

# Trim to exactly train+val
paths_ok = paths_ok[:train_size + val_size]
caps_ok  = caps_ok [:train_size + val_size]

print("Building local HF dataset from cached files...")

hf_ds = Dataset.from_dict({
    "image": paths_ok,
    "caption": caps_ok,
})
hf_ds = hf_ds.cast_column("image", HFImage())  # lazy PIL loading

print("Full cached dataset size:", len(hf_ds))
print("Example row:", hf_ds[0])


# TRAIN / VAL SPLIT

hf_ds = hf_ds.shuffle(seed=42)
train_ds = hf_ds.select(range(train_size))
val_ds   = hf_ds.select(range(train_size, train_size + val_size))

print("Train size:", len(train_ds))
print("Val size:  ", len(val_ds))


# PYTORCH DATALOADERS

def collate_fn(batch):
    images = []
    caps   = []
    for ex in batch:
        img = ex["image"]
        if isinstance(img, Image.Image):
            img = img.convert("RGB")
        else:
            # HF Image returns PIL.Image upon access; keep robust
            try:
                img = img.convert("RGB")
            except Exception:
                continue
        images.append(img)
        caps.append(ex["caption"])

    if len(images) == 0:
        return None

    pixel_values = processor(images=images, return_tensors="pt")["pixel_values"]
    enc = tokenizer(
        caps,
        padding="max_length",
        truncation=True,
        max_length=max_txt_len,
        return_tensors="pt",
    )

    return {
        "pixel_values": pixel_values,
        "input_ids": enc["input_ids"],
        "attention_mask": enc["attention_mask"],
        "captions": caps,
        "pil_images": images,
    }

train_loader = DataLoader(
    train_ds,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2,
    collate_fn=collate_fn,
)

val_loader = DataLoader(
    val_ds,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2,
    collate_fn=collate_fn,
)

print("Dataloaders ready.")


# STRONGER Q-FORMER (BLIP-2 style: self-attn + cross-attn blocks)

class QFormerBlock(nn.Module):
    """
    One block of a BLIP2-style Q-Former:
      - LN + self-attention on query tokens
      - LN + cross-attention from queries ‚Üí vision tokens
      - LN + MLP
    """
    def __init__(self, d_model, n_heads, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.self_attn = nn.MultiheadAttention(
            d_model, n_heads, batch_first=True, dropout=dropout
        )

        self.ln2 = nn.LayerNorm(d_model)
        self.cross_attn = nn.MultiheadAttention(
            d_model, n_heads, batch_first=True, dropout=dropout
        )

        self.ln3 = nn.LayerNorm(d_model)
        hidden = int(d_model * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, hidden),
            nn.GELU(),
            nn.Linear(hidden, d_model),
            nn.Dropout(dropout),
        )

    def forward(self, q, v, attn_mask_v=None):
        # Self-attention on q
        q2, _ = self.self_attn(self.ln1(q), self.ln1(q), self.ln1(q), need_weights=False)
        q = q + q2

        # Cross-attention (queries = q, keys/values = v)
        q2, attn = self.cross_attn(
            self.ln2(q), self.ln2(v), self.ln2(v),
            key_padding_mask=attn_mask_v,
            need_weights=True,
            average_attn_weights=False,
        )
        q = q + q2

        # Feed-forward
        q = q + self.mlp(self.ln3(q))

        return q, attn


class QFormer(nn.Module):
    def __init__(self, d_vis, d_model,
                 n_queries=32,
                 n_layers=8,
                 heads=8,
                 mlp_ratio=4.0,
                 dropout=0.1):
        super().__init__()
        self.query = nn.Parameter(torch.randn(n_queries, d_model))
        self.vis_proj = nn.Linear(d_vis, d_model)

        self.layers = nn.ModuleList([
            QFormerBlock(d_model, heads, mlp_ratio=mlp_ratio, dropout=dropout)
            for _ in range(n_layers)
        ])
        self.final_ln = nn.LayerNorm(d_model)

        self.last_attn = None  # store last layer attention for visualization

    def forward(self, vis_tokens: torch.Tensor) -> torch.Tensor:
        """
        vis_tokens: (B, N, d_vis)
        returns:    (B, K, d_model)
        """
        v = self.vis_proj(vis_tokens)          # (B, N, d_model)

        B = v.size(0)
        q = self.query.unsqueeze(0).expand(B, -1, -1)  # (B, K, d_model)

        last_attn = None
        for blk in self.layers:
            q, attn = blk(q, v, attn_mask_v=None)      # attn: (B, heads, K, N)
            last_attn = attn

        self.last_attn = last_attn
        q = self.final_ln(q)
        return q

qformer = QFormer(
    d_vis=d_vision,
    d_model=d_model,
    n_queries=32,
    n_layers=8,
    heads=8,
    mlp_ratio=4.0,
    dropout=0.1,
).to(device)

print("Q-Former params (M):", sum(p.numel() for p in qformer.parameters())/1e6)


# IMAGE-LEVEL MoE PROJECTOR (CLS-routed, 4 experts)

class MoEProjector(nn.Module):
    """
    Image-level Mixture-of-Experts Projector.

    Routing:
      - Uses SigLIP CLS token (before dropping CLS)
      - CLS (B, d_vision) ‚Üí router ‚Üí softmax over num_experts
      - Same mixture is applied to all K query tokens for that image.

    Experts:
      - Each expert is a strong projector: LN + 2-layer MLP + residual-ish stack.
    """
    def __init__(self, d_model, d_router_in, num_experts=4, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.num_experts = num_experts
        hidden = int(d_model * mlp_ratio)

        # Experts operate on Q-Former outputs (B, K, d_model)
        self.experts = nn.ModuleList()
        for _ in range(num_experts):
            self.experts.append(
                nn.Sequential(
                    nn.LayerNorm(d_model),
                    nn.Linear(d_model, hidden),
                    nn.GELU(),
                    nn.Dropout(dropout),
                    nn.Linear(hidden, d_model),
                    nn.Dropout(dropout),
                    nn.LayerNorm(d_model),
                )
            )

        # Router operates on CLS token in vision space (d_router_in = d_vision)
        self.router = nn.Sequential(
            nn.LayerNorm(d_router_in),
            nn.Linear(d_router_in, d_router_in // 2),
            nn.GELU(),
            nn.Linear(d_router_in // 2, num_experts),
        )

    def forward(self, q, cls_token):
        """
        q:         (B, K, d_model)
        cls_token: (B, d_router_in) from SigLIP CLS
        """
        # Router over CLS
        logits = self.router(cls_token)           # (B, num_experts)
        weights = torch.softmax(logits, dim=-1)   # (B, num_experts)

        # Compute each expert's output on q: list of (B, K, d_model)
        expert_outputs = []
        for exp in self.experts:
            expert_outputs.append(exp(q))

        # Stack: (E, B, K, d_model)
        expert_stack = torch.stack(expert_outputs, dim=0)

        # weights: (B, E) ‚Üí (E, B, 1, 1)
        w = weights.t().unsqueeze(-1).unsqueeze(-1)   # (E, B, 1, 1)

        # Weighted sum over experts
        out = (w * expert_stack).sum(dim=0)           # (B, K, d_model)
        return out, weights

projector = MoEProjector(
    d_model=d_model,
    d_router_in=d_vision,
    num_experts=4,
    mlp_ratio=4.0,
    dropout=0.1,
).to(device)

print("MoE Projector params (M):", sum(p.numel() for p in projector.parameters())/1e6)


# BLIP-2 WRAPPER (visual prefix + MoE projector + text)

class BLIP2(nn.Module):
    def __init__(self, llm, vision, qformer, projector):
        super().__init__()
        self.llm = llm
        self.vision = vision
        self.qformer = qformer
        self.projector = projector  # MoEProjector
        self.last_router_weights = None

    def forward(self, input_ids, pixel_values, attention_mask):
        # Vision encoder (frozen, fp16)
        with torch.no_grad():
            vout = self.vision(pixel_values=pixel_values)
            cls_token = vout.last_hidden_state[:, 0, :]    # (B, d_vision)
            vtoks     = vout.last_hidden_state[:, 1:, :]   # (B, N-1, d_vision)

        # Q-Former in fp32
        q = self.qformer(vtoks.to(torch.float32))          # (B, K, d_model)

        # MoE projector (image-level routing via CLS), then cast to LLM dtype
        q, router_w = self.projector(q, cls_token.to(torch.float32))
        self.last_router_weights = router_w                # (B, num_experts)
        q = q.to(self.llm.dtype)
        K = q.size(1)

        # Text embeddings
        embed = self.llm.get_input_embeddings()
        txt   = embed(input_ids)                           # (B, T, d_model)

        # Concatenate visual prefix + text
        all_emb = torch.cat([q, txt], dim=1)               # (B, K+T, d_model)

        # Build attention mask
        prefix_mask = torch.ones(
            input_ids.size(0), K,
            device=input_ids.device,
            dtype=attention_mask.dtype,
        )
        full_mask = torch.cat([prefix_mask, attention_mask], dim=1)  # (B, K+T)

        # Labels: ignore visual prefix
        labels = input_ids.clone()
        labels[attention_mask == 0] = -100
        prefix_labels = torch.full(
            (input_ids.size(0), K),
            -100,
            device=input_ids.device,
            dtype=torch.long,
        )
        full_labels = torch.cat([prefix_labels, labels], dim=1)

        out = self.llm(
            inputs_embeds=all_emb,
            attention_mask=full_mask,
            labels=full_labels,
        )
        return out.logits, out.loss

model = BLIP2(llm, vision_model, qformer, projector)

# Freeze LLM + Vision (only Q-Former & Projector train)
for p in llm.parameters():
    p.requires_grad = False
for p in vision_model.parameters():
    p.requires_grad = False


# RESUME FROM WEIGHTS (OPTIONAL)

if RESUME:
    print("üîÅ RESUME is True ‚Äì trying to load Q-Former & Projector weights...")
    if os.path.isfile(RESUME_QFORMER_PATH):
        try:
            qformer.load_state_dict(torch.load(RESUME_QFORMER_PATH, map_location=device))
            print(f"Loaded Q-Former from {RESUME_QFORMER_PATH}")
        except Exception as e:
            print("‚ö†Ô∏è Failed to load Q-Former weights:", e)
    else:
        print(f"‚ö†Ô∏è Q-Former resume path not found: {RESUME_QFORMER_PATH}")

    if os.path.isfile(RESUME_PROJECTOR_PATH):
        try:
            projector.load_state_dict(torch.load(RESUME_PROJECTOR_PATH, map_location=device))
            print(f"Loaded Projector from {RESUME_PROJECTOR_PATH}")
        except Exception as e:
            print("‚ö†Ô∏è Failed to load Projector weights:", e)
    else:
        print(f"‚ö†Ô∏è Projector resume path not found: {RESUME_PROJECTOR_PATH}")


# OPTIMIZER + SCHEDULER + AMP

train_params = list(qformer.parameters()) + list(projector.parameters())
optimizer = torch.optim.AdamW(train_params, lr=5e-5, weight_decay=0.01)

def lr_lambda(step):
    if step < warmup_steps:
        return step / max(1, warmup_steps)
    ratio = (step - warmup_steps) / max(1, total_steps - warmup_steps)
    return max(0.1, 0.5 * (1 + math.cos(math.pi * ratio)))

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
scaler = GradScaler(device="cuda")

if use_wandb:
    import wandb
    wandb.init(project=wandb_project, config={
        "llm": LLM_NAME,
        "vision": VISION_NAME,
        "lr": 5e-5,
        "warmup": warmup_steps,
        "steps": total_steps,
        "grad_accum": grad_accum,
        "batch_size": batch_size,
        "train_size": train_size,
        "val_size": val_size,
        "cache_dir": cache_dir,
        "moe_experts": 4,
    })


# INFERENCE HELPER (single image ‚Üí caption)

@torch.no_grad()
def generate_caption_for_image(img: Image.Image, max_new_tokens=32):
    img = img.convert("RGB")
    v = processor(images=img, return_tensors="pt")["pixel_values"].to(device)

    vout = vision_model(pixel_values=v)
    cls_token = vout.last_hidden_state[:, 0, :]   # (1, d_vision)
    vtoks = vout.last_hidden_state[:, 1:, :]      # (1, N, d_vision)

    q = qformer(vtoks.to(torch.float32))          # (1, K, d_model)
    q, _ = projector(q, cls_token.to(torch.float32))
    q = q.to(llm.dtype)

    prompt = "Short caption: "
    enc = tokenizer(prompt, return_tensors="pt").to(device)
    ids = enc["input_ids"]                         # (1, T)
    txt_emb = llm.get_input_embeddings()(ids)      # (1, T, d_model)

    all_emb = torch.cat([q, txt_emb], dim=1)       # (1, K+T, d_model)
    attn_mask = torch.ones(1, all_emb.size(1), device=device, dtype=torch.long)

    out_ids = llm.generate(
        inputs_embeds=all_emb,
        attention_mask=attn_mask,
        max_new_tokens=max_new_tokens,
        do_sample=False,
        pad_token_id=tokenizer.eos_token_id,
    )

    text = tokenizer.decode(out_ids[0], skip_special_tokens=True)
    if text.lower().startswith(prompt.lower()):
        text = text[len(prompt):].strip()
    return text


# ATTENTION HEATMAP HELPER

def save_attention_heatmap(img: Image.Image, step: int, max_queries: int = 4):
    img = img.convert("RGB")
    v = processor(images=img, return_tensors="pt")["pixel_values"].to(device)

    with torch.no_grad():
        vout = vision_model(pixel_values=v)
        vtoks = vout.last_hidden_state[:, 1:, :]
        _ = qformer(vtoks.to(torch.float32))
        attn = qformer.last_attn  # (B, heads, K, N)

    if attn is None:
        return None

    # average over heads, take first batch
    attn = attn.mean(dim=1)[0]  # (K, N)
    Kq, N = attn.shape

    H, W = grid_H, grid_W
    if H * W > N:
        H = int(math.sqrt(N)) or 1
        W = H
    attn = attn[:, :H * W].view(Kq, H, W)

    num_q = min(max_queries, Kq)
    fig, axes = plt.subplots(1, num_q, figsize=(4 * num_q, 4))
    if num_q == 1:
        axes = [axes]

    img_np = np.array(img)

    for qi in range(num_q):
        heat = attn[qi].detach().cpu().numpy()
        heat = (heat - heat.min()) / (heat.max() - heat.min() + 1e-8)

        ax = axes[qi]
        ax.imshow(img_np)
        ax.imshow(heat, cmap="jet", alpha=0.45)
        ax.set_title(f"Query {qi}")
        ax.axis("off")

    plt.tight_layout()
    out_path = os.path.join("logs", f"attn_step_{step}.png")
    plt.savefig(out_path)
    plt.close(fig)
    return out_path


# VALIDATION LOOP (console + jsonl + wandb)

val_pred_log_path = os.path.join(cache_dir, "val_predictions.jsonl")

def run_validation(global_step):
    model.eval()
    total_loss = 0.0
    count = 0
    example_logged = False
    example_img = None
    example_gt  = None
    example_pred= None

    with torch.no_grad():
        for batch in val_loader:
            if batch is None:
                continue
            pixel_values = batch["pixel_values"].to(device)
            input_ids    = batch["input_ids"].to(device)
            attn_mask    = batch["attention_mask"].to(device)

            with autocast("cuda", dtype=torch.float16):
                _, loss = model(input_ids, pixel_values, attn_mask)

            bs = input_ids.size(0)
            total_loss += loss.item() * bs
            count += bs

            if not example_logged:
                img0 = batch["pil_images"][0]
                gt0  = batch["captions"][0]
                pred0 = generate_caption_for_image(img0)
                example_img = img0
                example_gt  = gt0
                example_pred= pred0
                example_logged = True

    avg_loss = total_loss / max(1, count)
    print(f"[VAL @ step {global_step}] loss={avg_loss:.4f}")
    if example_logged:
        print(f"[VAL] GT : {example_gt}")
        print(f"[VAL] PR : {example_pred}")

    # Append example GT+Pred to jsonl log
    if example_logged:
        try:
            with open(val_pred_log_path, "a", encoding="utf-8") as f:
                f.write(json.dumps({
                    "step": global_step,
                    "gt": example_gt,
                    "pred": example_pred,
                }) + "\n")
        except Exception as e:
            print("‚ö†Ô∏è Error writing val prediction log:", e)

    attn_path = None
    if example_img is not None:
        attn_path = save_attention_heatmap(example_img, global_step)

    if use_wandb:
        import wandb
        log_dict = {"val_loss": avg_loss, "step": global_step}

        # Single-row table with GT+Pred
        table = wandb.Table(columns=["step", "image", "gt", "pred"])
        if example_img is not None:
            table.add_data(
                global_step,
                wandb.Image(example_img),
                example_gt,
                example_pred,
            )
        log_dict["val_samples"] = table

        # Attention overlay image
        if attn_path is not None and os.path.exists(attn_path):
            log_dict["val_attention"] = wandb.Image(
                attn_path, caption=f"Q-Former attention @ step {global_step}"
            )

        # Optional: log router weights histogram if available
        if getattr(model, "last_router_weights", None) is not None:
            rw = model.last_router_weights.detach().cpu().numpy()
            log_dict["router_weights"] = wandb.Histogram(rw)

        wandb.log(log_dict, step=global_step)

    model.train()
    return avg_loss


# TRAINING LOOP (advanced logs + validation + checkpoints)

print("üöÄ Training starting...")
best_val_loss = float("inf")
global_step = 0

train_iter = iter(train_loader)
running_loss = 0.0
start_time = time.time()

while global_step < total_steps:
    try:
        batch = next(train_iter)
    except StopIteration:
        train_iter = iter(train_loader)
        batch = next(train_iter)

    if batch is None:
        continue

    pixel_values = batch["pixel_values"].to(device)
    input_ids    = batch["input_ids"].to(device)
    attn_mask    = batch["attention_mask"].to(device)

    with autocast("cuda", dtype=torch.float16):
        _, loss = model(input_ids, pixel_values, attn_mask)
        loss = loss / grad_accum

    scaler.scale(loss).backward()
    running_loss += loss.item()

    # Occasionally log caption lengths histogram
    if use_wandb and (global_step % 1000 == 0):
        cap_lengths = attn_mask.sum(dim=1).detach().cpu().tolist()
        import wandb
        wandb.log(
            {"caption_lengths": wandb.Histogram(cap_lengths)},
            step=global_step,
        )

    grad_norm = None
    if (global_step + 1) % grad_accum == 0:
        scaler.unscale_(optimizer)
        grad_norm = torch.nn.utils.clip_grad_norm_(train_params, 1.0).item()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        optimizer.zero_grad()

        if use_wandb:
            import wandb
            wandb.log({"grad_norm": grad_norm}, step=global_step)

    # Training logs
    if global_step % 200 == 0:
        avg_train_loss = running_loss / max(1, (global_step + 1))
        lr = scheduler.get_last_lr()[0]
        elapsed = time.time() - start_time
        steps_done = global_step + 1
        eta_secs = (total_steps - steps_done) * (elapsed / steps_done)
        eta_mins = eta_secs / 60.0

        print(
            f"[{global_step:05d}/{total_steps}] "
            f"loss={loss.item():.4f} (avg={avg_train_loss:.4f}) "
            f"lr={lr:.6e} grad_norm={grad_norm} "
            f"ETA={eta_mins:.1f} min"
        )

        if use_wandb:
            import wandb
            wandb.log(
                {
                    "train_loss": loss.item(),
                    "train_loss_avg": avg_train_loss,
                    "lr": lr,
                    "eta_min": eta_mins,
                    "step": global_step,
                },
                step=global_step,
            )

    # Validation
    if (global_step + 1) % val_interval == 0:
        val_loss = run_validation(global_step + 1)
        if use_wandb:
            import wandb
            avg_train_loss = running_loss / max(1, (global_step + 1))
            wandb.log(
                {
                    "train_val_gap": avg_train_loss - val_loss
                },
                step=global_step + 1,
            )

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            try:
                os.makedirs("checkpoints", exist_ok=True)
                torch.save(qformer.state_dict(), "checkpoints/qformer_best.pt")
                torch.save(projector.state_dict(), "checkpoints/projector_best.pt")
                print(f"‚úÖ Saved new best checkpoint at step {global_step+1}, val_loss={val_loss:.4f}")
                if use_wandb:
                    import wandb
                    wandb.save("checkpoints/qformer_best.pt")
                    wandb.save("checkpoints/projector_best.pt")
            except Exception as e:
                print("‚ö†Ô∏è Error while saving best checkpoint (ignored):", e)

    # Explicit save at 10k or 5k
    if (global_step + 1) == save_step_10k or (global_step + 1) == save_step_5k:
        try:
            os.makedirs("checkpoints", exist_ok=True)
            torch.save(qformer.state_dict(), "checkpoints/qformer_step10k.pt")
            torch.save(projector.state_dict(), "checkpoints/projector_step10k.pt")
            print("üíæ Saved 10k-step checkpoint.")
            if use_wandb:
                import wandb
                wandb.save("checkpoints/qformer_step10k.pt")
                wandb.save("checkpoints/projector_step10k.pt")
        except Exception as e:
            print("‚ö†Ô∏è Error while saving 10k-step checkpoint (ignored):", e)

    global_step += 1

print("üéâ Training Finished!")


# QUICK INFERENCE EXAMPLE (after training)

ex = random.choice(list(val_ds))
img = ex["image"]
if isinstance(img, Image.Image):
    img = img.convert("RGB")
pred = generate_caption_for_image(img)
gt   = ex["caption"]
print("\n[INFERENCE EXAMPLE]")
print("GT:", gt)
print("Pred:", pred)
