In [1]:
import json
import numpy as np
import plotly.graph_objects as go

In [2]:
with open("./results.json") as results_file:
    results = json.load(results_file)
len(results)

256

In [3]:
total_balance = lambda result: result["total_balance"]
max(results, key=total_balance)

{'hyperparameters': {'exploration_level': 1,
  'discount_factor': 0.9987111413812587,
  'replay_buffer_capacity': 16,
  'batch_size': 4,
  'train_steps_per_turn': 256,
  'learning_rate': 0.00010473113593873374,
  'steps_per_target_update': 2048},
 'total_balance': 0.9056626857712563}

In [4]:
hyperparameter_values = lambda hyperparameter: np.array(
    [result["hyperparameters"][hyperparameter] for result in results]
)

def order(array):
    result = np.zeros_like(array)
    result[np.argsort(array)] = np.arange(len(array))
    return result

def line_params(array: np.ndarray, transform, inverse_transform):
    transformed = transform(array)
    tick_values = np.linspace(transformed.min(), transformed.max(), num=3)
    true_tick_values = inverse_transform(tick_values)
    return dict(
        tickvals=tick_values,
        ticktext=[f"{value:2g}" for value in true_tick_values],
        values=transform(array),
    )

go.Figure(
    data=[
        go.Parcoords(
            line=dict(
                color=order([total_balance(result) for result in results]),
            ),
            dimensions=[
                dict(
                    label=hyperparameter.replace("_", " ").capitalize(),
                    **line_params(
                        hyperparameter_values(hyperparameter),
                        transform=(lambda x: -np.log(1 - x)) if hyperparameter == "discount_factor" else np.log,
                        inverse_transform=(lambda y: 1 - np.exp(-y)) if hyperparameter == "discount_factor" else np.exp,
                    ),
                )
                for hyperparameter in results[0]["hyperparameters"].keys()
            ]
            + [
                dict(
                    label="Balance",
                    **line_params(
                        np.array([total_balance(result) for result in results]),
                        transform=np.log,
                        inverse_transform=np.exp,
                    ),
                )
            ],
        ),
    ],
)
