In [None]:
import sys
from pathlib import Path

sys.path.append(str(Path.cwd().resolve().parent))

In [None]:
from experiments.run_experiment import (
    TrainingConfig,
    EvaluateConfig,
)
from experiments.run_sweep import run_sweep, plot_sweep_training, plot_sweep_evaluation
from environments.fronzenlake import FrozenLakeConfig, get_frozenlake_env
from agents.sarsa_td0 import SarsaTD0Agent, SarsaTD0Config
from agents.sarsa_td0_variant import (
    SarsaTD0VariantAgent,
    SarsaTD0VariantConfig,
)   
from metrics.learning_mertrics import total_reward_per_episode, episode_won
from metrics.frustration_metrics import (
    total_td_error_per_episode,
    frustration_rate_per_episode,
    tail_frustration_per_episode,
    cvar_tail_frustration_per_episode,
)

In [None]:
env_config = CliffWalkingConfig()

In [None]:
# env_config = FrozenLakeConfig(map_name="4x4", is_slippery=False)
env_config = FrozenLakeConfig(
    map_name="4x4",
    is_slippery=True,
    reward_schedule=(1.0, -1.0, -0.1),
)

In [None]:
# sarsa_td0 agent
agent_factory = SarsaTD0Agent

sarsa_td0_config = SarsaTD0Config(
    alpha=0.05,
    gamma=0.99,
    epsilon=0.3,
    reward_metrics={
        "total_reward_per_episode": total_reward_per_episode,
        "episode_won": episode_won,
    },
    td_error_metrics={
        "total_td_error_per_episode": total_td_error_per_episode,
        "frustration_rate_per_episode": frustration_rate_per_episode,
        "tail_frustration_per_episode": lambda td: tail_frustration_per_episode(
            td, percentile=0.90
        ),
        "cvar_tail_frustration_per_episode": lambda td: cvar_tail_frustration_per_episode(
            td, percentile=0.90
        ),
    },
)

In [None]:
# sarsa_td0 agent
agent_factory = SarsaTD0LowerExpectationsAgent

sarsa_td0_config = SarsaTD0LowerExpectationsConfig(
    alpha=0.05,
    gamma=0.99,
    lower_expectations=0.0,
    epsilon=0.3,
    reward_metrics={
        "total_reward_per_episode": total_reward_per_episode,
        "episode_won": episode_won,
    },
    td_error_metrics={
        "total_td_error_per_episode": total_td_error_per_episode,
        "frustration_rate_per_episode": frustration_rate_per_episode,
        "tail_frustration_per_episode": lambda td: tail_frustration_per_episode(
            td, percentile=0.90
        ),
        "cvar_tail_frustration_per_episode": lambda td: cvar_tail_frustration_per_episode(
            td, percentile=0.90
        ),
    },
)

In [None]:
# sarsa_td0 agent
agent_factory = SarsaTD0VariantAgent

sarsa_td0_config = SarsaTD0VariantConfig(
    alpha_positive=0.2,
    alpha_negative=0.2,
    gamma=0.99,
    epsilon=0.3,
    reward_metrics={
        "total_reward_per_episode": total_reward_per_episode,
        "episode_won": episode_won,
    },
    td_error_metrics={
        "total_td_error_per_episode": total_td_error_per_episode,
        "frustration_rate_per_episode": frustration_rate_per_episode,
        "tail_frustration_per_episode": lambda td: tail_frustration_per_episode(
            td, percentile=0.90
        ),
        "cvar_tail_frustration_per_episode": lambda td: cvar_tail_frustration_per_episode(
            td, percentile=0.90
        ),
    },
)

In [None]:
base_train = TrainingConfig(
    name="sarsa_frozenlake",
    num_train_episodes=12000,
    env_kwargs={"config": env_config},
    agent_kwargs={"config": sarsa_td0_config},
)

base_eval = EvaluateConfig(
    name="sarsa_frozenlake",
    num_eval_episodes=2000,
    env_kwargs={"config": env_config},
)

In [None]:
base_train = TrainingConfig(
    name="sarsa_cliffwalking",
    num_train_episodes=100,
    env_kwargs={"config": env_config},
    agent_kwargs={"config": sarsa_td0_config},
)

base_eval = EvaluateConfig(
    name="sarsa_cliffwalking",
    num_eval_episodes=10,
    env_kwargs={"config": env_config},
)

In [None]:
sweep = {
    "agent_kwargs": [
        {"alpha_positive": ap, "alpha_negative": an, "seed": 1}
        for ap, an in [
            (0.05, 0.05),
            (0.1, 0.1),
            (0.2, 0.2),
            (0.1, 0.05),
            (0.2, 0.05),
            (0.3, 0.1),
            (0.05, 0.1),
        ]
    ],
}

In [None]:
results = run_sweep(
    base_training=base_train,
    base_evaluation=base_eval,
    sweep=sweep,
    env_factory=get_frozenlake_env,
    agent_factory=agent_factory,
)

In [None]:
plot_sweep_training(results, window_size=100)

In [None]:
plot_sweep_evaluation(results)