# Deep Q-Learning

We're going to use Deep Q-Learning in order to learn a [cartpole](https://gymnasium.farama.org/environments/classic_control/cart_pole/) agent.  You'll notice the cartpole state space is continuous... Tabular Q-Learning won't work!

Let's start by setting up our environment.  Run the three cells below to install gymnasium on AWS, retrieve a `.npy` file of states, import everything relevant, and see what the observations look like.

After you run the first cell, I suggest commenting it out as you won't need to run it more than once.

In [2]:
'''
!pip install gymnasium gymnasium[classic_control]
!wget https://www.usna.edu/Users/cs/SD312/lab/13DeepQ/sampled_states.npy
'''

--2024-12-02 09:24:36--  https://www.usna.edu/Users/cs/SD312/lab/13DeepQ/sampled_states.npy
Resolving www.usna.edu (www.usna.edu)... 10.4.36.20
Connecting to www.usna.edu (www.usna.edu)|10.4.36.20|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 928
Saving to: ‘sampled_states.npy.1’


2024-12-02 09:24:36 (222 MB/s) - ‘sampled_states.npy.1’ saved [928/928]



In [3]:
import gymnasium as gym
import numpy as np
import numpy.random
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque, namedtuple
import random
import plotly.graph_objects as go
from tqdm import tqdm

In [None]:
env = gym.make('CartPole-v1')
STATE_DIM = env.observation_space.shape[0]
N_ACTIONS = env.action_space.n
GAMMA = .99

print(f'State space is continuous in {STATE_DIM} dimensions, and there are {N_ACTIONS} actions.')

obs, info = env.reset()
print(f'For example, heres an observation: {obs}.')

## Hyperparameters

We have to decide a few things.
- What should $\epsilon$ be for our $\epsilon$-greedy exploration policy?
- How large should our replay be?
- How many datapoints should we pull from our replay to train on at a time (batch size)?
- What should our neural network look like? A good first step here is to make sure you understand what the dimensions of the input and output layers need to be - those aren't up to us, they are prescribed by the problem.

Choose some values, design your network.

Below we've also created a Replay - note that it's essentially a `deque` of limited size.  Make sure you understand that code!  In that replay, we are storing Transitions, each of which consists of a state, an action, a reward, and a next_state.  If the transition represents a failure (pole fell over or cart went off screen), the next_state will be `None`.

In [None]:
EPSILON =
REPLAY_LENGTH =
BATCH_SIZE =
LEARNING_RATE =

In [None]:
# Define a network and create an instance of it

In [None]:
Transition = namedtuple('Transition',
                        ('state', 'action', 'reward', 'next_state'))

class ReplayMemory:

    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, s, a, r, sp):
        """Save a transition"""
        s = torch.tensor(s, dtype=torch.float32)
        a = torch.tensor([a], dtype=torch.int64)
        r = torch.tensor([r], dtype=torch.float32)
        if sp is not None:
            sp = torch.tensor(sp, dtype=torch.float32)
        self.memory.append(Transition(s,a,r,sp))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

memory=ReplayMemory(REPLAY_LENGTH)

## Policies!

The below functions define a greedy policy, an $\epsilon$-greedy policy, and a random policy.

In [None]:
def greedy_policy(network, states):
    '''
    Returns a tuple
    index 0 contains the Q-value of the best action for all states
    index 1 contains the index of the best action for all states
    '''
    with torch.no_grad():
        qs = network(states) # Get the q-values
        if qs.dim() == 1:   # If it's just a single state
            return torch.max(qs, dim=0) # Return the tuple of max info for that state
        return torch.max(qs, dim=1) # Return the tuple of max information for all states

def epsilon_greedy(network, state):
    '''
    Returns an action selected via epsilon-greedy
    '''
    if numpy.random.random() < EPSILON:
        return numpy.random.randint(N_ACTIONS)
    else:
        return greedy_policy(network, torch.tensor(state).to('cuda'))[1].item()

def random_policy():
    '''
    Chooses a random action.
    '''
    return numpy.random.randint(N_ACTIONS)

## Random performance

100 times, reset the environment and run it until truncation or termination, using a random policy.  Print out the average number of steps a random policy keeps the pole upright.

## Evaluating training

We're going to reproduce the graphs on the right-hand side of Figure 2 in the paper in order to judge the smoothness of our training.  Create a function called `avg_qs`, which accepts as arguments your network and a tensor representing a group of states, which you should load from `sampled_states.npy`.  It should then do the following:

- in a `with torch.no_grad()` block, push the states through the network, producing some approximate Q-values
- calculate the maximum Q-value for each state
- average those maximum Q-values over all the states, resulting in a single scalar
- return that scalar (which should just be a number, not a tensor or numpy array - recall you can pull out the value from a tensor by using `.item()`.)

In [None]:
sampled_states = torch.tensor(np.load('sampled_states.npy'), dtype=torch.float32).to('cuda')

def avg_qs(network, states):


## Training your network

- Create an optimizer
- Choose a criterion
- Understand, then complete, this `train_model()` function, which implements the steps in Algorithm 1 from the word "Sample" to the mention of equation 3.

The initial steps in `train_model()` pull out the states, actions, and rewards from the batch, and turn them into cuda torch tensors.  `actual_next_mask` is the indices of the transitions in the batch that actually have a next state (ie, they don't represent a failure state, or in the parlance of the paper, are "non-terminal").  `next_states` is the torch tensor of all the next states for the non-terminal transitions.

In [None]:
optimizer = 
criterion = 

def train_model():
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)

    # Tensors of the states, actions, and rewards from the minibatch
    # states is BATCH_SIZEx4
    # actions is BATCH_SIZEx1
    # rewards is BATCH_SIZEx1
    states = torch.cat([transition.state.unsqueeze(0) for transition in transitions], dim=0).to('cuda')
    actions = torch.cat([transition.action.unsqueeze(0) for transition in transitions], dim=0).to('cuda')
    rewards = torch.cat([transition.reward.unsqueeze(0) for transition in transitions], dim=0).to('cuda')

    # actual_next_mask contains the indices of samples without None next_states
    actual_next_mask = [i for i in range(BATCH_SIZE) if transitions[i].next_state is not None]
    # next_states contains those actual next_states
    next_states = torch.cat([transition.next_state.unsqueeze(0) for transition in transitions if transition.next_state is not None], dim=0).to('cuda')
    # the number of rows in `next_states` is the number of non-terminal states
    # the number of elements in `actual_next_mask`
    
    




## Create your samples, and call the training function

Implement the rest of Algorithm 1, calling your `train_model()` function where appropriate.

If the transition is terminal, the next_state should be `None`.

## Evaluating the smoothness of your training

In the above training loop, keep track of the average maximum Q values for the `sampled_states` you loaded above.  Make a plot displaying the average Q values of the sampled states over time.

In [None]:
fig = go.Figure(data = go.Scatter(x=list(range(len(qs))), y=qs, mode='lines'))
fig.show()

## Evaluating your model's performance

1000 times, use a greedy policy based off your model to run until termination or truncation.  Keep track of the number of steps that pass on each run before it stops (each trial will run a maximum of 500 steps before truncation - of course, it may terminate sooner if the pole falls or the cart goes off the screen).  Print out that average.