In [1]:
import numpy as np

In [2]:
num_cells = 6  # Number of cells in the 1D world
start = 3      # Starting position
n_iters = 10000   # Number of iterations/episodes
n_actions = 2  # Number of possible actions (move left or right)

# Reward setup (1 for the first cell, 0 for the rest)
reward = np.array([1] + [0] * (num_cells - 1))

# Initialize Q-values
Q = np.zeros((num_cells, n_actions))

In [3]:
epsilon = 0.1   # Epsilon for epsilon-greedy policy
discount = 0.95 # Discount factor for future rewards
Qs = []

# Q-learning algorithm
for _ in range(n_iters):
    curr_state = start
    
    # Perform episode
    time = 0
    max_time = 10
    while True or time < max_time:
        # Choose action using epsilon-greedy policy
        if np.random.uniform() < epsilon:
            action = np.random.randint(n_actions)  # epsilon-greedy
        else:
            action = np.argmax(Q[curr_state])      # for UCB, add C * sqrt( ln(t) / # times in s,a )
        
        # Move to next state based on action and get reward
        next_state = curr_state - 1 if action == 0 else curr_state + 1
        immediate_reward = reward[next_state]
        
        # Q-value update (temporal difference learning, with one step)
        best_next_action = np.argmax(Q[next_state]) # for SARSA, we update based on v(s, A_{t+1}), weighted SARSA: sum(a*v(a, a)) 
        Q[curr_state, action] += immediate_reward + discount * Q[next_state, best_next_action] - Q[curr_state, action]
        
        # Move to next state
        curr_state = next_state
        
        # Termination condition (reached boundaries of the world)
        time += 1
        if curr_state == 0 or curr_state == num_cells - 1:
            break
    
    # MC methods wait to get more accurate estimates at end of episode
    ## this above uses TD learning with step one, immediately

    Qs.append(Q)

# After all iterations, print the learned Q-values
print("\nLearned Q-values:")
print(Q)



Learned Q-values:
[[0.         0.        ]
 [1.         0.9025    ]
 [0.95       0.857375  ]
 [0.9025     0.81450625]
 [0.857375   0.        ]
 [0.         0.        ]]


In [4]:
# 1D line q-learning / e-greedy / UCB, SARSA / TD learning