# Map between two states and get free energy

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%env XLA_PYTHON_CLIENT_MEM_FRACTION=.6
# %env XLA_PYTHON_CLIENT_ALLOCATOR=platform

env: XLA_PYTHON_CLIENT_MEM_FRACTION=.6


In [3]:
import jax
import equinox as eqx
import logging
import git
import numpy as np
import matplotlib.pyplot as plt
from jax import numpy as jnp
from dataclasses import asdict
from typing import cast

from rigid_flows.data import DataWithAuxiliary
from rigid_flows.density import KeyArray, OpenMMDensity
from rigid_flows.flow import (
    RigidWithAuxiliary,
    build_flow,
    initialize_actnorm,
    toggle_layer_stack,
)
from rigid_flows.reporting import Reporter, pretty_json
from rigid_flows.specs import ExperimentSpecification
from rigid_flows.train import run_training_stage

from flox._src.flow.api import Transform
from flox.flow import Pipe
from flox.util import key_chain


2023-02-09 13:34:23.389026: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/home/mi/minvernizzi/.local/lib
2023-02-09 13:34:23.389191: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/home/mi/minvernizzi/.local/lib




In [4]:
logging.getLogger().setLevel(logging.INFO)

def setup_model(key: KeyArray, specs: ExperimentSpecification):
    chain = key_chain(key)

    logging.info("Loading base density.")
    num_datapoints = specs.model.base.num_samples
    if num_datapoints is not None:
        logging.info(f'  taking only {num_datapoints} samples from MD')
        selection = np.s_[:num_datapoints]
    else:
        selection = np.s_[:]
    base = OpenMMDensity.from_specs(
        specs.model.auxiliary_shape, specs.model.base, selection
    )

    logging.info(f"Loading target density.")
    num_datapoints = specs.model.target.num_samples
    if num_datapoints is not None:
        logging.info(f'  taking only {num_datapoints} samples from MD')
        selection = np.s_[:num_datapoints]
    else:
        selection = np.s_[:]
    target = OpenMMDensity.from_specs(
        specs.model.auxiliary_shape, specs.model.target, selection
    )

    logging.info(f"Setting up flow model.")
    flow = build_flow(
        next(chain),
        specs.model.auxiliary_shape,
        specs.model.flow,
        # base,
        # target,
    )

    if specs.act_norm_init_samples is not None:
        logging.info(f"Initializing ActNorm")

        @eqx.filter_jit
        def init_actnorm(flow, key):
            actnorm_batch = jax.vmap(target.sample)(
                jax.random.split(key, specs.act_norm_init_samples)
            ).obj
            flow = toggle_layer_stack(flow, False)
            flow, _ = initialize_actnorm(flow, actnorm_batch)
            flow = toggle_layer_stack(flow, True)
            return flow

        flow = init_actnorm(flow, next(chain))

    if specs.model.pretrained_model_path is not None:
        logging.info(
            f"Loading pre-trained model from {specs.model.pretrained_model_path}."
        )
        flow = cast(
            Pipe[DataWithAuxiliary, RigidWithAuxiliary],
            eqx.tree_deserialise_leaves(
                specs.model.pretrained_model_path, flow
            ),
        )

    return base, target, flow


def train(
    key: KeyArray,
    run_dir: str,
    specs: ExperimentSpecification,
    base: OpenMMDensity,
    target: OpenMMDensity,
    flow: Transform[DataWithAuxiliary, DataWithAuxiliary],
    tot_iter: int,
    loss_reporter: list | None = None,
) -> Transform[DataWithAuxiliary, DataWithAuxiliary]:
    chain = key_chain(key)
    repo = git.Repo(search_parent_directories=True)
    branch = repo.active_branch.name
    sha = repo.head.object.hexsha

    log = asdict(specs)
    log["git"] = {"branch": branch, "sha": sha}
    # tf.summary.text("run_params", pretty_json(log), step=tot_iter)
    logging.info(f"Starting training.")
    reporter = Reporter(
        base,
        target,
        run_dir,
        specs.reporting,
        scope=None,
    )
    reporter.with_scope(f"initial").report_model(next(chain), flow, tot_iter)
    for stage, train_spec in enumerate(specs.train):
        flow = run_training_stage(
            next(chain),
            base,
            target,
            flow,
            train_spec,
            specs.model.target,
            reporter.with_scope(f"training_stage_{stage}"),
            tot_iter,
            loss_reporter,
        )
        tot_iter += train_spec.num_iterations
    return flow

def count_params(model):
    return jax.tree_util.tree_reduce(
        lambda s, n: s + n.size if eqx.is_array(n) else s, model, jnp.zeros((), dtype=jnp.int32)).item()

hist_kwargs = {"bins": "auto", "density": True, "alpha": 0.5}

In [5]:
specs_file = "testing.yaml"
specs = ExperimentSpecification.load_from_file(specs_file)

chain = key_chain(specs.seed)
base, target, flow = setup_model(next(chain), specs)
model = base.omm_model.model

tot_iter = specs.global_step if specs.global_step is not None else 0

print(f'tot flow parameters: {count_params(flow):_}')

INFO:root:Loading base density.
INFO:root:  taking only 10000 samples from MD
INFO:root:Loading OpenMM model specs from /group/ag_cmb/scratch/minvernizzi/so3-flow/ice_MDdata/model-tip4pice_iceXI_T250_N16.json
INFO:root:Loading data from /group/ag_cmb/scratch/minvernizzi/so3-flow/ice_MDdata/MDtraj-tip4pice_iceXI_T250_N16.npz
INFO:root:Loading target density.
INFO:root:  taking only 10000 samples from MD
INFO:root:Loading OpenMM model specs from /group/ag_cmb/scratch/minvernizzi/so3-flow/ice_MDdata/model-tip4pice_iceXI_T100_N16.json
INFO:root:Loading data from /group/ag_cmb/scratch/minvernizzi/so3-flow/ice_MDdata/MDtraj-tip4pice_iceXI_T100_N16.npz
INFO:root:Setting up flow model.


tot flow parameters: 319_568


In [6]:
try:
    ref_file = f"data/water/DeltaF_estimates/DF-{specs.model.base}-{specs.model.target}.txt"
    reference_deltaF, reference_deltaF_std = np.loadtxt(ref_file, unpack=True)
except FileNotFoundError:
    reference_deltaF, reference_deltaF_std = None, None

reference_deltaF, reference_deltaF_std

(-750.69943, 0.01622)

In [14]:
base.sample_idx(next(chain), 0).ldj

Array(-469.68274, dtype=float32)

In [16]:
base.sample_idx(next(chain), 0).ldj

-469.68367418517096

In [17]:
num_samples = 1_000

keys = jax.random.split(next(chain), num_samples)
base_tr = jax.vmap(base.sample)(keys)

mapped_tr = jax.vmap(flow.forward)(base_tr.obj)

keys = jax.random.split(next(chain), num_samples)
target_tr = jax.vmap(target.sample)(keys)

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(int32[])>with<BatchTrace(level=1/0)> with
  val = Array([6455, 8330, 6881, 7498, 7877, 6631, 2943, 5241, 1918,   19, 3133,
       6154, 1814, 2136, 3861, 2053, 4501,  720, 8934, 4273, 5914, 1647,
       5659, 8362,  654, 2624, 5543, 8941, 4268, 8487, 1710, 2260, 9795,
       4846, 2578, 2055, 5568, 5107, 2798, 4783, 5684, 6577,  447, 7580,
       4406, 7223, 6586, 9761, 5693, 8539, 2593,  981, 7017, 2813, 6905,
       4521, 3100,  212, 7898, 8158,   95, 8367, 2225, 3631, 1888, 8052,
       8600, 8065, 1497, 5081, 8461, 9186, 2926, 3969, 2098, 7937, 7138,
       3976, 8404, 7465, 5515, 4020, 4135, 9202,  418, 2755, 4252, 7711,
       7332, 8726, 9017,  823, 5860, 6913, 5732, 9630, 5655, 9008, 5732,
       6556, 2389, 4582, 5740, 5173, 9342, 4665, 3271, 9720, 2858, 1759,
       1921, 5988, 3055, 5234, 1894, 4867, 7995, 2109, 9962, 3922, 6055,
       7787, 1114, 4181, 3957, 4426, 4396, 2712, 5208, 3036, 5439, 3391,
       3358, 9221,  326, 5794, 7174, 7655, 4423, 1806,  862, 4002, 9341,
       1781, 9783, 7564, 7328, 3747, 3850, 3474, 2632, 2570, 8232, 1500,
       6302, 9727, 2543, 5570, 7590, 8162, 3798, 7656, 6717,  595, 9482,
        513, 6124, 2339, 5989, 6442, 7150, 7456, 4192, 5665,  418, 8178,
       9602, 4933, 3859, 6971, 9768, 5704, 6148, 8743, 1372, 5199, 9163,
       8502, 5095, 8021, 4024, 7431, 7432,  637, 3816, 2521, 7966, 3601,
       2585,  805, 2083,  593, 7922,  589, 9785, 1600, 3989, 8589, 7120,
       1654, 3657, 8338,  756, 8014, 8766, 5324, 4680, 4553, 5412, 5886,
       7644, 8386, 4711, 5752, 3812, 7840, 4101, 2385, 9076, 1120, 6830,
       5072, 8914,  580, 4124, 8643, 1506, 9894, 5390, 4698, 6397, 8435,
       8755, 9808, 1290, 6010, 8433, 9090, 2856, 3281, 9624, 3963, 2940,
       8095,  453, 8968, 2056, 3659, 4965, 2999, 5383, 3284, 1078, 8785,
        693, 1685, 8126, 8080, 8352, 6620, 4228, 1019, 1373, 4764, 3258,
       5616, 5008, 5347,  453,  783, 9473, 5658,  851, 3523, 5461, 6800,
       9549, 6565,  117, 9933, 1309, 8410, 9896, 9526, 7479, 7883, 7351,
       3053,  272, 6164,  141, 1878, 4404,    5, 1755, 6919, 7683, 5861,
       3402, 9525, 3114, 9752, 9590, 5960, 1373, 3699, 4312, 3506,  402,
       7637, 1217, 9526, 6024, 3125, 1151, 3653,  897, 4481, 3112, 4058,
       8091, 4101, 8747, 5517, 6975, 7640, 3244, 6660, 2063, 5883, 2558,
       5300, 2803, 5204, 6497,  851, 8219, 1966, 5407, 8550, 5995, 6128,
       1857, 9284, 2179, 1512, 5486, 7391, 7408, 1783, 6716,  191,  224,
       1927,  396, 9148, 5320, 7682, 1383, 2781, 3479, 2221,  623, 5564,
       9029, 1782, 3425, 1964, 3760, 9402, 6618, 4089, 3787, 9076, 4125,
       3362, 8484, 5577,  694, 9050, 1340, 6818, 6794, 6213, 6411, 3468,
       5116, 1757, 5407, 1239, 2171, 9652,  212, 8620, 1278, 3584, 5435,
       5401, 8395, 8644, 7427, 6142, 1918, 3524, 8398, 5743, 3104, 7505,
       5957, 3184, 1106, 6357, 7199, 2043, 1735, 4155, 6700, 8014, 1647,
       3442, 9891, 7483, 2857, 2414, 6982, 7296, 5561, 4202, 9270, 5032,
       3214, 4285, 2718, 1270, 3466, 1515, 9862, 3748, 2908, 6560,  720,
       9532, 9570, 3018, 1453, 3417, 3562, 4755, 7163, 8742, 1478, 6439,
       6163,  683, 6660, 3636, 9207, 3420, 3757, 7165, 7235, 2937, 6386,
       8165, 3935, 6278, 7350, 8740,  218, 5654, 5939, 8900, 4972, 2256,
       5693, 9824, 6914, 5987, 7178, 4130, 1460, 4327,  707, 8839, 8861,
       4279, 4667, 4683, 2123, 8702, 8989, 1902,  677, 1938, 9470, 6550,
       6608, 6704, 3652, 9030, 8810, 8585, 2689, 8152, 6806, 3920, 6603,
       6448, 6071, 6558, 8863, 1132, 3859, 6202, 3302, 2581, 8157, 2374,
       6916, 2051, 3142, 4124, 3834, 1154,  909, 6567, 3591, 4203, 3440,
       9029, 9076, 2032,  996, 5022, 6515, 5406,  887, 3827, 3355, 2554,
       1236, 7757, 4916, 9177,  116, 5352, 4809, 6995, 1455, 2996, 4976,
       9521, 8443, 5103, 7564, 3724, 9966, 3387, 9327, 4821, 4666, 9871,
       4476, 6425, 6482, 8311, 8678, 1064, 1227, 2085, 7954, 6258, 6113,
       3644, 1363, 3831, 3270, 2910, 3322, 8201, 9379, 5277, 1076, 3949,
       9584, 8089, 9842, 6805, 2422, 9285, 5443, 9893, 8473, 3467, 8514,
         55, 4818, 1125, 1446, 2614, 4788, 8115, 1192, 8905, 1002, 3512,
       8476, 5133, 5973, 5637, 9607, 6666, 5081, 2732, 7675, 4786, 9339,
       7490, 5117, 5234, 3554, 9384, 7547, 4924,  432,    6,  487, 9106,
       8919, 3356, 5843,  284, 8444, 3823, 4135, 5743, 9903, 3639, 9149,
        294, 8841, 7521, 5905, 7694, 9272, 2503, 2514, 1275, 2195, 5259,
       3891, 6740, 4807, 7159, 2844, 6594, 7337, 5607, 9648, 8497, 9014,
       8784, 6734, 6086, 8183, 5787, 6227, 2037, 9212, 8965, 4410, 5242,
       2492, 8604, 4170, 7442, 8323, 1694, 7670, 1299, 1287, 5468, 1940,
       6238, 7926, 5773, 5320, 6309, 5454, 8872, 4625, 5660, 1047, 8567,
       2725, 2565, 6423, 7306,  657, 5956,  911, 6731, 8280, 5304, 3825,
       7591,    2, 7148, 5952, 7629,  222, 5699, 6616, 9389, 2970, 4907,
       5193, 1465, 2196, 5701, 3849, 6779, 5165, 6130,  294, 4496, 1851,
       1810, 3456, 4974, 1862, 8334, 2985, 4348, 8914, 8422, 5816, 6600,
       3479, 6730, 5312, 9925, 5975, 4190, 6521, 6870, 9844,  528, 2773,
       8466,  209, 3921, 7379,  619, 3062, 9153,  481, 6301, 7685, 5185,
       1100, 7121, 9735, 7145, 2733, 7207, 7641, 8972, 9847, 3202, 5407,
       7634,  256, 8153, 1900, 4374, 2581, 3924, 7267, 4719, 3252, 1360,
       7499, 2566, 3863, 1225, 9483, 5008, 2913, 7290, 3427, 1687, 1497,
       9901, 5875, 4396, 1405, 8681, 1287, 1196, 4346, 6643,  123, 3443,
        144, 9794, 3537, 2526, 2234, 5164, 3907, 4293,    7, 3295, 3746,
       6760, 9973, 1752, 4024,  816, 2233, 6218, 7757, 7568, 5922, 4726,
       6733, 4870, 7496, 9460, 9314, 9146, 6930, 5902, 4743, 3078, 5544,
       2264, 1556, 8164, 9388,  225, 6472, 1793, 1969, 7458, 6186, 5464,
       8026, 8698, 4763, 5135,  642, 7526, 2985, 1560, 4653,  368, 9345,
       5244, 2912, 5316, 5204, 9541, 7260, 4443, 6713, 9176, 3234, 6087,
       4603, 9169, 3800, 3920, 5207, 8219, 2660, 5036, 8121, 6055, 9371,
       5283, 3244, 2307, 9698, 6347, 9352, 8084, 1989, 8972, 7295, 4576,
       2884, 5131, 7860, 5205, 4499, 8914, 2261, 9188, 4455, 1989, 3329,
       9238, 3672, 8654, 2790, 6970, 5923, 3575, 6403, 1063, 6612, 4530,
       4921, 1388,  131, 7280, 5853,  418,  698, 8292, 2709,  989, 3279,
       4823, 3166, 3817, 4799, 8398, 2562, 4432, 2568, 4435, 4485, 4016,
       3440, 8106,  784, 6964, 7696, 3623, 4573, 4438, 5724,  793, 5164,
        842, 6295, 4022, 5328, 1056, 5367, 4442, 1475, 2016, 3483, 8158,
       2599, 8153, 9395, 2016, 9033, 1521,  792, 5710, 8750, 5899, 5393,
       8286, 2510, 5372, 7818, 7679, 8302, 2710, 5879, 8905, 6137,  632,
       1473, 5947, 4098, 7770, 2791, 5969, 6837, 4912,  188, 9908],      dtype=int32)
  batch_dim = 0
This BatchTracer with object id 139655000531856 was created on line:
  /srv/data/minvernizzi/notebook/water/rigid-flows/rigid_flows/density.py:93 (sample)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

In [None]:
toPBC = True
model.plot_2Dview(base_tr.obj.pos.reshape(-1, model.n_atoms, 3), toPBC=toPBC, title='base')
model.plot_2Dview(mapped_tr.obj.pos.reshape(-1, model.n_atoms, 3), toPBC=toPBC, title='mapped')
model.plot_2Dview(target_tr.obj.pos.reshape(-1, model.n_atoms, 3), toPBC=toPBC, title='target')


In [None]:
## closer look at the center of mass
com_pos = base_tr.obj.pos.mean(axis=(1,2))
plt.plot(com_pos, '.')
plt.show()

In [None]:
%%time

loss_reporter = []
flow = train(next(chain), "testing", specs, base, target, flow, tot_iter, loss_reporter)

In [None]:
plt.plot(loss_reporter, label='KL loss')
if reference_deltaF is not None:
    plt.axhline(reference_deltaF, c="k", ls=":", label='MD reference')
plt.xlim(0, len(loss_reporter))
plt.xlabel('iterations')
plt.ylabel('$\Delta F$')
plt.legend()
plt.show()

In [None]:
mapped_tr = jax.vmap(flow.forward)(base_tr.obj)

In [None]:
model.plot_2Dview(base_tr.obj.pos.reshape(-1, model.n_atoms, 3),   toPBC=toPBC, title='base')
model.plot_2Dview(mapped_tr.obj.pos.reshape(-1, model.n_atoms, 3), toPBC=toPBC, title='mapped')
model.plot_2Dview(target_tr.obj.pos.reshape(-1, model.n_atoms, 3), toPBC=toPBC, title='target')


In [None]:
def ess(logw):
    return jnp.exp(2*jax.scipy.special.logsumexp(logw)-jax.scipy.special.logsumexp(2*logw))

## NB: base_tr.ldj = jax.vmap(base.potential)(base_tr.obj)
logw = base_tr.ldj + mapped_tr.ldj - jax.vmap(target.potential)(mapped_tr.obj)

plt.hist(jnp.exp(logw-logw.max()), bins=100)
plt.yscale('log')
plt.show()

print(f'ESS = {ess(logw):g}  ->  {ess(logw)/len(logw):.2%}')

In [None]:
ene_fn = jax.vmap(target.potential)
base_ene = ene_fn(base_tr.obj)
mapped_ene = ene_fn(mapped_tr.obj)
target_ene = ene_fn(target_tr.obj)

In [None]:
plt.hist(base_ene, **hist_kwargs, label='base')
plt.hist(mapped_ene, **hist_kwargs, label='mapped')
plt.hist(target_ene, **hist_kwargs, label='target')
plt.hist(mapped_ene, weights=np.exp(logw-jax.scipy.special.logsumexp(logw)), bins=75, histtype='step', density=True, label='reweighted')
plt.xlabel('total energy')
plt.legend()
plt.show()

In [None]:
ene_label = 'omm'
scaling = 1
if ene_label == 'omm':
    scaling = target.omm_model.kbT
base_ene2 = target.compute_energies(base_tr.obj, True, True, True)[ene_label] * scaling
target_ene2 = target.compute_energies(target_tr.obj, True, True, True)[ene_label] * scaling
mapped_ene2 = target.compute_energies(mapped_tr.obj, True, True, True)[ene_label] * scaling

In [None]:
plt.hist(base_ene2, **hist_kwargs, label='base')
plt.hist(mapped_ene2, **hist_kwargs, label='mapped')
plt.hist(target_ene2, **hist_kwargs, label='target')
plt.xlabel(ene_label + ' energy [kJ/mol]') #it's kJ/mol only for omm energies
plt.legend()
plt.show()

In [None]:
r_range = [0.2, np.diag(model.box).min()]
n_bins = 300
plt.title('Oxygen-Oxygen')
model.plot_rdf(base_tr.obj.pos.reshape(-1, model.n_atoms, 3), r_range=r_range, n_bins=n_bins, label='base')
model.plot_rdf(mapped_tr.obj.pos.reshape(-1, model.n_atoms, 3), r_range=r_range, n_bins=n_bins, label='mapped')
model.plot_rdf(target_tr.obj.pos.reshape(-1, model.n_atoms, 3), r_range=r_range, n_bins=n_bins, label='target')
plt.axvline(model.box.max()/2, c='k', ls=':', alpha=.5)
plt.legend()
plt.show()

In [None]:
## closer look at the center of mass
com_pos = mapped_tr.obj.pos.mean(axis=(1,2))
# com_pos = base_tr.obj.pos.mean(axis=(1,2)) - mapped_tr.obj.pos.mean(axis=(1,2))
plt.plot(com_pos, '.')
plt.show()
# for i in range(3):
#     plt.hist(com_pos[:,i], **hist_kwargs)
# plt.show()

In [None]:
## TFEP
deltaF = (jnp.log(len(logw)) - jax.scipy.special.logsumexp(logw)).item()
print(f'Estimated deltaF from TFEP (below) = {deltaF:g}')
if reference_deltaF is not None:
    print(f'                  Reference deltaF = {reference_deltaF:g}')

In [None]:
def estimate_deltaF(key, num_samples, base):
    keys = jax.random.split(key, num_samples)
    base_tr = jax.jit(jax.vmap(base.sample))(keys)
    mapped_tr = jax.jit(jax.vmap(flow.forward))(base_tr.obj)
    logw = base_tr.ldj + mapped_tr.ldj - jax.jit(jax.vmap(target.potential))(mapped_tr.obj)
    return (jnp.log(len(logw)) - jax.scipy.special.logsumexp(logw)).item()

eval_base = OpenMMDensity.from_specs(specs.model.auxiliary_shape, specs.model.base, np.s_[-specs.model.base.num_samples:])
eval_base.data.pos.shape[0]

In [None]:
estimate_from_training = False #this gives a biased estimate, always lower

from tqdm import trange

#FIXME: it's maybe better to use all MD points instead of random samples
iterations = 10
n_samples = 1000 

if estimate_from_training:
    deltaFs = np.zeros(iterations)
    for i in trange(iterations):
        deltaFs[i] = estimate_deltaF(next(chain), n_samples, base)
    print(f'deltaF = {deltaFs.mean():g} +/- {deltaFs.std():g}')

eval_deltaFs = np.zeros(iterations)
for i in trange(iterations):
    eval_deltaFs[i] = estimate_deltaF(next(chain), n_samples, eval_base)
eval_deltaFs.mean(), eval_deltaFs.std()
print(f'deltaF = {eval_deltaFs.mean():g} +/- {eval_deltaFs.std():g}')

In [None]:
xlim = [0, len(eval_deltaFs)]
if estimate_from_training:
    plt.plot(deltaFs, '.', c='blue', label='training')
    x = 2 * [deltaFs.mean()]
    plt.fill_between(xlim, x-deltaFs.std(), x+deltaFs.std(), color='blue', alpha=0.3)
    plt.axhline(deltaFs.mean(), c='blue')
plt.plot(eval_deltaFs, '.', c='orange', label='testing')
x = 2 * [eval_deltaFs.mean()]
plt.fill_between(xlim, x-eval_deltaFs.std(), x+eval_deltaFs.std(), color='orange', alpha=0.3)
plt.axhline(eval_deltaFs.mean(), c='orange')
if reference_deltaF is not None:
    plt.axhline(reference_deltaF, c='k', ls=":", label='reference')
    x = np.array(2 * [reference_deltaF])
    plt.fill_between(xlim, x-reference_deltaF_std, x+reference_deltaF_std, color='k', alpha=0.1)
plt.xlim(xlim)
plt.legend()
plt.show()