In [None]:
!wandb login --relogin

In [None]:
from huggingface_hub import login
login()

In [3]:
import os
LOCAL_DIR = "/content/livis_cache"
os.makedirs(LOCAL_DIR, exist_ok=True)

In [None]:
import os, math, random, io, requests, json, time, torch
from torch import nn
from torch.utils.data import DataLoader
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

from datasets import load_dataset, Dataset, Image as HFImage
from concurrent.futures import ThreadPoolExecutor, as_completed

cache_dir        = "/content/livis_cache"
max_cached_items = 200      # max (train+val) cached image+caption pairs
train_size       = 180
val_size         = 20


# CACHE LIVIS DATASET LOCALLY 

print("Loading raw LIVIS dataset from Hugging Face...")
raw = load_dataset("laion/220k-GPT4Vision-captions-from-LIVIS", split="train")
print("Raw dataset size:", len(raw))

def extract_caption(ex):
    """
    Prefer short_caption; fallback to caption; always prefix 'Short caption:' so
    the model learns consistent prompting. (Matches inference prompt.)
    """
    sc = (ex.get("short_caption") or "").strip()
    if not sc:
        sc = (ex.get("caption") or "An image.").strip()
    return f"Short caption: {sc}"

paths_ok = []
caps_ok  = []
jobs     = []  # (idx, url, caption) to be downloaded

print(f"Scanning dataset to reuse cache & queue missing images (target {max_cached_items})...")

for idx, ex in enumerate(raw):
    if len(paths_ok) + len(jobs) >= max_cached_items:
        break

    url = ex.get("url")
    if not url:
        continue

    cap = extract_caption(ex)
    img_path = os.path.join(cache_dir, f"{idx:06d}.jpg")
    txt_path = os.path.join(cache_dir, f"{idx:06d}.txt")

    if os.path.exists(img_path):
        # Try to reuse existing image; if broken, schedule for re-download
        try:
            Image.open(img_path).verify()
            if not os.path.exists(txt_path):
                with open(txt_path, "w", encoding="utf-8") as f:
                    f.write(cap)
            paths_ok.append(img_path)
            caps_ok.append(cap)
        except Exception:
            jobs.append((idx, url, cap))
        continue

    # Missing: schedule for download
    jobs.append((idx, url, cap))

print(f"Already valid images (reused): {len(paths_ok)}")
print(f"Images queued for download:   {len(jobs)}")

# Clip job list so final count <= max_cached_items
if len(paths_ok) > max_cached_items:
    paths_ok = paths_ok[:max_cached_items]
    caps_ok  = caps_ok [:max_cached_items]
    jobs = []
elif len(paths_ok) + len(jobs) > max_cached_items:
    jobs = jobs[: max_cached_items - len(paths_ok)]

print(f"Final plan -> keep {len(paths_ok)} existing + download {len(jobs)} new.")

def fetch_one(job):
    """
    Worker for parallel download.
    job: (idx, url, caption)
    Returns (img_path, caption) or None if failed.
    """
    idx, url, cap = job
    img_path = os.path.join(cache_dir, f"{idx:06d}.jpg")
    txt_path = os.path.join(cache_dir, f"{idx:06d}.txt")

    # If already exists and ok, just ensure caption
    if os.path.exists(img_path):
        try:
            Image.open(img_path).verify()
            if not os.path.exists(txt_path):
                with open(txt_path, "w", encoding="utf-8") as f:
                    f.write(cap)
            return (img_path, cap)
        except Exception as e:
            print(f"[Error] Existing image {img_path} failed verification: {e}. Re-downloading.")
            pass  # will re-download

    try:
        r = requests.get(url, timeout=7)
        img = Image.open(io.BytesIO(r.content)).convert("RGB")
        img.save(img_path, "JPEG", quality=92)

        with open(txt_path, "w", encoding="utf-8") as f:
            f.write(cap)

        return (img_path, cap)
    except Exception as e:
        print(f"[Error] Failed to download or process image {url} (idx {idx}): {e}")
        return None

if jobs:
    print("üöÄ Parallel downloading with workers...")
    max_workers = 10
    done = 0
    good = 0

    with ThreadPoolExecutor(max_workers=max_workers) as ex_pool:
        futures = {ex_pool.submit(fetch_one, j): j[0] for j in jobs}
        for fut in as_completed(futures):
            res = fut.result()
            done += 1
            if res is not None:
                p, c = res
                paths_ok.append(p)
                caps_ok.append(c)
                good += 1
            if done % 500 == 0:
                print(f"[Download progress] finished={done}/{len(jobs)}, successful={good}")

print(f"‚úÖ Cache ready. Total valid pairs: {len(paths_ok)}")

if len(paths_ok) < (train_size + val_size):
    raise RuntimeError(
        f"Not enough cached images+captions ({len(paths_ok)}) "
        f"for requested train+val={train_size+val_size}."
    )

# Trim to exactly train+val
paths_ok = paths_ok[:train_size + val_size]
caps_ok  = caps_ok [:train_size + val_size]

print("Building local HF dataset from cached files...")

In [None]:
# Run: eren23/blip2-llama-mode-livis-50k / 0i2u1k0s

import wandb
import os


# Configure
WANDB_ENTITY = "eren23"
WANDB_PROJECT = "blip2-llama-mode-livis-50k"
WANDB_RUN_ID = "0i2u1k0s"

OUT_DIR = "/content/wandb_weights"
os.makedirs(OUT_DIR, exist_ok=True)

print("‚û°Ô∏è Logging in to W&B...")
wandb.login()   # Will prompt you for API key in Colab

# Connect to the run
api = wandb.Api()
run_path = f"{WANDB_ENTITY}/{WANDB_PROJECT}/{WANDB_RUN_ID}"
run = api.run(run_path)

print(f"üì¶ Found run: {run.name}")
print("üîç Searching for .pt model files...\n")


# Download all .pt files from artifacts + files tab

downloaded = []

# if you have your weights in the files tab
for f in run.files():
    if f.name.endswith(".pt"):
        print(f"‚¨áÔ∏è Downloading file_tab weight: {f.name}")
        f.download(root=OUT_DIR, replace=True)
        downloaded.append(f.name)

# if you have your weights in the artifacts tab
for artifact in run.logged_artifacts():
    art = api.artifact(artifact.id)
    art_files = art.files()

    for af in art_files:
        if af.name.endswith(".pt"):
            print(f"‚¨áÔ∏è Downloading artifact weight: {af.name}")
            target_dir = os.path.join(OUT_DIR, art.name)
            os.makedirs(target_dir, exist_ok=True)
            af.download(root=target_dir)
            downloaded.append(os.path.join(art.name, af.name))

print("\n===================================================")
print("üì• Download complete")
print("Saved .pt files:")
for x in downloaded:
    print("   ‚Ä¢", x)
print("‚û°Ô∏è Output folder:", OUT_DIR)
print("===================================================")


In [None]:
# ============================================================
#  BLIP-2 + MoE Projector ‚Äî Inference + Router + Attn Viz
# ============================================================

import os, math, json
from pathlib import Path

import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoProcessor,
    SiglipVisionModel,
)


# CONFIG

device = "cuda"
torch.backends.cuda.matmul.allow_tf32 = True

LLM_NAME    = "meta-llama/Llama-3.2-1B"
VISION_NAME = "google/siglip-so400m-patch14-384"

# paths to weights you downloaded from W&B
QFORMER_CKPT   = "/content/wandb_weights/checkpoints/qformer_step10k.pt"
PROJECTOR_CKPT = "/content/wandb_weights/checkpoints/projector_step10k.pt"

# folder with images to run inference on
IMG_DIR = "/content/livis_cache"

PROMPT_PREFIX = "Short caption: "
OUT_DIR = "/content/moe_out"
os.makedirs(OUT_DIR, exist_ok=True)

print("Device:", device)
print("Q-Former ckpt:", QFORMER_CKPT)
print("Projector ckpt:", PROJECTOR_CKPT)
print("Image dir:", IMG_DIR)
print("Output dir:", OUT_DIR)


# LOAD LLM

tokenizer = AutoTokenizer.from_pretrained(LLM_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

llm = AutoModelForCausalLM.from_pretrained(
    LLM_NAME,
    torch_dtype=torch.float16,
    device_map="auto",
)
llm.eval()
d_model = llm.config.hidden_size
print("LLM hidden size:", d_model)


# LOAD VISION ENCODER

processor = AutoProcessor.from_pretrained(VISION_NAME)
vision = SiglipVisionModel.from_pretrained(
    VISION_NAME,
    torch_dtype=torch.float16,
).to(device)
vision.eval()
d_vis = vision.config.hidden_size
print("SigLIP hidden size:", d_vis)


# Q-FORMER BLOCK WITH ATTENTION OUTPUT

class QFormerBlock(nn.Module):
    """
    Layout matches training:
      ln1 ‚Üí self_attn
      ln2 ‚Üí cross_attn
      ln3 ‚Üí mlp
    """
    def __init__(self, d_model, n_heads, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.self_attn = nn.MultiheadAttention(
            d_model, n_heads, batch_first=True
        )

        self.ln2 = nn.LayerNorm(d_model)
        self.cross_attn = nn.MultiheadAttention(
            d_model, n_heads, batch_first=True
        )

        self.ln3 = nn.LayerNorm(d_model)
        hidden = int(d_model * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, hidden),
            nn.GELU(),
            nn.Linear(hidden, d_model),
            nn.Dropout(dropout),
        )

    def forward(self, q, v, need_weights=False):
        # Self-attention (queries = keys = values = q)
        q2, self_w = self.self_attn(
            self.ln1(q), self.ln1(q), self.ln1(q),
            need_weights=need_weights,
            average_attn_weights=False,
        )
        q = q + q2

        # Cross-attention (queries = q, keys/values = v)
        q2, cross_w = self.cross_attn(
            self.ln2(q), self.ln2(v), self.ln2(v),
            need_weights=need_weights,
            average_attn_weights=False,
        )
        q = q + q2

        # MLP
        q = q + self.mlp(self.ln3(q))

        if need_weights:
            # self_w, cross_w: (B, H, Q, N/Q)
            return q, self_w, cross_w
        return q, None, None


# Q-FORMER

class QFormer(nn.Module):
    """
    Matches training:
      - vis_proj
      - layers.{i}.*
    Returns:
      q:           (B, K, d_model)
      self_atts:  list[L]: (B, H, K, K)
      cross_atts: list[L]: (B, H, K, N)
    """
    def __init__(self, d_vis, d_model,
                 n_queries=32,
                 n_layers=8,
                 heads=8,
                 mlp_ratio=4.0,
                 dropout=0.1):
        super().__init__()

        self.query = nn.Parameter(torch.randn(n_queries, d_model))
        self.vis_proj = nn.Linear(d_vis, d_model)

        self.layers = nn.ModuleList([
            QFormerBlock(d_model, heads, mlp_ratio, dropout)
            for _ in range(n_layers)
        ])

        self.final_ln = nn.LayerNorm(d_model)

    def forward(self, vis_tokens, collect_attn: bool = False):
        """
        vis_tokens: (B, N, d_vis)
        """
        B = vis_tokens.size(0)
        v = self.vis_proj(vis_tokens.to(torch.float32))   # (B, N, d_model)

        q = self.query.unsqueeze(0).expand(B, -1, -1)     # (B, K, d_model)

        self_atts = []
        cross_atts = []

        for blk in self.layers:
            if collect_attn:
                q, sw, cw = blk(q, v, need_weights=True)
                self_atts.append(sw)  # (B, H, K, K)
                cross_atts.append(cw) # (B, H, K, N)
            else:
                q, _, _ = blk(q, v, need_weights=False)

        q = self.final_ln(q)
        return q, self_atts, cross_atts


# MoE PROJECTOR

class MoEProjector(nn.Module):
    """
    Image-level MoE projector (4 experts).

    Experts:
      Sequential(
        0: LayerNorm(d_model)
        1: Linear(d_model ‚Üí 4*d_model)
        2: GELU
        3: Dropout
        4: Linear(4*d_model ‚Üí d_model)
        5: Dropout
        6: LayerNorm(d_model)
      )

    Router:
      Sequential(
        0: LayerNorm(d_router_in)
        1: Linear(d_router_in ‚Üí 576)
        2: GELU
        3: Linear(576 ‚Üí num_experts)
      )
    """
    def __init__(self, d_model=2048, d_router_in=1152, num_experts=4, dropout=0.1):
        super().__init__()
        self.num_experts = num_experts
        hidden = d_model * 4  # 8192 for d_model=2048

        # Experts
        self.experts = nn.ModuleList()
        for _ in range(num_experts):
            self.experts.append(
                nn.Sequential(
                    nn.LayerNorm(d_model),           # 0
                    nn.Linear(d_model, hidden),      # 1
                    nn.GELU(),                       # 2
                    nn.Dropout(dropout),             # 3
                    nn.Linear(hidden, d_model),      # 4
                    nn.Dropout(dropout),             # 5
                    nn.LayerNorm(d_model),           # 6
                )
            )

        # Router over CLS token
        self.router = nn.Sequential(
            nn.LayerNorm(d_router_in),              # router.0
            nn.Linear(d_router_in, 576),            # router.1
            nn.GELU(),                              # router.2
            nn.Linear(576, num_experts),            # router.3
        )

    def forward(self, q, cls_token):
        """
        q:         (B, K, d_model)   Q-Former outputs
        cls_token: (B, d_router_in)  SigLIP CLS
        """
        cls_token = cls_token.to(torch.float32)
        q = q.to(torch.float32)

        logits = self.router(cls_token)             # (B, E)
        weights = torch.softmax(logits, dim=-1)     # (B, E)

        # Per-expert outputs
        expert_outputs = []
        for expert in self.experts:
            expert_outputs.append(expert(q))        # list of (B, K, d_model)

        # Stack: (E, B, K, d_model)
        expert_stack = torch.stack(expert_outputs, dim=0)

        # Weights: (B, E) ‚Üí (E, B, 1, 1)
        w = weights.t().unsqueeze(-1).unsqueeze(-1)

        # Weighted sum over experts
        out = (w * expert_stack).sum(dim=0)         # (B, K, d_model)
        return out, weights, logits


# LOAD Q-FORMER + PROJECTOR WEIGHTS

qformer = QFormer(d_vis=d_vis, d_model=d_model).to(device)
qformer.load_state_dict(torch.load(QFORMER_CKPT, map_location=device))
qformer.eval()

projector = MoEProjector(d_model=d_model, d_router_in=d_vis, num_experts=4).to(device)
projector.load_state_dict(torch.load(PROJECTOR_CKPT, map_location=device))
projector.eval()

print("Loaded Q-Former and MoEProjector.")


# ATTENTION UTILS

def tokens_to_grid(attn_1d: torch.Tensor):
    """
    Convert 1D attention over N tokens ‚Üí (h, w) grid.
    Chooses h so that N % h == 0.
    """
    N = attn_1d.numel()
    h = int(math.sqrt(N))
    while h > 1 and (N % h) != 0:
        h -= 1
    w = N // h
    return attn_1d[:h * w].reshape(h, w)

def upsample_to_image(grid: np.ndarray, img_np: np.ndarray):
    """
    Upsample a (h√ów) attention grid to (H√óW) image resolution.
    """
    h, w = grid.shape
    H, W = img_np.shape[:2]

    t = torch.from_numpy(grid).float().unsqueeze(0).unsqueeze(0)  # (1,1,h,w)
    t_up = F.interpolate(t, size=(H, W), mode="bilinear", align_corners=False)
    return t_up.squeeze().numpy()  # (H, W)

def save_cross_attention_overlays(
    img: Image.Image,
    cross_attn_last: torch.Tensor,
    fname_prefix: str,
    out_dir: str,
    num_queries_vis: int = 4,
):
    """
    cross_attn_last: (H, K, N)  (after dropping batch; H=heads, K=queries, N=vis tokens)
    Saves:
      - per-query overlay PNG: <fname_prefix>_attn_queries.png
      - averaged overlay PNG: <fname_prefix>_attn_mean.png
    """
    img_np = np.array(img.convert("RGB"))
    H_heads, K, N = cross_attn_last.shape

    # Average over heads: (K, N)
    attn_q = cross_attn_last.mean(dim=0)  # (K, N)

    num_q = min(num_queries_vis, K)
    fig, axes = plt.subplots(1, num_q + 1, figsize=(3 * (num_q + 1), 3))
    if num_q + 1 == 1:
        axes = [axes]

    # Column 0: original image
    axes[0].imshow(img_np)
    axes[0].set_title("Original")
    axes[0].axis("off")

    # Per-query overlays
    for qi in range(num_q):
        a_1d = attn_q[qi]            # (N,)
        grid = tokens_to_grid(a_1d)  # (h, w)
        g = grid.detach().cpu().numpy()
        g = (g - g.min()) / (g.max() - g.min() + 1e-8)

        g_up = upsample_to_image(g, img_np)

        ax = axes[qi + 1]
        ax.imshow(img_np)
        ax.imshow(g_up, cmap="jet", alpha=0.42)
        ax.set_title(f"Query {qi}")
        ax.axis("off")

    plt.tight_layout()
    out_path_queries = os.path.join(out_dir, f"{fname_prefix}_attn_queries.png")
    plt.savefig(out_path_queries, dpi=140)
    plt.close(fig)

    # Mean over queries as well ‚Üí global map
    a_global = attn_q.mean(dim=0)          # (N,)
    grid_g = tokens_to_grid(a_global)
    g = grid_g.detach().cpu().numpy()
    g = (g - g.min()) / (g.max() - g.min() + 1e-8)
    g_up = upsample_to_image(g, img_np)

    plt.figure(figsize=(4, 4))
    plt.imshow(img_np)
    plt.imshow(g_up, cmap="jet", alpha=0.42)
    plt.title("Global cross-attn (mean over heads & queries)")
    plt.axis("off")
    out_path_global = os.path.join(out_dir, f"{fname_prefix}_attn_mean.png")
    plt.tight_layout()
    plt.savefig(out_path_global, dpi=140)
    plt.close()

    return {
        "per_query": out_path_queries,
        "global": out_path_global,
    }


# CORE INFERENCE: MoE + ATTENTION

@torch.no_grad()
def generate_moe_caption_with_attn(
    img: Image.Image,
    fname: str,
    max_new_tokens: int = 64,
):
    """
    Runs:
      - SigLIP ‚Üí CLS + tokens
      - Q-Former (collects attention)
      - MoE projector (router weights)
      - LLaMA generate
      - Saves router barplot + attention overlays

    Returns:
      caption: str
      weights: np.ndarray (E,)
      top_expert: int
      attn_paths: dict with overlay paths
    """
    img_rgb = img.convert("RGB")

    inp = processor(images=img_rgb, return_tensors="pt")["pixel_values"].to(device)

    # Vision forward
    vout = vision(pixel_values=inp)
    full_tokens = vout.last_hidden_state       # (1, 1+N, d_vis)
    cls_token = full_tokens[:, 0, :]          # (1, d_vis)
    vis_tokens = full_tokens[:, 1:, :]        # (1, N, d_vis)

    # Q-Former + attention
    q, self_atts, cross_atts = qformer(vis_tokens, collect_attn=True)
    # cross_atts[-1]: (1, H, K, N)
    last_cross = cross_atts[-1][0]            # (H, K, N)  (drop batch)

    # MoE projection
    q_proj, weights, logits = projector(q, cls_token)
    weights_np = weights[0].detach().cpu().numpy()
    top_expert = int(weights_np.argmax())

    # LLM generate
    prompts = [PROMPT_PREFIX]
    tok = tokenizer(prompts, return_tensors="pt", padding=True).to(device)
    txt_emb = llm.get_input_embeddings()(tok["input_ids"])         # (1, T, d_model)

    all_emb = torch.cat([q_proj.to(llm.dtype), txt_emb], dim=1)    # (1, K+T, d_model)
    attn_mask = torch.ones(all_emb.shape[:2], device=device, dtype=torch.long)

    gen_ids = llm.generate(
        inputs_embeds=all_emb,
        attention_mask=attn_mask,
        max_new_tokens=max_new_tokens,
        do_sample=False,
        pad_token_id=tokenizer.eos_token_id,
    )

    text = tokenizer.decode(gen_ids[0], skip_special_tokens=True)
    if text.startswith(PROMPT_PREFIX):
        text = text[len(PROMPT_PREFIX):].strip()

    # Router barplot
    plt.figure(figsize=(4, 3))
    plt.bar(np.arange(weights_np.shape[0]), weights_np)
    plt.xlabel("Expert ID")
    plt.ylabel("Weight")
    plt.title(f"{fname} ‚Üí Expert {top_expert}")
    plt.tight_layout()
    router_plot_path = os.path.join(OUT_DIR, f"router_{fname}.png")
    plt.savefig(router_plot_path, dpi=140)
    plt.close()

    # Attention overlays
    attn_paths = save_cross_attention_overlays(
        img_rgb,
        last_cross,
        fname_prefix=fname,
        out_dir=OUT_DIR,
        num_queries_vis=4,
    )

    return text, weights_np, top_expert, attn_paths


# RUN ON FOLDER + LOGGING + GLOBAL HISTOGRAM

images = []
paths = []

for fname in os.listdir(IMG_DIR):
    if fname.lower().endswith((".jpg", ".jpeg", ".png", ".webp")):
        fpath = os.path.join(IMG_DIR, fname)
        try:
            img = Image.open(fpath).convert("RGB")
            images.append(img)
            paths.append(fname)
        except Exception as e:
            print(f"‚ö†Ô∏è Skipping {fname}: {e}")

if not images:
    print("‚ö†Ô∏è No images found in IMG_DIR. Put some images into:", IMG_DIR)
else:
    print(f"Found {len(images)} images. Running MoE inference...")

router_hist = []
expert_logs = {}

for fname, img in zip(paths, images):
    cap, w, top, attn_paths = generate_moe_caption_with_attn(img, fname)

    router_hist.append(top)
    expert_logs[fname] = {
        "caption": cap,
        "weights": w.tolist(),
        "top_expert": int(top),
        "attn_overlays": attn_paths,
    }

    print(f"\nImage: {fname}")
    print(f"  Caption    : {cap}")
    print(f"  Top expert : {top}")
    print(f"  Weights    : {w}")

# Save JSON log
log_path = os.path.join(OUT_DIR, "moe_router_log.json")
with open(log_path, "w", encoding="utf-8") as f:
    json.dump(expert_logs, f, indent=2)

# Global router histogram
if router_hist:
    plt.figure(figsize=(5, 3))
    plt.hist(router_hist, bins=np.arange(0, 5) - 0.5, rwidth=0.8)
    plt.xticks(np.arange(4))
    plt.xlabel("Expert ID")
    plt.ylabel("Count")
    plt.title("Router Histogram ‚Äî Expert Usage")
    plt.tight_layout()
    hist_path = os.path.join(OUT_DIR, "router_hist.png")
    plt.savefig(hist_path, dpi=140)
    plt.close()
    print("\nRouter histogram saved to:", hist_path)

print("\n===================================================")
print("   MoE Routing + Attention Summary")
print("===================================================")
print("Images processed:", len(expert_logs))
print("Log JSON        :", log_path)
print("Output directory:", OUT_DIR)
