In [41]:
import os
import numpy as np
import imageio
from pathlib import Path
import orbax.checkpoint as ocp
import sys


REPO_ROOT = Path("/home/mlic/kawon/251222_metaworld/metaworld-algorithms")
sys.path.insert(0, str(REPO_ROOT))

import metaworld_algorithms
print("ok:", metaworld_algorithms.__file__)
from metaworld_algorithms.checkpoint import get_agent_checkpoint_restore_args

ok: /home/mlic/kawon/251222_metaworld/metaworld-algorithms/metaworld_algorithms/__init__.py


In [None]:
CKPT_ROOT = Path(
    "/home/mlic/kawon/251222_metaworld/metaworld-algorithms/"

    "examples/multi_task/run_results9/mt10_custom_mtsac_1/checkpoints"
)

STEP = 1599990   # 네가 보고 싶은 step
agent_dir = CKPT_ROOT / str(STEP) / "agent"
SEED = 1

# MT10 Task List
MT10_TASKS = [
    "reach-v2", "push-v2", "pick-place-v2", "door-open-v2", "drawer-open-v2", "drawer-close-v2",
    "button-press-topdown-v2", "peg-insert-side-v2", "window-open-v2", "window-close-v2"
]
TARGET_TASK_IDX = 9
assert (CKPT_ROOT / str(STEP)).exists()

In [43]:
from metaworld_algorithms.envs import MetaworldConfig

env_config = MetaworldConfig(
    env_id="MT10",
    use_one_hot=True,
    terminate_on_success=False,
    max_episode_steps=200,
    reward_func_version="v2",
    num_goals=50,
    reward_normalization_method=None,
    normalize_observations=False,
    # render_mode 제거
)

In [51]:
import gymnasium as gym

# 10개의 환경을 동시 생성 (MT10의 모든 태스크 커버 가능)
envs = gym.make_vec(
    f"Meta-World/{env_config.env_id}",
    num_envs=10,  # <--- ✅ 10개 생성!
    seed=SEED,
    # env_config 값들 전달
    use_one_hot=env_config.use_one_hot,
    terminate_on_success=env_config.terminate_on_success,
    max_episode_steps=env_config.max_episode_steps,
    vector_strategy="async",
    reward_function_version=env_config.reward_func_version,
    num_goals=env_config.num_goals,
    reward_normalization_method=env_config.reward_normalization_method,
    normalize_observations=env_config.normalize_observations,
    render_mode="rgb_array", 
)

obs, _ = envs.reset()
print(f"Created {envs.num_envs} envs. Obs shape: {obs.shape}")

Created 10 envs. Obs shape: (10, 49)


In [45]:
from metaworld_algorithms.rl.algorithms.mtsac import MTSAC, MTSACConfig
from metaworld_algorithms.config.networks import ContinuousActionPolicyConfig, QValueFunctionConfig
from metaworld_algorithms.config.nn import VanillaNetworkConfig
from metaworld_algorithms.config.optim import OptimizerConfig

algo_config = MTSACConfig(
    num_tasks=10,
    gamma=0.99,

    actor_config=ContinuousActionPolicyConfig(
        network_config=VanillaNetworkConfig(
            optimizer=OptimizerConfig(max_grad_norm=1.0)
        )
    ),
    critic_config=QValueFunctionConfig(
        network_config=VanillaNetworkConfig(
            optimizer=OptimizerConfig(max_grad_norm=1.0),
        )
    ),
    num_critics=2,

    use_inter_task_sampling=True,
    use_intra_task_sampling=True,

    use_success_based_il=True,
    success_ema_tau=0.01,
    il_weight_mode="sigmoid",
    il_weight_temp=0.1,
    il_weight_power=2.0,
    il_loss_type="mse",
    il_coef=1.0,

    # 안전빵(디폴트가 같아도 명시 추천)
    il_qfilter_top_p=0.2,
    il_qfilter_min_good=8,
)

agent = MTSAC.initialize(algo_config, env_config, seed=SEED)

In [46]:
from orbax.checkpoint import checkpoint_utils

agent_ckptr = ocp.PyTreeCheckpointer()
# Construct restore_args ensuring mesh/sharding is set to current available device (e.g. CPU)
restore_args = checkpoint_utils.construct_restore_args(agent)
agent = agent_ckptr.restore(str(agent_dir), item=agent, restore_args=restore_args)

print("agent restored from:", agent_dir)

agent restored from: /home/mlic/kawon/251222_metaworld/metaworld-algorithms/examples/multi_task/run_results9/mt10_custom_mtsac_1/checkpoints/199990/agent


In [47]:
CKPT_ROOT = Path(
    "/home/mlic/kawon/251222_metaworld/metaworld-algorithms/"
    "examples/multi_task/run_results9/mt10_custom_mtsac_1/checkpoints"
)
step_dir = CKPT_ROOT / str(STEP)
agent_dir = step_dir / "agent"  # 경로 정의 복구

# Checkpoint restoration with fix
from orbax.checkpoint import checkpoint_utils
import orbax.checkpoint as ocp

agent_ckptr = ocp.PyTreeCheckpointer()

# Construct restore_args ensuring mesh/sharding is set to current available device (e.g. CPU)
# This prevents the "sharding passed to deserialization should be specified" error
restore_args = checkpoint_utils.construct_restore_args(agent)

agent = agent_ckptr.restore(str(agent_dir), item=agent, restore_args=restore_args)

print("agent restored from:", agent_dir)

agent restored from: /home/mlic/kawon/251222_metaworld/metaworld-algorithms/examples/multi_task/run_results9/mt10_custom_mtsac_1/checkpoints/199990/agent


In [48]:
# Rollout and Render
import imageio
import numpy as np
import jax
os.environ["EGL_LOG_LEVEL"] = "fatal"

frames = []
obs, info = envs.reset()
print("Starting rollout...")

for i in range(200):  # max_episode_steps
    # Deterministic action selection
    action = agent.eval_action(obs)
    
    # Environment step
    obs, reward, terminated, truncated, info = envs.step(np.array(action))
    
    # Render
    try:
        # AsyncVectorEnv 대응: call("render")로 모든 환경의 렌더링 결과를 리스트로 받음
        frames_list = envs.call("render")
        # 첫 번째 환경(0번)의 영상만 저장
        frames.append(frames_list[0])
    except Exception as e:
        print(f"Render failed at step {i}: {e}")
        break

    if terminated.any() or truncated.any():
        print(f"Episode terminated at step {i}")
        break

print(f"Rollout complete. Frames caught: {len(frames)}")

if frames:
    save_path = "rollout.mp4"
    imageio.mimsave(save_path, frames, fps=30)
    print(f"Video saved to {save_path}")
else:
    print("No frames captured.")

Starting rollout...


KeyboardInterrupt: 

In [None]:
# Rollout and Render with Analytics (Success + Background)
import imageio
import numpy as np
import jax
import jax.numpy as jnp
from PIL import Image, ImageDraw, ImageFont



frames = []
obs, info = envs.reset()
print("Starting rollout...")

# 폰트 설정
try:
    font = ImageFont.truetype("DejaVuSans-Bold.ttf", 12)
except:
    font = ImageFont.load_default()

for i in range(200):
    # 1. Action & Q-value
    action = agent.eval_action(obs)
    current_q = agent.q_min(obs, np.array(action))[0]

    # 2. Environment Step
    next_obs, reward, terminated, truncated, info = envs.step(np.array(action))
    
    # 3. Target Q & TD Error
    next_action_dist = agent.actor.apply_fn(agent.actor.params, next_obs)
    next_action = next_action_dist.mode()
    q_next_ens = agent.critic.apply_fn(agent.critic.target_params, next_obs, next_action)
    min_q_next = jnp.min(q_next_ens, axis=0)
    target_q = reward + (1 - terminated) * agent.gamma * min_q_next.flatten()
    td_error = abs(float(target_q[0]) - current_q)

    # 4. Task & Success Info
    task_onehot = obs[0, -10:]
    task_idx = int(np.argmax(task_onehot))
    task_name = MT10_TASKS[task_idx] if task_idx < len(MT10_TASKS) else f"Task {task_idx}"
    
    # info["success"]는 (num_envs,) 형태의 배열
    is_success = bool(info["success"][0]) if "success" in info else False
    success_str = "SUCCESS" if is_success else "FAIL"

    # 5. Render & Overlay
    try:
        frames_list = envs.call("render")
        frame_array = frames_list[0]
        
        img = Image.fromarray(frame_array)
        draw = ImageDraw.Draw(img, "RGBA") # RGBA 모드로 그리기 (투명도 지원)
        
        # 텍스트 내용
        text = (
            f"Task: {task_name}\n"
            f"Step: {i}\n"
            f"Q-value: {current_q:.2f}\n"
            f"TD-Error: {td_error:.4f}\n"
            f"Status: {success_str}"
        )
        
        # 텍스트 크기 계산하여 배경 박스 그리기
        bbox = draw.textbbox((10, 10), text, font=font)
        # 박스에 여백(padding) 좀 주고, 반투명 검은색(0,0,0,160)
        draw.rectangle(
            (bbox[0]-5, bbox[1]-5, bbox[2]+5, bbox[3]+5),
            fill=(0, 0, 0, 160)
        )
        
        # 텍스트 그리기 (흰색)
        # 성공하면 녹색, 아니면 흰색 등으로 색상 변경 가능
        text_color = (100, 255, 100) if is_success else (255, 255, 255)
        draw.text((10, 10), text, font=font, fill=text_color)
        
        frames.append(np.array(img))
    except Exception as e:
        print(f"Render failed at step {i}: {e}")
        break

    if terminated.any() or truncated.any():
        print(f"Episode terminated at step {i}")
        break
        
    obs = next_obs

print(f"Rollout complete. Frames caught: {len(frames)}")

if frames:
    save_path = "rollout_overlay.gif"
    imageio.mimsave(save_path, frames, fps=30, loop=0)
    print(f"Video saved to {save_path}")
else:
    print("No frames captured.")

AlreadyPendingCallError: Calling `reset_async` while waiting for a pending call to `call` to complete

In [57]:
# Rollout and Render (Multi-env support)
import imageio
import numpy as np
import jax
import jax.numpy as jnp
from PIL import Image, ImageDraw, ImageFont
import os
# MuJoCo가 헤드리스 모드에서 GPU/EGL을 사용하도록 설정
# 이 코드는 반드시 'import gymnasium'이나 'import mujoco' 보다 *먼저* 실행되어야 합니다!
os.environ["MUJOCO_GL"] = "egl"
os.environ["EGL_LOG_LEVEL"] = "fatal"
# ✅ 원하는 태스크 번호 선택 (0~9)

print(f"Targeting Task {TARGET_TASK_IDX}: {MT10_TASKS[TARGET_TASK_IDX]}")

frames = []
episode_return = 0.0  # 누적 보상 초기화

obs, info = envs.reset()

try:
    font = ImageFont.truetype("DejaVuSans-Bold.ttf", 12)
except:
    font = ImageFont.load_default()

for i in range(200):
    # 1. Action (All 10 envs)
    action = agent.eval_action(obs) # (10, action_dim)
    
    # Target의 Q-value만 계산 (Overlay용)
    # obs[TARGET_TASK_IDX:TARGET_TASK_IDX+1] 형태로 슬라이싱해서 (1, dim)으로 전달
    target_obs = obs[TARGET_TASK_IDX][None, ...] 
    target_action = np.array(action)[TARGET_TASK_IDX][None, ...]
    current_q = agent.q_min(target_obs, target_action)[0]

    r_val = float(reward[TARGET_TASK_IDX])
    episode_return += r_val

    # 2. Environment Step (All 10 envs)
    next_obs, reward, terminated, truncated, info = envs.step(np.array(action))
    
    # 3. Target Q & TD Error (Only for target)
    target_next_obs = next_obs[TARGET_TASK_IDX][None, ...]
    
    next_action_dist = agent.actor.apply_fn(agent.actor.params, target_next_obs)
    next_act_mode = next_action_dist.mode()
    q_next_ens = agent.critic.apply_fn(agent.critic.target_params, target_next_obs, next_act_mode)
    min_q_next = jnp.min(q_next_ens, axis=0) # (1, 1) or (1,)
    
    # reward, term도 해당 인덱스 것만
    r_val = float(reward[TARGET_TASK_IDX])
    done_val = float(terminated[TARGET_TASK_IDX])
    
    target_q_val = r_val + (1 - done_val) * agent.gamma * float(min_q_next.flatten()[0])
    td_error = abs(target_q_val - current_q)

    # 4. Info
    task_onehot = obs[TARGET_TASK_IDX, -10:]
    task_idx_found = int(np.argmax(task_onehot))
    task_name = MT10_TASKS[task_idx_found]
    
    success_flags = info["success"] if "success" in info else [False]*10
    is_success = bool(success_flags[TARGET_TASK_IDX])
    success_str = "SUCCESS" if is_success else "FAIL"

    # 5. Render (Target env only)
    try:
        # call("render") returns list of 10 arrays
        frames_list = envs.call("render")
        frame_array = frames_list[TARGET_TASK_IDX] # ✅ 원하는 것만 픽!
        
        img = Image.fromarray(frame_array)
        draw = ImageDraw.Draw(img, "RGBA")
        
        text = (
            f"Target: {MT10_TASKS[TARGET_TASK_IDX]}\n"
            f"Actual: {task_name} (idx {task_idx_found})\n"
            f"Reward: {r_val:.4f}\n"      # 이번 스텝 보상
            f"Return: {episode_return:.4f}\n" # 누적 보상
            f"Step: {i}\n"
            f"Q-value: {current_q:.2f}\n"
            f"TD-Error: {td_error:.4f}\n"
            f"Status: {success_str}"
        )
        bbox = draw.textbbox((10, 10), text, font=font)
        draw.rectangle((bbox[0]-5, bbox[1]-5, bbox[2]+5, bbox[3]+5), fill=(0, 0, 0, 160))
        text_color = (100, 255, 100) if is_success else (255, 255, 255)
        draw.text((10, 10), text, font=font, fill=text_color)
        
        frames.append(np.array(img))
    except Exception as e:
        print(f"Render failed at step {i}: {e}")
        break

    # 타겟 환경이 끝나면 종료 (선택사항)
    if terminated[TARGET_TASK_IDX] or truncated[TARGET_TASK_IDX]:
        print(f"Target episode terminated at step {i}")
        break
        
    obs = next_obs

print(f"Rollout complete. Frames caught: {len(frames)}")

if frames:
    save_path = f"rollout_task{TARGET_TASK_IDX}-{STEP}.gif"
    imageio.mimsave(save_path, frames, fps=30, loop=0)
    print(f"Video saved to {save_path}")
else:
    print("No frames captured.")

Targeting Task 9: window-close-v2
Target episode terminated at step 199
Rollout complete. Frames caught: 200
Video saved to rollout_task9-199990.gif
