In [1]:
import json
import os
import numpy as np

In [2]:
def import_metrics_json(models_path, run_num):
    """
    Looks in {models_path}/{run_num}/metrics.json and returns the contents as a
    Python dictionary. Returns None if the path does not exist.
    """
    path = os.path.join(models_path, str(run_num), "metrics.json")
    if not os.path.exists(path):
        return None
    with open(path, "r") as f:
        return json.load(f)

In [3]:
def get_best_metric(models_path, metric_extract_func, metric_compare_func):
    """
    Given the path to a set of runs, determines the run with the best metric value,
    where the metric value is fetched by `metric_extract_func`. This function must
    take the imported metrics JSON and return the (scalar) value to use for
    comparison. The best metric value is determiend by `metric_compare_func`, which
    must take in two arguments, and return whether or not the _first_ one is better.
    Returns the number of the run, the value associated with that run, and a list of
    all the values used for comparison.
    """
    # Get the metrics, ignoring empty or nonexistent metrics.json files
    metrics = {run_num : import_metrics_json(models_path, run_num) for run_num in os.listdir(models_path)}
    metrics = {key : val for key, val in metrics.items() if val}  # Remove empties
    
    # Get the best value
    best_run, best_val, all_vals = None, None, {}
    for run_num in metrics.keys():
        try:
            val = metric_extract_func(metrics[run_num])
        except Exception:
            print("Warning: Was not able to extract metric for run %s" % run_num)
            continue
        all_vals[run_num] = val
        if best_val is None or metric_compare_func(val, best_val):
            best_val, best_run = val, run_num
    return best_run, best_val, all_vals

In [4]:
models_path = "/users/amtseng/tfmodisco/models/trained_models/SPI1/"
best_run, best_val, all_vals = get_best_metric(
    models_path,
    lambda metrics: np.mean(metrics["summit_prof_nll"]["values"][0]),
    lambda x, y: x < y
)
print("Best run: %s" % best_run)
print("Associated value: %s" % best_val)

Best run: 1
Associated value: 153.54793947399247


In [5]:
for key in sorted(all_vals.keys(), key=lambda x: int(x)):
    print(key, all_vals[key])

1 153.54793947399247
2 226.16465030266193
3 154.67320674900674
4 226.16370786839664
5 226.16432397203155
6 166.8852087779198
7 153.92978690030122
8 226.16666338690862
9 156.80632273905965
10 162.54583435052035
11 226.1755197358771
12 157.01017634035634
14 154.1111828598224
15 226.16998388074
16 165.55251039109362
17 155.17937774691393
18 157.66106714824517
19 159.91977544091895
20 154.5057103706613
21 226.1637362810328
22 153.91011901812507
23 153.69814584858167
