In [None]:
import os
import json
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
from typing import Dict, List, Tuple, Optional
import sys

sys.path.append('/workspace')
sys.path.append('/workspace/planner')

from planner.modules.config_planner import ConfigPlanner
from planner.modules.transformer import LatentPlanner
from planner.modules.dataset_lowres import LowResVideoDataset
from models.diffusion import AsymmetricUNet
from models.scheduler import NoiseScheduler

print("Imports completed")

In [None]:
def compute_mse(pred: torch.Tensor, gt: torch.Tensor) -> float:
    return float(F.mse_loss(pred, gt).item())


def compute_mae(pred: torch.Tensor, gt: torch.Tensor) -> float:
    return float(F.l1_loss(pred, gt).item())


def compute_psnr(pred: torch.Tensor, gt: torch.Tensor, max_val: float = 2.0) -> float:
    mse = F.mse_loss(pred, gt).item()
    if mse == 0:
        return float('inf')
    psnr = 10 * np.log10((max_val ** 2) / mse)
    return float(psnr)


def compute_ssim(pred: torch.Tensor, gt: torch.Tensor, window_size: int = 11) -> float:
    pred_norm = (pred + 1.0) * 0.5
    gt_norm = (gt + 1.0) * 0.5
    
    if pred_norm.ndim == 4:
        pred_gray = pred_norm.mean(dim=1)  # [B, H, W]
        gt_gray = gt_norm.mean(dim=1)
    else:
        pred_gray = pred_norm.mean(dim=0)  # [H, W]
        gt_gray = gt_norm.mean(dim=0)
    
    C1 = 0.01 ** 2
    C2 = 0.03 ** 2
    
    mu_pred = pred_gray.mean()
    mu_gt = gt_gray.mean()
    
    sigma_pred_sq = ((pred_gray - mu_pred) ** 2).mean()
    sigma_gt_sq = ((gt_gray - mu_gt) ** 2).mean()
    sigma_pred_gt = ((pred_gray - mu_pred) * (gt_gray - mu_gt)).mean()
    
    numerator = (2 * mu_pred * mu_gt + C1) * (2 * sigma_pred_gt + C2)
    denominator = (mu_pred ** 2 + mu_gt ** 2 + C1) * (sigma_pred_sq + sigma_gt_sq + C2)
    
    ssim = numerator / (denominator + 1e-8)
    return float(ssim.clamp(0, 1).item())


def extract_ball_position(frame: np.ndarray, threshold: float = 0.7) -> Optional[Tuple[float, float]]:
    if len(frame.shape) == 3:
        gray = np.mean(frame, axis=2).astype(np.float32) / 255.0
    else:
        gray = frame.astype(np.float32)
        if gray.max() > 1.0:
            gray = gray / 255.0
    
    mask = gray > threshold
    
    if not np.any(mask):
        max_idx = np.unravel_index(np.argmax(gray), gray.shape)
        return (float(max_idx[0]), float(max_idx[1]))
    
    y_coords, x_coords = np.where(mask)
    if len(y_coords) == 0:
        return None
    
    weights = gray[y_coords, x_coords]
    total_weight = np.sum(weights)
    
    if total_weight > 0:
        y_center = np.average(y_coords, weights=weights)
        x_center = np.average(x_coords, weights=weights)
    else:
        y_center = np.mean(y_coords)
        x_center = np.mean(x_coords)
    
    return (float(y_center), float(x_center))


def compute_trajectory_error(
    pred_video: torch.Tensor,
    gt_video: torch.Tensor,
    threshold: float = 0.7
) -> Dict[str, float]:
    def to_uint8(tensor):
        tensor = tensor.clamp(-1, 1)
        tensor = (tensor + 1) * 0.5
        tensor = (tensor * 255.0).round().to(torch.uint8)
        return tensor
    
    if pred_video.ndim == 4:
        pred_video = pred_video.unsqueeze(0)
        gt_video = gt_video.unsqueeze(0)
    
    B, T, C, H, W = pred_video.shape
    errors = []
    valid_frames = 0
    
    for b in range(B):
        for t in range(T):
            pred_frame = to_uint8(pred_video[b, t]).permute(1, 2, 0).cpu().numpy()  # [H, W, 3]
            gt_frame = to_uint8(gt_video[b, t]).permute(1, 2, 0).cpu().numpy()
            
            pred_pos = extract_ball_position(pred_frame, threshold)
            gt_pos = extract_ball_position(gt_frame, threshold)
            
            if pred_pos is not None and gt_pos is not None:
                error = np.sqrt((pred_pos[0] - gt_pos[0])**2 + (pred_pos[1] - gt_pos[1])**2)
                errors.append(error)
                valid_frames += 1
    
    if len(errors) == 0:
        return {'mean': float('nan'), 'std': float('nan'), 'max': float('nan')}
    
    errors = np.array(errors)
    return {
        'mean': float(errors.mean()),
        'std': float(errors.std()),
        'max': float(errors.max())
    }


def compute_temporal_consistency(video: torch.Tensor) -> float:
    if video.ndim == 4:
        video = video.unsqueeze(0)
    
    B, T, C, H, W = video.shape
    if T < 2:
        return 0.0
    
    diffs = []
    for b in range(B):
        for t in range(1, T):
            diff = F.l1_loss(video[b, t], video[b, t-1]).item()
            diffs.append(diff)
    
    return float(np.mean(diffs)) if diffs else 0.0


print("Metric functions defined")

In [None]:
def get_resolution_for_timestep(
    t: int, num_timesteps: int, high_res: int, latent_res: int, k_step: int
) -> int:
    ratio = t / (num_timesteps - 1)
    size_float = high_res - ratio * (high_res - latent_res)
    size_int = int(round(size_float / k_step) * k_step)
    size_int = max(latent_res, min(high_res, size_int))
    return size_int


def planner_generate(
    planner: LatentPlanner,
    cond_frames: torch.Tensor,
    total_T: int,
    device: torch.device,
    show_progress: bool = False
) -> torch.Tensor:
    planner.eval()
    B, cond_T, C, H, W = cond_frames.shape
    frames = [cond_frames[:, i] for i in range(cond_T)]
    
    iterator = range(total_T - cond_T)
    if show_progress:
        iterator = tqdm(iterator, desc="  Planner generating", leave=False)
    
    with torch.no_grad():
        for i in iterator:
            seq = torch.stack(frames, dim=1)  # [B, t, C, H, W]
            attn = torch.ones(B, seq.shape[1], device=device)
            pred = planner(seq, attn_mask=attn)  # [B, t+1, C, H, W]
            next_frame = pred[:, seq.shape[1]]  # prediction for f_t
            frames.append(next_frame)
    
    return torch.stack(frames, dim=1)


def refiner_refine_sequence(
    refiner: AsymmetricUNet,
    scheduler: NoiseScheduler,
    lowres_btchw: torch.Tensor,
    high_res: int = 128,
    latent_res: int = 32,
    k_step: int = 1,
    t_start_frac: float = 0.1,
    batch_frames: bool = True,
) -> torch.Tensor:
    refiner.eval()
    B, T, C, H, W = lowres_btchw.shape
    assert H == latent_res and W == latent_res
    
    t_start = int((scheduler.num_timesteps - 1) * float(t_start_frac))
    
    with torch.no_grad():
        if batch_frames:
            out_frames_list = []
            
            for b_idx in range(B):
                latent_video = lowres_btchw[b_idx]  # [T, 3, 32, 32]
                
                noise = torch.randn_like(latent_video)
                timesteps = torch.full((T,), t_start, device=latent_video.device, dtype=torch.long)
                curr = scheduler.add_noise(latent_video, noise, timesteps)
                
                for t in range(t_start, -1, -1):
                    t_batch = torch.full((T,), t, device=latent_video.device, dtype=torch.long)
                    target_res_t = get_resolution_for_timestep(
                        t, scheduler.num_timesteps, high_res, latent_res, k_step
                    )
                    
                    curr_h, curr_w = curr.shape[-2:]
                    if curr_h != target_res_t or curr_w != target_res_t:
                        curr = F.interpolate(
                            curr, size=(target_res_t, target_res_t),
                            mode='bilinear', align_corners=False
                        )
                        curr_h, curr_w = target_res_t, target_res_t
                    
                    pred_x0_high = refiner(curr, t_batch, target_shape=(high_res, high_res))
                    pred_x0_curr = F.interpolate(
                        pred_x0_high, size=(curr_h, curr_w),
                        mode='bilinear', align_corners=False
                    )
                    
                    t_val = int(t)
                    prev = scheduler.step_x0(pred_x0_curr, t_val, curr)
                    
                    if t > 0:
                        next_res = get_resolution_for_timestep(
                            t - 1, scheduler.num_timesteps, high_res, latent_res, k_step
                        )
                        if prev.shape[-1] != next_res or prev.shape[-2] != next_res:
                            curr = F.interpolate(
                                prev, size=(next_res, next_res),
                                mode='bilinear', align_corners=False
                            )
                        else:
                            curr = prev
                    else:
                        curr = prev
                
                if curr.shape[-1] != high_res or curr.shape[-2] != high_res:
                    curr = F.interpolate(
                        curr, size=(high_res, high_res),
                        mode='bilinear', align_corners=False
                    )
                
                out_frames_list.append(curr)
            
            return torch.stack(out_frames_list, dim=0)
        else:
            out_frames = []
            for i in range(T):
                latent_img = lowres_btchw[:, i]  # [B, 3, 32, 32]
                
                noise = torch.randn_like(latent_img)
                timesteps = torch.full((B,), t_start, device=latent_img.device, dtype=torch.long)
                curr = scheduler.add_noise(latent_img, noise, timesteps)
                
                for t in range(t_start, -1, -1):
                    t_batch = torch.full((B,), t, device=latent_img.device, dtype=torch.long)
                    target_res_t = get_resolution_for_timestep(
                        t, scheduler.num_timesteps, high_res, latent_res, k_step
                    )
                    
                    if curr.shape[-1] != target_res_t:
                        curr = F.interpolate(
                            curr, size=(target_res_t, target_res_t),
                            mode='bilinear', align_corners=False
                        )
                    
                    pred_x0_high = refiner(curr, t_batch, target_shape=(high_res, high_res))
                    pred_x0_curr = F.interpolate(
                        pred_x0_high, size=(target_res_t, target_res_t),
                        mode='bilinear', align_corners=False
                    )
                    
                    prev = scheduler.step_x0(pred_x0_curr, t_batch, curr)
                    
                    if t > 0:
                        next_res = get_resolution_for_timestep(
                            t - 1, scheduler.num_timesteps, high_res, latent_res, k_step
                        )
                        if prev.shape[-1] != next_res:
                            curr = F.interpolate(
                                prev, size=(next_res, next_res),
                                mode='bilinear', align_corners=False
                            )
                        else:
                            curr = prev
                    else:
                        curr = prev
                
                if curr.shape[-1] != high_res:
                    curr = F.interpolate(
                        curr, size=(high_res, high_res),
                        mode='bilinear', align_corners=False
                    )
                
                out_frames.append(curr)
            
            return torch.stack(out_frames, dim=1)


print("Helper functions defined")

In [None]:
# ====== Config (EDIT ONLY THIS CELL) ======
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# --- Checkpoints ---
PLANNER_CKPT = '/workspace/planner/experiments/20260123_073354_planner_32_large/checkpoints/model_epoch_1000.pth'
REFINER_CKPT = '/workspace/experiments/20260120_100656_obj_weight_5/checkpoints/model_epoch_30.pt'

# --- Data ---
LOWRES_ROOT = '/workspace/data/processed_32'  # must be 32x32 .pt
HIGHRES_ROOT = '/workspace/data/processed'   # 128x128 high-res data (optional)
SPLIT = 'val'  # 'train', 'val', 'test'

EVAL_MODE = 'all'
COND_FRAMES_LIST = [1,3,5,10,15]
MAX_SAMPLES = 1

# --- Resolutions / schedule ---
LATENT_RES = 32
HIGH_RES = 128
K_STEP = 1

# --- Refiner sampling strength ---
T_START_FRAC = 0.1
REFINER_BATCH_FRAMES = True

# --- Refiner model architecture (must match the checkpoint) ---
REF_MODEL_CHANNELS = 128
REF_CHANNEL_MULT = (1, 2, 4, 8)
REF_NUM_RES_BLOCKS = 2

# --- Outputs ---
OUTPUT_DIR = '/workspace/evaluation_results'

print('DEVICE:', DEVICE)
print('LOWRES_ROOT:', LOWRES_ROOT)
print('SPLIT:', SPLIT)
print('EVAL_MODE:', EVAL_MODE)
print('COND_FRAMES_LIST:', COND_FRAMES_LIST)
print('OUTPUT_DIR:', OUTPUT_DIR)

In [None]:
def evaluate_planner(
    planner: LatentPlanner,
    dataloader: DataLoader,
    device: torch.device,
    cond_frames: int = 5,
    max_samples: Optional[int] = None
) -> Dict[str, List[float]]:
    planner.eval()
    
    all_metrics = {
        'mse': [], 'mae': [], 'psnr': [], 'ssim': [],
        'trajectory_error': [],
        'temporal_consistency_pred': [],
        'temporal_consistency_gt': []
    }
    
    sample_count = 0
    
    total_samples = min(len(dataloader), max_samples) if max_samples else len(dataloader)
    
    with torch.no_grad():
        pbar = tqdm(dataloader, desc="Evaluating Planner", total=total_samples, miniters=1, mininterval=1.0)
        
        for batch in pbar:
            if max_samples and sample_count >= max_samples:
                break
            
            seq = batch["seq"].to(device)
            lengths = batch["lengths"]
            
            B = seq.shape[0]
            for b in range(B):
                if max_samples and sample_count >= max_samples:
                    break
                
                T_actual = int(lengths[b].item())
                if T_actual < cond_frames + 1:
                    continue
                
                pbar.set_postfix({
                    'sample': sample_count + 1,
                    'frames': T_actual,
                    'cond': cond_frames
                })
                
                cond_seq = seq[b:b+1, :cond_frames]
                target_seq = seq[b:b+1, :T_actual]
                
                pred_seq = planner_generate(planner, cond_seq, T_actual, device, show_progress=False)
                
                pred_eval = pred_seq[:, cond_frames:]
                target_eval = target_seq[:, cond_frames:]
                
                T_eval = pred_eval.shape[1]
                frame_metrics = {'mse': [], 'mae': [], 'psnr': [], 'ssim': []}
                
                for t in range(T_eval):
                    pred_frame = pred_eval[0, t]
                    target_frame = target_eval[0, t]
                    
                    frame_metrics['mse'].append(compute_mse(pred_frame, target_frame))
                    frame_metrics['mae'].append(compute_mae(pred_frame, target_frame))
                    frame_metrics['psnr'].append(compute_psnr(pred_frame, target_frame))
                    frame_metrics['ssim'].append(compute_ssim(pred_frame, target_frame))
                
                all_metrics['mse'].append(np.mean(frame_metrics['mse']))
                all_metrics['mae'].append(np.mean(frame_metrics['mae']))
                all_metrics['psnr'].append(np.mean(frame_metrics['psnr']))
                all_metrics['ssim'].append(np.mean(frame_metrics['ssim']))
                
                traj_err = compute_trajectory_error(pred_eval, target_eval)
                all_metrics['trajectory_error'].append(traj_err['mean'])
                
                all_metrics['temporal_consistency_pred'].append(
                    compute_temporal_consistency(pred_eval)
                )
                all_metrics['temporal_consistency_gt'].append(
                    compute_temporal_consistency(target_eval)
                )
                
                sample_count += 1
    
    return all_metrics


def evaluate_refiner(
    refiner: AsymmetricUNet,
    scheduler: NoiseScheduler,
    lowres_dataloader: DataLoader,
    highres_dataloader: DataLoader,
    device: torch.device,
    high_res: int = 128,
    latent_res: int = 32,
    k_step: int = 1,
    t_start_frac: float = 0.1,
    batch_frames: bool = True,
    max_samples: Optional[int] = None
) -> Dict[str, List[float]]:
    refiner.eval()
    
    all_metrics = {
        'mse': [], 'mae': [], 'psnr': [], 'ssim': [],
        'trajectory_error': [],
        'temporal_consistency_pred': [],
        'temporal_consistency_gt': []
    }
    
    sample_count = 0
    
    total_samples = min(len(lowres_dataloader), max_samples) if max_samples else len(lowres_dataloader)
    
    with torch.no_grad():
        pbar = tqdm(
            zip(lowres_dataloader, highres_dataloader),
            desc="Evaluating Refiner",
            total=total_samples,
            miniters=1,
            mininterval=1.0
        )
        
        for (lowres_batch, highres_batch) in pbar:
            if max_samples and sample_count >= max_samples:
                break
            
            lowres_seq = lowres_batch["seq"].to(device)
            highres_seq = highres_batch["seq"].to(device)
            lengths = lowres_batch["lengths"]
            
            B = lowres_seq.shape[0]
            for b in range(B):
                if max_samples and sample_count >= max_samples:
                    break
                
                T_actual = int(lengths[b].item())
                
                pbar.set_postfix({
                    'sample': sample_count + 1,
                    'frames': T_actual
                })
                
                lowres_video = lowres_seq[b:b+1, :T_actual]
                highres_gt = highres_seq[b:b+1, :T_actual]
                
                if highres_gt.shape[-1] != high_res:
                    B_dim, T, C, H, W = highres_gt.shape
                    highres_gt_flat = highres_gt.reshape(B_dim * T * C, 1, H, W)
                    highres_gt_upscaled = F.interpolate(
                        highres_gt_flat,
                        size=(high_res, high_res),
                        mode='bilinear',
                        align_corners=False
                    )
                    highres_gt = highres_gt_upscaled.reshape(B_dim, T, C, high_res, high_res)
                
                highres_pred = refiner_refine_sequence(
                    refiner, scheduler, lowres_video,
                    high_res=high_res, latent_res=latent_res,
                    k_step=k_step, t_start_frac=t_start_frac,
                    batch_frames=batch_frames
                )
                
                T = highres_pred.shape[1]
                frame_metrics = {'mse': [], 'mae': [], 'psnr': [], 'ssim': []}
                
                for t in range(T):
                    pred_frame = highres_pred[0, t]
                    gt_frame = highres_gt[0, t]
                    
                    frame_metrics['mse'].append(compute_mse(pred_frame, gt_frame))
                    frame_metrics['mae'].append(compute_mae(pred_frame, gt_frame))
                    frame_metrics['psnr'].append(compute_psnr(pred_frame, gt_frame))
                    frame_metrics['ssim'].append(compute_ssim(pred_frame, gt_frame))
                
                all_metrics['mse'].append(np.mean(frame_metrics['mse']))
                all_metrics['mae'].append(np.mean(frame_metrics['mae']))
                all_metrics['psnr'].append(np.mean(frame_metrics['psnr']))
                all_metrics['ssim'].append(np.mean(frame_metrics['ssim']))
                
                traj_err = compute_trajectory_error(highres_pred, highres_gt)
                all_metrics['trajectory_error'].append(traj_err['mean'])
                
                all_metrics['temporal_consistency_pred'].append(
                    compute_temporal_consistency(highres_pred)
                )
                all_metrics['temporal_consistency_gt'].append(
                    compute_temporal_consistency(highres_gt)
                )
                
                sample_count += 1
                pbar.update(1)
    
    pbar.close()
    return all_metrics


def evaluate_end2end(
    planner: LatentPlanner,
    refiner: AsymmetricUNet,
    scheduler: NoiseScheduler,
    lowres_dataloader: DataLoader,
    highres_dataloader: DataLoader,
    device: torch.device,
    cond_frames: int = 5,
    high_res: int = 128,
    latent_res: int = 32,
    k_step: int = 1,
    t_start_frac: float = 0.1,
    batch_frames: bool = True,
    max_samples: Optional[int] = None
) -> Dict[str, List[float]]:
    planner.eval()
    refiner.eval()
    
    all_metrics = {
        'mse': [], 'mae': [], 'psnr': [], 'ssim': [],
        'trajectory_error': [],
        'temporal_consistency_pred': [],
        'temporal_consistency_gt': []
    }
    
    planner_metrics = {
        'mse': [], 'mae': [], 'psnr': [], 'ssim': [],
        'trajectory_error': []
    }
    
    refiner_metrics = {
        'mse': [], 'mae': [], 'psnr': [], 'ssim': [],
        'trajectory_error': []
    }
    
    sample_count = 0
    
    total_samples = min(len(lowres_dataloader), max_samples) if max_samples else len(lowres_dataloader)
    
    with torch.no_grad():
        pbar = tqdm(
            zip(lowres_dataloader, highres_dataloader),
            desc="Evaluating End-to-End",
            total=total_samples,
            miniters=1,
            mininterval=1.0
        )
        
        for (lowres_batch, highres_batch) in pbar:
            if max_samples and sample_count >= max_samples:
                break
            
            lowres_seq = lowres_batch["seq"].to(device)
            highres_seq = highres_batch["seq"].to(device)
            lengths = lowres_batch["lengths"]
            
            B = lowres_seq.shape[0]
            for b in range(B):
                if max_samples and sample_count >= max_samples:
                    break
                
                T_actual = int(lengths[b].item())
                if T_actual < cond_frames + 1:
                    continue
                
                pbar.set_postfix({
                    'sample': sample_count + 1,
                    'frames': T_actual,
                    'cond': cond_frames
                })
                
                cond_seq = lowres_seq[b:b+1, :cond_frames]
                
                lowres_gt = lowres_seq[b:b+1, :T_actual]
                highres_gt = highres_seq[b:b+1, :T_actual]
                
                if highres_gt.shape[-1] != high_res:
                    B_dim, T, C, H, W = highres_gt.shape
                    highres_gt_flat = highres_gt.reshape(B_dim * T * C, 1, H, W)
                    highres_gt_upscaled = F.interpolate(
                        highres_gt_flat,
                        size=(high_res, high_res),
                        mode='bilinear',
                        align_corners=False
                    )
                    highres_gt = highres_gt_upscaled.reshape(B_dim, T, C, high_res, high_res)
                
                lowres_pred = planner_generate(planner, cond_seq, T_actual, device, show_progress=False)
                
                planner_pred_eval = lowres_pred[:, cond_frames:]
                planner_gt_eval = lowres_gt[:, cond_frames:]
                
                T_eval = planner_pred_eval.shape[1]
                planner_frame_metrics = {'mse': [], 'mae': [], 'psnr': [], 'ssim': []}
                
                for t in range(T_eval):
                    pred_frame = planner_pred_eval[0, t]
                    gt_frame = planner_gt_eval[0, t]
                    
                    planner_frame_metrics['mse'].append(compute_mse(pred_frame, gt_frame))
                    planner_frame_metrics['mae'].append(compute_mae(pred_frame, gt_frame))
                    planner_frame_metrics['psnr'].append(compute_psnr(pred_frame, gt_frame))
                    planner_frame_metrics['ssim'].append(compute_ssim(pred_frame, gt_frame))
                
                planner_metrics['mse'].append(np.mean(planner_frame_metrics['mse']))
                planner_metrics['mae'].append(np.mean(planner_frame_metrics['mae']))
                planner_metrics['psnr'].append(np.mean(planner_frame_metrics['psnr']))
                planner_metrics['ssim'].append(np.mean(planner_frame_metrics['ssim']))
                
                planner_traj_err = compute_trajectory_error(planner_pred_eval, planner_gt_eval)
                planner_metrics['trajectory_error'].append(planner_traj_err['mean'])
                
                highres_from_gt_lowres = refiner_refine_sequence(
                    refiner, scheduler, lowres_gt,
                    high_res=high_res, latent_res=latent_res,
                    k_step=k_step, t_start_frac=t_start_frac,
                    batch_frames=batch_frames
                )
                
                refiner_pred_eval = highres_from_gt_lowres[:, cond_frames:]
                refiner_gt_eval = highres_gt[:, cond_frames:]
                
                refiner_frame_metrics = {'mse': [], 'mae': [], 'psnr': [], 'ssim': []}
                
                for t in range(T_eval):
                    pred_frame = refiner_pred_eval[0, t]
                    gt_frame = refiner_gt_eval[0, t]
                    
                    refiner_frame_metrics['mse'].append(compute_mse(pred_frame, gt_frame))
                    refiner_frame_metrics['mae'].append(compute_mae(pred_frame, gt_frame))
                    refiner_frame_metrics['psnr'].append(compute_psnr(pred_frame, gt_frame))
                    refiner_frame_metrics['ssim'].append(compute_ssim(pred_frame, gt_frame))
                
                refiner_metrics['mse'].append(np.mean(refiner_frame_metrics['mse']))
                refiner_metrics['mae'].append(np.mean(refiner_frame_metrics['mae']))
                refiner_metrics['psnr'].append(np.mean(refiner_frame_metrics['psnr']))
                refiner_metrics['ssim'].append(np.mean(refiner_frame_metrics['ssim']))
                
                refiner_traj_err = compute_trajectory_error(refiner_pred_eval, refiner_gt_eval)
                refiner_metrics['trajectory_error'].append(refiner_traj_err['mean'])
                
                highres_pred = refiner_refine_sequence(
                    refiner, scheduler, lowres_pred,
                    high_res=high_res, latent_res=latent_res,
                    k_step=k_step, t_start_frac=t_start_frac,
                    batch_frames=batch_frames
                )
                
                e2e_pred_eval = highres_pred[:, cond_frames:]
                e2e_gt_eval = highres_gt[:, cond_frames:]
                
                e2e_frame_metrics = {'mse': [], 'mae': [], 'psnr': [], 'ssim': []}
                
                for t in range(T_eval):
                    pred_frame = e2e_pred_eval[0, t]
                    gt_frame = e2e_gt_eval[0, t]
                    
                    e2e_frame_metrics['mse'].append(compute_mse(pred_frame, gt_frame))
                    e2e_frame_metrics['mae'].append(compute_mae(pred_frame, gt_frame))
                    e2e_frame_metrics['psnr'].append(compute_psnr(pred_frame, gt_frame))
                    e2e_frame_metrics['ssim'].append(compute_ssim(pred_frame, gt_frame))
                
                all_metrics['mse'].append(np.mean(e2e_frame_metrics['mse']))
                all_metrics['mae'].append(np.mean(e2e_frame_metrics['mae']))
                all_metrics['psnr'].append(np.mean(e2e_frame_metrics['psnr']))
                all_metrics['ssim'].append(np.mean(e2e_frame_metrics['ssim']))
                
                e2e_traj_err = compute_trajectory_error(e2e_pred_eval, e2e_gt_eval)
                all_metrics['trajectory_error'].append(e2e_traj_err['mean'])
                
                all_metrics['temporal_consistency_pred'].append(
                    compute_temporal_consistency(e2e_pred_eval)
                )
                all_metrics['temporal_consistency_gt'].append(
                    compute_temporal_consistency(e2e_gt_eval)
                )
                
                sample_count += 1
                pbar.update(1)
    
    pbar.close()
    
    return {
        'e2e': all_metrics,
        'planner': planner_metrics,
        'refiner': refiner_metrics
    }


def compute_statistics(values: List[float]) -> Dict[str, float]:
    values = [v for v in values if not np.isnan(v) and np.isfinite(v)]
    if len(values) == 0:
        return {'mean': float('nan'), 'std': float('nan'), 'min': float('nan'), 'max': float('nan')}
    
    return {
        'mean': float(np.mean(values)),
        'std': float(np.std(values)),
        'min': float(np.min(values)),
        'max': float(np.max(values))
    }


def save_results(metrics: Dict[str, List[float]], output_path: str, config: Dict):
    results = {'config': config, 'metrics': {}, 'statistics': {}}
    
    for key, values in metrics.items():
        results['metrics'][key] = values
        results['statistics'][key] = compute_statistics(values)
    
    with open(output_path, 'w') as f:
        json.dump(results, f, indent=2)
    
    print(f"Results saved to {output_path}")


def print_summary(metrics: Dict[str, List[float]], title: str = "Evaluation Results"):
    print(f"\n{'='*60}")
    print(f"{title}")
    print(f"{'='*60}")
    
    for key, values in metrics.items():
        stats = compute_statistics(values)
        print(f"\n{key.upper()}:")
        print(f"  Mean:   {stats['mean']:.4f}")
        print(f"  Std:    {stats['std']:.4f}")
        print(f"  Min:    {stats['min']:.4f}")
        print(f"  Max:    {stats['max']:.4f}")
        print(f"  Count:  {len(values)}")


print("Evaluation functions defined")

In [None]:
# ====== Load Models ======
device = torch.device(DEVICE)
# Set gpu number as 1
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
print(f"Using device: {device}")

# Planner
config_planner = ConfigPlanner()
config_planner.target_res = LATENT_RES
config_planner.max_seq_len = 100

planner = LatentPlanner(config_planner).to(device)
planner_state = torch.load(PLANNER_CKPT, map_location='cpu')

model_state = planner.state_dict()
checkpoint_state = planner_state

if 'pe.pe' in checkpoint_state and 'pe.pe' in model_state:
    if checkpoint_state['pe.pe'].shape != model_state['pe.pe'].shape:
        print(f"Warning: PositionalEncoding size mismatch. "
              f"Checkpoint: {checkpoint_state['pe.pe'].shape}, "
              f"Model: {model_state['pe.pe'].shape}")
        print("Skipping PE loading (will use model's initialized PE)")
        checkpoint_state = {k: v for k, v in checkpoint_state.items() if not k.startswith('pe.')}

planner.load_state_dict(checkpoint_state, strict=False)
planner.eval()
print(f"Loaded Planner: {PLANNER_CKPT}")

refiner = AsymmetricUNet(
    in_channels=3,
    out_channels=3,
    model_channels=REF_MODEL_CHANNELS,
    channel_mult=REF_CHANNEL_MULT,
    num_res_blocks=REF_NUM_RES_BLOCKS,
).to(device)

refiner_state = torch.load(REFINER_CKPT, map_location='cpu')
refiner.load_state_dict(refiner_state, strict=False)
refiner.eval()
print(f"Loaded Refiner: {REFINER_CKPT}")

scheduler = NoiseScheduler(device=device)
print("Models loaded successfully")

In [None]:
def collate_fn(batch):
    max_len = max(item["length"] for item in batch)
    B = len(batch)
    sample_seq = batch[0]["seq"]
    C, H, W = sample_seq.shape[1:]
    
    seq = torch.zeros(B, max_len, C, H, W, dtype=sample_seq.dtype)
    attn = torch.zeros(B, max_len, dtype=batch[0]["attention_mask"].dtype)
    lengths = torch.zeros(B, dtype=torch.long)
    paths = []
    
    for i, item in enumerate(batch):
        l = item["length"]
        seq[i, :l] = item["seq"]
        attn[i, :l] = item["attention_mask"]
        lengths[i] = l
        paths.append(item["path"])
    
    return {"seq": seq, "attention_mask": attn, "lengths": lengths, "paths": paths}

lowres_dataset = LowResVideoDataset(
    root=LOWRES_ROOT,
    split=SPLIT,
    target_res=LATENT_RES,
    max_seq_len=100
)

lowres_dataloader = DataLoader(
    lowres_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=4,
    collate_fn=collate_fn
)

print(f"Loaded low-res dataset: {len(lowres_dataset)} samples")

highres_dataset = None
highres_dataloader = None

if EVAL_MODE in ['refiner', 'end2end', 'all']:
    try:
        highres_dataset = LowResVideoDataset(
            root=HIGHRES_ROOT,
            split=SPLIT,
            target_res=HIGH_RES,
            max_seq_len=100
        )
        print(f"Loaded high-res dataset from {HIGHRES_ROOT}: {len(highres_dataset)} samples")
    except (FileNotFoundError, ValueError) as e:
        print(f"Warning: Could not load high-res dataset from {HIGHRES_ROOT}")
        print(f"  Error: {e}")
        print(f"  Will use upscaled low-res data as GT")
        highres_dataset = lowres_dataset
    
    highres_dataloader = DataLoader(
        highres_dataset,
        batch_size=1,
        shuffle=False,
        num_workers=4,
        collate_fn=collate_fn
    )

print("Datasets loaded successfully")

In [None]:
os.makedirs(OUTPUT_DIR, exist_ok=True)

all_results = {}

for cond_frames_val in COND_FRAMES_LIST:
    print(f"\n{'='*60}")
    print(f"Evaluating with cond_frames={cond_frames_val}")
    print(f"{'='*60}")
    
    cond_results = {}
    
    if EVAL_MODE in ['planner', 'all']:
        print("\n--- Planner Evaluation ---")
        planner_metrics = evaluate_planner(
            planner, lowres_dataloader, device,
            cond_frames=cond_frames_val,
            max_samples=MAX_SAMPLES
        )
        cond_results['planner'] = planner_metrics
        print_summary(planner_metrics, f"Planner (cond_frames={cond_frames_val})")
        
        save_results(
            planner_metrics,
            os.path.join(OUTPUT_DIR, f'planner_cond{cond_frames_val}.json'),
            {'cond_frames': cond_frames_val, 'mode': 'planner'}
        )
    
    if EVAL_MODE in ['refiner', 'all']:
        print("\n--- Refiner Evaluation ---")
        if highres_dataset is None or highres_dataset.target_res != HIGH_RES:
            print("  Note: Using upscaled low-res as GT")
        
        refiner_metrics = evaluate_refiner(
            refiner, scheduler, lowres_dataloader, highres_dataloader,
            device,
            high_res=HIGH_RES,
            latent_res=LATENT_RES,
            k_step=K_STEP,
            t_start_frac=T_START_FRAC,
            batch_frames=REFINER_BATCH_FRAMES,
            max_samples=MAX_SAMPLES
        )
        cond_results['refiner'] = refiner_metrics
        print_summary(refiner_metrics, f"Refiner (cond_frames={cond_frames_val})")
        
        save_results(
            refiner_metrics,
            os.path.join(OUTPUT_DIR, f'refiner_cond{cond_frames_val}.json'),
            {'cond_frames': cond_frames_val, 'mode': 'refiner'}
        )
    
    if EVAL_MODE in ['end2end', 'all']:
        print("\n--- End-to-End Evaluation (with stage-wise errors) ---")
        e2e_results = evaluate_end2end(
            planner, refiner, scheduler,
            lowres_dataloader, highres_dataloader,
            device,
            cond_frames=cond_frames_val,
            high_res=HIGH_RES,
            latent_res=LATENT_RES,
            k_step=K_STEP,
            t_start_frac=T_START_FRAC,
            batch_frames=REFINER_BATCH_FRAMES,
            max_samples=MAX_SAMPLES
        )
        
        e2e_metrics = e2e_results['e2e']
        planner_stage_metrics = e2e_results['planner']
        refiner_stage_metrics = e2e_results['refiner']
        
        cond_results['end2end'] = e2e_metrics
        cond_results['planner_stage'] = planner_stage_metrics
        cond_results['refiner_stage'] = refiner_stage_metrics
        
        print_summary(e2e_metrics, f"End-to-End (cond_frames={cond_frames_val})")
        print_summary(planner_stage_metrics, f"Planner Stage Error (cond_frames={cond_frames_val})")
        print_summary(refiner_stage_metrics, f"Refiner Stage Error (cond_frames={cond_frames_val})")
        
        save_results(
            e2e_metrics,
            os.path.join(OUTPUT_DIR, f'end2end_cond{cond_frames_val}.json'),
            {'cond_frames': cond_frames_val, 'mode': 'end2end'}
        )
        
        save_results(
            planner_stage_metrics,
            os.path.join(OUTPUT_DIR, f'planner_stage_cond{cond_frames_val}.json'),
            {'cond_frames': cond_frames_val, 'mode': 'planner_stage'}
        )
        
        save_results(
            refiner_stage_metrics,
            os.path.join(OUTPUT_DIR, f'refiner_stage_cond{cond_frames_val}.json'),
            {'cond_frames': cond_frames_val, 'mode': 'refiner_stage'}
        )
    
    all_results[f'cond_frames_{cond_frames_val}'] = cond_results

summary_path = os.path.join(OUTPUT_DIR, 'summary_all.json')
with open(summary_path, 'w') as f:
    json.dump(all_results, f, indent=2)
print(f"\nSummary saved to {summary_path}")

print("\nEvaluation completed!")

In [None]:
RESULT_PATHS = [
    '/workspace/evaluation_results/end2end_cond1.json',
    '/workspace/evaluation_results/end2end_cond3.json',
    '/workspace/evaluation_results/end2end_cond5.json',
    '/workspace/evaluation_results/end2end_cond10.json',
]

print(f"Will load {len(RESULT_PATHS)} result files")

In [None]:
import json
import matplotlib.pyplot as plt
import numpy as np
from collections import defaultdict

all_data = []
cond_frames_list = []

for result_path in RESULT_PATHS:
    if not os.path.exists(result_path):
        print(f"Warning: File not found: {result_path}")
        continue
    
    try:
        with open(result_path, 'r') as f:
            data = json.load(f)
            all_data.append(data)
            cond_frames = data['config'].get('cond_frames', None)
            if cond_frames is not None:
                cond_frames_list.append(cond_frames)
            print(f"Loaded: {result_path} (cond_frames={cond_frames})")
    except Exception as e:
        print(f"Error loading {result_path}: {e}")

cond_frames_list = sorted(set(cond_frames_list))
print(f"\nTotal condition frames found: {cond_frames_list}")
print(f"Total result files loaded: {len(all_data)}")

metrics_by_cond = defaultdict(lambda: defaultdict(list))

for data in all_data:
    cond_frames = data['config'].get('cond_frames', None)
    if cond_frames is None:
        continue
    
    metrics = data.get('metrics', {})
    for metric_name, values in metrics.items():
        if isinstance(values, list) and len(values) > 0:
            metrics_by_cond[metric_name][cond_frames].extend(values)

print("\nMetrics collected:")
for metric_name in sorted(metrics_by_cond.keys()):
    cond_count = len(metrics_by_cond[metric_name])
    total_samples = sum(len(v) for v in metrics_by_cond[metric_name].values())
    print(f"  - {metric_name}: {cond_count} cond_frames, {total_samples} total samples")

In [None]:

metric_names = ['mse', 'mae', 'psnr', 'ssim', 'trajectory_error']
n_metrics = len(metric_names)

fig, axes = plt.subplots(2, 3, figsize=(18, 12))
axes = axes.flatten()

for idx, metric_name in enumerate(metric_names):
    ax = axes[idx]
    
    if metric_name not in metrics_by_cond:
        ax.text(0.5, 0.5, f'{metric_name}\n(no data)', 
                ha='center', va='center', transform=ax.transAxes)
        ax.set_xticks([])
        ax.set_yticks([])
        continue
    
    data_to_plot = []
    labels = []
    
    for cond_frames in sorted(metrics_by_cond[metric_name].keys()):
        values = metrics_by_cond[metric_name][cond_frames]
        # Filter out NaN and inf values
        values = [v for v in values if not (np.isnan(v) or np.isinf(v))]
        if len(values) > 0:
            data_to_plot.append(values)
            labels.append(f'cond={cond_frames}')
    
    if len(data_to_plot) > 0:
        bp = ax.boxplot(data_to_plot, labels=labels, patch_artist=True)
        
        # Color the boxes
        colors = plt.cm.Set3(np.linspace(0, 1, len(bp['boxes'])))
        for patch, color in zip(bp['boxes'], colors):
            patch.set_facecolor(color)
        
        ax.set_title(f'{metric_name.upper()}', fontsize=12, fontweight='bold')
        ax.set_ylabel('Value')
        ax.grid(True, alpha=0.3)
        ax.tick_params(axis='x', rotation=45)
    else:
        ax.text(0.5, 0.5, f'{metric_name}\n(no valid data)', 
                ha='center', va='center', transform=ax.transAxes)

# Remove empty subplot
axes[-1].remove()

plt.tight_layout()

# Save to first result file's directory or current directory
save_dir = os.path.dirname(RESULT_PATHS[0]) if RESULT_PATHS else OUTPUT_DIR
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, 'metrics_boxplots.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"Saved: {save_path}")
plt.show()

In [None]:
# ====== Plot 2: Mean Metrics vs Condition Frames ======

fig, axes = plt.subplots(2, 3, figsize=(18, 12))
axes = axes.flatten()

for idx, metric_name in enumerate(metric_names):
    ax = axes[idx]
    
    if metric_name not in metrics_by_cond:
        ax.text(0.5, 0.5, f'{metric_name}\n(no data)', 
                ha='center', va='center', transform=ax.transAxes)
        ax.set_xticks([])
        ax.set_yticks([])
        continue
    
    cond_frames_sorted = sorted(metrics_by_cond[metric_name].keys())
    means = []
    stds = []
    
    for cond_frames in cond_frames_sorted:
        values = metrics_by_cond[metric_name][cond_frames]
        values = [v for v in values if not (np.isnan(v) or np.isinf(v))]
        if len(values) > 0:
            means.append(np.mean(values))
            stds.append(np.std(values))
        else:
            means.append(np.nan)
            stds.append(np.nan)
    
    if len(means) > 0 and not all(np.isnan(means)):
        ax.errorbar(cond_frames_sorted, means, yerr=stds, 
                   marker='o', markersize=8, capsize=5, capthick=2,
                   linewidth=2, label='Mean ± Std')
        ax.set_title(f'{metric_name.upper()} vs Condition Frames', 
                    fontsize=12, fontweight='bold')
        ax.set_xlabel('Condition Frames')
        ax.set_ylabel('Value')
        ax.grid(True, alpha=0.3)
        ax.legend()
    else:
        ax.text(0.5, 0.5, f'{metric_name}\n(no valid data)', 
                ha='center', va='center', transform=ax.transAxes)

# Remove empty subplot
axes[-1].remove()

plt.tight_layout()

save_dir = os.path.dirname(RESULT_PATHS[0]) if RESULT_PATHS else OUTPUT_DIR
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, 'metrics_vs_cond_frames.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"Saved: {save_path}")
plt.show()

In [None]:
# ====== Plot 3: Statistics Summary Table ======

# Compute statistics from loaded data
print("\n" + "="*80)
print("SUMMARY STATISTICS (Computed from loaded data)")
print("="*80)

for metric_name in metric_names:
    print(f"\n{metric_name.upper()}:")
    print(f"{'Cond Frames':<15} {'Mean':<15} {'Std':<15} {'Min':<15} {'Max':<15} {'Samples':<10}")
    print("-" * 85)
    
    for cond_frames in sorted(cond_frames_list):
        if metric_name in metrics_by_cond and cond_frames in metrics_by_cond[metric_name]:
            values = metrics_by_cond[metric_name][cond_frames]
            values = [v for v in values if not (np.isnan(v) or np.isinf(v))]
            
            if len(values) > 0:
                mean_val = np.mean(values)
                std_val = np.std(values)
                min_val = np.min(values)
                max_val = np.max(values)
                n_samples = len(values)
                
                print(f"{cond_frames:<15} {mean_val:<15.6f} {std_val:<15.6f} {min_val:<15.6f} {max_val:<15.6f} {n_samples:<10}")
            else:
                print(f"{cond_frames:<15} {'N/A':<15} {'N/A':<15} {'N/A':<15} {'N/A':<15} {'0':<10}")
        else:
            print(f"{cond_frames:<15} {'N/A':<15} {'N/A':<15} {'N/A':<15} {'N/A':<15} {'0':<10}")

In [None]:
# ====== Plot 4: Heatmap of Mean Metrics ======

# Prepare data for heatmap
heatmap_data = []
row_labels = []

for metric_name in metric_names:
    row = []
    for cond_frames in sorted(cond_frames_list):
        if metric_name in metrics_by_cond and cond_frames in metrics_by_cond[metric_name]:
            values = metrics_by_cond[metric_name][cond_frames]
            values = [v for v in values if not (np.isnan(v) or np.isinf(v))]
            if len(values) > 0:
                row.append(np.mean(values))
            else:
                row.append(np.nan)
        else:
            row.append(np.nan)
    
    if not all(np.isnan(row)):
        heatmap_data.append(row)
        row_labels.append(metric_name.upper())

if len(heatmap_data) > 0:
    heatmap_data = np.array(heatmap_data)
    col_labels = [f'cond={cf}' for cf in sorted(cond_frames_list)]
    
    fig, ax = plt.subplots(figsize=(10, 6))
    im = ax.imshow(heatmap_data, cmap='RdYlGn', aspect='auto')
    
    # Normalize for better visualization (invert for MSE, MAE, trajectory_error)
    for i, metric_name in enumerate(metric_names):
        if metric_name in ['mse', 'mae', 'trajectory_error']:
            # Lower is better - invert colormap
            pass
    
    ax.set_xticks(np.arange(len(col_labels)))
    ax.set_yticks(np.arange(len(row_labels)))
    ax.set_xticklabels(col_labels)
    ax.set_yticklabels(row_labels)
    
    # Add text annotations
    for i in range(len(row_labels)):
        for j in range(len(col_labels)):
            if not np.isnan(heatmap_data[i, j]):
                text = ax.text(j, i, f'{heatmap_data[i, j]:.4f}',
                             ha="center", va="center", color="black", fontsize=9)
    
    ax.set_title('Mean Metrics Heatmap', fontsize=14, fontweight='bold', pad=20)
    plt.colorbar(im, ax=ax)
    plt.tight_layout()
    
    save_dir = os.path.dirname(RESULT_PATHS[0]) if RESULT_PATHS else OUTPUT_DIR
    os.makedirs(save_dir, exist_ok=True)
    save_path = os.path.join(save_dir, 'metrics_heatmap.png')
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"Saved: {save_path}")
    plt.show()
else:
    print("No data available for heatmap")

In [None]:
# ====== Plot 5: Distribution Histograms ======

# Plot histograms for key metrics
key_metrics = ['psnr', 'ssim', 'trajectory_error']
n_key = len(key_metrics)

fig, axes = plt.subplots(1, n_key, figsize=(15, 5))

for idx, metric_name in enumerate(key_metrics):
    ax = axes[idx]
    
    if metric_name not in metrics_by_cond:
        ax.text(0.5, 0.5, f'{metric_name}\n(no data)', 
                ha='center', va='center', transform=ax.transAxes)
        continue
    
    for cond_frames in sorted(metrics_by_cond[metric_name].keys()):
        values = metrics_by_cond[metric_name][cond_frames]
        values = [v for v in values if not (np.isnan(v) or np.isinf(v))]
        if len(values) > 0:
            ax.hist(values, alpha=0.6, label=f'cond={cond_frames}', bins=15)
    
    ax.set_title(f'{metric_name.upper()} Distribution', fontsize=12, fontweight='bold')
    ax.set_xlabel('Value')
    ax.set_ylabel('Frequency')
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.tight_layout()

save_dir = os.path.dirname(RESULT_PATHS[0]) if RESULT_PATHS else OUTPUT_DIR
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, 'metrics_distributions.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"Saved: {save_path}")
plt.show()

In [None]:
# ====== Plot 6: Temporal Consistency Comparison ======

if 'temporal_consistency_pred' in metrics_by_cond and 'temporal_consistency_gt' in metrics_by_cond:
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    
    # Predicted temporal consistency
    cond_frames_sorted = sorted(metrics_by_cond['temporal_consistency_pred'].keys())
    pred_means = []
    pred_stds = []
    gt_means = []
    gt_stds = []
    
    for cond_frames in cond_frames_sorted:
        pred_values = metrics_by_cond['temporal_consistency_pred'][cond_frames]
        pred_values = [v for v in pred_values if not (np.isnan(v) or np.isinf(v))]
        if len(pred_values) > 0:
            pred_means.append(np.mean(pred_values))
            pred_stds.append(np.std(pred_values))
        else:
            pred_means.append(np.nan)
            pred_stds.append(np.nan)
        
        if cond_frames in metrics_by_cond['temporal_consistency_gt']:
            gt_values = metrics_by_cond['temporal_consistency_gt'][cond_frames]
            gt_values = [v for v in gt_values if not (np.isnan(v) or np.isinf(v))]
            if len(gt_values) > 0:
                gt_means.append(np.mean(gt_values))
                gt_stds.append(np.std(gt_values))
            else:
                gt_means.append(np.nan)
                gt_stds.append(np.nan)
        else:
            gt_means.append(np.nan)
            gt_stds.append(np.nan)
    
    # Plot comparison
    x = np.arange(len(cond_frames_sorted))
    width = 0.35
    
    ax1.bar(x - width/2, pred_means, width, yerr=pred_stds, label='Predicted', alpha=0.8)
    ax1.bar(x + width/2, gt_means, width, yerr=gt_stds, label='GT', alpha=0.8)
    ax1.set_xlabel('Condition Frames')
    ax1.set_ylabel('Temporal Consistency')
    ax1.set_title('Temporal Consistency: Predicted vs GT', fontweight='bold')
    ax1.set_xticks(x)
    ax1.set_xticklabels([f'cond={cf}' for cf in cond_frames_sorted])
    ax1.legend()
    ax1.grid(True, alpha=0.3, axis='y')
    
    # Box plot comparison
    pred_data = [metrics_by_cond['temporal_consistency_pred'][cf] 
                 for cf in cond_frames_sorted 
                 if cf in metrics_by_cond['temporal_consistency_pred']]
    gt_data = [metrics_by_cond['temporal_consistency_gt'][cf] 
               for cf in cond_frames_sorted 
               if cf in metrics_by_cond['temporal_consistency_gt']]
    
    bp1 = ax2.boxplot(pred_data, positions=[i-0.2 for i in range(1, len(cond_frames_sorted)+1)], 
                      widths=0.35, patch_artist=True, labels=[f'cond={cf}' for cf in cond_frames_sorted])
    bp2 = ax2.boxplot(gt_data, positions=[i+0.2 for i in range(1, len(cond_frames_sorted)+1)], 
                      widths=0.35, patch_artist=True)
    
    for patch in bp1['boxes']:
        patch.set_facecolor('lightblue')
    for patch in bp2['boxes']:
        patch.set_facecolor('lightcoral')
    
    ax2.set_xlabel('Condition Frames')
    ax2.set_ylabel('Temporal Consistency')
    ax2.set_title('Temporal Consistency Distribution', fontweight='bold')
    ax2.legend([bp1['boxes'][0], bp2['boxes'][0]], ['Predicted', 'GT'])
    ax2.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    
    save_dir = os.path.dirname(RESULT_PATHS[0]) if RESULT_PATHS else OUTPUT_DIR
    os.makedirs(save_dir, exist_ok=True)
    save_path = os.path.join(save_dir, 'temporal_consistency_comparison.png')
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"Saved: {save_path}")
    plt.show()
else:
    print("Temporal consistency data not available")

In [None]:
# ====== Load Stage-wise Error Results ======
STAGE_RESULT_PATHS = {
    'planner': [],
    'refiner': [],
    'e2e': []
}

import glob
for cond_frames in cond_frames_list:
    planner_path = os.path.join(os.path.dirname(RESULT_PATHS[0]) if RESULT_PATHS else OUTPUT_DIR, 
                                f'planner_stage_cond{cond_frames}.json')
    refiner_path = os.path.join(os.path.dirname(RESULT_PATHS[0]) if RESULT_PATHS else OUTPUT_DIR,
                                f'refiner_stage_cond{cond_frames}.json')
    e2e_path = os.path.join(os.path.dirname(RESULT_PATHS[0]) if RESULT_PATHS else OUTPUT_DIR,
                            f'end2end_cond{cond_frames}.json')
    
    if os.path.exists(planner_path):
        STAGE_RESULT_PATHS['planner'].append(planner_path)
    if os.path.exists(refiner_path):
        STAGE_RESULT_PATHS['refiner'].append(refiner_path)
    if os.path.exists(e2e_path):
        STAGE_RESULT_PATHS['e2e'].append(e2e_path)

print("Stage-wise result files:")
print(f"  Planner: {len(STAGE_RESULT_PATHS['planner'])} files")
print(f"  Refiner: {len(STAGE_RESULT_PATHS['refiner'])} files")
print(f"  E2E: {len(STAGE_RESULT_PATHS['e2e'])} files")

In [None]:
# ====== Load and Organize Stage-wise Metrics ======

stage_metrics = {
    'planner': defaultdict(lambda: defaultdict(list)),
    'refiner': defaultdict(lambda: defaultdict(list)),
    'e2e': defaultdict(lambda: defaultdict(list))
}

# Load planner stage errors
for file_path in STAGE_RESULT_PATHS['planner']:
    try:
        with open(file_path, 'r') as f:
            data = json.load(f)
            cond_frames = data['config'].get('cond_frames', None)
            if cond_frames is not None:
                metrics = data.get('metrics', {})
                for metric_name, values in metrics.items():
                    if isinstance(values, list):
                        stage_metrics['planner'][metric_name][cond_frames].extend(values)
    except Exception as e:
        print(f"Error loading {file_path}: {e}")

# Load refiner stage errors
for file_path in STAGE_RESULT_PATHS['refiner']:
    try:
        with open(file_path, 'r') as f:
            data = json.load(f)
            cond_frames = data['config'].get('cond_frames', None)
            if cond_frames is not None:
                metrics = data.get('metrics', {})
                for metric_name, values in metrics.items():
                    if isinstance(values, list):
                        stage_metrics['refiner'][metric_name][cond_frames].extend(values)
    except Exception as e:
        print(f"Error loading {file_path}: {e}")

# Load e2e errors (from existing end2end files)
for file_path in STAGE_RESULT_PATHS['e2e']:
    try:
        with open(file_path, 'r') as f:
            data = json.load(f)
            cond_frames = data['config'].get('cond_frames', None)
            if cond_frames is not None:
                metrics = data.get('metrics', {})
                for metric_name, values in metrics.items():
                    if isinstance(values, list):
                        stage_metrics['e2e'][metric_name][cond_frames].extend(values)
    except Exception as e:
        print(f"Error loading {file_path}: {e}")

print("\nStage-wise metrics loaded:")
for stage in ['planner', 'refiner', 'e2e']:
    print(f"\n{stage.upper()}:")
    for metric_name in sorted(stage_metrics[stage].keys()):
        cond_count = len(stage_metrics[stage][metric_name])
        total_samples = sum(len(v) for v in stage_metrics[stage][metric_name].values())
        print(f"  - {metric_name}: {cond_count} cond_frames, {total_samples} samples")

In [None]:
# ====== Plot: Stage-wise Error Comparison ======

key_metrics = ['mse', 'mae', 'psnr', 'ssim', 'trajectory_error']
n_metrics = len(key_metrics)

fig, axes = plt.subplots(2, 3, figsize=(18, 12))
axes = axes.flatten()

for idx, metric_name in enumerate(key_metrics):
    ax = axes[idx]
    
    stages_data = {}
    for stage in ['planner', 'refiner', 'e2e']:
        if metric_name in stage_metrics[stage]:
            cond_frames_sorted = sorted(stage_metrics[stage][metric_name].keys())
            means = []
            stds = []
            
            for cond_frames in cond_frames_sorted:
                values = stage_metrics[stage][metric_name][cond_frames]
                values = [v for v in values if not (np.isnan(v) or np.isinf(v))]
                if len(values) > 0:
                    means.append(np.mean(values))
                    stds.append(np.std(values))
                else:
                    means.append(np.nan)
                    stds.append(np.nan)
            
            if len(means) > 0 and not all(np.isnan(means)):
                stages_data[stage] = {
                    'cond_frames': cond_frames_sorted,
                    'means': means,
                    'stds': stds
                }
    
    # Plot
    if len(stages_data) > 0:
        colors = {'planner': 'blue', 'refiner': 'green', 'e2e': 'red'}
        markers = {'planner': 'o', 'refiner': 's', 'e2e': '^'}
        labels = {'planner': 'Planner Stage', 'refiner': 'Refiner Stage', 'e2e': 'End-to-End'}
        
        for stage in ['planner', 'refiner', 'e2e']:
            if stage in stages_data:
                data = stages_data[stage]
                ax.errorbar(data['cond_frames'], data['means'], yerr=data['stds'],
                           marker=markers[stage], markersize=8, capsize=5, capthick=2,
                           linewidth=2, label=labels[stage], color=colors[stage], alpha=0.8)
        
        ax.set_title(f'{metric_name.upper()} - Stage Comparison', fontsize=12, fontweight='bold')
        ax.set_xlabel('Condition Frames')
        ax.set_ylabel('Value')
        ax.grid(True, alpha=0.3)
        ax.legend()
    else:
        ax.text(0.5, 0.5, f'{metric_name}\n(no data)', 
                ha='center', va='center', transform=ax.transAxes)

# Remove empty subplot
axes[-1].remove()

plt.tight_layout()

save_dir = os.path.dirname(RESULT_PATHS[0]) if RESULT_PATHS else OUTPUT_DIR
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, 'stage_wise_error_comparison.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"Saved: {save_path}")
plt.show()

In [None]:
# ====== Plot: Error Contribution Analysis ======

fig, axes = plt.subplots(2, 3, figsize=(18, 12))
axes = axes.flatten()

for idx, metric_name in enumerate(key_metrics):
    ax = axes[idx]
    
    if metric_name not in stage_metrics['e2e']:
        ax.text(0.5, 0.5, f'{metric_name}\n(no E2E data)', 
                ha='center', va='center', transform=ax.transAxes)
        continue
    
    cond_frames_sorted = sorted(stage_metrics['e2e'][metric_name].keys())
    
    planner_contrib = []
    refiner_contrib = []
    e2e_values = []
    
    for cond_frames in cond_frames_sorted:
        e2e_vals = stage_metrics['e2e'][metric_name][cond_frames]
        e2e_vals = [v for v in e2e_vals if not (np.isnan(v) or np.isinf(v))]
        if len(e2e_vals) == 0:
            continue
        
        e2e_mean = np.mean(e2e_vals)
        e2e_values.append(e2e_mean)
        
        if metric_name in stage_metrics['planner'] and cond_frames in stage_metrics['planner'][metric_name]:
            planner_vals = stage_metrics['planner'][metric_name][cond_frames]
            planner_vals = [v for v in planner_vals if not (np.isnan(v) or np.isinf(v))]
            if len(planner_vals) > 0:
                planner_mean = np.mean(planner_vals)
                if metric_name in ['mse', 'mae', 'trajectory_error']:
                    contrib = planner_mean / (e2e_mean + 1e-8) * 100
                else:
                    contrib = (1 - planner_mean / (e2e_mean + 1e-8)) * 100
                planner_contrib.append(contrib)
            else:
                planner_contrib.append(0)
        else:
            planner_contrib.append(0)
        
        if metric_name in stage_metrics['refiner'] and cond_frames in stage_metrics['refiner'][metric_name]:
            refiner_vals = stage_metrics['refiner'][metric_name][cond_frames]
            refiner_vals = [v for v in refiner_vals if not (np.isnan(v) or np.isinf(v))]
            if len(refiner_vals) > 0:
                refiner_mean = np.mean(refiner_vals)
                if metric_name in ['mse', 'mae', 'trajectory_error']:
                    contrib = refiner_mean / (e2e_mean + 1e-8) * 100
                else:
                    contrib = (1 - refiner_mean / (e2e_mean + 1e-8)) * 100
                refiner_contrib.append(contrib)
            else:
                refiner_contrib.append(0)
        else:
            refiner_contrib.append(0)
    
    if len(e2e_values) > 0:
        x = np.arange(len(cond_frames_sorted))
        width = 0.25
        
        ax.bar(x - width, planner_contrib, width, label='Planner Contribution', color='blue', alpha=0.7)
        ax.bar(x, refiner_contrib, width, label='Refiner Contribution', color='green', alpha=0.7)
        ax.bar(x + width, [100 - p - r for p, r in zip(planner_contrib, refiner_contrib)], 
               width, label='Other/Interaction', color='gray', alpha=0.7)
        
        ax.set_title(f'{metric_name.upper()} - Error Contribution by Stage', fontsize=12, fontweight='bold')
        ax.set_xlabel('Condition Frames')
        ax.set_ylabel('Contribution (%)')
        ax.set_xticks(x)
        ax.set_xticklabels([f'cond={cf}' for cf in cond_frames_sorted])
        ax.legend()
        ax.grid(True, alpha=0.3, axis='y')
        ax.set_ylim([0, 100])
    else:
        ax.text(0.5, 0.5, f'{metric_name}\n(no valid data)', 
                ha='center', va='center', transform=ax.transAxes)

axes[-1].remove()

plt.tight_layout()

save_dir = os.path.dirname(RESULT_PATHS[0]) if RESULT_PATHS else OUTPUT_DIR
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, 'error_contribution_by_stage.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"Saved: {save_path}")
plt.show()

In [None]:
# ====== Table: Stage-wise Error Summary ======

print("\n" + "="*100)
print("STAGE-WISE ERROR SUMMARY")
print("="*100)

for metric_name in key_metrics:
    print(f"\n{metric_name.upper()}:")
    print(f"{'Cond':<8} {'Planner Stage':<20} {'Refiner Stage':<20} {'End-to-End':<20} {'Planner→Refiner':<20}")
    print(f"{'Frames':<8} {'Mean ± Std':<20} {'Mean ± Std':<20} {'Mean ± Std':<20} {'Error Increase':<20}")
    print("-" * 100)
    
    all_cond_frames = set()
    for stage in ['planner', 'refiner', 'e2e']:
        if metric_name in stage_metrics[stage]:
            all_cond_frames.update(stage_metrics[stage][metric_name].keys())
    
    for cond_frames in sorted(all_cond_frames):
        planner_str = "N/A"
        refiner_str = "N/A"
        e2e_str = "N/A"
        increase_str = "N/A"
        
        # Planner
        if metric_name in stage_metrics['planner'] and cond_frames in stage_metrics['planner'][metric_name]:
            planner_vals = stage_metrics['planner'][metric_name][cond_frames]
            planner_vals = [v for v in planner_vals if not (np.isnan(v) or np.isinf(v))]
            if len(planner_vals) > 0:
                planner_mean = np.mean(planner_vals)
                planner_std = np.std(planner_vals)
                planner_str = f"{planner_mean:.4f} ± {planner_std:.4f}"
        
        # Refiner
        if metric_name in stage_metrics['refiner'] and cond_frames in stage_metrics['refiner'][metric_name]:
            refiner_vals = stage_metrics['refiner'][metric_name][cond_frames]
            refiner_vals = [v for v in refiner_vals if not (np.isnan(v) or np.isinf(v))]
            if len(refiner_vals) > 0:
                refiner_mean = np.mean(refiner_vals)
                refiner_std = np.std(refiner_vals)
                refiner_str = f"{refiner_mean:.4f} ± {refiner_std:.4f}"
        
        # E2E
        if metric_name in stage_metrics['e2e'] and cond_frames in stage_metrics['e2e'][metric_name]:
            e2e_vals = stage_metrics['e2e'][metric_name][cond_frames]
            e2e_vals = [v for v in e2e_vals if not (np.isnan(v) or np.isinf(v))]
            if len(e2e_vals) > 0:
                e2e_mean = np.mean(e2e_vals)
                e2e_std = np.std(e2e_vals)
                e2e_str = f"{e2e_mean:.4f} ± {e2e_std:.4f}"
        
        # Planner → Refiner 
        if planner_str != "N/A" and refiner_str != "N/A":
            planner_vals = stage_metrics['planner'][metric_name][cond_frames]
            planner_vals = [v for v in planner_vals if not (np.isnan(v) or np.isinf(v))]
            refiner_vals = stage_metrics['refiner'][metric_name][cond_frames]
            refiner_vals = [v for v in refiner_vals if not (np.isnan(v) or np.isinf(v))]
            
            if len(planner_vals) > 0 and len(refiner_vals) > 0:
                planner_mean = np.mean(planner_vals)
                refiner_mean = np.mean(refiner_vals)
                
                if metric_name in ['mse', 'mae', 'trajectory_error']:
                    increase = refiner_mean - planner_mean
                    increase_str = f"+{increase:.4f}"
                else:
                    decrease = planner_mean - refiner_mean
                    increase_str = f"{decrease:.4f}"
        
        print(f"{cond_frames:<8} {planner_str:<20} {refiner_str:<20} {e2e_str:<20} {increase_str:<20}")