In [1]:
import sys

sys.path.append('../..')

# Loading data

In [2]:
import os

from academia.curriculum import LearningStats

pygame 2.5.2 (SDL 2.28.2, Python 3.10.12)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [3]:
OUTPUT_DIR = './outputs/'
N_TASKS_CURR_FULL = 3
N_TASKS_CURR_SKIP = 2

In [4]:
# for better looking legend on plots
task_names_maps = {
    'full': {
        '1': 'difficulty=0',
        '2': 'difficulty=1',
        '3': 'difficulty=2',
    },
    'skip': {
        '1': 'difficulty=0',
        '2': 'difficulty=2',
    }
}

In [5]:
def load_curr_full_stats(output_dir: str) -> list[dict[str, LearningStats]]:
    result: list[dict[str, LearningStats]] = []
    for dir_name in os.listdir(output_dir):
        # load only full curriculum results
        if not dir_name.startswith('full'):
            continue
        dir_path = os.path.join(output_dir, dir_name)
        curr_stats: dict[str, LearningStats] = {}
        # load each stats object
        for i in range(1, N_TASKS_CURR_FULL + 1):
            stats_path = os.path.join(dir_path, f'{i}.stats.json')
            curr_stats[task_names_maps['full'][str(i)]] = LearningStats.load(stats_path)
        result.append(curr_stats)
    return result

def load_curr_skip_stats(output_dir: str) -> list[dict[str, LearningStats]]:
    result: list[dict[str, LearningStats]] = []
    for dir_name in os.listdir(output_dir):
        # load only skip curriculum results
        if not dir_name.startswith('skip'):
            continue
        dir_path = os.path.join(output_dir, dir_name)
        curr_stats: dict[str, LearningStats] = {}
        # load each stats object
        for i in range(1, N_TASKS_CURR_SKIP + 1):
            stats_path = os.path.join(dir_path, f'{i}.stats.json')
            curr_stats[task_names_maps['skip'][str(i)]] = LearningStats.load(stats_path)
        result.append(curr_stats)
    return result

In [6]:
stats_curr_full: list[dict[str, LearningStats]] = load_curr_full_stats(OUTPUT_DIR)
stats_curr_skip: list[dict[str, LearningStats]] = load_curr_skip_stats(OUTPUT_DIR)

# Visualisations

In [7]:
from academia.tools import visualizations as vis

In [8]:
figs = vis.plot_trajectories(
    [stats_curr_full, stats_curr_skip],
    as_separate_figs=True,
    time_domain='steps',
    task_trace_start='mean',
    show_std=True,
    show_stop_time=True,
    show=True,
    save_format='svg',
    save_path='./outputs/plots/doorkey.svg',
)