In [1]:
import gc
import os
import time
import random
from typing import Any, List, Tuple

import gymnasium as gym
import gymnasium_robotics

import matplotlib.pyplot as plt
from IPython.display import HTML
from matplotlib import animation
import numpy as np
# import pybullet_envs  # PyBulletの環境をgymに登録する
import torch
import cv2

%env MUJOCO_GL=egl

env: MUJOCO_GL=egl


In [2]:
%cd /home/afs/tmp/in-hand_manipulation_wm_2024/
from models.dreamer.agent import Agent, Encoder, RSSM, ValueModel, ActionModel, preprocess_obs
from models.wrapper import GymWrapper, RepeatAction

/home/afs/tmp/in-hand_manipulation_wm_2024


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [3]:
class CustomManipulateBoxEnv(gym.Env):
    def __init__(self, env):
        super().__init__()
        self.env = env
        self._render_width = 64
        self._render_height = 64
        self.box_drop_threshold = 0.05  # 卵が落下したと判断する高さの閾値（メートル単位）

    def reset(self, **kwargs):
        observation, info = self.env.reset(**kwargs)
        obs_hand_all = np.concatenate([observation['observation'], observation['desired_goal'], observation['achieved_goal']], axis=-1)
        img = self.render() # observation
        return img, obs_hand_all

    def step(self, action):
        observation, reward, terminated, truncated, info = self.env.step(action)
        obs_hand_all = np.concatenate([observation['observation'], observation['desired_goal'], observation['achieved_goal']], axis=-1)
        if observation["observation"][56] < self.box_drop_threshold:
            reward = -1000
            terminated = True
        image = self.env.render() # observation
        return image, obs_hand_all, reward, terminated, info

    def render(self):
        img = self.env.render()
        img = img[120:400, 200:480]
        img = cv2.resize(img, (self._render_height, self._render_width), interpolation=cv2.INTER_LINEAR)
        return img

    def close(self):
        self.env.close()

    def __getattr__(self, name):
        return getattr(self.env, name)
    
class RepeatAction(CustomManipulateBoxEnv):
    """
    同じ行動を指定された回数自動的に繰り返すラッパー. 観測は最後の行動に対応するものになる
    """
    def __init__(self, env, skip=4, max_steps=100):
        gym.Wrapper.__init__(self, env)
        self.max_steps = max_steps if max_steps else float("inf")  # イテレーションの制限
        self.steps = 0  # イテレーション回数のカウント
        self._skip = skip

    @property
    def observation_space():
        img = self.env.render()
        return gym.spaces.Box(img)
        
    def reset(self):
        img, obs = self.env.reset()
        self.steps = 0
        return img, obs

    def step(self, action):
        # if self.steps >= self.max_steps:  # 100kに達したら何も返さないようにする
        #     print("Reached max iterations.")
        #     return None

        total_reward = 0.0
        self.steps += 1
        for _ in range(self._skip):
            img, obs, reward, done, info = self.env.step(action)
            img = self.env.render()

            total_reward += reward
            if self.steps >= self.max_steps:  # 100kに達したら終端にする
                done = True

            if done:
                break

        return img, obs, total_reward, done, info

gym.register_envs(gymnasium_robotics)
def make_env(seed=None, max_steps=100) -> RepeatAction:
    """
    作成たラッパーをまとめて適用して環境を作成する関数．

    Returns
    -------
    env : RepeatAction
        ラッパーを適用した環境．
    """
    env = gym.make('HandManipulateBlockRotateZDense-v1', render_mode="rgb_array", max_episode_steps=max_steps)
    env.reset(seed=seed)
    env.action_space.seed(seed)
    env.observation_space.seed(seed)

    env = CustomManipulateBoxEnv(env)
    env = RepeatAction(env, skip=2)  # DreamerではActionRepeatは2
    return env

# 結果を動画で観てみるための関数
def display_video(frames: List[np.ndarray]) -> None:
    """
    結果を動画にするための関数．

    frames : List[np.ndarray]
        観測画像をリスト化したもの．
    """
    plt.figure(figsize=(8, 8), dpi=50)
    patch = plt.imshow(frames[0])
    plt.axis("off")

    def animate(i):
        patch.set_data(frames[i])
        plt.title("Step %d" % (i))

    anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=50)
    display(HTML(anim.to_jshtml(default_mode="once")))
    plt.close()

def save_video(frames: List[np.ndarray], filename: str = "output.mp4", fps: int = 20) -> None:
    """
    結果を動画に保存するための関数．

    frames : List[np.ndarray]
        観測画像をリスト化したもの．
    filename : str
        保存する動画のファイル名．デフォルトは "output.mp4"．
    fps : int
        動画のフレームレート．デフォルトは20．
    """
    plt.figure(figsize=(8, 8), dpi=50)
    patch = plt.imshow(frames[0])
    plt.axis("off")

    def animate(i):
        patch.set_data(frames[i])
        plt.title("Step %d" % (i))

    anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=1000/fps)

    # 動画を保存
    anim.save(filename, writer="ffmpeg", fps=fps)
    plt.close()


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# モデルの宣言
state_dim = 32  # 確率的状態の次元
rnn_hidden_dim = 600  # 決定的状態（RNNの隠れ状態）の次元
obs_hand_dim = 75  # ロボットの観測の次元
action_spaces = 20  # 行動の次元
# 確率的状態の次元と決定的状態（RNNの隠れ状態）の次元は一致しなくて良い
encoder = Encoder().to(device)
rssm = RSSM(
    state_dim,
    action_spaces,
    rnn_hidden_dim,
    obs_hand_dim,
    device
)
value_model = ValueModel(state_dim, rnn_hidden_dim).to(device)
action_model = ActionModel(state_dim, rnn_hidden_dim, action_spaces).to(device)

model_log_dir = "/home/afs/tmp/in-hand_manipulation_wm_2024/logs/logs-Dreamer_Box_epocks-1000_episode-_CustomEnv_v1.0_20250129/episode_1000"

# Load the saved state dictionaries
encoder.load_state_dict(torch.load(os.path.join(model_log_dir, 'encoder.pth'), weights_only=True))
rssm.transition.load_state_dict(torch.load(os.path.join(model_log_dir, 'rssm.pth'), weights_only=True))
rssm.observation.load_state_dict(torch.load(os.path.join(model_log_dir, 'obs_model.pth'), weights_only=True))
rssm.reward.load_state_dict(torch.load(os.path.join(model_log_dir, 'reward_model.pth'), weights_only=True))
value_model.load_state_dict(torch.load(os.path.join(model_log_dir, 'value_model.pth'), weights_only=True))
action_model.load_state_dict(torch.load(os.path.join(model_log_dir, 'action_model.pth'), weights_only=True))

<All keys matched successfully>

In [5]:
max_episodes = 100
env = make_env(seed=1234, max_steps=max_episodes)

policy = Agent(encoder, rssm.transition, action_model)
obs, obs_hand = env.reset()
terminated = False
total_reward = 0
frames = [obs]
for i in range(max_episodes):
    action = policy(obs, obs_hand, training=False)
    obs, obs_hand, reward, terminated, _ = env.step(action)
    total_reward += reward
    frames.append(obs)

print("Total Reward:", total_reward)





Total Reward: -162.33374515162447


In [6]:
display_video(frames)

In [7]:
save_video(frames, "videos/anim_box_003.mp4", fps=10)

In [23]:
!pwd

/home/afs/in-hand_manipulation_wm_2024
