In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pathlib
import ipywidgets
from ipywidgets import interact
import os
import re
import json

In [None]:
SCRATCH_DIR = pathlib.Path(os.environ["SCRATCH"])
BASE_DIR = SCRATCH_DIR / "spring_mesh_runs"

In [None]:
BASE_STEP_SIZE = 0.3 / 100

In [None]:
STEP = 500
MEASURE = "raw_l2"

series = {}
num_traj_re = re.compile("-n(?P<numtraj>\d+)-t(?P<numstep>\d+)-")
for i, eval_dir in enumerate((BASE_DIR / "run" / "eval").glob("*")):
    # Ensure done
    if not (eval_dir / "launch" / "done_token.txt").is_file():
        continue
    # Ensure not non-standard configuration
    if "outdist" not in eval_dir.name:
        continue
    # Load data, determine method and dtfactor
    data_path = eval_dir / "integrated_trajectories.npz"
    meta_path = eval_dir / "results_meta.json"
    with open(meta_path, "r", encoding="utf8") as meta_file:
        meta = json.load(meta_file)
    with np.load(data_path) as data_file:
        # Continue loading
        traj_times_errs = []
        for traj_num, traj_meta in enumerate(meta["integration_stats"]):
            traj_time = traj_meta["timing"]["integrate_elapsed"]
            error_name = traj_meta["file_names"][MEASURE]
            errors = data_file[error_name]
            #num_steps = errors.shape[0]
            num_steps = 1
            traj_time = traj_time / num_steps
            if STEP == "mean":
                err = np.mean(errors)
            else:
                err = errors[STEP]
            traj_times_errs.append((traj_time, err, traj_num))
    # Compute key for storage
    run_desc_path = eval_dir / "launch" / "run_description.json"
    with open(run_desc_path, "r", encoding="utf8") as run_desc_file:
        run_desc = json.load(run_desc_file)
    eval_set_path = BASE_DIR / run_desc["phase_args"]["eval_data"]["data_dir"]
    integrator = run_desc["phase_args"]["eval"]["integrator"]
    data_meta_path = eval_set_path / "system_meta.json"
    with open(data_meta_path, "r", encoding="utf8") as data_meta_file:
        data_meta = json.load(data_meta_file)
    time_step_size = data_meta["system_args"]["trajectory_defs"][0]["time_step_size"]
    name_match = num_traj_re.search(eval_dir.name)
    num_traj = int(name_match.groups()[0])
    # TODO: REMOVE FILTER!
    if time_step_size != BASE_STEP_SIZE:
        continue
    # Compute method
    eval_type = run_desc["phase_args"]["eval"]["eval_type"]
    key = (eval_type, time_step_size, integrator)
    if key not in series:
        series[key] = []
    for (traj_time, err, traj_num) in traj_times_errs:
        series[key].append((traj_time, err, (eval_dir, traj_num)))

In [None]:
%matplotlib notebook
# Build lookups for colors and markers
selected_point = []
time_sizes = sorted({t for (_, t, _) in series.keys()})

integrators = sorted({it for (_, _, it) in series.keys()})
marker_idxs = {k: i for i, k in enumerate(integrators)}

markers = ["<", ".", "x", "*", "1", "+", "^"]

methods = sorted({m for (m, _, _) in series.keys()})
colors = {k: f"C{i}" for i, k in enumerate(methods)}

defined_legend = set()
fig = plt.figure(figsize=(15, 8), facecolor="white")
for key in sorted(series.keys()):
    data = series[key]
    method, time_size, integrator = key
    color = colors[method]
    marker = markers[marker_idxs[integrator]]
    times = [t for (t, _e, _) in data]
    errs = [e for (_t, e, _) in data]
    legend_key = f"{method} {integrator}"
    if integrator == "scipy-RK45":
        continue
    if legend_key not in defined_legend:
        legend_args = {"label": legend_key}
        defined_legend.add(legend_key)
    else:
        legend_args = {}
    plt_group = plt.scatter(times, errs, color=color, marker=marker, **legend_args)
    plt_group.series_key = key
    plt_group.set_picker(5)
plt.grid(True)
plt.legend(bbox_to_anchor=(1.1, 1.05))
plt.xlabel("Total int time (sec)")
plt.ylabel("Error")
    
plt.yscale("log")
plt.xscale("log")

plt.tight_layout()

plt.scatter([0], [0], color="black", marker="o")

def onpick(event):
    thisline = event.artist
    ind = event.ind
    selected_point.clear()
    selected_point.append((ind, event))
    
fig.canvas.mpl_connect('pick_event', onpick)

plt.title(f"Spring mesh step {STEP} {MEASURE} in-distribution")
plt.savefig(f"errtime-springmesh-step{STEP}-{MEASURE}-indist.png", transparent=False, bbox_inches="tight")

In [None]:
%matplotlib inline
# Retrieve key to plot
key = selected_point[0][1].artist.series_key
method, time_size, integrator = key
idxs = selected_point[0][0].tolist()
selected_series = series[key]
# Extract point
time, error, (eval_dir, traj_num) = selected_series[idxs[0]]
data_path = eval_dir / "integrated_trajectories.npz"
meta_path = eval_dir / "results_meta.json"

traj_file = np.load(data_path)
meta = json.load(meta_path.open())
traj = f"traj_{traj_num:05}"

run_desc_path = eval_dir / "launch" / "run_description.json"
with open(run_desc_path, "r", encoding="utf8") as run_desc_file:
    run_desc = json.load(run_desc_file)
    eval_set_path = BASE_DIR / run_desc["phase_args"]["eval_data"]["data_dir"]
    eval_data = np.load(eval_set_path / "trajectories.npz")
    
# Finally do the plot
p_true = eval_data[f"{traj}_p"]
q_true = eval_data[f"{traj}_q"]
edge_idx = eval_data[f"{traj}_edge_indices"].T
p = traj_file[f"{traj}_p"].reshape(p_true.shape)
q = traj_file[f"{traj}_q"].reshape(q_true.shape)

@interact(t=(0, p.shape[0] - 1))
def plot_traj(t):
    fig = plt.figure(figsize=(8,6), facecolor="white", dpi=100)
    plt.xlim(-1, 3)
    plt.ylim(-1, 3)
    pos = q[t]
    pos_true = q_true[t]
    # Plot points
    plt.scatter(pos[:, 0], pos[:, 1], marker='o', color="C0")
    plt.scatter(pos_true[:, 0], pos_true[:, 1], marker='x', color="C2")
    # Plot edges
    for edge_i in range(edge_idx.shape[0]):
        a = edge_idx[edge_i, 0]
        b = edge_idx[edge_i, 1]
        if a <= b:
            plt.plot(pos[(a, b), 0], pos[(a, b), 1], color="C1", linestyle=':')
    plt.grid(True)
    plt.title(f"{method}, int: {integrator}, Trajectory {traj_num}, step={t}")

In [None]:
methods = sorted({s for s, ts, i in series.keys()})
time_sizes = sorted({ts for s, ts, i in series.keys()})
integrators = sorted({i for s, ts, i in series.keys()})

@interact(method=methods, time_size=time_sizes, integrator=integrators)
def plot_traj(method, time_size, integrator):
    key = (method, time_size, integrator)
    runs = series[key]
    files = sorted({str(r[2]) for r in runs})
    @interact(file=files)
    def select_file(file):
        eval_dir = pathlib.Path(file)
        data_path = eval_dir / "integrated_trajectories.npz"
        meta_path = eval_dir / "results_meta.json"
        traj_file = np.load(data_path)
        meta = json.load(meta_path.open())
        trajs = sorted({tn["name"] for tn in meta["integration_stats"]})
        
        run_desc_path = eval_dir / "launch" / "run_description.json"
        with open(run_desc_path, "r", encoding="utf8") as run_desc_file:
            run_desc = json.load(run_desc_file)
        eval_set_path = BASE_DIR / run_desc["phase_args"]["eval_data"]["data_dir"]
        eval_data = np.load(eval_set_path / "trajectories.npz")
        
        @interact(traj=trajs)
        def select_traj(traj):
            # Finally do the plot
            p = traj_file[f"{traj}_p"]
            q = traj_file[f"{traj}_q"]
            p_true = eval_data[f"{traj}_p"]
            q_true = eval_data[f"{traj}_q"]
            plt.plot(q, label="q", color="C0", linestyle="-")
            plt.plot(q_true, label="q-true", color="C0", linestyle=':')
            plt.plot(p, label="p", color="C1", linestyle="-")
            plt.plot(p_true, label="p-true", color="C1", linestyle=':')
            plt.legend()
            plt.grid(True)