In [None]:
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from typing import Optional
import random
from types import SimpleNamespace

from rl_final_project.control import AbstractControl
from rl_final_project.q_functions import QLinear, QTabular
from rl_final_project.control import MonteCarloControl, QLearningControl, \
    SarsaLambdaControl
from rl_final_project.agent import Agent
from rl_final_project.dqn import DQNFunction, DQNControl

In [None]:



cmap = {
    "sarsa-lambda": SarsaLambdaControl,
    "q-learning": QLearningControl,
    "dqn": DQNControl,
    "monte-carlo": MonteCarloControl,
}

fmap = {
    "linear": QLinear,
    "tabular": QTabular,
}

def build_control_experiment(
        env: gym.Env,
        method: str, 
        gamma: float,
        function: Optional[str] = None,
        num_episodes: int = 10_000,
        replay_capacity: int = 10_000,
        n0: int = 10,
        discrete_scale: int = 1,
        batch_size: int = 128,
        eps_func: str = "dqn",
        stochasticity_factor: float = 0.4,
        method_args: Optional[dict] = None,
) -> AbstractControl:
    
    if method_args is None:
        method_args = {}
    
    if function is None and method != "dqn":
        raise ValueError("function must be specified for all methods except dqn")
    
    if function is not None and function not in fmap:
        raise ValueError(f"Unknown function approximation {function}")
    
    if method not in cmap:
        raise ValueError(f"Unknown control method {method}")
    
    n_actions = env.action_space.n
    n_states = env.observation_space.shape[0]
    
    if method == "dqn":       
        q_function = DQNFunction(
            batch_size=batch_size,
            n_actions=n_actions,
            n_feat=n_states,
            discrete_scale=discrete_scale,
        )
    else:
        q_function = fmap[function](
            n_actions=n_actions,
            n_feat=n_states,
            discrete_scale=discrete_scale
        )
        
    agent = Agent(
        q_function, 
        n0=n0, 
        n_actions=n_actions, 
        eps_greedy_function=eps_func,
        stochasticity_factor=stochasticity_factor,
    )
    
    control = cmap[method](
        env=env,
        agent=agent,
        num_episodes=num_episodes,
        gamma=gamma,
        batch_size=batch_size,
        replay_capacity=replay_capacity,
        **method_args
    )
    
    return control
        
exp_config = SimpleNamespace(
    env_name="CartPole-v0",
    num_episodes=10_000,
    control_algorithm=SimpleNamespace(
        method="sarsa", 
        args=SimpleNamespace(
            λ=0.8
        ),
        function="linear",
    ),
    gamma=0.9,
    discrete_scale=10,
    n0=10,
    eps_func="dqn",
    stochasticity_factor=0.4,
    reward_mode="normal",
    seed=0
)

In [None]:
def run(config: SimpleNamespace) -> tuple[dict, list[float]]:
    # Initialize random seed
    np.random.seed(config.seed)
    random.seed(config.seed)
    
    # Initialize environment
    env = gym.make(config.env_name)
    
    # Initialize control algorithm
    control_algorithm: AbstractControl 
    
    