Reinforcement Learning


In this notebook, I work through OpenAI's Gym, and the cartpole problem in particular. Applying a Deep Q Network, it is possible to "win" this game after a blend of exploration and exploitation of this space. 

I'll do my best to document what each piece of the code does, and provide sources for this work.

In [1]:
#Import dependencies 

#Python stuff 
import math
import random
import numpy as np
from collections import namedtuple
from itertools import count
from PIL import Image
from collections import deque


#DL Stuff
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#RL Stuff
import gym
from gym import wrappers

#visualization stuff 
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline


In [None]:
#First define the Deep Q Network

class DQN(nn.Module):
    
    """
    
    
    Neural Network to approximate the Q function. 
    
    
    https://danieltakeshi.github.io/2016/10/31/going-deeper-into-reinforcement-learning-understanding-q-learning-and-linear-function-approximation/
    
    """
    def __init__(self, obs_shape):
        super().__init__()
            
        self.fc1 = nn.Linear(in_features=obs_shape, out_features=24)   
        self.fc2 = nn.Linear(in_features=24, out_features=32)
        self.fc3 = nn.Linear(in_features=32,out_features=24)
        self.out = nn.Linear(in_features=24, out_features=2)
    
    def forward(self, t):
        t = F.relu(self.fc1(t))
        t = F.relu(self.fc2(t))
        t = F.relu(self.fc3(t))

        t = self.out(t)
        return t
   

In [None]:
class EpsilonGreedyStrategy():
    """
    Strategy for deciding between exploring and exploiting Q table.

    https://jamesmccaffrey.wordpress.com/2017/11/30/the-epsilon-greedy-algorithm/
    
    """
    def __init__(self, start, end, decay):
        self.start = start
        self.end = end
        self.decay = decay
        
    def get_exploration_rate(self, current_step):
        return self.end + (self.start - self.end) * \
            math.exp(-1. * current_step * self.decay)
    

   

In [None]:
class Agent():
    """
    The class for the agent which interacts with the cartpole problem. 
    
    Tracks the current step, it's action space, and decides whether or not to 
    explore or exploit based on rate vs random.random()
    """
    def __init__(self, strategy, num_actions, device):
        self.current_step = 0
        self.strategy = strategy
        self.num_actions = num_actions
        self.device = device
        
    def select_action(self, state, policy_net):
        rate = self.strategy.get_exploration_rate(self.current_step)
        self.current_step += 1

        if rate > random.random():
            action = random.randrange(self.num_actions)
            return torch.tensor([action]).to(self.device) # explore      
        else:
            with torch.no_grad():
                # convert current state to a tensor
                state = torch.tensor(state).reshape(-1,4).float()
                return policy_net(state).argmax().to(self.device) # exploit

In [None]:
# utility functions for trajectories. 


Experience = namedtuple(
    'Experience',
    ('state', 'action', 'next_state', 'reward','done')
)

#experience is stored in the main loop and contains state action pairs, with corresponding results


def extract_tensors(experiences):
    # Convert batch of Experiences to Experience of batches
    # Used as part of experience replay and calculating Q values of each at every 
    # iteration. 
    batch = Experience(*zip(*experiences))

    t1 = torch.tensor(batch.state)
    t2 = torch.tensor(batch.action)
    t3 = torch.tensor(batch.reward)
    t4 = torch.tensor(batch.next_state)
    t5 = torch.tensor(batch.done)
    return (t1.float(),t2.float(),t3.float(),t4.float(),t5.float())




In [None]:
class ReplayMemory():
    """
    A tensor of defined capacity containing a given number of experiences.
    
    
    Sampled randomly to promote unrelated observations, 
    and fit DQN with those.
    
    Creating a certain capacity pushes out oldest experiences as time goes on, and 
    if coupling this with epsilon greedy strategy, means newer experiences are typically
    exploits. 
    
    """
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.push_count = 0
        
    def push(self, experience):
        if len(self.memory) < self.capacity:
            self.memory.append(experience)
        else:
            self.memory[self.push_count % self.capacity] = experience
        self.push_count += 1
        
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    
    def can_provide_sample(self, batch_size):
        return len(self.memory) >= batch_size
    
    
    

In [None]:
class QValues():
    """
    Method to calculate Q values. 
    
    More specifically, calculates the Q value for the given states, 

    and then the max Q value of the next state, provided the episode is not done.
    
    If done is True or 1, then the next Q value is just 0. 
    
    
    These two values are important for value iteration. Get Current is the predicted Q
    
    and get_next is related to the Bellman Optimality of the true Q value. 
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    @staticmethod
    def get_current(policy_net, states, actions):
        return policy_net(states).gather(dim=1, index=actions.unsqueeze(-1).long())
    
    @staticmethod        
    def get_next(target_net, next_states,dones):  #if done, need to include to set as 0   
        done_mask = dones == 1 # if true, set to 0
        values = target_net(next_states).max(dim=1)[0].detach()
        values[done_mask] = 0.
        return values

    
    

Now that everything has been initially defined, we can begin to populate this notebook with the task at hand, along with tuneable parameters. 



In [2]:
#tuneable parameters

batch_size = 256
gamma = 0.999
eps_start = 1
eps_end = 0.01
eps_decay = 1e-3
target_update = 10
memory_size = 10000
lr = 0.001
num_episodes = 1000

In [None]:

#Policy net is what is optimized, 
# and target net is an occasionally updated network used for finding Q*.

policy_net = DQN(obs_shape=4).to(device)
target_net = DQN(obs_shape=4).to(device)

In [None]:
#initilize more objects. 

strategy = EpsilonGreedyStrategy(eps_start, eps_end, eps_decay)
agent = Agent(strategy, env.action_space.n, device)
memory = ReplayMemory(memory_size)
optimizer = optim.Adam(params=policy_net.parameters(), lr=lr)
RENDER_TIME = 5

Main Loop:

Iterate over episodes, iterate over timesteps in episode, save that experience, train the DQN, and all the while take steps according to epsilon greedy strategy. 

In [None]:
episode_durations = []
RENDER = True
for episode in range(num_episodes):
    state = env.reset() #initial state
    done = False
    timestep = 0
    while not done:
        if RENDER:
            env.render(mode = 'human')
        timestep += 1
        action = agent.select_action(state, policy_net).item()
        next_state, reward, done, info = env.step(action)
        
        memory.push(Experience(state, action, next_state, reward,done))
        
        state = next_state

        #once enough replays are collected so it is possible to learn. 
        if memory.can_provide_sample(batch_size):
            experiences = memory.sample(batch_size)
            states, actions, rewards, next_states, dones = extract_tensors(experiences)
            current_q_values = QValues.get_current(policy_net, states, actions)
            next_q_values = QValues.get_next(target_net, next_states,dones)
            target_q_values = (next_q_values * gamma) + rewards
            loss = F.mse_loss(current_q_values, target_q_values.unsqueeze(1))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    if episode % target_update == 0:
        target_net.load_state_dict(policy_net.state_dict())
    
    episode_durations.append(timestep)
    if episode % RENDER_TIME == 0:
        RENDER = True
    else: RENDER = False
    
env.close()
