In [1]:
from operator import itemgetter
from rl.distribution import Categorical
from typing import Iterable, Iterator, TypeVar, Callable, Mapping, Tuple, Set
from rl.markov_decision_process import MarkovDecisionProcess, Policy, \
    TransitionStep, NonTerminal, FiniteMarkovDecisionProcess
from rl.approximate_dynamic_programming import NTStateDistribution
from rl.function_approx import learning_rate_schedule
from rl.distribution import Choose
import rl.iterate as iterate
import itertools
from pprint import pprint
from rl.function_approx import Tabular

In [2]:
S = TypeVar('S')
A = TypeVar('A')

### Tabular SARSA algorithm

In [3]:
def epsilon_greedy_action_tabular(
    q: Mapping[Tuple[S, A], float],
    nt_state: NonTerminal[S],
    actions: Set[A],
    epsilon: float
) -> A:
    greedy_action: A = max(((a, q[(nt_state, a)]) for a in actions), key=itemgetter(1))[0]
    return Categorical({a: epsilon / len(actions) + 
                        (1 - epsilon if a == greedy_action else 0.) for a in actions}).sample()

def glie_sarsa_tabular(
    mdp: MarkovDecisionProcess[S, A],
    states: NTStateDistribution[S],   # initial_state_distribution
    gamma: float,
    epsilon_as_func_of_episodes: Callable[[int], float],
    max_episode_length: int,
    count_to_weight_func: Callable[[int], float]
) -> Iterator[Mapping[Tuple[S, A], float]]:
    q = {(s, a): 0.0 for s in mdp.non_terminal_states for a in mdp.actions(s)}
    count = q.copy()
    yield q
    num_episode: int = 0
    
    while True:
        num_episode += 1
        epsilon = epsilon_as_func_of_episodes(num_episode)
        state = states.sample()
        action = epsilon_greedy_action_tabular(q, nt_state = state, 
                                               actions = set(mdp.actions(state)),
                                              epsilon = epsilon)
        steps = 0
        while isinstance(state, NonTerminal) and steps < max_episode_length:
            next_state, reward = mdp.step(state, action).sample()
            if isinstance(state, NonTerminal):
                next_action = epsilon_greedy_action_tabular(q, nt_state = next_state, 
                                                            actions = set(mdp.actions(next_state)),
                                                            epsilon = epsilon)
                count[(state, action)] += 1
                learning_rate = count_to_weight_func(count[(state, action)])
                q[(state, action)] = q[(state,action)] + learning_rate * (reward + gamma * q[(next_state, next_action)] - q[(state, action)])
                action = next_action
            else:
                q[(state,action)] = q[(state,action)] + learning_rate * (reward - q[(state, action)])
            yield q
            steps += 1
            state = next_state

#### Extend Tabular SARSA to function approximation

In [4]:
from rl.td import glie_sarsa

In [6]:
## test on simple inventory mdp
from rl.chapter3.simple_inventory_mdp_cap import SimpleInventoryMDPCap
from rl.chapter3.simple_inventory_mdp_cap import InventoryState
capacity: int = 2
poisson_lambda: float = 1.0
holding_cost: float = 1.0
stockout_cost: float = 10.0
gamma: float = 0.9
si_mdp: SimpleInventoryMDPCap = SimpleInventoryMDPCap(
    capacity=capacity,
    poisson_lambda=poisson_lambda,
    holding_cost=holding_cost,
    stockout_cost=stockout_cost
)
    
num_episodes = 10000
max_episode_length: int = 100
epsilon_as_func_of_episodes: Callable[[int], float] = lambda k: k ** -0.5
initial_learning_rate: float = 0.1
half_life: float = 10000.0
exponent: float = 1.0
gamma: float = 0.9
learning_rate_func: Callable[[int], float] = learning_rate_schedule(
    initial_learning_rate=initial_learning_rate,
    half_life=half_life,
    exponent=exponent
)

# tabular sarsa
qvfs = glie_sarsa_tabular(
    mdp = si_mdp,
    states = Choose(si_mdp.non_terminal_states),
    gamma = gamma,
    epsilon_as_func_of_episodes = epsilon_as_func_of_episodes,
    max_episode_length = max_episode_length,
    count_to_weight_func=learning_rate_func
)
num_updates = num_episodes * max_episode_length
final_qvf = iterate.last(itertools.islice(qvfs, num_updates))
optimal_q_value = {s: max(final_qvf[(s,a)] for a in si_mdp.actions(s)) 
                   for s in si_mdp.non_terminal_states}
optimal_policy = {s: max(((a, final_qvf[(s, a)]) for a in si_mdp.actions(s)), key=itemgetter(1))[0] 
                  for s in si_mdp.non_terminal_states}
print("Tabular_SARSA_control")
pprint(optimal_q_value)
pprint(optimal_policy)


# function approximation sarsa
from rl.chapter11.control_utils import get_vf_and_policy_from_qvf
initial_qvf_dict = {(s, a): 0. for s in si_mdp.non_terminal_states for a in si_mdp.actions(s)}
approx_0 = Tabular(
    values_map = initial_qvf_dict,
    count_to_weight_func = learning_rate_func
)
qvfs = glie_sarsa(
    mdp = si_mdp,
    states = Choose(si_mdp.non_terminal_states),
    approx_0 = approx_0,
    gamma = gamma,
    epsilon_as_func_of_episodes = epsilon_as_func_of_episodes,
    max_episode_length = max_episode_length
)
num_updates = num_episodes * max_episode_length
final_qvf = iterate.last(itertools.islice(qvfs, num_updates))
opt_vf, opt_policy = get_vf_and_policy_from_qvf(mdp=si_mdp, qvf=final_qvf)
print("Approximation_Function_SARSA_control")
pprint(opt_vf)
pprint(opt_policy)

Tabular_SARSA_control
{NonTerminal(state=InventoryState(on_hand=0, on_order=1)): -27.750359886683103,
 NonTerminal(state=InventoryState(on_hand=0, on_order=2)): -27.68758766212732,
 NonTerminal(state=InventoryState(on_hand=1, on_order=0)): -28.758661238766457,
 NonTerminal(state=InventoryState(on_hand=1, on_order=1)): -29.156320557298695,
 NonTerminal(state=InventoryState(on_hand=2, on_order=0)): -30.451493482298687,
 NonTerminal(state=InventoryState(on_hand=0, on_order=0)): -34.94751552147776}
{NonTerminal(state=InventoryState(on_hand=0, on_order=1)): 1,
 NonTerminal(state=InventoryState(on_hand=0, on_order=2)): 0,
 NonTerminal(state=InventoryState(on_hand=1, on_order=0)): 1,
 NonTerminal(state=InventoryState(on_hand=1, on_order=1)): 0,
 NonTerminal(state=InventoryState(on_hand=2, on_order=0)): 0,
 NonTerminal(state=InventoryState(on_hand=0, on_order=0)): 1}
Approximation_Function_SARSA_control
{NonTerminal(state=InventoryState(on_hand=0, on_order=1)): -27.64793797840082,
 NonTerminal

### Tabular Q-Learning algorithm

In [13]:
def q_learning_tabular(
    mdp: MarkovDecisionProcess[S, A],
    states: NTStateDistribution[S],   # initial_state_distribution
    gamma: float,
    max_episode_length: int,
    count_to_weight_func: Callable[[int], float]
) -> Iterator[Mapping[Tuple[S, A], float]]:
    q = {(s, a): 0.0 for s in mdp.non_terminal_states for a in mdp.actions(s)}
    count = q.copy()
    yield q
    
    while True:
        state: NonTerminal[S] = states.sample()
        steps = 0
        while isinstance(state, NonTerminal) and steps < max_episode_length:
            action = epsilon_greedy_action_tabular(q, nt_state = state, 
                                               actions = set(mdp.actions(state)),
                                              epsilon = 1)
                
            next_state, reward = mdp.step(state, action).sample()
            if isinstance(state, NonTerminal):
                count[(state, action)] += 1
                learning_rate = count_to_weight_func(count[(state, action)])
                next_return = max(q[next_state,a] for a in mdp.actions(next_state))
                q[(state, action)] = q[(state,action)] + learning_rate * (reward + gamma * next_return - q[(state, action)])
            else:
                q[(state,action)] = q[(state,action)] + learning_rate * (reward - q[(state, action)])
            yield q
            steps += 1
            state = next_state

#### Extend Tabular Q-Learning to function approximation

In [15]:
## test on simple inventory mdp

# tabular q-learning
qvfs = q_learning_tabular(
    mdp = si_mdp,
    states = Choose(si_mdp.non_terminal_states),
    gamma = gamma,
    max_episode_length = max_episode_length,
    count_to_weight_func = learning_rate_func
)
num_updates = num_episodes * max_episode_length
final_qvf = iterate.last(itertools.islice(qvfs, num_updates))
optimal_q_value = {s: max(final_qvf[(s,a)] for a in si_mdp.actions(s)) 
                   for s in si_mdp.non_terminal_states}
optimal_policy = {s: max(((a, final_qvf[(s, a)]) for a in si_mdp.actions(s)), key=itemgetter(1))[0] 
                  for s in si_mdp.non_terminal_states}
print("Tabular_SARSA_control")
pprint(optimal_q_value)
pprint(optimal_policy)

Tabular_SARSA_control
{NonTerminal(state=InventoryState(on_hand=0, on_order=1)): -27.793177374064413,
 NonTerminal(state=InventoryState(on_hand=0, on_order=2)): -28.20224921577488,
 NonTerminal(state=InventoryState(on_hand=1, on_order=0)): -28.858441189868383,
 NonTerminal(state=InventoryState(on_hand=1, on_order=1)): -29.244109908333325,
 NonTerminal(state=InventoryState(on_hand=2, on_order=0)): -30.110641751326625,
 NonTerminal(state=InventoryState(on_hand=0, on_order=0)): -34.95680896900531}
{NonTerminal(state=InventoryState(on_hand=0, on_order=1)): 1,
 NonTerminal(state=InventoryState(on_hand=0, on_order=2)): 0,
 NonTerminal(state=InventoryState(on_hand=1, on_order=0)): 1,
 NonTerminal(state=InventoryState(on_hand=1, on_order=1)): 0,
 NonTerminal(state=InventoryState(on_hand=2, on_order=0)): 0,
 NonTerminal(state=InventoryState(on_hand=0, on_order=0)): 1}


In [None]:
from rl.td import q_learning