# Introduction to Reinforcement Learning


<center>
<img src="https://raw.githubusercontent.com/jcformanek/jcformanek.github.io/master/docs/assets/images/rl_in_space.png" width="100%" />
</center>

<a href="https://colab.research.google.com/github/deep-learning-indaba/indaba-pracs-2022/blob/main/practicals/introduction_to_reinforcement_learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

© Deep Learning Indaba 2022. Apache License 2.0.

**Authors:**
Claude Formanek, Kale-ab Tessera, Sicelukwanda Zwane, Sebastian Bodenstein

**Reviewers:**
Ruan van der Merwe, Avishkar Bhoopchand

**Introduction:** 

In this tutorial, we will be learning about Reinforcement Learning, a type of Machine Learning where an agent learns to choose actions in an environment that lead to maximal reward in the long run. RL has seen tremendous success on a wide range of challenging problems such as learning to play complex video games like [Atari](https://www.deepmind.com/blog/agent57-outperforming-the-human-atari-benchmark), [StarCraft II](https://www.deepmind.com/blog/alphastar-mastering-the-real-time-strategy-game-starcraft-ii) and [Dota II](https://openai.com/five/). 

In this introductory tutorial we will solve the classic [CartPole](https://www.gymlibrary.ml/environments/classic_control/cart_pole/) environment, where an agent must learn to balance a pole on a cart, using several different RL approaches. Along the way you will be introduced to some of the most important concepts and terminology in RL.

**Topics:** 
* Reinforcement Learning
* Random Policy Search
* Policy Gradient
* Q-Learning

**Level:** 

Beginner

**Aims/Learning Objectives:**

* Understand the basic theory behind RL.
* Implement a simple random policy search algorithm.
* Implement a simple policy gradient RL algorithm.
* Implement a simple Q-learning algorithm.

**Prerequisites:**

* Some familiarity with [JAX](https://github.com/google/jax).
* Neural network basics.

**Outline:** 

* Section 1: Key Concepts in Reinforcement Learning
* Section 2: Random Policy Search
* Section 3: Policy Gradient
* Section 4: Deep Q-Learning


## Setup

In [None]:
# @title Install required packages (run me) { display-mode: "form" }
# @markdown This may take a minute or two to complete.
%%capture
!pip install jaxlib
!pip install jax
!pip install git+https://github.com/deepmind/dm-haiku
!pip install gym==0.25
!pip install gym[box2d]
!pip install optax
!pip install matplotlib
!pip install chex

In [None]:
# @title Import required packages (run me) { display-mode: "form" }
%%capture
import copy
from shutil import rmtree # deleting directories
import random
import collections # useful data structures
import numpy as np
import gym # reinforcement learning environments
from gym.wrappers import RecordVideo
import jax
import jax.numpy as jnp # jax numpy
import haiku as hk # jax neural network library
import optax # jax optimizer library
import matplotlib.pyplot as plt # graph plotting library
from IPython.display import HTML
from base64 import b64encode
import chex

# Hide warnings
import warnings
warnings.filterwarnings('ignore')

## Section 1: Key Concepts in Reinforcement Learning

Reinforcement Learning (RL) is a subfield of Machine Learning (ML). Unlike fields like supervised learning, where we give examples of expected behaviour to our models, RL focuses on *goal-orientated* learning from interactions, through trial-and-error. RL algorithms learn what to do (i.e. which optimal actions to take) in an environment to maximise some reward signal. In settings like a video game, the reward signal could be the score of the game, i.e., RL algorithms will try to maximise the score in the game by choosing the best actions.  

<center>
<img src="https://miro.medium.com/max/1400/1*Ews7HaMiSn2l8r70eeIszQ.png" width="40%" />
</center>

[*Image Source*](https://towardsdatascience.com/multi-agent-deep-reinforcement-learning-in-15-lines-of-code-using-pettingzoo-e0b963c0820b)

More precisely, in RL we have an **agent** which perceives an **observation** $o_t$ of the current state $s_t$ of the **environment** and must choose an **action** $a_t$ to take. The environment then transitions to a new state $s_{t+1}$ in response to the agent's action and also gives the agent a scalar reward $r_t$ to indicate how good or bad the chosen action was given the environment's state. The goal in RL is for the agent to maximise the amount of reward it receives from the environment over time. The subscript $t$ is used to indicate the timestep number, i.e., $s_0$ is the state of the environment at the initial timestep, and $a_{99}$ is the agent's action at the $99th$ timestep. 

### Environment - OpenAI Gym
As mentioned above, an environment receives an action $a_t$ from the agent and returns reward $r_t$ and observation $o_t$.

OpenAI has provided a Python package called **Gym** that includes implementations of popular environments and a simple interface for an RL agent to interact with. To use a supported [gym environment](https://www.gymlibrary.ml/), all you need to do is pass the name of the environment to the function `gym.make(<environment_name>)`.

In this tutorial, we will be using a simple environment called **CartPole**. In CartPole the task is for the agent to learn to balance a pole for as long as possible by moving a cart *left* or *right*.

<img src="https://miro.medium.com/max/600/1*v8KcdjfVGf39yvTpXDTCGQ.gif" width="30%" />

In [None]:
# Create the environment
env_name = "CartPole-v0"
env = gym.make(env_name)

### States and Observations - $s_t$ and $o_t$

In RL, an agent perceives an observation of the environment's state. In some settings, the observation may include all the information underlying the environment's state. Such an environment is called **fully observed**. In other settings, the agent may only receive partial information about the environment's state in its observation. Such an environment is called **partially observed**. 

For the rest of this tutorial, we will assume the environment is fully observed and so we will use state $s_t$ and observation $o_t$ interchangeably. In Gym we get the initial observation from the environment by calling the function `env.reset()`.

In [None]:
# Reset the environment
s_0 = env.reset()
print("Initial State::", s_0)

# Get environment obs space
obs_shape = env.observation_space.shape
print("Environment Obs Space Shape:", obs_shape)

In CartPole, the state of the environment is represented by four numbers; *angular position of the pole, angular velocity of the pole, position of the cart, velocity of the cart*. 

### Actions - $a_t$

In RL, actions are usually either **discrete** or **continuous**. Continuous actions are given by a vector of real numbers. Discrete actions are given by an integer value. In environments where we can count out the finite set of actions we usually use discrete actions. 

In CartPole there are only two actions; *left* and *right*. As such, the actions can be represented by integers $0$ and $1$. In gym we can easily get the list of possible actions as follows:

In [None]:
# Get action space - e.g. discrete or continuous
print(f"Environment action space: {env.action_space}")

# Get num actions
num_actions = env.action_space.n
print(f"Number of actions: {num_actions}")

### The Agent's Policy - $\pi$

In RL the agent chooses actions based on the observations it receives. We can think of the agent's action selection process as a function that takes an observation as input and returns an action as output. In RL we usually call this function the agent's **policy** and denote it $\pi(s_t)=a_t$. In RL we usually parametrise our policy in some way and then try to learn the optimal parameters. A parametrised policy is usually denoted $\pi_\theta$, where $\theta$ is the set of parameters.

**Exercise 1:** As an exercise, lets implement a simple policy which takes a set of parameters and an observation  as input, and returns an action. Assume the observation is a vector of four numbers like the CartPole observation and that the action should be either $0$ or $1$. Assume also, the parameters are a vector of four real-numbers. Then the action should be computed as follows.


1.   Compute the [vector dot product](https://www.mathsisfun.com/algebra/vectors-dot-product.html) between the observation and the parameters.
2.   If the result is greater than zero, return action $1$.
3. 	 Else return action $0$.

In this prac we will try to use JAX as much as possible. So, try to use JAX methods for this task. Below are some useful methods you can use. You will need to complete the code in the block below by replacing the `...` with the correct code.

We have also provided a code cell to check your solution. Finally, we also provide the solution to the coding task which you can check after you have given the task a try.

**Useful methods:** 
* Compute the vector dot product with `jax.numpy.dot` ([docs](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.dot.html)).
* When you try to conditionally assigne a value of $0$ or $1$ to the action based on the result of the dot product, you should use `jax.lax.select` ([docs](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html#jax.lax.select)). The method `jax.lax.select` takes three arguments as input. The first argument is a statement that will evaluate to either `True` or `False`. If the statment is `True` then `jax.lax.select` will return its second argument. If the statement is `False`, `jax.lax.select` will return its third argument.



In [None]:
def linear_policy(params, obs):
  """A simple linear policy
  
  Args:
    params: a vector of four real-numbers that give the parameters of the policy
    obs: a vector of four real-numbers that give the agent's observation

  Returns:
    a discrete action given by a 0 or 1
  """
  # YOUR CODE
  dot_product_result = ...

  action = jax.lax.select(
      ..., # boolean statement goes here
      ..., # result when the statement is True goes here
      ..., # result when the statement is False goes here
  )
  # END YOUR CODE
  return action

In [None]:
# @title Check exercise 1 (run me) { display-mode: "form" }

def check_linear_policy(linear_policy):
  fixed_obs = jnp.array([1,1,2,4])
  
  # check case1 - negative dot product.
  # weights
  params1 = jnp.array([1,1,1,1])
  params2 = jnp.array([-1,-1,-1,-1])

  hint1 = f"Incorrect answer, your linear policy is incorrect. The action when \
  obs={fixed_obs} and params={params1} should be 1"

  hint2 = f"Incorrect answer, your linear policy is incorrect. The action when \
  obs={fixed_obs} and params={params2} should be 0"

  hint = None
  if linear_policy(params1, fixed_obs) != 1:
    hint = hint1
  elif linear_policy(params2, fixed_obs) != 0:
    hint = hint2

  if hint is not None:
    print(hint)
  else:
    print("Your function is correct!")

try:    
  check_linear_policy(linear_policy)
except Exception as e:
    print("An Error Occured: {}".format(e))

In [None]:
# @title  Solution exercise 1 { display-mode: "form" }

def linear_policy(params, obs):
  """A simple linear policy
  
  Args:
    params: a vector of four real-numbers that give the parameters of the policy
    obs: a vector of four real-numbers that give the agent's observation

  Returns:
    a discrete action given by a 0 or 1
  """
  # YOUR CODE
  dot_product_result = jax.numpy.dot(params, obs)

  action = jax.lax.select(
      dot_product_result > 0, # boolean statement goes here
      1, # result when the statement is True goes here
      0, # result when the statement is False goes here
  )
  # END YOUR CODE
  return action 

### The Environment Transition Function - $P$

Now that we have a policy we can pass actions from the agent to the environment. The environment will then transition to a new state in response to the agent's action. 

In RL we model this process by using a **state transition function** $P$ which takes the current state $s_t$ and an action $a_t$ as input and returns the next state $s_{t+1}$ as output:

<center>
 $s_{t+1}=P(s_t, a_t)$
</center> 

In gym, we can pass actions to the environment by calling the `env.step(<action>)` function. The function will then return four values:
- the **next observation**
- the **reward** for the action taken
- a boolean flag to indicate if the game is **done** 
- some **extra** information.




In [None]:
# Get the initial obs by resetting the env
initial_obs = env.reset()

# Randomly sample actions from env
action = env.action_space.sample()

# Step the environment
next_obs, reward, done, info = env.step(action)

print("Observation:", initial_obs)
print("Action:", action)
print("Next observation:", next_obs)
print("Reward:", reward)
print("Game is done:", done)

### Episode Return - $R_t$

In RL we usually break an agent's interactions with the environment up into **episodes**.The sum of all rewards collected during an episode is what we call the episode's **return** - $R_t$:

<center>
$R_t=\sum_{t=0}^Tr_t$,
</center>

where $r_t$ is the reward at time $t$ and $T$ is the last timestep. The goal in RL is for the agent to chose actions which maximise this expected future return $R_t$.  


### Agent-environment Loop
Now that we know what a policy is and we know how to step the environment, let's close the agent-environment loop.

**Exercise 2:** Write a function that runs one episode of CartPole by sequentially choosing actions and stepping the environment. You should use the linear policy we defined earlier to chose actions. The function should keep track of the reward received and output the return at the end of the episode.

In CartPole the agent receives a reward of `1` for every timestep the pole is still upright. If the pole falls over, the game is over and the agent receives no more reward. The game is also over after `200` timesteps, so the maximum reward the agent can collect is `200`.

In [None]:
def run_episode(env):
  episode_return = 0 # counter to keep track of rewards
  done = False # initially set to False
  params = jnp.array([1,-2,2,-1]) # fixed policy parameters

  ## YOUR CODE
  
  obs = ... # TODO: get the initial obs from the env

  while not done: # loop until episode is done

    action = ... # TODO: compute action using linear policy
    action = np.array(action) # We need to the convert the action from the policy to a np.array

    obs, reward, done, info = ... # TODO: step the environment

    episode_return = ... # TODO: add reward to episode return

  return episode_return

In [None]:
# @title Check exercise 2 (run me) { display-mode: "form" }

try:
  env.seed(42)
  if run_episode(env) == 31:
    print("Looks correct!")
  else:
    print("Looks like your implementation might be wrong.")
except Exception as e:
    print("An Error Occured: {}".format(e))

In [None]:
#@title Soluction Exercise 2 { display-mode: "form" }
def run_episode(env):
  episode_return = 0 # counter to keep track of rewards
  done = False # initially set to False
  params = jnp.array([1,-2,2,-1]) # fixed policy parameters

  ## YOUR CODE
  
  obs = env.reset() # TODO: get the initial obs from the env

  while not done: # loop until episode is done

    # HINT: You might need to the convert the action from your policy to a np.array
    action = linear_policy(params, obs) # TODO: compute action using linear policy
    action = np.array(action) # We need to the convert the action from the policy to a np.array

    obs, reward, done, info = env.step(action) # TODO: step the environment

    episode_return = episode_return + reward # TODO: add reward to episode return

  return episode_return


In CartPole, the environment is considered solved when the agent can reliably achieve an episode return of 500. As you can see, our current policy is far from optimal.

One way we can find an optimal policy is by randomly trying out different policies until we find one that is optimal. This strategy is called Random Policy Search and can be supprisingly effective.

Before we implement Random Policy Search, let's quickly cover the general RL training loop we will be using  to implement the algorithms in the rest of this turorial.

### A General Purpose RL Training Loop
We have implemented a general purpose RL training loop for you. The training loop takes several arguments as input but the three most important for you to understand are `agent_select_action_func`, `agent_learn_func` and the `agent_memory`. 

* The `agent_select_action_func` is a function we define and can pass to the training loop. The function takes an observation and set of `agent_params` as input and should return an action.
* The `agent_learn_func` is another method we define and pass to the training loop. It should take the agent's parameters and some "memories" as input and then update and return the agents new parameters.
* The `agent_memory` is a general purpose module we define that can store some relevant information about the agent's experiences in the environment that can be used in the `agent_learn_func`.


Below is the training loop function we have implemented for you. You are welcome to go through the code and try to understand it but this is not required. As such, we have hidden the code by default, just make sure that you run the code cell before moving on because the training loop is used throughout this prac.

In [None]:
#@title Training loop (run me) { display-mode: "form" }

# NamedTuple to store transitions
Transition = collections.namedtuple("Transition", ["obs", "action", "reward", "next_obs", "done"])

# Training Loop
def run_training_loop(env_name, agent_params, agent_select_action_func, 
    agent_actor_state=None, agent_learn_func=None, agent_learner_state=None, 
    agent_memory=None, num_episodes=1000, evaluator_period=100, 
    evaluation_episodes=8, learn_steps_per_episode=1, 
    train_every_timestep=False, video_subdir="",):
    """
    This function runs several episodes in an environment and periodically does 
    some agent learning and evaluation.
    
    Args:
        env: a gym environment.
        agent_params: an object to store parameters that the agent uses.
        agent_select_func: a function that does action selection for the agent.
        agent_actor_state (optional): an object that stores the internal state 
            of the agents action selection function.
        agent_learn_func (optional): a function that does some learning for the 
            agent by updating the agent parameters.
        agent_learn_state (optional): an object that stores the internal state 
            of the agent learn function.
        agent_memory (optional): an object for storing an retrieving historical 
            experience.
        num_episodes: how many episodes to run.
        evaluator_period: how often to run evaluation.
        evaluation_episodes: how many evaluation episodes to run.
        train_every_timestep: whether to train every timestep rather than at the end 
            of the episode.
        video_subdir: subdirectory to store epsiode recordings.

    Returns:
        episode_returns: list of all the episode returns.
        evaluator_episode_returns: list of all the evaluator episode returns.
    """

    # Setup Cartpole environment and recorder
    env = gym.make(env_name, render_mode="rgb_array") # training environment
    eval_env = gym.make(env_name, render_mode="rgb_array") # evaluation environment

    # Video dir
    video_dir = "./video"+"/"+video_subdir

    # Clear video dir
    try:
      rmtree(video_dir)
    except:
      pass

    # Wrap in recorder
    env = RecordVideo(env, video_dir+"/train", episode_trigger=lambda x: (x % evaluator_period) == 0)
    eval_env = RecordVideo(eval_env, video_dir+"/eval", episode_trigger=lambda x: (x % evaluation_episodes) == 0)

    # JAX random number generator
    rng = hk.PRNGSequence(jax.random.PRNGKey(0))
    env.seed(0) # seed environment for reproducability
    random.seed(0)

    episode_returns = [] # List to store history of episode returns.
    evaluator_episode_returns = [] # List to store history of evaluator returns.
    timesteps = 0
    for episode in range(num_episodes):

        # Reset environment.
        obs = env.reset()
        episode_return = 0
        done = False

        while not done:

            # Agent select action.
            action, agent_actor_state = agent_select_action_func(
                                            next(rng), 
                                            agent_params, 
                                            agent_actor_state, 
                                            np.array(obs)
                                        )

            # Step environment.
            next_obs, reward, done, _ = env.step(int(action))

            # Pack into transition.
            transition = Transition(obs, action, reward, next_obs, done)

            # Add transition to memory.
            if agent_memory: # check if agent has memory
              agent_memory.push(transition)

            # Add reward to episode return.
            episode_return += reward

            # Set obs to next obs before next environment step. CRITICAL!!!
            obs = next_obs

            # Increment timestep counter
            timesteps += 1

            # Maybe learn every timestep
            if train_every_timestep and (timesteps % 4 == 0) and agent_memory and agent_memory.is_ready(): # Make sure memory is ready
                # First sample memory and then pass the result to the learn function
                memory = agent_memory.sample()
                agent_params, agent_learner_state = agent_learn_func(
                                                        next(rng), 
                                                        agent_params, 
                                                        agent_learner_state, 
                                                        memory
                                                    )

        episode_returns.append(episode_return)

        # At the end of every episode we do a learn step.
        if agent_memory and agent_memory.is_ready(): # Make sure memory is ready

            for _ in range(learn_steps_per_episode):
                # First sample memory and then pass the result to the learn function
                memory = agent_memory.sample()
                agent_params, agent_learner_state = agent_learn_func(
                                                        next(rng), 
                                                        agent_params, 
                                                        agent_learner_state, 
                                                        memory
                                                    )

        if (episode % evaluator_period) == 0: # Do evaluation

            evaluator_episode_return = 0
            for eval_episode in range(evaluation_episodes):
                obs = eval_env.reset()
                done = False
                while not done:
                    action, _ = agent_select_action_func(
                                    next(rng), 
                                    agent_params, 
                                    agent_actor_state, 
                                    np.array(obs), 
                                    evaluation=True
                                )

                    obs, reward, done, _ = eval_env.step(int(action))

                    evaluator_episode_return += reward

            evaluator_episode_return /= evaluation_episodes

            evaluator_episode_returns.append(evaluator_episode_return)

            logs = [
                    f"Episode: {episode}",
                    f"Episode Return: {episode_return}",
                    f"Average Episode Return: {np.mean(episode_returns[-20:])}",
                    f"Evaluator Episode Return: {evaluator_episode_return}"
            ]

            print(*logs, sep="\t") # Print the logs

    env.close()
    eval_env.close()

    return episode_returns, evaluator_episode_returns

## Section 2: Random Policy Search (RPS)
In Section 1 we used a fixed set of parameters for our policy. That is to say we, didn't learn $\pi$'s parameters $\theta$, we simply kept them fixed ( `params = [1,-2,2,-1]`). 

We will now implment Random Policy Search (RPS), which is an algorithm that  randomly tries different policy parameters and keeps track of the best parameters found so far. We will say that policy parameters $\theta_A$ are better than parameters $\theta_B$ if the average episode return achieved over the last 20 episodes by the policy with parameters $\theta_A$ is greater than that of the policy with parameters $\theta_B$. 


To keep track of the "current" parameters as well as the "best" parameters, we will use a [NamedTuple](https://www.geeksforgeeks.org/namedtuple-in-python/).

In [None]:
# Parameter container for Random Policy Search
RandomPolicySearchParams = collections.namedtuple("RandomPolicySearchParams", ["current", "best"])

# TEST: store two different sets of parameters
current_params = np.ones(obs_shape) * -1
best_params = np.zeros(obs_shape)
rps_params = RandomPolicySearchParams(current_params, best_params)

# How to access the best or current params.
print(f"Best params: {rps_params.best}")
print(f"Current params: {rps_params.current}")

Next we will implement the following:
  - **RPS select action function** - define how we choose actions given a set of parameters.
  - **RPS memory module** - define what experiences to store from the environment interactions.
  - **RPS learn function** - define how we update and improve our policy parameters.

### RPS choose action function
Let's implement a function called `random_policy_search_choose_action` which we can pass to the training loop. The function needs to take in several arguments inorder for it to interface nicely with our generalised training loop but you will only need to use three of them - `params`, `obs` and `evaluation`.

- `params` is an instance of `RandomPolicySearchParams` with "current" and "best".
- `obs` is the latest observation from the environment.
- `evaluation` is a boolean value that indicates if we should use the "current" or "best" parameters.  When `evaluation==True` we should use the "best" parameters, else we should use the "current" parameters.

**Exercise 3:** Implement the `random_policy_search_choose_action` function as described above. You should make use of the `linear_policy` method we defined earlier. You will also want to use `jax.lax.select()` to conditionally return the "best" action or the "current" action. 

In [None]:
def random_policy_search_choose_action(
    key, 
    params, 
    actor_state, 
    obs, 
    evaluation=False
):
  """Random policy search select action method.

  Args:
    key: a random number (seed). Not used in this function.
    params: the agent's parameters. In this case an instance of `RandomPolicySearchParams`
    actor_state: some extra information about the actor. Not used in this function.
    obs: the latest observation.
    evaluation: a boolean indicating whether to use the best "parameters" or the "current" ones.

  Returns:
    The chosen action and the updated actor_state. In this function the actor_state is not updated.
  """

  # YOUR CODE

  best_action = ...

  current_action = ...

  action = jax.lax.select(
      ... ,
      ... ,
      ... 
  )

  # END YOUR CODE

  return action, actor_state

In [None]:
# @title Check exercise 3 (run me) {display-mode: "form"}

def check_random_policy_search_choose_action(choose_action):
  key = None # not used
  actor_state = None # not used

  # obs
  obs = np.ones(obs_shape)

  evaluation=False
  current_params = np.ones(obs_shape) * -1
  best_params = np.ones(obs_shape)
  rps_params = RandomPolicySearchParams(current_params, best_params)
  action, actor_state = choose_action(key,rps_params,actor_state,obs,evaluation)
  if action != 0:
    return False

  evaluation=True
  current_params = np.ones(obs_shape) * -1
  best_params = np.ones(obs_shape)
  rps_params = RandomPolicySearchParams(current_params, best_params)
  action, actor_state = choose_action(key,rps_params,actor_state,obs,evaluation)
  if action != 1:
    return False

  return True

try:
  if check_random_policy_search_choose_action(random_policy_search_choose_action):
    print("Your function looks correct.")
  else:
    print("Your function looks incorrect.")
except Exception as e:
    print("An Error Occured: {}".format(e))

In [None]:
#@title Solution exercise 3 {display-mode: "form"}

def random_policy_search_choose_action(
    key, 
    params, 
    actor_state, 
    obs, 
    evaluation=False
):

  best_action = linear_policy(params.best, obs)

  current_action = linear_policy(params.current, obs)

  action = jax.lax.select(
      evaluation ,
      best_action ,
      current_action 
  )

  return action, actor_state

### RPS agent memory

For the Random Policy Search algorithm we will need to keep track of the average episode return for the last 50 episodes. Remember that we said the "current" parameters will be deemed to be the "best"parameters if the average episode return when using those parameters is greater than the previous best average episode return.

We will use a general purpose memory interface which is fairy simple. The memory module should have three methods. The first is a function `memory.push(<transition>)` that adds some information about the latest environment transition to the memory. The second is a function `memory.is_ready()` to check if the memory is ready to do some learning. Finally a function `memory.sample()` should return the latest set of memories that can be passed to the `agent_learn_func`.

#### Average Episode Return Memory
We have built a simple agent memory module for you. It stores the `epsisode_returns` of the last 20 episodes. Read through our implementation below and see if you can understand it. The `memory.sample()` method returns the average episode return over the last 20 episodes.

In [None]:
class AverageEpisodeReturnBuffer:

    def __init__(self, num_episodes_to_store=50):
        """
        This class implements an agent memory that stores the average episode 
        return over the last 50 episodes.
        """
        self.num_episodes_to_store = num_episodes_to_store
        self.episode_return_buffer = []
        self.current_episode_return = 0

    def push(self, transition):
        self.current_episode_return += transition.reward

        if transition.done: # If the episode is done
            # Add episode return to buffer
            self.episode_return_buffer.append(self.current_episode_return)

            # Reset episode return
            self.current_episode_return = 0


    def is_ready(self):
        return len(self.episode_return_buffer) == self.num_episodes_to_store

    def sample(self):
        average_episode_return = np.mean(self.episode_return_buffer)

        # Clear episode return buffer
        self.episode_return_buffer = []

        return average_episode_return

### RPS learn function
Fnally, we need to implement the `random_policy_search_learn` function for our Random Policy Search algorithm. The learn function is quite simple. All we need to do is check if the current parameters are better than the best parameters. If they are better, then set the best parameters to be the current parameters and randomly generate a new set of current parameters. 

**Exercise 4:** Write a function to randomly generates new weights using JAX. The weights should be sampled from the interval `[-2,2]`.

**Useful functions:** 
*   `jax.random.uniform` ([docs](https://jax.readthedocs.io/en/latest/_autosummary/jax.random.uniform.html#jax.random.uniform))



In [None]:
def get_new_random_weights(random_key, old_weights, minval=-2.0, maxval=2.0):
    new_weights_shape = old_weights.shape # you will need to use these values
    new_weights_dtype = old_weights.dtype # you will need to use these values

    # YOUR CODE

    new_params = ...

    # END YOUR CODE

    return new_params

In [None]:
# @title Check exercise 4 (run me) {display-mode: "form"}

def check_get_new_random_weights(get_new_random_weights):
  old_weights = np.ones(obs_shape, "float32")
  random_key = jax.random.PRNGKey(42)

  # Case 1
  new_weights = get_new_random_weights(random_key, old_weights, minval=-2.0, maxval=2.0)
  
  if jnp.array_equal(new_weights, jnp.array([ 0.29657745,1.4265499, -1.7621555, -1.7505779 ])):
    print("Function is correct!")
  else:
    print("Something is wrong.")

try:
  check_get_new_random_weights(get_new_random_weights)
except Exception as e:
    print("An Error Occured: {}".format(e))

In [None]:
#@title Solution exercise 4 {display-mode: "form"}

def get_new_random_weights(random_key, old_weights,minval=-2.0,maxval=2.0):
    new_weights_shape = old_weights.shape
    new_weights_dtype = old_weights.dtype
    # Sample new weights
    new_weights = jax.random.uniform(random_key,new_weights_shape,new_weights_dtype,minval=minval,
                      maxval=maxval)
    return new_weights

Our learn function receives a memory, in the form of the average episode return, from the `AverageEpisodeReturnMemory` we implemented earlier. We can use this to compare the current parameters to the best parameters. But we will also need to keep track of the best average episode return for the learn function. For that, we can use the `learn_state` argument which is passed to the `agent_learn_func` in our training loop. As with the `RandomPolicySearchParams`, we will use a NamedTuple to store the `best_average_episode_return` in the `learn_state`.

In [None]:
# A NamedTuple to store the best average episode return so far
RandomPolicyLearnState = collections.namedtuple(
  "RandomPolicyLearnState", 
  ["best_average_episode_return"]
)

# Test
initial_learn_state = RandomPolicyLearnState(best_average_episode_return=-float("inf"))
print("Initial best average episode return:", initial_learn_state.best_average_episode_return)

Now we have everything we need to implement the `random_policy_search_learn` function.

**Exercise 5:** Implement the `random_policy_search_learn` function. The function should check if the "current" parameters are better than the "best" parameters by comparing the `current_average_episode_return` to the `best_average_episode_return`. The function should also update the `learn_state`.

In [None]:
def random_policy_search_learn(key, params, learn_state, memory):
    best_params = params.best 
    current_params = params.current

    current_average_episode_return = memory # the memory contains the average episode return
    best_average_episode_return = learn_state.best_average_episode_return


    # YOUR CODE

    best_params = jax.lax.select(
        ... ,
        ... ,
        ...
    )
        
    best_average_episode_return = jax.lax.select(
        ... ,
        ... ,
        ...
    )
    
    # END YOUR CODE

    # Generate new random parameters
    new_params = get_new_random_weights(key, current_params)

    # Bundle weights in RandomPolicySearchParams NamedTuple
    params = RandomPolicySearchParams(current=new_params, best=best_params)

    return params, RandomPolicyLearnState(best_average_episode_return)

In [None]:
#@title Check exercise 5 {display-mode: "form"}

params = RandomPolicySearchParams(np.ones(obs_shape, "float32"), np.ones(obs_shape, "float32") * -1)
learn_state = RandomPolicyLearnState(10)
memory = 11
key = jax.random.PRNGKey(42)

try:
  new_params, new_learn_state = random_policy_search_learn(key, params, learn_state, memory)

  if not jnp.array_equal(new_params.current, jnp.array([ 0.29657745,  1.4265499 , -1.7621555 , -1.7505779 ])):
    print("Your function is incorrect.")

  elif not jnp.array_equal(new_params.best, jnp.array([1., 1., 1., 1.])):
    print("Your function is incorrect.")

  elif new_learn_state.best_average_episode_return != 11:
    print("Your function is incorrect.")

  else:
    print("Your function looks correct.")
except Exception as e:
    print("An Error Occured: {}".format(e))

In [None]:
#@title Solution exercise 5 {display-mode: "form"}

def random_policy_search_learn(key, params, learn_state, memory):
    best_params = params.best 
    current_params = params.current

    current_average_episode_return = memory # the memory contains the average episode return
    best_average_episode_return = learn_state.best_average_episode_return


    # YOUR CODE

    best_params = jax.lax.select(
        current_average_episode_return > best_average_episode_return,
        current_params,
        best_params,
    )
        
    best_average_episode_return = jax.lax.select(
        current_average_episode_return > best_average_episode_return,
        current_average_episode_return,
        best_average_episode_return
    )
    
    # END YOUR CODE

    # Generate new random parameters
    new_params = get_new_random_weights(key, current_params)

    # Bundle weights in RandomPolicySearchParams NamedTuple
    params = RandomPolicySearchParams(current=new_params, best=best_params)

    # Update learn_state
    learn_state = RandomPolicyLearnState(best_average_episode_return)

    return params, learn_state

### RPS training loop
We can now put everything together by passing the `memory` module, the `learn` function and `choose_action` function to the training loop. To help speed up our algorithm we will use `jax.jit` on the `learn` function and the `choose_action` function.

In [None]:
# JIT the learn and choose action functions
random_policy_search_learn_jit = jax.jit(random_policy_search_learn)
random_policy_search_choose_action_jit = jax.jit(random_policy_search_choose_action)

# Initialise the parameters
initial_weights = np.ones(obs_shape, "float32")
initial_params = RandomPolicySearchParams(initial_weights, initial_weights)

# Initialise the learn state
initial_learn_state = RandomPolicyLearnState(best_average_episode_return=-float("inf"))

# Initialise memory
memory = AverageEpisodeReturnBuffer(num_episodes_to_store=50)

# Run the training loop
print("Starting training. This may take up to 5 minutes to complete.")
chex.clear_trace_counter()
episode_return, evaluator_episode_returns = run_training_loop(
                                        env_name,
                                        initial_params, 
                                        random_policy_search_choose_action_jit, 
                                        None, # no actor state
                                        random_policy_search_learn_jit, 
                                        initial_learn_state, 
                                        memory, 
                                        num_episodes=1001,
                                        video_subdir="rps"
                                    )

# Plot graph of evaluator episode returns
plt.plot(np.linspace(0, 1000, len(evaluator_episode_returns)), evaluator_episode_returns)
plt.title("Random Policy Search")
plt.xlabel("Episodes")
plt.ylabel("Episode Return")
plt.show()

Hopefully, you found a set of optimal parameters on CartPole (episode return reaches `200`). In the cell below you can watch some videos of the agent doing the task.

In [None]:
#@title Visualise Policy {display-mode: "form"}
#@markdown Choose an episode number that is a multiple of 100 and less than or equal to 2000, and **run** this cell.
episode_number = 0 #@param {type:"number"}

assert (episode_number % 100) == 0, "Episode number must be a multiple of 100 since we only record every 100th episode."
assert episode_number < 1001, "Episode number must be less than or equal to 2000"

eval_episode_number = int(episode_number / 100 * 8)
video_path = f"./video/rps/eval/rl-video-episode-{eval_episode_number}.mp4"

mp4 = open(video_path,'rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
HTML("""
<video width=400 controls>
      <source src="%s" type="video/mp4">
</video>
""" % data_url)

So, Random Policy Search did pretty well on this task. However, there is very little (if any) real *learning* going on here. Next, let's look at implementing a simple RL algorithm instead, that can use its experiences to guide our search for an optimal policy, rather than just randomly searching for it.

## Section 3: Policy Gradients (PG)
As discussed, the goal in RL is to find a policy which maximise the expected cummulative reward (return) the agent receives from the environment. We can write the expected return of a policy as:

$J(\pi_\theta)=\mathrm{E}_{\tau\sim\pi_\theta}\ [R(\tau)]$,

where $\pi_\theta$ is a policy parametrised by $\theta$, $\mathrm{E}$ means *expectation*, $\tau$ is shorthand for "*episode*", $\tau\sim\pi_\theta$ is shorthand for "*episodes sampled using the policy* $\pi_\theta$", and $R(\tau)$ is the return of episode $\tau$.

Then, the goal in RL is to find the parameters $\theta$ that maximise the function $J(\pi_\theta)$. One way to find these parameters is to perform gradient ascent on $J(\pi_\theta)$ with respect to the parameters $\theta$: 

$\theta_{k+1}=\theta_k + \alpha \nabla J(\pi_\theta)|_{\theta_{k}}$,

where $\nabla J(\pi_\theta)|_{\theta_{k}}$ is the gradient of the expected return with respect to the policy parameters $\theta_k$ and $\alpha$ is the step size. This quantity, $\nabla J(\pi_\theta)$, is also called the **policy gradient** and is very important in RL. If we can comput the policy gradient, then we will have a means by which to directly optimise our policy.

As it turns out, there is a way for us to compute the policy gradient and the mathematical derivation can be found [here](https://spinningup.openai.com/en/latest/spinningup/rl_intro3.html). But for this tutorial we will ommit the derivation and just give you the result:


$\nabla_{\theta} J(\pi_{\theta})=\underset{\tau \sim \pi_{\theta}}{\mathrm{E}}[\sum_{t=0}^{T} \nabla_{\theta} \log \pi_{\theta}(a_{t} \mid s_{t}) R(\tau)]$

Informaly, the policy gradient is equal to the gradient of the log of the probability of the action chosen, multiplied by the return of the episode in which the action was taken.


### REINFORCE
REINFORCE is a simple RL algorithm that uses the policy gradient to find the optimal policy by increasing the probability of choosing actions (reinforcing actions) that tend to lead to high return episodes.

**Exercise 6:** Implement a function that takes the probability of an action and the return of the episode the action was taken in and computes the log of the probability, multiplied by the return. Make sure you use JAX.

**Useful functions:**
*   `jax.numpy.log`([docs](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.log.html))

In [None]:
def compute_weighted_log_prob(action_prob, episode_return):

    # YOUR CODE

    log_porb = ...

    weighted_log_prob = ... 

    # END YOUR CODE

    return weighted_log_prob

In [None]:
#@title Check exercise 6 {display-mode: "form"}

try:
  action_prob = 0.8
  episode_return = 100
  result = compute_weighted_log_prob(action_prob, episode_return)
  if result != -22.314354:
    print("Your implementation looks incorrect.")
  else:
    print("Looks correct.")
except Exception as e:
    print("An Error Occured: {}".format(e))

In [None]:
#@title Solution exercise 6 {display-mode: "form"}

def compute_weighted_log_prob(action_prob, episode_return):

    # YOUR CODE

    log_prob = jax.numpy.log(action_prob)

    weighted_log_prob = log_prob * episode_return

    # END YOUR CODE

    return weighted_log_prob


### Rewards-to-go
Performing gradient ascent on the gradient of the log of the action probability, weighted by the return of the episode will tend to push up the probability of actions that were in episodes with high return, regardless of *where* in the episode the action was taken. This does not really make much sense because an action near the end of an episode may be reinforced because lots of reward was collected earlier on in the episode, *before* the action was taken. RL agents should really only reinforce actions on the basis of their *consequences*. Rewards obtained before taking an action have no bearing on how good that action was: only rewards that come after. The cummulative rewards received after an action was taken is called the **rewards-to-go** and can be computed as:

$\hat{R}_i=\sum_{t=i}^Tr_t$

Compare this to the episode return:

$R(\tau)=\sum_{t=0}^Tr_t$

We can improve the reliability of the policy gradient by substituting the episode return with the rewards-to-go. The policy gradient with rewards-to-go is given by:

$\nabla_{\theta} J(\pi_{\theta})=\underset{\tau \sim \pi_{\theta}}{\mathrm{E}}[\sum_{t=0}^{T} \nabla_{\theta} \log \pi_{\theta}(a_{t} \mid s_{t}) \hat{R}_t]$

**Exercise 7:** Implement a function that takes a list of all the rewards obtained in an episode and computes the rewards-to-go. Don't worry about using JAX in this function. You can use regular Python operations like `for-loops`.

In [None]:
def compute_rewards_to_go(rewards):
    """
    This function should take a list of rewards as input and 
    compute the rewards-to-go for each timestep.

    EXAMPLE: compute_rewards_to_go([1,2,3,4]) = [10, 9, 7, 4]
    
    Arguments:
        rewards[t] is the reward at time step t.

    Returns:
        rewards_to_go[t] should be the reward-to-go at timestep t.
    """

    rewards_to_go = []

    # YOUR CODE


    # END YOUR CODE

    return rewards_to_go

In [None]:
#@title Check exercise 7 {display-mode: "form"}

try: 
  result = compute_rewards_to_go([1,2,3,4])

  if result != [10, 9, 7, 4]:
    print("There is a problem with your implementation.")
  else:
    print("Looks correct.")
except Exception as e:
    print("An Error Occured: {}".format(e))


In [None]:
#@title Solution exercise 7 {display-mode: "form"}

def compute_rewards_to_go(rewards):
    rewards_to_go = []
    for i in range(len(rewards)):
        r2g = 0
        for j in range(i, len(rewards)):
            r2g += rewards[j]
        rewards_to_go.append(r2g)
    return rewards_to_go

### REINFORCE memory
Next we will need to make a new agent memory to store the rewards-to-go $\hat{R}_t$ along with the observation $o_t$ and action $a_t$ at every timestep. Below we implemented such a memory module for you. The function `memory.sample()` will return a batch of last 500 memories. You are welcome to read through the code to try and understand it, but it is not required. Therefore, we hide the code by default.

In [None]:
# @title Memory implementation (run me) {display-mode: "form"}

# NamedTuple to store memory
EpisodeRewardsToGoMemory = collections.namedtuple("EpisodeRewardsToGoMemory", ["obs", "action", "reward_to_go"])

class EpisodeRewardsToGoBuffer:

    def __init__(self, num_transitions_to_store=512, batch_size=256):
        self.batch_size = batch_size
        self.memory_buffer = collections.deque(maxlen=num_transitions_to_store)
        self.current_episode_transition_buffer = []

    def push(self, transition):
        self.current_episode_transition_buffer.append(transition)

        if transition.done:

            episode_rewards = []
            for t in self.current_episode_transition_buffer:
                episode_rewards.append(t.reward)

            r2g = compute_rewards_to_go(episode_rewards)

            for i, t in enumerate(self.current_episode_transition_buffer):
                memory = EpisodeRewardsToGoMemory(t.obs, t.action, r2g[i])
                self.memory_buffer.append(memory)

            # Reset episode buffer
            self.current_episode_transition_buffer = []


    def is_ready(self):
        return len(self.memory_buffer) >= self.batch_size

    def sample(self):
        random_memory_sample = random.sample(self.memory_buffer, self.batch_size)

        obs_batch, action_batch, reward_to_go_batch = zip(*random_memory_sample)

        return EpisodeRewardsToGoMemory(
            np.stack(obs_batch).astype("float32"), 
            np.asarray(action_batch).astype("int32"), 
            np.asarray(reward_to_go_batch).astype("int32")
        )


# Instantiate Memory
REINFORCE_memory = EpisodeRewardsToGoBuffer(num_transitions_to_store=512, batch_size=256)

### Policy neural network
Next, we will use a simple neural network to aproximate the policy. Our policy neural network will have an input layer that takes the observation as input and passes it through two hidden layers and then outputs one scalar value for each of the possible actions. So, in CartPole the output layer will have size `2`.

[Haiku](https://github.com/deepmind/dm-haiku) is a library for implementing neural networks is JAX. Below we have implemented a simple function to make the policy network for you. 


In [None]:
def make_policy_network(num_actions: int, layers=[20, 20]) -> hk.Transformed:
  """Factory for a simple MLP network for the policy."""

  def policy_network(obs):
    network = hk.Sequential(
        [
            hk.Flatten(),
            hk.nets.MLP(layers + [num_actions])
        ]
    )
    return network(obs)

  return hk.without_apply_rng(hk.transform(policy_network))

Haiku networks have two important functions you need to know about. The first is the `network.init(<random_key>, <input>)`, which returns a set of random initial parameters. The second method is the `network.apply(<params>, <input>)` which passes an input through the network using the set of parameters provided.

In [None]:
# Example
POLICY_NETWORK = make_policy_network(num_actions=num_actions, layers=[20,20])
random_key = jax.random.PRNGKey(42) # random key
dummy_obs = np.ones(obs_shape, "float32")

# Initialise parameters
REINFORCE_params = POLICY_NETWORK.init(random_key, dummy_obs)
print("Initial params:", REINFORCE_params.keys())

# Pass input through the network
output = POLICY_NETWORK.apply(REINFORCE_params, dummy_obs)
print("Policy network output:", output)


The outputs of our policy network are [logits](https://qr.ae/pv4YTe). To convert this into a probability distribution over actions we pass the logits to the [softmax](https://en.wikipedia.org/wiki/Softmax_function) function.

### REINFORCE choose action function

**Exercise 8:** Complete the function below which takes a vector of logits and randomly samples an action from a categorical distibution given by the logits. 

**Useful functions:**
*   `jax.random.categorical` ([docs](https://jax.readthedocs.io/en/latest/_autosummary/jax.random.categorical.html))

In [None]:
def sample_action(random_key, logits):
    
  # YOUR CODE HERE
  action = ...

  # END YOUR code

  return action

In [None]:
#@title Check exercise 8 {display-mode: "form"}

try:
  random_key = jax.random.PRNGKey(42) # random key
  action = sample_action(random_key, np.array([1,2], "float32"))
  if action != 1:
    print("Your function is incorrect.")
  else:
    print("Seems correct.")
except Exception as e:
    print("An Error Occured: {}".format(e))

In [None]:
#@title Solution exercise 8 {display-mode: "form"}

def sample_action(random_key, logits):
    
    # YOUR CODE HERE

    action = jax.random.categorical(random_key, logits)

    # END YOUR code

    return action


Now we can implement the `REINFORCE_choose_action` function. We will pass the observation through the policy network to compute the logits and then pass the logits to the `sample_action` function to choose and action.

In [None]:
def REINFORCE_choose_action(key, params, actor_state, obs, evaluation=False):
  obs = jnp.expand_dims(obs, axis=0) # add dummy batch dim before passing through network

  # Pass obs through policy network to compute logits
  logits = POLICY_NETWORK.apply(params, obs)
  logits = logits[0] # remove batch dim

  # Randomly sample action
  sampled_action = sample_action(key, logits)
  
  return sampled_action, actor_state

Now that we have  implemented the `REINFORCE_choose_action` function, all we have left to do is to make a `REINFORCE_learn` function. The learn function should use the `weighted_log_prob` function we made earlier to compute the policy gradient loss and apply the gradient updates to our neural network.

### Policy gradient loss

**Exercise 9:** Complete the `policy_gradient_loss` function below. The function should compute the action probabilities by passing the `logits` through the softmax function. Then you should extract the probability of the given `action` (using array indexing) and compute the `weighted_log_prob` using the function we made earlier.

**Useful methods:**
*   `jax.nn.softmax` ([docs](https://jax.readthedocs.io/en/latest/_autosummary/jax.nn.softmax.html))

In [None]:
def policy_gradient_loss(action, logits, reward_to_go):

  # YOUR CODE

  all_action_probs = ... # convert logits into probs

  action_prob = ... # using array indexing to get prob of action

  weighted_log_prob = ...

  # END YOUR CODE
  
  loss = - weighted_log_prob # negative because we want gradient `ascent`
  
  return loss

In [None]:
#@title Check exercise 9 {display-mode: "form"}

try:
  result = policy_gradient_loss(1, np.array([1,2], "float32"), 10)
  if result != 3.1326165:
    print("Your implementation looks wrong.")
  else:
    print("Looks correct.")
except Exception as e:
  print("An Error Occured: {}".format(e))


In [None]:
#@title Solution exercise 9 {display-mode: "form"}

def policy_gradient_loss(action, logits, reward_to_go):

  # YOUR CODE

  all_action_probs = jax.nn.softmax(logits) # convert logits into probs

  action_prob = all_action_probs[action]

  weighted_log_prob = compute_weighted_log_prob(action_prob, reward_to_go)

  # END YOUR CODE
  
  loss = - weighted_log_prob # negative because we want gradient `ascent`
  
  return loss

When we do a policy gradient update step we are going to want to do it using a batch of experience, rather than just a single experience like above. We can use JAX's [vmap](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html#jax.vmap) function to easily make our `policy_gradient_loss` function work on a batch of experience.

In [None]:
def batched_policy_gradient_loss(params, obs_batch, action_batch, reward_to_go_batch):
    # Get logits by passing observation through network
    logits_batch = POLICY_NETWORK.apply(params, obs_batch)

    policy_gradient_loss_batch = jax.vmap(policy_gradient_loss)(action_batch, logits_batch, reward_to_go_batch) # add batch

    # Compute mean loss over batch
    mean_policy_gradient_loss = jnp.mean(policy_gradient_loss_batch)

    return mean_policy_gradient_loss

# TEST
obs_batch = np.ones((3, *obs_shape), "float32")
actions_batch = np.array([1,0,0])
rew2go_batch = np.array([2.3, 4.3, 2.1])

loss = batched_policy_gradient_loss(REINFORCE_params, obs_batch, actions_batch, rew2go_batch)

print("Policy gradient loss on batch:", loss)

### Network Optimiser

To apply policy gradient updates to our neural network we will use a JAX library called [Optax](https://github.com/deepmind/optax). Optax has an implementation of the [Adam optimizer](https://www.geeksforgeeks.org/intuition-of-adam-optimizer/) which we can use.

In [None]:
REINFORCE_OPTIMIZER = optax.adam(1e-3)

# Initialise the optimiser
REINFORCE_optim_state = REINFORCE_OPTIMIZER.init(REINFORCE_params)

Now we have everything we need tp make the `REINFORCE_learn` function. We will store the state of the optimiser in the `learn_state`. We will compute the gradient of the policy gradient loss by using `jax.grad` ([docs](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html)).

In [None]:
# A NamedTuple to store the state of the optimiser
REINFORCELearnState = collections.namedtuple("LearnerState", ["optim_state"])


def REINFORCE_learn(key, params, learner_state, memory):
    
  # Get the policy gradient by using `jax.grad()` on `batched_policy_gradient_loss`
  grad_loss = jax.grad(batched_policy_gradient_loss)(params, memory.obs, memory.action, memory.reward_to_go)

  # Get param updates using gradient and optimizer
  updates, new_optim_state = REINFORCE_OPTIMIZER.update(grad_loss, learner_state.optim_state)

  # Apply updates to params
  params = optax.apply_updates(params, updates)

  return params, REINFORCELearnState(new_optim_state) # update learner state

### REINFORCE training loop
Now we can train our REINFORCE agent by putting everything together using the training loop. 

In [None]:
# JIT the choose_action and learn functions for more speed
REINFORCE_learn_jit = jax.jit(REINFORCE_learn)
REINFORCE_choose_action_jit = jax.jit(REINFORCE_choose_action)

# Initial learn state
REINFORCE_learn_state = REINFORCELearnState(REINFORCE_optim_state)

# Run training loop
print("Starting training. This may take up to 10 minutes to complete.")
episode_returns, evaluator_returns = run_training_loop(
                                        env_name,
                                        REINFORCE_params,
                                        REINFORCE_choose_action_jit, 
                                        None, # action state not used
                                        REINFORCE_learn_jit, 
                                        REINFORCE_learn_state, 
                                        REINFORCE_memory,
                                        num_episodes=10_001,
                                        learn_steps_per_episode=2,
                                        video_subdir="reinforce"
                                      )

# Plot the episode returns
plt.plot(episode_returns)
plt.xlabel("Episode")
plt.ylabel("Episode Return")
plt.title("REINFORCE")
plt.show()


In [None]:
#@title Visualise Policy {display-mode: "form"}
#@markdown Choose an episode number that is a multiple of 100 and less than or equal to 1000, and **run this cell**.

episode_number = 100 #@param {type:"number"}

assert (episode_number % 100) == 0, "Episode number must be a multiple of 100 since we only record every 100th episode."
assert episode_number < 1001, "Episode number must be less than or equal to 1000"

eval_episode_number = int(episode_number / 100 * 8)
video_path = f"./video/reinforce/eval/rl-video-episode-{eval_episode_number}.mp4"

mp4 = open(video_path,'rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
HTML("""
<video width=400 controls>
      <source src="%s" type="video/mp4">
</video>
""" % data_url)

## Section 4: Q-Learning
Another common aproach to finding an optimal policy in an environment using RL is via Q-learning. 

### State-Action Value function
In Q-learning the agent learns a function that approximates the **value** of state-action pairs. By *value* we mean the return you expect to receive if you start in a particular state $s_t$, take a particular action $a_t$, and then act according to a particular policy $\pi$ forever after. The state-action value function of policy $\pi$ is given by

$Q_\pi(s,a)=\mathrm{E}_{\tau\sim\pi}\left[R(\tau) \mid s_0=s,\ a_t=a\right]$.

We say that the value function $Q_\pi(s,a)$ is the **optimal** value function if the policy $\pi$ is an optimal policy. We denote the optimal value function as follows:

$Q_\ast(s,a)=\max \limits_\pi \  \mathrm{E}_{\tau\sim\pi}\left[R(\tau) \mid s_0=s,\ a_0=a\right]$

There is an important relationship between the optimal action $a_\ast$ in a state $s$ and the optimal state-action value function $Q_\ast$. Namely, the optimal action $a_\ast$ in state $s$ is equal to the action that maximises the optimal state-action value function. This relationship naturally induces an optimal policy:

$\pi_\ast(s)=\arg \max \limits_a\ Q_\ast(s, a)$

### Greedy action selection

**Exercise 10:** Let's implement a function that, given a vector of Q-values, returns the action with the largest Q-value (i.e. the greedy action).

**Useful methods:**
*   `jax.numpy.argmax` ([docs](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.argmax.html))

In [None]:
# Implement a function takes q-values as input and returns the greedy_action
def select_greedy_action(q_values):

  # YOUR CODE
  action = ...
  # END YOUR CODE

  return action

In [None]:
# @title Check exercise 10 (run me) {display-mode: "form"}

try:
  q_values = jnp.array([1,1,3,4])
  action = select_greedy_action(q_values)

  if action != 3:
    print("Incorrect answer, your greedy action selector looks wrong")
  else:
    print("Looks good.")
except Exception as e:
  print("An Error Occured: {}".format(e))


In [None]:
#@title Solution exercise 10 {display-mode: "form"}

def select_greedy_action(q_values):
  
  # YOUR CODE
  action = jnp.argmax(q_values)
  # END YOUR CODE

  return action

### Q-Network
Unlike in the policy gradient approach from the previous section, in Q-learning and other value-based RL methods we don't need a parameterisation for the policy, rather we parametrise the Q-function using a neural network $Q_\theta$. We obtain a policy from the Q-network by always choosing the action with the *greatest* value:

$\hat{\pi}_\theta(s)=\arg \max \limits_a\ Q_{\theta}(s, a)$

As we did previously, we shall use haiku to make a neural network to approximate this Q-function. The network will take an observation as input and then output a Q-value for each of the available actions. So in the case of CartPole, the output of the network will have size $2$.

In [None]:
def build_network(num_actions: int, layers=[20, 20]) -> hk.Transformed:
  """Factory for a simple MLP network for approximating Q-values."""

  def q_network(obs):
    network = hk.Sequential(
        [hk.Flatten(),
         hk.nets.MLP(layers + [num_actions])])
    return network(obs)

  return hk.without_apply_rng(hk.transform(q_network))

Let's initialise our Q-network and get the initial parameters.

In [None]:
# Initialise Q-network
Q_NETWORK = build_network(num_actions=num_actions, layers=[20, 20]) # two actions

dummy_obs = jnp.zeros((1,*obs_shape), jnp.float32) # a dummy observation like the one in CartPole

random_key = jax.random.PRNGKey(42) # random key
Q_NETWORK_PARAMS = Q_NETWORK.init(random_key, dummy_obs) # Get initial params

print("Q-Learning params:", Q_NETWORK_PARAMS.keys())

Before we implement the loss function required for training our Q-network. Let's first discuss the intuition behind it. 

### The Bellman Equations
The value function can be written recursively as:

$Q_{\pi}(s, a) =\underset{s^{\prime} \sim P}{\mathrm{E}}\left[r(s, a)+ \underset{a^{\prime} \sim \pi}{\mathrm{E}}\left[Q_{\pi}\left(s^{\prime}, a^{\prime}\right)\right]\right]$,

where $s' \sim P$ is shorthand for saying that the next state $s'$ is sampled from the environment’s transition function $P(s'\mid s,a)$. Intuitively, this equation says that the value of the action $a$ you took in the state $s$ is equal to the reward $r$ you expect to get, plus the value you expect to get in the next state $s`$ you land in given that you will choose your next action $a`$ with the policy $\pi$. The Bellman equation for the optimal value function is:

$Q_{*}(s, a) =\underset{s^{\prime} \sim P}{\mathrm{E}}\left[r(s, a)+\ \underset{a^{\prime}}{\max}\ Q_{*}(s^{\prime}, a^{\prime})\right]$

Notice that instead of chosing your next action $a`$ with policy $\pi$, we choose the action with the greatest Q-value.


For a more in-depth discussion of the Bellman Equations, see the [OpenAI Spinning Up](https://spinningup.openai.com/en/latest/spinningup/rl_intro.html) website.

### The Bellman Backup
To learn to approximate the optimal Q-value function, we can use the right-hand side of the Bellman equation as an update rule. In other words, suppose we have a Q-network $Q_\theta$, approximated using parameters $\theta$, then we can iteratively update the parameters such that

$Q_\theta(s,a)\leftarrow r(s, a) + \underset{a'}{\max}\ Q_\theta(s', a')$.

Intuitively, this says that the approximation of the Q-value of action $a$ in state $s$ should be updated such that it is closer to being equal to the reward received from the environment $r(s, a)$ plus the value of the best possible action in the next state $s'$. We can perform this optimisation by minimising the difference between the left and right-hand side, with respect to the parameters $\theta$ using gradient descent. We can measure the difference between the two values using the [squared-error](https://en.wikipedia.org/wiki/Mean_squared_error#Loss_function).

**Exercise 11:** Implement the squared-error function.

**Useful functions**
* `jax.numpy.square` ([docs](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.square.html))

In [None]:
def compute_squared_error(pred, target):
  # YOUR CODE
  squared_error = ...
  # END YOUR CODE

  return squared_error

In [None]:
#@title Check exercise 11 {display-mode: "form"}

try:
  result = compute_squared_error(1, 4)

  if result != 9:
    print("Your implementation looks wrong.")
  else:
    print("Looks good.")
except Exception as e:
  print("An Error Occured: {}".format(e))

In [None]:
#@title Solution exercise 11 {display-mode: "form"}
def compute_squared_error(pred, target):

  # YOUR CODE
  squared_error = jax.numpy.square(pred - target)
  # END YOUR CODE
  return squared_error

**Exercise 12:** Implement a function that computes the **Bellman target** (right-hand side of the Bellman equation). If the episode is at the last timestep (i.e. done==1.0), then the Bellman target should be equal to the reward, with no extra value at the end.

**Useful functions**
* `jax.numpy.max` ([docs](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.max.html))

In [None]:
# Bellman target
def compute_bellman_target(reward, done, next_q_values):
  """A function to compute the bellman target.
  
  Args:
      reward: a scalar reward.
      done: a scalar of value either 1.0 or 0.0, indicating if the transition is a terminal one.
      next_q_values: a vector of q_values for the next state. One for each action.
  Returns:
      A scalar equal to the bellman target.
  
  """
  # YOUR CODE
  bellman_target = ...
  # END YOUR CODE

  return bellman_target

In [None]:
#@title Check exercise 12 {display-mode: "form"}

try:
  # not done
  result1 = compute_bellman_target(1, 0.0, np.array([3,2], "float32"))

  # done
  result2 = compute_bellman_target(1, 1.0, np.array([3,2], "float32"))

  if result1 != 4 or result2 != 1:
    print("Your implementation looks wrong.")
  else:
    print("Looks good.")
except Exception as e:
  print("An Error Occured: {}".format(e))

In [None]:
#@title Solution exercise 12 {display-mode: "form"}

# Bellman target
def compute_bellman_target(reward, done, next_q_values):
    """A function to compute the bellman target.
    
    Args:
        reward: a scalar reward.
        done: a scalar of value either 1.0 or 0.0, indicating if the transition is a terminal one.
        next_q_values: a vector of q_values for the next state. One for each action.
    Returns:
        A scalar equal to the bellman target.
    
    """
    # YOUR CODE
    bellman_target = reward + (1.0 - done) * jax.numpy.max(next_q_values)
    # END YOUR CODE

    return bellman_target


We can now combine these two functions to compute the loss for Q-learning. The Q-learning loss is equal to the squared difference between the predicted Q-value of an action and its corresponding Bellman target.

**Exercise 13:** Implement the Q-learning loss.

In [None]:
def q_learning_loss(q_values, action, reward, done, next_q_values):
    """Implementation of the Q-learning loss.T
    
    Args:
        q_values: a vector of Q-values, one for each action.
        action: an integer, giving the action that was chosen. q_values[action] is the value of the chose action.
        done: is a scalar that indicates if this is a terminal transition.
        next_q_values: a vector of Q-values in the next state.
    Returns:
        The squared difference between the q_value of the chosen action and the bellman target.
    """
    # YOUR CODE
    chosen_action_q_value = ... # q_value of action, use array indexing
    bellman_target = ...
    squared_error = ...
    # END YOUR CODE
    
    return squared_error

In [None]:
#@title Check exercise 13 {display-mode: "form"}

try:
  result = q_learning_loss(np.array([3,2], "float32"), 1, 2, 0.0, np.array([3,2], "float32"))

  if result != 9.0:
    print("Your implementation looks wrong.")
  else:
    print("Looks good.")
except Exception as e:
  print("An Error Occured: {}".format(e))

In [None]:
#@title Solution exercise 13 {display-mode: "form"}

def q_learning_loss(q_values, action, reward, done, next_q_values):
    """Implementation of the Q-learning loss.T
    
    Args:
        q_values: a vector of Q-values, one for each action.
        action: an integer, giving the action that was chosen. q_values[action] is the value of the chose action.
        done: is a scalar that indicates if this is a terminal transition.
        next_q_values: a vector of Q-values in the next state.
    Returns:
        The squared difference between the q_value of the chosen action and the bellman target.
    """
    # YOUR CODE
    chosen_action_q_value = q_values[action]
    bellman_target = compute_bellman_target(reward, done, next_q_values)
    squared_error = compute_squared_error(chosen_action_q_value, bellman_target)
    # END YOUR CODE
    
    return squared_error

### Target Q-network
Notice that when we compute the bellman target we are using our Q-network $Q_\theta$ to compute the value for the next state $s_t$. We are basically using our latest approximation of the Q-function to compute the target of our next approximation. Using an approximation to compute the target for your next approximation, is called bootstrapping. Unfortunately, if we naively bootstrap like this, it can make training a neural network very unstable. To mitigage this we can instead use a different set of parameters $\hat{\theta}$ to compute the values at state $s_{t+1}$. We will keep the parameters $\hat{\theta}$ fixed and only periodically update them to be equal to the latest online parameters $\theta$ every couple of training steps *(say 100)*. This serves to keep the bellman targets fixed for a couple training steps to help reduce the instability due to bootstrapping. 


We will need to keep track of the latest (online) parameters, as well as the target networks parameters. Lets make a `NamedTuple` to store these two values. We will also need to keep track of the number of learner steps we have taken, so that we know when to update the target network. Lets store a `count` of the learn steps in the `learn_state`.

In [None]:
# Store online and target parameters
QLearnParams = collections.namedtuple("Params", ["online", "target"])

# Q-learn-state
QLearnState = collections.namedtuple("LearnerState", ["count", "optim_state"])

We will once again be using Optax to optimize our neural network in JAX. We store the state of the optimizer in the `learn_state` above. Lets now instantiate optimizer and add the initial Q-network parameters to a `QLearnParams` object.

In [None]:
# Initialise Q-network optimizer
Q_LEARN_OPTIMIZER = optax.adam(3e-4) # learning rate

Q_LEARN_OPTIM_STATE = Q_LEARN_OPTIMIZER.init(Q_NETWORK_PARAMS) # initial optim state

# Create Learn State
Q_LEARNING_LEARN_STATE = QLearnState(0, Q_LEARN_OPTIM_STATE) # count set to zero initially

# Add initial Q-network weights to QLearnParams object
Q_LEARNING_PARAMS = QLearnParams(online=Q_NETWORK_PARAMS, target=Q_NETWORK_PARAMS) # target equal to online

Now we can implement a simple function that updates target networks parameters to equal the latest online parameters every 100 training steps.

In [None]:
def update_target_params(learn_state, online_weights, target_weights):
  """A function to update target params every 100 training steps"""

  target = jax.lax.cond(
      jax.numpy.mod(learn_state.count, 100) == 0,
      lambda x, y: x,
      lambda x, y: y,
      online_weights, 
      target_weights
  )

  params = QLearnParams(online_weights, target)

  return params

### Q-learning loss
We now have everything we need to implement the `q_learn` function which takes some batch of transitions and does a step of Q-learning to update the network paramters. But first we use `jax.vmap` to modify the `q_learning_loss` function so that it accepts batches of transitions. In addition, we will compute the Q-values by passing the observations through the `Q_NETWORK` and the target Q-values using the target parameters of the `Q_NETWORK`.

In [None]:
def batched_q_learning_loss(online_params, target_params, obs, actions, rewards, next_obs, dones):
    q_values = Q_NETWORK.apply(online_params, obs) # use the online parameters
    next_q_values = Q_NETWORK.apply(target_params, next_obs) # use the target parameters
    squared_error = jax.vmap(q_learning_loss)(q_values, actions, rewards, dones, next_q_values) # vmap q_learning_loss
    mean_squared_error = jnp.mean(squared_error) # mean squared error over batch
    return mean_squared_error

Now we can create the `q_learn` function which computes the gradient of the `batched_q_learning_loss` and then uses an Optax optimizer to update the network weights and then finally (maybe) updates the target parameters.

In [None]:
def q_learn(rng, params, learner_state, memory):
  # Compute gradients
  grad_loss = jax.grad(batched_q_learning_loss)(params.online, params.target, memory.obs, 
                                          memory.action, memory.reward, 
                                          memory.next_obs, memory.done,
                                          ) # jax.grad

  # Get updates
  updates, opt_state = Q_LEARN_OPTIMIZER.update(grad_loss, learner_state.optim_state)

  # Apply them
  new_weights = optax.apply_updates(params.online, updates)

  # Maybe update target network
  params = update_target_params(learner_state, new_weights, params.target)

  # Increment learner step counter
  learner_state = QLearnState(learner_state.count + 1, opt_state)

  return params, learner_state

### Replay Buffer
For Q-learning we will need an agent memory that stores entire transitions: `obs`, `action`, `reward`, `next_obs`, `done`. When we retrieve transitions from the memory, they should be chosen randomly from all of the transitions collected so far. In RL we often call such a module a **replay buffer**. One benefit of using a replay buffer like this is that experiences can be re-used several times for training unlike in the policy gradient algorithm REINFORCE, where we discarded memories after using them for learning.

In [None]:
class TransitionMemory(object):
  """A simple Python replay buffer."""

  def __init__(self, max_size=10_000, batch_size=256):
    self.batch_size = batch_size
    self.buffer = collections.deque(maxlen=max_size)

  def push(self, transition):

    # add transition to the replay buffer
    self.buffer.append(
        (transition.obs, transition.action, transition.reward, 
          transition.next_obs, transition.done)
    )

  
  def is_ready(self):
    return self.batch_size <= len(self.buffer)

  def sample(self):
    # Randomly sample a batch of transitions from the buffer
    random_replay_sample = random.sample(self.buffer, self.batch_size)

    # Batch the transitions together
    obs_batch, action_batch, reward_batch, next_obs_batch, done_batch = zip(*random_replay_sample)

    return Transition(
        np.stack(obs_batch).astype("float32"), 
        np.asarray(action_batch).astype("int32"), 
        np.asarray(reward_batch).astype("float32"), 
        np.stack(next_obs_batch).astype("float32"), 
        np.asarray(done_batch).astype("float32")
    )

# Instantiate the memory
Q_LEARNING_MEMORY = TransitionMemory(max_size=50_000, batch_size=256)

### Random exploration
We almost have everything we need for a functioning Q-learning agent. But one problem is that if we always choose the action with the highest Q-value then the agent's policy will be completly [deterministic](https://www.quora.com/What-is-the-intuitive-difference-between-a-stochastic-model-and-a-deterministic-model). This means the agent will always choose the same strategy. This can pose a problem because at the start of training, the Q-network will be very inaccurate (i.e. a bad aproximation of the true Q-function). As such, the agent will consistently choose suboptimal actions. Moreover, the agent will never deviate from its suboptimal strategy and will never discover new, potentially more rewarding  actions. As a result, the Q-network remains inaccurate. Ideally, the agent should try out many different strategies so that it can observe the outcomes (rewards) of its actions in different states and so improve its approximation of the Q-function.

One easy way to ensure that the agent tries out many different actions is to let it periodically choose some random actions, instead of the greedy (best) action all the time.

**Exercise 14:** Implement a function that, given the number of possible (discrete) actions, returns a random action.

**Useful methods:**

*  `jax.random.randint` ([docs](https://jax.readthedocs.io/en/latest/_autosummary/jax.random.randint.html))

In [None]:
def select_random_action(key, num_actions):
    
    # YOUR CODE
    action = ...
    # END YOUR CODE

    return action

In [None]:
#@title Check exercise 14 {display-mode: "form"}

try:
  random_key1 = random_key = jax.random.PRNGKey(6) # random key
  random_key2 = random_key = jax.random.PRNGKey(1000) # random key
  result1 = select_random_action(random_key1, 2)
  result2 = select_random_action(random_key2, 2)

  if result1 != 1 or result2 != 0:
    print("Your implementation looks wrong.")
  else:
    print("Looks good.")
except:
  print("Your implementation looks wrong.")

In [None]:
#@title Solution exercise 14 {display-mode: "form"}

def select_random_action(key, num_actions):
    # YOUR CODE
    action = jax.random.randint(
        key, 
        shape=(), 
        minval=0, 
        maxval=num_actions
    )
    # END YOUR CODE

    return action

### $\varepsilon$-greedy action selection
At the start of training, when the accuracy of the Q-network is low, it is worthwhile for the agent to mostly take random actions so that it can learn about how goo/bad actions are. However, as the accuracy of the Q-network improves, the agent should start taking fewer random actions and instead start choosing the greedy actions with respect to the Q-values. Choosing the best actions given the current Q-network is referred to as **exploitation.** In RL we often call the ratio of random to greedy actions **epsilon** $\varepsilon$. Epsilon is usually a decimal value in the interval $[0,1]$, where for example $\varepsilon=0.4$ means that the agent chooses a random action 40% of the time and the greedy action 60% of the time. It is common in RL to linearly decrease the value of epsilon over time so that the agent becomes increasingly greedy as the accuracy of its Q-network improves through learning.


**Exercise 15:** Implement a function that takes the number of timesteps as input and returns the current epsilon value.

In [None]:
EPSILON_DECAY_TIMESTEPS = 3000 # decay epsilon over 3000 timesteps
EPSILON_MIN = 0.1 # 10% exploration

In [None]:
def get_epsilon(num_timesteps):
  # YOUR CODE
  epsilon = ... # decay epsilon

  epsilon = jax.lax.select(
      epsilon < EPSILON_MIN,
      ..., # if less than min then set to min
      ... # else don't change epsilon
  )
  # END YOUR CODE

  return epsilon

In [None]:
#@title Check exercise 15 {display-mode: "form"}
def check_get_epsilon(get_epsilon):
  try:
    result1 = get_epsilon(10)
    result2 = get_epsilon(5_010)

    if result1 != 0.99666667 or result2 != 0.1:
      print("Your function looks wrong.")
    else:
      print("Your function looks correct.")
  except:
    print("Your function looks wrong.")

check_get_epsilon(get_epsilon)


In [None]:
#@title Solution exercise 15 {display-mode: "form"}

def get_epsilon(num_timesteps):

  # YOUR CODE
  epsilon = 1.0 - num_timesteps / EPSILON_DECAY_TIMESTEPS

  epsilon = jax.lax.select(
      epsilon < EPSILON_MIN,
      EPSILON_MIN,
      epsilon
  )
  # END YOUR CODE

  return epsilon

# CHECK
check_get_epsilon(get_epsilon)


**Exercise 16:** Now lets put these functions together to do epsilon-greedy action selection.

In [None]:
def select_epsilon_greedy_action(key, q_values, num_timesteps):  
    num_actions = len(q_values) # number of available actions

    # YOUR CODE HERE
    epsilon = ... # get epsilon value

    should_explore = ... # hint: a boolean expression to check if some random number is less than epsilon

    action = jax.lax.select(
        should_explore,
        ..., # if should explore
        ... # if should be greedy
    )
    # END YOUR CODE

    return action

In [None]:
#@title Check exercise 16 {display-mode: "form"}

try:
  rng = hk.PRNGSequence(jax.random.PRNGKey(42))
  dummy_q_values = jnp.array([0,1], jnp.float32)
  num_timesteps = 5010 # very greedy
  actions1 = []
  for i in range(10):
      actions1.append(int(select_epsilon_greedy_action(next(rng), dummy_q_values, num_timesteps)))

  num_timesteps = 0 # completly random
  actions2 = []
  for i in range(10):
      actions2.append(int(select_epsilon_greedy_action(next(rng), dummy_q_values, num_timesteps)))

  if actions1 != [1, 1, 0, 1, 1, 0, 1, 1, 1, 1] or actions2 != [0, 0, 0, 1, 1, 1, 1, 0, 0, 0]:
    print("Looks like something might be incorrect!")
  else:
    print("Looks correct!")
except:
  print("Looks like something might be incorrect!")

In [None]:
#@title Solution exercise 16 {display-mode: "form"}

# Now make a function that takes an epsilon-greedy action

def select_epsilon_greedy_action(key, q_values, num_timesteps):
    num_actions = len(q_values) # number of available actions

    # YOUR CODE
    epsilon = get_epsilon(num_timesteps)

    should_explore = jax.random.uniform(key, (1,))[0] < epsilon

    num_actions = len(q_values)

    action = jax.lax.select(
        should_explore,
        select_random_action(key, num_actions), 
        select_greedy_action(q_values)
    )
    # END YOUR CODE

    return action

### Q-learning select action

We now have everything we need to make the `q_learning_select_action` function. We will use the `actor_state` to store a counter which keeps track of the current number of timesteps. We can use the counter to decrement our `epsilon` value.

In [None]:
# Actor state stores the current number of timesteps
QActorState = collections.namedtuple("ActorState", ["count"])

def q_learning_select_action(key, params, actor_state, obs, evaluation=False):
    obs = jnp.expand_dims(obs, axis=0) # add dummy batch dim
    q_values = Q_NETWORK.apply(params.online, obs)[0] # remove batch dim

    action = select_epsilon_greedy_action(key, q_values, actor_state.count)
    greedy_action = select_greedy_action(q_values)

    action = jax.lax.select(
        evaluation,
        greedy_action,
        action
    )

    next_actor_state = QActorState(actor_state.count + 1) # increment timestep counter

    return action, next_actor_state

Q_LEARNING_ACTOR_STATE = QActorState(0) # counter set to zero

### Training
We can now put everything together using the agent-environment loop. But first,lets jit the select action function and the learn function for some extra speed.

In [None]:
# Jit functions
q_learning_select_action_jit = jax.jit(q_learning_select_action)
q_learn_jit = jax.jit(q_learn)

# Run environment loop
print("Starting training. This may take up to 8 minutes to complete.")
episode_returns, evaluator_returns = run_training_loop(
                                        env_name,
                                        Q_LEARNING_PARAMS, 
                                        q_learning_select_action_jit, 
                                        Q_LEARNING_ACTOR_STATE,
                                        q_learn_jit,
                                        Q_LEARNING_LEARN_STATE,
                                        Q_LEARNING_MEMORY,
                                        num_episodes=1001,
                                        train_every_timestep=True, # do learning after every timestep
                                        video_subdir="q_learning"
                                    )

plt.plot(episode_returns)
plt.xlabel("Episodes")
plt.ylabel("Episode Return")
plt.title("Deep Q-Learning")
plt.show()

At this stage, the approximated Q-function hopefully converged to a decent policy for balancing the pole in the CartPole problem.

In [None]:
#@title Visualise Policy
#@markdown Choose an episode number that is a multiple of 100 and less than or equal to 1000, and **run this cell**.

episode_number = 0 #@param {type:"number"}

assert (episode_number % 100) == 0, "Episode number must be a multiple of 100 since we only record every 100th episode."
assert episode_number < 1001, "Episode number must be less than or equal to 1000"

eval_episode_number = int(episode_number / 100 * 32)
video_path = f"./video/q_learning/eval/rl-video-episode-{eval_episode_number}.mp4"

mp4 = open(video_path,'rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
HTML("""
<video width=400 controls>
      <source src="%s" type="video/mp4">
</video>
""" % data_url)

This section attempts to summarise [Playing Atari with Deep Reinforcement Learning](https://arxiv.org/abs/1312.5602), the research paper where Deep-Q Learning was first introduced. To understand the concepts covered in this section better, we recommend you give the original paper a read.

## Conclusion
**Summary:**

In this practical we learnt the basics of reinforcement learning (RL).

In the first section we learnt some basic concepts such as environment observations, action selection strategies, rewards, and episodes. We learnt about rewards and that the goal in RL is to learn a policy which maximises some notion of cummulative reward that the agent receives from the environment (return). 

In the second section we searched for an optimal policy in CartPole using an algorithm called RandomSearch. Basically, we tried out different policies until we happened to find one that worked well. This method did not yield consistent results and success required immense luck.

In the third section we learnt about policy gradients and how we can use gradient ascent to adjust the parameters in our agents policy in the direction which maximises the expected cummulative reward (return).

Finally, in the fourth section we learnt about the state-action value function and how it is related to an optimal policy. We implemented an algorithm called Q-learning to learn the optimal state-action value function in CartPole. We learnt about the importance of using a target network and epsilon-greedy exploration.

**Next Steps:** 

Now that you have successfully solved CartPole with two different RL algorithms, REINFORCE and Deep Q-Learning, we now encourage you to use what you have learnt to try and solve some more challenging environments. OpenAI Gym is a great place to find RL environments. [LunarLander](https://www.gymlibrary.dev/environments/box2d/lunar_lander/) is a great next step. You can go to the start of this notebook ande replace the environment with LunarLander by replacing `env = gym.make("CartPole-v1")` with env = `gym.make("LunarLander-v2")`. Note, you will need to increase the number of training episodes in order to learn a good policy in LunarLander because it is a significantly more challenging environment than CartPole.

<center>
<img src="https://miro.medium.com/max/1194/1*Dj2fkRjrMA0w9E-PuyETdg.gif" width="60%" />
</center>

In addition, there are many RL algorithms out there that make significant improvements to REINFORCE and Deep Q-Learning. See these resources:
* [REINFORCE with baseline](https://spinningup.openai.com/en/latest/spinningup/rl_intro3.html#baselines-in-policy-gradients)
* [Double Deep Q-Network](https://arxiv.org/pdf/1509.06461.pdf)
* [Proximal Policy Optimisation (PPO)](https://arxiv.org/pdf/1707.06347.pdf)

If you are looking for a more indepth online course for RL you can checkout these courses:
* [Reinforcement Learning Foundations on LinkedIn Learning](https://www.linkedin.com/learning/reinforcement-learning-foundations) (made by one of our very own tutors, Khaulat Abdulhakeem)
* [An introduction to Reinforcement Learning on FreeCodeCamp](https://www.freecodecamp.org/news/an-introduction-to-reinforcement-learning-4339519de419/)
* [Reinforcement Learning Specialization on Coursera](https://www.coursera.org/specializations/reinforcement-learning)

Finally, the most infuential textbook on RL is available for free online:
* [Reinforcement Learning: An Introduction](http://incompleteideas.net/book/the-book-2nd.html) by Richard S. Sutton and Andrew G. Barto

**Appendix:** 

N/a

**References:** 

* [OpenAI Spinning Up](https://spinningup.openai.com/en/latest/)
* [Deep Q-Network]()

For other practicals from the Deep Learning Indaba, please visit [here](https://github.com/deep-learning-indaba/indaba-pracs-2022).

## Feedback

Please provide feedback that we can use to improve our practicals in the future.

In [None]:
#@title Generate Feedback Form. (Run Cell)
from IPython.display import HTML
HTML(
"""
<iframe 
	src="https://forms.gle/bvLLPX74LMGrFefo9",
  width="80%" 
	height="1200px" >
	Loading...
</iframe>
"""
)

<img src="https://baobab.deeplearningindaba.com/static/media/indaba-logo-dark.d5a6196d.png" width="50%" />