In [None]:
# --- Version 2: Static Plan Overlay (single-shot plan) ---
# record_dp_video_static.py
import os
import cv2
import torch
import numpy as np
from tqdm import trange
import tyro

from mani_skill.utils.wrappers.flatten import FlattenRGBDObservationWrapper
from diffusion_policy.make_env import make_eval_envs
from train_checkp import Agent, Args

def draw_traj(frame,pts,color):
    for i in range(len(pts)-1): cv2.line(frame,tuple(map(int,pts[i])),tuple(map(int,pts[i+1])),color,2)
    return frame

def main(checkpoint:str, env_id:str="PickCube-v1", control_mode:str="pd_ee_delta_pos",
         output_dir:str="recorded_dp_videos", num_episodes:int=5, max_steps:int=100,
         obs_horizon:int=2, act_horizon:int=8, pred_horizon:int=16,
         px_per_m_x:float=50.94, px_per_m_y:float=33.21):
    args=Args(); args.env_id=env_id; args.obs_mode="rgb"; args.obs_horizon=obs_horizon
    args.act_horizon=act_horizon; args.pred_horizon=pred_horizon; args.control_mode=control_mode
    args.sim_backend="physx_cpu"; args.max_episode_steps=max_steps; args.cuda=True
    device=torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
    env=make_eval_envs(env_id,1,args.sim_backend,
        dict(control_mode=control_mode,reward_mode="sparse",obs_mode="rgb",render_mode="rgb_array",
             human_render_camera_configs={"shader_pack":"default"},max_episode_steps=max_steps),
        {"obs_horizon":obs_horizon},[FlattenRGBDObservationWrapper])
    agent=Agent(env,args).to(device)
    ckpt=torch.load(checkpoint,map_location=device); agent.load_state_dict(ckpt["agent"])
    agent.eval(); os.makedirs(output_dir,exist_ok=True)

    for ep in range(num_episodes):
        # reset & init
        raw,_=env.reset(); obs0={k:raw[k][0] for k in raw}
        rgb_arr=obs0['rgb'];state_arr=obs0['state']
        rgb_t=torch.from_numpy(rgb_arr).unsqueeze(0).to(device)
        state_t=torch.from_numpy(state_arr).unsqueeze(0).to(device)
        obs_batch={'rgb':rgb_t,'state':state_t}
        with torch.no_grad(): act_seq=agent.get_action(obs_batch)
        # build static planned pts
        ee0=obs0['state'][-1,:3]; u0=v0=None
        # initial screen pos
        frame0=env.envs[0].render();h,w,_=frame0.shape
        u0=w/2+ee0[0]*px_per_m_x; v0=h/2-ee0[1]*px_per_m_y
        stat_pts=[(u0,v0)]; up, vp = u0, v0
        fut=act_seq[0].cpu().numpy()
        for dx,dy in fut[:,:2]: up+=dx*px_per_m_x; vp-=dy*px_per_m_y; stat_pts.append((up,vp))
        frames=[]
        raw_obs_batch,_=env.reset(); raw_obs={k:raw_obs_batch[k][0] for k in raw_obs_batch}
        traj_pts=[]
        for _ in trange(max_steps,desc=f"Ep{ep}"):
            # get actual EE
            ee=raw_obs['state'][-1,:3]
            frame=env.envs[0].render();h,w,_=frame.shape
            ua=w/2+ee[0]*px_per_m_x; va=h/2-ee[1]*px_per_m_y; traj_pts.append((ua,va))
            over=draw_traj(frame.copy(), stat_pts,(255,0,0))
            over=draw_traj(over, traj_pts,(0,255,0))
            frames.append(cv2.cvtColor(over,cv2.COLOR_RGB2BGR))
            act=act_seq[0,0].cpu().numpy(); rawn,_,tb,tb2,_=env.step(act[None]); raw_obs={k:rawn[k][0] for k in rawn}
            if tb[0] or tb2[0]: break
        out=os.path.join(output_dir,f"static_ep{ep}.mp4")
        wr=cv2.VideoWriter(out,cv2.VideoWriter_fourcc(*"mp4v"),30,(w,h))
        for f in frames: wr.write(f)
        wr.release(); print("Saved",out)
    env.close()

if __name__=="__main__": tyro.cli(main)
