In [1]:
from pathlib import Path
import os, sys
sys.path.append(str(Path(os.getcwd()).parent.parent))

## Grabbing results from a grid search

Just use `evariste.train.utils.get_experiments` which grabs all experiments from a grid, given a list of experiments names and dump_paths to look through

In [8]:
from evariste.trainer.utils import get_experiments
dump_paths = [
    'YOUR_PATH/dumped/',
]
exps = ['']
experiments = get_experiments(exps, dump_paths)

print(f"Loaded {len(experiments)} experiments")

Don't hesitate to look around in an experiment object. Params are in a flat_dict format. What's more interesting is the `logs` key for each experiment.

Let's manually plot the results from one of the experiments we grabbed.

In [6]:
from matplotlib import pyplot as plt

x_key, y_key = 'hours', 'valid-eq_equality_equiv_seq2seq-proven'

def grab_data_from_logs(logs):
    x, y = [], []
    for epoch in logs:
        x.append(epoch[x_key])
        y.append(epoch[y_key])
    return x,y

plt.figure(dpi=100)
plt.plot(*grab_data_from_logs(experiments['36560935']['logs']))
plt.xlabel("Hours")
plt.ylabel("% valid theorem proven")
plt.show()

In order to make life easier, `plot_experiments` is a rather complex function that let's you explore the results from a grid search. Take a look at the cell below and experiment.

In [7]:
from evariste.trainer.plot_utils import plot_experiments

def filter_data(data):
    """
    Since the backward eval is finicky, sometimes it's interrupted and the value logged is -1.
    Let's ignore these values in the plot.
    """
    select = True
    select &= data[key] >0
    return select


def exp_filter(exp):
    """
    This can be used to select only a few experiments in the grid.
    For exemple we could select based on the optimizer
    """
    select = True
#     select &= '0.0003' not in exp['params']['optimizer']
    return select


IGNORE_PREFIX = ['slurm_conf', 'master_port', 'command']
key = 'valid-eq_equality_equiv_seq2seq-proven'



# Use mpld3 for nice interactive plots you can hover on
import mpld3
from mpld3 import plugins

mpld3.enable_notebook()
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:600% !important; }</style>"))



plot_experiments(
    'epoch',                     # x coordinate
    [key],                       # y coordinates
    experiments.values(),        # all experiments to plot
    exp_filter=exp_filter,       # experiment filter defined above
    filter_data=filter_data,     # filter points
    higher_better=1,             # in the log below, how should the results be sorted ?
    ignore_prefix=IGNORE_PREFIX, # prefixes to ignore when printing grid search params
    dump_paths=dump_paths,
    repeat_print=False,
    with_legend=False
)