In [1]:
import gymnasium as gym
import torch as th
import torch.nn as nn
from stable_baselines3 import PPO
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor

import EvnOneStock

import pandas as pd
import numpy as np
from datetime import datetime

In [3]:
# 1. 定义自定义特征提取器 和 简单的 mlp_extractor 自定义
class CustomCombinedExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 256):
        super().__init__(observation_space, features_dim)
        n_input_channels = observation_space.shape[0]
        
        self.layers = nn.Sequential(
            nn.Linear(n_input_channels, 128),
            nn.ReLU(),
            nn.Linear(128, features_dim),
            nn.ReLU(),
        )

    def forward(self, observations: th.Tensor) -> th.Tensor:
        return self.layers(observations)

# 2. 配置 policy_kwargs
policy_kwargs = dict(
    features_extractor_class=CustomCombinedExtractor,
    features_extractor_kwargs=dict(features_dim=256),
    # 自定义 mlp_extractor 就在这里：
    # net_arch=dict(
    #     pi=[128, 64],  # 策略网络 (Actor) 的隐藏层：256 -> 128 -> 64
    #     vf=[64, 64]    # 价值网络 (Critic) 的隐藏层：256 -> 64 -> 64
    # )
)

# 3. 创建模型

df = pd.read_csv("ohlcv_000001.SZ.csv").fillna(0)

env = EvnOneStock.SingleStockTradingEnv(df)
model = PPO("MlpPolicy", env, policy_kwargs=policy_kwargs, verbose=1)
model.learn(total_timesteps=1e5)

Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.




-----------------------------
| time/              |      |
|    fps             | 190  |
|    iterations      | 1    |
|    time_elapsed    | 10   |
|    total_timesteps | 2048 |
-----------------------------
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 3.4e+03      |
|    ep_rew_mean          | -1.35e+03    |
| time/                   |              |
|    fps                  | 178          |
|    iterations           | 2            |
|    time_elapsed         | 22           |
|    total_timesteps      | 4096         |
| train/                  |              |
|    approx_kl            | 0.0010557764 |
|    clip_fraction        | 0.0161       |
|    clip_range           | 0.2          |
|    entropy_loss         | -1.43        |
|    explained_variance   | -0.00382     |
|    learning_rate        | 0.0003       |
|    loss                 | 3.98         |
|    n_updates            | 10           |
|    policy_grad

<stable_baselines3.ppo.ppo.PPO at 0x1c17240bd90>