In [1]:
def get_eps_greedy(actions, epsilon, best_action):
    prob = {}
    num_actions = len(actions)
    for action in actions:
        if action == best_action:
            prob[action] = 1 - epsilon + (epsilon / num_actions)
        else:
            prob[action] = epsilon / num_actions
    return prob


In [2]:
def on_policy_mc_control(env, iterations, epsilon, gamma):
    states = env.state_space
    actions = env.action_space
    Q = {s: {a: 0 for a in actions} for s in states}
    returns_count = {s: {a: 0 for a in actions} for s in states}
    policy = {s: get_eps_greedy(actions, epsilon, max(actions)) for s in states}

    for _ in range(iterations):
        episode = generate_episode(env, policy)
        G = 0
        for state, action, reward in reversed(episode):
            G = reward + gamma * G
            if (state, action) not in [(x[0], x[1]) for x in episode[:-1]]:
                returns_count[state][action] += 1
                Q[state][action] += (G - Q[state][action]) / returns_count[state][action]
                best_action = max(Q[state], key=Q[state].get)
                policy[state] = get_eps_greedy(actions, epsilon, best_action)
    return policy, Q


In [3]:
def off_policy_mc_control(env, iterations, epsilon, gamma):
    Q = {s: {a: 0 for a in env.action_space} for s in env.state_space}
    C = {s: {a: 0 for a in env.action_space} for s in env.state_space}
    target_policy = {}
    behavior_policy = {s: get_eps_greedy(env.action_space, epsilon, max(env.action_space)) for s in env.state_space}

    for _ in range(iterations):
        episode = generate_episode(env, behavior_policy)
        G = 0
        W = 1
        for state, action, reward in reversed(episode):
            G = reward + gamma * G
            C[state][action] += W
            Q[state][action] += (W / C[state][action]) * (G - Q[state][action])
            best_action = max(Q[state], key=Q[state].get)
            target_policy[state] = best_action
            if action != best_action:
                break
            W *= 1 / behavior_policy[state][action]

    return target_policy, Q
