# Write-up and code for Feb 15

## To Do:
- ~~Prove the Epsilon-Greedy Policy Improvement Theorem (we sketched the proof in Class)~~
- ~~Provide (with clear mathematical notation) the defintion of GLIE (Greedy in the Limit with Infinite Exploration)~~
- ~~Implement the tabular SARSA and tabular SARSA(Lambda) algorithms~~
- ~~Implement the tabular Q-Learning algorithm~~
- ~~Test the above algorithms on some example MDPs by using DP Policy Iteration/Value Iteration solutions as a benchmark~~

## $\epsilon$-Greedy Policy Improvement Theorem
__Theorem:__ For any $\epsilon$-greedy policy $\pi$, the $\epsilon$-greedy policy $\pi'$ with respect to $q_\pi$ is an improvement, $v_{\pi'}(s) \geq v_\pi(s)$

__Proof:__ 
$$
\begin{align}
q_\pi(s,\pi'(s)) & = \sum_{a\in\mathcal A}\pi'(a|s)q_\pi(s,a)\\
& = \epsilon/m\sum_{a\in\mathcal A}q_\pi(s,a) + (1-\epsilon) \max_{a \in\mathcal A}q_\pi(s,a)\\
& \geq \epsilon/m\sum_{a\in\mathcal A}q_\pi(s,a) + (1-\epsilon) \sum_{a \in\mathcal A}\frac{\pi(a|s) - \epsilon/m}{1-\epsilon}q_\pi(s,a)\\
& = \sum_{a\in\mathcal A}\pi(a|s)q_\pi(s,a)\\
& = v_\pi(s)
\end{align}
$$

Therefore, from policy improvement theorem $v_{\pi'}(s) \geq v_\pi(s)$.

## Greedy in the Limit with Infinite Exploration
__Definition:__ Greedy in the Limit with Infinite Exploration (GLIE)
- All State-Action pairs are explored infinitely many times,

$$
\lim_{k\to\infty}N_k(s,a) = \infty
$$

where $N_k(s,a)$ is the number of times we have taken action $a$ and visited state $s$ after $k$ time-steps

- The policy converges on a greedy policy,

$$
\lim_{k\to\infty}\pi(a|s) = \mathbf 1\big(a=\arg\max_{a'\in\mathcal A}Q_k(s,a')\big)
$$

## Sarsa and Sarsa($\lambda$) Algorithms

In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [26]:
from typing import Dict
from modules.MDP import MDP, Q, Policy
from modules.RL_interface import RL_interface
from modules.state_action_vars import S, A
import random

def sarsa(mdp: MDP, num_epi: int, num_steps: int, eps: float, alpha: float) -> Q:
    # implementation of Sarsa-learning
    
    # initialize Q(s,a) to zero for all state-action pairs
    q = {s: {a: 0. for a in mdp.Actions} for s in mdp.States}
    
    for i in range(num_epi):
        # sample a random starting state and an action
        s = random.sample(mdp.States,1).pop()
        if random.random() > eps:
            _, a = find_max_q(q, s)
        else:
            a = random.sample(mdp.Actions,1).pop()
        
        for j in range(num_steps):
            # observe a reward r and a next state sp
            sp, r = RL_interface(mdp, s, a)
        
            # follow an epsilon-greedy policy
            if random.random() > eps:
                _, ap = find_max_q(q, sp)
            else:
                ap = random.sample(mdp.Actions,1).pop()
            
            # update the q-function
            q[s][a] += alpha*(r + mdp.gamma*q[sp][ap] - q[s][a])
            
            s = sp
            a = ap
    return q

In [38]:
def sarsa_lambda(mdp: MDP, num_epi: int, num_steps: int, eps: float, alpha: float, lambd: float) -> Q:
    # implementation of the Sarsa-lambda algorithm
    
    # initialize Q(s,a) and N(s,a) to zero for all state-action pairs
    q = {s: {a: 0. for a in mdp.Actions} for s in mdp.States}
    n = {s: {a: 0 for a in mdp.Actions} for s in mdp.States}
    
    for i in range(num_epi):
        # sample a random starting state and an action
        s = random.sample(mdp.States,1).pop()
        if random.random() > eps:
            _, a = find_max_q(q, s)
        else:
            a = random.sample(mdp.Actions,1).pop()
        
        for j in range(num_steps):
            # observe a reward r and a next state sp
            sp, r = RL_interface(mdp, s, a)
            # follow an epsilon-greedy policy
            if random.random() > eps:
                _, ap = find_max_q(q, sp)
            else:
                ap = random.sample(mdp.Actions,1).pop()
            # increment the eligibility trace
            n[s][a] += 1
            # calculate the error
            delta = r + mdp.gamma*q[sp][ap] - q[s][a]
            
            for s_ in mdp.States:
                for a_ in mdp.Actions:
                    # update the q-function
                    q[s_][a_] += alpha*delta*n[s_][a_]
                    n[s_][a_] *= mdp.gamma*lambd
            
            s = sp
            a = ap
    return q

In [22]:
def find_max_q(q: Q, s: S) -> A:
    # returns the best action for a specific state
    best_value = -1e7
    best_a = None
    
    # loop over all the actions and store Q(s,a) and a if it is the current best action
    for a in q[s]:
        if q[s][a] > best_value:
            best_value = q[s][a]
            best_a = a
            
    return best_value, best_a

## Q-learning

In [39]:
def q_learning(mdp: MDP, num_epi: int, num_steps: int, eps: float, alpha: float) -> Q:
    # implementation of Sarsa-learning
    
    # initialize Q(s,a) to zero for all state-action pairs
    q = {s: {a: 0. for a in mdp.Actions} for s in mdp.States}
    
    for i in range(num_epi):
        # sample a random starting state and an action
        s = random.sample(mdp.States,1).pop()
        
        for j in range(num_steps):
            if random.random() > eps:
                _, a = find_max_q(q, s)
            else:
                a = random.sample(mdp.Actions,1).pop()
            # observe a reward r and a next state sp
            sp, r = RL_interface(mdp, s, a)
            # find the max for next state
            q_max = find_max_q(q, sp)
            # update the q-function
            q[s][a] += alpha*(r + mdp.gamma*q_max - q[s][a])
            
            s = sp
            a = ap
    return q

In [27]:
def q_to_policy(q: Q) -> Policy:
    # takes in Q(s,a) and returns a greedy policy
    policy = {s: {a: 0.0 for a in q[s]} for s in q}
    
    for s in q:
        _, a = find_max_q(q, s)
        policy[s][a] = 1.0
    
    return policy

## Gridworld Example
Continue on the previous Gridworld example. Remember, there are positive rewards of entering state (0,0) and (3,3) while there is a negative reward when entering state (1,2).

In [32]:
from modules.gridworld import gridworld
import numpy as np
gw = gridworld(0.9)

In [43]:
def print_policy_gridworld(policy: Policy):
    # function that prints out the grid
    last_s = (0,0)
    for s in sorted(policy.keys()):
        if s[0] != last_s[0]:
            print()
        for a in policy[s]:
            
            if np.abs(policy[s][a] - 1.0) < 1e-6:
                if a == 1:
                    string = '<-'
                elif a == 2:
                    string = '->'
                elif a == 3:
                    string = '/\\'
                else:
                    string = '\\/'
                print(s, ": {} \t".format(string), end='')
        last_s = s
    print('\n')

In [45]:
q_sarsa = sarsa(gw, 1000, 30, 0.3, 0.01)
q_sarsa_lambda = sarsa_lambda(gw, 10000, 30, 0.3, 0.01, 0.9)
q_learn = sarsa(gw, 1000, 30, 0.3, 0.01)

In [46]:
policy_sarsa = q_to_policy(q_sarsa)
policy_sarsa_lambda = q_to_policy(q_sarsa_lambda)
policy_qlearn = q_to_policy(q_learn)

In [47]:
print_policy_gridworld(policy_sarsa)
print_policy_gridworld(policy_sarsa_lambda)
print_policy_gridworld(policy_qlearn)

(0, 0) : <- 	(0, 1) : <- 	(0, 2) : <- 	(0, 3) : <- 	
(1, 0) : /\ 	(1, 1) : /\ 	(1, 2) : <- 	(1, 3) : \/ 	
(2, 0) : /\ 	(2, 1) : <- 	(2, 2) : \/ 	(2, 3) : <- 	
(3, 0) : /\ 	(3, 1) : -> 	(3, 2) : -> 	(3, 3) : <- 	

(0, 0) : <- 	(0, 1) : <- 	(0, 2) : <- 	(0, 3) : <- 	
(1, 0) : /\ 	(1, 1) : <- 	(1, 2) : <- 	(1, 3) : \/ 	
(2, 0) : /\ 	(2, 1) : /\ 	(2, 2) : \/ 	(2, 3) : \/ 	
(3, 0) : /\ 	(3, 1) : -> 	(3, 2) : -> 	(3, 3) : -> 	

(0, 0) : <- 	(0, 1) : <- 	(0, 2) : <- 	(0, 3) : <- 	
(1, 0) : /\ 	(1, 1) : /\ 	(1, 2) : /\ 	(1, 3) : \/ 	
(2, 0) : /\ 	(2, 1) : <- 	(2, 2) : \/ 	(2, 3) : \/ 	
(3, 0) : /\ 	(3, 1) : -> 	(3, 2) : -> 	(3, 3) : <- 	



## Example Jack's Car Rental
The number of cars available in the afternoon at $t+1$ in location A is given by:

$$
A_{t+1} = A_t + N^A_{t+1} - M^A_{t+1} + D_t
$$

where $N^A_t$ is the number of cars returned to A, $M^A_t$ is the number of cars rented at A, and $D_t$ is the number of cars moved between A and B. The equation is similar for location B.