In [1]:
import gc
import os
import time
import random
from typing import Any, List, Tuple

import gymnasium as gym
import gymnasium_robotics

import matplotlib.pyplot as plt
from IPython.display import HTML
from matplotlib import animation
import numpy as np
# import pybullet_envs  # PyBulletの環境をgymに登録する
import torch

%env MUJOCO_GL=egl

env: MUJOCO_GL=egl


In [8]:
%cd ../
from models.dreamer.agent import Agent, Encoder, RSSM, ValueModel, ActionModel, preprocess_obs
from models.wrapper import GymWrapper, RepeatAction

/home/flet_pro/in-hand_manipulation_wm_2024


In [9]:
def make_env(seed=None, max_steps=50) -> RepeatAction:
    """
    作成たラッパーをまとめて適用して環境を作成する関数．

    Returns
    -------
    env : RepeatAction
        ラッパーを適用した環境．
    """
    gym.register_envs(gymnasium_robotics)
    env = gym.make('HandManipulateBlockRotateZ_BooleanTouchSensorsDense-v1', render_mode="rgb_array", max_episode_steps=max_steps)
    env.reset(seed=seed)
    env.action_space.seed(seed)
    env.observation_space.seed(seed)

    # Dreamerでは観測は64x64のRGB画像
    env = GymWrapper(
        env, render_width=64, render_height=64
    )
    env = RepeatAction(env, skip=2)  # DreamerではActionRepeatは2
    return env

In [15]:
# 結果を動画で観てみるための関数
def display_video(frames: List[np.ndarray], save=False) -> None:
    """
    結果を動画にするための関数．

    frames : List[np.ndarray]
        観測画像をリスト化したもの．
    """
    plt.figure(figsize=(8, 8), dpi=50)
    patch = plt.imshow(frames[0])
    plt.axis("off")

    def animate(i):
        patch.set_data(frames[i])
        plt.title("Step %d" % (i))

    anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=50)
    display(HTML(anim.to_jshtml(default_mode="once")))
    if save:
        anim.save('videos/anim.mp4', writer="ffmpeg")
    plt.close()

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

state_dim = 32  # 確率的状態の次元
rnn_hidden_dim = 600  # 決定的状態（RNNの隠れ状態）の次元
action_spaces = 20
encoder = Encoder().to(device)
rssm = RSSM(
    state_dim,
    action_spaces,
    rnn_hidden_dim,
    device
)
value_model = ValueModel(state_dim, rnn_hidden_dim).to(device)
action_model = ActionModel(state_dim, rnn_hidden_dim, action_spaces).to(
    device
)

model_log_dir = "logs/logs-Dreamer_Box_epocks-2000_v2.0_20250126/episode_0600"

# Load the saved state dictionaries
encoder.load_state_dict(torch.load(os.path.join(model_log_dir, 'encoder.pth'), weights_only=True))
rssm.transition.load_state_dict(torch.load(os.path.join(model_log_dir, 'rssm.pth'), weights_only=True))
rssm.observation.load_state_dict(torch.load(os.path.join(model_log_dir, 'obs_model.pth'), weights_only=True))
rssm.reward.load_state_dict(torch.load(os.path.join(model_log_dir, 'reward_model.pth'), weights_only=True))
value_model.load_state_dict(torch.load(os.path.join(model_log_dir, 'value_model.pth'), weights_only=True))
action_model.load_state_dict(torch.load(os.path.join(model_log_dir, 'action_model.pth'), weights_only=True))

<All keys matched successfully>

In [23]:
max_episodes = 50
env = make_env(seed=1234, max_steps=max_episodes)

policy = Agent(encoder, rssm.transition, action_model)
obs, obs_hand = env.reset()
terminated = False
total_reward = 0
frames = [env.render()]
for i in range(max_episodes):
    action = policy(obs, obs_hand, training=False)
    obs, obs_hand, reward, terminated, truncated, _ = env.step(action)
    total_reward += reward
    frames.append(env.render())

print("Total Reward:", total_reward)

Total Reward: -71.87982444916737


In [24]:
display_video(frames, save=True)