# problem 2

In [None]:
from typing import Tuple, Sequence, Set, Mapping, Dict, Callable, Optional
from dataclasses import dataclass
from operator import itemgetter
from rl.distribution import Categorical, Choose, Constant
from rl.markov_decision_process import FiniteMarkovDecisionProcess
from rl.markov_decision_process import StateActionMapping
from rl.markov_decision_process import FinitePolicy
from rl.dynamic_programming import value_iteration_result, V


def eps_greedy_action(
    nt_state: Cell,
    q: Mapping[Cell, Mapping[Move, float]],
    epsilon: float) -> Move:
    action_values: Mapping[Move, float] = q[nt_state]
    return Categorical({a: epsilon / len(action_values) +(1 - epsilon if a == max(action_values.items(), key=itemgetter(1))[0] else 0.)
         for a in action_values}).sample()

def sarsa(
    states_actions_dict: Mapping[S, Optional[Set[A]]],
    sample_func: Callable[[S, A], Tuple[S, float]],
    episodes: int = 10000,
    step_size: float = 0.01) -> Tuple[V[S], FinitePolicy[S, A]]:

    q: Dict[Cell, Dict[A, float]] = \
        {s: {a: 0. for a in actions} for s, actions in
         states_actions_dict.items() if actions is not None}
    nt_states: Set[S] = {s for s in q}
    uniform_states: Choose[S] = Choose(nt_states)
    for episode_num in range(episodes):
        epsilon: float = 1.0 / (episode_num + 1)
        state: S = uniform_states.sample()
        action: A = eps_greedy_action(state, q, epsilon)
        while(True): # loop until the episode terminates
            tmp = sample_func(state, action)
            next_state: S = tmp[0]
            reward: float = tmp[1]
            if(q[next_state] is None):
                q[state][action] += step_size * (reward - q[state][action])
                break
            else:
                next_action: A = eps_greedy_action(next_state, q, epsilon)
                q[state][action] += step_size * (reward + q[next_state][next_action] - q[state][action])
                action = next_action
                state = next_state

    vf_dict: V[S] = {s: max(d.values()) for s, d in q.items()}
    policy: FinitePolicy[S, A] = FinitePolicy(
        {s: Constant(max(d.items(), key=itemgetter(1))[0])
         for s, d in q.items()}
    )
    return (vf_dict, policy)

 