In [1]:
#!/usr/bin/env python3
from collections import namedtuple

import gym
import gym.spaces
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from gym.envs.toy_text import frozen_lake
from torch.utils.tensorboard import SummaryWriter


class DiscreteOneHotWrapper(gym.ObservationWrapper):
    def __init__(self, env):
        super(DiscreteOneHotWrapper, self).__init__(env)
        assert isinstance(env.observation_space, gym.spaces.Discrete)
        shape = (env.observation_space.n,)
        self.observation_space = gym.spaces.Box(0.0, 1.0, shape, dtype=np.float32)

    def observation(self, observation):
        res = np.copy(self.observation_space.low)
        res[observation] = 1.0
        return res


class Net(nn.Module):
    def __init__(self, obs_size, hidden_size, n_actions):
        super(Net, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, n_actions),
        )

    def forward(self, x):
        return self.net(x)

In [2]:
Episode = namedtuple("Episode", field_names=["reward", "steps"])
EpisodeStep = namedtuple("EpisodeStep", field_names=["observation", "action"])
EpisodeStepReward = namedtuple("EpisodeStepReward", field_names=["reward"])


# 批处理
def iterate_batches(env, net, batch_size):
    batch = []
    episode_reward = []
    episode_steps = []
    # 该状态的状态值
    obs, info = env.reset()
    sm = nn.Softmax(dim=1)
    # while True:
    while True:
        obs_v = torch.FloatTensor(obs.reshape(1, -1))
        act_probs_y = sm(net(obs_v))
        act_probs = act_probs_y.detach().numpy()[0]
        action = np.random.choice(len(act_probs), p=act_probs)
        next_obs, reward, terminated, truncated, info = env.step(action)

        ## 更新 step
        step = EpisodeStep(observation=obs, action=action)
        episode_steps.append(step)

        ## 更新 reward
        episode_reward.append(reward)

        if terminated:
            e = Episode(reward=episode_reward, steps=episode_steps)
            batch.append(e)
            episode_reward = []
            episode_steps = []
            next_obs, info = env.reset()
            if len(batch) == batch_size:
                yield batch
                batch = []
                env

        obs = next_obs


def discount_reward(r_history, gamma):
    n = len(r_history)
    dr = 0
    for i in range(n):
        dr += gamma**i * r_history[i]
    return dr


# 筛选批
def filter_batch(batch, percentile):
    filter_fun = lambda s: discount_reward(s.reward, GAMMA)
    disc_rewards = list(map(filter_fun, batch))
    reward_bound = np.percentile(disc_rewards, percentile)

    train_obs = []
    train_act = []
    elite_batch = []
    for example, discounted_reward in zip(batch, disc_rewards):
        if discounted_reward > reward_bound:
            train_obs.extend(map(lambda step: step.observation, example.steps))
            train_act.extend(map(lambda step: step.action, example.steps))
            elite_batch.append(example)
    return elite_batch, train_obs, train_act, reward_bound

In [3]:
import random

random.seed(12345)

HIDDEN_SIZE = 128
BATCH_SIZE = 200
PERCENTILE = 30
GAMMA = 0.9

env = frozen_lake.FrozenLakeEnv(is_slippery=False, render_mode="human")
env.spec = gym.spec("FrozenLake-v1")
env = gym.wrappers.TimeLimit(env, max_episode_steps=100)
env = DiscreteOneHotWrapper(env)
# env = DiscreteOneHotWrapper(gym.make("FrozenLake-v1"))
# env = gym.wrappers.RecordVideo(
#     env, video_folder="video", name_prefix="mario", video_length=20
# )
obs_size = env.observation_space.shape[0]
n_actions = env.action_space.n

net = Net(obs_size, HIDDEN_SIZE, n_actions)
optimizer = optim.Adam(params=net.parameters(), lr=0.001)
objective = nn.CrossEntropyLoss()

In [None]:
full_batch = []

for iter_no, batch in enumerate(iterate_batches(env, net, BATCH_SIZE)):
    reward_mean = float(np.mean(list(map(lambda s: np.mean(s.reward), batch))))
    # 保存了好的情况下的批次
    full_batch, obs, acts, reward_bound = filter_batch(full_batch + batch, PERCENTILE)
    if not full_batch:
        continue
    obs_v = torch.FloatTensor(np.array(obs))
    acts_v = torch.LongTensor(np.array(acts))
    full_batch = full_batch[-500:]
    optimizer.zero_grad()
    action_scores_v = net(obs_v)
    loss_v = objective(action_scores_v, acts_v)
    loss_v.backward()
    optimizer.step()

    # if iter_no % 50 == 0 and iter_no != 0:
    print(
        "batch %d: loss=%.3f, rw_mean=%.3f, "
        "rw_bound=%.3f, batch=%d"
            % (iter_no, loss_v.item(), reward_mean, reward_bound, len(full_batch))
    )
    if reward_mean > 0.8:
        print("Solved!")
        break

batch 0: loss=1.385, rw_mean=0.002, rw_bound=0.000, batch=5
batch 1: loss=1.378, rw_mean=0.002, rw_bound=0.000, batch=9
batch 2: loss=1.375, rw_mean=0.001, rw_bound=0.000, batch=11
batch 3: loss=1.371, rw_mean=0.000, rw_bound=0.000, batch=11
batch 4: loss=1.367, rw_mean=0.002, rw_bound=0.000, batch=15
batch 5: loss=1.357, rw_mean=0.003, rw_bound=0.000, batch=21
batch 6: loss=1.359, rw_mean=0.001, rw_bound=0.000, batch=24
batch 7: loss=1.354, rw_mean=0.001, rw_bound=0.000, batch=27
batch 8: loss=1.343, rw_mean=0.004, rw_bound=0.000, batch=36
batch 9: loss=1.337, rw_mean=0.002, rw_bound=0.000, batch=40
batch 10: loss=1.332, rw_mean=0.000, rw_bound=0.000, batch=41
batch 11: loss=1.330, rw_mean=0.004, rw_bound=0.000, batch=49
batch 12: loss=1.325, rw_mean=0.001, rw_bound=0.000, batch=51
batch 13: loss=1.318, rw_mean=0.003, rw_bound=0.000, batch=56
batch 14: loss=1.313, rw_mean=0.002, rw_bound=0.000, batch=59
batch 15: loss=1.306, rw_mean=0.006, rw_bound=0.000, batch=71
batch 16: loss=1.304