In [4]:
from rl.common.logger import ConsoleLogger, FigureLogger, Tracker
from rl.ppo.policies import ActorCriticNet
from rl.ppo.ppo import PPO

from torch import optim
from env import VanillaEnv
import matplotlib.pyplot as plt
import numpy as np
import itertools

In [19]:
def plot_evaluation_grid(grid, training_positions, min_obs_position,
                         min_floor_height):
    """Plots the evaluation grid."""
    fig, ax = plt.subplots(figsize=(7, 9))
    grid_x, grid_y = grid.shape
    extent = (0, grid_x, grid_y, 0)
    ax.imshow(grid.T, extent=extent, origin='lower', cmap='copper')

    x_ticks = np.arange(grid_x)
    y_ticks = np.arange(grid_y)
    ax.set_xticks(x_ticks)
    ax.set_yticks(y_ticks)

    # ax.tick_params(labelbottom=False, labelleft=False)
    ax.set_ylabel("Floor height")
    ax.set_xlabel("Obstacle position")
    
    # Loop over data dimensions and create text annotations.
    for (obstacle_pos, floor_height) in training_positions:
        pos_index = obstacle_pos - min_obs_position
        height_index = floor_height - min_floor_height
        ax.text(
            pos_index + 0.5,
            height_index + 0.5,
            'T',
            ha='center',
            va='center',
            color='r',
            fontsize='large',
            fontweight = 'bold')

    ax.grid(color='w', linewidth=1)
    fig.tight_layout()
    return fig

In [20]:
# obstacle_pos: min: 14, max: 47
obstacle_pos = np.array(range(14, 48))
# floor_height: min: 0, max: 40
floor_height = np.array(range(10, 35))
ALL_CONFIGURATIONS = set(itertools.product(obstacle_pos, floor_height))

grid = np.zeros((len(obstacle_pos), len(floor_height)))

train_conf = {
    "narrow_grid": set([
        # (obstacle_pos, floor_height)
        (22, 18), (22, 24),
        (26, 18), (26, 24),
    ]),
    "wide_grid": set([
        # (obstacle_pos, floor_height)
        (18, 16), (18, 26),
        (28, 16), (28, 26),
    ]),
    "random": set([
        # (obstacle_pos, floor_height)
        (43, 29), (39, 33),
        (28, 19), (15, 17),
    ]),
}

# just quickly check that all training configurations are valid
for conf_name in train_conf.keys():
    for conf in train_conf[conf_name]:
        assert conf in ALL_CONFIGURATIONS, f"Invalid configuration in {conf_name}"

# TEST_CONFIGURATIONS = ALL_CONFIGURATIONS - TRAINING_CONFIGURATIONS

In [22]:
episodes = 5000
for conf_name in train_conf.keys():
    print(f"====== Training on {conf_name} ======")
    env = VanillaEnv(list(train_conf[conf_name]))

    policy: ActorCriticNet = ActorCriticNet()
    optimizer = optim.Adam(policy.parameters(), lr=0.001)

    logger1 = ConsoleLogger(log_every=1000)
    logger2 = FigureLogger()
    tracker = Tracker(logger1, logger2)

    ppo = PPO(policy, env, optimizer, seed=31, tracker=tracker)
    ppo.learn(episodes)
    ppo.save('./ckpts', conf_name)

    fig = logger2.get_figure(fig_size=(8, 4))
    fig.suptitle(f"Training on {conf_name} for {episodes} episodes")
    plt.show()
    


Episode:   1000, return: 24.0
Episode:   2000, return: 29.0
Episode:   3000, return: 29.0
