# 自定义 policy 模块中的 feature_extractor 和 mlp_extractor 

In [1]:
import gymnasium as gym
import torch as th
import torch.nn as nn
from stable_baselines3 import PPO
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor

In [2]:
# 1. 定义自定义特征提取器 和 简单的 mlp_extractor 自定义
class CustomCombinedExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 256):
        super().__init__(observation_space, features_dim)
        n_input_channels = observation_space.shape[0]
        
        self.layers = nn.Sequential(
            nn.Linear(n_input_channels, 128),
            nn.ReLU(),
            nn.Linear(128, features_dim),
            nn.ReLU(),
        )

    def forward(self, observations: th.Tensor) -> th.Tensor:
        return self.layers(observations)

# 2. 配置 policy_kwargs
policy_kwargs = dict(
    features_extractor_class=CustomCombinedExtractor,
    features_extractor_kwargs=dict(features_dim=256),
    # 自定义 mlp_extractor 就在这里：
    # net_arch=dict(
    #     pi=[128, 64],  # 策略网络 (Actor) 的隐藏层：256 -> 128 -> 64
    #     vf=[64, 64]    # 价值网络 (Critic) 的隐藏层：256 -> 64 -> 64
    # )
)

# 3. 创建模型
env = gym.make("CartPole-v1")
model = PPO("MlpPolicy", env, policy_kwargs=policy_kwargs, verbose=1)
model.learn(total_timesteps=10000)

Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.




---------------------------------
| rollout/           |          |
|    ep_len_mean     | 22.1     |
|    ep_rew_mean     | 22.1     |
| time/              |          |
|    fps             | 368      |
|    iterations      | 1        |
|    time_elapsed    | 5        |
|    total_timesteps | 2048     |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 31.7        |
|    ep_rew_mean          | 31.7        |
| time/                   |             |
|    fps                  | 281         |
|    iterations           | 2           |
|    time_elapsed         | 14          |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.015045412 |
|    clip_fraction        | 0.165       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.683      |
|    explained_variance   | 0.000748    |
|    learning_rate        | 0.

<stable_baselines3.ppo.ppo.PPO at 0x21b2e811a90>

In [3]:
import torch as th
import numpy as np

# 1. 采样状态
random_obs = env.observation_space.sample()

# 2. 准备 Tensor 并确保设备一致
# model.device 会自动返回 'cuda' 或 'cpu'
obs_tensor = th.as_tensor(random_obs).float().unsqueeze(0).to(model.device)

# 3. 手动推断分布
with th.no_grad():
    # 获取动作分布
    dist = model.policy.get_distribution(obs_tensor)
    
    # 提取概率（转回 CPU 以便 numpy/print 处理）
    probs = dist.distribution.probs.cpu().numpy()[0]
    
    print(f"--- 状态测试 ---")
    print(f"输入状态: {random_obs}")
    print(f"动作概率: 向左(0): {probs[0]:.2%}, 向右(1): {probs[1]:.2%}")
    
    # 选出概率最大的动作
    action = np.argmax(probs)
    print(f"最终决策: {action}")

--- 状态测试 ---
输入状态: [-1.3300587  0.816699   0.0824696  1.2880276]
动作概率: 向左(0): 14.47%, 向右(1): 85.53%
最终决策: 1


In [4]:
random_obs

array([-1.3300587,  0.816699 ,  0.0824696,  1.2880276], dtype=float32)

In [6]:
model.policy

ActorCriticPolicy(
  (features_extractor): CustomCombinedExtractor(
    (layers): Sequential(
      (0): Linear(in_features=4, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=256, bias=True)
      (3): ReLU()
    )
  )
  (pi_features_extractor): CustomCombinedExtractor(
    (layers): Sequential(
      (0): Linear(in_features=4, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=256, bias=True)
      (3): ReLU()
    )
  )
  (vf_features_extractor): CustomCombinedExtractor(
    (layers): Sequential(
      (0): Linear(in_features=4, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=256, bias=True)
      (3): ReLU()
    )
  )
  (mlp_extractor): MlpExtractor(
    (policy_net): Sequential(
      (0): Linear(in_features=256, out_features=64, bias=True)
      (1): Tanh()
      (2): Linear(in_features=64, out_features=64, bias=True)
      (3): Tanh()
    )
    

In [8]:
id(model.policy.features_extractor), id(model.policy.vf_features_extractor), id(model.policy.pi_features_extractor)

(2315767588112, 2315767588112, 2315767588112)

In [9]:
# 1. 定义自定义特征提取器 和 mlp_extractor 自定义

import torch as th
import torch.nn as nn
from stable_baselines3.common.policies import ActorCriticPolicy

# 1. 定义一个完全自定义的 MlpExtractor 类
class MyCustomMlpExtractor(nn.Module):
    def __init__(self, feature_dim: int):
        super().__init__()
        # 定义输出维度，必须告知 Policy 最终输出给 action_net 的维度是多少
        self.latent_dim_pi = 64
        self.latent_dim_vf = 64

        # 策略网络分支：加入 Dropout 层作为示例
        self.policy_net = nn.Sequential(
            nn.Linear(feature_dim, 128),
            nn.ReLU(),
            nn.Dropout(p=0.2),
            nn.Linear(128, self.latent_dim_pi),
            nn.ReLU()
        )

        # 价值网络分支
        self.value_net = nn.Sequential(
            nn.Linear(feature_dim, 128),
            nn.Tanh(), # 甚至可以在这里用不同的激活函数
            nn.Linear(128, self.latent_dim_vf),
            nn.Tanh()
        )

    def forward(self, features: th.Tensor):
        # 返回 (policy_latent, value_latent)
        return self.policy_net(features), self.value_net(features)

    # 为了兼容性，SB3 需要这两个方法
    def forward_actor(self, features: th.Tensor) -> th.Tensor:
        return self.policy_net(features)

    def forward_critic(self, features: th.Tensor) -> th.Tensor:
        return self.value_net(features)

# 2. 定义一个新的 Policy 类来使用这个 Extractor
class CustomPolicy(ActorCriticPolicy):
    def _build_mlp_extractor(self) -> None:
        # 这里用我们自定义的类替换默认的 MlpExtractor
        self.mlp_extractor = MyCustomMlpExtractor(self.features_dim)

# 3. 使用这个自定义 Policy
model = PPO(CustomPolicy, env, policy_kwargs=policy_kwargs, verbose=1)

Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [10]:
model.policy

CustomPolicy(
  (features_extractor): CustomCombinedExtractor(
    (layers): Sequential(
      (0): Linear(in_features=4, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=256, bias=True)
      (3): ReLU()
    )
  )
  (pi_features_extractor): CustomCombinedExtractor(
    (layers): Sequential(
      (0): Linear(in_features=4, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=256, bias=True)
      (3): ReLU()
    )
  )
  (vf_features_extractor): CustomCombinedExtractor(
    (layers): Sequential(
      (0): Linear(in_features=4, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=256, bias=True)
      (3): ReLU()
    )
  )
  (mlp_extractor): MyCustomMlpExtractor(
    (policy_net): Sequential(
      (0): Linear(in_features=256, out_features=128, bias=True)
      (1): ReLU()
      (2): Dropout(p=0.2, inplace=False)
      (3): Linear(in_features=128, out_featur