In [None]:
import os
from frozen_lake_env import FrozenLakeEnv
from monte_carlo_control import MonteCarloControl
from sarsa_control import SarsaControl
from q_learning_control import QLearningControl
from grid_inputs import grid_input_4x4, grid_input_10x10
import shutil

plots_dir = os.path.join(os.getcwd(), "plots")

if os.path.exists(plots_dir):
    shutil.rmtree(plots_dir)

In [None]:
env_4x4 = FrozenLakeEnv(grid_input=grid_input_4x4)   
num_of_episodes_4x4 = 5000
max_steps_4x4 = 100
epsilon_4x4 = 0.15

# Monte Carlo Control
mc_4x4 = MonteCarloControl(env=env_4x4, 
                           num_of_episodes=num_of_episodes_4x4, 
                           max_steps=max_steps_4x4, 
                           epsilon=0.3, 
                           plots_dir=os.path.join(os.getcwd(), "plots", "mc_4x4"))
mc_policy = mc_4x4.extract_optimal_policy()

# SARSA Control
sarsa = SarsaControl(env=env_4x4, 
                     num_of_episodes=num_of_episodes_4x4, 
                     max_steps=max_steps_4x4, 
                     epsilon=epsilon_4x4, 
                     plots_dir=os.path.join(os.getcwd(), "plots", "sarsa_4x4"))
sarsa_policy = sarsa.extract_optimal_policy()

# Q-Learning Control
q_learning = QLearningControl(env=env_4x4, 
                              num_of_episodes=num_of_episodes_4x4, 
                              max_steps=max_steps_4x4, 
                              epsilon=epsilon_4x4, 
                              plots_dir=os.path.join(os.getcwd(), "plots", "q_learning_4x4"))
q_learning_policy = q_learning.extract_optimal_policy()

In [None]:
env_10x10 = FrozenLakeEnv(grid_input=grid_input_10x10)
num_of_episodes_10x10 = 10000
max_steps_10x10 = 1000
epsilon_10x10 = 0.15

# Monte Carlo Control
mc_10x10 = MonteCarloControl(env=env_10x10, 
                            num_of_episodes=100000, 
                            max_steps=1000000, 
                            epsilon=1.0, 
                            epsilon_decay=0.9999,
                            min_epsilon=0.7,
                            gamma=0.98,
                            plots_dir=os.path.join(os.getcwd(), "plots", "mc_10x10"))
mc_policy_10x10 = mc_10x10.extract_optimal_policy()

# SARSA Control
sarsa_10x10 = SarsaControl(env=env_10x10, 
                           num_of_episodes=num_of_episodes_10x10, 
                           max_steps=max_steps_10x10, 
                           epsilon=epsilon_10x10, 
                           plots_dir=os.path.join(os.getcwd(), "plots", "sarsa_10x10"))
sarsa_policy_10x10 = sarsa_10x10.extract_optimal_policy()

# Q-Learning Control
q_learning_10x10 = QLearningControl(env=env_10x10, 
                                    num_of_episodes=num_of_episodes_10x10, 
                                    max_steps=max_steps_10x10, 
                                    epsilon=epsilon_10x10, 
                                    plots_dir=os.path.join(os.getcwd(), "plots", "q_learning_10x10"))
q_learning_policy_10x10 = q_learning_10x10.extract_optimal_policy()