In [None]:
from importlib import reload
from pathlib import Path
from time import time

import fff
import fff.evaluate.bg as bg_eval
import pandas as pd
import torch
import yaml
from IPython.display import display
from tqdm.auto import tqdm, trange
import numpy as np

In [None]:
runs = {
    "dw4": [
        # Insert the lightning_log directories of your runs here:
        "v24-dw4 beta=10 bs=256 lc=20 lr=0.001 gc=1",
        "v28-dw4 beta=10 bs=256 lc=20 lr=0.001 gc=1",
        "v29-dw4 beta=10 bs=256 lc=20 lr=0.001 gc=1",
    ],
    "lj13": [
        # Insert the lightning_log directories of your runs here:
        "v12-lj13 beta=200 bs=256 lc=8 lr=0.001 gc=1",
        "v31-lj13 beta=200 bs=256 lc=8 lr=0.001 gc=1",
        "v38-lj13 beta=200 bs=256 lc=8 lr=0.001 gc=1",
        "v39-lj13 beta=200 bs=256 lc=8 lr=0.001 gc=1",
    ],
    "lj55": [
        # Insert the lightning_log directories of your runs here:
        "v34-lj55 beta=500 bs=56 lc=8 lr=0.001 gc=0.1",
        "v47-lj55 beta=500 bs=56 lc=8 lr=0.001 gc=0.1",
        "v59-lj55 beta=500 bs=56 lc=8 lr=0.001 gc=0.1",
    ],
}
# With how many samples to compute the Boltzmann generator metrics
n_samples = {
    "dw4": 10000,
    "lj13": 1000,
    "lj55": 100
}
# Change the batch size to fit your GPU memory
batch_size = {
    "dw4": 1000,
    "lj13": 300,
    "lj55": 10
}

In [None]:
torch.set_grad_enabled(False)

In [None]:
dfs = {}

device = "cuda"
time_repetitions = 10
reload(bg_eval)

for dataset, run_names in runs.items():
    ckpt_files = []
    for base_path in sorted(Path("lightning_logs_selected").absolute().iterdir()):
        if not any(k in base_path.name for k in run_names):
            continue
        ckpt_dir = base_path / "checkpoints"
        if not ckpt_dir.exists():
            continue
        ckpt_file = ckpt_dir / "last.ckpt"
        if not ckpt_file.exists():  # or ckpt_file.stat().st_mtime > time() - 5 * 60:
            ckpt_file = max(ckpt_dir.iterdir(), key=lambda p: p.name)
        hparams_file = base_path / "hparams.yaml"
        if not hparams_file.is_file():
            continue
        with hparams_file.open() as f:
            hparams = yaml.safe_load(f)
        if hparams["data_set"]["name"] != dataset:
            continue
        ckpt_files.append(ckpt_file)

    assert len(run_names) == len(ckpt_files)

    models = []
    data = []
    for ckpt_file in tqdm(ckpt_files):
        ckpt = torch.load(ckpt_file)
        ckpt["hyper_parameters"]["data_set"]["root"] = "../data"
        model = fff.FreeFormFlow(ckpt["hyper_parameters"])
        model.load_state_dict(ckpt["state_dict"])
        model.to(device)

        model.hparams.batch_size = batch_size[dataset]

        # This creates a cache file
        dim, n_dimensions, n_particles, target = bg_eval._tgt_info(model)
        bg, bg_samples = bg_eval.sample_boltzmann(model, ckpt_file, n_samples[dataset])
        nll = bg_eval.nll(model, ckpt_file, bg)

        # Raw sampling time
        raw_sampling_time = float("inf")
        latent = model.get_latent(torch.device("cuda"))
        with torch.no_grad():
            for bs in tqdm(10 ** np.arange(5)):
                start = time()
                try:
                    for _ in trange(time_repetitions):
                        z = latent.sample((int(bs),))[0].reshape(bs, n_particles, n_dimensions)
                        conditioned = model.apply_conditions([z])
                        x = model.decode(conditioned.x0, conditioned.condition)
                except RuntimeError as e:
                    print(e)
                    break
                raw_sampling_time = min(
                    (time() - start) / bs / time_repetitions,
                    raw_sampling_time
                )
        data.append({
            "run": ckpt_file.parents[1].name,
            "model": model,
            "ckpt_file": ckpt_file,
            "raw_sample_time": raw_sampling_time,
            # "log_prob_sample_time": bg_samples["times_np"].mean(),
            "nll": nll
        })

    dfs[dataset] = pd.DataFrame(data)

In [None]:
for dataset, df in dfs.items():
    print(dataset)
    display(df)
    for col in df:
        if "time" in col:
            df[col] *= 1000
    df = df.set_index("run").describe()
    for col in df:
        print(f"{col}: {df[col]['mean']:.3f} ± {df[col]['std']:.3f}")
    print()