<a href="https://colab.research.google.com/github/mcnica89/MATH4060/blob/main/Week_6_Temporal_Difference_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import itertools
import jax.numpy as jnp
import matplotlib.pyplot as plt
from jax import random as jrandom
from jax import nn as jnn
from jax import jit
import random
import time
import timeit
import math
import numpy as np

plt.rc('xtick', labelsize=15) 
plt.rc('ytick', labelsize=15)
font = {'size'   : 20}

plt.rc('font', **font)

# Gambler's Ruin (aka Drunkard's Walk)

Consider the statespace $[[0,N_{target}]] \subset \mathbb{Z}$ which represetnt possible values for a gambler's wealth. Let $X_n$ be a Markov chain representing the Gambler's wealth after $n$ bets. If they have no money, $0$, then they cannot make any bets:

$$P( X_{n+1} = 0 | X_n = 0) =1$$

If they reach their target $N_{target}$, then they choose to stop gambling.


$$P( X_{n+1} = N_{target} | X_n = N_{target}) =1$$

Otherwise, they bet and their fortune goes up or down by exactly $1$:

$$P( X_{n+1} = x+1 | X_n = x) =p$$
$$P( X_{n+1} = x-1 | X_n = x) =1-p$$

The gambler keeps playing until they either go broke or reach their target. Assume the Gambler gains a reward of 1 life point if they get to $N_{target}$ and the gain a reward of 0 if they reach their target. Determine

$$E[\text{Reward}|X_0 = x] = P(\text{Reach target}|X_0 = x)$$


# Monte Carlo Method

In [None]:
def monte_carlo_method(learning_rate, N_episodes, N_target):
  value_function = np.zeros(N_target+1)
  num_visits = np.zeros(N_target+1)

  for i in range(N_episodes):
    x_0 = random.randint(1,N_target-1)
    current_state = x_0

    current_time = 0
    max_time = 10*N_target**2
    state_history = np.zeros(max_time+1,dtype=int)
    reward_history = np.zeros(max_time+1)

    while True: 
      state_history[current_time] = current_state
      if current_state==0 or current_state==N_target or current_time >= max_time:
        reward_history[current_time] = (current_state==N_target)
        break
      else:
        reward_history[current_time] = 0   
      
      #THIS IS THE MOST IMPORTANT LINE OF PART I:
      #current_state = gridworld_next_state(current_state,policy[current_state])
      
      current_state += 2*random.randint(0,1)-1
      current_time += 1

    t_final = current_time
    
    
    #PART II
    #Now update the value function for states we saw this episode
    #We are using the first visit monte carlo rule!
    
    visited_already = np.zeros(N_target+1,dtype=bool)
    for t in range(0,t_final+1):
      state = state_history[t]
      if visited_already[state] == False:
        G = np.sum(reward_history[t:t_final+1])
        num_visits[state] += 1
        
        #This line actually computes the AVERAGE of G over all the visits!
        #Homework: Figure out why this line is actually computing the average!
        #i..e value_function[state] will be = average of G's over all observations
        value_function[state] += learning_rate*(G - value_function[state])
                

        visited_already[state] = True
  
  return value_function

In [None]:
monte_carlo_method(0.05,100000, 10)

array([0.  , 0.1 , 0.17, 0.19, 0.35, 0.42, 0.58, 0.72, 0.82, 0.91, 1.  ])

# TD0 Method (Temperal Difference Learning with 0 memory)

In [None]:
def TD0_method(learning_rate, N_samples, N_target):

  #This code is very similar to the naive Monte Carlo method
  #Instead of keeping track of total_rewards, we update the function V as we go

  value_func = np.random.rand(N_target+1)
  #Warning: initializing with value_func = all zeros will lead to worse performance!
  #Ideally want them to be "on average" correct
  
  #This works too: value_func = 0.5*np.ones(N_target+1)
  #This is bad: value_func = np.zeros(N_target+1)
  value_func[N_target] = 1
  value_func[0] = 0
  
  for i in range(N_samples):
    x_0 = random.randint(1,N_target-1)
    X = x_0
    
    while X > 0 and X < N_target:
      new_X = X + 2*random.randint(0,1)-1
      delta_value = value_func[new_X] - value_func[X]
      value_func[X] += learning_rate*( delta_value  )
      X = new_X

  return value_func

In [None]:
TD0_method(0.1, 100000, 10)

array([0.  , 0.07, 0.17, 0.26, 0.37, 0.46, 0.58, 0.74, 0.82, 0.93, 1.  ])

#Time comparisons

In [None]:
def numpy_rms(x):
  #Return the root-mean-square of a sequence
  return np.sqrt(np.mean(np.square(x)))

In [None]:
N_target = 100
N_samples = 5*10**3
exact_answer = np.arange(0,N_target+1)/N_target
np.set_printoptions(precision=2,suppress=True)

In [None]:
%time val_MC = monte_carlo_method(0.1, N_samples, N_target)
print(val_MC)
print( numpy_rms(val_MC-exact_answer) )

CPU times: user 30.5 s, sys: 406 ms, total: 30.9 s
Wall time: 30.5 s
[0.   0.03 0.03 0.08 0.08 0.11 0.11 0.12 0.13 0.13 0.13 0.14 0.14 0.14
 0.14 0.17 0.17 0.17 0.17 0.17 0.21 0.21 0.21 0.23 0.23 0.23 0.23 0.25
 0.27 0.27 0.3  0.31 0.31 0.31 0.32 0.32 0.32 0.32 0.32 0.36 0.36 0.36
 0.36 0.36 0.36 0.39 0.39 0.43 0.43 0.43 0.43 0.43 0.44 0.45 0.45 0.45
 0.45 0.45 0.45 0.45 0.45 0.46 0.46 0.46 0.48 0.54 0.54 0.55 0.55 0.61
 0.68 0.68 0.7  0.72 0.72 0.73 0.77 0.77 0.77 0.77 0.78 0.78 0.78 0.78
 0.78 0.78 0.78 0.78 0.78 0.78 0.87 0.87 0.87 0.88 0.98 0.98 0.98 0.98
 0.98 0.98 1.  ]
0.06435515430893131


In [None]:
%time val_TD0 = TD0_method(0.1, N_samples, N_target)
print(val_TD0)
print( numpy_rms(val_TD0-exact_answer) )

CPU times: user 17.6 s, sys: 23.8 ms, total: 17.6 s
Wall time: 17.6 s
[0.   0.   0.01 0.02 0.03 0.03 0.04 0.04 0.05 0.06 0.07 0.07 0.08 0.09
 0.1  0.11 0.12 0.13 0.14 0.14 0.15 0.16 0.17 0.18 0.2  0.21 0.22 0.22
 0.23 0.24 0.26 0.27 0.29 0.3  0.31 0.32 0.33 0.34 0.35 0.37 0.38 0.4
 0.41 0.42 0.43 0.44 0.46 0.47 0.48 0.49 0.5  0.51 0.52 0.53 0.55 0.56
 0.57 0.57 0.59 0.6  0.62 0.62 0.64 0.64 0.65 0.66 0.67 0.68 0.69 0.7
 0.71 0.72 0.74 0.75 0.76 0.78 0.79 0.8  0.82 0.84 0.85 0.86 0.86 0.87
 0.88 0.89 0.9  0.91 0.92 0.92 0.93 0.94 0.95 0.96 0.96 0.97 0.98 0.98
 0.99 0.99 1.  ]
0.028429865765165735


In [None]:
%time TD0_method(0.01, 3000, 100)

CPU times: user 11.1 s, sys: 25.7 ms, total: 11.2 s
Wall time: 11.2 s


array([0.  , 0.03, 0.06, 0.09, 0.12, 0.14, 0.17, 0.19, 0.21, 0.23, 0.25,
       0.27, 0.29, 0.3 , 0.32, 0.33, 0.34, 0.36, 0.37, 0.38, 0.39, 0.39,
       0.4 , 0.41, 0.42, 0.42, 0.43, 0.43, 0.44, 0.44, 0.45, 0.45, 0.45,
       0.46, 0.46, 0.46, 0.46, 0.47, 0.47, 0.47, 0.47, 0.47, 0.48, 0.48,
       0.48, 0.48, 0.48, 0.48, 0.49, 0.49, 0.49, 0.49, 0.49, 0.49, 0.5 ,
       0.5 , 0.5 , 0.5 , 0.5 , 0.51, 0.51, 0.51, 0.51, 0.52, 0.52, 0.52,
       0.53, 0.53, 0.53, 0.54, 0.54, 0.55, 0.55, 0.56, 0.57, 0.57, 0.58,
       0.59, 0.59, 0.6 , 0.61, 0.62, 0.63, 0.64, 0.65, 0.66, 0.68, 0.69,
       0.7 , 0.72, 0.74, 0.76, 0.78, 0.8 , 0.82, 0.85, 0.87, 0.9 , 0.93,
       0.97, 1.  ])