# 環境構築

In [None]:
!git clone https://github.com/AlignmentResearch/vlmrm.git

In [None]:
!tar -zxvf mujoco210-linux-x86_64.tar.gz
!mkdir ~/.mujoco
!cp -r mujoco210 ~/.mujoco/mujoco210
!cp -r ~/.mujoco/mujoco210/bin/* /usr/lib/

In [None]:
%cd vlmrm

In [None]:
!pip install setuptools==65.5.0 "wheel<0.40.0"

!pip install -e ".[dev]"

!echo y | pip uninstall opencv-contrib-python=='4.8.0.76'
!pip install opencv-contrib-python=='4.8.0.74'
!apt update && apt install libosmesa6-dev libgl1-mesa-glx libglfw3-dev patchelf xvfb freeglut3-dev libgles2-mesa-dev -y
!pip3 install -U 'mujoco-py<2.2,>=2.1' pyvirtualdisplay

In [None]:
%cd 'src'

In [None]:
import os
os.environ["LD_LIBRARY_PATH"] = "/root/.mujoco/mujoco210/bin"
os.environ["PYOPENGL_PLATFORM"] = "osmesa"
os.environ["MUJOCO_GL"] = "osmesa"

In [None]:
import time
import gc
import random
import pathlib
from typing import Any, Dict, List, Optional, Tuple, overload
from numpy.typing import NDArray

import numpy as np
import matplotlib.pyplot as plt
import open_clip

import mujoco_py
import gymnasium
from gymnasium import utils
from gymnasium.envs.mujoco import MujocoEnv
from gymnasium.envs.mujoco.humanoid_v4 import HumanoidEnv as GymHumanoidEnv
from gymnasium.spaces import Box

import torch
from torch.distributions import Normal
from torch.distributions.kl import kl_divergence
from torch import nn
from torch.nn import functional as F
from torch.nn.utils import clip_grad_norm_


from vlmrm.contrib.open_clip.transform import image_transform
from vlmrm.trainer.config import CLIPRewardConfig, Config

%load_ext tensorboard

In [None]:
from pyvirtualdisplay import Display
pydisplay = Display(visible=0, size=(400, 300))
pydisplay.start()

In [None]:
env = gymnasium.make("Humanoid-v4", render_mode="rgb_array")
env.reset()
image = env.render()
plt.imshow(image) # (480, 480, 3)
plt.show()
env.close()

# Reward Model

In [None]:
vlm = {"env_name": "Humanoid-v4",
       "base_path": "data/runs/training",
       "seed": 42,
       "description": "Humanoid kneeling",
       "tags": ["kneeling", "humanoid", "clip", "model-scaling"],
       "reward": {
                  "name": "clip",
                  "pretrained_model": "ViT-bigG-14/laion2b_s39b_b160k",
                  "batch_size": 1,
                  "alpha": 0.0,
                  "target_prompts": ["a humanoid robot kneeling"],
                  "baseline_prompts": ["a humanoid robot"],
                  "cache_dir": "/data/cache",
                  "camera_config": {
                      "trackbodyid": 1,
                      "distance": 3.5,
                      "lookat": [0.0, 0.0, 1.0],
                      "elevation": -10.0,
                      "azimuth": 180.0,
                  }
       },
       "rl": {"policy_name": "MlpPolicy",
            "n_steps": 10000000,
            "n_envs_per_worker": 8,
            "episode_length": 100,
            "learning_starts": 50000,
            "train_freq": 100,
            "batch_size": 512,
            "gradient_steps": 100,
            "tau": 0.005,
            "gamma": 0.95,
            "learning_rate": 6e-4
        },
        "logging": {
                  "checkpoint_freq": 128000,
                  "video_freq": 128000
                }
       }

In [None]:
config = Config(**vlm)

In [None]:
# src/vlmrm/envs/mujoco/clip_rewarded_humanoid.py

DEFAULT_CAMERA_CONFIG = {
    "trackbodyid": 1,
    "distance": 3.5,
    "lookat": np.array((0.0, 0.0, 1.0)),
    "elevation": -10.0,
    "azimuth": 180.0,
}

class CLIPRewardedHumanoidEnv(GymHumanoidEnv):
    def __init__(
        self,
        episode_length: int = 100,
        render_mode: str = "rgb_array",
        forward_reward_weight: float = 1.25,
        ctrl_cost_weight: float = 0.1,
        healthy_reward: float = 5.0,
        healthy_z_range: Tuple[float] = (1.0, 2.0),
        reset_noise_scale: float = 1e-2,
        exclude_current_positions_from_observation: bool = True,
        camera_config: Optional[Dict[str, Any]] = DEFAULT_CAMERA_CONFIG,
        textured: bool = True,
        **kwargs,
    ):
        terminate_when_unhealthy = False
        utils.EzPickle.__init__(
            self,
            forward_reward_weight,
            ctrl_cost_weight,
            healthy_reward,
            terminate_when_unhealthy,
            healthy_z_range,
            reset_noise_scale,
            exclude_current_positions_from_observation,
            render_mode=render_mode,
            **kwargs,
        )

        self._forward_reward_weight = forward_reward_weight
        self._ctrl_cost_weight = ctrl_cost_weight
        self._healthy_reward = healthy_reward
        self._terminate_when_unhealthy = terminate_when_unhealthy
        self._healthy_z_range = healthy_z_range

        self._reset_noise_scale = reset_noise_scale

        self._exclude_current_positions_from_observation = (
            exclude_current_positions_from_observation
        )

        if exclude_current_positions_from_observation:
            observation_space = Box(
                low=-np.inf, high=np.inf, shape=(64, 64, 3), dtype=np.float64
            )
        else:
            observation_space = Box(
                low=-np.inf, high=np.inf, shape=(378,), dtype=np.float64
            )
        env_file_name = None
        if textured:
            env_file_name = "humanoid_textured.xml"
        else:
            env_file_name = "humanoid.xml"
        model_path = 'vlmrm/envs/mujoco/' + env_file_name
        MujocoEnv.__init__(
            self,
            model_path,
            5,
            observation_space=observation_space,
            width=64,
            height=64,
            default_camera_config=camera_config,
            render_mode=render_mode,
            **kwargs,
        )
        self.episode_length = episode_length
        self.num_steps = 0
        if camera_config:
            self.camera_id = -1

    def step(self, action) -> Tuple[NDArray, float, bool, bool, Dict]:
        obs, reward, terminated, truncated, info = super().step(action)
        obs = self.render()
        self.num_steps += 1
        terminated = self.num_steps >= self.episode_length
        return obs, reward, terminated, info

    def reset(self, *, seed: Optional[int] = 42, options: Optional[Dict] = None):
        self.num_steps = 0
        obs = self.render()
        return obs  # super().reset(seed=seed, options=options)

In [None]:
# src/vlmrm/reward_model.py

class CLIPEmbed(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.clip_model = clip_model
        if isinstance(clip_model.visual.image_size, int):
            image_size = clip_model.visual.image_size
        else:
            image_size = clip_model.visual.image_size[0]
        self.transform = image_transform(image_size)

    @torch.inference_mode()
    def forward(self, x):
        if x.shape[1] != 3:
            x = x.permute(0, 3, 1, 2)

        with torch.no_grad(), torch.autocast("cuda", enabled=torch.cuda.is_available()):
            x = self.transform(x)
            x = self.clip_model.encode_image(x, normalize=True)
        return x


class CLIPReward(nn.Module):
    def __init__(
        self,
        *,
        model: CLIPEmbed,
        alpha: float,
        target_prompts: torch.Tensor,
        baseline_prompts: torch.Tensor,
    ) -> None:
        """
        Args:
            model (str): CLIP model.
            device (str): Device to use.
            alpha (float, optional): Coefficient of projection.
            target_prompts (torch.Tensor): Tokenized prompts describing
                the target state.
            baseline_prompts (torch.Tensor): Tokenized prompts describing
                the baseline state.
        """
        super().__init__()
        self.embed_module = model
        target = self.embed_prompts(target_prompts).mean(dim=0, keepdim=True)
        baseline = self.embed_prompts(baseline_prompts).mean(dim=0, keepdim=True)
        direction = target - baseline
        # Register them as buffers so they are automatically moved around.
        self.register_buffer("target", target)
        self.register_buffer("baseline", baseline)
        self.register_buffer("direction", direction)

        self.alpha = alpha
        projection = self.compute_projection(alpha)
        self.register_buffer("projection", projection)

    def compute_projection(self, alpha: float) -> torch.Tensor:
        projection = self.direction.T @ self.direction / torch.norm(self.direction) ** 2
        identity = torch.diag(torch.ones(projection.shape[0])).to(projection.device)
        projection = alpha * projection + (1 - alpha) * identity
        return projection

    def update_alpha(self, alpha: float) -> None:
        self.alpha = alpha
        self.projection = self.compute_projection(alpha)

    @torch.inference_mode()
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x / torch.norm(x, dim=-1, keepdim=True)
        y = 1 - (torch.norm((x - self.target) @ self.projection, dim=-1) ** 2) / 2
        return y

    @staticmethod
    def tokenize_prompts(x: List[str]) -> torch.Tensor:
        """Tokenize a list of prompts."""
        return open_clip.tokenize(x)

    def embed_prompts(self, x) -> torch.Tensor:
        """Embed a list of prompts."""
        with torch.no_grad():
            x = self.embed_module.clip_model.encode_text(x).float()
        x = x / x.norm(dim=-1, keepdim=True)
        return x

    def embed_images(self, x):
        return self.embed_module.forward(x)


def load_reward_model(model_name, target_prompts, baseline_prompts, alpha, cache_dir: Optional[str] = None):
    model_name_prefix, pretrained = model_name.split("/")
    model = open_clip.create_model(
        model_name=model_name_prefix, pretrained=pretrained, cache_dir=cache_dir
    )
    target_prompts = CLIPReward.tokenize_prompts(target_prompts)
    baseline_prompts = CLIPReward.tokenize_prompts(baseline_prompts)
    model = CLIPEmbed(model)
    model = CLIPReward(
        model=model,
        alpha=alpha,
        target_prompts=target_prompts,
        baseline_prompts=baseline_prompts,
    )
    return model.eval()


def load_reward_model_from_config(config: CLIPRewardConfig) -> CLIPReward:
    return load_reward_model(
        model_name=config.pretrained_model,
        target_prompts=config.target_prompts,
        baseline_prompts=config.baseline_prompts,
        alpha=config.alpha,
        cache_dir=config.cache_dir,
    )


def compute_reward(model: CLIPEmbed, frame: torch.Tensor) -> torch.Tensor:
    reward_model = model.eval()
    with torch.no_grad():
        embedding = reward_model.embed_module(frame)
        reward = reward_model(embedding)

    reward = reward.to('cpu').detach().numpy()
    return reward

In [None]:
# src/vlmrm/contrib/sb3/clip_rewarded_sac.py

device = 'cuda' if torch.cuda.is_available() else "cpu"
reward_model = load_reward_model_from_config(config.reward).to(device)

def clip_reward(obs: NDArray, model: CLIPReward) -> NDArray:
    frame = obs.copy()
    frame = torch.from_numpy(frame)
    frame = torch.unsqueeze(frame, dim=0)
    frame = frame.permute(0, 3, 1, 2)
    upsample = nn.Upsample(scale_factor=3.5, mode='bilinear')
    frame = upsample(frame)

    reward = compute_reward(
        model=model,
        frame=frame.to(device),
    )
    return reward

In [None]:
class RepeatAction(gymnasium.Wrapper):
    def __init__(self, env, skip=2, model=reward_model):
        gymnasium.Wrapper.__init__(self, env)
        self._skip = skip
        self.reward_model = reward_model

    def reset(self):
        return self.env.reset()

    def step(self, action):
        total_reward = 0.0
        for _ in range(self._skip):
            obs, reward, done, info = self.env.step(action)
            reward = clip_reward(obs, self.reward_model)
            total_reward += reward

            if done:
                break
        return obs, total_reward, done, info

In [None]:
def humanoid_env(seed=42):
    env = CLIPRewardedHumanoidEnv()

    env.action_space.seed(seed)
    env.observation_space.seed(seed)

    env = RepeatAction(env)

    return env

In [None]:
def set_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [None]:
env = humanoid_env()
env.reset()
obs, reward, terminated, info = env.step(env.action_space.sample())
plt.imshow(obs) # (64, 64, 3)
plt.show()

# モデルの実装







In [None]:
class TransitionModel(nn.Module):
    def __init__(self, state_dim, action_dim, rnn_hidden_dim,
                 hidden_dim=200, min_stddev=0.1, act=F.elu):
        super(TransitionModel, self).__init__()
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.rnn_hidden_dim = rnn_hidden_dim
        self.fc_state_action = nn.Linear(state_dim + action_dim, hidden_dim)

        self.fc_rnn_hidden = nn.Linear(rnn_hidden_dim, hidden_dim)
        self.fc_state_mean_prior = nn.Linear(hidden_dim, state_dim)
        self.fc_state_stddev_prior = nn.Linear(hidden_dim, state_dim)
        self.fc_rnn_hidden_embedded_obs = nn.Linear(rnn_hidden_dim + 1024, hidden_dim)
        self.fc_state_mean_posterior = nn.Linear(hidden_dim, state_dim)
        self.fc_state_stddev_posterior = nn.Linear(hidden_dim, state_dim)

        self.rnn = nn.GRUCell(hidden_dim, rnn_hidden_dim)
        self._min_stddev = min_stddev
        self.act = act


    def forward(self, state, action, rnn_hidden, embedded_next_obs):
        next_state_prior, rnn_hidden = self.prior(self.recurrent(state, action, rnn_hidden))
        next_state_posterior = self.posterior(rnn_hidden, embedded_next_obs)
        return next_state_prior, next_state_posterior, rnn_hidden

    def recurrent(self, state, action, rnn_hidden):
        """
        h_t+1 = f(h_t, s_t, a_t)
        """
        hidden = self.act(self.fc_state_action(torch.cat([state, action], dim=1)))
        rnn_hidden = self.rnn(hidden, rnn_hidden)
        return rnn_hidden

    def prior(self, rnn_hidden):
        """
        prior p(s_t+1 | h_t+1)
        """
        hidden = self.act(self.fc_rnn_hidden(rnn_hidden))

        mean = self.fc_state_mean_prior(hidden)
        stddev = F.softplus(self.fc_state_stddev_prior(hidden)) + self._min_stddev
        return Normal(mean, stddev), rnn_hidden

    def posterior(self, rnn_hidden, embedded_obs):
        """
        posterior q(s_t+1 | h_t+1, e_t+1)
        """
        hidden = self.act(
            self.fc_rnn_hidden_embedded_obs(
                torch.cat([rnn_hidden, embedded_obs], dim=1)
            )
        )
        mean = self.fc_state_mean_posterior(hidden)
        stddev = F.softplus(self.fc_state_stddev_posterior(hidden)) + self._min_stddev
        return Normal(mean, stddev)

In [None]:
class ObservationModel(nn.Module):
    """
    p(o_t | s_t, h_t)
    """
    def __init__(self, state_dim, rnn_hidden_dim):
        super(ObservationModel, self).__init__()
        self.fc = nn.Linear(state_dim + rnn_hidden_dim, 1024)
        self.dc1 = nn.ConvTranspose2d(1024, 128, kernel_size=5, stride=2)
        self.dc2 = nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2)
        self.dc3 = nn.ConvTranspose2d(64, 32, kernel_size=6, stride=2)
        self.dc4 = nn.ConvTranspose2d(32, 3, kernel_size=6, stride=2)


    def forward(self, state, rnn_hidden):
        hidden = self.fc(torch.cat([state, rnn_hidden], dim=1))
        hidden = hidden.view(hidden.size(0), 1024, 1, 1)
        hidden = F.relu(self.dc1(hidden))
        hidden = F.relu(self.dc2(hidden))
        hidden = F.relu(self.dc3(hidden))
        obs = self.dc4(hidden)
        return obs

In [None]:
class RewardModel(nn.Module):
    """
    p(r_t | s_t, h_t)
    """
    def __init__(self, state_dim, rnn_hidden_dim, hidden_dim=400, act=F.elu):
        super(RewardModel, self).__init__()
        self.fc1 = nn.Linear(state_dim + rnn_hidden_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, 1)
        self.act = act


    def forward(self, state, rnn_hidden):
        hidden = self.act(self.fc1(torch.cat([state, rnn_hidden], dim=1)))
        hidden = self.act(self.fc2(hidden))
        hidden = self.act(self.fc3(hidden))
        reward = self.fc4(hidden)
        return reward

In [None]:
class RSSM(nn.Module):
    def __init__(self, state_dim, action_dim, rnn_hidden_dim):
        super().__init__()

        self.transition = TransitionModel(state_dim, action_dim, rnn_hidden_dim).to(device)
        self.observation = ObservationModel(state_dim, rnn_hidden_dim,).to(device)
        self.reward = RewardModel(state_dim, rnn_hidden_dim,).to(device)

In [None]:
class ReplayBuffer(object):
    def __init__(self, capacity, observation_shape, action_dim):
        self.capacity = capacity

        self.observations = np.zeros((capacity, *observation_shape), dtype=np.uint8)
        self.actions = np.zeros((capacity, action_dim), dtype=np.float32)
        self.rewards = np.zeros((capacity, 1), dtype=np.float32)
        self.done = np.zeros((capacity, 1), dtype=np.bool_)

        self.index = 0
        self.is_filled = False

    def push(self, observation, action, reward, done):
        self.observations[self.index] = observation
        self.actions[self.index] = action
        self.rewards[self.index] = reward
        self.done[self.index] = done

        if self.index == self.capacity - 1:
            self.is_filled = True
        self.index = (self.index + 1) % self.capacity

    def sample(self, batch_size, chunk_length):
        episode_borders = np.where(self.done)[0]
        sampled_indexes = []
        for _ in range(batch_size):
            cross_border = True
            while cross_border:
                initial_index = np.random.randint(len(self) - chunk_length + 1)
                final_index = initial_index + chunk_length - 1
                cross_border = np.logical_and(initial_index <= episode_borders,
                                              episode_borders < final_index).any()
            sampled_indexes += list(range(initial_index, final_index + 1))

        sampled_observations = self.observations[sampled_indexes].reshape(
            batch_size, chunk_length, *self.observations.shape[1:])
        sampled_actions = self.actions[sampled_indexes].reshape(
            batch_size, chunk_length, self.actions.shape[1])
        sampled_rewards = self.rewards[sampled_indexes].reshape(
            batch_size, chunk_length, 1)
        sampled_done = self.done[sampled_indexes].reshape(
            batch_size, chunk_length, 1)
        return sampled_observations, sampled_actions, sampled_rewards, sampled_done

    def __len__(self):
        return self.capacity if self.is_filled else self.index

In [None]:
def preprocess_obs(obs):
    """
    [0, 255] -> [-0.5, 0.5]
    """
    obs = obs.astype(np.float32)
    normalized_obs = obs / 255.0 - 0.5
    return normalized_obs

In [None]:
def lambda_target(rewards, values, gamma, lambda_):
    """
    λ-return
    """
    V_lambda = torch.zeros_like(rewards, device=rewards.device)

    H = rewards.shape[0] - 1
    V_n = torch.zeros_like(rewards, device=rewards.device)
    V_n[H] = values[H]
    for n in range(1, H+1):
        V_n[:-n] = (gamma ** n) * values[n:]
        for k in range(1, n+1):
            if k == n:
                V_n[:-n] += (gamma ** (n-1)) * rewards[k:]
            else:
                V_n[:-n] += (gamma ** (k-1)) * rewards[k:-n+k]

        if n == H:
            V_lambda += (lambda_ ** (H-1)) * V_n
        else:
            V_lambda += (1 - lambda_) * (lambda_ ** (n-1)) * V_n

    return V_lambda

In [None]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.cv1 = nn.Conv2d(3, 32, kernel_size=4, stride=2)
        self.cv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.cv3 = nn.Conv2d(64, 128, kernel_size=4, stride=2)
        self.cv4 = nn.Conv2d(128, 256, kernel_size=4, stride=2)

    def forward(self, obs):
        hidden = F.relu(self.cv1(obs))
        hidden = F.relu(self.cv2(hidden))
        hidden = F.relu(self.cv3(hidden))
        embedded_obs = F.relu(self.cv4(hidden)).reshape(hidden.size(0), -1)
        return embedded_obs

In [None]:
class ValueModel(nn.Module):
    def __init__(self, state_dim, rnn_hidden_dim, hidden_dim=400, act=F.elu):
        super(ValueModel, self).__init__()
        self.fc1 = nn.Linear(state_dim + rnn_hidden_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, 1)
        self.act = act

    def forward(self, state, rnn_hidden):
        hidden = self.act(self.fc1(torch.cat([state, rnn_hidden], dim=1)))
        hidden = self.act(self.fc2(hidden))
        hidden = self.act(self.fc3(hidden))
        state_value = self.fc4(hidden)
        return state_value

In [None]:
class ActionModel(nn.Module):
    def __init__(self, state_dim, rnn_hidden_dim, action_dim,
                 hidden_dim=400, act=F.elu, min_stddev=1e-4, init_stddev=5.0):
        super(ActionModel, self).__init__()
        self.fc1 = nn.Linear(state_dim + rnn_hidden_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, hidden_dim)
        self.fc_mean = nn.Linear(hidden_dim, action_dim)
        self.fc_stddev = nn.Linear(hidden_dim, action_dim)
        self.act = act
        self.min_stddev = min_stddev
        self.init_stddev = np.log(np.exp(init_stddev) - 1)

    def forward(self, state, rnn_hidden, training=True):
        hidden = self.act(self.fc1(torch.cat([state, rnn_hidden], dim=1)))
        hidden = self.act(self.fc2(hidden))
        hidden = self.act(self.fc3(hidden))
        hidden = self.act(self.fc4(hidden))

        mean = self.fc_mean(hidden)
        mean = 5.0 * torch.tanh(mean / 5.0)
        stddev = self.fc_stddev(hidden)
        stddev = F.softplus(stddev + self.init_stddev) + self.min_stddev

        if training:
            action = torch.tanh(Normal(mean, stddev).rsample())
        else:
            action = torch.tanh(mean)
        return action

In [None]:
class Agent:
    def __init__(self, encoder, rssm, action_model):
        self.encoder = encoder
        self.rssm = rssm
        self.action_model = action_model

        self.device = next(self.action_model.parameters()).device
        self.rnn_hidden = torch.zeros(1, rssm.rnn_hidden_dim, device=self.device)

    def __call__(self, obs, training=True):
        obs = preprocess_obs(obs)
        obs = torch.as_tensor(obs, device=self.device)
        obs = obs.transpose(1, 2).transpose(0, 1).unsqueeze(0)

        with torch.no_grad():
            embedded_obs = self.encoder(obs)
            state_posterior = self.rssm.posterior(self.rnn_hidden, embedded_obs)
            state = state_posterior.sample()
            action = self.action_model(state, self.rnn_hidden, training=training)
            _, self.rnn_hidden = self.rssm.prior(self.rssm.recurrent(state, action, self.rnn_hidden))

        return action.squeeze().cpu().numpy()

    def reset(self):
        self.rnn_hidden = torch.zeros(1, self.rssm.rnn_hidden_dim, device=self.device)

# モデルの学習

In [None]:
set_seed(42)
env = humanoid_env()
device = "cuda" if torch.cuda.is_available() else "cpu"

buffer_capacity = 200000
replay_buffer = ReplayBuffer(
    capacity=buffer_capacity,
    observation_shape=env.observation_space.shape,
    action_dim=env.action_space.shape[0]
)

state_dim = 30
rnn_hidden_dim = 200

encoder = Encoder().to(device)
rssm = RSSM(state_dim,env.action_space.shape[0],rnn_hidden_dim, )
value_model = ValueModel(state_dim, rnn_hidden_dim).to(device)
action_model = ActionModel(state_dim, rnn_hidden_dim,
                             env.action_space.shape[0]).to(device)

trained_models = TrainedModels(
    encoder, rssm, value_model, action_model
)


model_lr = 6e-4
value_lr = 8e-5
action_lr = 8e-5
eps = 1e-4
model_params = (list(encoder.parameters()) +
                  list(rssm.transition.parameters()) +
                  list(rssm.observation.parameters()) +
                  list(rssm.reward.parameters()))
model_optimizer = torch.optim.Adam(model_params, lr=model_lr, eps=eps)
value_optimizer = torch.optim.Adam(value_model.parameters(), lr=value_lr, eps=eps)
action_optimizer = torch.optim.Adam(action_model.parameters(), lr=action_lr, eps=eps)


seed_episodes = 2 # 5
all_episodes = 10 # 100
test_interval = 2 # 10
model_save_interval = 4 # 20
collect_interval = 10 # 100

action_noise_var = 0.3

batch_size = 50
chunk_length = 50
imagination_horizon = 15

gamma = 0.9
lambda_ = 0.95
clip_grad_norm = 100
free_nats = 3

In [None]:
for episode in range(seed_episodes):
    obs = env.reset()
    done = False
    while not done:
        action = env.action_space.sample()
        next_obs, reward, done, _= env.step(action)
        replay_buffer.push(obs, action, reward, done)
        obs = next_obs

In [None]:
log_dir = "logs"
writer = SummaryWriter(log_dir)

In [None]:
for episode in range(seed_episodes, all_episodes):
    start = time.time()
    policy = Agent(encoder, rssm.transition, action_model)

    env = CLIPRewardedHumanoidEnv()
    obs = env.reset()
    done = False
    total_reward = 0
    while not done:
        action = policy(obs)
        action += np.random.normal(0, np.sqrt(action_noise_var),
                                     env.action_space.shape[0])
        next_obs, reward, done, _, = env.step(action)

        replay_buffer.push(obs, action, reward, done)

        obs = next_obs
        total_reward += reward

    print('episode [%4d/%4d] is collected. Total reward is %f' %
            (episode+1, all_episodes, total_reward))
    print('elasped time for interaction: %.2fs' % (time.time() - start))

    start = time.time()
    for update_step in range(collect_interval):
        observations, actions, rewards, _ = \
            replay_buffer.sample(batch_size, chunk_length)

        observations = preprocess_obs(observations)
        observations = torch.as_tensor(observations, device=device)
        observations = observations.transpose(3, 4).transpose(2, 3)
        observations = observations.transpose(0, 1)
        actions = torch.as_tensor(actions, device=device).transpose(0, 1)
        rewards = torch.as_tensor(rewards, device=device).transpose(0, 1)

        embedded_observations = encoder(
            observations.reshape(-1, 3, 64, 64)).view(chunk_length, batch_size, -1)

        states = torch.zeros(chunk_length, batch_size, state_dim, device=device)
        rnn_hiddens = torch.zeros(chunk_length, batch_size, rnn_hidden_dim, device=device)

        state = torch.zeros(batch_size, state_dim, device=device)
        rnn_hidden = torch.zeros(batch_size, rnn_hidden_dim, device=device)

        kl_loss = 0
        for l in range(chunk_length-1):
            next_state_prior, next_state_posterior, rnn_hidden = \
                rssm.transition(state, actions[l], rnn_hidden, embedded_observations[l+1])
            state = next_state_posterior.rsample()
            states[l+1] = state
            rnn_hiddens[l+1] = rnn_hidden
            kl = kl_divergence(next_state_prior, next_state_posterior).sum(dim=1)
            kl_loss += kl.clamp(min=free_nats).mean()
        kl_loss /= (chunk_length - 1)

        states = states[1:]
        rnn_hiddens = rnn_hiddens[1:]

        flatten_states = states.view(-1, state_dim)
        flatten_rnn_hiddens = rnn_hiddens.view(-1, rnn_hidden_dim)
        recon_observations = rssm.observation(flatten_states, flatten_rnn_hiddens).view(chunk_length-1, batch_size, 3, 64, 64)
        predicted_rewards = rssm.reward(flatten_states, flatten_rnn_hiddens).view(chunk_length-1, batch_size, 1)

        obs_loss = 0.5 * F.mse_loss(recon_observations, observations[1:], reduction='none').mean([0, 1]).sum()
        reward_loss = 0.5 * F.mse_loss(predicted_rewards, rewards[:-1])

        model_loss = kl_loss + obs_loss + reward_loss
        model_optimizer.zero_grad()
        model_loss.backward()
        clip_grad_norm_(model_params, clip_grad_norm)
        model_optimizer.step()

        flatten_states = flatten_states.detach()
        flatten_rnn_hiddens = flatten_rnn_hiddens.detach()

        imagined_states = torch.zeros(imagination_horizon + 1,
                                         *flatten_states.shape,
                                          device=flatten_states.device)
        imagined_rnn_hiddens = torch.zeros(imagination_horizon + 1,
                                                *flatten_rnn_hiddens.shape,
                                                device=flatten_rnn_hiddens.device)

        imagined_states[0] = flatten_states
        imagined_rnn_hiddens[0] = flatten_rnn_hiddens

        for h in range(1, imagination_horizon + 1):
            actions = action_model(flatten_states, flatten_rnn_hiddens)
            flatten_states_prior, flatten_rnn_hiddens = rssm.transition.prior(rssm.transition.recurrent(flatten_states,
                                                                   actions,
                                                                   flatten_rnn_hiddens))
            flatten_states = flatten_states_prior.rsample()
            imagined_states[h] = flatten_states
            imagined_rnn_hiddens[h] = flatten_rnn_hiddens

        flatten_imagined_states = imagined_states.view(-1, state_dim)
        flatten_imagined_rnn_hiddens = imagined_rnn_hiddens.view(-1, rnn_hidden_dim)
        imagined_rewards = \
            rssm.reward(flatten_imagined_states,
                        flatten_imagined_rnn_hiddens).view(imagination_horizon + 1, -1)
        imagined_values = \
            value_model(flatten_imagined_states,
                        flatten_imagined_rnn_hiddens).view(imagination_horizon + 1, -1)

        lambda_target_values = lambda_target(imagined_rewards, imagined_values, gamma, lambda_)

        action_loss = -lambda_target_values.mean()
        action_optimizer.zero_grad()
        action_loss.backward()
        clip_grad_norm_(action_model.parameters(), clip_grad_norm)
        action_optimizer.step()

        imagined_values = value_model(flatten_imagined_states.detach(), flatten_imagined_rnn_hiddens.detach()).view(imagination_horizon + 1, -1)
        value_loss = 0.5 * F.mse_loss(imagined_values, lambda_target_values.detach())
        value_optimizer.zero_grad()
        value_loss.backward()
        clip_grad_norm_(value_model.parameters(), clip_grad_norm)
        value_optimizer.step()

        print('update_step: %3d model loss: %.5f, kl_loss: %.5f, '
             'obs_loss: %.5f, reward_loss: %.5f, '
             'value_loss: %.5f action_loss: %.5f'
                % (update_step + 1, model_loss.item(), kl_loss.item(),
                    obs_loss.item(), reward_loss.item(),
                    value_loss.item(), action_loss.item()))
        total_update_step = episode * collect_interval + update_step

    print('elasped time for update: %.2fs' % (time.time() - start))

    if (episode + 1) % test_interval == 0:
        policy = Agent(encoder, rssm.transition, action_model)
        start = time.time()
        obs = env.reset()
        done = False
        total_reward = 0
        _returns = []
        while not done:
            action = policy(obs, training=False)
            obs, reward, done, _ = env.step(action)
            total_reward += reward

        print('Total test reward at episode [%4d/%4d] is %f' %
                (episode+1, all_episodes, total_reward))
        print('elasped time for test: %.2fs' % (time.time() - start))

    if (episode + 1) % model_save_interval == 0:
        model_log_dir = os.path.join(log_dir, 'episode_%04d' % (episode + 1))
        os.makedirs(model_log_dir)
        torch.save(encoder.state_dict(), os.path.join(model_log_dir, 'encoder.pth'))
        torch.save(rssm.transition.state_dict(), os.path.join(model_log_dir, 'rssm.pth'))
        torch.save(value_model.state_dict(), os.path.join(model_log_dir, 'value_model.pth'))
        torch.save(action_model.state_dict(), os.path.join(model_log_dir, 'action_model.pth'))
    del env
    gc.collect()
writer.close()

# 結果

In [None]:
%tensorboard --logdir='./logs'

In [None]:
env = humanoid_env()
device = "cuda" if torch.cuda.is_available() else "cpu"

encoder = Encoder().to(device)
rssm = RSSM(state_dim,env.action_space.shape[0],rnn_hidden_dim, )
value_model = ValueModel(state_dim, rnn_hidden_dim).to(device)
action_model = ActionModel(state_dim, rnn_hidden_dim,
                             env.action_space.shape[0]).to(device)

torch.load(log_dir, device)

In [None]:
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML


def display_video(frames):
    plt.figure(figsize=(8, 8), dpi=50)
    patch = plt.imshow(frames[0])
    plt.axis('off')

    def animate(i):
        patch.set_data(frames[i])
        plt.title("Step %d" % (i))

    anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=50)
    display(HTML(anim.to_jshtml(default_mode='once')))
    plt.close()

In [None]:
policy = Agent(encoder, rssm.transition, action_model)

obs = env.reset()
done = False
total_reward = 0
frames = [obs]
actions = []

while not done:
    action = policy(obs, training=False)
    obs, reward, done, _ = env.step(action)

    total_reward += reward
    frames.append(obs)
    actions.append(action)

print('Total Reward:', total_reward)

In [None]:
display_video(frames=frames)

In [None]:
actions = np.stack(actions)
np.save("actions", actions)