# 自定义 policy 模块中的 feature_extractor 和 mlp_extractor 

In [1]:
import gymnasium as gym
import torch as th
import torch.nn as nn
from stable_baselines3 import PPO
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor

import EvnOneStock

import pandas as pd
import numpy as np
from datetime import datetime

In [2]:
# 1. 自定义特征提取器（保持你的代码）
class CustomCombinedExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 256):
        super().__init__(observation_space, features_dim)
        n_input_channels = observation_space.shape[0]
        self.layers = nn.Sequential(
            nn.Linear(n_input_channels, 128),
            nn.ReLU(),
            nn.Linear(128, features_dim),
            nn.ReLU(),
        )
    def forward(self, observations: th.Tensor) -> th.Tensor:
        return self.layers(observations)

# 2. 修改后的 policy_kwargs
policy_kwargs = dict(
    features_extractor_class=CustomCombinedExtractor,
    features_extractor_kwargs=dict(features_dim=256),
    # 重点：对于股票交易这类连续动作，使用 Squashed Gaussian 或简单的 Tanh 激活有时有帮助
    # 但 SB3 的 PPO 默认在输出层之后不加激活函数，它靠 Box 空间的 Bound 来裁剪动作。
    net_arch=dict(
        pi=[128, 64], 
        vf=[128, 64]
    ),
    # 可以选择激活函数，Tanh 在连续控制任务中通常比 ReLU 更平滑
    activation_fn=th.nn.Tanh 
)

df = pd.read_csv("ohlcv_000001.SZ.csv").fillna(0)

env = EvnOneStock.SingleStockTradingEnv(df)

# 3. 创建模型
# 注意：SB3 会根据 env.action_space 自动识别是连续动作还是离散动作
model = PPO("MlpPolicy", env, policy_kwargs=policy_kwargs, verbose=1)

model.learn(total_timesteps=10)

Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.




-----------------------------
| time/              |      |
|    fps             | 197  |
|    iterations      | 1    |
|    time_elapsed    | 10   |
|    total_timesteps | 2048 |
-----------------------------


<stable_baselines3.ppo.ppo.PPO at 0x2469c644ec0>

In [26]:
# 返回所有包含至少一个缺失值的行
null_rows = df[df.isnull().any(axis=1)]

print(f"总共有 {len(null_rows)} 行包含空值。")
print(null_rows)

总共有 1 行包含空值。
         date     open     high      low    close  volume
156  19911111  45.5264  45.5264  45.5264  45.5264     NaN


In [6]:
# 1. 定义自定义特征提取器 和 简单的 mlp_extractor 自定义
class CustomCombinedExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 256):
        super().__init__(observation_space, features_dim)
        n_input_channels = observation_space.shape[0]
        
        self.layers = nn.Sequential(
            nn.Linear(n_input_channels, 128),
            nn.ReLU(),
            nn.Linear(128, features_dim),
            nn.ReLU(),
        )

    def forward(self, observations: th.Tensor) -> th.Tensor:
        return self.layers(observations)

# 2. 配置 policy_kwargs
policy_kwargs = dict(
    features_extractor_class=CustomCombinedExtractor,
    features_extractor_kwargs=dict(features_dim=256),
    # 自定义 mlp_extractor 就在这里：
    # net_arch=dict(
    #     pi=[128, 64],  # 策略网络 (Actor) 的隐藏层：256 -> 128 -> 64
    #     vf=[64, 64]    # 价值网络 (Critic) 的隐藏层：256 -> 64 -> 64
    # )
)

# 3. 创建模型

df = pd.read_csv("ohlcv_000001.SZ.csv").fillna(0)

env = EvnOneStock.SingleStockTradingEnv(df)
model = PPO("MlpPolicy", env, policy_kwargs=policy_kwargs, verbose=1)
model.learn(total_timesteps=10)

Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.




ValueError: Expected parameter loc (Tensor of shape (1, 1)) of distribution Normal(loc: tensor([[nan]], device='cuda:0'), scale: tensor([[1.]], device='cuda:0')) to satisfy the constraint Real(), but found invalid values:
tensor([[nan]], device='cuda:0')

In [3]:
import torch as th
import numpy as np

# 1. 采样状态
random_obs = env.observation_space.sample()

# 2. 准备 Tensor 并确保设备一致
# model.device 会自动返回 'cuda' 或 'cpu'
obs_tensor = th.as_tensor(random_obs).float().unsqueeze(0).to(model.device)

# 3. 手动推断分布
with th.no_grad():
    # 获取动作分布
    dist = model.policy.get_distribution(obs_tensor)
    
    # 提取概率（转回 CPU 以便 numpy/print 处理）
    probs = dist.distribution.probs.cpu().numpy()[0]
    
    print(f"--- 状态测试 ---")
    print(f"输入状态: {random_obs}")
    print(f"动作概率: 向左(0): {probs[0]:.2%}, 向右(1): {probs[1]:.2%}")
    
    # 选出概率最大的动作
    action = np.argmax(probs)
    print(f"最终决策: {action}")

--- 状态测试 ---
输入状态: [-1.3300587  0.816699   0.0824696  1.2880276]
动作概率: 向左(0): 14.47%, 向右(1): 85.53%
最终决策: 1


In [4]:
random_obs

array([-1.3300587,  0.816699 ,  0.0824696,  1.2880276], dtype=float32)

In [6]:
model.policy

ActorCriticPolicy(
  (features_extractor): CustomCombinedExtractor(
    (layers): Sequential(
      (0): Linear(in_features=4, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=256, bias=True)
      (3): ReLU()
    )
  )
  (pi_features_extractor): CustomCombinedExtractor(
    (layers): Sequential(
      (0): Linear(in_features=4, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=256, bias=True)
      (3): ReLU()
    )
  )
  (vf_features_extractor): CustomCombinedExtractor(
    (layers): Sequential(
      (0): Linear(in_features=4, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=256, bias=True)
      (3): ReLU()
    )
  )
  (mlp_extractor): MlpExtractor(
    (policy_net): Sequential(
      (0): Linear(in_features=256, out_features=64, bias=True)
      (1): Tanh()
      (2): Linear(in_features=64, out_features=64, bias=True)
      (3): Tanh()
    )
    

In [8]:
id(model.policy.features_extractor), id(model.policy.vf_features_extractor), id(model.policy.pi_features_extractor)

(2315767588112, 2315767588112, 2315767588112)

In [5]:
# 1. 定义自定义特征提取器 和 mlp_extractor 自定义

import torch as th
import torch.nn as nn
from stable_baselines3.common.policies import ActorCriticPolicy

# 1. 定义一个完全自定义的 MlpExtractor 类
class MyCustomMlpExtractor(nn.Module):
    def __init__(self, feature_dim: int):
        super().__init__()
        # 定义输出维度，必须告知 Policy 最终输出给 action_net 的维度是多少
        self.latent_dim_pi = 64
        self.latent_dim_vf = 64

        # 策略网络分支：加入 Dropout 层作为示例
        self.policy_net = nn.Sequential(
            nn.Linear(feature_dim, 128),
            nn.ReLU(),
            nn.Dropout(p=0.2),
            nn.Linear(128, self.latent_dim_pi),
            nn.ReLU()
        )

        # 价值网络分支
        self.value_net = nn.Sequential(
            nn.Linear(feature_dim, 128),
            nn.Tanh(), # 甚至可以在这里用不同的激活函数
            nn.Linear(128, self.latent_dim_vf),
            nn.Tanh()
        )

    def forward(self, features: th.Tensor):
        # 返回 (policy_latent, value_latent)
        return self.policy_net(features), self.value_net(features)

    # 为了兼容性，SB3 需要这两个方法
    def forward_actor(self, features: th.Tensor) -> th.Tensor:
        return self.policy_net(features)

    def forward_critic(self, features: th.Tensor) -> th.Tensor:
        return self.value_net(features)

# 2. 定义一个新的 Policy 类来使用这个 Extractor
class CustomPolicy(ActorCriticPolicy):
    def _build_mlp_extractor(self) -> None:
        # 这里用我们自定义的类替换默认的 MlpExtractor
        self.mlp_extractor = MyCustomMlpExtractor(self.features_dim)

# 3. 使用这个自定义 Policy
model = PPO(CustomPolicy, env, policy_kwargs=policy_kwargs, verbose=1)

Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [10]:
model.policy

CustomPolicy(
  (features_extractor): CustomCombinedExtractor(
    (layers): Sequential(
      (0): Linear(in_features=4, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=256, bias=True)
      (3): ReLU()
    )
  )
  (pi_features_extractor): CustomCombinedExtractor(
    (layers): Sequential(
      (0): Linear(in_features=4, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=256, bias=True)
      (3): ReLU()
    )
  )
  (vf_features_extractor): CustomCombinedExtractor(
    (layers): Sequential(
      (0): Linear(in_features=4, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=256, bias=True)
      (3): ReLU()
    )
  )
  (mlp_extractor): MyCustomMlpExtractor(
    (policy_net): Sequential(
      (0): Linear(in_features=256, out_features=128, bias=True)
      (1): ReLU()
      (2): Dropout(p=0.2, inplace=False)
      (3): Linear(in_features=128, out_featur

In [20]:
action = np.array([0.5], dtype=np.float32)
state, reward, terminated, truncated, info = env.step(action)
assert not np.any(np.isnan(state)), "Observation contains NaN!"
assert not np.any(np.isinf(state)), "Observation contains Inf!"
assert np.isfinite(reward), "Reward is not finite!"

In [1]:
import gymnasium as gym
import torch as th
import torch.nn as nn
from stable_baselines3 import PPO
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor

import EvnOneStock

import pandas as pd
import numpy as np

In [15]:
df = pd.read_csv("ohlcv_000001.SZ.csv")

env = EvnOneStock.SingleStockTradingEnv(df=df, lookback_n=5)


env.reset()

info_dict = {
    # "state":[],
    "reward":[],
    "terminated":[],
    "truncated":[],
    "portfolio_value":[],
    "position":[],
    "drawdown":[],
    "current_step":[],
    "prev_close":[],
    "open_price":[],
    "close_price":[],
    "prev_position":[],
    "delta_position":[],
}
    
for i in range(6):
    action = np.array([0], dtype=np.float32)
    if i%3==1:
        action = np.array([0.5], dtype=np.float32)
    elif i%3==2:
        action = np.array([1], dtype=np.float32)
    state, reward, terminated, truncated, info = env.step(action)
    info_dict["reward"].append(reward)
    info_dict["terminated"].append(terminated)
    info_dict["truncated"].append(truncated)
    for k in list(info_dict.keys())[3:]:
        info_dict[k].append(info[k])

info_pd = pd.DataFrame(info_dict)

In [16]:
info_pd

Unnamed: 0,reward,terminated,truncated,portfolio_value,position,drawdown,current_step,prev_close,open_price,close_price,prev_position,delta_position
0,0.0,False,False,1000000.0,0.0,0.0,5,47.8,47.56,47.56,0.0,0.0
1,0.0,False,False,1000000.0,0.5,0.0,6,47.56,47.56,47.56,0.0,0.5
2,-0.010093,False,False,994953.742641,1.0,0.005046,7,47.56,47.08,47.08,0.5,0.5
3,-0.034708,False,False,980160.462695,0.0,0.01984,8,47.08,46.38,46.38,1.0,-1.0
4,-0.01984,False,False,980160.462695,0.5,0.01984,9,46.38,46.15,46.15,0.0,0.5
5,-0.024774,False,False,977718.026006,1.0,0.022282,10,46.15,45.92,45.92,0.5,0.5


In [13]:
df.head(10)

Unnamed: 0,date,open,high,low,close,volume
0,19910403,49.0,49.0,49.0,49.0,1.0
1,19910404,48.76,48.76,48.76,48.76,3.0
2,19910405,48.52,48.52,48.52,48.52,2.0
3,19910408,48.04,48.04,48.04,48.04,2.0
4,19910409,47.8,47.8,47.8,47.8,4.0
5,19910410,47.56,47.56,47.56,47.56,15.0
6,19910411,47.56,47.56,47.56,47.56,0.0
7,19910412,47.08,47.08,47.08,47.08,8.0
8,19910416,46.38,46.38,46.38,46.38,2.0
9,19910417,46.15,46.15,46.15,46.15,1.0


In [17]:
state

array([47.56, 47.56, 47.56, 47.56,  0.  , 47.08, 47.08, 47.08, 47.08,
        8.  , 46.38, 46.38, 46.38, 46.38,  2.  , 46.15, 46.15, 46.15,
       46.15,  1.  , 45.92, 45.92, 45.92, 45.92,  4.  ,  1.  ],
      dtype=float32)

In [5]:
EvnOneStock.SingleStockTradingEnv(df=df, lookback_n=5)

<EvnOneStock.SingleStockTradingEnv at 0x202c5102d50>

In [7]:
df.loc[0:4]

Unnamed: 0,date,open,high,low,close,volume
0,19910403,49.0,49.0,49.0,49.0,1.0
1,19910404,48.76,48.76,48.76,48.76,3.0
2,19910405,48.52,48.52,48.52,48.52,2.0
3,19910408,48.04,48.04,48.04,48.04,2.0
4,19910409,47.8,47.8,47.8,47.8,4.0


In [6]:
9%2

1

In [4]:
return_pd = pd.read_csv("EvnOneStock_20260108152545.csv")
return_pd

Unnamed: 0,reward,terminated,truncated,portfolio_value,position,drawdown,current_step,prev_close,open_price,close_price,prev_position,delta_position
0,0.000000,False,False,1.000000e+06,0.062328,0.000000,20,61.2600,60.9500,60.9500,0.000000,0.062328
1,-0.000634,False,False,9.996830e+05,0.000000,0.000317,21,60.9500,60.6400,60.6400,0.062328,-0.062328
2,-0.000317,False,False,9.996830e+05,0.000000,0.000317,22,60.6400,60.0300,60.0300,0.000000,0.000000
3,-0.000317,False,False,9.996830e+05,0.054134,0.000317,23,60.0300,59.7300,59.7300,0.000000,0.054134
4,-0.000843,False,False,9.994202e+05,0.081649,0.000580,24,59.7300,59.4400,59.4400,0.054134,0.027515
...,...,...,...,...,...,...,...,...,...,...,...,...
2043,-0.164601,False,False,4.299765e+06,0.000000,0.151222,2063,297.6043,293.4619,282.0704,1.000000,-1.000000
2044,-0.151222,False,False,4.299765e+06,0.000000,0.151222,2064,282.0704,278.3164,286.4717,0.000000,0.000000
2045,-0.151222,False,False,4.299765e+06,0.000000,0.151222,2065,286.4717,286.8600,284.7888,0.000000,0.000000
2046,-0.149534,False,False,4.303691e+06,1.000000,0.150447,2066,284.7888,283.4943,283.7532,0.000000,1.000000


In [7]:
return_pd[["portfolio_value","position"]]

Unnamed: 0,portfolio_value,position
0,1.000000e+06,0.062328
1,9.996830e+05,0.000000
2,9.996830e+05,0.000000
3,9.996830e+05,0.054134
4,9.994202e+05,0.081649
...,...,...
2043,4.299765e+06,0.000000
2044,4.299765e+06,0.000000
2045,4.299765e+06,0.000000
2046,4.303691e+06,1.000000
