# Example using DQN and Flashbax in gym environments

### [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/instadeepai/flashbax/blob/main/examples/gym_dqn_example.ipynb)

In [None]:
!git clone https://github.com/instadeepai/flashbax.git
!pip install ./flashbax[examples]

#### Imports

In [1]:
import random
import time
from typing import NamedTuple

import haiku as hk
import gymnasium as gym
import jax
import jax.numpy as jnp
import numpy as np
import optax
import rlax
import chex


In [None]:
import flashbax as fbx

#### Define Network and data classes

In [2]:
# Define a simple network function using Haiku.
def get_network_fn(num_outputs: int):
    """Define a fully connected multi-layer haiku network."""
    def network_fn(obs: chex.Array, rng: chex.PRNGKey) -> chex.Array:
        return hk.Sequential([  # flatten, 2x hidden + relu, output layer.
            hk.Flatten(),
            hk.Linear(256), jax.nn.leaky_relu,
            hk.Linear(128), jax.nn.leaky_relu,
            hk.Linear(num_outputs)])(obs)
    return hk.without_apply_rng(hk.transform(network_fn))

# Define a simple tuple to hold the state of the training.
class TrainState(NamedTuple):
    params: hk.Params
    target_params: hk.Params
    opt_state : optax.OptState


# Define a simple tuple to hold the state of the environment. This is the format we will use to store transitions in our buffer.
@chex.dataclass(frozen=True)
class TimeStep:
    observation: chex.Array
    action: chex.Array
    discount: chex.Array
    reward: chex.Array


#### Training Parameters

In [3]:
# We specify our parameters 
env_id = "CartPole-v1"
seed = 42
num_envs = 1

total_timesteps = 50_000
learning_starts = 1_000
train_frequency = 5
target_network_frequency = 500
sample_batch_size = 128
buffer_size = 50_000
tau = 1.0
learning_rate = 1e-3
start_e = 1.0
end_e = 0.01
exploration_fraction = 0.5
gamma = 0.99


#### Set up environment

In [4]:
# We then set up the environments
def make_env(env_id, seed):
    def thunk():
        
        env = gym.make(env_id)
        # We use an auto reset wrapper to automatically reset the environment
        # when the episode is done since we are using vectorized environments
        # and we want all the environments to always be active.
        # Additionally, we use the auto reset wrapper since we are adding transitions 
        # sequentially and we want to maintain the order in which we are adding the transitions.
        env = gym.wrappers.AutoResetWrapper(env)
        env = gym.wrappers.RecordEpisodeStatistics(env)
        env.action_space.seed(seed)

        return env

    return thunk

if num_envs == 1:
    envs = make_env(env_id, seed)()
    assert isinstance(envs.action_space, gym.spaces.Discrete), "only discrete action space is supported"
    num_actions = envs.action_space.n
else:
    envs = gym.vector.SyncVectorEnv(
            [make_env(env_id, seed + i) for i in range(num_envs)]
        )
    assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"
    num_actions = envs.single_action_space.n

#### Train DQN agent

In [5]:
with jax.default_device(jax.devices('cpu')[0]):

    # Specify the random seeds we will use for reproducibility
    random.seed(seed)
    np.random.seed(seed)
    key = jax.random.PRNGKey(seed)
    key, q_key = jax.random.split(key, 2)

    # Set up the network and optimiser
    q_network = get_network_fn(num_actions)
    optim = optax.adam(learning_rate=learning_rate)

    # Get an initial observation from the environment to initialize the network 
    dummy_obs, _ = envs.reset(seed=seed)
    if num_envs>1:
        dummy_obs = dummy_obs[0]
    # Initialize the network parameters
    params = q_network.init(q_key, dummy_obs, None)
    # Initialize the optimiser state
    opt_state=optim.init(params)
    # Initialize the initial train state
    q_state = TrainState(
        params=params,
        target_params=params,
        opt_state=opt_state
    )
    # Initialize the replay buffer
    # We specify the size of the buffer - this is the maximum number of transitions that can be stored
    # We specify the minimum number of transitions that must be in the buffer before we can sample
    # We specify the number of transitions to sample at once
    # We specify whether we will be adding sequences of transitions or individual transitions. 
    # In this case we will be adding individual transitions as we add each timestep to the buffer.
    # We specify whether we will be adding batches of transitions or individual transitions. 
    # If we are using a vectorised environment (n_env > 1), we will add batches of transitions and 
    # specify the add batch size as the number of environments
    buffer = fbx.make_flat_buffer(
        max_length=buffer_size,
        min_length=sample_batch_size,
        sample_batch_size=sample_batch_size,
        add_sequences=False,
        add_batch_size=num_envs if num_envs>1 else None,
    )
    buffer = buffer.replace(
        init = jax.jit(buffer.init),
        add = jax.jit(buffer.add, donate_argnums=0),
        sample = jax.jit(buffer.sample),
        can_sample = jax.jit(buffer.can_sample),
    )
    # Create a dummy timestep to initialize the buffer
    dummy_timestep = TimeStep(observation=dummy_obs, action=jnp.int32(0), reward=jnp.float32(0.0), discount=jnp.float32(0.0))
    buffer_state = buffer.init(dummy_timestep)

    # Create a linear schedule function for the epsilon greedy exploration
    # linear_schedule = jax.jit(optax.polynomial_schedule(start_e, end_e, 1.0 ,exploration_fraction * total_timesteps))
    # Faster to use custom than optax.polynomial_schedule due to jax conversions
    def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
        slope = (end_e - start_e) / duration
        return max(slope * t + start_e, end_e)

    # Create a function to update the network
    @jax.jit
    def update(q_state : TrainState, batch : TimeStep):
        
        def loss_fn(params, target_params,  batch):
            q_tm1 = q_network.apply(params, batch.first.observation, None)
            a_tm1 = batch.first.action
            r_t = batch.first.reward
            d_t = batch.first.discount*gamma # We use first here because of the way we add transitions to the buffer
            q_t = q_network.apply(target_params, batch.second.observation, None)

            return jnp.mean(jnp.square(jax.vmap(rlax.q_learning)(q_tm1, a_tm1, r_t, d_t, q_t)))
        

        loss, grads = jax.value_and_grad(loss_fn)(q_state.params, q_state.target_params, batch)
        updates, new_opt_state = optim.update(grads, q_state.opt_state)  # transform grads.
        new_params = optax.apply_updates(q_state.params, updates)  # update parameters.
        q_state = q_state._replace(
            params=new_params,
            opt_state=new_opt_state
        )
        return loss, q_state

    # Create a function to select actions from the network
    @jax.jit
    def action_select_fn(q_state, obs):
        q_values = q_network.apply(q_state.params, obs, None)
        actions = jnp.argmax(q_values, axis=-1)
        return actions

    @jax.jit
    def perform_update(q_state, buffer_state, sample_key):
        data = buffer.sample(buffer_state, sample_key)
            
        loss, q_state = update(
            q_state,
            data.experience
        )
        return loss, q_state

    start_time = time.time()

    # Run the training loop
    print("Starting training...")
    obs, _ = envs.reset(seed=seed) # obs = np.array 
    for global_step in range(total_timesteps):
        epsilon = linear_schedule(start_e, end_e, exploration_fraction * total_timesteps, global_step)
        if random.random() < epsilon:
            if num_envs > 1:
                actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
            else:
                actions = envs.action_space.sample()
        else:
            actions = action_select_fn(q_state, obs) # obs = np.array -> jnp.array
            actions = jax.device_get(actions) # actions = jnp.array -> np.array

        next_obs, rewards, terminated, truncated, infos = envs.step(actions)

        if global_step % 1000 == 0:
            print(f"Current Training Step: {global_step}")

        # Create a timestep
        timestep = TimeStep(observation=obs, action=actions, reward=rewards, discount = 1-np.asarray(terminated).astype(np.float32))

        # # Add the timestep to the buffer
        buffer_state = buffer.add(buffer_state, timestep)

        # Update the observation
        obs = next_obs
        
        # Update the network
        loss = 0
        if global_step > learning_starts:
            if global_step % train_frequency == 0:
                # Check if the buffer can sample
                if buffer.can_sample(buffer_state):
                    
                    key, sample_key = jax.jit(jax.random.split)(key)
                    loss, q_state = perform_update(q_state, buffer_state, sample_key)

                if global_step % 100 == 0:
                    print("SPS:", int(global_step / (time.time() - start_time)))
                    print("Loss", loss)
                    

            # Update the target network
            if global_step % target_network_frequency == 0:
                q_state = q_state._replace(
                    target_params=optax.incremental_update(q_state.params, q_state.target_params, tau)
                )

    print("Training complete.")


  chex.assert_axis_dimension_lteq(jax.tree_leaves(batch)[0], 1, max_length_time_axis)


Starting training...
Current Training Step: 0
Current Training Step: 1000
SPS: 1632
Loss 0.28701562
SPS: 1731
Loss 0.15886527
SPS: 1826
Loss 0.04291299
SPS: 1917
Loss 0.013063445
SPS: 2002
Loss 0.007482495
SPS: 1906
Loss 0.10549195
SPS: 1980
Loss 0.044599906
SPS: 2049
Loss 0.033348784
SPS: 2114
Loss 0.07498935
Current Training Step: 2000
SPS: 2176
Loss 0.079916604
SPS: 2234
Loss 0.099185415
SPS: 2293
Loss 0.09092561
SPS: 2350
Loss 0.16927493
SPS: 2405
Loss 0.097453564
SPS: 2458
Loss 0.10643174
SPS: 2506
Loss 0.28333086
SPS: 2555
Loss 0.1799488
SPS: 2601
Loss 0.1939533
SPS: 2645
Loss 0.17877991
Current Training Step: 3000
SPS: 2685
Loss 0.1255859
SPS: 2724
Loss 0.27233487
SPS: 2764
Loss 0.25333983
SPS: 2801
Loss 0.15252432
SPS: 2840
Loss 0.2792627
SPS: 2877
Loss 0.15868285
SPS: 2911
Loss 0.33359522
SPS: 2945
Loss 0.16947365
SPS: 2978
Loss 0.3194501
SPS: 3009
Loss 0.34927875
Current Training Step: 4000
SPS: 3037
Loss 0.44799635
SPS: 3063
Loss 0.35320124
SPS: 3088
Loss 0.28368106
SPS: 311

#### Performance Evaluation

In [6]:
print("Evaluating...")
envs = make_env(env_id, seed)()
obs, _ = envs.reset(seed=seed) # obs = np.array 
for global_step in range(10_000):
    actions = action_select_fn(q_state, obs) # obs = np.array -> jnp.array
    actions = jax.device_get(actions) # actions = jnp.array -> np.array

    next_obs, rewards, terminated, truncated, infos = envs.step(actions)

    # Get Episode Return Statistics
    if "final_info" in infos:
        if isinstance(infos["final_info"], dict):
            print(f"Evaluating Step : {global_step}, episodic_return={infos['episode']['r'][0]}")

    # Update the observation
    obs = next_obs

envs.close()

Evaluating...
Evaluating Step : 230, episodic_return=231.0
Evaluating Step : 598, episodic_return=368.0
Evaluating Step : 1008, episodic_return=410.0
Evaluating Step : 1282, episodic_return=274.0
Evaluating Step : 1533, episodic_return=251.0
Evaluating Step : 1747, episodic_return=214.0
Evaluating Step : 1993, episodic_return=246.0
Evaluating Step : 2341, episodic_return=348.0
Evaluating Step : 2698, episodic_return=357.0
Evaluating Step : 3097, episodic_return=399.0
Evaluating Step : 3359, episodic_return=262.0
Evaluating Step : 3606, episodic_return=247.0
Evaluating Step : 3994, episodic_return=388.0
Evaluating Step : 4228, episodic_return=234.0
Evaluating Step : 4552, episodic_return=324.0
Evaluating Step : 4795, episodic_return=243.0
Evaluating Step : 5062, episodic_return=267.0
Evaluating Step : 5355, episodic_return=293.0
Evaluating Step : 5622, episodic_return=267.0
Evaluating Step : 5896, episodic_return=274.0
Evaluating Step : 6131, episodic_return=235.0
Evaluating Step : 6392