In [None]:
# Copyright 2025 EPFL – Apache 2.0
#
# Process *all* Penn-Action videos:
#   RGB   → Cosmos tokens  →  output/dump/train/tok_rgb/0001/00001.npy …
#   Pose  → Cosmos tokens  →  output/dump/train/tok_pose/0001/00001.npy …
#   Coords → .npy                        idem
#   Caption → .json                      idem
# --------------------------------------------------------------------

import os, json
import numpy as np
import torch
from pathlib import Path
import torchvision.transforms.functional as TF
from PIL import Image, ImageDraw
import scipy.io
from tqdm import tqdm
from transformers import BlipProcessor, BlipForConditionalGeneration
from cosmos_tokenizer.image_lib import ImageTokenizer

# ---------------- paths ----------------
frames_root  = "../project/penn_action_raw/Penn_Action/frames"
labels_root  = "../project/penn_action_raw/Penn_Action/labels"
out_root     = "output/train"                  # master output folder

# create modality sub-dirs once; each video gets its own sub-folder inside
for sub in ("tok_rgb", "tok_pose", "coords", "captions"):
    (Path(out_root)/sub).mkdir(parents=True, exist_ok=True)

# --------------- models ----------------
device = "cuda" if torch.cuda.is_available() else "cpu"

image_tokenizer = ImageTokenizer(
    checkpoint_enc="/tmp/nvidiaxxxx/Cosmos-0.1-Tokenizer-DI16x16/encoder.jit",
    checkpoint_dec="/tmp/nvidiaxxxx/Cosmos-0.1-Tokenizer-DI16x16/decoder.jit",
).to(device)

processor  = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained(
    "Salesforce/blip-image-captioning-base"
).to(device)

# ------------- helpers -----------------
def encode_rgb(img: Image.Image) -> np.ndarray:
    ten = TF.to_tensor(img).unsqueeze(0).to(device) * 2 - 1
    with torch.inference_mode():
        tok, _ = image_tokenizer.encode(ten)
    return tok.squeeze(0).cpu().short().numpy()

def draw_pose(x, y, vis, size=640):
    edges = [(1,2),(1,3),(2,4),(3,5),(4,6),(7,8),(7,9),
             (8,10),(9,11),(10,12)]
    canvas = Image.new("RGBA", (size, size), (0,0,0,255))
    drw    = ImageDraw.Draw(canvas)
    for j in range(13):
        if vis[j]:
            cx, cy = int(x[j]), int(y[j]); r = 5 if j==0 else 3
            drw.ellipse((cx-r, cy-r, cx+r, cy+r),
                        outline="red" if j==0 else None,
                        width=2 if j==0 else 0,
                        fill=(255,0,0,255) if j else None)
    for j1,j2 in edges:
        if vis[j1] and vis[j2]:
            drw.line([(int(x[j1]),int(y[j1])),(int(x[j2]),int(y[j2]))],
                     fill="red", width=2)
    if vis[0] and vis[7] and vis[8]:
        hip = tuple(np.mean([[x[7],y[7]],[x[8],y[8]]],0).astype(int))
        drw.line([(int(x[0]),int(y[0])), hip], fill="red", width=2)
    return canvas.convert("RGB")

def encode_pose(x,y,vis):
    return encode_rgb(draw_pose(x,y,vis))

def caption_image(img):
    inp = processor(images=img, return_tensors="pt").to(device)
    out = blip_model.generate(**inp, max_length=30)
    return processor.decode(out[0], skip_special_tokens=True)

# ------------- main loop --------------
video_dirs = [d for d in sorted(os.listdir(frames_root))
              if (Path(frames_root)/d).is_dir()]

for vid in tqdm(video_dirs, desc="videos"):
    label_file = Path(labels_root)/f"{vid}.mat"
    frame_dir  = Path(frames_root)/vid
    if not label_file.exists():
        print(f"⚠️  skip {vid}: no label file"); continue

    # ------ load MATLAB struct ------
    mat   = scipy.io.loadmat(label_file, squeeze_me=True, struct_as_record=False)
    x_all, y_all, vis_all = mat['x'], mat['y'], mat['visibility'].astype(bool)
    T                 = int(mat['nframes'])
    H0, W0, _         = mat['dimensions']
    scale_x, scale_y  = 640.0/W0, 640.0/H0

    frame_files = sorted(frame_dir.glob("*.jpg"))
    assert len(frame_files)==T, f"{vid}: frame/label mismatch"

    # make per-video sub-folders
    for sub in ("tok_rgb","tok_pose","coords","captions"):
        (Path(out_root)/sub/vid).mkdir(parents=True, exist_ok=True)

    for i, fpath in enumerate(tqdm(frame_files, desc=vid, leave=False)):
        stem = f"{i+1:05d}"

        # --- RGB ---
        img = Image.open(fpath).resize((640,640))
        np.save(Path(out_root)/"tok_rgb"/vid/f"{stem}.npy", encode_rgb(img))

        # --- Pose & coords ---
        x   = x_all[i]*scale_x
        y   = y_all[i]*scale_y
        vis = vis_all[i]
        np.save(Path(out_root)/"tok_pose"/vid/f"{stem}.npy", encode_pose(x,y,vis))
        coords = np.stack([x/640.0, y/640.0, vis.astype(float)])
        np.save(Path(out_root)/"coords"/vid/f"{stem}.npy", coords)

        # --- Caption ---
        cap = caption_image(img)
        with open(Path(out_root)/"captions"/vid/f"{stem}.json","w") as fp:
            json.dump({"video": vid, "frame": stem, "caption": cap}, fp, indent=2)

print("✅  All videos processed →", out_root)


videos:   0%|                                                                                                                      | 0/2326 [00:00<?, ?it/s]
0001:   0%|                                                                                                                         | 0/151 [00:00<?, ?it/s][A
0001:   1%|▋                                                                                                                | 1/151 [00:01<04:59,  1.99s/it][A
0001:   1%|█▍                                                                                                               | 2/151 [00:02<02:25,  1.02it/s][A
0001:   2%|██▏                                                                                                              | 3/151 [00:02<01:34,  1.56it/s][A
0001:   3%|██▉                                                                                                              | 4/151 [00:02<01:10,  2.08it/s][A
0001:   3%|███▋                            