In [1]:
# # 准备环境
# !wget https://raw.githubusercontent.com/lhiqwj173/dl_helper/master/envs/rl.py > /dev/null 2>&1
# !python rl.py not_install_dl_helper > /dev/null 2>&1
# !pip install /kaggle/working/3rd/dl_helper > /dev/null 2>&1

In [2]:
from pprint import pprint
import numpy as np
import gymnasium as gym
import ale_py
gym.register_envs(ale_py)
import ray
print("ray 版本:", ray.__version__)

ray 版本: 2.40.0


# 配置

- **PPO**
- DQN
- **Rainbow_DQN**
- Double_DQN
- Dueling_DQN
- DQN_PER
- Noisy_DQN
- DQN_C51
- **IMPALA**
- **APPO**


In [3]:
# 算法
algo = "PPO"
# algo = "DQN"
# algo = "Rainbow_DQN"
# algo = "DQN_C51"
# algo = "IMPALA"
# algo = "APPO"

In [4]:
class algo_base:
    def __init__(self, **kwargs):
        self._training_kwargs = kwargs

    @property
    def training_kwargs(self):
        return self._training_kwargs
    
    def _update_kwargs(self, kwargs):
        for k, v in kwargs.items():
            self._training_kwargs[k] = v

class PPO(algo_base):
    @property
    def algo(self):
        return "PPO"

class IMPALA(algo_base):
    @property
    def algo(self):
        return "IMPALA"

class APPO(algo_base):
    def __init__(self, **kwargs):
        self._training_kwargs = {
            'grad_clip': 30.0,
        }

        self._update_kwargs(kwargs)

    @property
    def algo(self):
        return "APPO"

class DQN(algo_base):

    @property
    def algo(self):
        return "DQN"
    
class Rainbow_DQN(DQN):
    def __init__(self, **kwargs):
        self._training_kwargs = {
            "target_network_update_freq": 500,
            'replay_buffer_config': {
                "type": "PrioritizedEpisodeReplayBuffer",
                "capacity": 60000,
                "alpha": 0.5,
                "beta": 0.5,
            },
            # "replay_buffer_config": {
            #     "_enable_replay_buffer_api": False,
            #     "type": "ReplayBuffer",
            #     "type": "PrioritizedReplayBuffer",
            #     "capacity": 50000,
            #     "prioritized_replay_alpha": 0.6,
            #     "prioritized_replay_beta": 0.4,
            #     "prioritized_replay_eps": 1e-6,
            #     "replay_sequence_length": 1
            # },
            "epsilon": [[0, 1.0], [1000000, 0.1]],
            "adam_epsilon": 1e-8,
            "grad_clip": 40.0,
            "num_steps_sampled_before_learning_starts": 10000,
            "tau": 1,
            "num_atoms": 51,
            "v_min": -10.0,
            "v_max": 10.0,
            "noisy": True,
            "sigma0": 0.5,
            "dueling": True,
            "hiddens": [512],
            "double_q": True,
            "n_step": 3,
        }

        self._update_kwargs(kwargs)

class Double_DQN(DQN):
    def __init__(self, **kwargs):
        self._training_kwargs = {
            'double_q': True,
        }

        self._update_kwargs(kwargs)

class Dueling_DQN(DQN):
    def __init__(self, **kwargs):
        self._training_kwargs = {
            'dueling': True,
        }

        self._update_kwargs(kwargs)

class DQN_PER(DQN):
    def __init__(self, **kwargs):
        self._training_kwargs = {
            'replay_buffer_config': {
                "type": "PrioritizedEpisodeReplayBuffer",
                "capacity": 60000,
                "alpha": 0.5,
                "beta": 0.5,
            },
        }

        self._update_kwargs(kwargs)

class Noisy_DQN(DQN):
    def __init__(self, **kwargs):
        self._training_kwargs = {
            'noisy': True,
        }
        
        self._update_kwargs(kwargs)

class DQN_C51(DQN):
    def __init__(self, **kwargs):
        self._training_kwargs = {
            'num_atoms': 51,
            'v_min': -10.0,
            'v_max': 10.0,
        }
        
        self._update_kwargs(kwargs)


In [5]:
from ray.tune.registry import get_trainable_cls
from dl_helper.rl.rl_env.breakout_env import BreakoutEnv# 自定义环境
from ray.tune.registry import register_env

# 注册环境 不是必须
register_env("breakout", lambda config: BreakoutEnv())

# 实例化简单算法配置类
simple_algo = globals()[algo]()

config = (
    get_trainable_cls(simple_algo.algo)
    .get_default_config()
    .api_stack(
        enable_rl_module_and_learner=True,
        enable_env_runner_and_connector_v2=True,
    )
    .environment("breakout")
    # .environment("CartPole-v1")
    .env_runners(num_env_runners=0)
    .evaluation(
        evaluation_interval=10,
        evaluation_duration=3,
    )
    .rl_module(
        model_config={
            "conv_filters": [
                [32, [8, 8], 4],  # [输出通道数, [kernel_size_h, kernel_size_w], stride]
                [64, [4, 4], 2],  # [64个通道, 4x4卷积核, stride=2]
                [64, [3, 3], 1],  # [64个通道, 3x3卷积核, stride=1] 
            ],
        },
    )
    .training(**simple_algo.training_kwargs)
)

  gym.logger.warn(
  gym.logger.warn(
  logger.warn(
  logger.warn(f"{pre} is not within the observation space.")


# 训练

In [6]:
import os
from dl_helper.rl.rl_utils import simplify_rllib_metrics

checkpoint_base_dir = rf'C:/Users/lh/Desktop/temp/RLlib_{algo}_cartpole'
os.makedirs(checkpoint_base_dir, exist_ok=True)

# 构建算法
algo = config.build()

# 模型
algo.get_module()

`UnifiedLogger` will be removed in Ray 2.7.
  return UnifiedLogger(config, logdir, loggers=None)
The `JsonLogger interface is deprecated in favor of the `ray.tune.json.JsonLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `CSVLogger interface is deprecated in favor of the `ray.tune.csv.CSVLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `TBXLogger interface is deprecated in favor of the `ray.tune.tensorboardx.TBXLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
  logger.warn(f"Overriding environment {new_spec.id} already in registry.")


In [None]:
# 训练循环
rounds = 10 * 3
for i in range(rounds):
    print(f"\nTraining iteration {i+1}/{rounds}")
    result = algo.train()
    simplify_rllib_metrics(result)
    
    if (i + 1) % 5 == 0:
        checkpoint_dir = algo.save_to_path(
            os.path.join(os.path.abspath(checkpoint_base_dir), f"checkpoint_{i+1}")
        )
        print(f"Checkpoint saved in directory {checkpoint_dir}")

# 保存最终模型
final_checkpoint = algo.save_to_path(
    os.path.join(os.path.abspath(checkpoint_base_dir), f"final")
)
print(f"Final checkpoint saved in directory {final_checkpoint}")

# 清理
algo.stop()

In [7]:
import numpy as np
print(f"NumPy 版本: {np.__version__}")

NumPy 版本: 1.26.4


# DQN 默认配置


In [None]:
# 基类
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule

# API 类
from ray.rllib.core.rl_module.apis import InferenceOnlyAPI, TargetNetworkAPI

In [None]:
# DQN 模型
from ray.rllib.algorithms.dqn.dqn_rainbow_catalog import DQNRainbowCatalog
"""
"DQN Rainbow 目录类用于构建模型。

`DQNRainbowCatalog` 提供以下模型：
    - 编码器（Encoder）：用于对观察结果进行编码。
    - 目标编码器（Target_Encoder）：用于对目标网络的观察结果进行编码。
    - Af Head：
        如果使用 dueling 架构，则是优势流的头部；如果使用 Q 函数，则是 Q 函数的头部。
        这是一个多节点头部，在期望学习的情况下有 `action_space.n` 个节点，在分布式 Q 学习的情况下有 `action_space.n` 倍于支持原子数（`num_atoms`）的节点数。
    - Vf Head（可选）：
        在选择 dueling 架构时，值函数的头部。这是一个单一节点头部。如果不使用 dueling 架构，则不存在此头部。

所有网络都可以包括嘈杂层（noisy layers），如果 `noisy` 为 `True`。在这种情况下，不使用 epsilon 贪心探索。

可以通过重写 `build_af_head()` 和 `build_vf_head()` 来构建任何自定义头部。另外，可以重写 `AfHeadConfig` 或 `VfHeadConfig` 来在 `RLModule` 运行时构建自定义逻辑。

所有头部可以选择使用分布式学习。在这种情况下，输出神经元的数量对应于离散分布的支持原子数乘以动作数。

任何为探索或推断构建的模块都使用标志 `inference_only=True`，不包含任何目标网络。可以通过 `SingleAgentModuleSpec` 中的 `inference_only` 布尔标志来设置此标志。"
"""

from ray.rllib.algorithms.dqn.torch.dqn_rainbow_torch_rl_module import (
    DQNRainbowTorchRLModule,
)


In [None]:
dqn_config = {
    "exploration_config": {},
    "algo_class": "ray.rllib.algorithms.dqn.dqn.DQN",
    "extra_python_environs_for_driver": {},
    "extra_python_environs_for_worker": {},
    "placement_strategy": "PACK",
    "num_gpus": 0,
    "_fake_gpus": False,
    "num_cpus_for_main_process": 1,
    "framework_str": "torch",
    "eager_tracing": True,
    "eager_max_retraces": 20,
    "tf_session_args": {
        "intra_op_parallelism_threads": 2,
        "inter_op_parallelism_threads": 2,
        "gpu_options": {"allow_growth": True},
        "log_device_placement": False,
        "device_count": {"CPU": 1},
        "allow_soft_placement": True
    },
    "local_tf_session_args": {
        "intra_op_parallelism_threads": 8,
        "inter_op_parallelism_threads": 8
    },
    "torch_compile_learner": False,
    "torch_compile_learner_what_to_compile": "TorchCompileWhatToCompile.FORWARD_TRAIN",
    "torch_compile_learner_dynamo_backend": "inductor",
    "torch_compile_learner_dynamo_mode": None,
    "torch_compile_worker": False,
    "torch_compile_worker_dynamo_backend": "onnxrt",
    "torch_compile_worker_dynamo_mode": None,
    "torch_ddp_kwargs": {},
    "torch_skip_nan_gradients": False,
    "env": "CartPole-v1",
    "env_config": {},
    "observation_space": None,
    "action_space": None,
    "clip_rewards": None,
    "normalize_actions": True,
    "clip_actions": False,
    "_is_atari": False,
    "disable_env_checking": False,
    "env_task_fn": None,
    "render_env": False,
    "action_mask_key": "action_mask",
    "env_runner_cls": None,
    "num_env_runners": 0,
    "num_envs_per_env_runner": 1,
    "num_cpus_per_env_runner": 1,
    "num_gpus_per_env_runner": 0,
    "custom_resources_per_env_runner": {},
    "validate_env_runners_after_construction": True,
    "max_requests_in_flight_per_env_runner": 1,
    "sample_timeout_s": 60.0,
    "create_env_on_local_worker": False,
    "_env_to_module_connector": None,
    "add_default_connectors_to_env_to_module_pipeline": True,
    "_module_to_env_connector": None,
    "add_default_connectors_to_module_to_env_pipeline": True,
    "episode_lookback_horizon": 1,
    "rollout_fragment_length": "auto",
    "batch_mode": "truncate_episodes",
    "compress_observations": False,
    "remote_worker_envs": False,
    "remote_env_batch_wait_ms": 0,
    "enable_tf1_exec_eagerly": False,
    "sample_collector": "ray.rllib.evaluation.collectors.simple_list_collector.SimpleListCollector",
    "preprocessor_pref": "deepmind",
    "observation_filter": "NoFilter",
    "update_worker_filter_stats": True,
    "use_worker_filter_stats": True,
    "sampler_perf_stats_ema_coef": None,
    "num_learners": 0,
    "num_gpus_per_learner": 0,
    "num_cpus_per_learner": 1,
    "local_gpu_idx": 0,
    "max_requests_in_flight_per_learner": 3,
    "gamma": 0.99,
    "lr": 0.0005,
    "grad_clip": 40.0,
    "grad_clip_by": "global_norm",
    "train_batch_size_per_learner": None,
    "train_batch_size": 32,
    "num_epochs": 1,
    "minibatch_size": None,
    "shuffle_batch_per_epoch": False,
    "model": {
        "fcnet_hiddens": [256, 256],
        "fcnet_activation": "tanh",
        "fcnet_weights_initializer": None,
        "fcnet_weights_initializer_config": None,
        "fcnet_bias_initializer": None,
        "fcnet_bias_initializer_config": None,
        "conv_filters": None,
        "conv_activation": "relu",
        "conv_kernel_initializer": None,
        "conv_kernel_initializer_config": None,
        "conv_bias_initializer": None,
        "conv_bias_initializer_config": None,
        "conv_transpose_kernel_initializer": None,
        "conv_transpose_kernel_initializer_config": None,
        "conv_transpose_bias_initializer": None,
        "conv_transpose_bias_initializer_config": None,
        "post_fcnet_hiddens": [],
        "post_fcnet_activation": "relu",
        "post_fcnet_weights_initializer": None,
        "post_fcnet_weights_initializer_config": None,
        "post_fcnet_bias_initializer": None,
        "post_fcnet_bias_initializer_config": None,
        "free_log_std": False,
        "log_std_clip_param": 20.0,
        "no_final_linear": False,
        "vf_share_layers": True,
        "use_lstm": False,
        "max_seq_len": 20,
        "lstm_cell_size": 256,
        "lstm_use_prev_action": False,
        "lstm_use_prev_reward": False,
        "lstm_weights_initializer": None,
        "lstm_weights_initializer_config": None,
        "lstm_bias_initializer": None,
        "lstm_bias_initializer_config": None,
        "_time_major": False,
        "use_attention": False,
        "attention_num_transformer_units": 1,
        "attention_dim": 64,
        "attention_num_heads": 1,
        "attention_head_dim": 32,
        "attention_memory_inference": 50,
        "attention_memory_training": 50,
        "attention_position_wise_mlp_dim": 32,
        "attention_init_gru_gate_bias": 2.0,
        "attention_use_n_prev_actions": 0,
        "attention_use_n_prev_rewards": 0,
        "framestack": True,
        "dim": 84,
        "grayscale": False,
        "zero_mean": True,
        "custom_model": None,
        "custom_model_config": {},
        "custom_action_dist": None,
        "custom_preprocessor": None,
        "encoder_latent_dim": None,
        "always_check_shapes": False,
        "lstm_use_prev_action_reward": -1,
        "_use_default_native_models": -1,
        "_disable_preprocessor_api": False,
        "_disable_action_flattening": False
    },
    "_learner_connector": None,
    "add_default_connectors_to_learner_pipeline": True,
    "learner_config_dict": {},
    "optimizer": {},
    "_learner_class": None,
    "callbacks_class": "ray.rllib.algorithms.callbacks.DefaultCallbacks",
    "explore": True,
    "enable_rl_module_and_learner": True,
    "enable_env_runner_and_connector_v2": True,
    "_prior_exploration_config": {
        "type": "EpsilonGreedy",
        "initial_epsilon": 1.0,
        "final_epsilon": 0.02,
        "epsilon_timesteps": 10000
    },
    "count_steps_by": "env_steps",
    "policies": {"default_policy": [None, None, None, None]},
    "policy_map_capacity": 100,
    "policy_mapping_fn": "AlgorithmConfig.DEFAULT_POLICY_MAPPING_FN",
    "policies_to_train": None,
    "policy_states_are_swappable": False,
    "observation_fn": None,
    "input_": "sampler",
    "input_read_method": "read_parquet",
    "input_read_method_kwargs": {},
    "input_read_schema": {},
    "input_read_episodes": False,
    "input_read_sample_batches": False,
    "input_read_batch_size": None,
    "input_filesystem": None,
    "input_filesystem_kwargs": {},
    "input_compress_columns": ["obs", "new_obs"],
    "input_spaces_jsonable": True,
    "materialize_data": False,
    "materialize_mapped_data": True,
    "map_batches_kwargs": {},
    "iter_batches_kwargs": {},
    "prelearner_class": None,
    "prelearner_buffer_class": None,
    "prelearner_buffer_kwargs": {},
    "prelearner_module_synch_period": 10,
    "dataset_num_iters_per_learner": None,
    "input_config": {},
    "actions_in_input_normalized": False,
    "postprocess_inputs": False,
    "shuffle_buffer_size": 0,
    "output": None,
    "output_config": {},
    "output_compress_columns": ["obs", "new_obs"],
    "output_max_file_size": 67108864,
    "output_max_rows_per_file": None,
    "output_write_method": "write_parquet",
    "output_write_method_kwargs": {},
    "output_filesystem": None,
    "output_filesystem_kwargs": {},
    "output_write_episodes": True,
    "offline_sampling": False,
    "evaluation_interval": 10,
    "evaluation_duration": 3,
    "evaluation_duration_unit": "episodes",
    "evaluation_sample_timeout_s": 120.0,
    "evaluation_parallel_to_training": False,
    "evaluation_force_reset_envs_before_iteration": True,
    "evaluation_config": {"explore": False},
    "off_policy_estimation_methods": {},
    "ope_split_batch_by_episode": True,
    "evaluation_num_env_runners": 0,
    "custom_evaluation_function": None,
    "in_evaluation": False,
    "sync_filters_on_rollout_workers_timeout_s": 10.0,
    "keep_per_episode_custom_metrics": False,
    "metrics_episode_collection_timeout_s": 60.0,
    "metrics_num_episodes_for_smoothing": 100,
    "min_time_s_per_iteration": None,
    "min_train_timesteps_per_iteration": 0,
    "min_sample_timesteps_per_iteration": 1000,
    "log_gradients": True,
    "export_native_model_files": False,
    "checkpoint_trainable_policies_only": False,
    "logger_creator": None,
    "logger_config": None,
    "log_level": "WARN",
    "log_sys_usage": True,
    "fake_sampler": False,
    "seed": None,
    "_run_training_always_in_thread": False,
    "_evaluation_parallel_to_training_wo_thread": False,
    "restart_failed_env_runners": True,
    "ignore_env_runner_failures": False,
    "max_num_env_runner_restarts": 1000,
    "delay_between_env_runner_restarts_s": 60.0,
    "restart_failed_sub_environments": False,
    "num_consecutive_env_runner_failures_tolerance": 100,
    "env_runner_health_probe_timeout_s": 30.0,
    "env_runner_restore_timeout_s": 1800.0,
    "_model_config": {},
    "_rl_module_spec": None,
    "algorithm_config_overrides_per_module": {},
    "_per_module_overrides": {},
    "_torch_grad_scaler_class": None,
    "_torch_lr_scheduler_classes": None,
    "_tf_policy_handles_more_than_one_loss": False,
    "_disable_preprocessor_api": False,
    "_disable_action_flattening": False,
    "_disable_initialize_loss_from_dummy_batch": False,
    "_dont_auto_sync_env_runner_states": False,
    "_is_frozen": False,
    "enable_connectors": -1,
    "simple_optimizer": -1,
    "monitor": -1,
    "evaluation_num_episodes": -1,
    "metrics_smoothing_episodes": -1,
    "timesteps_per_iteration": -1,
    "min_iter_time_s": -1,
    "collect_metrics_timeout": -1,
    "min_time_s_per_reporting": -1,
    "min_train_timesteps_per_reporting": -1,
    "min_sample_timesteps_per_reporting": -1,
    "input_evaluation": -1,
    "policy_map_cache": -1,
    "worker_cls": -1,
    "synchronize_filters": -1,
    "enable_async_evaluation": -1,
    "custom_async_evaluation_function": -1,
    "_enable_rl_module_api": -1,
    "auto_wrap_old_gym_envs": -1,
    "always_attach_evaluation_results": -1,
    "buffer_size": -1,
    "prioritized_replay": -1,
    "learning_starts": -1,
    "replay_batch_size": -1,
    "replay_sequence_length": None,
    "replay_mode": -1,
    "prioritized_replay_alpha": -1,
    "prioritized_replay_beta": -1,
    "prioritized_replay_eps": -1,
    "_disable_execution_plan_api": -1,
    "epsilon": [(0, 1.0), (10000, 0.05)],
    "target_network_update_freq": 500,
    "num_steps_sampled_before_learning_starts": 1000,
    "store_buffer_in_checkpoints": False,
    "adam_epsilon": 1e-08,
    "tau": 1.0,
    "num_atoms": 1,
    "v_min": -10.0,
    "v_max": 10.0,
    "noisy": False,
    "sigma0": 0.5,
    "dueling": True,
    "hiddens": [256],
    "double_q": True,
    "n_step": 1,
    "before_learn_on_batch": None,
    "training_intensity": None,
    "td_error_loss_fn": "huber",
    "categorical_distribution_temperature": 1.0,
    "replay_buffer_config": {
        "type": "PrioritizedEpisodeReplayBuffer",
        "capacity": 50000,
        "alpha": 0.6,
        "beta": 0.4
    },
    "lr_schedule": None
}

# PPO 默认配置（IMPALA/APPO）


In [15]:
# 基类
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.core.rl_module.torch import TorchRLModule
from ray.rllib.algorithms.ppo.ppo_rl_module import PPORLModule
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule

# API 类
from ray.rllib.core.rl_module.apis import InferenceOnlyAPI, ValueFunctionAPI

In [29]:
# 默认的编码器
from ray.rllib.core.models.configs import (
    CNNEncoderConfig,
    MLPEncoderConfig,
    RecurrentEncoderConfig,
)

# config = MLPEncoderConfig(
#     input_dims=[2],
#     hidden_layer_dims=[8, 8],
#     hidden_layer_activation="silu",
#     hidden_layer_use_layernorm=True,
#     hidden_layer_use_bias=False,
#     output_layer_dim=4,
#     output_layer_activation="tanh",
#     output_layer_use_bias=False,
# )
# model = config.build(framework="torch")
# print(model)

config = CNNEncoderConfig(
    input_dims=[84, 84, 3],  # must be 3D tensor (image: w x h x C)
    cnn_filter_specifiers=[
        [16, [8, 8], 4],
        [32, [4, 4], 2],
    ],
    cnn_activation="relu",
    cnn_use_layernorm=False,
    cnn_use_bias=True,
)
model = config.build(framework="torch")
print(model)

# 创建一个batch的输入数据
batch_size = 32
import torch
input_tensor = torch.randn(batch_size, 84, 84, 3)  # PyTorch格式
print(input_tensor.shape)

# 前向传播
output = model({'obs': input_tensor})
print(output)


TorchCNNEncoder(
  (net): Sequential(
    (0): TorchCNN(
      (cnn): Sequential(
        (0): ZeroPad2d((2, 2, 2, 2))
        (1): Conv2d(3, 16, kernel_size=(8, 8), stride=(4, 4))
        (2): ReLU()
        (3): ZeroPad2d((1, 2, 1, 2))
        (4): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2))
        (5): ReLU()
      )
    )
    (1): Flatten(start_dim=1, end_dim=-1)
  )
)
torch.Size([32, 84, 84, 3])
{'encoder_out': tensor([[0.0000, 0.0233, 0.0000,  ..., 0.0138, 0.0000, 0.0348],
        [0.0000, 0.0264, 0.0046,  ..., 0.0392, 0.0000, 0.3358],
        [0.0000, 0.0000, 0.0159,  ..., 0.2468, 0.0000, 0.1673],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.1377, 0.0000, 0.2146],
        [0.0000, 0.0000, 0.0000,  ..., 0.1896, 0.0000, 0.0087],
        [0.0000, 0.0000, 0.1356,  ..., 0.1044, 0.0000, 0.0186]],
       grad_fn=<ViewBackward0>)}


In [6]:
# 自定义 ModelConfig
from ray.rllib.core.models.configs import ModelConfig
from ray.rllib.core.models.torch.encoder import TorchModel, Encoder
from dataclasses import dataclass
from ray.rllib.core.models.base import ENCODER_OUT
from ray.rllib.core.columns import Columns
import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvFCNet(TorchModel, Encoder):
    def __init__(self, config) -> None:
        TorchModel.__init__(self, config)
        Encoder.__init__(self, config)
    
        self.conv1 = nn.Conv2d(in_channels=config.input_dims[0], out_channels=16, kernel_size=3, stride=1, padding=1)
        conv1_out = (config.input_dims[1] // 1) * (config.input_dims[2] // 1) * 16  # 64 * 64 * 16
        self.fc = nn.Linear(conv1_out, config.out_dim)

    def _forward(self, inputs: dict, **kwargs) -> dict:
        x = F.relu(self.conv1(inputs[Columns.OBS]))
        x = x.view(x.size(0), -1)  # Flatten the output of convolutional layer
        x = self.fc(x)

        return {ENCODER_OUT: x}

@dataclass
class test_EncoderConfig(ModelConfig):
    input_dims = None
    out_dim = 10
    def build(self, framework: str = "torch"):
        if framework == "torch":
            # 一个卷积层 + 全连接层
            return ConvFCNet(self)

        else:
            raise ValueError(f'only torch ModelConfig')

    @property
    def output_dims(self):
        """Read-only `output_dims` are inferred automatically from other settings."""
        return self.out_dim
    
net = test_EncoderConfig([3, 64, 64], 10).build()
print(net)

batch_size = 64
input_tensor = torch.randn(batch_size, 3, 64, 64)  # PyTorch格式
print(input_tensor.shape)

# 前向传播
output = net({Columns.OBS: input_tensor})
print(output)

ConvFCNet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (fc): Linear(in_features=65536, out_features=10, bias=True)
)
torch.Size([64, 3, 64, 64])


TypeError: conv2d() received an invalid combination of arguments - got (dict, Parameter, Parameter, tuple, tuple, tuple, int), but expected one of:
 * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, tuple of ints padding, tuple of ints dilation, int groups)
      didn't match because some of the arguments have invalid types: (!dict!, !Parameter!, !Parameter!, !tuple of (int, int)!, !tuple of (int, int)!, !tuple of (int, int)!, int)
 * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, str padding, tuple of ints dilation, int groups)
      didn't match because some of the arguments have invalid types: (!dict!, !Parameter!, !Parameter!, !tuple of (int, int)!, !tuple of (int, int)!, !tuple of (int, int)!, int)


In [11]:
# 用于自定义 PPO 模型
from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog

"""
1. 自定义编码器a 
    - 继承 PPOCatalog 类
        - 重写 Catalog.build_encoder方法
        - 应该根据自定义的 encoder 模型 覆盖属性 `Catalog.latent_dims`，以便可以使用该信息来构建头部。
    - 继承重写算法配置类(PPOConfig) 
        @override(AlgorithmConfig)
        def get_default_rl_module_spec(self) -> RLModuleSpec:
            from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog

            if self.framework_str == "torch":
                from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import (
                    PPOTorchRLModule,
                )

                return RLModuleSpec(module_class=PPOTorchRLModule, catalog_class=PPOCatalog) << return RLModuleSpec(module_class=PPOTorchRLModule, catalog_class=Custom_Catalog)
            elif self.framework_str == "tf2":
                from ray.rllib.algorithms.ppo.tf.ppo_tf_rl_module import PPOTfRLModule

                return RLModuleSpec(module_class=PPOTfRLModule, catalog_class=PPOCatalog)
            else:
                raise ValueError(
                    f"The framework {self.framework_str} is not supported. "
                    "Use either 'torch' or 'tf2'."
                ) 

2. 自定义编码器b
    - 继承 ModelConfig，自定义 ModelConfig
        from ray.rllib.core.models.configs import ModelConfig

3. 调整编码器
    CNNEncoderConfig
    MLPEncoderConfig
    RecurrentEncoderConfig(循环神经网络)

    通过配置调整
    config.rl_module(
        model_config={
            # MLPEncoderConfig
            "fcnet_hiddens": [5, 3, 3],
            "fcnet_kernel_initializer": None,
            "fcnet_kernel_initializer_kwargs": {},
            "fcnet_bias_initializer": None,
            "fcnet_bias_initializer_kwargs": {},

            # CNNEncoderConfig
            "conv_filters": [
                [32, [8, 8], 4],  # [输出通道数, [kernel_size_h, kernel_size_w], stride]
                [64, [4, 4], 2],  # [64个通道, 4x4卷积核, stride=2]
                [64, [3, 3], 1],  # [64个通道, 3x3卷积核, stride=1] 
            ],

            # ...
        },
    )

            "adam_epsilon": 1e-8,        grad_lp":40.0,
            "num_sp_saple_before_nng_sat":10000,
        "ta": 1,   ""b ectoue"num_ iRms":M51, "v_mn":-10.0,
          "v_mx": 100,"osy":Tre,
         "iga0": 0.5,

    RLlib's native RLModules get their Models from a Catalog object.
    By default, that Catalog builds the configs it has as attributes.
    This component was build to be hackable and extensible. You can inject custom
    components into RL Modules by overriding the `build_xxx` methods of this class.
    Note that it is recommended to write a custom RL Module for a single use-case.
    Modifications to Catalogs mostly make sense if you want to reuse the same
    Catalog for different RL Modules. For example if you have written a custom
    encoder and want to inject it into different RL Modules (e.g. for PPO, DQN, etc.).
    You can influence the decision tree that determines the sub-components by modifying
    `Catalog._determine_components_hook`.

    Usage example:

    # Define a custom catalog

    .. testcode::

        import torch
        import gymnasium as gym
        from ray.rllib.core.models.configs import MLPHeadConfig
        from ray.rllib.core.models.catalog import Catalog

        class MyCatalog(Catalog):
            def __init__(
                self,
                observation_space: gym.Space,
                action_space: gym.Space,
                model_config_dict: dict,
            ):
                super().__init__(observation_space, action_space, model_config_dict)
                self.my_model_config = MLPHeadConfig(
                    hidden_layer_dims=[64, 32],
                    input_dims=[self.observation_space.shape[0]],
                )

            def build_my_head(self, framework: str):
                return self.my_model_config.build(framework=framework)

        # With that, RLlib can build and use models from this catalog like this:
        catalog = MyCatalog(gym.spaces.Box(0, 1), gym.spaces.Box(0, 1), {})
        my_head = catalog.build_my_head(framework="torch")

        # Make a call to the built model.
        out = my_head(torch.Tensor([[1]]))
    """
    """描述用于 RL 模块的子模块架构。

    RLlib 的原生 RL 模块从 Catalog 对象获取其模型。
    默认情况下，该 Catalog 会构建其作为属性拥有的配置。
    此组件被构建为可hack和可扩展的。您可以通过重写此类的 `build_xxx` 方法，
    将自定义组件注入到 RL 模块中。
    请注意，建议为单个用例编写自定义 RL 模块。
    对 Catalog 的修改主要在您想要为不同的 RL 模块重用相同的 Catalog 时才有意义。
    例如，如果您编写了一个自定义编码器并希望将其注入到不同的 RL 模块
    （例如，PPO、DQN 等）。您可以通过修改
    `Catalog._determine_components_hook` 来影响决定子组件的决策树。

    使用示例：

    # 定义一个自定义的 catalog

    .. testcode::

        import torch
        import gymnasium as gym
        from ray.rllib.core.models.configs import MLPHeadConfig
        from ray.rllib.core.models.catalog import Catalog

        class MyCatalog(Catalog):
            def __init__(
                self,
                observation_space: gym.Space,
                action_space: gym.Space,
                model_config_dict: dict,
            ):
                super().__init__(observation_space, action_space, model_config_dict)
                self.my_model_config = MLPHeadConfig(
                    hidden_layer_dims=[64, 32],
                    input_dims=[self.observation_space.shape[0]],
                )

            def build_my_head(self, framework: str):
                return self.my_model_config.build(framework=framework)

        # 有了这个，RLlib 可以像这样从这个 catalog 构建和使用模型：
        catalog = MyCatalog(gym.spaces.Box(0, 1), gym.spaces.Box(0, 1), {})
        my_head = catalog.build_my_head(framework="torch")

        # 对构建的模型进行调用。
        out = my_head(torch.Tensor([[1]]))
    """

    # TODO (Sven): Add `framework` arg to c'tor and remove this arg from `build`
    #  methods. This way, we can already know in the c'tor of Catalog, what the exact
    #  action distibution objects are and thus what the output dims for e.g. a pi-head
    #  will be.
    def __init__(
        self,
        observation_space: gym.Space,
        action_space: gym.Space,
        model_config_dict: dict,
        # deprecated args.
        view_requirements=DEPRECATED_VALUE,
    ):
        """Initializes a Catalog with a default encoder config.

        Args:
            observation_space: The observation space of the environment.
            action_space: The action space of the environment.
            model_config_dict: The model config that specifies things like hidden
                dimensions and activations functions to use in this Catalog.
        """
        if view_requirements != DEPRECATED_VALUE:
            deprecation_warning(old="Catalog(view_requirements=..)", error=True)

        # TODO (sven): The following logic won't be needed anymore, once we get rid of
        #  Catalogs entirely. We will assert directly inside the algo's DefaultRLModule
        #  class that the `model_config` is a DefaultModelConfig. Thus users won't be
        #  able to pass in partial config dicts into a default model (alternatively, we
        #  could automatically augment the user provided dict by the default config
        #  dataclass object only(!) for default modules).
        # TODO (sven): 一旦我们完全摆脱 Catalog，以下逻辑将不再需要。
        # 我们将直接在算法的 DefaultRLModule 类中断言 `model_config` 是一个 DefaultModelConfig。
        # 因此，用户将无法将部分配置字典传递到默认模型中（或者，我们可以仅为默认模块自动
        # 通过默认配置数据类对象(!)来增强用户提供的字典）。
        if dataclasses.is_dataclass(model_config_dict):
            model_config_dict = dataclasses.asdict(model_config_dict)
        default_config = dataclasses.asdict(DefaultModelConfig())
        # end: TODO

        self.observation_space = observation_space
        self.action_space = action_space

        self._model_config_dict = default_config | model_config_dict
        self._latent_dims = None

        self._determine_components_hook()

    @OverrideToImplementCustomLogic_CallToSuperRecommended
    def _determine_components_hook(self):
        """Decision tree hook for subclasses to override.

        By default, this method executes the decision tree that determines the
        components that a Catalog builds. You can extend the components by overriding
        this or by adding to the constructor of your subclass.

        Override this method if you don't want to use the default components
        determined here. If you want to use them but add additional components, you
        should call `super()._determine_components()` at the beginning of your
        implementation.

        This makes it so that subclasses are not forced to create an encoder config
        if the rest of their catalog is not dependent on it or if it breaks.
        At the end of this method, an attribute `Catalog.latent_dims`
        should be set so that heads can be built using that information.
        """
        """子类可以重写的决策树钩子。

        默认情况下，此方法执行决策树，该决策树确定 Catalog 构建的组件。
        您可以通过重写此方法或在子类的构造函数中添加内容来扩展组件。

        如果您不想使用此处确定的默认组件，请重写此方法。
        如果您想使用它们但添加额外的组件，您应该在实现开始时调用
        `super()._determine_components()`。

        这使得子类不必创建编码器配置，如果它们 Catalog 的其余部分不依赖于它，或者它会破坏逻辑。
        在此方法的最后，应该设置属性 `Catalog.latent_dims`，以便可以使用该信息来构建头部。
        """

        self._encoder_config = self._get_encoder_config(
            observation_space=self.observation_space,
            action_space=self.action_space,
            model_config_dict=self._model_config_dict,
        )

        # Create a function that can be called when framework is known to retrieve the
        # class type for action distributions
        self._action_dist_class_fn = functools.partial(
            self._get_dist_cls_from_action_space, action_space=self.action_space
        )

        # The dimensions of the latent vector that is output by the encoder and fed
        # to the heads.
        self.latent_dims = self._encoder_config.output_dims

    @property
    def latent_dims(self):
        """Returns the latent dimensions of the encoder.

        This establishes an agreement between encoder and heads about the latent
        dimensions. Encoders can be built to output a latent tensor with
        `latent_dims` dimensions, and heads can be built with tensors of
        `latent_dims` dimensions as inputs. This can be safely ignored if this
        agreement is not needed in case of modifications to the Catalog.

        Returns:
            The latent dimensions of the encoder.
        """
        
        """返回编码器的潜在维度。

        这在编码器和头部之间建立了关于潜在维度的一致性。
        可以构建编码器以输出具有 latent_dims 维度的潜在张量，并且可以构建头部，
        将具有 latent_dims 维度的张量作为输入。
        如果不需要对目录进行修改的情况下达成此一致性，则可以安全地忽略此内容。

        返回：
        编码器的潜在维度。
        """
        return self._latent_dims

    @latent_dims.setter
    def latent_dims(self, value):
        self._latent_dims = value

    @OverrideToImplementCustomLogic
    def build_encoder(self, framework: str) -> Encoder:
        """Builds the encoder.

        By default, this method builds an encoder instance from Catalog._encoder_config.

        You should override this if you want to use RLlib's default RL Modules but
        only want to change the encoder. For example, if you want to use a custom
        encoder, but want to use RLlib's default heads, action distribution and how
        tensors are routed between them. If you want to have full control over the
        RL Module, we recommend writing your own RL Module by inheriting from one of
        RLlib's RL Modules instead.

        Args:
            framework: The framework to use. Either "torch" or "tf2".

        Returns:
            The encoder.
        """
        """构建编码器。

        默认情况下，此方法从 Catalog._encoder_config 构建一个编码器实例。

        如果您想使用 RLlib 的默认 RL 模块，但只想更改编码器，则应重写此方法。
        例如，如果您想使用自定义编码器，但想使用 RLlib 的默认头、动作分布以及张量如何在它们之间路由。
        如果您想完全控制 RL 模块，
        我们建议您通过继承 RLlib 的 RL 模块之一来编写自己的 RL 模块。

        Args:
        framework: 要使用的框架。可以是 "torch" 或 "tf2"。

        Returns:
        编码器。
        """
        assert hasattr(self, "_encoder_config"), (
            "You must define a `Catalog._encoder_config` attribute in your Catalog "
            "subclass or override the `Catalog.build_encoder` method. By default, "
            "an encoder_config is created in the __post_init__ method."
        )
        return self._encoder_config.build(framework=framework)

    @OverrideToImplementCustomLogic
    def get_action_dist_cls(self, framework: str):
        """Get the action distribution class.

        The default behavior is to get the action distribution from the
        `Catalog._action_dist_class_fn`.

        You should override this to have RLlib build your custom action
        distribution instead of the default one. For example, if you don't want to
        use RLlib's default RLModules with their default models, but only want to
        change the distribution that Catalog returns.

        Args:
            framework: The framework to use. Either "torch" or "tf2".

        Returns:
            The action distribution.
        """
        """获取动作分布类。

        默认行为是从 Catalog._action_dist_class_fn 获取动作分布。

        您应该重写此方法，以便让 RLlib 构建您的自定义动作分布，而不是默认的动作分布。
        例如，如果您不想使用 RLlib 的默认 RL 模块及其默认模型，而只想更改 Catalog 返回的分布。

        Args:
        framework: 要使用的框架。可以是 "torch" 或 "tf2"。

        Returns:
        动作分布。
        """
        assert hasattr(self, "_action_dist_class_fn"), (
            "You must define a `Catalog._action_dist_class_fn` attribute in your "
            "Catalog subclass or override the `Catalog.action_dist_class_fn` method. "
            "By default, an action_dist_class_fn is created in the __post_init__ "
            "method."
        )
        return self._action_dist_class_fn(framework=framework)

    @classmethod
    def _get_encoder_config(
        cls,
        observation_space: gym.Space,
        model_config_dict: dict,
        action_space: gym.Space = None,
    ) -> ModelConfig:
        """Returns an EncoderConfig for the given input_space and model_config_dict.

        Encoders are usually used in RLModules to transform the input space into a
        latent space that is then fed to the heads. The returned EncoderConfig
        objects correspond to the built-in Encoder classes in RLlib.
        For example, for a simple 1D-Box input_space, RLlib offers an
        MLPEncoder, hence this method returns the MLPEncoderConfig. You can overwrite
        this method to produce specific EncoderConfigs for your custom Models.

        The following input spaces lead to the following configs:
        - 1D-Box: MLPEncoderConfig
        - 3D-Box: CNNEncoderConfig
        # TODO (Artur): Support more spaces here
        # ...

        Args:
            observation_space: The observation space to use.
            model_config_dict: The model config to use.
            action_space: The action space to use if actions are to be encoded. This
                is commonly the case for LSTM models.

        Returns:
            The encoder config.
        """
        """返回给定 input_space 和 model_config_dict 的 EncoderConfig。

        编码器通常在 RLModules 中使用，用于将输入空间转换为一个潜在空间，然后将其传递给头部。
        返回的 EncoderConfig 对象对应于 RLlib 中的内置编码器类。
        例如，对于一个简单的 1D-Box 输入空间，RLlib 提供了 MLPEncoder
        ，因此此方法返回 MLPEncoderConfig。
        您可以重写此方法以生成特定的 EncoderConfig 以用于您的自定义模型。

        以下输入空间会导致以下配置：

        1D-Box: MLPEncoderConfig
        3D-Box: CNNEncoderConfig
        TODO (Artur): 在此处支持更多空间
        ...
        Args:
        observation_space: 要使用的观测空间。
        model_config_dict: 要使用的模型配置。
        action_space: 如果动作需要编码，则要使用的动作空间。这在 LSTM 模型的情况下通常是这样。

        Returns:
        编码器配置。
        """

        activation = model_config_dict["fcnet_activation"]
        output_activation = model_config_dict["fcnet_activation"]
        use_lstm = model_config_dict["use_lstm"]

        if use_lstm:
            encoder_config = RecurrentEncoderConfig(
                input_dims=observation_space.shape,
                recurrent_layer_type="lstm",
                hidden_dim=model_config_dict["lstm_cell_size"],
                hidden_weights_initializer=model_config_dict["lstm_kernel_initializer"],
                hidden_weights_initializer_config=model_config_dict[
                    "lstm_kernel_initializer_kwargs"
                ],
                hidden_bias_initializer=model_config_dict["lstm_bias_initializer"],
                hidden_bias_initializer_config=model_config_dict[
                    "lstm_bias_initializer_kwargs"
                ],
                batch_major=True,
                num_layers=1,
                tokenizer_config=cls.get_tokenizer_config(
                    observation_space,
                    model_config_dict,
                ),
            )
        else:
            # TODO (Artur): Maybe check for original spaces here
            # input_space is a 1D Box
            if isinstance(observation_space, Box) and len(observation_space.shape) == 1:
                # In order to guarantee backward compatability with old configs,
                # we need to check if no latent dim was set and simply reuse the last
                # fcnet hidden dim for that purpose.
                hidden_layer_dims = model_config_dict["fcnet_hiddens"][:-1]
                encoder_latent_dim = model_config_dict["fcnet_hiddens"][-1]
                encoder_config = MLPEncoderConfig(
                    input_dims=observation_space.shape,
                    hidden_layer_dims=hidden_layer_dims,
                    hidden_layer_activation=activation,
                    hidden_layer_weights_initializer=model_config_dict[
                        "fcnet_kernel_initializer"
                    ],
                    hidden_layer_weights_initializer_config=model_config_dict[
                        "fcnet_kernel_initializer_kwargs"
                    ],
                    hidden_layer_bias_initializer=model_config_dict[
                        "fcnet_bias_initializer"
                    ],
                    hidden_layer_bias_initializer_config=model_config_dict[
                        "fcnet_bias_initializer_kwargs"
                    ],
                    output_layer_dim=encoder_latent_dim,
                    output_layer_activation=output_activation,
                    output_layer_weights_initializer=model_config_dict[
                        "fcnet_kernel_initializer"
                    ],
                    output_layer_weights_initializer_config=model_config_dict[
                        "fcnet_kernel_initializer_kwargs"
                    ],
                    output_layer_bias_initializer=model_config_dict[
                        "fcnet_bias_initializer"
                    ],
                    output_layer_bias_initializer_config=model_config_dict[
                        "fcnet_bias_initializer_kwargs"
                    ],
                )

            # input_space is a 3D Box
            elif (
                isinstance(observation_space, Box) and len(observation_space.shape) == 3
            ):
                if not model_config_dict.get("conv_filters"):
                    model_config_dict["conv_filters"] = get_filter_config(
                        observation_space.shape
                    )

                encoder_config = CNNEncoderConfig(
                    input_dims=observation_space.shape,
                    cnn_filter_specifiers=model_config_dict["conv_filters"],
                    cnn_activation=model_config_dict["conv_activation"],
                    cnn_kernel_initializer=model_config_dict["conv_kernel_initializer"],
                    cnn_kernel_initializer_config=model_config_dict[
                        "conv_kernel_initializer_kwargs"
                    ],
                    cnn_bias_initializer=model_config_dict["conv_bias_initializer"],
                    cnn_bias_initializer_config=model_config_dict[
                        "conv_bias_initializer_kwargs"
                    ],
                )
            # input_space is a 2D Box
            elif (
                isinstance(observation_space, Box) and len(observation_space.shape) == 2
            ):
                # RLlib used to support 2D Box spaces by silently flattening them
                raise ValueError(
                    f"No default encoder config for obs space={observation_space},"
                    f" lstm={use_lstm} found. 2D Box "
                    f"spaces are not supported. They should be either flattened to a "
                    f"1D Box space or enhanced to be a 3D box space."
                )
            # input_space is a possibly nested structure of spaces.
            else:
                # NestedModelConfig
                raise ValueError(
                    f"No default encoder config for obs space={observation_space},"
                    f" lstm={use_lstm} found."
                )

        return encoder_config

    @classmethod
    @OverrideToImplementCustomLogic
    def get_tokenizer_config(
        cls,
        observation_space: gym.Space,
        model_config_dict: dict,
        # deprecated args.
        view_requirements=DEPRECATED_VALUE,
    ) -> ModelConfig:
        """Returns a tokenizer config for the given space.

        This is useful for recurrent / transformer models that need to tokenize their
        inputs. By default, RLlib uses the models supported by Catalog out of the box to
        tokenize.

        You should override this method if you want to change the custom tokenizer
        inside current encoders that Catalog returns without providing the recurrent
        network as a whole. For example, if you want to define some custom CNN layers
        as a tokenizer for a recurrent encoder that already includes the recurrent
        layers and handles the state.

        Args:
            observation_space: The observation space to use.
            model_config_dict: The model config to use.
        """
        """返回给定空间的 tokenizer 配置。

        这对于需要对其输入进行标记化的循环/transformer 模型非常有用。
        默认情况下，RLlib 使用 Catalog 开箱即用的模型进行标记化。

        如果您想更改 Catalog 返回的当前编码器内的自定义 tokenizer，
        而无需提供整个循环网络，则应重写此方法。
        例如，如果您想将一些自定义 CNN 层定义为循环编码器的 tokenizer，
        该循环编码器已经包含循环层并处理状态。

        Args:
        observation_space: 要使用的观测空间。
        model_config_dict: 要使用的模型配置。
        """
        if view_requirements != DEPRECATED_VALUE:
            deprecation_warning(old="Catalog(view_requirements=..)", error=True)

        return cls._get_encoder_config(
            observation_space=observation_space,
            # Use model_config_dict without flags that would end up in complex models
            model_config_dict={
                **model_config_dict,
                **{"use_lstm": False, "use_attention": False},
            },
        )

    @classmethod
    def _get_dist_cls_from_action_space(
        cls,
        action_space: gym.Space,
        *,
        framework: Optional[str] = None,
    ) -> Distribution:
        """Returns a distribution class for the given action space.

        You can get the required input dimension for the distribution by calling
        `action_dict_cls.required_input_dim(action_space)`
        on the retrieved class. This is useful, because the Catalog needs to find out
        about the required input dimension for the distribution before the model that
        outputs these inputs is configured.

        Args:
            action_space: Action space of the target gym env.
            framework: The framework to use.

        Returns:
            The distribution class for the given action space.
        """
        
        """返回给定动作空间的分布类。

        您可以通过在检索到的类上调用 action_dict_cls.required_input_dim(action_space) 获取分布所需的输入维度。
        这对于 Catalog 在配置输出这些输入的模型之前需要了解分布所需的输入维度非常有用。

        Args:
        action_space: 目标 Gym 环境的动作空间。
        framework: 要使用的框架。

        Returns:
        给定动作空间的分布类。
        """
        # If no framework provided, return no action distribution class (None).
        if framework is None:
            return None
        # This method is structured in two steps:
        # Firstly, construct a dictionary containing the available distribution classes.
        # Secondly, return the correct distribution class for the given action space.

        # Step 1: Construct the dictionary.

        class DistEnum(enum.Enum):
            Categorical = "Categorical"
            DiagGaussian = "Gaussian"
            Deterministic = "Deterministic"
            MultiDistribution = "MultiDistribution"
            MultiCategorical = "MultiCategorical"

        if framework == "torch":
            from ray.rllib.models.torch.torch_distributions import (
                TorchCategorical,
                TorchDeterministic,
                TorchDiagGaussian,
            )

            distribution_dicts = {
                DistEnum.Deterministic: TorchDeterministic,
                DistEnum.DiagGaussian: TorchDiagGaussian,
                DistEnum.Categorical: TorchCategorical,
            }
        elif framework == "tf2":
            from ray.rllib.models.tf.tf_distributions import (
                TfCategorical,
                TfDeterministic,
                TfDiagGaussian,
            )

            distribution_dicts = {
                DistEnum.Deterministic: TfDeterministic,
                DistEnum.DiagGaussian: TfDiagGaussian,
                DistEnum.Categorical: TfCategorical,
            }
        else:
            raise ValueError(
                f"Unknown framework: {framework}. Only 'torch' and 'tf2' are "
                "supported for RLModule Catalogs."
            )

        # Only add a MultiAction distribution class to the dict if we can compute its
        # components (we need a Tuple/Dict space for this).
        if isinstance(action_space, (Tuple, Dict)):
            partial_multi_action_distribution_cls = _multi_action_dist_partial_helper(
                catalog_cls=cls,
                action_space=action_space,
                framework=framework,
            )

            distribution_dicts[
                DistEnum.MultiDistribution
            ] = partial_multi_action_distribution_cls

        # Only add a MultiCategorical distribution class to the dict if we can compute
        # its components (we need a MultiDiscrete space for this).
        if isinstance(action_space, MultiDiscrete):
            partial_multi_categorical_distribution_cls = (
                _multi_categorical_dist_partial_helper(
                    action_space=action_space,
                    framework=framework,
                )
            )

            distribution_dicts[
                DistEnum.MultiCategorical
            ] = partial_multi_categorical_distribution_cls

        # Step 2: Return the correct distribution class for the given action space.

        # Box space -> DiagGaussian OR Deterministic.
        if isinstance(action_space, Box):
            if action_space.dtype.char in np.typecodes["AllInteger"]:
                raise ValueError(
                    "Box(..., `int`) action spaces are not supported. "
                    "Use MultiDiscrete  or Box(..., `float`)."
                )
            else:
                if len(action_space.shape) > 1:
                    raise UnsupportedSpaceException(
                        f"Action space has multiple dimensions {action_space.shape}. "
                        f"Consider reshaping this into a single dimension, using a "
                        f"custom action distribution, using a Tuple action space, "
                        f"or the multi-agent API."
                    )
                return distribution_dicts[DistEnum.DiagGaussian]

        # Discrete Space -> Categorical.
        elif isinstance(action_space, Discrete):
            return distribution_dicts[DistEnum.Categorical]

        # Tuple/Dict Spaces -> MultiAction.
        elif isinstance(action_space, (Tuple, Dict)):
            return distribution_dicts[DistEnum.MultiDistribution]

        # Simplex -> Dirichlet.
        elif isinstance(action_space, Simplex):
            # TODO(Artur): Supported Simplex (in torch).
            raise NotImplementedError("Simplex action space not yet supported.")

        # MultiDiscrete -> MultiCategorical.
        elif isinstance(action_space, MultiDiscrete):
            return distribution_dicts[DistEnum.MultiCategorical]

        # Unknown type -> Error.
        else:
            raise NotImplementedError(f"Unsupported action space: `{action_space}`")

    @staticmethod
    def get_preprocessor(observation_space: gym.Space, **kwargs) -> Preprocessor:
        """Returns a suitable preprocessor for the given observation space.

        Args:
            observation_space: The input observation space.
            **kwargs: Forward-compatible kwargs.

        Returns:
            preprocessor: Preprocessor for the observations.
        """
        """返回适用于给定观测空间的预处理器。

        Args:
            observation_space: 输入的观测空间。
            **kwargs: 向前兼容的 kwargs。

        Returns:
            preprocessor: 观测的预处理器。
        """
        # TODO(Artur): Since preprocessors have long been @PublicAPI with the options
        #  kwarg as part of their constructor, we fade out support for this,
        #  beginning with this entrypoint.
        # Next, we should deprecate the `options` kwarg from the Preprocessor itself,
        # after deprecating the old catalog and other components that still pass this.
        # TODO(Artur): 由于预处理器长期以来一直是 @PublicAPI，并且 options kwarg 是其构造函数的一部分，
        # 我们正在逐渐淘汰对此的支持，
        # 从这个入口点开始。
        # 接下来，在弃用旧的 catalog 和其他仍然传递此参数的组件之后，
        # 我们应该弃用 Preprocessor 本身的 `options` kwarg。
        options = kwargs.get("options", {})
        if options:
            deprecation_warning(
                old="get_preprocessor_for_space(..., options={...})",
                help="Override `Catalog.get_preprocessor()` "
                "in order to implement custom behaviour.",
                error=False,
            )

        if options.get("custom_preprocessor"):
            deprecation_warning(
                old="model_config['custom_preprocessor']",
                help="Custom preprocessors are deprecated, "
                "since they sometimes conflict with the built-in "
                "preprocessors for handling complex observation spaces. "
                "Please use wrapper classes around your environment "
                "instead.",
                error=True,
            )
        else:
            # TODO(Artur): Inline the get_preprocessor() call here once we have
            #  deprecated the old model catalog.
            cls = get_preprocessor(observation_space)
            prep = cls(observation_space, options)
            return prep

class PPOCatalog(Catalog):
    """The Catalog class used to build models for PPO.

    PPOCatalog provides the following models:
        - ActorCriticEncoder: The encoder used to encode the observations.
        - Pi Head: The head used to compute the policy logits.
        - Value Function Head: The head used to compute the value function.

    The ActorCriticEncoder is a wrapper around Encoders to produce separate outputs
    for the policy and value function. See implementations of PPORLModule for
    more details.

    Any custom ActorCriticEncoder can be built by overriding the
    build_actor_critic_encoder() method. Alternatively, the ActorCriticEncoderConfig
    at PPOCatalog.actor_critic_encoder_config can be overridden to build a custom
    ActorCriticEncoder during RLModule runtime.

    Any custom head can be built by overriding the build_pi_head() and build_vf_head()
    methods. Alternatively, the PiHeadConfig and VfHeadConfig can be overridden to
    build custom heads during RLModule runtime.

    Any module built for exploration or inference is built with the flag
    `ìnference_only=True` and does not contain a value network. This flag can be set
    in the `SingleAgentModuleSpec` through the `inference_only` boolean flag.
    In case that the actor-critic-encoder is not shared between the policy and value
    function, the inference-only module will contain only the actor encoder network.
    """
    """
    PPO 目录类用于构建模型。

    PPOCatalog 提供以下模型：
        - ActorCriticEncoder（演员-评论员编码器）：用于对观察结果进行编码。
        - Pi Head（策略头部）：用于计算策略的对数（logits）。
        - Value Function Head（值函数头部）：用于计算值函数。

    ActorCriticEncoder 是 Encoders 的包装器，用于为策略和值函数产生单独的输出。查看 PPORLModule 的实现以获取更多细节。

    任何自定义的 ActorCriticEncoder 都可以通过重写 build_actor_critic_encoder() 方法进行构建。
    另外，可以重写 PPOCatalog.actor_critic_encoder_config 中的 ActorCriticEncoderConfig 来在 RLModule 运行时构建自定义 ActorCriticEncoder。

    任何自定义的头部都可以通过重写 build_pi_head() 和 build_vf_head() 方法进行构建。
    另外，可以重写 PiHeadConfig 和 VfHeadConfig 来在 RLModule 运行时构建自定义头部。

    任何为探索或推断构建的模块都使用标志 `inference_only=True`，不包含值网络。
    可以通过 `SingleAgentModuleSpec` 中的 `inference_only` 布尔标志来设置此标志。

    如果 ActorCriticEncoder 在策略和值函数之间不共享，那么推断模块将仅包含演员编码器网络。 -> self.model_config.get("vf_share_layers") 
    """

    def __init__(
        self,
        observation_space: gym.Space,
        action_space: gym.Space,
        model_config_dict: dict,
    ):
        """Initializes the PPOCatalog.

        Args:
            observation_space: The observation space of the Encoder.
            action_space: The action space for the Pi Head.
            model_config_dict: The model config to use.
        """
        super().__init__(
            observation_space=observation_space,
            action_space=action_space,
            model_config_dict=model_config_dict,
        )
        # Replace EncoderConfig by ActorCriticEncoderConfig
        self.actor_critic_encoder_config = ActorCriticEncoderConfig(
            base_encoder_config=self._encoder_config,
            shared=self._model_config_dict["vf_share_layers"],
        )

        self.pi_and_vf_head_hiddens = self._model_config_dict["head_fcnet_hiddens"]
        self.pi_and_vf_head_activation = self._model_config_dict[
            "head_fcnet_activation"
        ]

        # We don't have the exact (framework specific) action dist class yet and thus
        # cannot determine the exact number of output nodes (action space) required.
        # -> Build pi config only in the `self.build_pi_head` method.
        self.pi_head_config = None

        self.vf_head_config = MLPHeadConfig(
            input_dims=self.latent_dims,
            hidden_layer_dims=self.pi_and_vf_head_hiddens,
            hidden_layer_activation=self.pi_and_vf_head_activation,
            output_layer_activation="linear",
            output_layer_dim=1,
        )

    @OverrideToImplementCustomLogic
    def build_actor_critic_encoder(self, framework: str) -> ActorCriticEncoder:
        """Builds the ActorCriticEncoder.

        The default behavior is to build the encoder from the encoder_config.
        This can be overridden to build a custom ActorCriticEncoder as a means of
        configuring the behavior of a PPORLModule implementation.

        Args:
            framework: The framework to use. Either "torch" or "tf2".

        Returns:
            The ActorCriticEncoder.
        """
        return self.actor_critic_encoder_config.build(framework=framework)

    @override(Catalog)
    def build_encoder(self, framework: str) -> Encoder:
        """Builds the encoder.

        Since PPO uses an ActorCriticEncoder, this method should not be implemented.
        """
        raise NotImplementedError(
            "Use PPOCatalog.build_actor_critic_encoder() instead for PPO."
        )

    @OverrideToImplementCustomLogic
    def build_pi_head(self, framework: str) -> Model:
        """Builds the policy head.

        The default behavior is to build the head from the pi_head_config.
        This can be overridden to build a custom policy head as a means of configuring
        the behavior of a PPORLModule implementation.

        Args:
            framework: The framework to use. Either "torch" or "tf2".

        Returns:
            The policy head.
        """
        # Get action_distribution_cls to find out about the output dimension for pi_head
        action_distribution_cls = self.get_action_dist_cls(framework=framework)
        if self._model_config_dict["free_log_std"]:
            _check_if_diag_gaussian(
                action_distribution_cls=action_distribution_cls, framework=framework
            )
            is_diag_gaussian = True
        else:
            is_diag_gaussian = _check_if_diag_gaussian(
                action_distribution_cls=action_distribution_cls,
                framework=framework,
                no_error=True,
            )
        required_output_dim = action_distribution_cls.required_input_dim(
            space=self.action_space, model_config=self._model_config_dict
        )
        # Now that we have the action dist class and number of outputs, we can define
        # our pi-config and build the pi head.
        pi_head_config_class = (
            FreeLogStdMLPHeadConfig
            if self._model_config_dict["free_log_std"]
            else MLPHeadConfig
        )
        self.pi_head_config = pi_head_config_class(
            input_dims=self.latent_dims,
            hidden_layer_dims=self.pi_and_vf_head_hiddens,
            hidden_layer_activation=self.pi_and_vf_head_activation,
            output_layer_dim=required_output_dim,
            output_layer_activation="linear",
            clip_log_std=is_diag_gaussian,
            log_std_clip_param=self._model_config_dict.get("log_std_clip_param", 20),
        )

        return self.pi_head_config.build(framework=framework)

    @OverrideToImplementCustomLogic
    def build_vf_head(self, framework: str) -> Model:
        """Builds the value function head.

        The default behavior is to build the head from the vf_head_config.
        This can be overridden to build a custom value function head as a means of
        configuring the behavior of a PPORLModule implementation.

        Args:
            framework: The framework to use. Either "torch" or "tf2".

        Returns:
            The value function head.
        """
        return self.vf_head_config.build(framework=framework)

In [None]:
@PublicAPI(stability="alpha")
class RLModule(Checkpointable, abc.ABC):
    """Base class for RLlib modules.

    Subclasses should call super().__init__(config) in their __init__ method.
    Here is the pseudocode for how the forward methods are called:

    Example for creating a sampling loop:

    .. testcode::

        from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import (
            PPOTorchRLModule
        )
        from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
        import gymnasium as gym
        import torch

        env = gym.make("CartPole-v1")

        # Create a single agent RL module spec.
        module_spec = RLModuleSpec(
            module_class=PPOTorchRLModule,
            observation_space=env.observation_space,
            action_space=env.action_space,
            model_config=DefaultModelConfig(fcnet_hiddens=[128, 128]),
            catalog_class=PPOCatalog,
        )
        module = module_spec.build()
        action_dist_class = module.get_inference_action_dist_cls()
        obs, info = env.reset()
        terminated = False

        while not terminated:
            fwd_ins = {"obs": torch.Tensor([obs])}
            fwd_outputs = module.forward_exploration(fwd_ins)
            # this can be either deterministic or stochastic distribution
            action_dist = action_dist_class.from_logits(
                fwd_outputs["action_dist_inputs"]
            )
            action = action_dist.sample()[0].numpy()
            obs, reward, terminated, truncated, info = env.step(action)


    Example for training:

    .. testcode::

        from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import (
            PPOTorchRLModule
        )
        from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
        import gymnasium as gym
        import torch

        env = gym.make("CartPole-v1")

        # Create a single agent RL module spec.
        module_spec = RLModuleSpec(
            module_class=PPOTorchRLModule,
            observation_space=env.observation_space,
            action_space=env.action_space,
            model_config=DefaultModelConfig(fcnet_hiddens=[128, 128]),
            catalog_class=PPOCatalog,
        )
        module = module_spec.build()

        fwd_ins = {"obs": torch.Tensor([obs])}
        fwd_outputs = module.forward_train(fwd_ins)
        # loss = compute_loss(fwd_outputs, fwd_ins)
        # update_params(module, loss)

    Example for inference:

    .. testcode::

        from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import (
            PPOTorchRLModule
        )
        from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
        import gymnasium as gym
        import torch

        env = gym.make("CartPole-v1")

        # Create a single agent RL module spec.
        module_spec = RLModuleSpec(
            module_class=PPOTorchRLModule,
            observation_space=env.observation_space,
            action_space=env.action_space,
            model_config=DefaultModelConfig(fcnet_hiddens=[128, 128]),
            catalog_class=PPOCatalog,
        )
        module = module_spec.build()

        while not terminated:
            fwd_ins = {"obs": torch.Tensor([obs])}
            fwd_outputs = module.forward_inference(fwd_ins)
            # this can be either deterministic or stochastic distribution
            action_dist = action_dist_class.from_logits(
                fwd_outputs["action_dist_inputs"]
            )
            action = action_dist.sample()[0].numpy()
            obs, reward, terminated, truncated, info = env.step(action)


    Args:
        config: The config for the RLModule.

    Abstract Methods:
        ``~_forward_train``: Forward pass during training.

        ``~_forward_exploration``: Forward pass during training for exploration.

        ``~_forward_inference``: Forward pass during inference.


    Note:
        There is a reason that the specs are not written as abstract properties.
        The reason is that torch overrides `__getattr__` and `__setattr__`. This means
        that if we define the specs as properties, then any error in the property will
        be interpreted as a failure to retrieve the attribute and will invoke
        `__getattr__` which will give a confusing error about the attribute not found.
        More details here: https://github.com/pytorch/pytorch/issues/49726.
    """

    """RLlib模块的基础类。

    子类应在它们的__init__方法中调用super().__init__(config)。
    以下是如何调用forward方法的伪代码：

    创建采样循环的示例：

    .. testcode::

        from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import (
            PPOTorchRLModule
        )
        from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
        import gymnasium as gym
        import torch

        env = gym.make("CartPole-v1")

        # 创建单个智能体的RL模块规范。
        module_spec = RLModuleSpec(
            module_class=PPOTorchRLModule,
            observation_space=env.observation_space,
            action_space=env.action_space,
            model_config=DefaultModelConfig(fcnet_hiddens=[128, 128]),
            catalog_class=PPOCatalog,
        )
        module = module_spec.build()
        action_dist_class = module.get_inference_action_dist_cls()
        obs, info = env.reset()
        terminated = False

        while not terminated:
            fwd_ins = {"obs": torch.Tensor([obs])}
            fwd_outputs = module.forward_exploration(fwd_ins)
            # 这可以是确定性或随机分布
            action_dist = action_dist_class.from_logits(
                fwd_outputs["action_dist_inputs"]
            )
            action = action_dist.sample()[0].numpy()
            obs, reward, terminated, truncated, info = env.step(action)


    训练的示例：

    .. testcode::

        from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import (
            PPOTorchRLModule
        )
        from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
        import gymnasium as gym
        import torch

        env = gym.make("CartPole-v1")

        # 创建单个智能体的RL模块规范。
        module_spec = RLModuleSpec(
            module_class=PPOTorchRLModule,
            observation_space=env.observation_space,
            action_space=env.action_space,
            model_config=DefaultModelConfig(fcnet_hiddens=[128, 128]),
            catalog_class=PPOCatalog,
        )
        module = module_spec.build()

        fwd_ins = {"obs": torch.Tensor([obs])}
        fwd_outputs = module.forward_train(fwd_ins)
        # loss = compute_loss(fwd_outputs, fwd_ins)
        # update_params(module, loss)

    推理的示例：

    .. testcode::

        from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import (
            PPOTorchRLModule
        )
        from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
        import gymnasium as gym
        import torch

        env = gym.make("CartPole-v1")

        # 创建单个智能体的RL模块规范。
        module_spec = RLModuleSpec(
            module_class=PPOTorchRLModule,
            observation_space=env.observation_space,
            action_space=env.action_space,
            model_config=DefaultModelConfig(fcnet_hiddens=[128, 128]),
            catalog_class=PPOCatalog,
        )
        module = module_spec.build()

        while not terminated:
            fwd_ins = {"obs": torch.Tensor([obs])}
            fwd_outputs = module.forward_inference(fwd_ins)
            # 这可以是确定性或随机分布
            action_dist = action_dist_class.from_logits(
                fwd_outputs["action_dist_inputs"]
            )
            action = action_dist.sample()[0].numpy()
            obs, reward, terminated, truncated, info = env.step(action)


    参数:
        config: RLModule的配置。

    抽象方法:
        ``~_forward_train``: 训练期间的前向传递。

        ``~_forward_exploration``: 训练期间用于探索的前向传递。

        ``~_forward_inference``: 推理期间的前向传递。

    注意:
        规格没有被写成抽象属性的原因是torch重载了`__getattr__`和`__setattr__`。
        这意味着如果我们把规格定义为属性，那么属性中的任何错误都会被解释为获取属性的失败，并会调用`__getattr__`，
        这将给出一个关于找不到属性的令人困惑的错误。
        更多细节请参见：https://github.com/pytorch/pytorch/issues/49726。
    """

    framework: str = None

    STATE_FILE_NAME = "module_state.pkl"

    def __init__(
        self,
        config=DEPRECATED_VALUE,
        *,
        observation_space: Optional[gym.Space] = None,
        action_space: Optional[gym.Space] = None,
        inference_only: Optional[bool] = None,
        learner_only: bool = False,
        model_config: Optional[Union[dict, DefaultModelConfig]] = None,
        catalog_class=None,
    ):
        # TODO (sven): Deprecate Catalog and replace with utility functions to create
        #  primitive components based on obs- and action spaces.
        self.catalog = None
        self._catalog_ctor_error = None

        # Deprecated
        self.config = config
        if self.config != DEPRECATED_VALUE:
            deprecation_warning(
                old="RLModule(config=[RLModuleConfig])",
                new="RLModule(observation_space=.., action_space=.., inference_only=..,"
                " learner_only=.., model_config=..)",
                help="See https://github.com/ray-project/ray/blob/master/rllib/examples/rl_modules/custom_cnn_rl_module.py "  # noqa
                "for how to write a custom RLModule.",
                error=True,
            )
        else:
            self.observation_space = observation_space
            self.action_space = action_space
            self.inference_only = inference_only
            self.learner_only = learner_only
            self.model_config = model_config
            try:
                self.catalog = catalog_class(
                    observation_space=self.observation_space,
                    action_space=self.action_space,
                    model_config_dict=self.model_config,
                )
            except Exception as e:
                logger.warning(
                    "Could not create a Catalog object for your RLModule! If you are "
                    "not using the new API stack yet, make sure to switch it off in "
                    "your config: `config.api_stack(enable_rl_module_and_learner=False"
                    ", enable_env_runner_and_connector_v2=False)`. Some algos already "
                    "use the new stack by default. Ignore this message, if your "
                    "RLModule does not use a Catalog to build its sub-components."
                )
                self._catalog_ctor_error = e

        # TODO (sven): Deprecate this. We keep it here for now in case users
        #  still have custom models (or subclasses of RLlib default models)
        #  into which they pass in a `config` argument.
        # TODO (sven): 弃用此方法。我们暂时保留它，
        # 以防用户仍在使用自定义模型（或RLlib默认模型的子类）并传递`config`参数。
        self.config = RLModuleConfig(
            observation_space=self.observation_space,
            action_space=self.action_space,
            inference_only=self.inference_only,
            learner_only=self.learner_only,
            model_config_dict=self.model_config,
            catalog_class=catalog_class,
        )

        self.action_dist_cls = None
        if self.catalog is not None:
            self.action_dist_cls = self.catalog.get_action_dist_cls(
                framework=self.framework
            )

        # Make sure, `setup()` is only called once, no matter what.
        if hasattr(self, "_is_setup") and self._is_setup:
            raise RuntimeError(
                "`RLModule.setup()` called twice within your RLModule implementation "
                f"{self}! Make sure you are using the proper inheritance order "
                "(TorchRLModule before [Algo]RLModule) or (TfRLModule before "
                "[Algo]RLModule) and that you are NOT overriding the constructor, but "
                "only the `setup()` method of your subclass."
            )
        self.setup()
        self._is_setup = True

    @OverrideToImplementCustomLogic
    def setup(self):
        """Sets up the components of the module.

        This is called automatically during the __init__ method of this class,
        therefore, the subclass should call super.__init__() in its constructor. This
        abstraction can be used to create any components (e.g. NN layers) that your
        RLModule needs.
        """
        """设置模块的组件。

        这个方法会在该类的__init__方法中自动调用，
        因此，子类应该在其构造函数中调用super().__init__()。这个抽象方法可以用来创建你
        的RLModule所需的任何组件（例如，NN层）。
        """
        return None

    @OverrideToImplementCustomLogic
    def get_exploration_action_dist_cls(self) -> Type[Distribution]:
        """Returns the action distribution class for this RLModule used for exploration.

        This class is used to create action distributions from outputs of the
        forward_exploration method. If the case that no action distribution class is
        needed, this method can return None.

        Note that RLlib's distribution classes all implement the `Distribution`
        interface. This requires two special methods: `Distribution.from_logits()` and
        `Distribution.to_deterministic()`. See the documentation of the
        :py:class:`~ray.rllib.models.distributions.Distribution` class for more details.
        """
        """返回用于探索的此RLModule的动作分布类。

        这个类用于从forward_exploration方法的输出创建动作分布。如果不需要动作分布类，
        此方法可以返回None。

        注意，RLlib的所有分布类都实现了`Distribution`接口。这需要两个特殊方法：
        `Distribution.from_logits()`和`Distribution.to_deterministic()`。更多详情请参见
        :py:class:`~ray.rllib.models.distributions.Distribution`类的文档。
        """
        raise NotImplementedError

    @OverrideToImplementCustomLogic
    def get_inference_action_dist_cls(self) -> Type[Distribution]:
        """Returns the action distribution class for this RLModule used for inference.

        This class is used to create action distributions from outputs of the forward
        inference method. If the case that no action distribution class is needed,
        this method can return None.

        Note that RLlib's distribution classes all implement the `Distribution`
        interface. This requires two special methods: `Distribution.from_logits()` and
        `Distribution.to_deterministic()`. See the documentation of the
        :py:class:`~ray.rllib.models.distributions.Distribution` class for more details.
        """
        """返回用于推理的此RLModule的动作分布类。

        此类用于从forward inference方法的输出创建动作分布。如果不需要动作分布类，
        此方法可以返回None。

        请注意，RLlib的所有分布类都实现了Distribution接口。这需要两个特殊方法：
        Distribution.from_logits()和Distribution.to_deterministic()。有关更多详细信息，
        请参阅:py:class:~ray.rllib.models.distributions.Distribution类的文档。
        """
        raise NotImplementedError

    @OverrideToImplementCustomLogic
    def get_train_action_dist_cls(self) -> Type[Distribution]:
        """Returns the action distribution class for this RLModule used for training.

        This class is used to get the correct action distribution class to be used by
        the training components. In case that no action distribution class is needed,
        this method can return None.

        Note that RLlib's distribution classes all implement the `Distribution`
        interface. This requires two special methods: `Distribution.from_logits()` and
        `Distribution.to_deterministic()`. See the documentation of the
        :py:class:`~ray.rllib.models.distributions.Distribution` class for more details.
        """
        """返回用于训练的此RLModule的动作分布类。

        此类用于获取训练组件使用的正确动作分布类。如果不需要动作分布类，
        此方法可以返回None。

        请注意，RLlib的所有分布类都实现了Distribution接口。这需要两个特殊方法：
        Distribution.from_logits()和Distribution.to_deterministic()。有关更多详细信息，
        请参阅:py:class:~ray.rllib.models.distributions.Distribution类的文档。
        """
        raise NotImplementedError

    @OverrideToImplementCustomLogic
    def _forward(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
        """Generic forward pass method, used in all phases of training and evaluation.

        If you need a more nuanced distinction between forward passes in the different
        phases of training and evaluation, override the following methods instead:
        For distinct action computation logic w/o exploration, override the
        `self._forward_inference()` method.
        For distinct action computation logic with exploration, override the
        `self._forward_exploration()` method.
        For distinct forward pass logic before loss computation, override the
        `self._forward_train()` method.

        Args:
            batch: The input batch.
            **kwargs: Additional keyword arguments.

        Returns:
            The output of the forward pass.
        """
        """通用的前向传递方法，用于训练和评估的所有阶段。

        如果您需要在训练和评估的不同阶段之间进行更细致的前向传递区分，请改写以下方法：
        对于无探索的不同动作计算逻辑，请重写self._forward_inference()方法。
        对于带有探索的不同动作计算逻辑，请重写self._forward_exploration()方法。
        对于损失计算前的不同前向传递逻辑，请重写self._forward_train()方法。

        参数:
            batch: 输入批次。
            **kwargs: 额外的关键字参数。

        返回:
            前向传递的输出。
        """

        return {}

    def forward_inference(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
        """DO NOT OVERRIDE! Forward-pass during evaluation, called from the sampler.

        This method should not be overridden. Override the `self._forward_inference()`
        method instead.

        Args:
            batch: The input batch. This input batch should comply with
                input_specs_inference().
            **kwargs: Additional keyword arguments.

        Returns:
            The output of the forward pass. This output should comply with the
            ouptut_specs_inference().
        """
        """不要重写！在评估期间的前向传递，由采样器调用。

        此方法不应被重写。请改写self._forward_inference()方法。

        参数:
            batch: 输入批次。该输入批次应符合input_specs_inference()。
            **kwargs: 额外的关键字参数。

        返回:
            前向传递的输出。该输出应符合output_specs_inference()。
        """
        return self._forward_inference(batch, **kwargs)

    @OverrideToImplementCustomLogic
    def _forward_inference(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
        """Forward-pass used for action computation without exploration behavior.

        Override this method only, if you need specific behavior for non-exploratory
        action computation behavior. If you have only one generic behavior for all
        phases of training and evaluation, override `self._forward()` instead.

        By default, this calls the generic `self._forward()` method.
        """
        """用于无探索行为的动作计算的前向传递。

        仅当您需要特定于非探索动作计算行为的特定行为时才重写此方法。如果您对训练和评估的所有阶段只有一个通用行为，请改写self._forward()方法。
        默认情况下，此方法调用通用的self._forward()方法。
        """
        return self._forward(batch, **kwargs)

    def forward_exploration(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
        """DO NOT OVERRIDE! Forward-pass during exploration, called from the sampler.

        This method should not be overridden. Override the `self._forward_exploration()`
        method instead.

        Args:
            batch: The input batch. This input batch should comply with
                input_specs_exploration().
            **kwargs: Additional keyword arguments.

        Returns:
            The output of the forward pass. This output should comply with the
            output_specs_exploration().
        """
        return self._forward_exploration(batch, **kwargs)

    @OverrideToImplementCustomLogic
    def _forward_exploration(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
        """Forward-pass used for action computation with exploration behavior.

        Override this method only, if you need specific behavior for exploratory
        action computation behavior. If you have only one generic behavior for all
        phases of training and evaluation, override `self._forward()` instead.

        By default, this calls the generic `self._forward()` method.
        """
        return self._forward(batch, **kwargs)

    def forward_train(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
        """DO NOT OVERRIDE! Forward-pass during training called from the learner.

        This method should not be overridden. Override the `self._forward_train()`
        method instead.

        Args:
            batch: The input batch. This input batch should comply with
                input_specs_train().
            **kwargs: Additional keyword arguments.

        Returns:
            The output of the forward pass. This output should comply with the
            output_specs_train().
        """
        if self.inference_only:
            raise RuntimeError(
                "Calling `forward_train` on an inference_only module is not allowed! "
                "Set the `inference_only=False` flag in the RLModule's config when "
                "building the module."
            )
        return self._forward_train(batch, **kwargs)

    @OverrideToImplementCustomLogic
    def _forward_train(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
        """Forward-pass used before the loss computation (training).

        Override this method only, if you need specific behavior and outputs for your
        loss computations. If you have only one generic behavior for all
        phases of training and evaluation, override `self._forward()` instead.

        By default, this calls the generic `self._forward()` method.
        """
        return self._forward(batch, **kwargs)

    @OverrideToImplementCustomLogic
    def get_initial_state(self) -> Any:
        """Returns the initial state of the RLModule, in case this is a stateful module.

        Returns:
            A tensor or any nested struct of tensors, representing an initial state for
            this (stateful) RLModule.
        """
        
        """返回RLModule的初始状态，如果这是一个有状态的模块。

        返回:
            一个张量或任何嵌套的张量结构，表示此（有状态的）RLModule的初始状态。
        """
        return {}

    @OverrideToImplementCustomLogic
    def is_stateful(self) -> bool:
        """By default, returns False if the initial state is an empty dict (or None).

        By default, RLlib assumes that the module is non-recurrent, if the initial
        state is an empty dict and recurrent otherwise.
        This behavior can be customized by overriding this method.
        """
        """默认情况下，如果初始状态是一个空字典（或None），则返回False。
        
        默认情况下，如果初始状态是一个空字典，RLlib假定此模块是非循环的，它们没有像循环神经网络（RNN）、长短期记忆网络（LSTM）或门控循环单元（GRU）那样的反馈回路或记忆单元。
        否则为循环的。
        这种行为可以通过重写此方法来定制。
        """
        initial_state = self.get_initial_state()
        assert isinstance(initial_state, dict), (
            "The initial state of an RLModule must be a dict, but is "
            f"{type(initial_state)} instead."
        )
        return bool(initial_state)

    @OverrideToImplementCustomLogic
    @override(Checkpointable)
    def get_state(
        self,
        components: Optional[Union[str, Collection[str]]] = None,
        *,
        not_components: Optional[Union[str, Collection[str]]] = None,
        inference_only: bool = False,
        **kwargs,
    ) -> StateDict:
        """Returns the state dict of the module.

        Args:
            inference_only: Whether the returned state should be an inference-only
                state (w/o those model components that are not needed for action
                computations, such as a value function or a target network).
                Note that setting this to `False` might raise an error if
                `self.inference_only` is True.

        Returns:
            This RLModule's state dict.
        """
        """返回模块的状态字典。

        参数:
            inference_only: 是否返回仅用于推理状态的状态字典（不包括那些对动作计算不必要的模型组件，如值函数或目标网络）。
                请注意，如果self.inference_only为True，将其设置为False可能会引发错误。

        返回:
            此RLModule的状态字典。
        """
        if components is not None or not_components is not None:
            raise ValueError(
                "`component` arg and `not_component` arg not supported in "
                "`RLModule.get_state()` base implementation! Override this method in "
                "your custom RLModule subclass."
            )
        return {}

    @OverrideToImplementCustomLogic
    @override(Checkpointable)
    def set_state(self, state: StateDict) -> None:
        pass

    @override(Checkpointable)
    def get_ctor_args_and_kwargs(self):
        return (
            (),  # *args
            {
                "observation_space": self.observation_space,
                "action_space": self.action_space,
                "inference_only": self.inference_only,
                "learner_only": self.learner_only,
                "model_config": self.model_config,
                "catalog_class": (
                    type(self.catalog) if self.catalog is not None else None
                ),
            },  # **kwargs
        )

    def as_multi_rl_module(self) -> "MultiRLModule":
        """Returns a multi-agent wrapper around this module."""
        from ray.rllib.core.rl_module.multi_rl_module import MultiRLModule

        multi_rl_module = MultiRLModule(
            rl_module_specs={DEFAULT_MODULE_ID: RLModuleSpec.from_module(self)}
        )
        return multi_rl_module

    def unwrapped(self) -> "RLModule":
        """Returns the underlying module if this module is a wrapper.

        An example of a wrapped is the TorchDDPRLModule class, which wraps
        a TorchRLModule.

        Returns:
            The underlying module.
        """
        return self

    @Deprecated(new="RLModule.as_multi_rl_module()", error=True)
    def as_multi_agent(self, *args, **kwargs):
        pass

    @Deprecated(new="RLModule.save_to_path(...)", error=True)
    def save_state(self, *args, **kwargs):
        pass

    @Deprecated(new="RLModule.restore_from_path(...)", error=True)
    def load_state(self, *args, **kwargs):
        pass

    @Deprecated(new="RLModule.save_to_path(...)", error=True)
    def save_to_checkpoint(self, *args, **kwargs):
        pass

    def output_specs_inference(self) -> SpecType:
        return [Columns.ACTION_DIST_INPUTS]

    def output_specs_exploration(self) -> SpecType:
        return [Columns.ACTION_DIST_INPUTS]

    def output_specs_train(self) -> SpecType:
        """Returns the output specs of the forward_train method."""
        return {}

    def input_specs_inference(self) -> SpecType:
        """Returns the input specs of the forward_inference method."""
        return self._default_input_specs()

    def input_specs_exploration(self) -> SpecType:
        """Returns the input specs of the forward_exploration method."""
        return self._default_input_specs()

    def input_specs_train(self) -> SpecType:
        """Returns the input specs of the forward_train method."""
        return self._default_input_specs()

    def _default_input_specs(self) -> SpecType:
        """Returns the default input specs."""
        return [Columns.OBS]

class TorchRLModule(nn.Module, RLModule):
    """A base class for RLlib PyTorch RLModules.

    Note that the `_forward` methods of this class can be 'torch.compiled' individually:
        - `TorchRLModule._forward_train()`
        - `TorchRLModule._forward_inference()`
        - `TorchRLModule._forward_exploration()`

    As a rule of thumb, they should only contain torch-native tensor manipulations,
    or otherwise they may yield wrong outputs. In particular, the creation of RLlib
    distributions inside these methods should be avoided when using `torch.compile`.
    When in doubt, you can use `torch.dynamo.explain()` to check whether a compiled
    method has broken up into multiple sub-graphs.

    Compiling these methods can bring speedups under certain conditions.
    """
    """
    RLlib PyTorch RL模块的基类。

    注意这个类的 `_forward` 方法可以单独进行 'torch.compile' 编译:
        - `TorchRLModule._forward_train()`
        - `TorchRLModule._forward_inference()`
        - `TorchRLModule._forward_exploration()`

    作为经验法则，这些方法应该只包含torch原生的张量操作，
    否则可能会产生错误的输出。特别是在使用 `torch.compile` 时，
    应当避免在这些方法内创建 RLlib 分布。当有疑问时，可以使用
    `torch.dynamo.explain()` 来检查一个编译过的方法是否分解成了多个子图。

    在某些条件下，编译这些方法可以带来速度提升。
    """

    framework: str = "torch"

    # Stick with torch default.
    STATE_FILE_NAME = "module_state.pt"

    def __init__(self, *args, **kwargs) -> None:
        nn.Module.__init__(self)
        RLModule.__init__(self, *args, **kwargs)

        # If an inference-only class AND self.inference_only is True,
        # remove all attributes that are returned by
        # `self.get_non_inference_attributes()`.
        if self.inference_only and isinstance(self, InferenceOnlyAPI):
            for attr in self.get_non_inference_attributes():
                parts = attr.split(".")
                if not hasattr(self, parts[0]):
                    continue
                target = getattr(self, parts[0])
                # Traverse from the next part on (if nested).
                for part in parts[1:]:
                    if not hasattr(target, part):
                        target = None
                        break
                    target = getattr(target, part)
                # Delete, if target is valid.
                if target is not None:
                    del target

    def compile(self, compile_config: TorchCompileConfig):
        """Compile the forward methods of this module.

        This is a convenience method that calls `compile_wrapper` with the given
        compile_config.

        Args:
            compile_config: The compile config to use.
        """
        return compile_wrapper(self, compile_config)

    @OverrideToImplementCustomLogic
    def _forward_inference(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
        # By default, calls the generic `_forward()` method, but with a no-grad context
        # for performance reasons.
        with torch.no_grad():
            return self._forward(batch, **kwargs)

    @OverrideToImplementCustomLogic
    def _forward_exploration(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
        # By default, calls the generic `_forward()` method, but with a no-grad context
        # for performance reasons.
        with torch.no_grad():
            return self._forward(batch, **kwargs)

    @OverrideToImplementCustomLogic
    @override(RLModule)
    def get_state(
        self,
        components: Optional[Union[str, Collection[str]]] = None,
        *,
        not_components: Optional[Union[str, Collection[str]]] = None,
        inference_only: bool = False,
        **kwargs,
    ) -> StateDict:
        state_dict = self.state_dict()
        # Filter out `inference_only` keys from the state dict if `inference_only` and
        # this RLModule is NOT `inference_only` (but does implement the
        # InferenceOnlyAPI).
        if (
            inference_only
            and not self.inference_only
            and isinstance(self, InferenceOnlyAPI)
        ):
            attr = self.get_non_inference_attributes()
            for key in list(state_dict.keys()):
                if any(
                    key.startswith(a) and (len(key) == len(a) or key[len(a)] == ".")
                    for a in attr
                ):
                    del state_dict[key]
        return convert_to_numpy(state_dict)

    @OverrideToImplementCustomLogic
    @override(RLModule)
    def set_state(self, state: StateDict) -> None:
        # If state contains more keys than `self.state_dict()`, then we simply ignore
        # these keys (strict=False). This is most likely due to `state` coming from
        # an `inference_only=False` RLModule, while `self` is an `inference_only=True`
        # RLModule.
        self.load_state_dict(convert_to_torch_tensor(state), strict=False)

    @OverrideToImplementCustomLogic
    @override(RLModule)
    def get_inference_action_dist_cls(self) -> Type[TorchDistribution]:
        if self.action_dist_cls is not None:
            return self.action_dist_cls
        elif isinstance(self.action_space, gym.spaces.Discrete):
            return TorchCategorical
        elif isinstance(self.action_space, gym.spaces.Box):
            return TorchDiagGaussian
        else:
            raise ValueError(
                f"Default action distribution for action space "
                f"{self.action_space} not supported! Either set the "
                f"`self.action_dist_cls` property in your RLModule's `setup()` method "
                f"to a subclass of `ray.rllib.models.torch.torch_distributions."
                f"TorchDistribution` or - if you need different distributions for "
                f"inference and training - override the three methods: "
                f"`get_inference_action_dist_cls`, `get_exploration_action_dist_cls`, "
                f"and `get_train_action_dist_cls` in your RLModule."
            )

    @OverrideToImplementCustomLogic
    @override(RLModule)
    def get_exploration_action_dist_cls(self) -> Type[TorchDistribution]:
        return self.get_inference_action_dist_cls()

    @OverrideToImplementCustomLogic
    @override(RLModule)
    def get_train_action_dist_cls(self) -> Type[TorchDistribution]:
        return self.get_inference_action_dist_cls()

    @override(nn.Module)
    def forward(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
        """DO NOT OVERRIDE!

        This is aliased to `self.forward_train` because Torch DDP requires a forward
        method to be implemented for backpropagation to work.

        Instead, override:
        `_forward()` to define a generic forward pass for all phases (exploration,
        inference, training)
        `_forward_inference()` to define the forward pass for action inference in
        deployment/production (no exploration).
        `_forward_exploration()` to define the forward pass for action inference during
        training sample collection (w/ exploration behavior).
        `_forward_train()` to define the forward pass prior to loss computation.
        """
        return self.forward_train(batch, **kwargs)

"""
PPORLModule 需要使用一个 catalog 来构建模型

# Build models from catalog.
self.encoder = self.catalog.build_actor_critic_encoder(framework=self.framework)
self.pi = self.catalog.build_pi_head(framework=self.framework)
self.vf = self.catalog.build_vf_head(framework=self.framework)
"""
@DeveloperAPI(stability="alpha")
class PPORLModule(RLModule, InferenceOnlyAPI, ValueFunctionAPI, abc.ABC):
    @override(RLModule)
    def setup(self):
        if self.catalog is None and hasattr(self, "_catalog_ctor_error"):
            raise self._catalog_ctor_error

        # __sphinx_doc_begin__
        # If we have a stateful model, states for the critic need to be collected
        # during sampling and `inference-only` needs to be `False`. Note, at this
        # point the encoder is not built, yet and therefore `is_stateful()` does
        # not work.
        # 如果我们有一个有状态的模型，评判器的状态需要在采样期间被收集，
        # 并且 `inference-only` 需要设置为 `False`。注意，此刻编码器还未构建，
        # 因此 `is_stateful()` 无法工作。
        is_stateful = isinstance(
            self.catalog.actor_critic_encoder_config.base_encoder_config,
            RecurrentEncoderConfig,
        )
        if is_stateful:
            self.inference_only = False
        # If this is an `inference_only` Module, we'll have to pass this information
        # to the encoder config as well.
        if self.inference_only and self.framework == "torch":
            self.catalog.actor_critic_encoder_config.inference_only = True

        # Build models from catalog.
        self.encoder = self.catalog.build_actor_critic_encoder(framework=self.framework)
        self.pi = self.catalog.build_pi_head(framework=self.framework)
        self.vf = self.catalog.build_vf_head(framework=self.framework)
        # __sphinx_doc_end__

    @override(RLModule)
    def get_initial_state(self) -> dict:
        if hasattr(self.encoder, "get_initial_state"):
            return self.encoder.get_initial_state()
        else:
            return {}

    @OverrideToImplementCustomLogic_CallToSuperRecommended
    @override(InferenceOnlyAPI)
    def get_non_inference_attributes(self) -> List[str]:
        """Return attributes, which are NOT inference-only (only used for training)."""
        return ["vf"] + (
            []
            if self.model_config.get("vf_share_layers")
            else ["encoder.critic_encoder"]
        )

class Columns:
    """Definitions of common column names for RL data, e.g. 'obs', 'rewards', etc..

    Note that this replaces the `SampleBatch` and `Postprocessing` columns (of the same
    name).
    """

    # Observation received from an environment after `reset()` or `step()`.
    OBS = "obs"
    # Infos received from an environment after `reset()` or `step()`.
    INFOS = "infos"

    # Action computed/sampled by an RLModule.
    ACTIONS = "actions"
    # Action actually sent to the (gymnasium) `Env.step()` method.
    ACTIONS_FOR_ENV = "actions_for_env"
    # Reward returned by `env.step()`.
    REWARDS = "rewards"
    # Termination signal received from an environment after `step()`.
    TERMINATEDS = "terminateds"
    # Truncation signal received from an environment after `step()` (e.g. because
    # of a reached time limit).
    TRUNCATEDS = "truncateds"

    # Next observation: Only used by algorithms that need to look at TD-data for
    # training, such as off-policy/DQN algos.
    NEXT_OBS = "new_obs"

    # Uniquely identifies an episode
    EPS_ID = "eps_id"
    AGENT_ID = "agent_id"
    MODULE_ID = "module_id"

    # The size of non-zero-padded data within a (e.g. LSTM) zero-padded
    # (B, T, ...)-style train batch.
    SEQ_LENS = "seq_lens"
    # Episode timestep counter.
    T = "t"

    # Common extra RLModule output keys.
    STATE_IN = "state_in"
    STATE_OUT = "state_out"
    EMBEDDINGS = "embeddings"
    ACTION_DIST_INPUTS = "action_dist_inputs"
    ACTION_PROB = "action_prob"
    ACTION_LOGP = "action_logp"

    # Value function predictions.
    VF_PREDS = "vf_preds"
    # Values, predicted at one timestep beyond the last timestep taken.
    # These are usually calculated via the value function network using the final
    # observation (and in case of an RNN: the last returned internal state).
    VALUES_BOOTSTRAPPED = "values_bootstrapped"

    # Postprocessing columns.
    ADVANTAGES = "advantages"
    VALUE_TARGETS = "value_targets"

    # Intrinsic rewards (learning with curiosity).
    INTRINSIC_REWARDS = "intrinsic_rewards"

    # Loss mask. If provided in a train batch, a Learner's compute_loss_for_module
    # method should respect the False-set value in here and mask out the respective
    # items form the loss.
    LOSS_MASK = "loss_mask"

class PPOTorchRLModule(TorchRLModule, PPORLModule):
    @override(RLModule)
    def _forward(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
        """Default forward pass (used for inference and exploration)."""
        output = {}
        # Encoder forward pass.
        encoder_outs = self.encoder(batch)
        # Stateful encoder?
        if Columns.STATE_OUT in encoder_outs:
            output[Columns.STATE_OUT] = encoder_outs[Columns.STATE_OUT]
        # Pi head.
        output[Columns.ACTION_DIST_INPUTS] = self.pi(encoder_outs[ENCODER_OUT][ACTOR])
        return output

    @override(RLModule)
    def _forward_train(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
        """Train forward pass (keep embeddings for possible shared value func. call)."""
        output = {}
        encoder_outs = self.encoder(batch)
        output[Columns.EMBEDDINGS] = encoder_outs[ENCODER_OUT][CRITIC]
        if Columns.STATE_OUT in encoder_outs:
            output[Columns.STATE_OUT] = encoder_outs[Columns.STATE_OUT]
        output[Columns.ACTION_DIST_INPUTS] = self.pi(encoder_outs[ENCODER_OUT][ACTOR])
        return output

    @override(ValueFunctionAPI)
    def compute_values(
        self,
        batch: Dict[str, Any],
        embeddings: Optional[Any] = None,
    ) -> TensorType:
        if embeddings is None:
            # Separate vf-encoder.
            if hasattr(self.encoder, "critic_encoder"):
                batch_ = batch
                if self.is_stateful():
                    # The recurrent encoders expect a `(state_in, h)`  key in the
                    # input dict while the key returned is `(state_in, critic, h)`.
                    batch_ = batch.copy()
                    batch_[Columns.STATE_IN] = batch[Columns.STATE_IN][CRITIC]
                embeddings = self.encoder.critic_encoder(batch_)[ENCODER_OUT]
            # Shared encoder.
            else:
                embeddings = self.encoder(batch)[ENCODER_OUT][CRITIC]

        # Value head.
        vf_out = self.vf(embeddings)
        # Squeeze out last dimension (single node value head).
        return vf_out.squeeze(-1)


In [None]:
ppo_config = {
    "extra_python_environs_for_driver": {},
    "extra_python_environs_for_worker": {},
    "placement_strategy": "PACK",
    "num_gpus": 0,
    "_fake_gpus": False,
    "num_cpus_for_main_process": 1,
    "framework_str": "torch",
    "eager_tracing": True,
    "eager_max_retraces": 20,
    "tf_session_args": {
        "intra_op_parallelism_threads": 2,
        "inter_op_parallelism_threads": 2,
        "gpu_options": {"allow_growth": True},
        "log_device_placement": False,
        "device_count": {"CPU": 1},
        "allow_soft_placement": True
    },
    "local_tf_session_args": {
        "intra_op_parallelism_threads": 8,
        "inter_op_parallelism_threads": 8
    },
    "torch_compile_learner": False,
    "torch_compile_learner_what_to_compile": "TorchCompileWhatToCompile.FORWARD_TRAIN",
    "torch_compile_learner_dynamo_backend": "inductor",
    "torch_compile_learner_dynamo_mode": None,
    "torch_compile_worker": False,
    "torch_compile_worker_dynamo_backend": "onnxrt",
    "torch_compile_worker_dynamo_mode": None,
    "torch_ddp_kwargs": {},
    "torch_skip_nan_gradients": False,
    "env": "CartPole-v1",
    "env_config": {},
    "observation_space": None,
    "action_space": None,
    "clip_rewards": None,
    "normalize_actions": True,
    "clip_actions": False,
    "_is_atari": None,
    "disable_env_checking": False,
    "env_task_fn": None,
    "render_env": False,
    "action_mask_key": "action_mask",
    "env_runner_cls": None,
    "num_env_runners": 0,
    "num_envs_per_env_runner": 1,
    "num_cpus_per_env_runner": 1,
    "num_gpus_per_env_runner": 0,
    "custom_resources_per_env_runner": {},
    "validate_env_runners_after_construction": True,
    "max_requests_in_flight_per_env_runner": 1,
    "sample_timeout_s": 60.0,
    "create_env_on_local_worker": False,
    "_env_to_module_connector": None,
    "add_default_connectors_to_env_to_module_pipeline": True,
    "_module_to_env_connector": None,
    "add_default_connectors_to_module_to_env_pipeline": True,
    "episode_lookback_horizon": 1,
    "rollout_fragment_length": "auto",
    "batch_mode": "truncate_episodes",
    "compress_observations": False,
    "remote_worker_envs": False,
    "remote_env_batch_wait_ms": 0,
    "enable_tf1_exec_eagerly": False,
    "sample_collector": "<class 'ray.rllib.evaluation.collectors.simple_list_collector.SimpleListCollector'>",
    "preprocessor_pref": "deepmind",
    "observation_filter": "NoFilter",
    "update_worker_filter_stats": True,
    "use_worker_filter_stats": True,
    "sampler_perf_stats_ema_coef": None,
    "num_learners": 0,
    "num_gpus_per_learner": 0,
    "num_cpus_per_learner": 1,
    "local_gpu_idx": 0,
    "max_requests_in_flight_per_learner": 3,
    "gamma": 0.99,
    "lr": 5e-05,
    "grad_clip": None,
    "grad_clip_by": "global_norm",
    "train_batch_size_per_learner": None,
    "train_batch_size": 4000,
    "num_epochs": 30,
    "minibatch_size": 128,
    "shuffle_batch_per_epoch": True,
    "model": {
        "fcnet_hiddens": [256, 256],
        "fcnet_activation": "tanh",
        "fcnet_weights_initializer": None,
        "fcnet_weights_initializer_config": None,
        "fcnet_bias_initializer": None,
        "fcnet_bias_initializer_config": None,
        "conv_filters": None,
        "conv_activation": "relu",
        "conv_kernel_initializer": None,
        "conv_kernel_initializer_config": None,
        "conv_bias_initializer": None,
        "conv_bias_initializer_config": None,
        "conv_transpose_kernel_initializer": None,
        "conv_transpose_kernel_initializer_config": None,
        "conv_transpose_bias_initializer": None,
        "conv_transpose_bias_initializer_config": None,
        "post_fcnet_hiddens": [],
        "post_fcnet_activation": "relu",
        "post_fcnet_weights_initializer": None,
        "post_fcnet_weights_initializer_config": None,
        "post_fcnet_bias_initializer": None,
        "post_fcnet_bias_initializer_config": None,
        "free_log_std": False,
        "log_std_clip_param": 20.0,
        "no_final_linear": False,
        "vf_share_layers": False,
        "use_lstm": False,
        "max_seq_len": 20,
        "lstm_cell_size": 256,
        "lstm_use_prev_action": False,
        "lstm_use_prev_reward": False,
        "lstm_weights_initializer": None,
        "lstm_weights_initializer_config": None,
        "lstm_bias_initializer": None,
        "lstm_bias_initializer_config": None,
        "_time_major": False,
        "use_attention": False,
        "attention_num_transformer_units": 1,
        "attention_dim": 64,
        "attention_num_heads": 1,
        "attention_head_dim": 32,
        "attention_memory_inference": 50,
        "attention_memory_training": 50,
        "attention_position_wise_mlp_dim": 32,
        "attention_init_gru_gate_bias": 2.0,
        "attention_use_n_prev_actions": 0,
        "attention_use_n_prev_rewards": 0,
        "framestack": True,
        "dim": 84,
        "grayscale": False,
        "zero_mean": True,
        "custom_model": None,
        "custom_model_config": {},
        "custom_action_dist": None,
        "custom_preprocessor": None,
        "encoder_latent_dim": None,
        "always_check_shapes": False,
        "lstm_use_prev_action_reward": -1,
        "_use_default_native_models": -1,
        "_disable_preprocessor_api": False,
        "_disable_action_flattening": False
    },
    "_learner_connector": None,
    "add_default_connectors_to_learner_pipeline": True,
    "learner_config_dict": {},
    "optimizer": {},
    "_learner_class": None,
    "callbacks_class": "<class 'ray.rllib.algorithms.callbacks.DefaultCallbacks'>",
    "explore": True,
    "enable_rl_module_and_learner": True,
    "enable_env_runner_and_connector_v2": True,
    "_prior_exploration_config": {"type": "StochasticSampling"},
    "count_steps_by": "env_steps",
    "policies": {"default_policy": [None, None, None, None]},
    "policy_map_capacity": 100,
    "policy_mapping_fn": "<function AlgorithmConfig.DEFAULT_POLICY_MAPPING_FN at 0x0000000036439DA0>",
    "policies_to_train": None,
    "policy_states_are_swappable": False,
    "observation_fn": None,
    "input_": "sampler",
    "input_read_method": "read_parquet",
    "input_read_method_kwargs": {},
    "input_read_schema": {},
    "input_read_episodes": False,
    "input_read_sample_batches": False,
    "input_read_batch_size": None,
    "input_filesystem": None,
    "input_filesystem_kwargs": {},
    "input_compress_columns": ["obs", "new_obs"],
    "input_spaces_jsonable": True,
    "materialize_data": False,
    "materialize_mapped_data": True,
    "map_batches_kwargs": {},
    "iter_batches_kwargs": {},
    "prelearner_class": None,
    "prelearner_buffer_class": None,
    "prelearner_buffer_kwargs": {},
    "prelearner_module_synch_period": 10,
    "dataset_num_iters_per_learner": None,
    "input_config": {},
    "actions_in_input_normalized": False,
    "postprocess_inputs": False,
    "shuffle_buffer_size": 0,
    "output": None,
    "output_config": {},
    "output_compress_columns": ["obs", "new_obs"],
    "output_max_file_size": 67108864,
    "output_max_rows_per_file": None,
    "output_write_method": "write_parquet",
    "output_write_method_kwargs": {},
    "output_filesystem": None,
    "output_filesystem_kwargs": {},
    "output_write_episodes": True,
    "offline_sampling": False,
    "evaluation_interval": 10,
    "evaluation_duration": 3,
    "evaluation_duration_unit": "episodes",
    "evaluation_sample_timeout_s": 120.0,
    "evaluation_parallel_to_training": False,
    "evaluation_force_reset_envs_before_iteration": True,
    "evaluation_config": None,
    "off_policy_estimation_methods": {},
    "ope_split_batch_by_episode": True,
    "evaluation_num_env_runners": 0,
    "custom_evaluation_function": None,
    "in_evaluation": False,
    "sync_filters_on_rollout_workers_timeout_s": 10.0,
    "keep_per_episode_custom_metrics": False,
    "metrics_episode_collection_timeout_s": 60.0,
    "metrics_num_episodes_for_smoothing": 100,
    "min_time_s_per_iteration": None,
    "min_train_timesteps_per_iteration": 0,
    "min_sample_timesteps_per_iteration": 0,
    "log_gradients": True,
    "export_native_model_files": False,
    "checkpoint_trainable_policies_only": False,
    "logger_creator": None,
    "logger_config": None,
    "log_level": "WARN",
    "log_sys_usage": True,
    "fake_sampler": False,
    "seed": None,
    "_run_training_always_in_thread": False,
    "_evaluation_parallel_to_training_wo_thread": False,
    "restart_failed_env_runners": True,
    "ignore_env_runner_failures": False,
    "max_num_env_runner_restarts": 1000,
    "delay_between_env_runner_restarts_s": 60.0,
    "restart_failed_sub_environments": False,
    "num_consecutive_env_runner_failures_tolerance": 100,
    "env_runner_health_probe_timeout_s": 30.0,
    "env_runner_restore_timeout_s": 1800.0,
    "_model_config": {},
    "_rl_module_spec": None,
    "algorithm_config_overrides_per_module": {},
    "_per_module_overrides": {},
    "_torch_grad_scaler_class": None,
    "_torch_lr_scheduler_classes": None,
    "_tf_policy_handles_more_than_one_loss": False,
    "_disable_preprocessor_api": False,
    "_disable_action_flattening": False,
    "_disable_initialize_loss_from_dummy_batch": False,
    "_dont_auto_sync_env_runner_states": False,
    "_is_frozen": True,
    "enable_connectors": -1,
    "simple_optimizer": False,
    "monitor": -1,
    "evaluation_num_episodes": -1,
    "metrics_smoothing_episodes": -1,
    "timesteps_per_iteration": -1,
    "min_iter_time_s": -1,
    "collect_metrics_timeout": -1,
    "min_time_s_per_reporting": -1,
    "min_train_timesteps_per_reporting": -1,
    "min_sample_timesteps_per_reporting": -1,
    "input_evaluation": -1,
    "policy_map_cache": -1,
    "worker_cls": -1,
    "synchronize_filters": -1,
    "enable_async_evaluation": -1,
    "custom_async_evaluation_function": -1,
    "_enable_rl_module_api": -1,
    "auto_wrap_old_gym_envs": -1,
    "always_attach_evaluation_results": -1,
    "buffer_size": -1,
    "prioritized_replay": -1,
    "learning_starts": -1,
    "replay_batch_size": -1,
    "replay_sequence_length": None,
    "replay_mode": -1,
    "prioritized_replay_alpha": -1,
    "prioritized_replay_beta": -1,
    "prioritized_replay_eps": -1,
    "_disable_execution_plan_api": -1,
    "use_critic": True,
    "use_gae": True,
    "lambda_": 1.0,
    "use_kl_loss": True,
    "kl_coeff": 0.2,
    "kl_target": 0.01,
    "vf_loss_coeff": 1.0,
    "entropy_coeff": 0.0,
    "clip_param": 0.3,
    "vf_clip_param": 10.0,
    "entropy_coeff_schedule": None,
    "lr_schedule": None,
    "sgd_minibatch_size": -1,
    "vf_share_layers": -1
}
