d-sandbox

<div style="text-align: center; line-height: 0; padding-top: 9px;">
  <img src="https://databricks.com/wp-content/uploads/2018/03/db-academy-rgb-1200px.png" alt="Databricks Learning" style="width: 1200px">
</div>

# Model-free Prediction (first-visit MC) 

## ![Spark Logo Tiny](https://files.training.databricks.com/images/105/logo_spark_tiny.png) In this lab you learn:<br>
 - MC first visit update

### Problem statement ###

Consider the random walk problem on a straight line (1D) with 7 states. States t1, a, b, c, d, e, t2. Assume two end states, t1 and t2, are terminal states. If you end up there, the episode is over. If you reach t2, you are rewarded $100. If you reach t1, you are rewarded $0. There is a 50% chance of going right or left when you are at state a, b, c, d and e.

### Exercise ###
Assume you start at state c and \\(\gamma = 1\\). Use first-visit MC to estimate the value of each state.

In [5]:
#ANSWER
import random
import numpy as np
from statistics import mean


def monte_carlo(number_episodes, start_state='c'):
  """This function generates multiple episodes for the random walk problem."""
  
  np.random.seed(1234)
  # Initial transition list
  episodes = []

  
  # Set up all states
  states = ['t1','a', 'b', 'c', 'd', 'e', 't2']
  
  # Create samples of episodes
  for i in range (number_episodes):
    local_transitions = []
    index = states.index(start_state)
    local_transitions.append(index)
    while True:
      pick_action = np.random.randint(2, size = 1)[0]
      # Go left
      if pick_action == 0:
        index -= 1
      else:
        index += 1
      start_index = index
      local_transitions.append(index)
      if states[index] == 't1' or states[index] == 't2':
        episodes.append(local_transitions)
        break
    
  return(episodes) 


In [6]:
# ANSWER
def mc_first_visit(episodes):
  """This function calculate first-visit MC."""

  value_a = value_b = value_c = value_d = value_e = 0.5
  reward_t2 = 100
  reward_t1 = 0
  
  return_a = []
  return_b = []
  return_c = []
  return_d = []
  return_e = []
  
  # Loop through episodes and states to update return_a, return_b, return_c, return_d, return_e
  
  for i in range(len(episodes)):
    if i % 1000 == 0:
      print(f"This is episode {i+1}")
    for j in list(set(episodes[i])):
      
      if j == 1.0 and episodes[i][-1] == 6:
        return_a.append(reward_t2)
        
      elif j == 1.0 and episodes[i][-1] == 0: 
        return_a.append(reward_t1)
        
      
      elif j == 2.0 and episodes[i][-1] == 6:
        return_b.append(reward_t2)
        
      elif j == 2.0 and episodes[i][-1] == 0:
        return_b.append(reward_t1)
        
        
      elif j == 3.0 and episodes[i][-1] == 6:
        return_c.append(reward_t2)
        
      elif j == 3.0 and episodes[i][-1] == 0:
        return_c.append(reward_t1)
        
        
      elif j == 4.0 and episodes[i][-1] == 6:
        return_d.append(reward_t2)
        
      elif j == 4.0 and episodes[i][-1] == 0:
        return_d.append(reward_t1)
        
        
      elif j == 5.0 and episodes[i][-1] == 6:
        return_e.append(reward_t2)
        
      elif j == 5.0 and episodes[i][-1] == 0:
        return_e.append(reward_t1)
      
        
    value_a = mean(return_a) if len(return_a) !=0  else value_a
    value_b = mean(return_b) if len(return_b) !=0  else value_b
    value_c = mean(return_c) if len(return_c) !=0  else value_c
    value_d = mean(return_d) if len(return_d) !=0  else value_d
    value_e = mean(return_e) if len(return_e) !=0  else value_e
    
  return [value_a, value_b, value_c, value_d, value_e]

In [7]:
# Test your code
episodes = monte_carlo(2000)
value = mc_first_visit(episodes) 
value_expected = [15.0, 32.0, 49.0, 66.0, 82.0] 
np.testing.assert_array_almost_equal(value, value_expected, err_msg = "The values are incorrect", decimal = 0)


-sandbox
&copy; 2020 Databricks, Inc. All rights reserved.<br/>
Apache, Apache Spark, Spark and the Spark logo are trademarks of the <a href="http://www.apache.org/">Apache Software Foundation</a>.<br/>
<br/>
<a href="https://databricks.com/privacy-policy">Privacy Policy</a> | <a href="https://databricks.com/terms-of-use">Terms of Use</a> | <a href="http://help.databricks.com/">Support</a>