In [210]:
import os
import copy
import pathlib

import gymnasium as gym
import numpy as np
import torch
import torchvision
from lightning.fabric import Fabric
from omegaconf import OmegaConf
from PIL import Image

from sheeprl.algos.dreamer_v3.agent import build_agent
from sheeprl.data.buffers import SequentialReplayBuffer
from sheeprl.utils.env import make_env
from sheeprl.utils.utils import dotdict

In [211]:
ckpt_path = pathlib.Path("../logs/runs/dreamer_v3/MsPacmanNoFrameskip-v4/2024-07-17_12-01-29_dreamer_v3_MsPacmanNoFrameskip-v4_5/version_0/checkpoint/ckpt_100000_0.ckpt")
ckpt_path

PosixPath('../logs/runs/dreamer_v3/MsPacmanNoFrameskip-v4/2024-07-17_12-01-29_dreamer_v3_MsPacmanNoFrameskip-v4_5/version_0/checkpoint/ckpt_100000_0.ckpt')

In [212]:
seed = 5
fabric = Fabric(accelerator="cuda", devices=1)
fabric.launch()
state = fabric.load(ckpt_path)
cfg = dotdict(OmegaConf.to_container(OmegaConf.load(ckpt_path.parent.parent / "config.yaml"), resolve=True))

# The number of environments is set to 1
cfg.env.num_envs = 1

In [213]:
from torch.distributions import Distribution

# Torch settings
os.environ["OMP_NUM_THREADS"] = str(cfg.num_threads)
torch.set_float32_matmul_precision(cfg.float32_matmul_precision)

# Set the distribution validate_args once here
Distribution.set_default_validate_args(cfg.distribution.validate_args)

print(cfg.num_threads)
print(cfg.float32_matmul_precision)
print(cfg.distribution.validate_args)

1
high
False


In [214]:
print(os.environ["OMP_NUM_THREADS"])

1


In [215]:
module = None
decoupled = False
entrypoint = None
algo_name = cfg.algo.name
algo_name

'dreamer_v3'

In [216]:
from sheeprl.utils.registry import algorithm_registry, evaluation_registry
for _module, _algos in algorithm_registry.items():
    for _algo in _algos:
        if algo_name == _algo["name"]:
            module = _module
            entrypoint = _algo["entrypoint"]
            decoupled = _algo["decoupled"]
            break

print(module)
print(entrypoint)
print(decoupled)

sheeprl.algos.dreamer_v3
main
False


In [217]:
import importlib
task = importlib.import_module(f"{module}.{algo_name}")
utils = importlib.import_module(f"{module}.utils")
command = task.__dict__[entrypoint]
kwargs = {}

In [218]:
strategy = cfg.fabric.get("strategy", "auto")
strategy

'auto'

In [219]:
import hydra
fabric: Fabric = hydra.utils.instantiate(cfg.fabric, strategy=strategy, _convert_="all")

In [220]:
cfg.metric

{'log_every': 5000,
 'disable_timer': False,
 'log_level': 1,
 'sync_on_compute': False,
 'aggregator': {'_target_': 'sheeprl.utils.metric.MetricAggregator',
  'raise_on_missing': False,
  'metrics': {'Rewards/rew_avg': {'_target_': 'torchmetrics.MeanMetric',
    'sync_on_compute': False},
   'Game/ep_len_avg': {'_target_': 'torchmetrics.MeanMetric',
    'sync_on_compute': False},
   'Loss/world_model_loss': {'_target_': 'torchmetrics.MeanMetric',
    'sync_on_compute': False},
   'Loss/value_loss': {'_target_': 'torchmetrics.MeanMetric',
    'sync_on_compute': False},
   'Loss/policy_loss': {'_target_': 'torchmetrics.MeanMetric',
    'sync_on_compute': False},
   'Loss/observation_loss': {'_target_': 'torchmetrics.MeanMetric',
    'sync_on_compute': False},
   'Loss/reward_loss': {'_target_': 'torchmetrics.MeanMetric',
    'sync_on_compute': False},
   'Loss/state_loss': {'_target_': 'torchmetrics.MeanMetric',
    'sync_on_compute': False},
   'Loss/continue_loss': {'_target_': 'torch

In [221]:
hasattr(utils, "AGGREGATOR_KEYS")

True

In [222]:
predefined_metric_keys = utils.AGGREGATOR_KEYS

In [223]:
predefined_metric_keys   # dreamerv3 util.py 에 정의되어 있음

{'Game/ep_len_avg',
 'Grads/actor',
 'Grads/critic',
 'Grads/world_model',
 'Loss/continue_loss',
 'Loss/observation_loss',
 'Loss/policy_loss',
 'Loss/reward_loss',
 'Loss/state_loss',
 'Loss/value_loss',
 'Loss/world_model_loss',
 'Rewards/rew_avg',
 'State/kl',
 'State/post_entropy',
 'State/prior_entropy'}

In [224]:
cfg.metric.log_level == 0

False

In [225]:
cfg.metric.disable_timer

False

In [226]:
from sheeprl.utils.timer import timer
timer.disabled = cfg.metric.log_level == 0 or cfg.metric.disable_timer
timer.disabled

False

In [227]:
keys_to_remove = set(cfg.metric.aggregator.metrics.keys()) - predefined_metric_keys
for k in keys_to_remove:
    cfg.metric.aggregator.metrics.pop(k, None)
    print(k)

In [228]:
from sheeprl.utils.metric import MetricAggregator
MetricAggregator.disabled = cfg.metric.log_level == 0 or len(cfg.metric.aggregator.metrics) == 0
MetricAggregator.disabled

False

In [229]:
cfg.model_manager

{'disabled': True,
 'models': {'world_model': {'model_name': 'dreamer_v3_MsPacmanNoFrameskip-v4_world_model',
   'description': 'DreamerV3 World Model used in MsPacmanNoFrameskip-v4 Environment',
   'tags': {}},
  'actor': {'model_name': 'dreamer_v3_MsPacmanNoFrameskip-v4_actor',
   'description': 'DreamerV3 Actor used in MsPacmanNoFrameskip-v4 Environment',
   'tags': {}},
  'critic': {'model_name': 'dreamer_v3_MsPacmanNoFrameskip-v4_critic',
   'description': 'DreamerV3 Critic used in MsPacmanNoFrameskip-v4 Environment',
   'tags': {}},
  'target_critic': {'model_name': 'dreamer_v3_MsPacmanNoFrameskip-v4_target_critic',
   'description': 'DreamerV3 Target Critic used in MsPacmanNoFrameskip-v4 Environment',
   'tags': {}},
  'moments': {'model_name': 'dreamer_v3_MsPacmanNoFrameskip-v4_moments',
   'description': 'DreamerV3 Moments used in MsPacmanNoFrameskip-v4 Environment',
   'tags': {}}}}

In [230]:
hasattr(cfg, "model_manager") and not cfg.model_manager.disabled and cfg.model_manager.models is not None

False

In [231]:
command

<function sheeprl.algos.dreamer_v3.dreamer_v3.main(fabric: 'Fabric', cfg: 'Dict[str, Any]')>

In [232]:
fabric

<lightning.fabric.fabric.Fabric at 0x7fd473e61760>

In [233]:
cfg

{'num_threads': 1,
 'float32_matmul_precision': 'high',
 'dry_run': False,
 'seed': 5,
 'torch_use_deterministic_algorithms': False,
 'torch_backends_cudnn_benchmark': True,
 'torch_backends_cudnn_deterministic': False,
 'cublas_workspace_config': None,
 'exp_name': 'dreamer_v3_MsPacmanNoFrameskip-v4',
 'run_name': '2024-07-17_12-01-29_dreamer_v3_MsPacmanNoFrameskip-v4_5',
 'root_dir': 'dreamer_v3/MsPacmanNoFrameskip-v4',
 'algo': {'name': 'dreamer_v3',
  'total_steps': 100000,
  'per_rank_batch_size': 16,
  'run_test': True,
  'cnn_keys': {'encoder': ['rgb'], 'decoder': ['rgb']},
  'mlp_keys': {'encoder': [], 'decoder': []},
  'world_model': {'optimizer': {'_target_': 'torch.optim.Adam',
    'lr': 0.0001,
    'eps': 1e-08,
    'weight_decay': 0,
    'betas': [0.9, 0.999]},
   'discrete_size': 32,
   'stochastic_size': 32,
   'kl_dynamic': 0.5,
   'kl_representation': 0.1,
   'kl_free_nats': 1.0,
   'kl_regularizer': 1.0,
   'continue_scale_factor': 1.0,
   'clip_gradients': 1000.0,
  

In [234]:
device = fabric.device
rank = fabric.global_rank
world_size = fabric.world_size
device, rank, world_size

(device(type='cuda', index=0), 0, 1)

In [235]:
cfg.env.frame_stack

-1

In [236]:
from sheeprl.utils.logger import get_log_dir, get_logger
log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name)
log_dir

'logs/runs/dreamer_v3/MsPacmanNoFrameskip-v4/2024-07-17_12-01-29_dreamer_v3_MsPacmanNoFrameskip-v4_5/version_1'

In [237]:
cfg.env.sync_env

False

In [238]:
vectorized_env = gym.vector.AsyncVectorEnv

In [239]:
cfg.env.num_envs

1

In [240]:
from functools import partial
from sheeprl.envs.wrappers import RestartOnException
envs = vectorized_env(
        [
            partial(
                RestartOnException,
                make_env(
                    cfg,
                    cfg.seed + rank * cfg.env.num_envs + i,
                    rank * cfg.env.num_envs,
                    log_dir if rank == 0 else None,
                    "train",
                    vector_env_idx=i,
                ),
            )
            for i in range(cfg.env.num_envs)
        ]
    )

  logger.warn(


In [241]:
action_space = envs.single_action_space
action_space

Discrete(9)

In [242]:
observation_space = envs.single_observation_space
observation_space
# Dict('rgb': Box(0, 255, (3, 64, 64), uint8)) 에서
#0, 255 는 최소, 최대값이다. (3, 64, 64)는 shape이다. uint8은 데이터 타입이다.

Dict('rgb': Box(0, 255, (3, 64, 64), uint8))

In [243]:
is_continuous = isinstance(action_space, gym.spaces.Box)
is_continuous

False

In [244]:
is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete)

In [245]:
actions_dim = tuple(
        action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n])
    )
actions_dim

(9,)

In [246]:
cfg.env.clip_rewards

False

In [247]:
clip_rewards_fn = lambda r: np.tanh(r) if cfg.env.clip_rewards else r

In [248]:
print("Encoder CNN keys:", cfg.algo.cnn_keys.encoder)
print("Encoder MLP keys:", cfg.algo.mlp_keys.encoder)
print("Decoder CNN keys:", cfg.algo.cnn_keys.decoder)
print("Decoder MLP keys:", cfg.algo.mlp_keys.decoder)

Encoder CNN keys: ['rgb']
Encoder MLP keys: []
Decoder CNN keys: ['rgb']
Decoder MLP keys: []


In [249]:
obs_keys = cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder
obs_keys

['rgb']

In [250]:
# world_model, actor, critic, target_critic, player = build_agent(
#         fabric,
#         actions_dim,
#         is_continuous,
#         cfg,
#         observation_space,
#         state["world_model"] if cfg.checkpoint.resume_from else None,
#         state["actor"] if cfg.checkpoint.resume_from else None,
#         state["critic"] if cfg.checkpoint.resume_from else None,
#         state["target_critic"] if cfg.checkpoint.resume_from else None,
#     )

In [251]:
world_model_cfg = cfg.algo.world_model
world_model_cfg

{'optimizer': {'_target_': 'torch.optim.Adam',
  'lr': 0.0001,
  'eps': 1e-08,
  'weight_decay': 0,
  'betas': [0.9, 0.999]},
 'discrete_size': 32,
 'stochastic_size': 32,
 'kl_dynamic': 0.5,
 'kl_representation': 0.1,
 'kl_free_nats': 1.0,
 'kl_regularizer': 1.0,
 'continue_scale_factor': 1.0,
 'clip_gradients': 1000.0,
 'decoupled_rssm': False,
 'learnable_initial_recurrent_state': True,
 'encoder': {'cnn_channels_multiplier': 32,
  'cnn_act': 'torch.nn.SiLU',
  'dense_act': 'torch.nn.SiLU',
  'mlp_layers': 2,
  'cnn_layer_norm': {'cls': 'sheeprl.models.models.LayerNormChannelLast',
   'kw': {'eps': 0.001}},
  'mlp_layer_norm': {'cls': 'sheeprl.models.models.LayerNorm',
   'kw': {'eps': 0.001}},
  'dense_units': 512},
 'recurrent_model': {'recurrent_state_size': 512,
  'layer_norm': {'cls': 'sheeprl.models.models.LayerNorm',
   'kw': {'eps': 0.001}},
  'dense_units': 512},
 'transition_model': {'hidden_size': 512,
  'dense_act': 'torch.nn.SiLU',
  'layer_norm': {'cls': 'sheeprl.model

In [252]:
actor_cfg = cfg.algo.actor
actor_cfg

{'optimizer': {'_target_': 'torch.optim.Adam',
  'lr': 8e-05,
  'eps': 1e-05,
  'weight_decay': 0,
  'betas': [0.9, 0.999]},
 'cls': 'sheeprl.algos.dreamer_v3.agent.Actor',
 'ent_coef': 0.0003,
 'min_std': 0.1,
 'max_std': 1.0,
 'init_std': 2.0,
 'dense_act': 'torch.nn.SiLU',
 'mlp_layers': 2,
 'layer_norm': {'cls': 'sheeprl.models.models.LayerNorm',
  'kw': {'eps': 0.001}},
 'dense_units': 512,
 'clip_gradients': 100.0,
 'unimix': 0.01,
 'action_clip': 1.0,
 'moments': {'decay': 0.99,
  'max': 1.0,
  'percentile': {'low': 0.05, 'high': 0.95}}}

In [253]:
critic_cfg = cfg.algo.critic
critic_cfg

{'optimizer': {'_target_': 'torch.optim.Adam',
  'lr': 8e-05,
  'eps': 1e-05,
  'weight_decay': 0,
  'betas': [0.9, 0.999]},
 'dense_act': 'torch.nn.SiLU',
 'mlp_layers': 2,
 'layer_norm': {'cls': 'sheeprl.models.models.LayerNorm',
  'kw': {'eps': 0.001}},
 'dense_units': 512,
 'per_rank_target_network_update_freq': 1,
 'tau': 0.02,
 'bins': 255,
 'clip_gradients': 100.0}

In [254]:
recurrent_state_size = world_model_cfg.recurrent_model.recurrent_state_size
stochastic_size = world_model_cfg.stochastic_size * world_model_cfg.discrete_size
latent_state_size = stochastic_size + recurrent_state_size
recurrent_state_size, stochastic_size, latent_state_size

(512, 1024, 1536)

In [255]:
cnn_stages = int(np.log2(cfg.env.screen_size) - np.log2(4))
cnn_stages

4

In [256]:
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
from torch import Tensor, nn
from sheeprl.models.models import (
    CNN,
    MLP,
    DeCNN,
    LayerNorm,
    LayerNormChannelLast,
    LayerNormGRUCell,
    MultiDecoder,
    MultiEncoder,
)
from sheeprl.utils.model import ModuleType, cnn_forward

class CNNEncoder(nn.Module):
    """The Dreamer-V3 image encoder. This is composed of 4 `nn.Conv2d` with
    kernel_size=3, stride=2 and padding=1. No bias is used if a `nn.LayerNorm`
    is used after the convolution. This 4-stages model assumes that the image
    is a 64x64 and it ends with a resolution of 4x4. If more than one image is to be encoded, then those will
    be concatenated on the channel dimension and fed to the encoder.

    Args:
        keys (Sequence[str]): the keys representing the image observations to encode.
        input_channels (Sequence[int]): the input channels, one for each image observation to encode.
        image_size (Tuple[int, int]): the image size as (Height,Width).
        channels_multiplier (int): the multiplier for the output channels. Given the 4 stages, the 4 output channels
            will be [1, 2, 4, 8] * `channels_multiplier`.
        layer_norm_cls (Callable[..., nn.Module]): the layer norm to apply after the input projection.
            Defaults to LayerNormChannelLast.
        layer_norm_kw (Dict[str, Any]): the kwargs of the layer norm.
            Default to {"eps": 1e-3}.
        activation (ModuleType, optional): the activation function.
            Defaults to nn.SiLU.
        stages (int, optional): how many stages for the CNN.
    """

    def __init__(
        self,
        keys: Sequence[str],
        input_channels: Sequence[int],
        image_size: Tuple[int, int],
        channels_multiplier: int,
        layer_norm_cls: Callable[..., nn.Module] = LayerNormChannelLast,
        layer_norm_kw: Dict[str, Any] = {"eps": 1e-3},
        activation: ModuleType = nn.SiLU,
        stages: int = 4,
    ) -> None:
        super().__init__()
        self.keys = keys
        self.input_dim = (sum(input_channels), *image_size)
        self.model = nn.Sequential(
            CNN(
                input_channels=self.input_dim[0],
                hidden_channels=(torch.tensor([2**i for i in range(stages)]) * channels_multiplier).tolist(),
                cnn_layer=nn.Conv2d,
                layer_args={"kernel_size": 4, "stride": 2, "padding": 1, "bias": layer_norm_cls == nn.Identity},
                activation=activation,
                norm_layer=[layer_norm_cls] * stages,
                norm_args=[
                    {**layer_norm_kw, "normalized_shape": (2**i) * channels_multiplier} for i in range(stages)
                ],
            ),
            nn.Flatten(-3, -1),
        )
        with torch.no_grad():
            self.output_dim = self.model(torch.zeros(1, *self.input_dim)).shape[-1]

    def forward(self, obs: Dict[str, Tensor]) -> Tensor:
        x = torch.cat([obs[k] for k in self.keys], dim=-3)  # channels dimension
        return cnn_forward(self.model, x, x.shape[-3:], (-1,))

In [257]:
obs_space = observation_space

cnn_encoder = (
    CNNEncoder(
        keys=cfg.algo.cnn_keys.encoder,
        input_channels=[int(np.prod(obs_space[k].shape[:-2])) for k in cfg.algo.cnn_keys.encoder],
        image_size=obs_space[cfg.algo.cnn_keys.encoder[0]].shape[-2:],
        channels_multiplier=world_model_cfg.encoder.cnn_channels_multiplier,
        layer_norm_cls=hydra.utils.get_class(world_model_cfg.encoder.cnn_layer_norm.cls),
        layer_norm_kw=world_model_cfg.encoder.cnn_layer_norm.kw,
        activation=hydra.utils.get_class(world_model_cfg.encoder.cnn_act),
        stages=cnn_stages,
    )
    if cfg.algo.cnn_keys.encoder is not None and len(cfg.algo.cnn_keys.encoder) > 0
    else None
)

In [258]:
# import pytorch_model_summary as pms
# pms.summary(cnn_encoder, torch.zeros(1, 3, 64, 64))

In [259]:
cfg.algo.cnn_keys.encoder[0]

'rgb'

In [260]:
obs_space[cfg.algo.cnn_keys.encoder[0]].shape[-2:]

(64, 64)

In [261]:
world_model_cfg.encoder.cnn_channels_multiplier

32

In [262]:
hydra.utils.get_class(world_model_cfg.encoder.cnn_layer_norm.cls)

sheeprl.models.models.LayerNormChannelLast

In [263]:
mlp_encoder = (
        MLPEncoder(
            keys=cfg.algo.mlp_keys.encoder,
            input_dims=[obs_space[k].shape[0] for k in cfg.algo.mlp_keys.encoder],
            mlp_layers=world_model_cfg.encoder.mlp_layers,
            dense_units=world_model_cfg.encoder.dense_units,
            activation=hydra.utils.get_class(world_model_cfg.encoder.dense_act),
            layer_norm_cls=hydra.utils.get_class(world_model_cfg.encoder.mlp_layer_norm.cls),
            layer_norm_kw=world_model_cfg.encoder.mlp_layer_norm.kw,
        )
        if cfg.algo.mlp_keys.encoder is not None and len(cfg.algo.mlp_keys.encoder) > 0
        else None
    )

In [264]:
encoder = MultiEncoder(cnn_encoder, mlp_encoder)

In [265]:
class RecurrentModel(nn.Module):
    """Recurrent model for the model-base Dreamer-V3 agent.
    This implementation uses the `sheeprl.models.models.LayerNormGRUCell`, which combines
    the standard GRUCell from PyTorch with the `nn.LayerNorm`, where the normalization is applied
    right after having computed the projection from the input to the weight space.

    Args:
        input_size (int): the input size of the model.
        dense_units (int): the number of dense units.
        recurrent_state_size (int): the size of the recurrent state.
        activation_fn (nn.Module): the activation function.
            Default to SiLU.
        layer_norm_cls (Callable[..., nn.Module]): the layer norm to apply after the input projection.
            Defaults to LayerNorm.
        layer_norm_kw (Dict[str, Any]): the kwargs of the layer norm.
            Default to {"eps": 1e-3}.
    """

    def __init__(
        self,
        input_size: int,
        recurrent_state_size: int,
        dense_units: int,
        activation_fn: nn.Module = nn.SiLU,
        layer_norm_cls: Callable[..., nn.Module] = LayerNorm,
        layer_norm_kw: Dict[str, Any] = {"eps": 1e-3},
    ) -> None:
        super().__init__()
        self.mlp = MLP(
            input_dims=input_size,
            output_dim=None,
            hidden_sizes=[dense_units],
            activation=activation_fn,
            layer_args={"bias": layer_norm_cls == nn.Identity},
            norm_layer=[layer_norm_cls],
            norm_args=[{**layer_norm_kw, "normalized_shape": dense_units}],
        )
        self.rnn = LayerNormGRUCell(
            dense_units,
            recurrent_state_size,
            bias=False,
            batch_first=False,
            layer_norm_cls=layer_norm_cls,
            layer_norm_kw=layer_norm_kw,
        )
        self.recurrent_state_size = recurrent_state_size

    def forward(self, input: Tensor, recurrent_state: Tensor) -> Tensor:
        """
        Compute the next recurrent state from the latent state (stochastic and recurrent states) and the actions.

        Args:
            input (Tensor): the input tensor composed by the stochastic state and the actions concatenated together.
            recurrent_state (Tensor): the previous recurrent state.

        Returns:
            the computed recurrent output and recurrent state.
        """
        feat = self.mlp(input)
        out = self.rnn(feat, recurrent_state)
        return out

In [266]:
recurrent_model = RecurrentModel(
        input_size=int(sum(actions_dim) + stochastic_size),
        recurrent_state_size=world_model_cfg.recurrent_model.recurrent_state_size,
        dense_units=world_model_cfg.recurrent_model.dense_units,
        layer_norm_cls=hydra.utils.get_class(world_model_cfg.recurrent_model.layer_norm.cls),
        layer_norm_kw=world_model_cfg.recurrent_model.layer_norm.kw,
    )

In [267]:
represention_model_input_size = encoder.output_dim
represention_model_input_size

4096

In [268]:
cfg.algo.world_model.decoupled_rssm

False

In [269]:
represention_model_input_size += recurrent_state_size
represention_model_input_size

4608

In [270]:
world_model_cfg.representation_model.layer_norm.cls

'sheeprl.models.models.LayerNorm'

In [271]:
cfg.algo.actor.cls

'sheeprl.algos.dreamer_v3.agent.Actor'

In [272]:
cfg.algo.hafner_initialization

True

In [273]:
world_model, actor, critic, target_critic, player = build_agent(
        fabric,
        actions_dim,
        is_continuous,
        cfg,
        observation_space,
        state["world_model"] if cfg.checkpoint.resume_from else None,
        state["actor"] if cfg.checkpoint.resume_from else None,
        state["critic"] if cfg.checkpoint.resume_from else None,
        state["target_critic"] if cfg.checkpoint.resume_from else None,
    )

In [274]:
cfg.algo.world_model.optimizer

{'_target_': 'torch.optim.Adam',
 'lr': 0.0001,
 'eps': 1e-08,
 'weight_decay': 0,
 'betas': [0.9, 0.999]}

In [275]:
world_optimizer = hydra.utils.instantiate(
        cfg.algo.world_model.optimizer, params=world_model.parameters(), _convert_="all"
    )
actor_optimizer = hydra.utils.instantiate(cfg.algo.actor.optimizer, params=actor.parameters(), _convert_="all")
critic_optimizer = hydra.utils.instantiate(cfg.algo.critic.optimizer, params=critic.parameters(), _convert_="all")
if cfg.checkpoint.resume_from:
    world_optimizer.load_state_dict(state["world_optimizer"])
    actor_optimizer.load_state_dict(state["actor_optimizer"])
    critic_optimizer.load_state_dict(state["critic_optimizer"])
world_optimizer, actor_optimizer, critic_optimizer = fabric.setup_optimizers(
    world_optimizer, actor_optimizer, critic_optimizer
)

In [276]:
cfg.algo.actor.moments.decay

0.99

In [277]:
cfg.algo.actor.moments.max

1.0

In [278]:
cfg.algo.actor.moments.percentile.low

0.05

In [279]:
cfg.algo.actor.moments.percentile.high

0.95

In [280]:
from sheeprl.algos.dreamer_v3.utils import Moments, compute_lambda_values, prepare_obs, test
from sheeprl.utils.utils import Ratio, save_configs
moments = Moments(
        cfg.algo.actor.moments.decay,
        cfg.algo.actor.moments.max,
        cfg.algo.actor.moments.percentile.low,
        cfg.algo.actor.moments.percentile.high,
    )
if cfg.checkpoint.resume_from:
    moments.load_state_dict(state["moments"])

if fabric.is_global_zero:
    save_configs(cfg, log_dir)

In [281]:
MetricAggregator.disabled

False

In [282]:
cfg.metric.aggregator

{'_target_': 'sheeprl.utils.metric.MetricAggregator',
 'raise_on_missing': False,
 'metrics': {'Rewards/rew_avg': {'_target_': 'torchmetrics.MeanMetric',
   'sync_on_compute': False},
  'Game/ep_len_avg': {'_target_': 'torchmetrics.MeanMetric',
   'sync_on_compute': False},
  'Loss/world_model_loss': {'_target_': 'torchmetrics.MeanMetric',
   'sync_on_compute': False},
  'Loss/value_loss': {'_target_': 'torchmetrics.MeanMetric',
   'sync_on_compute': False},
  'Loss/policy_loss': {'_target_': 'torchmetrics.MeanMetric',
   'sync_on_compute': False},
  'Loss/observation_loss': {'_target_': 'torchmetrics.MeanMetric',
   'sync_on_compute': False},
  'Loss/reward_loss': {'_target_': 'torchmetrics.MeanMetric',
   'sync_on_compute': False},
  'Loss/state_loss': {'_target_': 'torchmetrics.MeanMetric',
   'sync_on_compute': False},
  'Loss/continue_loss': {'_target_': 'torchmetrics.MeanMetric',
   'sync_on_compute': False},
  'State/kl': {'_target_': 'torchmetrics.MeanMetric',
   'sync_on_compu

In [283]:
aggregator = None
if not MetricAggregator.disabled:
    aggregator: MetricAggregator = hydra.utils.instantiate(cfg.metric.aggregator, _convert_="all").to(device)


In [284]:
buffer_size = cfg.buffer.size // int(cfg.env.num_envs * fabric.world_size) if not cfg.dry_run else 2
buffer_size

100000

In [285]:
fabric.world_size

1

In [286]:
cfg.dry_run

False

In [287]:
from sheeprl.data.buffers import EnvIndependentReplayBuffer, SequentialReplayBuffer
rb = EnvIndependentReplayBuffer(
        buffer_size,
        n_envs=cfg.env.num_envs,
        memmap=cfg.buffer.memmap,
        memmap_dir=os.path.join(log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"),
        buffer_cls=SequentialReplayBuffer,
    )

In [288]:
cfg.buffer.memmap

True

In [289]:
isinstance(state["rb"], list)

False

In [290]:
cfg.checkpoint.resume_from

In [291]:
train_step = 0
last_train = 0
start_iter = (
    # + 1 because the checkpoint is at the end of the update step
    # (when resuming from a checkpoint, the update at the checkpoint
    # is ended and you have to start with the next one)
    (state["iter_num"] // fabric.world_size) + 1
    if cfg.checkpoint.resume_from
    else 1
)
start_iter

1

In [292]:
policy_step = state["iter_num"] * cfg.env.num_envs if cfg.checkpoint.resume_from else 0
policy_step

0

In [293]:
last_log = state["last_log"] if cfg.checkpoint.resume_from else 0
last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0
policy_steps_per_iter = int(cfg.env.num_envs * fabric.world_size)
policy_steps_per_iter

1

In [294]:
policy_steps_per_iter

1

In [295]:
total_iters = int(cfg.algo.total_steps // policy_steps_per_iter) if not cfg.dry_run else 1
total_iters

100000

In [296]:
learning_starts = cfg.algo.learning_starts // policy_steps_per_iter if not cfg.dry_run else 0
prefill_steps = learning_starts - int(learning_starts > 0)
prefill_steps

1023

In [297]:
cfg.algo.replay_ratio

1

In [298]:
cfg.algo.per_rank_pretrain_steps

0

In [299]:
ratio = Ratio(cfg.algo.replay_ratio, pretrain_steps=cfg.algo.per_rank_pretrain_steps)

In [300]:
step_data = {}
obs = envs.reset(seed=cfg.seed)[0]

In [301]:
for k in obs_keys:
    print(k)
    step_data[k] = obs[k][np.newaxis]
    print(obs[k])
    print(obs[k].shape)
    print(step_data[k].shape)
    # print(step_data[k])
        # step_data[k] = obs[k][np.newaxis]

rgb
[[[[159 123  69 ...  69 123 159]
   [228 137   0 ...   0 137 228]
   [228 137   0 ...   0 137 228]
   ...
   [  0   0   0 ...   0   0   0]
   [  0   0   0 ...   0   0   0]
   [  0   0   0 ...   0   0   0]]

  [[ 77  64  45 ...  45  64  77]
   [111  78  28 ...  28  78 111]
   [111  78  28 ...  28  78 111]
   ...
   [  0   0   0 ...   0   0   0]
   [  0   0   0 ...   0   0   0]
   [  0   0   0 ...   0   0   0]]

  [[ 77  81  87 ...  87  81  77]
   [111 121 136 ... 136 121 111]
   [111 121 136 ... 136 121 111]
   ...
   [  0   0   0 ...   0   0   0]
   [  0   0   0 ...   0   0   0]
   [  0   0   0 ...   0   0   0]]]]
(1, 3, 64, 64)
(1, 1, 3, 64, 64)


In [302]:
step_data["rewards"] = np.zeros((1, cfg.env.num_envs, 1))
step_data["truncated"] = np.zeros((1, cfg.env.num_envs, 1))
step_data["terminated"] = np.zeros((1, cfg.env.num_envs, 1))
step_data["is_first"] = np.ones_like(step_data["terminated"])
player.init_states()

In [303]:
step_data

{'rgb': array([[[[[159, 123,  69, ...,  69, 123, 159],
           [228, 137,   0, ...,   0, 137, 228],
           [228, 137,   0, ...,   0, 137, 228],
           ...,
           [  0,   0,   0, ...,   0,   0,   0],
           [  0,   0,   0, ...,   0,   0,   0],
           [  0,   0,   0, ...,   0,   0,   0]],
 
          [[ 77,  64,  45, ...,  45,  64,  77],
           [111,  78,  28, ...,  28,  78, 111],
           [111,  78,  28, ...,  28,  78, 111],
           ...,
           [  0,   0,   0, ...,   0,   0,   0],
           [  0,   0,   0, ...,   0,   0,   0],
           [  0,   0,   0, ...,   0,   0,   0]],
 
          [[ 77,  81,  87, ...,  87,  81,  77],
           [111, 121, 136, ..., 136, 121, 111],
           [111, 121, 136, ..., 136, 121, 111],
           ...,
           [  0,   0,   0, ...,   0,   0,   0],
           [  0,   0,   0, ...,   0,   0,   0],
           [  0,   0,   0, ...,   0,   0,   0]]]]], dtype=uint8),
 'rewards': array([[[0.]]]),
 'truncated': array([[[0.]]]

In [304]:
cumulative_per_rank_gradient_steps = 0

In [305]:
range(start_iter, total_iters + 1)

range(1, 100001)

In [306]:
iter_num = 1

In [307]:
policy_step += policy_steps_per_iter
policy_step

1

In [308]:
iter_num <= learning_starts \
and cfg.checkpoint.resume_from is None \
and "minedojo" not in cfg.env.wrapper._target_.lower()

True

In [309]:
actions = np.array(envs.action_space.sample())

In [310]:
real_actions = actions
actions

array([7])

In [311]:
import torch.nn.functional as F
actions = np.concatenate(
                            [
                                F.one_hot(torch.as_tensor(act), act_dim).numpy()
                                for act, act_dim in zip(actions.reshape(len(actions_dim), -1), actions_dim)
                            ],
                            axis=-1,
                        )
actions

array([[0, 0, 0, 0, 0, 0, 0, 1, 0]])

In [312]:
step_data["actions"] = actions.reshape((1, cfg.env.num_envs, -1))
step_data["actions"]

array([[[0, 0, 0, 0, 0, 0, 0, 1, 0]]])

In [313]:
cfg.buffer.validate_args

False

In [314]:
rb.add(step_data, validate_args=cfg.buffer.validate_args)

In [315]:
real_actions.reshape(envs.action_space.shape)
real_actions

array([7])

In [316]:
next_obs, rewards, terminated, truncated, infos = envs.step(
                    real_actions.reshape(envs.action_space.shape)
                )

In [317]:
next_obs

OrderedDict([('rgb',
              array([[[[159, 123,  69, ...,  69, 123, 159],
                       [228, 137,   0, ...,   0, 137, 228],
                       [228, 137,   0, ...,   0, 137, 228],
                       ...,
                       [  0,   0,   0, ...,   0,   0,   0],
                       [  0,   0,   0, ...,   0,   0,   0],
                       [  0,   0,   0, ...,   0,   0,   0]],
              
                      [[ 77,  64,  45, ...,  45,  64,  77],
                       [111,  78,  28, ...,  28,  78, 111],
                       [111,  78,  28, ...,  28,  78, 111],
                       ...,
                       [  0,   0,   0, ...,   0,   0,   0],
                       [  0,   0,   0, ...,   0,   0,   0],
                       [  0,   0,   0, ...,   0,   0,   0]],
              
                      [[ 77,  81,  87, ...,  87,  81,  77],
                       [111, 121, 136, ..., 136, 121, 111],
                       [111, 121, 136, ..., 136, 12

In [318]:
rewards

array([0.])

In [319]:
terminated

array([False])

In [320]:
truncated

array([False])

In [321]:
infos

{'lives': array([3]),
 '_lives': array([ True]),
 'episode_frame_number': array([33]),
 '_episode_frame_number': array([ True]),
 'frame_number': array([33]),
 '_frame_number': array([ True])}

In [322]:
dones = np.logical_or(terminated, truncated).astype(np.uint8)
dones

array([0], dtype=uint8)

In [323]:
step_data["is_first"]

array([[[1.]]])

In [324]:
step_data["is_first"] = np.zeros_like(step_data["terminated"])
step_data["is_first"]

array([[[0.]]])

In [325]:
real_next_obs = copy.deepcopy(next_obs)
real_next_obs

OrderedDict([('rgb',
              array([[[[159, 123,  69, ...,  69, 123, 159],
                       [228, 137,   0, ...,   0, 137, 228],
                       [228, 137,   0, ...,   0, 137, 228],
                       ...,
                       [  0,   0,   0, ...,   0,   0,   0],
                       [  0,   0,   0, ...,   0,   0,   0],
                       [  0,   0,   0, ...,   0,   0,   0]],
              
                      [[ 77,  64,  45, ...,  45,  64,  77],
                       [111,  78,  28, ...,  28,  78, 111],
                       [111,  78,  28, ...,  28,  78, 111],
                       ...,
                       [  0,   0,   0, ...,   0,   0,   0],
                       [  0,   0,   0, ...,   0,   0,   0],
                       [  0,   0,   0, ...,   0,   0,   0]],
              
                      [[ 77,  81,  87, ...,  87,  81,  77],
                       [111, 121, 136, ..., 136, 121, 111],
                       [111, 121, 136, ..., 136, 12

In [326]:
for k in obs_keys:
    step_data[k] = next_obs[k][np.newaxis]
    print(step_data[k])

[[[[[159 123  69 ...  69 123 159]
    [228 137   0 ...   0 137 228]
    [228 137   0 ...   0 137 228]
    ...
    [  0   0   0 ...   0   0   0]
    [  0   0   0 ...   0   0   0]
    [  0   0   0 ...   0   0   0]]

   [[ 77  64  45 ...  45  64  77]
    [111  78  28 ...  28  78 111]
    [111  78  28 ...  28  78 111]
    ...
    [  0   0   0 ...   0   0   0]
    [  0   0   0 ...   0   0   0]
    [  0   0   0 ...   0   0   0]]

   [[ 77  81  87 ...  87  81  77]
    [111 121 136 ... 136 121 111]
    [111 121 136 ... 136 121 111]
    ...
    [  0   0   0 ...   0   0   0]
    [  0   0   0 ...   0   0   0]
    [  0   0   0 ...   0   0   0]]]]]


In [327]:
obs = next_obs

In [328]:
rewards = rewards.reshape((1, cfg.env.num_envs, -1))
step_data["terminated"] = terminated.reshape((1, cfg.env.num_envs, -1))
step_data["truncated"] = truncated.reshape((1, cfg.env.num_envs, -1))
step_data["rewards"] = clip_rewards_fn(rewards)

In [329]:
dones_idxes = dones.nonzero()[0].tolist()
dones_idxes

[]

In [330]:
learning_starts

1024

In [331]:
prefill_steps

1023

In [332]:
policy_steps_per_iter

1

In [333]:
policy_step

1

In [334]:
ratio = Ratio(cfg.algo.replay_ratio, pretrain_steps=cfg.algo.per_rank_pretrain_steps)

In [335]:
policy_step = 1030
ratio_steps = policy_step - prefill_steps * policy_steps_per_iter
ratio_steps

7

In [336]:
per_rank_gradient_steps = ratio(ratio_steps / world_size)
per_rank_gradient_steps

7

In [337]:
policy_step = 1042
ratio_steps = policy_step - prefill_steps * policy_steps_per_iter
ratio_steps

19

In [338]:
per_rank_gradient_steps = ratio(ratio_steps / world_size)
per_rank_gradient_steps

12

In [339]:
cfg.algo.per_rank_batch_size

16

In [340]:
cfg.algo.per_rank_sequence_length

64

In [341]:
cfg.buffer.from_numpy

False

In [342]:
# 일단 1024번의 random action을 취하고, 그 결과를 replay buffer에 저장한다.
for i in range(1024):
    real_actions = actions = np.array(envs.action_space.sample())
    if not is_continuous:
        actions = np.concatenate(
            [
                F.one_hot(torch.as_tensor(act), act_dim).numpy()
                for act, act_dim in zip(actions.reshape(len(actions_dim), -1), actions_dim)
            ],
            axis=-1,
        )

    step_data["actions"] = actions.reshape((1, cfg.env.num_envs, -1))
    rb.add(step_data, validate_args=cfg.buffer.validate_args)

    next_obs, rewards, terminated, truncated, infos = envs.step(
        real_actions.reshape(envs.action_space.shape)
    )
    dones = np.logical_or(terminated, truncated).astype(np.uint8)

In [343]:
local_data = rb.sample_tensors(
                    cfg.algo.per_rank_batch_size,
                    sequence_length=cfg.algo.per_rank_sequence_length,
                    n_samples=per_rank_gradient_steps,
                    dtype=None,
                    device=fabric.device,
                    from_numpy=cfg.buffer.from_numpy,
                )
local_data

{'rgb': tensor([[[[[[159, 123,  69,  ...,  69, 123, 159],
             [228, 137,   0,  ...,   0, 137, 228],
             [228, 137,   0,  ...,   0, 137, 228],
             ...,
             [  0,   0,   0,  ...,   0,   0,   0],
             [  0,   0,   0,  ...,   0,   0,   0],
             [  0,   0,   0,  ...,   0,   0,   0]],
 
            [[ 77,  64,  45,  ...,  45,  64,  77],
             [111,  78,  28,  ...,  28,  78, 111],
             [111,  78,  28,  ...,  28,  78, 111],
             ...,
             [  0,   0,   0,  ...,   0,   0,   0],
             [  0,   0,   0,  ...,   0,   0,   0],
             [  0,   0,   0,  ...,   0,   0,   0]],
 
            [[ 77,  81,  87,  ...,  87,  81,  77],
             [111, 121, 136,  ..., 136, 121, 111],
             [111, 121, 136,  ..., 136, 121, 111],
             ...,
             [  0,   0,   0,  ...,   0,   0,   0],
             [  0,   0,   0,  ...,   0,   0,   0],
             [  0,   0,   0,  ...,   0,   0,   0]]],
 
 
         

In [344]:
cfg.algo.per_rank_sequence_length

64

In [345]:
local_data['rgb'].shape

torch.Size([12, 64, 16, 3, 64, 64])

In [346]:
cfg.metric.sync_on_compute

False

In [347]:
per_rank_gradient_steps

12

In [348]:
cumulative_per_rank_gradient_steps

0

In [349]:
cfg.algo.critic.per_rank_target_network_update_freq

1

In [350]:
tau = 1 if cumulative_per_rank_gradient_steps == 0 else cfg.algo.critic.tau
tau

1

In [351]:
cfg.algo.critic.tau

0.02

In [352]:
local_data.items()

dict_items([('rgb', tensor([[[[[[159, 123,  69,  ...,  69, 123, 159],
            [228, 137,   0,  ...,   0, 137, 228],
            [228, 137,   0,  ...,   0, 137, 228],
            ...,
            [  0,   0,   0,  ...,   0,   0,   0],
            [  0,   0,   0,  ...,   0,   0,   0],
            [  0,   0,   0,  ...,   0,   0,   0]],

           [[ 77,  64,  45,  ...,  45,  64,  77],
            [111,  78,  28,  ...,  28,  78, 111],
            [111,  78,  28,  ...,  28,  78, 111],
            ...,
            [  0,   0,   0,  ...,   0,   0,   0],
            [  0,   0,   0,  ...,   0,   0,   0],
            [  0,   0,   0,  ...,   0,   0,   0]],

           [[ 77,  81,  87,  ...,  87,  81,  77],
            [111, 121, 136,  ..., 136, 121, 111],
            [111, 121, 136,  ..., 136, 121, 111],
            ...,
            [  0,   0,   0,  ...,   0,   0,   0],
            [  0,   0,   0,  ...,   0,   0,   0],
            [  0,   0,   0,  ...,   0,   0,   0]]],


          [[[159, 123

In [353]:
len(local_data)

6

In [354]:
for k, v in local_data.items():
    print(k)

rgb
rewards
truncated
terminated
is_first
actions


In [355]:
batch = {k: v[0].float() for k, v in local_data.items()}
batch

{'rgb': tensor([[[[[159., 123.,  69.,  ...,  69., 123., 159.],
            [228., 137.,   0.,  ...,   0., 137., 228.],
            [228., 137.,   0.,  ...,   0., 137., 228.],
            ...,
            [  0.,   0.,   0.,  ...,   0.,   0.,   0.],
            [  0.,   0.,   0.,  ...,   0.,   0.,   0.],
            [  0.,   0.,   0.,  ...,   0.,   0.,   0.]],
 
           [[ 77.,  64.,  45.,  ...,  45.,  64.,  77.],
            [111.,  78.,  28.,  ...,  28.,  78., 111.],
            [111.,  78.,  28.,  ...,  28.,  78., 111.],
            ...,
            [  0.,   0.,   0.,  ...,   0.,   0.,   0.],
            [  0.,   0.,   0.,  ...,   0.,   0.,   0.],
            [  0.,   0.,   0.,  ...,   0.,   0.,   0.]],
 
           [[ 77.,  81.,  87.,  ...,  87.,  81.,  77.],
            [111., 121., 136.,  ..., 136., 121., 111.],
            [111., 121., 136.,  ..., 136., 121., 111.],
            ...,
            [  0.,   0.,   0.,  ...,   0.,   0.,   0.],
            [  0.,   0.,   0.,  ...,   0

In [356]:
policy_steps_per_iter

1

In [357]:
cfg.algo.per_rank_batch_size

16

In [358]:
cfg.algo.world_model.recurrent_model.recurrent_state_size

512

In [359]:
cfg.algo.world_model.stochastic_size

32

In [360]:
cfg.algo.world_model.discrete_size

32

In [361]:
sequence_length = cfg.algo.per_rank_sequence_length
sequence_length

64

In [362]:
cfg.algo.world_model.recurrent_model.recurrent_state_size

512

In [363]:
fabric.device

device(type='cuda', index=0)

In [364]:
data = batch
{k: data[k] / 255.0 - 0.5 for k in cfg.algo.cnn_keys.encoder}

{'rgb': tensor([[[[[ 0.1235, -0.0176, -0.2294,  ..., -0.2294, -0.0176,  0.1235],
            [ 0.3941,  0.0373, -0.5000,  ..., -0.5000,  0.0373,  0.3941],
            [ 0.3941,  0.0373, -0.5000,  ..., -0.5000,  0.0373,  0.3941],
            ...,
            [-0.5000, -0.5000, -0.5000,  ..., -0.5000, -0.5000, -0.5000],
            [-0.5000, -0.5000, -0.5000,  ..., -0.5000, -0.5000, -0.5000],
            [-0.5000, -0.5000, -0.5000,  ..., -0.5000, -0.5000, -0.5000]],
 
           [[-0.1980, -0.2490, -0.3235,  ..., -0.3235, -0.2490, -0.1980],
            [-0.0647, -0.1941, -0.3902,  ..., -0.3902, -0.1941, -0.0647],
            [-0.0647, -0.1941, -0.3902,  ..., -0.3902, -0.1941, -0.0647],
            ...,
            [-0.5000, -0.5000, -0.5000,  ..., -0.5000, -0.5000, -0.5000],
            [-0.5000, -0.5000, -0.5000,  ..., -0.5000, -0.5000, -0.5000],
            [-0.5000, -0.5000, -0.5000,  ..., -0.5000, -0.5000, -0.5000]],
 
           [[-0.1980, -0.1824, -0.1588,  ..., -0.1588, -0.1824, -

In [365]:
batch_obs = {k: data[k] / 255.0 - 0.5 for k in cfg.algo.cnn_keys.encoder}
batch_obs.update({k: data[k] for k in cfg.algo.mlp_keys.encoder})
batch_obs

{'rgb': tensor([[[[[ 0.1235, -0.0176, -0.2294,  ..., -0.2294, -0.0176,  0.1235],
            [ 0.3941,  0.0373, -0.5000,  ..., -0.5000,  0.0373,  0.3941],
            [ 0.3941,  0.0373, -0.5000,  ..., -0.5000,  0.0373,  0.3941],
            ...,
            [-0.5000, -0.5000, -0.5000,  ..., -0.5000, -0.5000, -0.5000],
            [-0.5000, -0.5000, -0.5000,  ..., -0.5000, -0.5000, -0.5000],
            [-0.5000, -0.5000, -0.5000,  ..., -0.5000, -0.5000, -0.5000]],
 
           [[-0.1980, -0.2490, -0.3235,  ..., -0.3235, -0.2490, -0.1980],
            [-0.0647, -0.1941, -0.3902,  ..., -0.3902, -0.1941, -0.0647],
            [-0.0647, -0.1941, -0.3902,  ..., -0.3902, -0.1941, -0.0647],
            ...,
            [-0.5000, -0.5000, -0.5000,  ..., -0.5000, -0.5000, -0.5000],
            [-0.5000, -0.5000, -0.5000,  ..., -0.5000, -0.5000, -0.5000],
            [-0.5000, -0.5000, -0.5000,  ..., -0.5000, -0.5000, -0.5000]],
 
           [[-0.1980, -0.1824, -0.1588,  ..., -0.1588, -0.1824, -

In [366]:
data["is_first"][0, :]

tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.]], device='cuda:0')

In [367]:
data["is_first"][0, :] = torch.ones_like(data["is_first"][0, :])
data["is_first"][0, :]

tensor([[1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.]], device='cuda:0')

In [368]:
batch_actions = torch.cat((torch.zeros_like(data["actions"][:1]), data["actions"][:-1]), dim=0)
batch_actions.shape

torch.Size([64, 16, 9])

In [369]:
batch_size = cfg.algo.per_rank_batch_size
sequence_length = cfg.algo.per_rank_sequence_length
recurrent_state_size = cfg.algo.world_model.recurrent_model.recurrent_state_size
stochastic_size = cfg.algo.world_model.stochastic_size
discrete_size = cfg.algo.world_model.discrete_size
device = fabric.device
batch_obs = {k: data[k] / 255.0 - 0.5 for k in cfg.algo.cnn_keys.encoder}
batch_obs.update({k: data[k] for k in cfg.algo.mlp_keys.encoder})
data["is_first"][0, :] = torch.ones_like(data["is_first"][0, :])
stoch_state_size = stochastic_size * discrete_size
stoch_state_size

1024

In [370]:
recurrent_state = torch.zeros(1, batch_size, recurrent_state_size, device=device)
recurrent_state.shape

torch.Size([1, 16, 512])

In [371]:
recurrent_states = torch.empty(sequence_length, batch_size, recurrent_state_size, device=device)
recurrent_states.shape

torch.Size([64, 16, 512])

In [372]:
priors_logits = torch.empty(sequence_length, batch_size, stoch_state_size, device=device)
priors_logits.shape

torch.Size([64, 16, 1024])

In [373]:
embedded_obs = world_model.encoder(batch_obs)
embedded_obs.shape

torch.Size([64, 16, 4096])

In [374]:
batch_obs['rgb'].shape

torch.Size([64, 16, 3, 64, 64])

In [375]:
posterior = torch.zeros(1, batch_size, stochastic_size, discrete_size, device=device)
posterior.shape

torch.Size([1, 16, 32, 32])

In [376]:
posteriors = torch.empty(sequence_length, batch_size, stochastic_size, discrete_size, device=device)
posteriors.shape

torch.Size([64, 16, 32, 32])

In [377]:
posteriors_logits = torch.empty(sequence_length, batch_size, stoch_state_size, device=device)
posteriors_logits.shape

torch.Size([64, 16, 1024])

In [378]:
for i in range(0, sequence_length):
            recurrent_state, posterior, _, posterior_logits, prior_logits = world_model.rssm.dynamic(
                posterior,
                recurrent_state,
                batch_actions[i : i + 1],
                embedded_obs[i : i + 1],
                data["is_first"][i : i + 1],
            )
            recurrent_states[i] = recurrent_state
            priors_logits[i] = prior_logits
            posteriors[i] = posterior
            posteriors_logits[i] = posterior_logits

In [379]:
latent_states = torch.cat((posteriors.view(*posteriors.shape[:-2], -1), recurrent_states), -1)
latent_states.shape

torch.Size([64, 16, 1536])

In [380]:
reconstructed_obs: Dict[str, torch.Tensor] = world_model.observation_model(latent_states)
for k, v in reconstructed_obs.items():
    print(k, v.shape)

rgb torch.Size([64, 16, 3, 64, 64])


In [381]:
len(reconstructed_obs[k].shape[2:])

3

In [382]:
world_model_cfg.observation_model

{'cnn_channels_multiplier': 32,
 'cnn_act': 'torch.nn.SiLU',
 'dense_act': 'torch.nn.SiLU',
 'mlp_layers': 2,
 'cnn_layer_norm': {'cls': 'sheeprl.models.models.LayerNormChannelLast',
  'kw': {'eps': 0.001}},
 'mlp_layer_norm': {'cls': 'sheeprl.models.models.LayerNorm',
  'kw': {'eps': 0.001}},
 'dense_units': 512}

In [383]:
reconstructed_obs: Dict[str, torch.Tensor] = world_model.observation_model(latent_states)
for k, v in reconstructed_obs.items():
    print(k, v.shape)

rgb torch.Size([64, 16, 3, 64, 64])


In [384]:
from sheeprl.utils.distribution import (
    BernoulliSafeMode,
    MSEDistribution,
    SymlogDistribution,
    TwoHotEncodingDistribution,
)
po = {
        k: MSEDistribution(reconstructed_obs[k], dims=len(reconstructed_obs[k].shape[2:]))
        for k in cfg.algo.cnn_keys.decoder
    }
for k, v in po.items():
    print(k, v._event_shape)

rgb torch.Size([3, 64, 64])


In [385]:
world_model.reward_model(latent_states).shape

torch.Size([64, 16, 255])

In [386]:
pr = TwoHotEncodingDistribution(world_model.reward_model(latent_states), dims=1)
# for k, v in pr.items():
#     print(k, v._event_shape)
pr._event_shape

torch.Size([1])

In [387]:
from torch.distributions import Distribution, Independent, OneHotCategorical
pc = Independent(BernoulliSafeMode(logits=world_model.continue_model(latent_states)), 1)
pc._event_shape

torch.Size([1])

In [388]:
continues_targets = 1 - data["terminated"]

In [389]:
priors_logits = priors_logits.view(*priors_logits.shape[:-1], stochastic_size, discrete_size)
priors_logits.shape

torch.Size([64, 16, 32, 32])

In [390]:
posteriors_logits = posteriors_logits.view(*posteriors_logits.shape[:-1], stochastic_size, discrete_size)
posteriors_logits.shape

torch.Size([64, 16, 32, 32])

In [391]:
world_optimizer.zero_grad(set_to_none=True)

In [392]:
cfg.algo.world_model.optimizer

{'_target_': 'torch.optim.Adam',
 'lr': 0.0001,
 'eps': 1e-08,
 'weight_decay': 0,
 'betas': [0.9, 0.999]}

In [393]:
cfg.algo.world_model.clip_gradients

1000.0

In [394]:
imagined_prior = posteriors.detach().reshape(1, -1, stoch_state_size)
imagined_prior.shape

torch.Size([1, 1024, 1024])

In [395]:
posteriors.detach().shape

torch.Size([64, 16, 32, 32])

In [396]:
recurrent_state = recurrent_states.detach().reshape(1, -1, recurrent_state_size)
recurrent_state.shape

torch.Size([1, 1024, 512])

In [397]:
imagined_latent_state = torch.cat((imagined_prior, recurrent_state), -1)
imagined_latent_state.shape

torch.Size([1, 1024, 1536])

In [398]:
imagined_trajectories = torch.empty(
        cfg.algo.horizon + 1,
        batch_size * sequence_length,
        stoch_state_size + recurrent_state_size,
        device=device,
    )
imagined_trajectories.shape

torch.Size([16, 1024, 1536])

In [399]:
imagined_trajectories[0] = imagined_latent_state
imagined_actions = torch.empty(
    cfg.algo.horizon + 1,
    batch_size * sequence_length,
    data["actions"].shape[-1],
    device=device,
)
imagined_actions.shape

torch.Size([16, 1024, 9])

In [400]:
actions = torch.cat(actor(imagined_latent_state.detach())[0], dim=-1)
actions.shape

torch.Size([1, 1024, 9])

In [401]:
cfg.algo.horizon + 1

16

In [402]:
actor(imagined_latent_state.detach())

((tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]]], device='cuda:0',
         grad_fn=<AddBackward0>),),
 (OneHotCategoricalStraightThrough(),))