In [1]:
import numpy as np
from collections import defaultdict


In [103]:
nA = 2
nS = 2
Q_table = {}

In [104]:
for state in range(nS):
        # Q_table[state] = np.random.random(nA)
        Q_table[state] = np.zeros(nA)

In [105]:
def q_learning(s, a, r, s_n, alpha=0.5, gamma=0.8):
    curr_q_val = Q_table[s][a]
    max_q_val = np.max(Q_table[s_n])
    update_q_val = (1-alpha)*curr_q_val + alpha*(r + gamma*max_q_val)
    Q_table[s][a] = update_q_val

In [106]:
def greedy_action(state, epsilon=0.5, no_preference=True):
    if np.random.rand() < epsilon:
        action = np.random.choice(nA) 
    else:
        if no_preference:
            action = np.argmax(Q_table[state])
            return action
        else:
            if Q_table[state][0] > Q_table[state][1]:
                return 0
            else:
                return 1
            
    return action

In [107]:
def step(current_state, action):
    # action is 0 means move, 1 means stay
    return (current_state+action+1)%2, action

In [108]:
step(0,0)

(1, 0)

In [109]:
for state in range(nS):
    for action in range(nA):
        print(f's = {state} a = {action} s_next, reward = {step(state,action)}')

s = 0 a = 0 s_next, reward = (1, 0)
s = 0 a = 1 s_next, reward = (0, 1)
s = 1 a = 0 s_next, reward = (0, 0)
s = 1 a = 1 s_next, reward = (1, 1)


In [125]:
#Training loop

def train(num_episodes=200, epsilon=0.5, gamma=0.8, alpha=0.5, nS=2, nA=2, no_pref=True):
    for state in range(nS):
        # Q_table[state] = np.random.random(nA)
        Q_table[state] = np.zeros(nA)
    for _ in range(num_episodes):
        # Randomly choose an initial state
        current_state = np.random.choice(nS)
        while True:
            action = greedy_action(current_state, epsilon, no_pref)

            next_state, reward = step(current_state, action)
            
            q_learning(current_state, action, reward, next_state, alpha, gamma)

            current_state = next_state

            # Check if the episode is finished
            if next_state == current_state:
                break
    print(Q_table)
    return Q_table


In [121]:
Q_table_part_2 = train(200, epsilon=0.5)

{0: array([3.99384751, 4.99632871]), 1: array([3.99616373, 4.99587546])}


In [122]:
Q_table_part_1 = train(200, epsilon=0.0, no_pref=False)

{0: array([0.50207421, 4.99988099]), 1: array([0.31254453, 4.99983755])}


In [123]:
print(f'State(s) & Action(a) & Q_value(s,a) \\\\')
for state in Q_table_part_2.keys():
    state_val = 'A' if state == 0 else 'B'
    for id in range(len(Q_table_part_2[state])):
        if id == 0:
            action_val = 'Move'
        else:
            action_val = 'Stay'
        print(f'{state_val} & {action_val} & {Q_table_part_2[state][id]} \\\\')

State(s) & Action(a) & Q_value(s,a) \\
A & Move & 0.502074212609357 \\
A & Stay & 4.999880987753604 \\
B & Move & 0.3125445271283821 \\
B & Stay & 4.999837549288073 \\


In [126]:
Q_table_part_2 = train(200, epsilon=0.0, no_pref=True)

{0: array([0., 0.]), 1: array([0., 0.])}
