d-sandbox

<div style="text-align: center; line-height: 0; padding-top: 9px;">
  <img src="https://databricks.com/wp-content/uploads/2018/03/db-academy-rgb-1200px.png" alt="Databricks Learning" style="width: 1200px">
</div>

# Distributed SARSA Control - Gridworld Problem

## ![Spark Logo Tiny](https://files.training.databricks.com/images/105/logo_spark_tiny.png) In this lab you learn:<br>
 - Finding Optimal Policy Using Distributed SARSA
 
Refrences: [Horovod](https://horovod.readthedocs.io/en/latest/)

### Problem Statement ###

In this lab we are going to find an optimal policy in a **distributed** fashion by using **Horovod**. Keep in mind that we do NOT know the dynamic of the environment nor are we given the MDP i.e. this is the full RL prediction problem. We created an environment for this gridworld problem earlier. We are going to use that environment to develop SARSA algorithm to find optimal policy.

![Prediction](https://files.training.databricks.com/images/rl/prediction.png)

## What is Horovod?
  
"Horovod is a distributed deep learning training framework for TensorFlow, Keras, PyTorch, and Apache MXNet. It was developed at Uber. The goal of Horovod is to make distributed deep learning fast and easy to use. The primary motivation for Horovod is to make it easy to take a single-GPU training script and successfully scale it to train across many GPUs in parallel". For more information, see [Horovod](https://github.com/uber/horovod).

In [5]:
%run "./Labs/helper/GridWorldEnvironment"

In [6]:
environment = GridWorldEnvironment()

In [7]:
import numpy as np
import random
np.random.seed(1234)


def pick_action(Q, na, epsilon=0.1):
  """This function picks the greedy action according to e-greedy algorithm."""
  
  actions = set([0,1,2,3])
  
  # Pick the greedy action with probability greater than epsilon/na 
  if np.random.rand() >= epsilon/na :
    return np.argmax(Q)
  # Pick a non-greedy action with a probability of epsilon/na 
  else:
    return random.sample(actions - set([np.argmax(Q)]), 1)[0]
      
  
def sarsa(Q, ns, na, gamma=1.0, epsilon=0.1, alpha=0.1, number_of_iterations=100000):
  """This function implements SARSA algorithm."""
 
  # Number of time steps
  time_step = 1000
  
  # Create samples of episodes
  for i in range(number_of_iterations):
    if i%10000 == 0:
      print(f"This is iteration {i+1}")
    # Initial start point
    start_state_index = random.randint(1,14)
    action_index = pick_action(Q[start_state_index][:], na)
    environment.set_state(start_state_index)
    
    for j in range(time_step):
      # Take an action and observe next state, reward and whether or not we are at the terminal points
      next_state, reward, is_done, _ = environment.step(action_index)
      
      next_action_index = pick_action(Q[next_state][:], na)
      #Update the Q
      Q[start_state_index][action_index] = Q[start_state_index][action_index] + alpha * (reward + gamma * Q[next_state][next_action_index] - Q[start_state_index][action_index])
      
      # Leave the loop if at terminal points
      if is_done:
        break
      start_state_index = next_state
      action_index = next_action_index
         
  return Q

In [8]:
import numpy as np
import horovod.tensorflow.keras as hvd

def run(Q, na, ns):
  """This function runs SARSA algorithm in a distributed fashion using Horovod."""
  
  hvd.init()
  Q = np.zeros([ns, na])
  new_q = sarsa(Q, ns, na)
  return new_q
  

In [9]:
import horovod.spark

def main(iteration):
  """This function invokes Horovod for multiple iterations."""
  ns = 16
  na = 4
  Q = np.zeros([ns, na])
  for i in range(iteration):
    Q = np.array(horovod.spark.run(run, (Q, na, ns)))
  return Q[0]

In [10]:
Q = main(10)

In [11]:
# Test your code
value_expected = [[ 0.0,  0.0,  0.0,  0.0],
       [-2.0, -3.0, -3.0, -1.0],
       [-3.0, -4.0, -4.0, -2.0],
       [-4.1, -4.0, -3.3, -3.1],
       [-1.0, -3.0, -3.0, -2.0],
       [-2.0, -4.0, -4.0, -2.3],
       [-3.3, -3.0, -3.0, -3.0],
       [-4.0, -3.0, -2.0, -4.0],
       [-2.0, -4.0, -4.0, -3.0],
       [-3.0, -3.0, -3.0, -3.0],
       [-4.0, -2.0, -2.0, -4.0],
       [-3.0, -2.0, -1.0, -3.0],
       [-3.0, -3.0, -4.0, -4.0],
       [-4.0, -2.0, -3.0, -4.0],
       [-3.1, -1.0, -2.0, -3.0],
       [ 0.0,  0.0,  0.0,  0.0]]  
np.testing.assert_array_almost_equal(Q, value_expected, err_msg = "The values are incorrect", decimal = 0)


-sandbox
&copy; 2020 Databricks, Inc. All rights reserved.<br/>
Apache, Apache Spark, Spark and the Spark logo are trademarks of the <a href="http://www.apache.org/">Apache Software Foundation</a>.<br/>
<br/>
<a href="https://databricks.com/privacy-policy">Privacy Policy</a> | <a href="https://databricks.com/terms-of-use">Terms of Use</a> | <a href="http://help.databricks.com/">Support</a>