In [36]:
STATE_A = 'A'
STATE_B = 'B'

episodes = [
    [(STATE_A, 0), (STATE_B, 0)],
    [(STATE_B, 1)],
    [(STATE_B, 1)],
    [(STATE_B, 1)],
    [(STATE_B, 1)],
    [(STATE_B, 1)],
    [(STATE_B, 1)],
    [(STATE_B, 0)],
]

In [37]:
def update_state_values_td(state_values, episode, alpha):
    for i in range(len(episode)):
        state, reward = episode[i]
        if i == len(episode) - 1:
            next_state_value = 0
        else:
            next_state, _ = episode[i + 1]
            next_state_value = state_values[next_state]
        
        error = reward + next_state_value - state_values[state]
        state_values[state] += alpha * error
    return state_values

def update_state_values_mc(state_values, episode, alpha):
    total_return = 0
    for state, reward in reversed(episode):
        total_return += reward
        state_values[state] += alpha * (total_return - state_values[state])

    return state_values

def get_total_divergence(previous_state_values, state_values):
    total_divergence = 0
    for key, value in state_values.items():
        total_divergence += (value - previous_state_values[key]) ** 2
    return total_divergence

def run_batch_learning(alpha, update_function, episodes):
    state_values = {
        'A': 0,
        'B': 0,
    }
    previous_state_values = {
        'A': float('inf'),
        'B': float('inf'),
    }
    epsilon = 0.00000000001
    while get_total_divergence(previous_state_values, state_values) > epsilon:
        previous_state_values = state_values.copy()
        for episode in episodes:
            state_values = update_function(state_values, episode, alpha)

    return state_values

In [38]:
result_td = run_batch_learning(0.01, update_state_values_td, episodes)
result_mc = run_batch_learning(0.01, update_state_values_mc, episodes)

print(result_td)
print(result_mc)

{'A': 0.7496016162048674, 'B': 0.749911629550885}
{'A': 0.0, 'B': 0.7498765444177047}
