# 環境構築

In [None]:
!tar -zxvf mujoco210-linux-x86_64.tar.gz
!mkdir ~/.mujoco
!cp -r mujoco210 ~/.mujoco/mujoco210
!cp -r ~/.mujoco/mujoco210/bin/* /usr/lib/

In [None]:
!pip install setuptools==65.5.0 "wheel<0.40.0"
!apt update && apt install libosmesa6-dev libgl1-mesa-glx libglfw3-dev patchelf xvfb freeglut3-dev libgles2-mesa-dev -y
!pip3 install -U 'mujoco-py<2.2,>=2.1' pyvirtualdisplay

In [None]:
import os
os.environ["LD_LIBRARY_PATH"] = "/root/.mujoco/mujoco210/bin"
os.environ["PYOPENGL_PLATFORM"] = "osmesa"
os.environ["MUJOCO_GL"] = "osmesa"

In [None]:
!pip install mujoco==2.3.5
!pip install dm_control==1.0.9
!pip install gym

In [None]:
import gc
import glob
import time
from typing import Any, List, Tuple

import gym
import random
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from torch.distributions import Normal
from torch.distributions.kl import kl_divergence
from torch.nn import functional as F
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter

%load_ext tensorboard

In [None]:
from pyvirtualdisplay import Display
pydisplay = Display(visible=0, size=(400, 300))
pydisplay.start()

In [None]:
env = gymnasium.make("Humanoid-v4", render_mode="rgb_array")
env.reset()
image = env.render()
plt.imshow(image) # (480, 480, 3)
plt.show()
env.close()

In [None]:
class DeepMindControl:
  metadata = {}

  def __init__(self, name, size=(64, 64), camera=None):
    domain, task = name.split('_', 1)
    if domain == 'cup':  # Only domain with multiple words.
      domain = 'ball_in_cup'
    if isinstance(domain, str):
      from dm_control import suite
      self._env = suite.load(domain, task)
    else:
      assert task is None
      self._env = domain()
    self._size = size
    if camera is None:
      camera = dict(quadruped=2).get(domain, 0)
    self._camera = camera

  @property
  def observation_space(self):
    spaces = {}
    for key, value in self._env.observation_spec().items():
      spaces[key] = gym.spaces.Box(
          -np.inf, np.inf, value.shape, dtype=np.float32)
    spaces['image'] = gym.spaces.Box(
        0, 255, self._size + (3,), dtype=np.uint8)
    return gym.spaces.Dict(spaces)

  @property
  def action_space(self):
    spec = self._env.action_spec()
    return gym.spaces.Box(spec.minimum, spec.maximum, dtype=np.float32)

  def step(self, action):
    time_step = self._env.step(action)
    obs = dict(time_step.observation)
    obs['image'] = self.render()
    reward = time_step.reward or 0
    done = time_step.last()
    info = {'discount': np.array(time_step.discount, np.float32)}
    return obs['image'], reward, done, info

  def reset(self):
    time_step = self._env.reset()
    obs = dict(time_step.observation)
    obs['image'] = self.render()
    return obs['image']

  def render(self, *args, **kwargs):
    if kwargs.get('mode', 'rgb_array') != 'rgb_array':
      raise ValueError("Only render mode 'rgb_array' is supported.")
    return self._env.physics.render(*self._size, camera_id=self._camera)

  def close(self):
    pass

In [None]:
class RepeatAction(gym.Wrapper):
    def __init__(self, env, skip=2):
        gym.Wrapper.__init__(self, env)
        self._skip = skip

    def reset(self):
        return self.env.reset()

    def step(self, action):
        total_reward = 0.0
        for _ in range(self._skip):
            obs, reward, done, info = self.env.step(action)
            total_reward += reward
            if done:
                break
        return obs, total_reward, done, info

In [None]:
def set_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [None]:
env = DeepMindControl('walker_walk')
env = RepeatAction(env)
env.reset()
obs, reward, done, info = env.step(env.action_space.sample())
plt.imshow(obs) # (64, 64, 3)
plt.show()

# Diffusion Model

In [None]:
def exists(x):
    return x is not None


class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x


class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.GroupNorm(num_groups=1, num_channels=dim)

    def forward(self, x):
        x = self.norm(x)
        return self.fn(x)


def Upsample(dim, dim_out=None):
    return nn.Sequential(
        nn.Upsample(scale_factor=2, mode="nearest"),
        nn.Conv2d(dim, dim_out if exists(dim_out) else dim, kernel_size=3, padding=1),
    )

def Downsample(dim, dim_out=None):
    return nn.Conv2d(
        dim, dim_out if exists(dim_out) else dim, kernel_size=4, stride=2, padding=1
    )

In [None]:
class SinusoidalPositionEncoding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        half_dim = self.dim // 2

        pos_enc = math.log(10000) / (half_dim - 1)  # scalar
        pos_enc = torch.exp(torch.arange(half_dim, device=time.device) * -pos_enc)  # ( time_emb_dim // 2, )
        pos_enc = time[:, None] * pos_enc[None, :]  # ( batch, 1 ) * ( 1, time_emb_dim // 2 ) -> ( batch, time_emb_dim // 2 )
        pos_enc = torch.cat((pos_enc.sin(), pos_enc.cos()), dim=-1)  # ( batch, time_emb_dim )

        return pos_enc

In [None]:
class Block(nn.Module):
    def __init__(self, dim, dim_out, groups=8):
        super().__init__()
        self.proj = nn.Conv2d(dim, dim_out, 3, padding=1)
        self.norm = nn.GroupNorm(groups, dim_out)
        # https://pytorch.org/docs/stable/generated/torch.nn.SiLU.html
        self.act = nn.SiLU()

    def forward(self, x, scale_shift=None):
        """
        Args:
            x ( b, dim, h', w' )
        Returns:
            x ( b, dim_out, h', w' )
        """
        x = self.proj(x)
        x = self.norm(x)

        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift

        return self.act(x)

class ResnetBlock(nn.Module):
    def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
        super().__init__()

        self.mlp = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_emb_dim, dim_out * 2)
        ) if exists(time_emb_dim) else None

        self.block1 = Block(dim, dim_out, groups=groups)
        self.block2 = Block(dim_out, dim_out, groups=groups)
        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def forward(self, x, time_emb=None):
        """
        Args:
            x ( b, dim, h', w' )
        Returns:
            x ( b, dim_out, h', w' )
        """
        scale_shift = None

        if exists(self.mlp) and exists(time_emb):
            time_emb = self.mlp(time_emb)  # ( b, dim_out * 2 )
            time_emb = rearrange(time_emb, "b c -> b c 1 1")  # ( b, dim_out * 2, 1, 1 )
            scale_shift = time_emb.chunk(2, dim=1)  # tuple(( b, dim_out, 1, 1 ), ( b, dim_out, 1, 1 ))

        h = self.block1(x, scale_shift=scale_shift)
        h = self.block2(h)

        return h + self.res_conv(x)

In [None]:
class LinearAttention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32, num_mem_kv=4):
        super().__init__()

        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads

        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)

        self.to_out = nn.Sequential(
            nn.Conv2d(hidden_dim, dim, 1),
            nn.GroupNorm(1, dim),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x ( b, dim, h', w' )
        Returns:
            out ( b, dim, h', w' )
        """
        b, c, h, w = x.shape

        q, k, v = self.to_qkv(x).chunk(3, dim=1)
        q = rearrange(q, "b (h c) x y -> b h c (x y)", h=self.heads)
        k = rearrange(k, "b (h c) x y -> b h c (x y)", h=self.heads)
        v = rearrange(v, "b (h c) x y -> b h c (x y)", h=self.heads)
        # ( b, heads, dim_head, h(=x) * w(=y) )

        q = q.softmax(dim=-2)
        k = k.softmax(dim=-1)

        q = q * self.scale

        context = torch.einsum("b h d n, b h e n -> b h d e", k, v)

        out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
        out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)

        return self.to_out(out)

In [None]:
class Unet(nn.Module):
    def __init__(
        self,
        dim, # image_size
        init_dim=None,
        out_dim=None,
        dim_mults=(1, 2, 4, 8),
        channels=3,
        with_time_emb=True,
        resnet_block_groups=8,
        convnext_mult=2,
    ):
        super().__init__()

        # determine dimensions
        self.channels = channels

        init_dim = init_dim if exists(init_dim) else dim // 3 * 2
        self.init_conv = nn.Conv2d(channels, init_dim, kernel_size=7, padding=3)

        # Wide ResNet
        ResnetBlock_ = partial(ResnetBlock, groups=resnet_block_groups)

        # position (time) encoding
        if with_time_emb:
            time_dim = dim * 4
            self.time_mlp = nn.Sequential(
                SinusoidalPositionEncoding(dim),
                nn.Linear(dim, time_dim),
                nn.GELU(),
                nn.Linear(time_dim, time_dim),
            )
        else:
            time_dim = None
            self.time_mlp = None

        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))

        num_resolutions = len(in_out)

        self.downs = nn.ModuleList([])

        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)

            self.downs.append(nn.ModuleList(
                    [
                        ResnetBlock_(dim_in, dim_in, time_emb_dim=time_dim),
                        ResnetBlock_(dim_in, dim_in, time_emb_dim=time_dim),
                        Residual(PreNorm(dim_in, LinearAttention(dim_in))),
                        Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding=1),
                    ]
                )
            )

        mid_dim = dims[-1]
        self.mid_block1 = ResnetBlock_(mid_dim, mid_dim, time_emb_dim=time_dim)
        self.mid_attn = Residual(PreNorm(mid_dim, LinearAttention(mid_dim)))
        self.mid_block2 = ResnetBlock_(mid_dim, mid_dim, time_emb_dim=time_dim)

        self.ups = nn.ModuleList([])

        for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
            is_last = ind == (num_resolutions - 1)

            self.ups.append(nn.ModuleList(
                    [
                        ResnetBlock_(dim_out + dim_in, dim_out, time_emb_dim=time_dim),
                        ResnetBlock_(dim_out + dim_in, dim_out, time_emb_dim=time_dim),
                        Residual(PreNorm(dim_out, LinearAttention(dim_out))),
                        Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding=1),
                    ]
                )
            )

        out_dim = out_dim if exists(out_dim) else channels
        self.final_res_block = ResnetBlock_(init_dim * 2, init_dim, time_emb_dim=time_dim)
        self.final_conv = nn.Conv2d(init_dim, out_dim, 1)

    def forward(self, x, time):  # time: ( b, )
        b = x.shape[0]

        x = self.init_conv(x)
        r = x.clone()

        t = self.time_mlp(time) if exists(self.time_mlp) else None # ( B, time_emb_dim )

        h = []

        for block1, block2, attn, downsample in self.downs:
            x = block1(x, t)
            h.append(x)

            x = block2(x, t)
            x = attn(x)
            h.append(x)

            x = downsample(x)

        x = self.mid_block1(x, t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t)

        for block1, block2, attn, upsample in self.ups:
            x = torch.cat((x, h.pop()), dim=1)
            x = block1(x, t)

            x = torch.cat((x, h.pop()), dim=1)
            x = block2(x, t)
            x = attn(x)

            x = upsample(x)

        x = torch.cat((x, r), dim=1)

        x = self.final_res_block(x, t)

        return self.final_conv(x)

In [None]:
timesteps = 800

def linear_beta_schedule(timesteps, beta_start=1e-4, beta_end=0.02):
    return torch.linspace(beta_start, beta_end, timesteps)

betas = linear_beta_schedule(timesteps)

alphas_cumprod = torch.cumprod(1. - betas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)

In [None]:
def extract(a, t, x_shape: torch.Size):
    batch_size = t.shape[0]

    out = a.gather(-1, t.to(a.device))

    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

def q_sample(x_start, t, noise):
    """
    Args:
        x_start ( b, c, h, w )
        t ( b, )
        noise ( b, c, h, w )
    Returns:
        x_noisy: ( b, c, h, w )
    """
    alphas_cumprod_t = extract(alphas_cumprod, t, x_start.shape) # ( b, 1, 1, 1 )

    return torch.sqrt(alphas_cumprod_t) * x_start + torch.sqrt(1.0 - alphas_cumprod_t) * noise

In [None]:
def p_losses(denoise_model, x_start, t, noise=None, loss_type="l2"):
    """
    Args:
        x_start ( b, c, h, w )
        t ( b, )
    """
    if noise is None:
        noise = torch.randn_like(x_start)

    x_noisy = q_sample(x_start, t, noise)

    predicted_noise = denoise_model(x_noisy, t)

    if loss_type == 'l1':
        loss = F.l1_loss(noise, predicted_noise)
    elif loss_type == 'l2':
        loss = F.mse_loss(noise, predicted_noise)
    elif loss_type == "huber":
        # https://pytorch.org/docs/stable/generated/torch.nn.SmoothL1Loss.html
        loss = F.smooth_l1_loss(noise, predicted_noise)
    else:
        raise NotImplementedError()

    return loss

In [None]:
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

@torch.no_grad()
def p_sample(model, x, t_index):
    """
    Args:
        x ( b, c, h, w )
        t_index ( int )
    Returns:
        x ( b, c, h, w )
    """
    t = torch.full((x.shape[0],), t_index, device=x.device, dtype=torch.long) # ( b, )

    betas_t = extract(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(torch.sqrt(1.0 - alphas_cumprod), t, x.shape)
    coef_t = extract(torch.sqrt(1.0 / (1. - betas)), t, x.shape)

    model_mean = coef_t * (
        x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
    )

    if t_index == 0:
        return model_mean
    else:
        posterior_variance_t = extract(posterior_variance, t, x.shape)
        noise = torch.randn_like(x)
        return model_mean + torch.sqrt(posterior_variance_t) * noise

@torch.no_grad()
def p_sample_loop(model, image_size, batch_size=16, channels=3):
    """
    Args:
        image_size ( int )
    Returns:
        imgs ( b, c, h, w )
    """
    device = next(model.parameters()).device

    shape = (batch_size, channels, image_size, image_size)
    img = torch.randn(shape, device=device)

    imgs = []
    for t in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):
        img = p_sample(model, img, t)
        imgs.append(img.cpu().numpy())

    return imgs

# Dreamer







In [None]:
class TransitionModel(nn.Module):
    def __init__(self, state_dim, action_dim, rnn_hidden_dim,
                 hidden_dim=200, min_stddev=0.1, act=F.elu):
        super(TransitionModel, self).__init__()
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.rnn_hidden_dim = rnn_hidden_dim
        self.fc_state_action = nn.Linear(state_dim + action_dim, hidden_dim)

        self.fc_rnn_hidden = nn.Linear(rnn_hidden_dim, hidden_dim)
        self.fc_state_mean_prior = nn.Linear(hidden_dim, state_dim)
        self.fc_state_stddev_prior = nn.Linear(hidden_dim, state_dim)
        self.fc_rnn_hidden_embedded_obs = nn.Linear(rnn_hidden_dim + 1024, hidden_dim)
        self.fc_state_mean_posterior = nn.Linear(hidden_dim, state_dim)
        self.fc_state_stddev_posterior = nn.Linear(hidden_dim, state_dim)

        self.rnn = nn.GRUCell(hidden_dim, rnn_hidden_dim)
        self._min_stddev = min_stddev
        self.act = act


    def forward(self, state, action, rnn_hidden, embedded_next_obs):
        next_state_prior, rnn_hidden = self.prior(self.recurrent(state, action, rnn_hidden))
        next_state_posterior = self.posterior(rnn_hidden, embedded_next_obs)
        return next_state_prior, next_state_posterior, rnn_hidden

    def recurrent(self, state, action, rnn_hidden):
        """
        h_t+1 = f(h_t, s_t, a_t)
        """
        hidden = self.act(self.fc_state_action(torch.cat([state, action], dim=1)))
        rnn_hidden = self.rnn(hidden, rnn_hidden)
        return rnn_hidden

    def prior(self, rnn_hidden):
        """
        prior p(s_t+1 | h_t+1)
        """
        hidden = self.act(self.fc_rnn_hidden(rnn_hidden))

        mean = self.fc_state_mean_prior(hidden)
        stddev = F.softplus(self.fc_state_stddev_prior(hidden)) + self._min_stddev
        return Normal(mean, stddev), rnn_hidden

    def posterior(self, rnn_hidden, embedded_obs):
        """
        posterior q(s_t+1 | h_t+1, e_t+1)
        """
        hidden = self.act(
            self.fc_rnn_hidden_embedded_obs(
                torch.cat([rnn_hidden, embedded_obs], dim=1)
            )
        )
        mean = self.fc_state_mean_posterior(hidden)
        stddev = F.softplus(self.fc_state_stddev_posterior(hidden)) + self._min_stddev
        return Normal(mean, stddev)

In [None]:
class ObservationModel(nn.Module):
    """
    p(o_t | s_t, h_t)
    """
    def __init__(self, state_dim, rnn_hidden_dim):
        super(ObservationModel, self).__init__()
        self.fc = nn.Linear(state_dim + rnn_hidden_dim, 1024)
        self.dc1 = nn.ConvTranspose2d(1024, 128, kernel_size=5, stride=2)
        self.dc2 = nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2)
        self.dc3 = nn.ConvTranspose2d(64, 32, kernel_size=6, stride=2)
        self.dc4 = nn.ConvTranspose2d(32, 3, kernel_size=6, stride=2)


    def forward(self, state, rnn_hidden):
        hidden = self.fc(torch.cat([state, rnn_hidden], dim=1))
        hidden = hidden.view(hidden.size(0), 1024, 1, 1)
        hidden = F.relu(self.dc1(hidden))
        hidden = F.relu(self.dc2(hidden))
        hidden = F.relu(self.dc3(hidden))
        obs = self.dc4(hidden)
        return obs

In [None]:
class RewardModel(nn.Module):
    """
    p(r_t | s_t, h_t)
    """
    def __init__(self, state_dim, rnn_hidden_dim, hidden_dim=400, act=F.elu):
        super(RewardModel, self).__init__()
        self.fc1 = nn.Linear(state_dim + rnn_hidden_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, 1)
        self.act = act


    def forward(self, state, rnn_hidden):
        hidden = self.act(self.fc1(torch.cat([state, rnn_hidden], dim=1)))
        hidden = self.act(self.fc2(hidden))
        hidden = self.act(self.fc3(hidden))
        reward = self.fc4(hidden)
        return reward

In [None]:
class RSSM(nn.Module):
    def __init__(self, state_dim, action_dim, rnn_hidden_dim):
        super().__init__()

        self.transition = TransitionModel(state_dim, action_dim, rnn_hidden_dim).to(device)
        self.observation = ObservationModel(state_dim, rnn_hidden_dim,).to(device)
        self.reward = RewardModel(state_dim, rnn_hidden_dim,).to(device)

In [None]:
class ReplayBuffer(object):
    def __init__(self, capacity, observation_shape, action_dim):
        self.capacity = capacity

        self.observations = np.zeros((capacity, *observation_shape), dtype=np.uint8)
        self.actions = np.zeros((capacity, action_dim), dtype=np.float32)
        self.rewards = np.zeros((capacity, 1), dtype=np.float32)
        self.done = np.zeros((capacity, 1), dtype=np.bool_)

        self.index = 0
        self.is_filled = False

    def push(self, observation, action, reward, done):
        self.observations[self.index] = observation
        self.actions[self.index] = action
        self.rewards[self.index] = reward
        self.done[self.index] = done

        if self.index == self.capacity - 1:
            self.is_filled = True
        self.index = (self.index + 1) % self.capacity

    def sample(self, batch_size, chunk_length):
        episode_borders = np.where(self.done)[0]
        sampled_indexes = []
        for _ in range(batch_size):
            cross_border = True
            while cross_border:
                initial_index = np.random.randint(len(self) - chunk_length + 1)
                final_index = initial_index + chunk_length - 1
                cross_border = np.logical_and(initial_index <= episode_borders,
                                              episode_borders < final_index).any()
            sampled_indexes += list(range(initial_index, final_index + 1))

        sampled_observations = self.observations[sampled_indexes].reshape(
            batch_size, chunk_length, *self.observations.shape[1:])
        sampled_actions = self.actions[sampled_indexes].reshape(
            batch_size, chunk_length, self.actions.shape[1])
        sampled_rewards = self.rewards[sampled_indexes].reshape(
            batch_size, chunk_length, 1)
        sampled_done = self.done[sampled_indexes].reshape(
            batch_size, chunk_length, 1)
        return sampled_observations, sampled_actions, sampled_rewards, sampled_done

    def __len__(self):
        return self.capacity if self.is_filled else self.index

In [None]:
def preprocess_obs(obs):
    """
    [0, 255] -> [-0.5, 0.5]
    """
    obs = obs.astype(np.float32)
    normalized_obs = obs / 255.0 - 0.5
    return normalized_obs

In [None]:
def lambda_target(rewards, values, gamma, lambda_):
    """
    λ-return
    """
    V_lambda = torch.zeros_like(rewards, device=rewards.device)

    H = rewards.shape[0] - 1
    V_n = torch.zeros_like(rewards, device=rewards.device)
    V_n[H] = values[H]
    for n in range(1, H+1):
        V_n[:-n] = (gamma ** n) * values[n:]
        for k in range(1, n+1):
            if k == n:
                V_n[:-n] += (gamma ** (n-1)) * rewards[k:]
            else:
                V_n[:-n] += (gamma ** (k-1)) * rewards[k:-n+k]

        if n == H:
            V_lambda += (lambda_ ** (H-1)) * V_n
        else:
            V_lambda += (1 - lambda_) * (lambda_ ** (n-1)) * V_n

    return V_lambda

In [None]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.cv1 = nn.Conv2d(3, 32, kernel_size=4, stride=2)
        self.cv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.cv3 = nn.Conv2d(64, 128, kernel_size=4, stride=2)
        self.cv4 = nn.Conv2d(128, 256, kernel_size=4, stride=2)

    def forward(self, obs):
        hidden = F.relu(self.cv1(obs))
        hidden = F.relu(self.cv2(hidden))
        hidden = F.relu(self.cv3(hidden))
        embedded_obs = F.relu(self.cv4(hidden)).reshape(hidden.size(0), -1)
        return embedded_obs

In [None]:
class ValueModel(nn.Module):
    def __init__(self, state_dim, rnn_hidden_dim, hidden_dim=400, act=F.elu):
        super(ValueModel, self).__init__()
        self.fc1 = nn.Linear(state_dim + rnn_hidden_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, 1)
        self.act = act

    def forward(self, state, rnn_hidden):
        hidden = self.act(self.fc1(torch.cat([state, rnn_hidden], dim=1)))
        hidden = self.act(self.fc2(hidden))
        hidden = self.act(self.fc3(hidden))
        state_value = self.fc4(hidden)
        return state_value

In [None]:
class ActionModel(nn.Module):
    def __init__(self, state_dim, rnn_hidden_dim, action_dim,
                 hidden_dim=400, act=F.elu, min_stddev=1e-4, init_stddev=5.0):
        super(ActionModel, self).__init__()
        self.fc1 = nn.Linear(state_dim + rnn_hidden_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, hidden_dim)
        self.fc_mean = nn.Linear(hidden_dim, action_dim)
        self.fc_stddev = nn.Linear(hidden_dim, action_dim)
        self.act = act
        self.min_stddev = min_stddev
        self.init_stddev = np.log(np.exp(init_stddev) - 1)

    def forward(self, state, rnn_hidden, training=True):
        hidden = self.act(self.fc1(torch.cat([state, rnn_hidden], dim=1)))
        hidden = self.act(self.fc2(hidden))
        hidden = self.act(self.fc3(hidden))
        hidden = self.act(self.fc4(hidden))

        mean = self.fc_mean(hidden)
        mean = 5.0 * torch.tanh(mean / 5.0)
        stddev = self.fc_stddev(hidden)
        stddev = F.softplus(stddev + self.init_stddev) + self.min_stddev

        if training:
            action = torch.tanh(Normal(mean, stddev).rsample())
        else:
            action = torch.tanh(mean)
        return action

In [None]:
class Agent:
    def __init__(self, encoder, rssm, action_model):
        self.encoder = encoder
        self.rssm = rssm
        self.action_model = action_model

        self.device = next(self.action_model.parameters()).device
        self.rnn_hidden = torch.zeros(1, rssm.rnn_hidden_dim, device=self.device)

    def __call__(self, obs, training=True):
        obs = preprocess_obs(obs)
        obs = torch.as_tensor(obs, device=self.device)
        obs = obs.transpose(1, 2).transpose(0, 1).unsqueeze(0)

        with torch.no_grad():
            embedded_obs = self.encoder(obs)
            state_posterior = self.rssm.posterior(self.rnn_hidden, embedded_obs)
            state = state_posterior.sample()
            action = self.action_model(state, self.rnn_hidden, training=training)
            _, self.rnn_hidden = self.rssm.prior(self.rssm.recurrent(state, action, self.rnn_hidden))

        return action.squeeze().cpu().numpy()

    def reset(self):
        self.rnn_hidden = torch.zeros(1, self.rssm.rnn_hidden_dim, device=self.device)

# 学習

In [None]:
set_seed(42)
env = humanoid_env()
device = "cuda" if torch.cuda.is_available() else "cpu"

buffer_capacity = 200000
replay_buffer = ReplayBuffer(
    capacity=buffer_capacity,
    observation_shape=env.observation_space.shape,
    action_dim=env.action_space.shape[0]
)

state_dim = 30
rnn_hidden_dim = 200

encoder = Encoder().to(device)
rssm = RSSM(state_dim,env.action_space.shape[0],rnn_hidden_dim, )
value_model = ValueModel(state_dim, rnn_hidden_dim).to(device)
action_model = ActionModel(state_dim, rnn_hidden_dim,
                             env.action_space.shape[0]).to(device)

trained_models = TrainedModels(
    encoder, rssm, value_model, action_model
)


model_lr = 6e-4
value_lr = 8e-5
action_lr = 8e-5
eps = 1e-4
model_params = (list(encoder.parameters()) +
                  list(rssm.transition.parameters()) +
                  list(rssm.observation.parameters()) +
                  list(rssm.reward.parameters()))
model_optimizer = torch.optim.Adam(model_params, lr=model_lr, eps=eps)
value_optimizer = torch.optim.Adam(value_model.parameters(), lr=value_lr, eps=eps)
action_optimizer = torch.optim.Adam(action_model.parameters(), lr=action_lr, eps=eps)


seed_episodes = 2 # 5
all_episodes = 10 # 100
test_interval = 2 # 10
model_save_interval = 4 # 20
collect_interval = 10 # 100

action_noise_var = 0.3

batch_size = 50
chunk_length = 50
imagination_horizon = 15

gamma = 0.9
lambda_ = 0.95
clip_grad_norm = 100
free_nats = 3

In [None]:
for episode in range(seed_episodes):
    obs = env.reset()
    done = False
    while not done:
        action = env.action_space.sample()
        next_obs, reward, done, _= env.step(action)
        replay_buffer.push(obs, action, reward, done)
        obs = next_obs

In [None]:
log_dir = "logs"
writer = SummaryWriter(log_dir)

In [None]:
for episode in range(seed_episodes, all_episodes):
    start = time.time()
    policy = Agent(encoder, rssm.transition, action_model)

    env = CLIPRewardedHumanoidEnv()
    obs = env.reset()
    done = False
    total_reward = 0
    while not done:
        action = policy(obs)
        action += np.random.normal(0, np.sqrt(action_noise_var),
                                     env.action_space.shape[0])
        next_obs, reward, done, _, = env.step(action)

        replay_buffer.push(obs, action, reward, done)

        obs = next_obs
        total_reward += reward

    print('episode [%4d/%4d] is collected. Total reward is %f' %
            (episode+1, all_episodes, total_reward))
    print('elasped time for interaction: %.2fs' % (time.time() - start))

    start = time.time()
    for update_step in range(collect_interval):
        observations, actions, rewards, _ = \
            replay_buffer.sample(batch_size, chunk_length)

        observations = preprocess_obs(observations)
        observations = torch.as_tensor(observations, device=device)
        observations = observations.transpose(3, 4).transpose(2, 3)
        observations = observations.transpose(0, 1)
        actions = torch.as_tensor(actions, device=device).transpose(0, 1)
        rewards = torch.as_tensor(rewards, device=device).transpose(0, 1)

        embedded_observations = encoder(
            observations.reshape(-1, 3, 64, 64)).view(chunk_length, batch_size, -1)

        states = torch.zeros(chunk_length, batch_size, state_dim, device=device)
        rnn_hiddens = torch.zeros(chunk_length, batch_size, rnn_hidden_dim, device=device)

        state = torch.zeros(batch_size, state_dim, device=device)
        rnn_hidden = torch.zeros(batch_size, rnn_hidden_dim, device=device)

        kl_loss = 0
        for l in range(chunk_length-1):
            next_state_prior, next_state_posterior, rnn_hidden = \
                rssm.transition(state, actions[l], rnn_hidden, embedded_observations[l+1])
            state = next_state_posterior.rsample()
            states[l+1] = state
            rnn_hiddens[l+1] = rnn_hidden
            kl = kl_divergence(next_state_prior, next_state_posterior).sum(dim=1)
            kl_loss += kl.clamp(min=free_nats).mean()
        kl_loss /= (chunk_length - 1)

        states = states[1:]
        rnn_hiddens = rnn_hiddens[1:]

        flatten_states = states.view(-1, state_dim)
        flatten_rnn_hiddens = rnn_hiddens.view(-1, rnn_hidden_dim)
        recon_observations = rssm.observation(flatten_states, flatten_rnn_hiddens).view(chunk_length-1, batch_size, 3, 64, 64)
        predicted_rewards = rssm.reward(flatten_states, flatten_rnn_hiddens).view(chunk_length-1, batch_size, 1)

        obs_loss = 0.5 * F.mse_loss(recon_observations, observations[1:], reduction='none').mean([0, 1]).sum()
        reward_loss = 0.5 * F.mse_loss(predicted_rewards, rewards[:-1])

        model_loss = kl_loss + obs_loss + reward_loss
        model_optimizer.zero_grad()
        model_loss.backward()
        clip_grad_norm_(model_params, clip_grad_norm)
        model_optimizer.step()

        flatten_states = flatten_states.detach()
        flatten_rnn_hiddens = flatten_rnn_hiddens.detach()

        imagined_states = torch.zeros(imagination_horizon + 1,
                                         *flatten_states.shape,
                                          device=flatten_states.device)
        imagined_rnn_hiddens = torch.zeros(imagination_horizon + 1,
                                                *flatten_rnn_hiddens.shape,
                                                device=flatten_rnn_hiddens.device)

        imagined_states[0] = flatten_states
        imagined_rnn_hiddens[0] = flatten_rnn_hiddens

        for h in range(1, imagination_horizon + 1):
            actions = action_model(flatten_states, flatten_rnn_hiddens)
            flatten_states_prior, flatten_rnn_hiddens = rssm.transition.prior(rssm.transition.recurrent(flatten_states,
                                                                   actions,
                                                                   flatten_rnn_hiddens))
            flatten_states = flatten_states_prior.rsample()
            imagined_states[h] = flatten_states
            imagined_rnn_hiddens[h] = flatten_rnn_hiddens

        flatten_imagined_states = imagined_states.view(-1, state_dim)
        flatten_imagined_rnn_hiddens = imagined_rnn_hiddens.view(-1, rnn_hidden_dim)
        imagined_rewards = \
            rssm.reward(flatten_imagined_states,
                        flatten_imagined_rnn_hiddens).view(imagination_horizon + 1, -1)
        imagined_values = \
            value_model(flatten_imagined_states,
                        flatten_imagined_rnn_hiddens).view(imagination_horizon + 1, -1)

        lambda_target_values = lambda_target(imagined_rewards, imagined_values, gamma, lambda_)

        action_loss = -lambda_target_values.mean()
        action_optimizer.zero_grad()
        action_loss.backward()
        clip_grad_norm_(action_model.parameters(), clip_grad_norm)
        action_optimizer.step()

        imagined_values = value_model(flatten_imagined_states.detach(), flatten_imagined_rnn_hiddens.detach()).view(imagination_horizon + 1, -1)
        value_loss = 0.5 * F.mse_loss(imagined_values, lambda_target_values.detach())
        value_optimizer.zero_grad()
        value_loss.backward()
        clip_grad_norm_(value_model.parameters(), clip_grad_norm)
        value_optimizer.step()

        print('update_step: %3d model loss: %.5f, kl_loss: %.5f, '
             'obs_loss: %.5f, reward_loss: %.5f, '
             'value_loss: %.5f action_loss: %.5f'
                % (update_step + 1, model_loss.item(), kl_loss.item(),
                    obs_loss.item(), reward_loss.item(),
                    value_loss.item(), action_loss.item()))
        total_update_step = episode * collect_interval + update_step

    print('elasped time for update: %.2fs' % (time.time() - start))

    if (episode + 1) % test_interval == 0:
        policy = Agent(encoder, rssm.transition, action_model)
        start = time.time()
        obs = env.reset()
        done = False
        total_reward = 0
        _returns = []
        while not done:
            action = policy(obs, training=False)
            obs, reward, done, _ = env.step(action)
            total_reward += reward

        print('Total test reward at episode [%4d/%4d] is %f' %
                (episode+1, all_episodes, total_reward))
        print('elasped time for test: %.2fs' % (time.time() - start))

    if (episode + 1) % model_save_interval == 0:
        model_log_dir = os.path.join(log_dir, 'episode_%04d' % (episode + 1))
        os.makedirs(model_log_dir)
        torch.save(encoder.state_dict(), os.path.join(model_log_dir, 'encoder.pth'))
        torch.save(rssm.transition.state_dict(), os.path.join(model_log_dir, 'rssm.pth'))
        torch.save(value_model.state_dict(), os.path.join(model_log_dir, 'value_model.pth'))
        torch.save(action_model.state_dict(), os.path.join(model_log_dir, 'action_model.pth'))
    del env
    gc.collect()
writer.close()

# 結果

In [None]:
%tensorboard --logdir='./logs'

In [None]:
env = humanoid_env()
device = "cuda" if torch.cuda.is_available() else "cpu"

encoder = Encoder().to(device)
rssm = RSSM(state_dim,env.action_space.shape[0],rnn_hidden_dim, )
value_model = ValueModel(state_dim, rnn_hidden_dim).to(device)
action_model = ActionModel(state_dim, rnn_hidden_dim,
                             env.action_space.shape[0]).to(device)

torch.load(log_dir, device)

In [None]:
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML


def display_video(frames):
    plt.figure(figsize=(8, 8), dpi=50)
    patch = plt.imshow(frames[0])
    plt.axis('off')

    def animate(i):
        patch.set_data(frames[i])
        plt.title("Step %d" % (i))

    anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=50)
    display(HTML(anim.to_jshtml(default_mode='once')))
    plt.close()

In [None]:
policy = Agent(encoder, rssm.transition, action_model)

obs = env.reset()
done = False
total_reward = 0
frames = [obs]
actions = []

while not done:
    action = policy(obs, training=False)
    obs, reward, done, _ = env.step(action)

    total_reward += reward
    frames.append(obs)
    actions.append(action)

print('Total Reward:', total_reward)

In [None]:
display_video(frames=frames)

In [None]:
actions = np.stack(actions)
np.save("actions", actions)