<a href="https://colab.research.google.com/github/mcnica89/MATH4060/blob/main/Week6_Temporal_Difference_Learning_Example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import itertools
import jax.numpy as jnp
import matplotlib.pyplot as plt
from jax import random as jrandom
from jax import nn as jnn
from jax import jit
import random
import time
import timeit
import math
import numpy as np

plt.rc('xtick', labelsize=15) 
plt.rc('ytick', labelsize=15)
font = {'size'   : 20}

plt.rc('font', **font)

# Gambler's Ruin (aka Drunkard's Walk)

Consider the statespace $[[0,N_{target}]] \subset \mathbb{Z}$ which represetnt possible values for a gambler's wealth. Let $X_n$ be a Markov chain representing the Gambler's wealth after $n$ bets. If they have no money, $0$, then they cannot make any bets:

$$P( X_{n+1} = 0 | X_n = 0) =1$$

If they reach their target $N_{target}$, then they choose to stop gambling.


$$P( X_{n+1} = N_{target} | X_n = N_{target}) =1$$

Otherwise, they bet and their fortune goes up or down by exactly $1$:

$$P( X_{n+1} = x+1 | X_n = x) =p$$
$$P( X_{n+1} = x-1 | X_n = x) =1-p$$

The gambler keeps playing until they either go broke or reach their target. Assume the Gambler gains a reward of 1 life point if they get to $N_{target}$ and the gain a reward of 0 if they reach their target. Determine

$$E[\text{Reward}|X_0 = x] = P(\text{Reach target}|X_0 = x)$$


# Monte Carlo Method

In [45]:
def monte_carlo_method(learning_rate, N_episodes, N_target):
  value_function = np.zeros(N_target+1)
  num_visits = np.zeros(N_target+1)

  for i in range(N_episodes):
    x_0 = random.randint(1,N_target-1)
    current_state = x_0

    current_time = 0
    max_time = 1000
    state_history = np.zeros(max_time+1,dtype=int)
    reward_history = np.zeros(max_time+1)

    while True: 
      state_history[current_time] = current_state
      if current_state==0 or current_state==N_target or current_time >= max_time:
        reward_history[current_time] = (current_state==N_target)
        break
      else:
        reward_history[current_time] = 0   
      
      #THIS IS THE MOST IMPORTANT LINE OF PART I:
      #current_state = gridworld_next_state(current_state,policy[current_state])
      
      current_state += 2*random.randint(0,1)-1
      current_time += 1

    t_final = current_time
    
    
    #PART II
    #Now update the value function for states we saw this episode
    #We are using the first visit monte carlo rule!
    
    visited_already = np.zeros(16,dtype=bool)
    for t in range(0,t_final+1):
      state = state_history[t]
      if visited_already[state] == False:
        G = np.sum(reward_history[t:t_final+1])
        num_visits[state] += 1
        
        #This line actually computes the AVERAGE of G over all the visits!
        #Homework: Figure out why this line is actually computing the average!
        #i..e value_function[state] will be = average of G's over all observations
        
        value_function[state] += learning_rate*(G - value_function[state])

        visited_already[state] = True
  
  return value_function

In [50]:
monte_carlo_method(0.05,100000, 10)

array([0.        , 0.12626279, 0.18175374, 0.23327819, 0.32603477,
       0.53419425, 0.60018212, 0.75812825, 0.78688499, 0.89392174,
       1.        ])

# TD0 Method (Temperal Difference Learning with 0 memory)

In [49]:
def TD0_method(learning_rate, N_samples, N_target):

  #This code is very similar to the naive Monte Carlo method
  #Instead of keeping track of total_rewards, we update the function V as we go

  value_func = np.zeros(N_target+1)
  value_func[N_target] = 1
  
  for i in range(N_samples):
    x_0 = random.randint(1,N_target-1)
    X = x_0
    
    while X > 0 and X < N_target:
      new_X = X + 2*random.randint(0,1)-1
      delta_value = value_func[new_X] - value_func[X]
      value_func[X] += learning_rate*( delta_value  )
      X = new_X

  return value_func

In [51]:
TD0_method(0.1, 100000, 10)

array([0.        , 0.1209138 , 0.21034023, 0.28481609, 0.35772368,
       0.45067095, 0.51319773, 0.69971957, 0.80599985, 0.89852383,
       1.        ])

#Time comparisons

In [57]:
%time monte_carlo_method(0.01, 100000, 10)

CPU times: user 14.2 s, sys: 1.1 s, total: 15.3 s
Wall time: 14.1 s


array([0.        , 0.07245991, 0.1765066 , 0.2687439 , 0.3971212 ,
       0.49495066, 0.60143898, 0.68979404, 0.78990505, 0.8927356 ,
       1.        ])

In [58]:
%time TD0_method(0.01, 300000, 10)

CPU times: user 12.2 s, sys: 1.48 ms, total: 12.2 s
Wall time: 12.2 s


array([0.        , 0.08949829, 0.20449964, 0.2966078 , 0.40519875,
       0.51030072, 0.62275848, 0.7215919 , 0.81953907, 0.90139822,
       1.        ])