In [1]:
import json
import os
import numpy as np

In [2]:
def import_metrics_json(models_path, run_num):
    """
    Looks in {models_path}/{run_num}/metrics.json and returns the contents as a
    Python dictionary. Returns None if the path does not exist.
    """
    path = os.path.join(models_path, str(run_num), "metrics.json")
    if not os.path.exists(path):
        return None
    with open(path, "r") as f:
        return json.load(f)

In [3]:
def get_best_metric(models_path, metric_extract_func, metric_compare_func):
    """
    Given the path to a set of runs, determines the run with the best metric value,
    where the metric value is fetched by `metric_extract_func`. This function must
    take the imported metrics JSON and return the (scalar) value to use for
    comparison. The best metric value is determiend by `metric_compare_func`, which
    must take in two arguments, and return whether or not the _first_ one is better.
    Returns the number of the run, the value associated with that run, and a list of
    all the values used for comparison.
    """
    # Get the metrics, ignoring empty or nonexistent metrics.json files
    metrics = {run_num : import_metrics_json(models_path, run_num) for run_num in os.listdir(models_path)}
    metrics = {key : val for key, val in metrics.items() if val}  # Remove empties
    
    # Get the best value
    best_run, best_val, all_vals = None, None, {}
    for run_num in metrics.keys():
        try:
            val = metric_extract_func(metrics[run_num])
        except Exception:
            print("Warning: Was not able to extract metric for run %s" % run_num)
            continue
        all_vals[run_num] = val
        if best_val is None or metric_compare_func(val, best_val):
            best_val, best_run = val, run_num
    return best_run, best_val, all_vals

In [9]:
models_path = "/users/amtseng/tfmodisco/models/trained_models/E2F6/"
best_run, best_val, all_vals = get_best_metric(
    models_path,
    lambda metrics: np.mean(metrics["summit_prof_nll"]["values"][0]),
    lambda x, y: x < y
)
print("Best run: %s" % best_run)
print("Associated value: %s" % best_val)
for run in sorted(all_vals.keys(), key=lambda x: int(x)):
    print("%s: %d" % (run, all_vals[run]))

Best run: 24
Associated value: 316.1255480083218
1: 319
2: 317
3: 375
4: 318
5: 375
6: 335
7: 317
8: 353
9: 375
10: 375
11: 375
12: 318
13: 375
14: 339
15: 317
16: 375
17: 317
18: 317
19: 375
20: 320
21: 322
22: 317
23: 327
24: 316
25: 318
26: 357
27: 316
28: 375
29: 347
30: 320
31: 375
32: 352
33: 317
34: 375
35: 375
36: 375
37: 375
38: 321
39: 375
40: 316
41: 375
42: 320
43: 375
44: 375
45: 318
46: 375
47: 319
48: 375
49: 375
50: 316
51: 319
52: 339
53: 325
54: 323
55: 340
56: 355
57: 318
58: 318
59: 375
60: 340
61: 319
63: 320
64: 329
65: 318
66: 324
67: 317
68: 318
69: 317
70: 319
71: 317
72: 375
73: 345
74: 318
75: 318
76: 345
77: 375
78: 346
79: 317
80: 318
81: 339
82: 375
83: 344
84: 375
85: 375
86: 316
87: 375
88: 375
89: 317
90: 375
91: 316
92: 336
93: 319
94: 375
95: 318
96: 342
97: 375
98: 321
99: 375
100: 326
101: 330
102: 329
103: 375
104: 319
105: 375
106: 375
107: 375
108: 375
109: 375
110: 318
111: 346
112: 318
