![ce](https://pic.imgdb.cn/item/649a59261ddac507cc728dbe.png)

**reference**: https://github.com/ucla-rlcourse/RLexample/blob/master/my_learning_agent.py

In [1]:
import gymnasium as gym
import numpy as np

In [6]:
class BinaryActionLinearPolicy(object):
    def __init__(self, theta):
        # theta dim: (obs_dim + 1, )
        self.w = theta[:-1]
        self.b = theta[-1]
    def act(self, ob):
        y = ob.dot(self.w) + self.b
        a = int(y < 0)    # 本质和对y施加一个logistic函数，然后判断和0.5的大小是一样的，这样可以简化计算，加快速度。
        return a
    
    
class ContinuousActionLinearPolicy(object):
    def __init__(self, theta, obs_dim, action_dim):
        assert len(theta) == (obs_dim + 1) * action_dim
        self.W = theta[:obs_dim * action_dim].reshape(obs_dim, action_dim)
        self.b = theta[obs_dim * action_dim:].reshape(1, action_dim)
    def act(self, ob):
        a = ob.dot(self.W) + self.b  # 如果有action有范围限制，可以加上一个tanh函数来调节输出的范围
        return a

In [7]:
def cem(J, mu, N, elite_frac, th_std):
    """
    Generic implementation of the cross-entropy method for maximizing a black-box function

    J: a function mapping from vector -> scalar
    mu: initial mean over input distribution, dim: [(obs_dim+1) * action_dim]
    N: number of samples of theta to evaluate per batch
    elite_frac: each batch, select this fraction of the top-performing samples
    th_std: istandard deviation over parameter vectors
    """
    n_elite = int(np.round(N * elite_frac))
    
    # sample theta's from a gaussian distribution with th_mean and th_std
    thetas = np.array([mu + std for std in th_std[None, :] * np.random.randn(N, mu.size)])
    # evaluate each theta using J(theta)
    Js = np.array([J(theta) for theta in thetas])
    # select thetas with the top n_elite performances
    elite_inds = Js.argsort()[::-1][:n_elite]
    elite_thetas = thetas[elite_inds]
    # update mu and std
    mu = elite_thetas.mean(axis=0)
    th_std = elite_thetas.std(axis=0)
    
    return Js.mean(), mu, th_std

In [8]:
def do_rollout(agent, env):
    total_rewards = 0
    obs, _ = env.reset()
    for _ in range(1000):
        a = agent.act(obs)
        obs, reward, terminated, truncated, info = env.step(a)
        total_rewards += reward
        if terminated or truncated:
            break
    return total_rewards

def noisy_evaluation(theta):
    # J(theta)
    agent = BinaryActionLinearPolicy(theta)
    total_rewards = do_rollout(agent, env)
    return total_rewards

# Training

In [34]:
n_iters = 200
N = 30
elite_frac = 0.4

In [52]:
env = gym.make('CartPole-v1', max_episode_steps=1000, render_mode=None)
obs_dim = env.observation_space.shape[0]
action_dim = 1 if len(env.action_space.shape)==0 else env.action_space.shape[0]

mu = np.zeros((obs_dim + 1) * action_dim)
th_std = np.ones_like(mu)
for step in range(n_iters):
    J_mean, mu, th_std = cem(noisy_evaluation, mu, N, elite_frac, th_std)
    print('Step: ', step, 'J means: ', J_mean)
    
    if J_mean > 999.999:
        print('Training is done') 
        break

Step:  0 J means:  51.43333333333333
Step:  1 J means:  37.7
Step:  2 J means:  69.33333333333333
Step:  3 J means:  88.03333333333333
Step:  4 J means:  92.5
Step:  5 J means:  118.76666666666667
Step:  6 J means:  156.5
Step:  7 J means:  198.76666666666668
Step:  8 J means:  370.06666666666666
Step:  9 J means:  380.1666666666667
Step:  10 J means:  508.5
Step:  11 J means:  614.3333333333334
Step:  12 J means:  864.8
Step:  13 J means:  912.2666666666667
Step:  14 J means:  997.4
Step:  15 J means:  971.6333333333333
Step:  16 J means:  998.2
Step:  17 J means:  985.1333333333333
Step:  18 J means:  1000.0
Training is done


In [53]:
np.random.seed(42)
theta = np.random.normal(mu, th_std)
agent = BinaryActionLinearPolicy(theta)
env = gym.make('CartPole-v1', max_episode_steps=1000, render_mode="human")
obs, info = env.reset(seed=42)
total_rewards = 0
for t in range(1000):
    action = agent.act(obs)  # this is where you would insert your policy
    obs, reward, terminated, truncated, info = env.step(action)
    total_rewards += reward
    if terminated or truncated:
        break
env.close()
print('Total rewards: ', total_rewards, 'Total steps: ', t+1)

Total rewards:  1000.0 Total steps:  1000


In [54]:
mu, th_std

(array([-0.01422972, -0.80893479, -1.12050025, -1.47640291,  0.00489943]),
 array([0.02448865, 0.1431876 , 0.05402537, 0.03403916, 0.01739198]))

In [55]:
terminated, truncated

(False, True)