In [None]:
import os
import glob
import math
import numpy as np
import torch
import torch.nn.functional as F

from PIL import Image

# Video writer
import imageio.v2 as imageio

# Repo imports
import sys
sys.path.append('/workspace')
sys.path.append('/workspace/planner')

from planner.modules.config_planner import ConfigPlanner
from planner.modules.transformer import LatentPlanner

from models.diffusion import AsymmetricUNet
from models.scheduler import NoiseScheduler

from tqdm import tqdm
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from scipy import ndimage

from typing import List, Tuple, Optional

def to_uint8_img(x_bchw: torch.Tensor) -> np.ndarray:
    """x in [-1,1], shape [B,3,H,W] -> uint8 [H,W,3]"""
    x = x_bchw.detach().float().clamp(-1, 1)
    x = (x + 1) * 0.5
    x = (x * 255.0).round().to(torch.uint8)
    x = x[0].permute(1, 2, 0).cpu().numpy()
    return x


def extract_ball_position(frame: np.ndarray, threshold: float = 0.7) -> Optional[Tuple[float, float]]:
    if len(frame.shape) == 3:
        gray = np.mean(frame, axis=2).astype(np.float32) / 255.0
    else:
        gray = frame.astype(np.float32)
        if gray.max() > 1.0:
            gray = gray / 255.0
    
    mask = gray > threshold
    
    if not np.any(mask):
        max_idx = np.unravel_index(np.argmax(gray), gray.shape)
        return (float(max_idx[0]), float(max_idx[1]))
    
    y_coords, x_coords = np.where(mask)
    if len(y_coords) == 0:
        return None
    
    weights = gray[y_coords, x_coords]
    total_weight = np.sum(weights)
    
    if total_weight > 0:
        y_center = np.average(y_coords, weights=weights)
        x_center = np.average(x_coords, weights=weights)
    else:
        y_center = np.mean(y_coords)
        x_center = np.mean(x_coords)
    
    return (float(y_center), float(x_center))


def extract_trajectory(video: torch.Tensor, threshold: float = 0.7, debug: bool = False) -> List[Tuple[float, float]]:
    def to_uint8(tensor):
        tensor = tensor.detach().float().clamp(-1, 1)
        tensor = (tensor + 1) * 0.5
        tensor = (tensor * 255.0).round().to(torch.uint8)
        return tensor
    
    if video.ndim == 5:  # [B, T, C, H, W]
        video = video[0]
    
    T = video.shape[0]
    trajectory = []
    invalid_count = 0
    
    for t in range(T):
        frame_tensor = to_uint8(video[t])
        frame = frame_tensor.permute(1, 2, 0).cpu().numpy()  # [H, W, 3]
        pos = extract_ball_position(frame, threshold)
        if pos is not None:
            trajectory.append(pos)
        else:
            invalid_count += 1
            if len(trajectory) > 0:
                trajectory.append(trajectory[-1])
                if debug:
                    print(f"  Frame {t}: pos=None, using previous position {trajectory[-1]}")
            else:
                trajectory.append((0.0, 0.0))
                if debug:
                    print(f"  Frame {t}: pos=None, using default (0, 0)")
    
    if debug and invalid_count > 0:
        print(f"  Warning: {invalid_count}/{T} frames had invalid positions")
    
    return trajectory

def get_resolution_for_timestep(t: int, num_timesteps: int, high_res: int, latent_res: int, k_step: int) -> int:
    ratio = t / (num_timesteps - 1)
    size_float = high_res - ratio * (high_res - latent_res)
    size_int = int(round(size_float / k_step) * k_step)
    size_int = max(latent_res, min(high_res, size_int))
    return size_int


def planner_generate(planner: LatentPlanner, cond_btchw: torch.Tensor, total_T: int, show_progress: bool = True) -> torch.Tensor:
    planner.eval()
    B, k, C, H, W = cond_btchw.shape
    frames = [cond_btchw[:, i] for i in range(k)]

    iterator = range(total_T - k)
    if show_progress:
        iterator = tqdm(iterator, desc="Generating frames")

    with torch.no_grad():
        for i in iterator:
            seq = torch.stack(frames, dim=1)  # [B, t, C,H,W]
            attn = torch.ones(B, seq.shape[1], device=seq.device)
            pred = planner(seq, attn_mask=attn)  # [B, t+1, C,H,W]
            next_frame = pred[:, seq.shape[1]]   # prediction for f_t
            frames.append(next_frame)

    out = torch.stack(frames, dim=1)
    return out

# Backward compatibility
def planner_generate_from_5(planner: LatentPlanner, cond_5_btchw: torch.Tensor, total_T: int) -> torch.Tensor:
    """Wrapper for backward compatibility"""
    return planner_generate(planner, cond_5_btchw, total_T, show_progress=True)


def refiner_refine_sequence(
    refiner: AsymmetricUNet,
    scheduler: NoiseScheduler,
    lowres_btchw: torch.Tensor,
    high_res: int = 128,
    latent_res: int = 32,
    k_step: int = 1,
    t_start_frac: float = 0.1,
    batch_frames: bool = True,
) -> torch.Tensor:
    """
    lowres_btchw: [B,T,3,latent_res,latent_res] in [-1,1]
    return: [B,T,3,high_res,high_res] in [-1,1]

    Note: This uses x0-pred sampling (scheduler.step_x0).
    
    Args:
        batch_frames: If True, process all frames in parallel (B*T batch). Much faster but uses more memory.
    """
    refiner.eval()
    B, T, C, H, W = lowres_btchw.shape
    assert H == latent_res and W == latent_res

    t_start = int((scheduler.num_timesteps - 1) * float(t_start_frac))
    
    print(f"  [Refiner Debug] num_timesteps={scheduler.num_timesteps}, t_start_frac={t_start_frac}")
    print(f"  [Refiner Debug] t_start={t_start}, will run from {t_start} to 0 (total {t_start + 1} steps)")

    with torch.no_grad():
        if batch_frames:
            out_frames_list = []
            
            pbar_videos = tqdm(range(B), desc="Refining videos", total=B, miniters=1, mininterval=1.0)
            
            for b_idx in pbar_videos:
                latent_video = lowres_btchw[b_idx]  # [T, 3, 32, 32]
                
                noise = torch.randn_like(latent_video)
                timesteps = torch.full((T,), t_start, device=latent_video.device, dtype=torch.long)
                curr = scheduler.add_noise(latent_video, noise, timesteps)  # [T, 3, 32, 32]

                num_steps = t_start + 1
                pbar_diffusion = tqdm(
                    range(t_start, -1, -1), 
                    desc=f"Video {b_idx+1}/{B} diffusion", 
                    total=num_steps,
                    leave=False,
                    miniters=max(1, num_steps // 20),  
                    mininterval=0.5
                )
                
                for t in pbar_diffusion:
                    t_batch = torch.full((T,), t, device=latent_video.device, dtype=torch.long)
                    target_res_t = get_resolution_for_timestep(t, scheduler.num_timesteps, high_res, latent_res, k_step)

                    # Ensure curr is at target_res_t
                    curr_h, curr_w = curr.shape[-2:]
                    if curr_h != target_res_t or curr_w != target_res_t:
                        curr = F.interpolate(curr, size=(target_res_t, target_res_t), mode='bilinear', align_corners=False)
                        curr_h, curr_w = target_res_t, target_res_t

                    # Predict x0 at high res (batch all T frames)
                    pred_x0_high = refiner(curr, t_batch, target_shape=(high_res, high_res))  # [T, 3, 128, 128]

                    # Downsample x0 to curr's actual resolution
                    pred_x0_curr = F.interpolate(pred_x0_high, size=(curr_h, curr_w), mode='bilinear', align_corners=False)
                    
                    # Ensure exact shape match
                    assert pred_x0_curr.shape == curr.shape, \
                        f"Shape mismatch: pred_x0_curr {pred_x0_curr.shape}, curr {curr.shape}"
                    
                    # Use scalar timestep for scheduler
                    t_val = int(t)
                    prev = scheduler.step_x0(pred_x0_curr, t_val, curr)

                    if t > 0:
                        next_res = get_resolution_for_timestep(t - 1, scheduler.num_timesteps, high_res, latent_res, k_step)
                        # Ensure prev is at next_res before assigning to curr
                        if prev.shape[-1] != next_res or prev.shape[-2] != next_res:
                            curr = F.interpolate(prev, size=(next_res, next_res), mode='bilinear', align_corners=False)
                        else:
                            curr = prev
                    else:
                        curr = prev
                
                pbar_diffusion.close()

                # Final upscale to high_res
                if curr.shape[-1] != high_res or curr.shape[-2] != high_res:
                    curr = F.interpolate(curr, size=(high_res, high_res), mode='bilinear', align_corners=False)

                out_frames_list.append(curr)  # [T, 3, 128, 128]
                pbar_videos.update(1)
            
            pbar_videos.close()
            
            # Stack all videos: [B, T, C, H, W]
            out = torch.stack(out_frames_list, dim=0)
            
        else:
            # Sequential mode (original): process frame by frame
            out_frames = []
            num_steps = t_start + 1
            
            pbar_frames = tqdm(
                range(T), 
                desc="Refining frames", 
                total=T,
                miniters=1,
                mininterval=1.0
            )
            
            for i in pbar_frames:
                latent_img = lowres_btchw[:, i]  # [B,3,32,32]

                # Start from noisy latent at timestep t_start
                noise = torch.randn_like(latent_img)
                timesteps = torch.full((B,), t_start, device=latent_img.device, dtype=torch.long)
                curr = scheduler.add_noise(latent_img, noise, timesteps)

                pbar_diffusion = tqdm(
                    range(t_start, -1, -1),
                    desc=f"Frame {i+1}/{T} diffusion",
                    total=num_steps,
                    leave=False,
                    miniters=max(1, num_steps // 20),
                    mininterval=0.5
                )
                
                for t in pbar_diffusion:
                    t_batch = torch.full((B,), t, device=latent_img.device, dtype=torch.long)
                    target_res_t = get_resolution_for_timestep(t, scheduler.num_timesteps, high_res, latent_res, k_step)

                    if curr.shape[-1] != target_res_t:
                        curr = F.interpolate(curr, size=(target_res_t, target_res_t), mode='bilinear', align_corners=False)

                    # Predict x0 at high res
                    pred_x0_high = refiner(curr, t_batch, target_shape=(high_res, high_res))

                    # Downsample x0 to current resolution for step
                    pred_x0_curr = F.interpolate(pred_x0_high, size=(target_res_t, target_res_t), mode='bilinear', align_corners=False)

                    prev = scheduler.step_x0(pred_x0_curr, t_batch, curr)

                    if t > 0:
                        next_res = get_resolution_for_timestep(t - 1, scheduler.num_timesteps, high_res, latent_res, k_step)
                        if prev.shape[-1] != next_res:
                            curr = F.interpolate(prev, size=(next_res, next_res), mode='bilinear', align_corners=False)
                        else:
                            curr = prev
                    else:
                        curr = prev
                
                pbar_diffusion.close()

                # Final upscale to high_res
                if curr.shape[-1] != high_res:
                    curr = F.interpolate(curr, size=(high_res, high_res), mode='bilinear', align_corners=False)

                out_frames.append(curr)
                pbar_frames.update(1)
            
            pbar_frames.close()
            
            out = torch.stack(out_frames, dim=1)  # [B,T,3,H,W]

    return out



In [None]:
# ====== Config (EDIT ONLY THIS CELL) ======
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# --- Checkpoints ---
PLANNER_CKPT = '/workspace/planner/experiments/20260123_073354_planner_32_large/checkpoints/model_epoch_1000.pth'   # e.g. '/workspace/planner/experiments/XXXX/checkpoints/model_epoch_500.pth'
REFINER_CKPT = '/workspace/experiments/20260120_100656_obj_weight_5/checkpoints/model_epoch_30.pt'   # e.g. '/workspace/experiments/XXXX/checkpoints/model_epoch_500.pt'

# --- Data (FORCED: 2nd-stage preprocessed low-res) ---
LOWRES_ROOT = '/workspace/data/processed_32'  # must be 32x32 .pt
SPLIT = 'val'
OUT_DIR = '/workspace/inferences/012345'

# Single-sample mode
SAMPLE_INDEX = 0  # which video file to use from LOWRES_ROOT/SPLIT

# Multi-sample mode (batch inference)
# - set to a list like [0,1,2,3] to run multiple videos in one batch
# - if empty/None, falls back to [SAMPLE_INDEX]
# - Example: SAMPLE_INDICES = [0, 1, 2, 3]  # Process 4 samples in parallel
SAMPLE_INDICES = [0, 1, 2, 3, 4, 5]

# How many samples to render to mp4 (one mp4 per sample)
# - None: use all samples in SAMPLE_INDICES (or just SAMPLE_INDEX if SAMPLE_INDICES is empty)
# - e.g. 4: only render first 4 samples
NUM_SAMPLES_TO_RENDER = None  # e.g. 4, or None to use all in SAMPLE_INDICES

# --- Sequence lengths ---
COND_FRAMES_LIST = [1, 3, 5, 10, 15]
COND_FRAMES = 5 
TOTAL_FRAMES = 100   # total frames to generate/refine (clipped by sample length)
T_RENDER = 100        # how many frames to run the refiner + write mp4 (cost control)

# Figure 1: Qualitative Video Generation
FIGURE1_TIMESTEPS = [10, 30, 60, 100]
FIGURE1_OUT_DIR = os.path.join(OUT_DIR, 'figures')

# Figure 2: Trajectory Comparison
FIGURE2_NUM_SAMPLES = 4
FIGURE2_OUT_DIR = os.path.join(OUT_DIR, 'figures')

# Figure 3: Conditioning Effect Analysis
FIGURE3_COND_FRAMES_LIST = COND_FRAMES_LIST
FIGURE3_SEQUENCE_INDEX = 0
FIGURE3_OUT_DIR = os.path.join(OUT_DIR, 'figures')

# Figure 4: Temporal Consistency Analysis
FIGURE4_SEQUENCE_INDEX = 0
FIGURE4_OUT_DIR = os.path.join(OUT_DIR, 'figures')

DATASET_FIG_NUM_SAMPLES_GRID = 12
DATASET_FIG_NUM_SEQUENCES_TIMELINE = 3
DATASET_FIG_TIMESTEPS = [0, 20, 40, 60, 80, 100]
DATASET_FIG_NUM_TRAJECTORIES = 4
DATASET_FIG_OUT_DIR = os.path.join(OUT_DIR, 'figures')


# --- Resolutions / schedule ---
LATENT_RES = 32
HIGH_RES = 128
K_STEP = 1

# --- Refiner sampling strength ---
# 0.0~1.0. Lower => less destruction (more refinement). Recommended: 0.05~0.2
T_START_FRAC = 0.1

# --- Refiner batch processing ---
# If True, process all frames in parallel (B*T batch). Much faster but uses more GPU memory.
# If False, process frames sequentially (slower but uses less memory).
REFINER_BATCH_FRAMES = True

# --- Refiner model architecture (must match the checkpoint you load) ---
REF_MODEL_CHANNELS = 128
REF_CHANNEL_MULT = (1, 2, 4, 8)
REF_NUM_RES_BLOCKS = 2

# --- Outputs ---
OUT_MP4 = os.path.join(OUT_DIR, 'out_end2end.mp4')
FPS = 10

COND_PREVIEW_PNG = os.path.join(OUT_DIR, 'cond_5.png')
LOWRES_PREVIEW_PNG = os.path.join(OUT_DIR, 'lowres_preview.png')
HIGHRES_PREVIEW_PNG = os.path.join(OUT_DIR, 'highres_preview.png')

print('DEVICE:', DEVICE)
print('LOWRES_ROOT:', LOWRES_ROOT)
print('SPLIT:', SPLIT, 'SAMPLE_INDEX:', SAMPLE_INDEX)
print('OUT_MP4:', OUT_MP4)
print('T_RENDER:', T_RENDER)

In [None]:
print('COND_FRAMES_LIST:', COND_FRAMES_LIST)
print('Total SAMPLE_INDICES :', len(SAMPLE_INDICES))

In [None]:
# ====== Load low-res GT samples (for conditioning) ======
pt_files = sorted(glob.glob(os.path.join(LOWRES_ROOT, SPLIT, '*.pt')))
assert len(pt_files) > 0, f'No .pt files found in {LOWRES_ROOT}/{SPLIT}'

# Determine which samples to load
if SAMPLE_INDICES and len(SAMPLE_INDICES) > 0:
    indices_to_load = [int(i) % len(pt_files) for i in SAMPLE_INDICES]
else:
    indices_to_load = [int(SAMPLE_INDEX) % len(pt_files)]

if NUM_SAMPLES_TO_RENDER is not None:
    indices_to_load = indices_to_load[:int(NUM_SAMPLES_TO_RENDER)]

print(f'Loading {len(indices_to_load)} samples: {indices_to_load}')

# Load all samples
conds_list = []
seqs_list = []
sample_paths = []

for idx in indices_to_load:
    sample_path = pt_files[idx]
    sample_paths.append(sample_path)
    print(f'  [{idx}] {os.path.basename(sample_path)}')
    
    x = torch.load(sample_path, map_location='cpu')  # uint8 [T,3,32,32]
    assert torch.is_tensor(x) and x.ndim == 4 and x.shape[1] == 3, f"{sample_path}: {x.shape}"
    
    # normalize to [-1,1]
    seq = (x.float() / 255.0 - 0.5) * 2.0  # [T,3,32,32]
    
    # pick first TOTAL_FRAMES (or shorter)
    T_avail = seq.shape[0]
    T_use = min(int(TOTAL_FRAMES), int(T_avail))
    seq = seq[:T_use]
    
    max_cond_frames = max(COND_FRAMES_LIST) if isinstance(COND_FRAMES_LIST, list) else COND_FRAMES
    cond = seq[:int(max_cond_frames)]
    conds_list.append(cond)
    seqs_list.append(seq)

if isinstance(COND_FRAMES_LIST, list) and len(COND_FRAMES_LIST) > 0:
    cond_frames_list_to_process = COND_FRAMES_LIST
else:
    cond_frames_list_to_process = [COND_FRAMES] if isinstance(COND_FRAMES, int) else [5]

print(f'Will process cond_frames: {cond_frames_list_to_process}')

cond_frames_to_use = cond_frames_list_to_process[0]
cond_batch = torch.stack([cond[:int(cond_frames_to_use)] for cond in conds_list], dim=0).to(DEVICE)
print(f'Preview cond_batch: {cond_batch.shape} (using cond_frames={cond_frames_to_use})')

# Visualize first sample's condition frames
imgs = [to_uint8_img(cond_batch[0:1, i]) for i in range(int(cond_frames_to_use))]
Image.fromarray(np.hstack(imgs)).save(COND_PREVIEW_PNG)
print('Saved', COND_PREVIEW_PNG)

In [None]:
# ====== Load Planner ======
config = ConfigPlanner()

# Force settings consistent with low-res
config.target_res = LATENT_RES
config.max_seq_len = TOTAL_FRAMES
config.latent_dim = 3 * LATENT_RES * LATENT_RES

planner = LatentPlanner(config).to(DEVICE)

assert PLANNER_CKPT, 'Set PLANNER_CKPT path in the config cell'
state = torch.load(PLANNER_CKPT, map_location='cpu')
# train_planner.py saves state_dict directly
planner.load_state_dict(state, strict=False)
planner.eval()

print('Loaded planner:', PLANNER_CKPT)

In [None]:
# ====== Load Refiner ======
assert REFINER_CKPT, 'Set REFINER_CKPT path in the config cell'

refiner = AsymmetricUNet(
    in_channels=3,
    out_channels=3,
    model_channels=int(REF_MODEL_CHANNELS),
    channel_mult=tuple(REF_CHANNEL_MULT),
    num_res_blocks=int(REF_NUM_RES_BLOCKS),
).to(DEVICE)

ref_state = torch.load(REFINER_CKPT, map_location='cpu')
refiner.load_state_dict(ref_state, strict=False)
refiner.eval()

noise_scheduler = NoiseScheduler(device=DEVICE)

print('Loaded refiner:', REFINER_CKPT)

In [None]:
# ====== Model summaries (Planner / Refiner) ======

params_sum = 0
def count_params(m: torch.nn.Module):
    total = sum(p.numel() for p in m.parameters())
    trainable = sum(p.numel() for p in m.parameters() if p.requires_grad)
    return total, trainable

print("\n=== LatentPlanner Model Summary ===")
if 'planner' not in globals():
    print("[WARN] planner is not loaded yet. Run the 'Load Planner' cell first.")
else:
    total, trainable = count_params(planner)
    print(f"Params: {total/1e6:.3f}M (trainable {trainable/1e6:.3f}M)")
    params_sum += total

print("\n=== AsymmetricUNet (Refiner) Model Summary ===")
if 'refiner' not in globals():
    print("[WARN] refiner is not loaded yet. Run the 'Load Refiner' cell first.")
else:
    total, trainable = count_params(refiner)
    print(f"Params: {total/1e6:.3f}M (trainable {trainable/1e6:.3f}M)")
    params_sum += total

print(f"\nTotal params: {params_sum/1e6:.3f}M")
print(f"\nTotal params: {params_sum/1e9:.3f}B")



In [None]:
# ====== Planner generation (LLM-style, batch processing) ======
T_uses = [min(int(TOTAL_FRAMES), seq.shape[0]) for seq in seqs_list]
T_use_batch = min(T_uses)  # Use minimum length for batch consistency

lowres_gen_batches = {}  # {cond_frames: lowres_gen_batch}
highres_gen_batches = {}  # {cond_frames: highres_gen_batch}

for cond_frames_val in cond_frames_list_to_process:
    print(f'\n{"="*60}')
    print(f'Processing cond_frames={cond_frames_val}')
    print(f'{"="*60}')
    
    cond_batch_cf = torch.stack([cond[:int(cond_frames_val)] for cond in conds_list], dim=0).to(DEVICE)
    print(f'cond_batch: {cond_batch_cf.shape} (using cond_frames={cond_frames_val})')
    
    lowres_gen_batch = planner_generate(planner, cond_btchw=cond_batch_cf, total_T=T_use_batch, show_progress=True)
    print(f'lowres_gen_batch: {lowres_gen_batch.shape}')  # [B, T, 3, 32, 32]
    
    lowres_gen_batches[cond_frames_val] = lowres_gen_batch
    
    if cond_frames_val == cond_frames_list_to_process[0]:
        preview = [to_uint8_img(lowres_gen_batch[0:1, i]) for i in range(min(10, lowres_gen_batch.shape[1]))]
        Image.fromarray(np.hstack(preview)).save(LOWRES_PREVIEW_PNG)
        print('Saved', LOWRES_PREVIEW_PNG)

In [None]:
# ====== Refine low-res sequences to high-res (batch processing) ======
# NOTE: this is expensive. Control cost via T_RENDER in the config cell.
T_RENDER_EFF = min(int(T_RENDER), T_use_batch)

for cond_frames_val in cond_frames_list_to_process:
    print(f'\n{"="*60}')
    print(f'Refining cond_frames={cond_frames_val}')
    print(f'{"="*60}')
    
    lowres_gen_batch = lowres_gen_batches[cond_frames_val]
    lowres_subset_batch = lowres_gen_batch[:, :T_RENDER_EFF]
    
    print(f'Refining {lowres_subset_batch.shape[0]} samples, {T_RENDER_EFF} frames each...')
    
    highres_batch = refiner_refine_sequence(
        refiner=refiner,
        scheduler=noise_scheduler,
        lowres_btchw=lowres_subset_batch,
        high_res=int(HIGH_RES),
        latent_res=int(LATENT_RES),
        k_step=int(K_STEP),
        t_start_frac=float(T_START_FRAC),
        batch_frames=bool(REFINER_BATCH_FRAMES),
    )
    print(f'highres_batch: {highres_batch.shape}')  # [B, T, 3, 128, 128]
    
    highres_gen_batches[cond_frames_val] = highres_batch
    
    if cond_frames_val == cond_frames_list_to_process[0]:
        preview_hr = [to_uint8_img(highres_batch[0:1, i]) for i in range(min(10, highres_batch.shape[1]))]
        Image.fromarray(np.hstack(preview_hr)).save(HIGHRES_PREVIEW_PNG)
        print('Saved', HIGHRES_PREVIEW_PNG)


In [None]:
highres_batch = highres_gen_batches[cond_frames_list_to_process[0]]
lowres_gen_batch = lowres_gen_batches[cond_frames_list_to_process[0]]

In [None]:
os.makedirs(OUT_DIR, exist_ok=True)

B = highres_batch.shape[0]
T = highres_batch.shape[1]

for b_idx in range(B):
    sample_name = os.path.splitext(os.path.basename(sample_paths[b_idx]))[0]
    out_mp4_path = os.path.join(OUT_DIR, f'out_sample_{b_idx:02d}_{sample_name}.mp4')
    
    with imageio.get_writer(out_mp4_path, fps=FPS, codec='libx264', quality=8) as w:
        for t_idx in range(T):
            frame = to_uint8_img(highres_batch[b_idx:b_idx+1, t_idx])
            w.append_data(frame)
    
    print(f'[{b_idx+1}/{B}] Wrote: {out_mp4_path}')

if B > 0:
    with imageio.get_writer(OUT_MP4, fps=FPS, codec='libx264', quality=8) as w:
        for t_idx in range(T):
            frame = to_uint8_img(highres_batch[0:1, t_idx])
            w.append_data(frame)
    print(f'Also wrote default: {OUT_MP4}')

In [None]:
GT_HIGHRES_ROOT = '/workspace/data/processed'


# Get high-res file list (same order as low-res)
gt_pt_files = sorted(glob.glob(os.path.join(GT_HIGHRES_ROOT, SPLIT, '*.pt')))
assert len(gt_pt_files) > 0, f'No .pt files found in {GT_HIGHRES_ROOT}/{SPLIT}'

print(f'Loading GT high-res data using same indices: {indices_to_load}')

gt_highres_list = []
for idx, orig_idx in enumerate(indices_to_load):
    # Use the same index from high-res file list
    gt_path = gt_pt_files[orig_idx] if orig_idx < len(gt_pt_files) else None
    
    if gt_path and os.path.exists(gt_path):
        try:
            x_gt = torch.load(gt_path, map_location='cpu')  # uint8 [T, 3, H, W]
            assert torch.is_tensor(x_gt) and x_gt.ndim == 4 and x_gt.shape[1] == 3, f"{gt_path}: {x_gt.shape}"
            
            # Normalize to [-1,1] and take same length as used
            seq_gt = (x_gt.float() / 255.0 - 0.5) * 2.0  # [T, 3, H, W]
            T_use_gt = min(seq_gt.shape[0], T_RENDER_EFF)
            seq_gt = seq_gt[:T_use_gt]
            
            if seq_gt.shape[-1] != HIGH_RES or seq_gt.shape[-2] != HIGH_RES:
                seq_gt = F.interpolate(
                    seq_gt.permute(0, 1, 2, 3).reshape(-1, 3, seq_gt.shape[-2], seq_gt.shape[-1]),
                    size=(HIGH_RES, HIGH_RES),
                    mode='bilinear',
                    align_corners=False
                ).reshape(T_use_gt, 3, HIGH_RES, HIGH_RES)
                print(f'  [{idx}] (orig_idx={orig_idx}) Upscaled GT: {os.path.basename(gt_path)} ({T_use_gt} frames, {seq_gt.shape[-1]}x{seq_gt.shape[-2]})')
            else:
                print(f'  [{idx}] (orig_idx={orig_idx}) Loaded GT: {os.path.basename(gt_path)} ({T_use_gt} frames, {seq_gt.shape[-1]}x{seq_gt.shape[-2]})')
            
            gt_highres_list.append(seq_gt.to(DEVICE))
        except Exception as e:
            print(f'  [{idx}] (orig_idx={orig_idx}) Error loading GT: {e}')
            seq_gt_fallback = F.interpolate(
                seqs_list[idx][:T_RENDER_EFF].permute(0, 1, 2, 3).reshape(-1, 3, LATENT_RES, LATENT_RES).to(DEVICE),
                size=(HIGH_RES, HIGH_RES),
                mode='bilinear',
                align_corners=False
            ).reshape(T_RENDER_EFF, 3, HIGH_RES, HIGH_RES)
            gt_highres_list.append(seq_gt_fallback)
            print(f'  [{idx}] Using upscaled low-res as GT fallback')
    else:
        print(f'  [{idx}] (orig_idx={orig_idx}) WARNING: GT file not found, using upscaled low-res')
        seq_gt_fallback = F.interpolate(
            seqs_list[idx][:T_RENDER_EFF].permute(0, 1, 2, 3).reshape(-1, 3, LATENT_RES, LATENT_RES).to(DEVICE),
            size=(HIGH_RES, HIGH_RES),
            mode='bilinear',
            align_corners=False
        ).reshape(T_RENDER_EFF, 3, HIGH_RES, HIGH_RES)
        gt_highres_list.append(seq_gt_fallback)

gt_highres_batch = torch.stack(gt_highres_list, dim=0)  # [B, T, 3, 128, 128]
print(f'gt_highres_batch: {gt_highres_batch.shape}')

In [None]:
# ====== Ensure GT high-res batch has correct shape ======
B = highres_batch.shape[0]
T = highres_batch.shape[1]

if gt_highres_batch.shape[-1] != HIGH_RES or gt_highres_batch.shape[-2] != HIGH_RES:
    print(f'Upscaling GT from {gt_highres_batch.shape[-2]}x{gt_highres_batch.shape[-1]} to {HIGH_RES}x{HIGH_RES}')
    gt_highres_batch = F.interpolate(
        gt_highres_batch.reshape(B * T, 3, gt_highres_batch.shape[-2], gt_highres_batch.shape[-1]),
        size=(HIGH_RES, HIGH_RES),
        mode='bilinear',
        align_corners=False
    ).reshape(B, T, 3, HIGH_RES, HIGH_RES)

print(f'Final shapes - highres_batch: {highres_batch.shape}, gt_highres_batch: {gt_highres_batch.shape}')



In [None]:
# ====== Figure 1: Qualitative Video Generation ======

os.makedirs(FIGURE1_OUT_DIR, exist_ok=True)

B = highres_batch.shape[0]
T = highres_batch.shape[1]

timesteps_to_plot = [t for t in FIGURE1_TIMESTEPS if t <= T]
timesteps_indices = [t - 1 for t in timesteps_to_plot]

if len(timesteps_to_plot) == 0:
    print(f"Warning: No valid timesteps found. T={T}, requested={FIGURE1_TIMESTEPS}")
    timesteps_to_plot = [min(10, T), min(30, T), min(60, T), min(100, T)]
    timesteps_to_plot = sorted(set(timesteps_to_plot))
    timesteps_indices = [t - 1 for t in timesteps_to_plot]

print(f"Creating Figure 1 with timesteps: {timesteps_to_plot}")

for b_idx in range(B):
    n_timesteps = len(timesteps_to_plot)
    fig = plt.figure(figsize=(2 + 3*n_timesteps, 9))
    gs = fig.add_gridspec(4, n_timesteps + 1, 
                          width_ratios=[0.3] + [1]*n_timesteps, 
                          height_ratios=[0.2, 1, 1, 1], 
                          hspace=0.15, wspace=0.1)
    
    for col_idx, t in enumerate(timesteps_to_plot):
        ax = fig.add_subplot(gs[0, col_idx + 1])
        ax.text(0.5, 0.5, f't={t}', fontsize=20, fontweight='bold', 
                ha='center', va='center', transform=ax.transAxes)
        ax.axis('off')
    
    ax_label_gt = fig.add_subplot(gs[1, 0])
    ax_label_gt.text(0.5, 0.5, 'GT', fontsize=20, fontweight='bold',
                     ha='center', va='center', transform=ax_label_gt.transAxes)
    ax_label_gt.axis('off')
    
    for col_idx, t_idx in enumerate(timesteps_indices):
        ax = fig.add_subplot(gs[1, col_idx + 1])
        gt_frame = to_uint8_img(gt_highres_batch[b_idx:b_idx+1, t_idx])
        ax.imshow(gt_frame)
        ax.axis('off')
    
    ax_label_ours = fig.add_subplot(gs[2, 0])
    ax_label_ours.text(0.5, 0.5, 'Ours', fontsize=20, fontweight='bold',
                       ha='center', va='center', transform=ax_label_ours.transAxes)
    ax_label_ours.axis('off')
    
    for col_idx, t_idx in enumerate(timesteps_indices):
        ax = fig.add_subplot(gs[2, col_idx + 1])
        pred_frame = to_uint8_img(highres_batch[b_idx:b_idx+1, t_idx])
        ax.imshow(pred_frame)
        ax.axis('off')
    
    ax_label_overlay = fig.add_subplot(gs[3, 0])
    ax_label_overlay.text(0.5, 0.5, 'Overlay', fontsize=20, fontweight='bold',
                          ha='center', va='center', transform=ax_label_overlay.transAxes)
    ax_label_overlay.axis('off')
    for col_idx, t_idx in enumerate(timesteps_indices):
        ax = fig.add_subplot(gs[3, col_idx + 1])
        gt_frame = to_uint8_img(gt_highres_batch[b_idx:b_idx+1, t_idx])
        pred_frame = to_uint8_img(highres_batch[b_idx:b_idx+1, t_idx])
        
        ax.imshow(gt_frame, alpha=0.7)
        
        gt_pos = extract_ball_position(gt_frame, threshold=0.7)
        if gt_pos is not None:
            y_gt, x_gt = gt_pos
            circle_gt = plt.Circle((x_gt, y_gt), radius=5, color='red', 
                                   fill=False, linewidth=3, label='GT' if col_idx == 0 else '')
            ax.add_patch(circle_gt)
        
        pred_pos = extract_ball_position(pred_frame, threshold=0.7)
        if pred_pos is not None:
            y_pred, x_pred = pred_pos
            circle_pred = plt.Circle((x_pred, y_pred), radius=5, color='lime', 
                                     fill=False, linewidth=3, label='Ours' if col_idx == 0 else '')
            ax.add_patch(circle_pred)
        
        ax.axis('off')
    
    sample_name = os.path.splitext(os.path.basename(sample_paths[b_idx]))[0] if b_idx < len(sample_paths) else f'sample_{b_idx}'
    save_path = os.path.join(FIGURE1_OUT_DIR, f'figure1_sample_{b_idx:02d}_{sample_name}.png')
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f'Saved: {save_path}')
    
    plt.show()

print(f'\nFigure 1 saved to: {FIGURE1_OUT_DIR}')

In [None]:

FIGURE2_NUM_SAMPLES = 35
FIGURE2_OUT_DIR = os.path.join(OUT_DIR, 'figures')

os.makedirs(FIGURE2_OUT_DIR, exist_ok=True)

B = highres_batch.shape[0]
T = highres_batch.shape[1]

num_samples_to_plot = min(B, FIGURE2_NUM_SAMPLES)

print(f"Creating Figure 2 with {num_samples_to_plot} samples")

n_cols = 7
n_rows = (num_samples_to_plot + n_cols - 1) // n_cols

fig, axes = plt.subplots(n_rows, n_cols, figsize=(35, 25), squeeze=False)
fig.subplots_adjust(hspace=0.5)
axes_flat = axes.flatten()

for plot_idx in range(num_samples_to_plot):
    b_idx = plot_idx
    
    gt_trajectory = extract_trajectory(gt_highres_batch[b_idx:b_idx+1], threshold=0.7)
    
    pred_trajectory = extract_trajectory(highres_batch[b_idx:b_idx+1], threshold=0.7)
    
    ax = axes_flat[plot_idx]
    
    if len(gt_trajectory) > 0 and len(pred_trajectory) > 0:
        gt_x = [pos[1] for pos in gt_trajectory]
        gt_y = [pos[0] for pos in gt_trajectory]
        pred_x = [pos[1] for pos in pred_trajectory]
        pred_y = [pos[0] for pos in pred_trajectory]
        
        if 'cond_frames_list_to_process' in globals() and len(cond_frames_list_to_process) > 0:
            cond_frames_used = cond_frames_list_to_process[0]
        elif 'cond_frames_to_use' in globals():
            cond_frames_used = cond_frames_to_use
        else:
            cond_frames_used = COND_FRAMES if isinstance(COND_FRAMES, int) else 5
        cond_end_idx = min(cond_frames_used, len(gt_trajectory), len(pred_trajectory))
        
        ax.plot(gt_x[:cond_end_idx+1], gt_y[:cond_end_idx+1], 
                color='blue', linewidth=5, linestyle='-', 
                label='input ('+str(cond_end_idx)+' frames)', alpha=0.9, zorder=4)
        ax.plot(gt_x[cond_end_idx:], gt_y[cond_end_idx:], 
                'r-.', linewidth=5, label='GT', alpha=0.9, zorder=2)
        
        ax.plot(pred_x[cond_end_idx:], pred_y[cond_end_idx:], 
                'g--', linewidth=5, label='Ours', alpha=0.9, dashes=(5, 3), zorder=2)
        
    
    W, H = gt_highres_batch[0].shape[2:]
    ax.set_aspect('equal')
    ax.set_xlim(0, W-1)
    ax.set_ylim(H-1, 0)
    ax.grid(True, alpha=0.3, linestyle='--')
    ax.set_facecolor('white')
    ax.tick_params(labelsize=20)

for idx in range(num_samples_to_plot, len(axes_flat)):
    axes_flat[idx].axis('off')
if 'cond_frames_list_to_process' in globals() and len(cond_frames_list_to_process) > 0:
    cond_frames_for_legend = cond_frames_list_to_process[0]
elif 'cond_frames_to_use' in globals():
    cond_frames_for_legend = cond_frames_to_use
else:
    cond_frames_for_legend = COND_FRAMES if isinstance(COND_FRAMES, int) else 5

from matplotlib.lines import Line2D
legend_elements = [
    Line2D([0], [0], color='blue', linewidth=3, linestyle='-', label=f'Input ({cond_frames_for_legend} frames)'),
    Line2D([0], [0], color='r', linewidth=3, linestyle='-.', label='GT'),
    Line2D([0], [0], color='g', linewidth=3, linestyle='--', label='Ours'),
]
fig.legend(handles=legend_elements, loc='lower center', bbox_to_anchor=(0.5, -0.02), 
           ncol=4, fontsize=30, frameon=True, fancybox=True, shadow=True)

plt.tight_layout()
plt.subplots_adjust(bottom=0.05)
save_path = os.path.join(FIGURE2_OUT_DIR, 'figure2_trajectory_comparison.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
print(f'Saved: {save_path}')

plt.show()

print(f'\nFigure 2 saved to: {FIGURE2_OUT_DIR}')

In [None]:
os.makedirs(FIGURE3_OUT_DIR, exist_ok=True)

if 'cond_frames_list_to_process' in globals():
    cond_frames_to_use_for_fig3 = cond_frames_list_to_process
    print(f"Using cond_frames from cond_frames_list_to_process: {cond_frames_to_use_for_fig3}")
else:
    cond_frames_to_use_for_fig3 = FIGURE3_COND_FRAMES_LIST
    print(f"Using cond_frames from FIGURE3_COND_FRAMES_LIST: {cond_frames_to_use_for_fig3}")

cond_frames_to_use_for_fig3 = cond_frames_to_use_for_fig3[:-1]
print(f"Using fixed sequence index: {cond_frames_to_use_for_fig3}")

if 'highres_gen_batches' not in globals():
    raise RuntimeError("highres_gen_batches not found. Please run planner and refiner cells first.")
print(f"Found highres_gen_batches with keys: {list(highres_gen_batches.keys())}")

gt_trajectory_full = extract_trajectory(gt_highres_batch[FIGURE3_SEQUENCE_INDEX:FIGURE3_SEQUENCE_INDEX+1], threshold=0.7)
print(f"GT trajectory length: {len(gt_trajectory_full)}")

cond_frames_results = {} 
for cond_frames_val in cond_frames_to_use_for_fig3:
    print(f"\n--- Processing cond_frames={cond_frames_val} ---")
    
    if 'highres_gen_batches' not in globals():
        raise RuntimeError(f"highres_gen_batches not found. Please run planner and refiner cells first.")
    
    if cond_frames_val not in highres_gen_batches:
        available_keys = list(highres_gen_batches.keys())
        raise KeyError(
            f"cond_frames={cond_frames_val} not found in highres_gen_batches. "
            f"Available keys: {available_keys}. "
            f"Please run planner and refiner cells with cond_frames={cond_frames_val} first."
        )
    
    print(f"  Using pre-generated result from highres_gen_batches")
    highres_pred = highres_gen_batches[cond_frames_val]
    pred_trajectory = extract_trajectory(highres_pred[FIGURE3_SEQUENCE_INDEX:FIGURE3_SEQUENCE_INDEX+1], threshold=0.7)

    input_trajectory = extract_trajectory(
        gt_highres_batch[FIGURE3_SEQUENCE_INDEX:FIGURE3_SEQUENCE_INDEX+1, :cond_frames_val], 
        threshold=0.7
    )
    
    cond_frames_results[cond_frames_val] = {
        'gt': gt_trajectory_full,
        'pred': pred_trajectory,
        'input': input_trajectory
    }
    
    print(f"  Pred trajectory length: {len(pred_trajectory)}")
    print(f"  Input trajectory length: {len(input_trajectory)}")

fig, axes = plt.subplots(2, 3, figsize=(6*3, 6*2), squeeze=False)
axes_flat = axes.flatten()

for col_idx, cond_frames_val in enumerate(cond_frames_to_use_for_fig3):
    ax = axes_flat[col_idx]
    
    results = cond_frames_results[cond_frames_val]
    gt_traj = results['gt']
    pred_traj = results['pred']
    input_traj = results['input']
    
    gt_x = [pos[1] for pos in gt_traj]
    gt_y = [pos[0] for pos in gt_traj]
    pred_x = [pos[1] for pos in pred_traj]
    pred_y = [pos[0] for pos in pred_traj]
    input_x = [pos[1] for pos in input_traj]
    input_y = [pos[0] for pos in input_traj]
    
    if len(input_x) > 0:
        ax.plot(input_x, input_y, color='blue', linewidth=5, linestyle='-', 
               label='Input', alpha=0.9, zorder=4)
    
    if len(gt_x) > 0:
        ax.plot(gt_x, gt_y, color='red', linewidth=5, linestyle='-', 
               label='GT', alpha=0.9, zorder=2)
    
    if len(pred_x) > 0:
        ax.plot(pred_x, pred_y, color='green', linewidth=5, linestyle='--', 
               dashes=(5, 3), label='Predicted', alpha=0.9, zorder=3)
    
    ax.set_aspect('equal')
    ax.invert_yaxis()
    ax.set_title(f'cond frames={cond_frames_val}', fontsize=30, fontweight='bold')
    ax.tick_params(labelsize=20)
    ax.grid(True, alpha=0.3, linestyle='--')
    ax.set_facecolor('white')

from matplotlib.lines import Line2D
legend_elements = [
    Line2D([0], [0], color='blue', linewidth=3, linestyle='-', label='Input'),
    Line2D([0], [0], color='red', linewidth=2.5, linestyle='-', label='GT'),
    Line2D([0], [0], color='green', linewidth=2.5, linestyle='--', dashes=(5, 3), label='Predicted'),
]
legend_ax = axes_flat[-1]
legend_ax.axis('off')

legend_ax.legend(
    handles=legend_elements,
    loc='center',
    shadow=True,
    fontsize=30,
    ncol=1
)
plt.tight_layout()
save_path = os.path.join(FIGURE3_OUT_DIR, 'figure3_conditioning_effect.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
print(f'\nSaved: {save_path}')

plt.show()

print(f'\nFigure 3 saved to: {FIGURE3_OUT_DIR}')

In [None]:
import os
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

os.makedirs(FIGURE4_OUT_DIR, exist_ok=True)

print(f"Creating Figure 4: Temporal Consistency Analysis (GT-deviation)")
print(f"Using sequence index: {FIGURE4_SEQUENCE_INDEX}")

def compute_frame_diff_magnitude(video: torch.Tensor) -> np.ndarray:
    if video.ndim == 5:  # [B, T, C, H, W]
        video = video[0]  # use first batch

    T = video.shape[0]
    if T < 2:
        return np.array([], dtype=np.float64)

    # Δ_t = mean(|x_t - x_{t-1}|)
    diffs = torch.mean(torch.abs(video[1:] - video[:-1]), dim=(1, 2, 3))  # [T-1]
    return diffs.detach().cpu().numpy().astype(np.float64)

def compute_temporal_deviation(pred_video: torch.Tensor, gt_video: torch.Tensor) -> np.ndarray:
    pred_d = compute_frame_diff_magnitude(pred_video)
    gt_d = compute_frame_diff_magnitude(gt_video)

    L = min(len(pred_d), len(gt_d))
    if L == 0:
        return np.array([], dtype=np.float64)

    return np.abs(pred_d[:L] - gt_d[:L])

def downsample_video_to_match(gt_video: torch.Tensor, ref_video: torch.Tensor) -> torch.Tensor:
    assert gt_video.ndim == 5 and ref_video.ndim == 5
    _, T, C, _, _ = gt_video.shape
    _, T2, C2, h, w = ref_video.shape
    assert T == T2 and C == C2, "GT and ref must have same T and C to match temporal comparison."

    # reshape to [T, C, H, W] for interpolate (treat T as batch)
    x = gt_video[0]  # [T, C, H, W]
    x_ds = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=False)  # [T, C, h, w]
    return x_ds.unsqueeze(0)  # [1, T, C, h, w]

idx = 3
cond_frames = cond_frames_list_to_process[idx]
cond_frames = COND_FRAMES if isinstance(COND_FRAMES, int) else 5

print(f"Using cond_frames: {cond_frames}")

# ---- GT (high-res) ----
gt_high = gt_highres_batch[FIGURE4_SEQUENCE_INDEX:FIGURE4_SEQUENCE_INDEX+1]  # [1, T, C, H, W]

# ---- Planner (low-res) ----
if 'lowres_gen_batches' not in globals():
    raise RuntimeError("lowres_gen_batches not found. Please run planner generation first.")

if cond_frames not in lowres_gen_batches:
    available_keys = list(lowres_gen_batches.keys())
    raise KeyError(f"cond_frames={cond_frames} not found in lowres_gen_batches. Available: {available_keys}")

planner_low = lowres_gen_batches[cond_frames][FIGURE4_SEQUENCE_INDEX:FIGURE4_SEQUENCE_INDEX+1]  # [1, T, C, h, w]

gt_low = downsample_video_to_match(gt_high, planner_low)  # [1, T, C, h, w]

# ---- End-to-End (high-res) ----
if 'highres_gen_batches' not in globals():
    raise RuntimeError("highres_gen_batches not found. Please run end-to-end (high-res) generation first.")

if cond_frames not in highres_gen_batches:
    available_keys = list(highres_gen_batches.keys())
    raise KeyError(f"cond_frames={cond_frames} not found in highres_gen_batches. Available: {available_keys}")

e2e_high = highres_gen_batches[cond_frames][FIGURE4_SEQUENCE_INDEX:FIGURE4_SEQUENCE_INDEX+1]  # [1, T, C, H, W]

# ---- Compute temporal deviation curves ----
planner_temporal_err = compute_temporal_deviation(planner_low, gt_low)
print(f"Planner temporal deviation shape: {planner_temporal_err.shape}")

e2e_temporal_err = compute_temporal_deviation(e2e_high, gt_high)
print(f"End-to-End temporal deviation shape: {e2e_temporal_err.shape}")

L = min(len(planner_temporal_err), len(e2e_temporal_err))
frame_indices = np.arange(2, L + 2)

# ---- Plot ----
fig, ax = plt.subplots(1, 1, figsize=(12, 6))

ax.plot(frame_indices, planner_temporal_err[:L],
        linewidth=2.5, linestyle='-',
        label='Planner (low-res)', alpha=0.9, zorder=2)

ax.plot(frame_indices, e2e_temporal_err[:L],
        linewidth=2.5, linestyle='--',
        dashes=(5, 3), label='End-to-End (high-res)', alpha=0.9, zorder=1)

# baseline
ax.axhline(0.0, linewidth=2.0, linestyle='-',
           label='GT (zero deviation)', alpha=0.8, zorder=3)

ax.set_xlabel('Video Time Step (frame index)', fontsize=14, fontweight='bold')
ax.set_ylabel(r'Temporal Deviation from GT  $|\Delta x^{pred}-\Delta x^{GT}|$', fontsize=14, fontweight='bold')
ax.set_title('Temporal Consistency Analysis (Deviation from GT Dynamics)', fontsize=16, fontweight='bold')

ax.grid(True, alpha=0.3, linestyle='--')
ax.legend(loc='best', fontsize=12, frameon=True, fancybox=True, shadow=True)
ax.set_facecolor('white')

plt.tight_layout()

save_path = os.path.join(FIGURE4_OUT_DIR, 'figure4_temporal_consistency_deviation.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
print(f'\nSaved: {save_path}')

plt.show()
print(f'\nFigure 4 saved to: {FIGURE4_OUT_DIR}')


In [None]:

os.makedirs(DATASET_FIG_OUT_DIR, exist_ok=True)

print("Creating 3 independent Dataset Example Figures")
print(f"Grid samples: {DATASET_FIG_NUM_SAMPLES_GRID}")
print(f"Timeline sequences: {DATASET_FIG_NUM_SEQUENCES_TIMELINE}")
print(f"Trajectory sequences: {DATASET_FIG_NUM_TRAJECTORIES}")

if 'gt_highres_batch' not in globals():
    raise RuntimeError("gt_highres_batch not found. Please load GT data first.")

B = gt_highres_batch.shape[0]
T = gt_highres_batch.shape[1]

num_samples_available = min(B, DATASET_FIG_NUM_SAMPLES_GRID)
num_timeline_available = min(B, DATASET_FIG_NUM_SEQUENCES_TIMELINE)
num_traj_available = min(B, DATASET_FIG_NUM_TRAJECTORIES)

print(f"Available samples: {B}, using: grid={num_samples_available}, timeline={num_timeline_available}, trajectory={num_traj_available}")

print("\n--- Creating Figure (A): Diverse Initial Conditions ---")
n_cols_grid = 3
n_rows_grid = 2
t_representative = 0

fig_a = plt.figure(figsize=(25, 17))
gs_a = fig_a.add_gridspec(n_rows_grid, n_cols_grid)

for idx in range(num_samples_available):
    row = idx // n_cols_grid
    col = idx % n_cols_grid
    if row < n_rows_grid:
        ax_small = fig_a.add_subplot(gs_a[row, col])
        frame = to_uint8_img(gt_highres_batch[idx:idx+1, t_representative])
        ax_small.imshow(frame)
        ax_small.axis('off')
        ax_small.text(10, 15, f'S{idx+1}', fontsize=40, color='white', 
                     bbox=dict(boxstyle='round', facecolor='black', alpha=0.7))
        
        if T > 1:
            frame_t0_tensor = gt_highres_batch[idx:idx+1, 0]
            frame_t1_tensor = gt_highres_batch[idx:idx+1, 1]
            frame_t0_np = to_uint8_img(frame_t0_tensor)
            frame_t1_np = to_uint8_img(frame_t1_tensor)
            pos_t0 = extract_ball_position(frame_t0_np, threshold=0.7)
            pos_t1 = extract_ball_position(frame_t1_np, threshold=0.7)
            
            if pos_t0 and pos_t1:
                y0, x0 = pos_t0
                y1, x1 = pos_t1
                dx = x1 - x0
                dy = y1 - y0
                norm = np.sqrt(dx**2 + dy**2)
                if norm > 1.0:
                    dx_norm = dx / norm
                    dy_norm = dy / norm
                    arrow_len = 25
                    start_x = x0
                    start_y = y0
                    end_x = x0 + dx_norm * arrow_len
                    end_y = y0 + dy_norm * arrow_len
                    ax_small.annotate('', 
                                        xy=(end_x, end_y),
                                        xytext=(start_x, start_y),
                                        arrowprops=dict(arrowstyle='->, head_length=1.5, head_width=1.5', color='yellow', lw=10, zorder=30))

for idx in range(num_samples_available, n_rows_grid * n_cols_grid):
    row = idx // n_cols_grid
    col = idx % n_cols_grid
    ax_empty = fig_a.add_subplot(gs_a[row, col])
    ax_empty.axis('off')

plt.tight_layout()
save_path_a = os.path.join(DATASET_FIG_OUT_DIR, 'dataset_figure_A_initial_conditions.png')
plt.savefig(save_path_a, dpi=300, bbox_inches='tight', facecolor='white')
print(f'Saved: {save_path_a}')
plt.show()


In [None]:

# ====== Figure (B): Temporal Evolution ======
timesteps_to_plot = [t for t in DATASET_FIG_TIMESTEPS if t < T]
timesteps_indices = [t for t in timesteps_to_plot]
n_timesteps = len(timesteps_indices)
n_timeline = num_timeline_available
print(f"Timeline sequences: {n_timeline}")

fig_b = plt.figure(figsize=(17, 10))
gs_b = fig_b.add_gridspec(n_timeline, n_timesteps + 1,
                         width_ratios=[0.15] + [1]*n_timesteps)

for seq_idx in range(n_timeline):
    ax_label = fig_b.add_subplot(gs_b[seq_idx, 0])
    ax_label.text(0.5, 0.5, f'S{seq_idx+1}', fontsize=20, fontweight='bold',
                  ha='center', va='center', transform=ax_label.transAxes)
    ax_label.axis('off')
    
    for col_idx, t_idx in enumerate(timesteps_indices):
        ax_frame = fig_b.add_subplot(gs_b[seq_idx, col_idx + 1])
        frame = to_uint8_img(gt_highres_batch[seq_idx:seq_idx+1, t_idx])
        ax_frame.imshow(frame)
        ax_frame.axis('off')

        if seq_idx == 0:
            ax_frame.set_title(f't={t_idx}', fontsize=25, pad=6)

plt.tight_layout()

save_path_b = os.path.join(DATASET_FIG_OUT_DIR, 'dataset_figure_B_temporal_evolution.png')
plt.savefig(save_path_b, dpi=300, bbox_inches='tight', facecolor='white')
print(f'Saved: {save_path_b}')
plt.show()

In [None]:

def compute_trajectory_error_over_time(
    pred_video: torch.Tensor,
    gt_video: torch.Tensor,
    threshold: float = 0.7
) -> np.ndarray:
    def to_uint8(tensor):
        tensor = tensor.detach().float().clamp(-1, 1)
        tensor = (tensor + 1) * 0.5
        tensor = (tensor * 255.0).round().to(torch.uint8)
        return tensor
    
    if pred_video.ndim == 4:
        pred_video = pred_video.unsqueeze(0)
        gt_video = gt_video.unsqueeze(0)
    
    B, T, C, H, W = pred_video.shape
    errors = []
    
    for t in range(T):
        pred_frame = to_uint8(pred_video[0, t]).permute(1, 2, 0).cpu().numpy()  # [H, W, 3]
        gt_frame = to_uint8(gt_video[0, t]).permute(1, 2, 0).cpu().numpy()
        
        pred_pos = extract_ball_position(pred_frame, threshold)
        gt_pos = extract_ball_position(gt_frame, threshold)
        
        if pred_pos is not None and gt_pos is not None:
            error = np.sqrt((pred_pos[0] - gt_pos[0])**2 + (pred_pos[1] - gt_pos[1])**2)
            errors.append(error)
        else:
            if len(errors) > 0:
                errors.append(errors[-1])
            else:
                errors.append(0.0)
    
    return np.array(errors)

TRAJ_ERROR_OUT_DIR = os.path.join(OUT_DIR, 'figures')
os.makedirs(TRAJ_ERROR_OUT_DIR, exist_ok=True)

TRAJ_ERROR_SEQUENCE_INDEX = 0

print("Creating Trajectory Error Analysis Figure")
print(f"Sequence index: {TRAJ_ERROR_SEQUENCE_INDEX}")

if 'gt_highres_batch' not in globals():
    raise RuntimeError("gt_highres_batch not found. Please load GT data first.")

if 'highres_gen_batches' not in globals():
    raise RuntimeError("highres_gen_batches not found. Please run planner and refiner cells first.")

if 'cond_frames_list_to_process' not in globals():
    if 'COND_FRAMES_LIST' in globals():
        cond_frames_list_to_process = COND_FRAMES_LIST
    else:
        cond_frames_list_to_process = [1, 3, 5, 10]
    print(f"Using default cond_frames_list: {cond_frames_list_to_process}")
else:
    print(f"Using cond_frames_list_to_process: {cond_frames_list_to_process}")

gt_video = gt_highres_batch[TRAJ_ERROR_SEQUENCE_INDEX:TRAJ_ERROR_SEQUENCE_INDEX+1]  # [1, T, C, H, W]

trajectory_errors = {}  # {cond_frames: error_array}

for cond_frames_val in cond_frames_list_to_process[:-1]:
    if cond_frames_val not in highres_gen_batches:
        print(f"WARNING: cond_frames={cond_frames_val} not found in highres_gen_batches, skipping...")
        continue
    
    pred_video = highres_gen_batches[cond_frames_val][TRAJ_ERROR_SEQUENCE_INDEX:TRAJ_ERROR_SEQUENCE_INDEX+1]
    
    T_min = min(pred_video.shape[1], gt_video.shape[1])
    pred_video = pred_video[:, :T_min]
    gt_video_aligned = gt_video[:, :T_min]
    
    errors = compute_trajectory_error_over_time(pred_video, gt_video_aligned, threshold=0.7)
    trajectory_errors[cond_frames_val] = errors
    print(f"cond_frames={cond_frames_val}: error shape={errors.shape}, mean={np.mean(errors):.2f}, max={np.max(errors):.2f}")

fig, ax = plt.subplots(1, 1, figsize=(20, 10))

strong_colors = [
    '#FF0000',
    '#0000FF',
    '#00AA00',
    '#FF6600',
    '#9900CC',
    '#CC0000',
    '#0066CC',
    '#000000',
]

line_styles = ['-', '--', '-.', ':', '-', '--', '-.', ':']

markers = ['o', 's', '^', 'D', 'v', 'p', '*', 'h']

sorted_keys = sorted(trajectory_errors.keys())
color_map = {}
style_map = {}
marker_map = {}

for i, cond_frames_val in enumerate(sorted_keys):
    color_map[cond_frames_val] = strong_colors[i % len(strong_colors)]
    style_map[cond_frames_val] = line_styles[i % len(line_styles)]
    marker_map[cond_frames_val] = markers[i % len(markers)]

for cond_frames_val in sorted_keys:
    errors = trajectory_errors[cond_frames_val]
    time_steps = np.arange(len(errors))
    color = color_map[cond_frames_val]
    linestyle = style_map[cond_frames_val]
    marker = marker_map[cond_frames_val]
    
    marker_every = max(1, len(time_steps) // 20)
    
    ax.plot(time_steps, errors, 
            linewidth=4.5,
            linestyle=linestyle,
            marker=marker,
            markevery=marker_every,
            markersize=10,
            label=f'cond_frames={cond_frames_val}',
            color=color,
            alpha=1.0,
            zorder=2,
            markeredgecolor='white',
            markeredgewidth=1.5)

ax.set_xlabel('Time (frame index)', fontsize=20, fontweight='bold')
ax.set_ylabel('Trajectory Error (pixels)', fontsize=20, fontweight='bold')
ax.set_title('Trajectory Error Over Time for Different Conditioning Frames', 
             fontsize=22, fontweight='bold')

ax.grid(True, alpha=0.5, linestyle='--', linewidth=1.5, zorder=0)

ax.tick_params(axis='both', which='major', labelsize=16, width=2, length=6)
ax.tick_params(axis='both', which='minor', labelsize=14)

ax.legend(loc='best', fontsize=16, frameon=True, fancybox=True, shadow=True, 
          ncol=2, framealpha=0.95, edgecolor='gray', columnspacing=1.2,
          handlelength=3.0, handletextpad=0.8, markerscale=1.5)

ax.set_facecolor('white')

plt.tight_layout()

save_path = os.path.join(TRAJ_ERROR_OUT_DIR, 'trajectory_error_multiple_cond_frames.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
print(f'\nSaved: {save_path}')

plt.show()

print(f'\nTrajectory Error Analysis Figure saved to: {TRAJ_ERROR_OUT_DIR}')

In [None]:
print("\n--- Creating Figure (C): Trajectory Patterns ---")
n_cols_traj = 3
n_rows_traj = 2

fig_c = plt.figure(figsize=(6*n_cols_traj, 6*n_rows_traj))
gs_c = fig_c.add_gridspec(n_rows_traj, n_cols_traj, hspace=0.3, wspace=0.2)

for idx in range(num_traj_available):
    row = idx // n_cols_traj
    col = idx % n_cols_traj
    ax_traj = fig_c.add_subplot(gs_c[row, col])
    
    gt_trajectory = extract_trajectory(gt_highres_batch[idx:idx+1], threshold=0.7)
    
    if len(gt_trajectory) > 0:
        gt_x = [pos[1] for pos in gt_trajectory]
        gt_y = [pos[0] for pos in gt_trajectory]
        ax_traj.plot(gt_x, gt_y, 'b-', linewidth=5, alpha=0.8, zorder=2)
        
        if len(gt_x) > 0:
            ax_traj.scatter([gt_x[0]], [gt_y[0]], c='green', s=200, marker='o', 
                           zorder=3, label='Start', edgecolors='white', linewidths=1)
            ax_traj.scatter([gt_x[-1]], [gt_y[-1]], c='red', s=200, marker='s', 
                           zorder=3, label='End', edgecolors='white', linewidths=1)
    
    W, H = gt_highres_batch[0].shape[2:]
    ax_traj.set_aspect('equal')
    ax_traj.set_xlim(0, W-1)
    ax_traj.set_ylim(H-1, 0)
    ax_traj.set_title(f'S{idx+1}', fontsize=30, fontweight='bold')
    ax_traj.grid(True, alpha=0.3, linestyle='--')
    ax_traj.set_facecolor('white')
    ax_traj.tick_params(labelsize=20)

for idx in range(num_traj_available, n_rows_traj * n_cols_traj):
    row = idx // n_cols_traj
    col = idx % n_cols_traj
    ax_empty = fig_c.add_subplot(gs_c[row, col])
    ax_empty.axis('off')

from matplotlib.lines import Line2D

legend_elements = [
    Line2D(
        [0], [0],
        color='blue',
        linewidth=5,
        label='Trajectory'
    ),
    Line2D(
        [0], [0],
        marker='o',
        color='green',
        linestyle='None',
        markersize=30,
        markeredgecolor='white',
        markeredgewidth=1.5,
        label='Start'
    ),
    Line2D(
        [0], [0],
        marker='s',
        color='red',
        linestyle='None',
        markersize=30,
        markeredgecolor='white',
        markeredgewidth=1.5,
        label='End'
    ),
]

fig_c.legend(handles=legend_elements, loc='lower center', bbox_to_anchor=(0.5, -0.02), 
           ncol=4, fontsize=30, frameon=True, fancybox=True, shadow=True)

save_path_c = os.path.join(DATASET_FIG_OUT_DIR, 'dataset_figure_C_trajectory_patterns.png')
plt.savefig(save_path_c, dpi=300, bbox_inches='tight', facecolor='white')
print(f'Saved: {save_path_c}')
plt.show()

print(f'\nAll 3 Dataset Example Figures saved to: {DATASET_FIG_OUT_DIR}')

In [None]:
import os
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

os.makedirs(FIGURE4_OUT_DIR, exist_ok=True)

print(f"Creating Figure 4: Temporal Consistency Analysis (GT-deviation)")
print(f"Using sequence index: {FIGURE4_SEQUENCE_INDEX}")

def compute_frame_diff_magnitude(video: torch.Tensor) -> np.ndarray:
    """
    Compute mean absolute frame-to-frame difference magnitude.
    video: [B, T, C, H, W] or [T, C, H, W], value range [-1, 1]
    returns: [T-1] numpy array
    """
    if video.ndim == 5:  # [B, T, C, H, W]
        video = video[0]  # use first batch

    T = video.shape[0]
    if T < 2:
        return np.array([], dtype=np.float64)

    # Δ_t = mean(|x_t - x_{t-1}|)
    diffs = torch.mean(torch.abs(video[1:] - video[:-1]), dim=(1, 2, 3))  # [T-1]
    return diffs.detach().cpu().numpy().astype(np.float64)

def compute_temporal_deviation(pred_video: torch.Tensor, gt_video: torch.Tensor) -> np.ndarray:
    """
    Temporal deviation from GT dynamics:
      err_t = | mean(|pred_t - pred_{t-1}|) - mean(|gt_t - gt_{t-1}|) |
    returns: [min(T_pred, T_gt)-1] numpy array
    """
    pred_d = compute_frame_diff_magnitude(pred_video)
    gt_d = compute_frame_diff_magnitude(gt_video)

    L = min(len(pred_d), len(gt_d))
    if L == 0:
        return np.array([], dtype=np.float64)

    return np.abs(pred_d[:L] - gt_d[:L])

def downsample_video_to_match(gt_video: torch.Tensor, ref_video: torch.Tensor) -> torch.Tensor:
    """
    Downsample gt_video spatially to match ref_video resolution.
    gt_video:  [1, T, C, H, W] (high-res GT)
    ref_video: [1, T, C, h, w] (e.g., low-res planner output)
    returns:   [1, T, C, h, w]
    """
    assert gt_video.ndim == 5 and ref_video.ndim == 5
    _, T, C, _, _ = gt_video.shape
    _, T2, C2, h, w = ref_video.shape
    assert T == T2 and C == C2, "GT and ref must have same T and C to match temporal comparison."

    # reshape to [T, C, H, W] for interpolate (treat T as batch)
    x = gt_video[0]  # [T, C, H, W]
    x_ds = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=False)  # [T, C, h, w]
    return x_ds.unsqueeze(0)  # [1, T, C, h, w]

idx = 3
cond_frames = cond_frames_list_to_process[idx]
cond_frames = COND_FRAMES if isinstance(COND_FRAMES, int) else 5

print(f"Using cond_frames: {cond_frames}")

# ---- GT (high-res) ----
gt_high = gt_highres_batch[FIGURE4_SEQUENCE_INDEX:FIGURE4_SEQUENCE_INDEX+1]  # [1, T, C, H, W]

# ---- Planner (low-res) ----
if 'lowres_gen_batches' not in globals():
    raise RuntimeError("lowres_gen_batches not found. Please run planner generation first.")

if cond_frames not in lowres_gen_batches:
    available_keys = list(lowres_gen_batches.keys())
    raise KeyError(f"cond_frames={cond_frames} not found in lowres_gen_batches. Available: {available_keys}")

planner_low = lowres_gen_batches[cond_frames][FIGURE4_SEQUENCE_INDEX:FIGURE4_SEQUENCE_INDEX+1]  # [1, T, C, h, w]

gt_low = downsample_video_to_match(gt_high, planner_low)  # [1, T, C, h, w]

# ---- End-to-End (high-res) ----
if 'highres_gen_batches' not in globals():
    raise RuntimeError("highres_gen_batches not found. Please run end-to-end (high-res) generation first.")

if cond_frames not in highres_gen_batches:
    available_keys = list(highres_gen_batches.keys())
    raise KeyError(f"cond_frames={cond_frames} not found in highres_gen_batches. Available: {available_keys}")

e2e_high = highres_gen_batches[cond_frames][FIGURE4_SEQUENCE_INDEX:FIGURE4_SEQUENCE_INDEX+1]  # [1, T, C, H, W]

# ---- Compute temporal deviation curves ----
planner_temporal_err = compute_temporal_deviation(planner_low, gt_low)
print(f"Planner temporal deviation shape: {planner_temporal_err.shape}")

e2e_temporal_err = compute_temporal_deviation(e2e_high, gt_high)
print(f"End-to-End temporal deviation shape: {e2e_temporal_err.shape}")

L = min(len(planner_temporal_err), len(e2e_temporal_err))
frame_indices = np.arange(2, L + 2)

# ---- Plot ----
fig, ax = plt.subplots(1, 1, figsize=(12, 6))

ax.plot(frame_indices, planner_temporal_err[:L],
        linewidth=2.5, linestyle='-',
        label='Planner (low-res)', alpha=0.9, zorder=2)

ax.plot(frame_indices, e2e_temporal_err[:L],
        linewidth=2.5, linestyle='--',
        dashes=(5, 3), label='End-to-End (high-res)', alpha=0.9, zorder=1)

# baseline
ax.axhline(0.0, linewidth=2.0, linestyle='-',
           label='GT (zero deviation)', alpha=0.8, zorder=3)

ax.set_xlabel('Video Time Step (frame index)', fontsize=14, fontweight='bold')
ax.set_ylabel(r'Temporal Deviation from GT  $|\Delta x^{pred}-\Delta x^{GT}|$', fontsize=14, fontweight='bold')
ax.set_title('Temporal Consistency Analysis (Deviation from GT Dynamics)', fontsize=16, fontweight='bold')

ax.grid(True, alpha=0.3, linestyle='--')
ax.legend(loc='best', fontsize=12, frameon=True, fancybox=True, shadow=True)
ax.set_facecolor('white')

plt.tight_layout()

save_path = os.path.join(FIGURE4_OUT_DIR, 'figure4_temporal_consistency_deviation.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
print(f'\nSaved: {save_path}')

plt.show()
print(f'\nFigure 4 saved to: {FIGURE4_OUT_DIR}')
