In [None]:
import pickle
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp

from identifiability import signals
from identifiability import mcmc

import identifiability.model.cascade_k4p9_fb as cascade_k4p9_fb
import identifiability.model.cascade_k4p11_fb as cascade_k4p11_fb
import identifiability.model.springs as springs

In [None]:
!mkdir -p ./cache

# plot utils

In [None]:
def plot_fit(
    samples_by_chain,
    parameters_names,
    hist_range=None,
    hist_ylim=(0, 2),
):
    nrows = len(parameters_names)
    fig, axs = plt.subplots(nrows=nrows, ncols=2, figsize=(8, 2 * nrows))
    
    axs[0, 0].set_title('histograms')
    axs[0, 1].set_title('chains')

    for i, (param_name, row) in enumerate(zip(parameters_names, axs)):
        ax = row[0]
        ax.set_ylabel(param_name)
        ax.hist(
            np.asarray(samples_by_chain[..., i]).reshape(-1),
            density=True,
            range=hist_range,
            bins=75,
            color='black',
        )
        if hist_ylim is not None:
            ax.set_ylim(hist_ylim)

        ax = row[1]
        for chain in samples_by_chain:
            ax.plot(chain[..., i], alpha=0.3)

# mcmc utils

In [None]:
speed_up = 1  # 10, 4

mcmc_kwargs = dict(
    num_chains=8,
    num_steps=(1_250_000 + 250_000) // speed_up,
    num_burn_in=250_000 // speed_up,
    thinning=1000 // speed_up,
    seed=1,
)

In [None]:
def _fit_models(
    model_train, parameters_train,
    model_fit, parameters_init,
    measure=lambda ys: ys,
    measurement_error=0.3,
    num_measurements=3,
    measurement_seed=42,
    sigma=None,
    log_reparam=True,
):
    # generate measurements
    ys_train_raw = model_train.run(parameters_train)
    ys_train_shape = (num_measurements, *ys_train_raw.shape)
    key = jax.random.PRNGKey(measurement_seed)
    ys_train = ys_train_raw[None] + measurement_error * jax.random.normal(key, shape=ys_train_shape)
    
    if sigma is None:
        sigma = jnp.full_like(parameters_init, 0.2)
        
    if log_reparam:
        enc = jnp.log
        dec = jnp.exp
    else:
        enc = lambda x: x
        dec = lambda x: x
    
    def cond_log_prob(parameters):
        ys = model_fit.run(parameters)[None]
        return mcmc.normal_error_log_prob(
            measure(ys_train),
            measure(ys),
            sigma=measurement_error,
        )

    log_prob_fn = mcmc.make_safe_log_prob_fn(model_fit.prior_log_prob, cond_log_prob)

    # sample in logs
    rep_log_prob_fn = lambda encoded_params: log_prob_fn(dec(encoded_params))
    rep_parameters_init = enc(parameters_init)

    print("Starting MCMC...")
    rep_samples_by_chain, log_prob_by_chain = mcmc.run_mcmc(
        parameters_init=rep_parameters_init,
        log_prob_fn=rep_log_prob_fn,
        sigma=sigma,
        **mcmc_kwargs,
    )

    # undo log reparametrization
    samples_by_chain = dec(rep_samples_by_chain)

    return samples_by_chain, log_prob_by_chain


def fit_models(
    model_train, parameters_train,
    model_fit, parameters_init,
    measure=lambda ys: ys,
    measurement_error=0.3,
    num_measurements=3,
    measurement_seed=42,
    sigma=None,
    log_reparam=True,
    cache_path=None,
    try_load=True,
):        
    if cache_path is not None:
        cache_path = Path(cache_path)
    else:
        try_load = False

    if try_load and cache_path.exists():
        print(f"\nLoading samples from: '{cache_path}'")
        with open(cache_path, 'rb') as handle:
            data = pickle.load(handle)
            samples_by_chain = data['samples_by_chain']
            log_prob_by_chain = data['log_prob_by_chain']

    else:
        samples_by_chain, log_prob_by_chain = _fit_models(
            model_train, parameters_train,
            model_fit, parameters_init,
            measure=measure,
            measurement_error=measurement_error,
            num_measurements=num_measurements,
            measurement_seed=measurement_seed,
            sigma=sigma,
            log_reparam=log_reparam,
        )
        
        if cache_path is not None:
            print(f"\nSaving samples to: '{cache_path}'")
            with open(cache_path, 'wb') as handle:
                pickle.dump({
                    'samples_by_chain': samples_by_chain,
                    'log_prob_by_chain': log_prob_by_chain,
                }, handle)

    rep_samples_by_chain = np.log(samples_by_chain) if log_reparam else samples_by_chain
    print("\nSample stats:")
    mcmc.mcmc_report(rep_samples_by_chain, model_fit.parameters_names)
    
    plot_fit(
        rep_samples_by_chain,
        [f"{'log ' if log_reparam else ''}{p}" for p in model_fit.parameters_names],
        hist_range=(-10, 17),
        hist_ylim=(0.0, 2.0),
    )

    return samples_by_chain, log_prob_by_chain

# common for cascades

In [None]:
ts_train_default = jnp.linspace(0, 11, 12)
ts_train_short = ts_train_default  # for short (1h) pulse use the same measurement times

# Figures 1–2 (nominal cascade)
## fit k4p9_fb to k4p9_fb, 4h signal

## measure=4

In [None]:
ts_train = ts_train_default
signal_train = signals.make_pulse(4)

samples_by_chain, log_prob_by_chain = fit_models(
    model_train=cascade_k4p9_fb.CascadeK4P9Fb(ts=ts_train, signal=signal_train),
    parameters_train=cascade_k4p9_fb.parameters_default,

    model_fit=cascade_k4p9_fb.CascadeK4P9Fb(ts=ts_train, signal=signal_train),
    parameters_init=cascade_k4p9_fb.parameters_median,

    measure=lambda ys: ys[..., -1],

    cache_path='./cache/cascade_k4p9_fb_self_4h_measure_4.pkl',
)

## measure=2,4

In [None]:
ts_train = ts_train_default
signal_train = signals.make_pulse(4)

samples_by_chain, log_prob_by_chain = fit_models(
    model_train=cascade_k4p9_fb.CascadeK4P9Fb(ts=ts_train, signal=signal_train),
    parameters_train=cascade_k4p9_fb.parameters_default,

    model_fit=cascade_k4p9_fb.CascadeK4P9Fb(ts=ts_train, signal=signal_train),
    parameters_init=cascade_k4p9_fb.parameters_median,

    measure=lambda ys: ys[..., 1::2],

    cache_path='./cache/cascade_k4p9_fb_self_4h_measure_2+4.pkl',
)

## measure=1,2,3,4

In [None]:
ts_train = ts_train_default
signal_train = signals.make_pulse(4)

samples_by_chain, log_prob_by_chain = fit_models(
    model_train=cascade_k4p9_fb.CascadeK4P9Fb(ts=ts_train, signal=signal_train),
    parameters_train=cascade_k4p9_fb.parameters_default,

    model_fit=cascade_k4p9_fb.CascadeK4P9Fb(ts=ts_train, signal=signal_train),
    parameters_init=cascade_k4p9_fb.parameters_median,

    measure=lambda ys: ys,

    cache_path='./cache/cascade_k4p9_fb_self_4h_measure_1+2+3+4.pkl',
)

# Figure 3 (relaxed cascade)
## fit k4p11_fb to k4p9_fb, 4h signal

## measure=4

In [None]:
ts_train = ts_train_default
signal_train = signals.make_pulse(4)

samples_by_chain, log_prob_by_chain = fit_models(
    model_train=cascade_k4p9_fb.CascadeK4P9Fb(ts=ts_train, signal=signal_train),
    parameters_train=cascade_k4p9_fb.parameters_default,

    model_fit=cascade_k4p11_fb.CascadeK4P11Fb(ts=ts_train, signal=signal_train),
    parameters_init=cascade_k4p11_fb.parameters_median,

    measure=lambda ys: ys[..., -1],

    cache_path='./cache/cascade_k4p11_fb_to_cascade_k4p9_fb_4h_measure_4.pkl',
)

## measure=1,2,3,4

In [None]:
ts_train = ts_train_default
signal_train = signals.make_pulse(4)

samples_by_chain, log_prob_by_chain = fit_models(
    model_train=cascade_k4p9_fb.CascadeK4P9Fb(ts=ts_train, signal=signal_train),
    parameters_train=cascade_k4p9_fb.parameters_default,

    model_fit=cascade_k4p11_fb.CascadeK4P11Fb(ts=ts_train, signal=signal_train),
    parameters_init=cascade_k4p11_fb.parameters_median,

    measure=lambda ys: ys,

    cache_path='./cache/cascade_k4p11_fb_to_cascade_k4p9_fb_4h_measure_1+2+3+4.pkl',
)

# Figure 4 (simplified cascade)
## fit k2p5_fb to k4p9_fb, 4h signal

In [None]:
import identifiability.model.cascade_k2p5_fb as cascade_k2p5_fb

In [None]:
ts_train = ts_train_default
signal_train = signals.make_pulse(4)

samples_by_chain, log_prob_by_chain = fit_models(
    model_train=cascade_k4p9_fb.CascadeK4P9Fb(ts=ts_train, signal=signal_train),
    parameters_train=cascade_k4p9_fb.parameters_default,

    model_fit=cascade_k2p5_fb.CascadeK2P5Fb(ts=ts_train, signal=signal_train),
    parameters_init=cascade_k2p5_fb.parameters_median,

    measure=lambda ys: ys[..., -1],

    cache_path='./cache/cascade_k2p5_fb_to_cascade_k4p9_fb_4h_measure_4.pkl',
)

# Appendix II. Pitfalls

In [None]:
import identifiability.model.cascade_k4p8 as cascade_k4p8

## cascade_k4p9_fb, 1h signal (pulse too short)

### measure=4

In [None]:
ts_train = ts_train_short
signal_train = signals.make_pulse(1)

samples_by_chain, log_prob_by_chain = fit_models(
    model_train=cascade_k4p9_fb.CascadeK4P9Fb(ts=ts_train, signal=signal_train),
    parameters_train=cascade_k4p9_fb.parameters_default,

    model_fit=cascade_k4p9_fb.CascadeK4P9Fb(ts=ts_train, signal=signal_train),
    parameters_init=cascade_k4p9_fb.parameters_median,

    measure=lambda ys: ys[..., -1],

    cache_path='./cache/cascade_k4p9_fb_self_1h_measure_4.pkl',
)

### measure=1,2,3,4

In [None]:
ts_train = ts_train_short
signal_train = signals.make_pulse(1)

samples_by_chain, log_prob_by_chain = fit_models(
    model_train=cascade_k4p9_fb.CascadeK4P9Fb(ts=ts_train, signal=signal_train),
    parameters_train=cascade_k4p9_fb.parameters_default,

    model_fit=cascade_k4p9_fb.CascadeK4P9Fb(ts=ts_train, signal=signal_train),
    parameters_init=cascade_k4p9_fb.parameters_median,

    measure=lambda ys: ys,

    cache_path='./cache/cascade_k4p9_fb_self_1h_measure_1+2+3+4.pkl',
)

## cascade_k4p8, 4h signal (wrong model)

### measure=1

In [None]:
ts_train = ts_train_default
signal_train = signals.make_pulse(4)

samples_by_chain, log_prob_by_chain = fit_models(
    model_train=cascade_k4p9_fb.CascadeK4P9Fb(ts=ts_train, signal=signal_train),
    parameters_train=cascade_k4p9_fb.parameters_default,

    model_fit=cascade_k4p8.CascadeK4P8(ts=ts_train, signal=signal_train),
    parameters_init=cascade_k4p8.parameters_median,

    measure=lambda ys: ys[..., -1],

    cache_path='./cache/cascade_k4p8_to_cascade_k4p9_fb_4h_measure_4.pkl',
)

### measure=1,2,3,4

In [None]:
ts_train = ts_train_default
signal_train = signals.make_pulse(4)

samples_by_chain, log_prob_by_chain = fit_models(
    model_train=cascade_k4p9_fb.CascadeK4P9Fb(ts=ts_train, signal=signal_train),
    parameters_train=cascade_k4p9_fb.parameters_default,

    model_fit=cascade_k4p8.CascadeK4P8(ts=ts_train, signal=signal_train),
    parameters_init=cascade_k4p8.parameters_median,

    measure=lambda ys: ys,

    cache_path='./cache/cascade_k4p8_to_cascade_k4p9_fb_4h_measure_1+2+3+4.pkl',
)

# Appendix I. Springs

## measure=3

In [None]:
ts_train = jnp.linspace(0, 40, 41)
signal_train = lambda t: -0.2 * (t < 10)
sigma = jnp.array([1.0, 1.0, 1.0, 0.2, 0.2, 0.2, 0.4, 0.4])

samples_by_chain, log_prob_by_chain = fit_models(
    model_train=springs.SpringsModel(ts=ts_train, signal=signal_train),
    parameters_train=springs.parameters_default,

    model_fit=springs.SpringsModel(ts=ts_train, signal=signal_train),
    parameters_init=springs.parameters_median,  # springs.parameters_default

    measure=lambda ys: ys[..., -1],

    cache_path='./cache/springs_measure_3.pkl',
    
    log_reparam=False,
    sigma=sigma,
)

## measure=1,2,3

In [None]:
ts_train = jnp.linspace(0, 40, 41)
signal_train = lambda t: -0.2 * (t < 10)
sigma = jnp.array([1.0, 1.0, 1.0, 0.2, 0.2, 0.2, 0.4, 0.4]) / 4

samples_by_chain, log_prob_by_chain = fit_models(
    model_train=springs.SpringsModel(ts=ts_train, signal=signal_train),
    parameters_train=springs.parameters_default,

    model_fit=springs.SpringsModel(ts=ts_train, signal=signal_train),
    parameters_init=springs.parameters_median,  # springs.parameters_default

    measure=lambda ys: ys,

    cache_path='./cache/springs_measure_1+2+3.pkl',
    
    log_reparam=False,
    sigma=sigma,
)