In [None]:
import dataclasses
import json
import glob
import time

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as onp
from skimage import measure

from totypes import json_utils

from invrs_gym import challenges

In [None]:
# Launch an experiment.
!python ../scripts/experiment.py --path="../experiments"

In [None]:
# Recover logged scalars and parameters.
experiment_path = "../experiments"
wid_paths = glob.glob(experiment_path + "/*")
wid_paths.sort()

scalars = {}
hparams = {}
params = {}

for path in wid_paths:
    print(path)
    name = path.split("/")[-1]
    checkpoint_fname = glob.glob(path + "/checkpoint_*.json")
    if not checkpoint_fname:
        continue
    checkpoint_fname.sort()
    with open(checkpoint_fname[-1], "r") as f:
        checkpoint = json_utils.pytree_from_json(f.read())
    scalars[name] = checkpoint["scalars"]
    params[name] = checkpoint["params"]
    with open(path + "/setup.json", "r") as f:
        hparams[name] = json.load(f)

In [None]:
# Plot the efficiency trajectory, and the final, optimized and designs.

rows = len(wid_paths)
plt.figure(figsize=(6, 3 * rows))
for i, wid in enumerate(scalars.keys()):
    efficiency = scalars[wid]["average_efficiency"] * 100
    mask = scalars[wid]["distance"] <= 0
    step = onp.arange(1, len(efficiency) + 1)
    plt.subplot(rows, 2, 2 * i + 1)
    line, = plt.plot(step, efficiency)
    plt.plot(step[mask], efficiency[mask], 'o', color=line.get_color())
    plt.xlabel("step")
    plt.ylabel("Efficiency (%)")

    ax = plt.subplot(rows, 2, 2 * i + 2)
    im = ax.imshow(params[wid].array, cmap="gray")
    im.set_clim([0, 1])
    ax.axis(False)

plt.tight_layout()