In [1]:
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from typing import Optional
import random
from types import SimpleNamespace

from rl_final_project.control import AbstractControl
from rl_final_project.q_functions import QLinear, QTabular
from rl_final_project.control import MonteCarloControl, QLearningControl, \
    SarsaLambdaControl
from rl_final_project.agent import Agent
from rl_final_project.dqn import DQNFunction, DQNControl

import multiprocessing

n_cpu = multiprocessing.cpu_count()

In [2]:
from copy import copy

cmap = {
    "sarsa-lambda": SarsaLambdaControl,
    "q-learning": QLearningControl,
    "dqn": DQNControl,
    "monte-carlo": MonteCarloControl,
}

fmap = {
    "linear": QLinear,
    "tabular": QTabular,
}

def build_control_experiment(
        env: gym.Env,
        method: str, 
        gamma: float,
        function: Optional[str] = None,
        num_episodes: int = 10_000,
        replay_capacity: int = 10_000,
        n0: int = 10,
        discrete_scale: int = 1,
        batch_size: int = 128,
        eps_func: str = "dqn",
        stochasticity_factor: float = 0.4,
        method_args: Optional[dict] = None,
) -> AbstractControl:
    
    if method_args is None:
        method_args = {}
    
    if function is None and method != "dqn":
        raise ValueError("function must be specified for all methods except dqn")
    
    if function is not None and function not in fmap:
        raise ValueError(f"Unknown function approximation {function}")
    
    if method not in cmap:
        raise ValueError(f"Unknown control method {method}")
    
    n_actions = env.action_space.n
    n_states = env.observation_space.shape[0]
    
    if method == "dqn":       
        q_function = DQNFunction(
            batch_size=batch_size,
            n_actions=n_actions,
            n_feat=n_states,
            discrete_scale=discrete_scale,
        )
    else:
        q_function = fmap[function](
            n_actions=n_actions,
            n_feat=n_states,
            discrete_scale=discrete_scale
        )
        
    agent = Agent(
        q_function, 
        n0=n0, 
        n_actions=n_actions, 
        eps_greedy_function=eps_func,
        stochasticity_factor=stochasticity_factor,
    )
    
    for k in copy(method_args):
        if method_args[k] is None:
            del method_args[k]
    
    control = cmap[method](
        env=env,
        agent=agent,
        num_episodes=num_episodes,
        gamma=gamma,
        batch_size=batch_size,
        replay_capacity=replay_capacity,
        verbose=False,
        **method_args
    )
    
    return control
        
exp_config = SimpleNamespace(
    env_name="CartPole-v0",
    num_episodes=10_000,
    control_algorithm=SimpleNamespace(
        method="sarsa", 
        args=SimpleNamespace(
            λ=0.8
        ),
        function="linear",
    ),
    gamma=0.9,
    discrete_scale=10,
    n0=10,
    eps_func="dqn",
    stochasticity_factor=0.4,
    reward_mode="normal",
    seed=0
)

In [3]:
from rl_final_project.environment import EnvironmentNormalizer


class ExperimentExitCode:
    SUCCESS = 0
    FAILED = 1
    INVALID = 2

def run(config: SimpleNamespace) -> tuple[ExperimentExitCode, SimpleNamespace, list[float]]:
    # Initialize random seed
    np.random.seed(config.seed)
    random.seed(config.seed)
    
    # Initialize environment
    env = gym.make(config.env_name)
    env = EnvironmentNormalizer(env)
    
    try: 
        # Initialize control algorithm
        control_algorithm = build_control_experiment(
            env=env,
            method=config.control_algorithm.method,
            gamma=config.gamma,
            function=config.control_algorithm.function,
            num_episodes=config.num_episodes,
            replay_capacity=config.replay_capacity,
            batch_size=config.batch_size,
            n0=config.n0,
            discrete_scale=config.discrete_scale,
            eps_func=config.eps_func,
            stochasticity_factor=config.stochasticity_factor,
            method_args=config.control_algorithm.args.__dict__,
        )
    except Exception:
        return ExperimentExitCode.INVALID, config, []
    
    # Run control algorithm
    try:
        eps_rewards = control_algorithm.fit()
    except Exception:
        return ExperimentExitCode.FAILED, config, []
    
    return ExperimentExitCode.SUCCESS, config, eps_rewards

In [4]:
def experiment_generator_grid_search() -> SimpleNamespace:
    for env_name in ["CartPole-v0"]:
        for method in ["q-learning", "monte-carlo", "sarsa-lambda"]:
            for function in ["linear", "tabular", None]:
                for gamma in [0.5, 0.95]:
                    for λ in [None, 0.3, 0.6, 0.9]:
                        for n0 in [2, 5, 10, 25, None]:
                            for eps_func in ["s"]:
                                for discrete_scale in [2, 5, 20]:
                                #for stochasticity_factor in [0.0, 0.25, 0.5]:
                                #    for reward_mode in ["normal", "sparse"]:
                                        for seed in range(3): 
                                            yield SimpleNamespace(
                                                env_name=env_name,
                                                num_episodes=10_000 if method != "dqn" else 2_000,
                                                control_algorithm=SimpleNamespace(
                                                    method=method, 
                                                    args=SimpleNamespace(
                                                        lambda_factor=λ,
                                                    ),
                                                    function=function,
                                                ),
                                                gamma=gamma,
                                                replay_capacity=10_000,
                                                batch_size=1 if method != "dqn" else 128,
                                                discrete_scale=discrete_scale,
                                                n0=n0,
                                                eps_func=eps_func,
                                                stochasticity_factor=0.0,
                                                reward_mode="normal",
                                                seed=seed
                                            )

In [5]:
from multiprocessing import Pool

experiments = list(experiment_generator_grid_search())
results = []

with Pool(n_cpu) as p:
    for exit_code, config, eps_rewards in tqdm(p.imap_unordered(run, experiments), total=len(experiments)):
        if exit_code == ExperimentExitCode.SUCCESS:
            results.append((config, eps_rewards))
            

  logger.deprecation(
  logger.deprecation(
  logger.deprecation(
  logger.deprecation(
  logger.deprecation(
  logger.deprecation(
  logger.deprecation(
  logger.deprecation(
  logger.deprecation(
  logger.deprecation(
  logger.deprecation(
  logger.deprecation(
  logger.deprecation(
  logger.deprecation(
  logger.deprecation(
  logger.deprecation(
  logger.deprecation(
  logger.deprecation(
  logger.deprecation(
  logger.deprecation(


  0%|          | 0/3240 [00:00<?, ?it/s]

Process ForkPoolWorker-14:
Process ForkPoolWorker-6:
Process ForkPoolWorker-18:
Process ForkPoolWorker-5:
Process ForkPoolWorker-12:
Process ForkPoolWorker-13:
Process ForkPoolWorker-10:
Process ForkPoolWorker-3:
Process ForkPoolWorker-4:
Process ForkPoolWorker-8:
Process ForkPoolWorker-17:
Process ForkPoolWorker-2:
Process ForkPoolWorker-16:
Process ForkPoolWorker-20:
Process ForkPoolWorker-9:
Process ForkPoolWorker-7:
Process ForkPoolWorker-19:
Process ForkPoolWorker-15:
Process ForkPoolWorker-11:
Process ForkPoolWorker-1:
Traceback (most recent call last):
TypeError: issubclass() arg 1 must be a class
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  F

KeyboardInterrupt: 