In [6]:
import gymnasium as gym


#定义环境
class MyWrapper(gym.Wrapper):

    def __init__(self):
        env = gym.make('CartPole-v1')
        super().__init__(env)
        self.env = env

    def reset(self, seed=None, options=None):
        state, info = self.env.reset()
        return state, info

    def step(self, action):
        state, reward, done, truncated, info = self.env.step(action)
        return state, reward, done, truncated, info


env = MyWrapper()

env.reset()

(array([-0.01998622,  0.03740019, -0.04420668, -0.00132413], dtype=float32),
 {})

In [7]:
state, _ = env.reset()
print(state.shape)

(4,)


In [3]:
import torch
from stable_baselines3 import PPO
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor


#自定义特征抽取层
class CustomCNN(BaseFeaturesExtractor):
    def __init__(self, observation_space, hidden_dim):
        super().__init__(observation_space, hidden_dim)

        self.sequential = torch.nn.Sequential(

            #[b, 4, 1, 1] -> [b, h, 1, 1]
            torch.nn.Conv2d(in_channels=observation_space.shape[0],
                            out_channels=hidden_dim,
                            kernel_size=1,
                            stride=1,
                            padding=0),
            torch.nn.ReLU(),

            #[b, h, 1, 1] -> [b, h, 1, 1]
            torch.nn.Conv2d(hidden_dim,
                            hidden_dim,
                            kernel_size=1,
                            stride=1,
                            padding=0),
            torch.nn.ReLU(),

            #[b, h, 1, 1] -> [b, h]
            torch.nn.Flatten(),

            #[b, h] -> [b, h]
            torch.nn.Linear(hidden_dim, hidden_dim),
            torch.nn.ReLU(),
        )

    def forward(self, state):
        b = state.shape[0]
        state = state.reshape(b, -1, 1, 1)
        return self.sequential(state)


model = PPO('CnnPolicy',
            env,
            policy_kwargs={
                'features_extractor_class': CustomCNN,
                'features_extractor_kwargs': {
                    'hidden_dim': 8
                },
            },
            verbose=0)

model

<stable_baselines3.ppo.ppo.PPO at 0x1e986579970>

In [4]:
from stable_baselines3.common.evaluation import evaluate_policy

#测试
evaluate_policy(model, env, n_eval_episodes=10, deterministic=False)



(24.5, 10.433120338613946)

In [5]:
#训练
model.learn(total_timesteps=2_0000, progress_bar=True)

model.save('models/自定义特征抽取层')

Output()

In [5]:
model = PPO.load('models/自定义特征抽取层')

evaluate_policy(model, env, n_eval_episodes=10, deterministic=False)

  return self.fget.__get__(instance, owner)()


(364.7, 133.57997604431586)