In [23]:
import sys

sys.path.append('../..')

# Loading data

In [24]:
import os

from academia.curriculum import LearningStats

In [25]:
DQN_OUTPUT_DIR = './outputs/DQNAgent/'
PPO_OUTPUT_DIR = './outputs/PPOAgent/'
CURRICULUM_N_TASKS = 3

In [26]:
def load_curr_stats(output_dir: str) -> list[dict[str, LearningStats]]:
    result: list[dict[str, LearningStats]] = []
    for dir_name in os.listdir(output_dir):
        # load only curriculum results
        if not dir_name.startswith('curriculum'):
            continue
        dir_path = os.path.join(output_dir, dir_name)
        curr_stats: dict[str, LearningStats] = {}
        # load each stats object
        for i in range(1, CURRICULUM_N_TASKS + 1):
            stats_path = os.path.join(dir_path, f'{i}.stats.json')
            curr_stats[f'Curriculum task {i}'] = LearningStats.load(stats_path)
        result.append(curr_stats)
    return result

def load_nocurr_stats(output_dir: str) -> list[dict[str, LearningStats]]:
    result: list[dict[str, LearningStats]] = []
    for dir_name in os.listdir(output_dir):
        # load only curriculum results
        if not dir_name.startswith('nocurriculum'):
            continue
        stats_path = os.path.join(output_dir, dir_name, 'stats.stats.json')
        task_stats = {'No curriculum task': LearningStats.load(stats_path)}
        result.append(task_stats)
    return result

In [27]:
dqn_stats_curr = load_curr_stats(DQN_OUTPUT_DIR)
ppo_stats_curr = load_curr_stats(PPO_OUTPUT_DIR)

dqn_stats_nocurr = load_nocurr_stats(DQN_OUTPUT_DIR)
ppo_stats_nocurr = load_nocurr_stats(PPO_OUTPUT_DIR)

# Visualisations

In [28]:
from academia.tools import visualizations as vis

In [29]:
vis.plot_trajectories(
    [ppo_stats_nocurr, ppo_stats_curr],
    time_domain='steps',
    show_std=True,
)

In [30]:
vis.plot_trajectories(
    [ppo_stats_nocurr, ppo_stats_curr],
    time_domain='steps',
    show_run_traces=True,
)

In [31]:
vis.plot_trajectories(
    [dqn_stats_nocurr, dqn_stats_curr],
    time_domain='steps',
    show_std=True,
)

In [32]:
vis.plot_trajectories(
    [dqn_stats_nocurr, dqn_stats_curr],
    time_domain='steps',
    show_run_traces=True,
)