In [1]:
import sys

sys.path.append('../..')

# Loading data

In [2]:
import os

from academia.curriculum import LearningStats

pygame 2.5.2 (SDL 2.28.2, Python 3.10.12)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [3]:
OUTPUT_DIR = './outputs/'
N_TASKS_CURR_FULL = 3
N_TASKS_CURR_SKIP = 2

In [11]:
# for better looking legend on plots
task_names_maps = {
    'full': {
        '1': 'Task 1',
        '2': 'Task 2',
        '3': 'Task 3',
    },
    'skip': {
        '1': 'Task 1',
        '2': 'Task 3',
    }
}

In [12]:
def load_curr_full_stats(output_dir: str) -> list[dict[str, LearningStats]]:
    result: list[dict[str, LearningStats]] = []
    for dir_name in os.listdir(output_dir):
        # load only full curriculum results
        if not dir_name.startswith('full'):
            continue
        dir_path = os.path.join(output_dir, dir_name)
        curr_stats: dict[str, LearningStats] = {}
        # load each stats object
        for i in range(1, N_TASKS_CURR_FULL + 1):
            stats_path = os.path.join(dir_path, f'{i}.stats.json')
            curr_stats[task_names_maps['full'][str(i)]] = LearningStats.load(stats_path)
        result.append(curr_stats)
    return result

def load_curr_skip_stats(output_dir: str) -> list[dict[str, LearningStats]]:
    result: list[dict[str, LearningStats]] = []
    for dir_name in os.listdir(output_dir):
        # load only skip curriculum results
        if not dir_name.startswith('skip'):
            continue
        dir_path = os.path.join(output_dir, dir_name)
        curr_stats: dict[str, LearningStats] = {}
        # load each stats object
        for i in range(1, N_TASKS_CURR_SKIP + 1):
            stats_path = os.path.join(dir_path, f'{i}.stats.json')
            curr_stats[task_names_maps['skip'][str(i)]] = LearningStats.load(stats_path)
        result.append(curr_stats)
    return result

In [13]:
stats_curr_full: list[dict[str, LearningStats]] = load_curr_full_stats(OUTPUT_DIR)
stats_curr_skip: list[dict[str, LearningStats]] = load_curr_skip_stats(OUTPUT_DIR)

# Visualisations

In [15]:
from academia.tools import visualizations as vis

In [18]:
figs = vis.plot_trajectories(
    [stats_curr_full, stats_curr_skip],
    as_separate_figs=True,
    time_domain='steps',
)

In [19]:
figs[0].update_layout(title=dict(text='PPO on DoorKey: full curriculum'))
figs[0].show()

In [20]:
figs[1].update_layout(title=dict(text='PPO on DoorKey: curriculum without task 2'))
figs[1].show()