In [1]:
import gymnasium as gym


#定义环境
class MyWrapper(gym.Wrapper):

    def __init__(self,env):

        super().__init__(env)

    def reset(self,**kwargs):

        return self.env.reset(**kwargs)

    def step(self, action):
        state, reward, terminated, truncated, info = self.env.step(action)
        return state, reward, terminated, truncated, info



env = gym.make('CartPole-v1')

env = MyWrapper(env)

env.reset()

env.observation_space

Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32)

In [2]:
import torch

from stable_baselines3 import PPO
from stable_baselines3.common.policies import ActorCriticPolicy


#自定义策略网络
class CustomNetwork(torch.nn.Module):

    def __init__(self,
                 feature_dim,
                 last_layer_dim_pi=64,
                 last_layer_dim_vf=64):

        super().__init__()

        self.latent_dim_pi = last_layer_dim_pi
        self.latent_dim_vf = last_layer_dim_vf

        self.policy_net = torch.nn.Sequential(
            torch.nn.Linear(feature_dim, last_layer_dim_pi),
            torch.nn.ReLU(),
        )

        self.value_net = torch.nn.Sequential(
            torch.nn.Linear(feature_dim, last_layer_dim_vf),
            torch.nn.ReLU(),
        )

    def forward(self, features):
        return self.forward_actor(features), self.forward_critic(features)

    def forward_actor(self, features):
        return self.policy_net(features)

    def forward_critic(self, features):
        return self.value_net(features)


#使用自定义策略网络
class CustomActorCriticPolicy(ActorCriticPolicy):

    def __init__(self, observation_space, action_space, lr_schedule,
                 custom_param, *args, **kwargs):
        super().__init__(observation_space, action_space, lr_schedule, *args,
                         **kwargs)
        print('custom_param=', custom_param)
        self.ortho_init = False

    def _build_mlp_extractor(self) -> None:
        self.mlp_extractor = CustomNetwork(self.features_dim)


model = PPO(CustomActorCriticPolicy,
            env,
            policy_kwargs={'custom_param': 'lee'},
            verbose=0)

model

  from .autonotebook import tqdm as notebook_tqdm


custom_param= lee


<stable_baselines3.ppo.ppo.PPO at 0x23afecb1900>

In [6]:
env.observation_space, env.action_space

(Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32),
 Discrete(2))

In [3]:
from stable_baselines3.common.evaluation import evaluate_policy

#测试
evaluate_policy(model, env, n_eval_episodes=10, deterministic=False)



(17.4, 4.317406628984581)

In [7]:
#训练
model.learn(total_timesteps=2_0000, progress_bar=True)

model.save('models/自定义策略网络层')

In [8]:
model = PPO.load('models/自定义策略网络层')

evaluate_policy(model, env, n_eval_episodes=10, deterministic=False)

custom_param= lee




(195.3, 46.58980575190242)