# **Temporal-Difference Learning**

Temporal Difference (TD) Learning is similar to the Monte-Carlo method. However, despite it learning from episodes of experience and being model-free, TD has a unique feature. TD will not wait for the episode to end to update the maximum possible reward.  The reward will be updated during the episode based on gained experience.

$V(s)\leftarrow V(S_t)+\alpha[R_{t+1}+\gamma(S_{t_1}-V(S_t))]$

There are 2 cmmon temporal-differnce learning method: Q-learning and SARSA.

## **Q-learning**

Q-learning is a **model-free**, **value-based** and **off-policy** reinforcement learning algorithm to find the best action by updating a table of values, called q-table, which is a data structure of sets of actions and states. We use the Q-learning algorithm to update the values in the table and find the best action for each state, the updating function used by Q-learning is called Q-function.

#### **Q-function**

Q-function is the same as action value function in Bellman equation.  

$Q_{\pi}(s,a)=E_{\pi}[R_{t+1}+\gamma v_{\pi}(S_{t+1},A_{t+1})|S+t=s,A_t=a]$


  
In each step, Q-learning update Q-value by,

<img src="img/q_learning.png" width="200">  

$Q(S_t,A_t)\leftarrow Q(S_t,A_t)+\alpha[R_{t+1}+\gamma\max_a Q(S_{t+1},a)-Q(S_t,A_t)]$  

where $\alpha$ is learning rate, 

$R_{t+1}+\gamma\max_a Q(S_{t+1},a)$ is TD target,  

$R_{t+1}+\gamma\max_a Q(S_{t+1},a)-Q(S_t,A_t)$ is TD error.

***Q-learning algorithm***

---
**Input** MDP $M=<S,s_0,A,P_a(s^{\prime},s),r(s,a,s^{\prime}),\gamma>$  
**Output** Q-table

Initialize $Q(s,a)$, for all $s\in S$, $a\in A(s)$, arbitrarily except that $Q(terminal,.)=0$

**repeat**  
    $\quad$ Initialize $s$      
    $\quad$**for each step of episode**        
        $\quad\quad$**Choose** $a\in A(s)$ **by Q-table** (e.g., $\epsilon$-greedy)      
            $\quad\quad\quad Q(s,a)\leftarrow Q(s,a)+\alpha[r(s,a,s^{\prime})+\gamma\max_{a^{\prime}\in A(s^{\prime})} Q(s^{\prime},a^{\prime})-Q(s,a)]$      
            $\quad\quad\quad s \leftarrow s^{\prime}$     
    $\quad\quad$**end unitl terminal**       
**end** 

---


## **SARSA**

SARSA is also a **model-free**, **value-based** but **on-policy** reinforcement learning algorithm based on Q-table, the difference between Q-learning and SARSA is the updating equation:

<img src="img/sarsa.png" width="100">  

$Q(S_t,A_t)\leftarrow Q(S_t,A_t)+\alpha[R_{t+1}+\gamma Q(S_{t+1},A^{\prime})-Q(S_t,A_t)]$      

$A^{\prime}$ is the action chosen from $S_{t+1}$ based on Q-table. 

$R_{t+1}+\gamma\max_a Q(S_{t+1},a)$ is TD target,  

$R_{t+1}+\gamma\max_a Q(S_{t+1},a)-Q(S_t,A_t)$ is TD error.


Not like Q-learning to choose next action by updated Q-value, SARSA is based on Q-table to choose the next action to update Q-table.

***SARSA algorithm***

---
**Input** MDP $M=<S,s_0,A,P_a(s^{\prime},s),r(s,a,s^{\prime}),\gamma>$  
**Output** Q-table


Initialize $Q(s,a)$, for all $s\in S$, $a\in A(s)$, arbitrarily except that $Q(terminal,.)=0$

**repeat**  
    $\quad$ Initialize $s, a$      
    $\quad$**for each step of episode**        
        $\quad\quad$**Choose** $a^{\prime}\in A(s)$ **by Q-table** (e.g., $\epsilon$-greedy)      
            $\quad\quad\quad Q(s,a)\leftarrow Q(s,a)+\alpha[r(s,a,s^{\prime})+\gamma Q(s^{\prime},a^{\prime})-Q(s,a)]$      
            $\quad\quad\quad s \leftarrow s^{\prime}$     
            $\quad\quad\quad a \leftarrow a^{\prime}$     
    $\quad\quad$**end unitl terminal**       
**end** 

---

In this 4X4 grid, we need to go to the end point with less steps, the upper left corner case and the lower right corner case are the end points. 

<img src="img/policy_iteration_example.png" width="400">  

In [None]:
import numpy as np



In [6]:
def transition(state, action):
    if state in [0, 15]:
        return state
    if action == 0: # up
        if state in [1, 2, 3]:
            return state
        else:
            return state - 4
    if action == 1: # down
        if state in [12, 13, 14]:
            return state
        else:
            return state + 4
    if action == 2: # left
        if state in [4, 8, 12]:
            return state
        else:
            return state - 1
    if action == 3: # right
        if state in [3, 7, 11]:
            return state
        else:
            return state + 1

def get_action(state, q, epsilon):
    # epsilon greedy policy
    if np.random.rand() < epsilon:
        action = np.random.choice([0, 1, 2, 3])
    else:
        # greedy policy
        action = np.argmax(q[state])
        # if there are multiple actions with the same max value, choose one randomly
        if list(q[state]).count(q[state][action]) > 1:
            action = np.random.choice(np.where(q[state] == q[state][action])[0])
    return action


def update_q(q, state, action, reward, next_state, alpha, gamma, method='q-learning'):
    # Q-learning update rule
    if method == 'q-learning':
        q[state, action] += alpha * (reward + gamma * np.max(q[next_state]) - q[state, action])
    elif method == 'sarsa':
        action_next = get_action(next_state, q, 0)
        q[state, action] += alpha * (reward + gamma * q[next_state, action_next] - q[state, action])
    return q

def q_learning(num_episodes=10000, alpha=0.1, gamma=0.9, epsilon=0.1):
    # Initialize variables
    q = np.zeros((16, 4))  # Q-table with 16 states and 4 actions

    for episode in range(num_episodes):
        state = np.random.randint(1, 16)  # Random initial state
        done = False

        while not done:
            action = get_action(state, q, epsilon)
            next_state = transition(state, action)

            if next_state in (0,15):  # Goal state
                reward = 1
                done = True
            else:
                reward = -0.01

            q = update_q(q, state, action, reward, next_state, alpha, gamma)
            state = next_state

    return q



def sarsa(num_episodes=10000, alpha=0.1, gamma=0.9, epsilon=0.1):
    # Initialize variables
    q = np.zeros((16, 4))  # Q-table with 16 states and 4 actions

    for episode in range(num_episodes):
        state = np.random.randint(1, 16)  # Random initial state
        done = False

        while not done:
            action = get_action(state, q, epsilon)
            next_state = transition(state, action)

            if next_state in (0,15):  # Goal state
                reward = 1
                done = True
            else:
                reward = -0.01

            q = update_q(q, state, action, reward, next_state, alpha, gamma, )
            state = next_state

    return q

def mapping_policy(pi):
    policy = dict()
    for k in range(1, 15):
        idx = np.where(pi[k] == np.max(pi[k]))[0]
        policy[k] = []
        if 0 in idx:
            policy[k].append('up')
        if 1 in idx:
            policy[k].append('down')
        if 2 in idx:
            policy[k].append('left')
        if 3 in idx:
            policy[k].append('right')
    return policy

In [7]:
print('---Q-learning---')
q = q_learning()
policy = mapping_policy(q)
for k, v in policy.items():
    print(f"State {k}: {v}")

print('---SARSA---')
q = sarsa()
policy = mapping_policy(q)
for k, v in policy.items():
    print(f"State {k}: {v}")

---Q-learning---
State 1: ['down']
State 2: ['down']
State 3: ['down']
State 4: ['right']
State 5: ['right']
State 6: ['right']
State 7: ['down']
State 8: ['down']
State 9: ['right']
State 10: ['down']
State 11: ['down']
State 12: ['right']
State 13: ['right']
State 14: ['right']
---SARSA---
State 1: ['down']
State 2: ['down']
State 3: ['down']
State 4: ['right']
State 5: ['right']
State 6: ['right']
State 7: ['down']
State 8: ['right']
State 9: ['right']
State 10: ['down']
State 11: ['down']
State 12: ['right']
State 13: ['right']
State 14: ['right']
