In [1]:
import jax
import jax.numpy as jnp
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import optax
import haiku as hk
import plotly.graph_objects as go
import numpy as np

from jax import random, vmap

from jax_tqdm import loop_tqdm

import sys

from cartpole import CartPole
from agents import DQN
from replay_buffer import UniformReplayBuffer
from rollout import deep_rl_rollout

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Env parameters
RANDOM_SEED = 1
N_ACTIONS = 2
STATE_SHAPE = (4,)

# Hyperparameters
DISCOUNT = 0.9
NEURONS_PER_LAYER = [64, 64, 64, N_ACTIONS]
TIMESTEPS = 20_000
TARGET_NET_UPDATE_FREQ = 10
BUFFER_SIZE = 512
BATCH_SIZE = 32
LEARNING_RATE = 0.001
EPSILON_START = 0.3
EPSILON_END = 0
DECAY_RATE = 1e-3

In [3]:
buffer_state = {
    "states": jnp.empty((BUFFER_SIZE, *STATE_SHAPE), dtype=jnp.float32),
    "actions": jnp.empty((BUFFER_SIZE,), dtype=jnp.int32),
    "rewards": jnp.empty((BUFFER_SIZE,), dtype=jnp.int32),
    "next_states": jnp.empty((BUFFER_SIZE, *STATE_SHAPE), dtype=jnp.float32),
    "dones": jnp.empty((BUFFER_SIZE,), dtype=jnp.bool_),
}
print(jax.tree_map(lambda x: x.shape, buffer_state))

{'actions': (512,), 'dones': (512,), 'next_states': (512, 4), 'rewards': (512,), 'states': (512, 4)}


In [4]:
@hk.transform
def model(x):
    mlp = hk.nets.MLP(output_sizes=NEURONS_PER_LAYER)
    return mlp(x)


def inverse_scaling_decay(epsilon_start, epsilon_end, current_step, decay_rate):
    return epsilon_end + (epsilon_start - epsilon_end) / (1 + decay_rate * current_step)


online_key, target_key = vmap(random.PRNGKey)(jnp.arange(2) + RANDOM_SEED)
env = CartPole()

replay_buffer = UniformReplayBuffer(BUFFER_SIZE, BATCH_SIZE)
online_net_params = model.init(online_key, random.normal(online_key, STATE_SHAPE))
target_net_params = model.init(target_key, random.normal(target_key, STATE_SHAPE))
optimizer = optax.adam(learning_rate=LEARNING_RATE)
optimizer_state = optimizer.init(online_net_params)
agent = DQN(model, DISCOUNT, N_ACTIONS)

In [5]:
px.line(
    [
        inverse_scaling_decay(EPSILON_START, EPSILON_END, i, DECAY_RATE)
        for i in range(TIMESTEPS)
    ],
    title="Epsilon Decay",
)

In [6]:
jax.tree_map(lambda x: x.shape, online_net_params)

{'mlp/~/linear_0': {'b': (64,), 'w': (4, 64)},
 'mlp/~/linear_1': {'b': (64,), 'w': (64, 64)},
 'mlp/~/linear_2': {'b': (64,), 'w': (64, 64)},
 'mlp/~/linear_3': {'b': (2,), 'w': (64, 2)}}

In [7]:
jax.tree_map(lambda x: x.shape, optimizer_state)

(ScaleByAdamState(count=(), mu={'mlp/~/linear_0': {'b': (64,), 'w': (4, 64)}, 'mlp/~/linear_1': {'b': (64,), 'w': (64, 64)}, 'mlp/~/linear_2': {'b': (64,), 'w': (64, 64)}, 'mlp/~/linear_3': {'b': (2,), 'w': (64, 2)}}, nu={'mlp/~/linear_0': {'b': (64,), 'w': (4, 64)}, 'mlp/~/linear_1': {'b': (64,), 'w': (64, 64)}, 'mlp/~/linear_2': {'b': (64,), 'w': (64, 64)}, 'mlp/~/linear_3': {'b': (2,), 'w': (64, 2)}}),
 EmptyState())

In [8]:
# (optional) initialize the replay buffer with random samples
init_key, action_key, buffer_key = vmap(random.PRNGKey)(jnp.arange(3) + RANDOM_SEED)
env_state, _ = env.reset(init_key)

for i in range(BUFFER_SIZE):
    state, _ = env_state
    # set epsilon to 1 for exploration
    action, action_key = agent.act(action_key, online_net_params, state, 1)
    env_state, new_state, reward, done = env.step(env_state, action)
    experience = (state, action, reward, new_state, done)

    buffer_state = replay_buffer.add(buffer_state, experience, i)


scatter inputs have incompatible types: cannot safely cast value from dtype=float32 to dtype=int32 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.



In [9]:
rollout_params = {
    "timesteps": TIMESTEPS,
    "random_seed": RANDOM_SEED,
    "target_net_update_freq": TARGET_NET_UPDATE_FREQ,
    "model": model,
    "optimizer": optimizer,
    "buffer_state": buffer_state,
    "agent": agent,
    "env": env,
    "replay_buffer": replay_buffer,
    "state_shape": STATE_SHAPE,
    "buffer_size": BUFFER_SIZE,
    "epsilon_decay_fn": inverse_scaling_decay,
    "epsilon_start": EPSILON_START,
    "epsilon_end": EPSILON_END,
    "decay_rate": DECAY_RATE,
}

out = deep_rl_rollout(**rollout_params)

Running for 20,000 iterations: 100%|██████████| 20000/20000 [00:00<00:00, 22364.10it/s]


In [10]:
px.line(out["losses"][:-10000], title="Loss during training")

In [11]:
colors = px.colors.qualitative.Plotly

df = pd.DataFrame(
    data={
        "episode": out["all_done"].cumsum(),
        "reward": out["all_rewards"],
    },
)
df["episode"] = df["episode"].shift().fillna(0)

episodes_df = df.groupby("episode").agg("sum")
# Define hover text based on the reward value
episodes_df["hover_text"] = np.where(
    episodes_df["reward"] > 200,
    "Over 200 steps: " + episodes_df["reward"].astype(str),
    "Under 200 steps: " + episodes_df["reward"].astype(str),
)

# Define colors based on the reward value
episodes_df["color"] = np.where(episodes_df["reward"] > 200, colors[2], colors[0])

# Create the figure
fig = go.Figure()

# Add bars for "under 200 steps"
fig.add_trace(
    go.Bar(
        x=episodes_df.index[episodes_df["reward"] < 200],
        y=episodes_df["reward"][episodes_df["reward"] < 200],
        marker=dict(color=colors[0]),
        legendgroup="under200",
        name="Under 200 steps",
        showlegend=True,
    )
)

# Add bars for "over 200 steps" on top of the previous bars
fig.add_trace(
    go.Bar(
        x=episodes_df.index[episodes_df["reward"] >= 200],
        y=episodes_df["reward"][episodes_df["reward"] >= 200],
        marker=dict(color=colors[2]),
        legendgroup="over200",
        name="Over 200 steps",
        showlegend=True,
    )
)

# Update the layout
fig.update_layout(
    title=f"Performances of DQN on the CartPole Environment",
    xaxis_title="Episode",
    yaxis_title="Sum of rewards",
    yaxis_range=[0, 200],
    # xaxis_range=[0, 500],
    legend_title_text="Reward Categories",
)

fig.show()

In [12]:
df.groupby("episode").agg("sum").max()

reward    393.0
dtype: float32