In [1]:
import sys

sys.path.append('../..')

# Loading data

In [2]:
import os

from academia.curriculum import LearningStats

pygame 2.5.2 (SDL 2.28.2, Python 3.10.12)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [3]:
OUTPUT_DIR_EVAL = './outputs/eval/'
OUTPUT_DIR_TIME = './outputs/time/'
CURRICULUM_N_TASKS = 2

In [4]:
def load_stats_for_n_episodes(output_dir: str, n_episodes_x: int) -> list[dict[str, LearningStats]]:
    result: list[dict[str, LearningStats]] = []
    output_dir_n_episodes = os.path.join(output_dir, f'episodes_{n_episodes_x}')
    for curr_dir_name in os.listdir(output_dir_n_episodes):
        curr_dir_path = os.path.join(output_dir_n_episodes, curr_dir_name)
        curr_stats: dict[str, LearningStats] = {}
        # load each stats object
        for i in range(1, CURRICULUM_N_TASKS + 1):
            stats_path = os.path.join(curr_dir_path, f'task{i}.stats.json')
            task_id = 'X' if i == 1 else 'Y'
            curr_stats[f'Curriculum task {task_id}'] = LearningStats.load(stats_path)
        result.append(curr_stats)
    return result

In [5]:
def get_task_x_stats(curriculum_stats: list[dict[str, LearningStats]]) -> list[LearningStats]:
    return [cs['Curriculum task X'] for cs in curriculum_stats]

In [6]:
def get_task_y_stats(curriculum_stats: list[dict[str, LearningStats]]) -> list[LearningStats]:
    return [cs['Curriculum task Y'] for cs in curriculum_stats]

In [7]:
n_episodes_x = [1500, 2500, 3500, 4500, 5500]

eval_stats = [load_stats_for_n_episodes(OUTPUT_DIR_EVAL, n) for n in n_episodes_x]
time_stats = [load_stats_for_n_episodes(OUTPUT_DIR_TIME, n) for n in n_episodes_x]

eval_task_y_stats = [get_task_y_stats(cs) for cs in eval_stats]

time_task_x_stats = [get_task_x_stats(cs) for cs in time_stats]
time_task_y_stats = [get_task_y_stats(cs) for cs in time_stats]

# Visualisations

In [8]:
from academia.tools import visualizations as vis

In [9]:
vis.plot_evaluation_impact(
    n_episodes_x=n_episodes_x, 
    task_runs_y=eval_task_y_stats,
    save_format='svg',
    save_path='./outputs/plots/lava',
)

In [10]:
vis.plot_time_impact(
    task_runs_x=time_task_x_stats,
    task_runs_y=time_task_y_stats,
    time_domain_y='episodes',
    save_format='svg',
    save_path='./outputs/plots/lava_xepisodes.svg',
)

In [11]:
vis.plot_time_impact(
    task_runs_x=time_task_x_stats,
    task_runs_y=time_task_y_stats,
    time_domain_y='steps',
    save_format='svg',
    save_path='./outputs/plots/lava_xsteps.svg',
)