In [None]:

"""
Penn-Action → fourM preprocessing helpers
========================================
Execute this cell once.  It defines functions but does **no I/O** until you call
`run(...)` from Cell 2.

**Output schema (per video)**
```
<OUTPUT_ROOT>/
  tok_rgb/        <video>/<ref>.npy          – Cosmos tokens of reference frame
  coords/         <video>/<ref>.npy          – 3×13 (x_tok, y_tok, vis)
  captions/       <video>/<ref>.json         – BLIP caption

  tok_rgb_next/   <video>/<ref>_n1.npy       – Cosmos tokens of next‑frame #1 …
  coords_next/    <video>/<ref>_n1.npy       – joint tokens of next‑frame #1 …
```
Every RGB token file has a **matching coords file**.
"""

from pathlib import Path
import json
from typing import List, Optional

import numpy as np
import scipy.io
import torch
import torchvision.transforms.functional as TF
from huggingface_hub import snapshot_download
from PIL import Image
from skimage.metrics import structural_similarity as ssim
from tqdm import tqdm
from transformers import BlipForConditionalGeneration, BlipProcessor

# ------------------------ constants ------------------------
TARGET_RES       = 256     # shorter side after resize for Cosmos encoder
SSIM_THRESHOLD   = 0.89   # duplicate‑frame threshold (SSIM > thr ⇒ skip)
REF_PLUS_NEXT    = 7      # 1 reference + n distinct successor frames
COORD_QLEVELS    = 8192    # discrete bins for x & y

# ------------------------ similarity ------------------------

def is_similar(img1: Image.Image, img2: Image.Image, thr: float = SSIM_THRESHOLD) -> bool:
    a = np.asarray(img1.convert("L"), dtype=np.float32)
    b = np.asarray(img2.convert("L"), dtype=np.float32)
    return ssim(a, b, data_range=255.0) > thr

# ------------------------ tokenisers ------------------------

def build_tokenisers(device: str):
    ckpt_dir = Path("/tmp/cosmos_DI16x16")
    if not ckpt_dir.exists():
        snapshot_download("nvidia/Cosmos-0.1-Tokenizer-DI16x16", local_dir=str(ckpt_dir))
    from cosmos_tokenizer.image_lib import ImageTokenizer

    image_tok = ImageTokenizer(
        checkpoint_enc=str(ckpt_dir / "encoder.jit"),
        checkpoint_dec=str(ckpt_dir / "decoder.jit"),
    ).to(device).eval()

    proc = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
    blip = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device).eval()
    return image_tok, proc, blip

# ------------------------ helpers ------------------------

def encode_rgb(img: Image.Image, image_tok, device: str) -> np.ndarray:
    ten = TF.to_tensor(img).to(device).unsqueeze(0) * 2 - 1
    with torch.no_grad():
        tok, _ = image_tok.encode(ten)
    return tok.squeeze(0).cpu().numpy()


def caption_image(img: Image.Image, proc, blip, device: str) -> str:
    with torch.no_grad():
        inp = proc(images=img, return_tensors="pt").to(device)
        out = blip.generate(**inp, max_length=30)
    return proc.decode(out[0], skip_special_tokens=True)


def quantise_coords(x_px: np.ndarray, y_px: np.ndarray, vis: np.ndarray, img_size: int = TARGET_RES) -> np.ndarray:
    """Pixel coordinates → discrete tokens.

    *Normalises* by `img_size` internally, so callers can pass pixel coordinates
    in the resized frame space (0‥img_size). Returns a (3,13) uint16 array.
    """
    q = COORD_QLEVELS - 1
    x_tok = np.clip(np.rint(x * q), 0, q).astype(np.uint16)
    y_tok = np.clip(np.rint(y * q), 0, q).astype(np.uint16)
    vis_tok = vis.astype(np.uint16)
    return np.stack([x_tok,y_tok,vis_tok])

# ------------------------ core routine ------------------------

def process_video(vid: str,
                  frames_root: Path,
                  labels_root: Path,
                  out_root: Path,
                  image_tok,
                  proc,
                  blip,
                  device: str):

    frame_dir  = frames_root / vid
    label_path = labels_root / f"{vid}.mat"
    if not frame_dir.exists() or not label_path.exists():
        print(f"⚠️  Skip {vid}: missing data")
        return

    mat = scipy.io.loadmat(label_path, squeeze_me=True, struct_as_record=False)
    x_all, y_all = mat["x"], mat["y"]          # pixel coords in original frame
    vis_all     = mat["visibility"].astype(bool)
    T           = int(mat["nframes"])
    H0, W0, _   = mat["dimensions"]
    sx, sy      = TARGET_RES / W0, TARGET_RES / H0   # scale to TARGET_RES canvas

    frame_files: List[Path] = sorted(frame_dir.glob("*.jpg"))
    assert len(frame_files) == T, f"{vid}: frame/label mismatch"

    # choose reference + up to 6 visually distinct successors
    kept = [0]; last = Image.open(frame_files[0]).convert("RGB")
    for j in range(1, T):
        if len(kept) == REF_PLUS_NEXT:
            break
        cand = Image.open(frame_files[j]).convert("RGB")
        if not is_similar(last, cand):
            kept.append(j); last = cand

    ref_idx, future_indices = kept[0], kept[1:]
    ref_stem = f"{ref_idx+1:05d}"

    # output dirs
    for sub in ("tok_rgb", "coords", "captions", "tok_rgb_next", "coords_next"):
        (out_root / sub / vid).mkdir(parents=True, exist_ok=True)

    # ---- reference ----
    ref_img = Image.open(frame_files[ref_idx]).convert("RGB")
    ref_res = TF.resize(ref_img, (TARGET_RES, TARGET_RES), interpolation=Image.BICUBIC)
    np.save(out_root/"tok_rgb"/vid/f"{ref_stem}.npy", encode_rgb(ref_res, image_tok, device))

    x_px = x_all[ref_idx] * sx; y_px = y_all[ref_idx] * sy
    np.save(out_root/"coords"/vid/f"{ref_stem}.npy", quantise_coords(x_px, y_px, vis_all[ref_idx]))

    cap = caption_image(ref_res, proc, blip, device)
    with open(out_root/"captions"/vid/f"{ref_stem}.json", "w") as fp:
        json.dump({"video": vid, "frame": ref_stem, "caption": cap}, fp, indent=2)

    # ---- successors ----
    for k, idx in enumerate(future_indices, 1):
        stem_next = f"{ref_stem}_n{k}"
        img = Image.open(frame_files[idx]).convert("RGB")
        img_res = TF.resize(img, (TARGET_RES, TARGET_RES), interpolation=Image.BICUBIC)
        np.save(out_root/"tok_rgb_next"/vid/f"{stem_next}.npy", encode_rgb(img_res, image_tok, device))

        x_px = x_all[idx] * sx; y_px = y_all[idx] * sy
        np.save(out_root/"coords_next"/vid/f"{stem_next}.npy", quantise_coords(x_px, y_px, vis_all[idx]))

    print(f"✅ {vid}: reference + {len(future_indices)} next frames written → {out_root}")

# ------------------------ wrapper ------------------------

def run(frames_root: Path,
        labels_root: Path,
        output_root: Path,
        video: Optional[str] = None):
    assert frames_root.exists(), "frames_root missing"
    assert labels_root.exists(), "labels_root missing"

    device = "cuda" if torch.cuda.is_available() else "cpu"
    image_tok, proc, blip = build_tokenisers(device)

    vids = [video] if video else sorted([d.name for d in frames_root.iterdir() if d.is_dir()])
    for vid in tqdm(vids, desc="videos"):
        process_video(vid, frames_root, labels_root, output_root, image_tok, proc, blip, device)

In [None]:
"""Execute **after** Cell 1.
Edit the three paths and (optionally) `VIDEO_ID`, then run the cell to start
pre‑processing.
"""

# --- user‑editable paths ---------------------------------
FRAMES_ROOT = Path("/home/skalli/COM-304-FM/project/penn_action_raw/Penn_Action/frames")   
LABELS_ROOT = Path("/home/skalli/COM-304-FM/project/penn_action_raw/Penn_Action/labels")  
OUTPUT_ROOT = Path("/home/skalli/COM-304-FM/project/new/output/")                     
VIDEO_ID    = None  # e.g. "0003" for a single clip
# ---------------------------------------------------------

run(FRAMES_ROOT, LABELS_ROOT, OUTPUT_ROOT, VIDEO_ID)

In [None]:
from pathlib import Path
import shutil
import random

def split_output_dataset(
    output_root: Path,
    train_ratio: float = 0.8,
    eval_ratio: float = 0.1,
    test_ratio: float = 0.1,
    seed: int = 42
):
    assert abs(train_ratio + eval_ratio + test_ratio - 1.0) < 1e-6, "Ratios must sum to 1."

    modalities = ["tok_rgb", "coords", "captions", "tok_rgb_next", "coords_next"]
    random.seed(seed)

    # Find all pair IDs from one modality (they should be the same across all)
    base_dir = output_root / "tok_rgb"
    all_ids = [d.name for d in base_dir.iterdir() if d.is_dir() and d.name.isdigit()]
    all_ids.sort()
    random.shuffle(all_ids)

    N = len(all_ids)
    n_train = int(N * train_ratio)
    n_eval = int(N * eval_ratio)

    splits = {
        "train": all_ids[:n_train],
        "eval":  all_ids[n_train:n_train + n_eval],
        "test":  all_ids[n_train + n_eval:]
    }

    print(f"Found {N} pairs: {len(splits['train'])} train, {len(splits['eval'])} eval, {len(splits['test'])} test")

    for split_name, split_ids in splits.items():
        for pid in split_ids:
            for mod in modalities:
                src = output_root / mod / pid
                dst = output_root / split_name / mod / pid
                if src.exists():
                    dst.parent.mkdir(parents=True, exist_ok=True)
                    shutil.move(str(src), str(dst))

    print("✅ Split completed. Original folders now empty (only train/eval/test remain).")

# Example usage:
split_output_dataset(Path("output"))