In [None]:
%load_ext autoreload
%autoreload 2

from rejax import PPO
import jax
from env import NavigationEnv, NavigationEnvParams
from rooms import RoomParams, generate_rooms

TRAIN_SEED = 11

# Initialize our environment
env = NavigationEnv()

# Generate rooms
room_params = RoomParams(
    size=4.0,
    grid_size=8,
    target_carved_percent=0.5,
    num_rooms=256,
)
room_key = jax.random.PRNGKey(TRAIN_SEED)
obstacles, free_positions = generate_rooms(room_key, room_params)

# Initialize environment parameters with generated rooms
env_params = NavigationEnvParams(
    lidar_max_distance=2.0,
    lidar_fov=120,
    lidar_num_beams=16,
    rooms=room_params,
    obstacles=obstacles,
    free_positions=free_positions,
)

# Initialize the training algorithm parameters
config = {
    # Pass our environment to the agent
    "env": env,
    "env_params": env_params,
    # Number of timesteps during which the agent will be trained
    "total_timesteps": 300_000,
}

# Create the training algorithm agent from `rejax` library
ppo = PPO.create(**config)

# Look at the whole configuration (we can experiment with all these parameters!)
ppo.config

## Training

In [None]:
import jax
import time

# Set the seed for reproducibility
TRAIN_SEED = 13

# Set training seed and jit train function
rng = jax.random.PRNGKey(TRAIN_SEED)
train_fn = jax.jit(ppo.train)

print("Starting to train")

# Train!
start = time.time()
train_state, evaluation = train_fn(rng)
time_elapsed = time.time() - start

sps = ppo.total_timesteps / time_elapsed
print(f"Finished training in {time_elapsed:g} seconds ({sps:g} steps/second).")

## Evaluation

In [None]:
from matplotlib import pyplot as plt

episode_lengths, episode_returns = evaluation
mean_return = episode_returns.mean(axis=1)

plt.plot(jax.numpy.linspace(0, ppo.total_timesteps, len(mean_return)), mean_return)
plt.xlabel("Environment step")
plt.ylabel("Episodic return")
plt.title(f"Training agent for {env.name} using {ppo.__class__.__name__}")
plt.show()

In [None]:
from eval import evaluate_model
from env_vis import save_gif
from pathlib import Path
from IPython.display import Image as IPImage, display


# Set the seed for reproducibility
TEST_SEED = 77

evaluation = evaluate_model(
    agent=ppo,
    train_state=train_state,
    seed=TEST_SEED,
    render=True,
    n_eval_episodes=4,
)

if evaluation.rendered_frames is not None:
    path = Path(f"temp/intro_policy.gif")
    save_gif(evaluation.rendered_frames, path)

    display(IPImage(filename=path))
