# Run notebooks

In [9]:
import time

import papermill as pm

## Constants

In [10]:
RL_NOTEBOOKS = ["a2c", "reinforce", "dqn"]
BASELINE_NOTEBOOKS = [
    "cnn1d",
    "cnn2d",
    "lstm",
    "bilstm",
    "lstm_attn",
    "bilstm_attn",
]

In [11]:
NB_DICT: dict[str, tuple[str, dict]] = {
    "a2c": ("3.0-a2c", {}),
    "reinforce": ("3.1-reinforce", {}),
    "dqn": ("3.2-dqn", {}),
    "cnn1d": ("4.1-cnn", {"CNN_DIM": 1}),
    "cnn2d": ("4.1-cnn", {"CNN_DIM": 2}),
    "lstm": ("4.0-lstm", {"BIDIRECTIONAL": False, "ATTENTION": False}),
    "bilstm": ("4.0-lstm", {"BIDIRECTIONAL": True, "ATTENTION": False}),
    "lstm_attn": ("4.0-lstm", {"BIDIRECTIONAL": False, "ATTENTION": True}),
    "bilstm_attn": ("4.0-lstm", {"BIDIRECTIONAL": True, "ATTENTION": True}),
}

## Parameters

In [12]:
basic_params = {
    "EXPERIMENT_NAME": "20_04",
    "EPOCHS": 5500,
    "EVAL_PERIOD": 10,
    "LOG_PERIOD": 10,
}

In [13]:
params_dict = {
    "a2c": {},
    "reinforce": {},
    "dqn": {},
    "cnn1d": {},
    "cnn2d": {},
    "lstm": {},
    "bilstm": {},
    "lstm_attn": {},
    "bilstm_attn": {},
}

## Utils

In [14]:
def build_exec_params(name: str) -> tuple[str, dict]:
    final_params = {}
    nb_name, specific_params = NB_DICT[name]
    final_params.update(basic_params)
    final_params.update(params_dict[name])
    final_params.update(specific_params)

    return nb_name, final_params

In [15]:
def exec_notebook(name: str, save_prefix: str = ""):
    print(f"Working on '{name}'...")

    nb_name, params = build_exec_params(name)

    start_time = time.time()

    save_name = f"{save_prefix}{nb_name}"

    pm.execute_notebook(f"./{nb_name}.ipynb", f"./{save_name}.ipynb", parameters=params)

    elapsed_time = time.time() - start_time
    print(f"Done with '{name}' in {elapsed_time:.2f} seconds! Saved to {save_name}\n")

## Execution

In [16]:
for nb in BASELINE_NOTEBOOKS:
    exec_notebook(nb)

Working on 'cnn1d'...


Executing:   0%|          | 0/33 [00:00<?, ?cell/s]

Done with 'cnn1d' in 57.03 seconds! Saved to 4.1-cnn

Working on 'cnn2d'...


Executing:   0%|          | 0/33 [00:00<?, ?cell/s]

Done with 'cnn2d' in 33.51 seconds! Saved to 4.1-cnn

Working on 'lstm'...


Executing:   0%|          | 0/31 [00:00<?, ?cell/s]

Done with 'lstm' in 543.16 seconds! Saved to 4.0-lstm

Working on 'bilstm'...


Executing:   0%|          | 0/31 [00:00<?, ?cell/s]

Done with 'bilstm' in 44.96 seconds! Saved to 4.0-lstm

Working on 'lstm_attn'...


Executing:   0%|          | 0/31 [00:00<?, ?cell/s]

Done with 'lstm_attn' in 13.54 seconds! Saved to 4.0-lstm

Working on 'bilstm_attn'...


Executing:   0%|          | 0/31 [00:00<?, ?cell/s]

Done with 'bilstm_attn' in 13.85 seconds! Saved to 4.0-lstm

