#DQN, Function Approximation, Pefrormance tricks

In this lab we study the basics of Q learning with function approximation by neural networks.

In [39]:
# Installing dependencies for visualization
!apt-get -qq -y install libcusparse8.0 libnvrtc8.0 libnvtoolsext1 > /dev/null
!ln -snf /usr/lib/x86_64-linux-gnu/libnvrtc-builtins.so.8.0 /usr/lib/x86_64-linux-gnu/libnvrtc-builtins.so
!apt-get -qq -y install xvfb freeglut3-dev ffmpeg> /dev/null
!pip -q install gymnasium[classic_control]
!pip -q install pyglet
!pip -q install pyopengl
!pip -q install pyvirtualdisplay

In [40]:
import glob
import random
import time
import collections
from typing import List, Tuple, Dict, Any

import gymnasium as gym
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.keras.optimizers import Adam
from tensorflow.python.keras.regularizers import l2
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model, clone_model
import matplotlib.pyplot as plt

from base64 import b64encode
from IPython.display import HTML
from pyvirtualdisplay import Display

# Start virtual display
display = Display(visible=0, size=(1024, 768))
display.start()


def show_video(file_name: str):
    mp4 = open(file_name, "rb").read()
    data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
    return HTML(
        """
    <video width=480 controls>
        <source src="%s" type="video/mp4">
    </video>
    """
        % data_url
    )

We will start by defining some useful data structure:

In [41]:
Transition = collections.namedtuple(
    "transition", ["state", "action", "reward", "done", "next_state"]
)

## CartPole
Debugging DQN is typically a complicated process, thus we have to start with a simple environment, that can be quickly iterated. Let's first construct working DQN for CartPole problem. We will use a small modification of the orginal CartPole env, we do reward reshape (to make problem easier for DQN):

In [42]:
class ModifiedCartPole:
    def __init__(self):
        self.env = gym.make("CartPole-v1")

    def reset(self) -> np.ndarray:
        return self.env.reset()[0]

    def step(self, action) -> Tuple[np.ndarray, float, bool, Dict[Any, Any]]:
        obs, reward, done, truncated, _ = self.env.step(action)
        if done:
            reward = -10
        return obs, reward / 10, done or truncated, {}

##Q-network.
First we must create a network to approximate Q(s, a). We have two natural design choices:
- Q-network takes two inputs: state s and action a and predicts one value Q(s,a)
- Q-network takes one input: state s, and predicts a vector of Q(s, a) for all possible actions.

We will follow the second design choice (one of the reasons is that such network can faster predict the best action).

**Exercise: fill the code below to create Q-network** Create a simple fully connected network with `num_layers` layers each with 64 neurons. The input is a vector of size `input_size`, and the output is a vector of size `num_actions` (we have 2 actions in cartpole).

In [43]:
def make_cartpole_network(
    input_size: int = 4,
    num_action: int = 2,
    num_layers: int = 3,
    learning_rate: float = 1e-4,
    weight_decay: float = 0.0,
) -> Model:
    input_state = Input(batch_shape=(None, input_size))
    #### TODO ####
    #### END ####
    model = Model(inputs=input_state, outputs=output)
    model.compile(loss="mse", optimizer=Adam(learning_rate=learning_rate))
    return model

## Building DQN

We will start with some utils functions:

**Exercise: read the following functions, implement epsilon greedy policy**

In [44]:
def predict_q_values(q_network: Model, state: np.ndarray) -> np.ndarray:
    """Makes a prediction for a single state and returns array of Q-values"""
    return q_network.predict(np.array([state]), verbose=0)[0]


def choose_best_action(q_network: Model, state: np.ndarray) -> int:
    """Chooses best action according to Q-network"""
    action_values = predict_q_values(q_network, state)
    best_action = np.argmax(action_values)
    return best_action


def evaluate_state_batch(
    target_network: Model, state_batch: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """This function can evaluate the whole batch of states at once, it
    is very useful to speedup the training when we calculate targets
    Arguments:
      - state_batch: list of states to evaluate
    Returns:
      - best actions: list of best action for every state
      - best vals: list of best state-action values for very state
      - action_values: list of all action-values for each state

    Here we named the argument target network instead of q_network, because this
    function will be used with target network.
    """
    action_values = target_network.predict(np.array(state_batch), verbose=0)
    best_actions = np.argmax(action_values, axis=-1)
    best_vals = np.max(action_values, axis=-1)
    return best_actions, best_vals, action_values


def choose_action(q_network: Model, state, epsilon: float) -> int:
    """Implement epsilon-greedy policy."""
    #### TODO ####
    #### END ####
    return action

While running the episodes we will collect transitions and store them in a replay_buffer, which is just a list of transitions. Before we write a code for running episodes we must first prepare a function that prepares training (since it is used while running the game) and a one for doing the training.

**Exercise: the training protocole is the heart of DQN. Fill the gaps in the following function. Specific tasks are described in multiline comments.**

In [45]:
def sample_minibatch(
    replay_buffer: List[Transition], mini_batch_size: int
) -> List[Transition]:
    """Write a code to choose random samples from replay_buffer.
    Choose mini_batch_size of samples and collect them in replay_batch variable.
    replay_batch must be a list of transitions.
    Hint: you can use random.sample method."""
    replay_batch: list
    #### TODO ####
    #### END ####
    return replay_batch


def compute_target(
    transition: Transition, next_state_value: float, gamma: float
) -> float:
    """Compute TD(1) target based on current transition and next state value.
    Remember to treat last state of the episode separately.
    """
    #### TODO ####
    #### END ####
    return target


def prepare_update_targets(
    target_network: Model,
    replay_buffer: List[Transition],
    mini_batch_size: int,
    gamma: float = 0.99,
) -> Tuple[np.ndarray, np.ndarray]:
    replay_batch = sample_minibatch(replay_buffer, mini_batch_size)

    # We will collect all states and next_states from the batch of transitions
    # to evaluate them at once.
    next_state_batch = [transition.next_state for transition in replay_batch]
    state_batch = [transition.state for transition in replay_batch]

    _, next_state_values, _ = evaluate_state_batch(target_network, next_state_batch)
    _, _, state_action_vals = evaluate_state_batch(target_network, state_batch)

    train_x, train_y = [], []
    for transition, state_vals, next_state_value in zip(
        replay_batch, state_action_vals, next_state_values
    ):
        """Prepare x, y training pairs for supervised model update:
        - x is a state,
        - y is a vector of values for each action.
        Note, that we only compute new targets for one action (the one in the trajectory),
        values for other actions should remain unchanged.
        Hint: use copy() method to make sure you are not modifying data in replay buffer.
        """
        # Copy transition.state to x, use copy() method
        action = transition.action
        x = transition.state.copy()
        # Copy state_vals vector to y
        y = state_vals.copy()

        y[action] = compute_target(transition, next_state_value, gamma)

        train_x.append(x)
        train_y.append(y)

    return np.array(train_x), np.array(train_y)

**Exercise: fill the gaps in the update function**

In [46]:
def update(
    q_network: Model,
    target_network: Model,
    replay_buffer: List[Transition],
    mini_batch_size: int,
    gamma: float,
) -> float:
    """Prepare training batch (x and y) and update the model on it.
    For models in keras you can use train_on_batch method. Checkout its
    interface in the documentation."""
    #### TODO ####
    #### END ####
    return loss

Now, let us code the heart of DQN algorithm: the function that runs an epizode and trains Q-network.

**Exercise: fill the code in run_one_episode function**

In [47]:
def run_one_episode(
    q_network: Model,
    target_network: Model,
    env: ModifiedCartPole,
    epsilon: float,
    steps_so_far: int,
    replay_buffer: List[Transition],
    mini_batch_size: int,
    update_target_every_n_steps: int,
    gamma: float,
) -> Tuple[int, List[float]]:
    done = False
    episode_steps = 0
    state = env.reset()
    ep_actions = []
    loss_history = []
    while not done:
        # Implement env interaction
        # 1. Select action with eps-greedy policy
        # 2. Advance the env
        # 3. Store recorded transition in the replay buffer
        # Remember to update the `state` variable
        ### TODO ###
        ### END ###
        ep_actions.append(action)
        episode_steps += 1
        steps_so_far += 1

        if len(replay_buffer) > mini_batch_size:
            # Update the model
            ### TODO ###
            ### END ###
            loss_history.append(loss)
            if steps_so_far % update_target_every_n_steps == 0:
                print(f"Updating target network")
                target_network.set_weights(q_network.get_weights())

    return episode_steps, loss_history

Finally, we can complete the full DQN algorithm.

In [48]:
def run_dqn(
    train_steps: int, n_checkpoints: int
) -> Tuple[List[int], List[float], List[np.ndarray]]:
    # The parameter
    env = ModifiedCartPole()

    # We save several checkpoints to later visualize theit performance
    q_checkpoints = []
    save_q_chepoint_every_n_steps = train_steps / n_checkpoints

    # Here is a set of default parameters (tested), you can try to find better values
    epsilon = 0.4
    min_epsilon = 0.1
    epsilon_decay = 0.99
    gamma = 0.975
    mini_batch_size = 128
    update_target_every_n_steps = 128

    replay_buffer = []

    q_network = make_cartpole_network()
    target_network = make_cartpole_network()

    steps_so_far = 0

    episode_lengths, loss_history = [], []
    episode_num = 0

    while steps_so_far < train_steps:
        episode_length, loss = run_one_episode(
            q_network,
            target_network,
            env,
            epsilon,
            steps_so_far,
            replay_buffer,
            mini_batch_size,
            update_target_every_n_steps,
            gamma,
        )
        if epsilon > min_epsilon:
            epsilon *= epsilon_decay
        episode_num += 1
        episode_lengths.append(episode_length)
        if loss is not None:
            loss_history.extend(loss)
        steps_so_far += episode_length
        if (
            steps_so_far - len(q_checkpoints) * save_q_chepoint_every_n_steps
            >= save_q_chepoint_every_n_steps
        ):
            q_checkpoints.append(q_network.get_weights())
        print(
            f"Episode = {episode_num} | steps =  {steps_so_far} | "
            f"episode_length = {episode_length} | epsilon = {epsilon} | "
            f"loss = {np.mean(loss)}"
        )

    return episode_lengths, loss_history, q_checkpoints

Let us now run the training (it may take several minutes to take the training of 5000-8000 steps). Do not expect the reward to grow monotonically. The training typically looks like a noisy process with some drift towards higher returns.

In [49]:
progress, loss_history, q_checkpoints = run_dqn(2000, 5)

In [50]:
def visualize_progress(progress: List[int], loss_history: List[float]):
    plt.clf()
    plt.plot(progress, label="DQN progress")
    smoothed_progress = [0]
    for x in progress:
        smoothed_progress.append(0.8 * smoothed_progress[-1] + 0.2 * x)
    plt.plot(smoothed_progress, label="DQN learning (smoothed)")
    plt.legend(loc="upper left")
    plt.show()

    plt.clf()
    plt.plot(loss_history, label="Loss")
    plt.legend(loc="upper left")
    plt.show()

In [51]:
visualize_progress(progress, loss_history)

Let us see how the agent performs across the training:

In [63]:
def record_checkpoint(checkpoint: np.ndarray):
    # This function records an episode of the agent equipped with a given chekpoint
    env = gym.make("CartPole-v1", render_mode="rgb_array")
    model = make_cartpole_network()
    model.set_weights(checkpoint)
    max_ep_len = 200
    envw = gym.wrappers.RecordVideo(env, "./", name_prefix="cartpole-video")
    (o, info_), d, ep_len = envw.reset(), False, 0
    while not (d or (ep_len == max_ep_len)):
        envw.render()
        action = choose_best_action(model, o)
        o, r, d, t, info = envw.step(action)
    env.close()

Lets take a look at first saved chekpoint:

In [64]:
record_checkpoint(q_checkpoints[0])
file_name = glob.glob("cartpole-video*.mp4")[0]
show_video(file_name)

And the last:

In [65]:
record_checkpoint(q_checkpoints[-1])
file_name = glob.glob("cartpole-video*.mp4")[0]
show_video(file_name)

# Ablation study
Let's see the what happens to DQN performance after turning off some of its mechanisms:
- target network
- sampling from replay_buffer

**Exercise: turn off the usage of target networks.** You can for example modify the code of run_dqn() and set target_network = q_network. Compare the results with previous run.

**Exercise: add the size limit to replay buffer.** Add a code to run_dqn() that clips its size to a given limit. What happens if the replay buffer is very small?