# ECE1512 Project A - Part B / Section 5
## CLIP (ViT-B/16) Vision Token Pruning - Efficiency Toy Profiling

In [None]:
!pip -q install transformers==4.44.2 timm==1.0.9

import os, time, csv, math, random
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from transformers import CLIPVisionModel, CLIPImageProcessor

# Reproducibility
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)
if DEVICE == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))

### Load CLIP Vision & Define Pruning Helpers

In [None]:
from transformers import CLIPVisionModel, CLIPImageProcessor

VISION_ID = "openai/clip-vit-base-patch16"
vision: CLIPVisionModel = CLIPVisionModel.from_pretrained(VISION_ID).to(DEVICE).eval()
processor = CLIPImageProcessor.from_pretrained(VISION_ID)

# ===========================================
# Synthetic batch of images (efficiency focus)
# You can swap to real images later if desired
# ===========================================
BATCH = 8      # images per batch
RES   = 224    # CLIP default
DUMMY_IMAGES = torch.randn(BATCH, 3, RES, RES, device=DEVICE)  # synthetic images

# ===========================================
# Patch embedding -> [B, N, D] tokens (no class/pos)
# ===========================================
def get_patch_tokens(model: CLIPVisionModel, pixel_values: torch.Tensor) -> torch.Tensor:
    """
    Returns patch tokens BEFORE adding class/pos embeddings.
    Shape: [B, N, D]
    """
    vm = model.vision_model
    # Patch embedding: Conv2d on (B,3,H,W) -> (B,Hidden,H',W')
    x = vm.embeddings.patch_embedding(pixel_values)
    # Flatten to tokens
    x = x.flatten(2).transpose(1, 2)  # [B, N, D]
    return x

# ===========================================
# Token pruning by L2-norm scoring
# Keep ratio p in (0,1]
# ===========================================
def prune_tokens(x_tokens: torch.Tensor, keep_ratio: float = 0.7):
    """
    x_tokens: [B, N, D]
    Returns:
      x_pruned: [B, 1+K, D]  (we prepend a summary token)
    """
    B, N, D = x_tokens.shape
    K = max(1, int(N * keep_ratio))
    scores = x_tokens.norm(dim=-1)                       # [B, N]
    top_idx = scores.topk(K, dim=1).indices              # [B, K]
    x_kept = x_tokens.gather(1, top_idx.unsqueeze(-1).expand(-1, -1, D))  # [B, K, D]
    summary = x_tokens.mean(dim=1, keepdim=True)         # [B, 1, D]
    x_pruned = torch.cat([summary, x_kept], dim=1)        # [B, 1+K, D]
    return x_pruned

# ===========================================
# Baseline forward: use full vision forward()
# ===========================================
@torch.no_grad()
def forward_baseline(model: CLIPVisionModel, pixel_values: torch.Tensor):
    # Full forward (includes embeddings + encoder)
    return model(pixel_values=pixel_values)

# ===========================================
# Pruned forward: run encoder on reduced sequence
# (Optionally add sliced positional embeddings for better approximation)
# IMPORTANT: pass hidden_states as POSitional arg (no kw) to avoid API error
# ===========================================
@torch.no_grad()
def forward_pruned_encoder_only(model: CLIPVisionModel, x_pruned: torch.Tensor, add_positional: bool = True):
    """
    Run the CLIP vision encoder directly on a reduced token sequence.
    Optionally add sliced positional embeddings to better approximate ViT inputs.
    Returns: last hidden states [B, L, D]
    """
    vm = model.vision_model  # CLIPVisionTransformer

    hidden_states = x_pruned  # [B, L, D] on correct device/dtype already

    # (Optional) add positional embeddings sliced to current token length
    if add_positional and hasattr(vm, "embeddings") and hasattr(vm.embeddings, "position_embedding"):
        # vm.embeddings.position_embedding: [1, num_patches+1, D]
        pos = vm.embeddings.position_embedding.weight[: hidden_states.size(1), :].unsqueeze(0).to(hidden_states.device).to(hidden_states.dtype)
        hidden_states = hidden_states + pos

    # Some CLIP variants apply a pre-layernorm before encoder; if present, use it.
    if hasattr(vm, "pre_layrnorm") and vm.pre_layrnorm is not None:
        hidden_states = vm.pre_layrnorm(hidden_states)

    # Pass as positional arg (no keyword 'hidden_states')
    out = vm.encoder(
        hidden_states,                 # positional, not 'hidden_states=...'
        output_attentions=False,
        output_hidden_states=False,
        return_dict=True
    )

    # Some models apply a final layernorm after encoder; keep it if available.
    if hasattr(vm, "post_layernorm") and vm.post_layernorm is not None:
        return vm.post_layernorm(out.last_hidden_state)

    return out.last_hidden_state

### Measurement & Run the Sweep (p = 1.0, 0.9, 0.7, 0.5)

In [None]:
N_WARMUP = 3
N_RUNS   = 10

def cuda_sync():
    if DEVICE == "cuda":
        torch.cuda.synchronize()

@torch.no_grad()
def measure_latency_mem(fn, *args, **kwargs):
    if DEVICE == "cuda":
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()

    # warm-up
    for _ in range(N_WARMUP):
        _ = fn(*args, **kwargs)
        cuda_sync()

    times = []
    for _ in range(N_RUNS):
        t0 = time.time()
        _ = fn(*args, **kwargs)
        cuda_sync()
        times.append(time.time() - t0)

    lat_ms = float(np.mean(times) * 1000.0)
    thr = (BATCH) / np.mean(times)
    peak_mem = torch.cuda.max_memory_allocated() / (1024**2) if DEVICE == "cuda" else float('nan')
    return lat_ms, thr, peak_mem

# ===========================================
# Build pixel values from synthetic images
# (We pass raw tensors; processor is not strictly needed for random data)
# ===========================================
pixel_values = DUMMY_IMAGES  # already normalized-ish random

# Baseline patch tokens for pruning reference
with torch.no_grad():
    tokens_full = get_patch_tokens(vision, pixel_values)  # [B, N, D]

# ===========================================
# Run sweep over keep ratios
# ===========================================
keep_ratios = [1.0, 0.9, 0.7, 0.5]  # 100%, 90%, 70%, 50% tokens kept
results = []

# 1) Baseline (full forward)
lat, thr, mem = measure_latency_mem(forward_baseline, vision, pixel_values)
results.append(("Baseline", 1.0, lat, thr, mem))

# 2) Pruned variants (encoder-only on reduced sequence)
for p in keep_ratios:
    if p == 1.0:
        # For reporting consistency, also compute "pruned at p=1.0"
        x_pruned = torch.cat([tokens_full.mean(dim=1, keepdim=True), tokens_full], dim=1)
    else:
        x_pruned = prune_tokens(tokens_full, keep_ratio=p)
    lat, thr, mem = measure_latency_mem(forward_pruned_encoder_only, vision, x_pruned)
    results.append((f"VTP p={p:.1f}", p, lat, thr, mem))

# ===========================================
# Save CSV
# ===========================================
os.makedirs("vlm/results", exist_ok=True)
csv_path = "vlm/results/clip_vtp_results.csv"
with open(csv_path, "w", newline="") as f:
    w = csv.writer(f)
    w.writerow(["Variant", "KeepRatio", "Latency(ms)", "Throughput(img/s)", "PeakMem(MiB)"])
    for r in results:
        w.writerow([r[0], f"{r[1]:.2f}", f"{r[2]:.3f}", f"{r[3]:.3f}", f"{r[4]:.1f}"])
print("Saved CSV ->", csv_path)

# Print summary
print("\n==== Results ====")
print("{:<12} {:>10} {:>14} {:>18} {:>14}".format("Variant","KeepRatio","Latency(ms)","Throughput(img/s)","PeakMem(MiB)"))
for lab, p, lat, thr, mem in results:
    print("{:<12} {:>10.2f} {:>14.2f} {:>18.2f} {:>14.1f}".format(lab, p, lat, thr, mem))

### Plots (Latency & Memory)

In [None]:
labels = [r[0] for r in results]
latencies = [r[2] for r in results]
mems = [r[4] for r in results]

plt.figure()
plt.bar(labels, latencies)
plt.ylabel("Latency (ms)")
plt.title("CLIP VTP — Latency")
plt.xticks(rotation=20)
plt.tight_layout()
lat_fig = "vlm/results/clip_vtp_latency.png"
plt.savefig(lat_fig, dpi=150)
print("Saved ->", lat_fig)

plt.figure()
plt.bar(labels, mems)
plt.ylabel("Peak Memory (MiB)")
plt.title("CLIP VTP — Peak Memory")
plt.xticks(rotation=20)
plt.tight_layout()
mem_fig = "vlm/results/clip_vtp_memory.png"
plt.savefig(mem_fig, dpi=150)
print("Saved ->", mem_fig)