In [6]:
import gymnasium as gym


#定义环境
class MyWrapper(gym.Wrapper):

    def __init__(self,env):

        super().__init__(env)

    def reset(self,**kwargs):

        return self.env.reset(**kwargs)

    def step(self, action):
        state, reward, terminated, truncated, info = self.env.step(action)
        return state, reward, terminated, truncated, info



env = gym.make('CartPole-v1')

env = MyWrapper(env)

env.reset()

env.observation_space

Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32)

In [11]:
import torch
from stable_baselines3 import PPO
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor


#自定义特征抽取层
class CustomCNN(BaseFeaturesExtractor):

    def __init__(self, observation_space, hidden_dim):
        super().__init__(observation_space, hidden_dim)

        self.sequential = torch.nn.Sequential(

            #[batch, 4, 1, 1] -> [batch, h, 1, 1]
            torch.nn.Conv2d(in_channels=observation_space.shape[0],
                            out_channels=hidden_dim,
                            kernel_size=1,
                            stride=1,
                            padding=0),
            torch.nn.ReLU(),

            #[b, h, 1, 1] -> [b, h, 1, 1]
            torch.nn.Conv2d(hidden_dim,
                            hidden_dim,
                            kernel_size=1,
                            stride=1,
                            padding=0),
            torch.nn.ReLU(),

            #[b, h, 1, 1] -> [b, h]
            torch.nn.Flatten(),

            #[b, h] -> [b, h]
            torch.nn.Linear(hidden_dim, hidden_dim),
            torch.nn.ReLU(),
        )

    def forward(self, state):
        b = state.shape[0]
        state = state.reshape(b, -1, 1, 1)
        return self.sequential(state)


model = PPO('CnnPolicy',
            env,
            policy_kwargs={
                'features_extractor_class': CustomCNN,
                'features_extractor_kwargs': {
                    'hidden_dim': 8,

                },
            },
            verbose=0)

model

<stable_baselines3.ppo.ppo.PPO at 0x1fa7c29fb80>

In [4]:
from stable_baselines3.common.evaluation import evaluate_policy

#测试
evaluate_policy(model, env, n_eval_episodes=10, deterministic=False)



(17.0, 5.215361924162119)

In [12]:
#训练
model.learn(total_timesteps=2_0000, progress_bar=True)

model.save('models/自定义特征抽取层')

In [13]:
model = PPO.load('models/自定义特征抽取层')

evaluate_policy(model, env, n_eval_episodes=10, deterministic=False)



(114.5, 35.06351379996021)