In [None]:
import pickle

import numpy as np

import matplotlib.colors as colors
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import matplotlib.lines as lines
import seaborn as sns

import scipy.stats as ss

import jax
import jax.numpy as jnp

from identifiability import signals
from identifiability.model.cascade_base import softlog as cascade_softlog
import identifiability.model.cascade_k4p9_fb as cascade_k4p9_fb
import identifiability.model.cascade_k4p11_fb as cascade_k4p11_fb

In [None]:
!mkdir -p ./figures-manuscript
!mkdir -p ./figures-internal

# Plot utils

## settings

In [None]:
plt.rcParams['svg.fonttype'] = 'none'
plt.rcParams['font.family'] = 'Arial'
plt.rcParams['axes.titlecolor'] = 'black'

plt.rcParams['figure.dpi'] = 100

plt.rcParams['xtick.labelsize'] = 8
plt.rcParams['ytick.labelsize'] = 8

plt.rcParams['axes.labelsize'] = 9
plt.rcParams['axes.titlesize'] = 11

In [None]:
inch_per_h = 1.5 / 12
inch_per_s = 1.0 / 30

axsize_cascade_12h = (1.0, 0.5)
axsize_signal_12h = (1.0, 0.33)

axsize_hist = (1.0, 0.66)
axsize_pca = (3.0, 1.0)

axsize_springs_60s = (1.0, 1.0)
axsize_springs_signal_60s = (1.0, 0.5)


assert axsize_cascade_12h[0] == axsize_signal_12h[0]
inch_per_h = axsize_cascade_12h[0] / 12

assert axsize_springs_60s[0] == axsize_springs_signal_60s[0]
inch_per_s = axsize_springs_60s[0] / 60

In [None]:
text_time_hour = "Time [h]"
text_time_seconds = "Time [s]"
text_pos_m = "Position [m]"
text_softlog = "softlog"  # "log_p"
text_signal = "signal"
text_springs_signal = "forcing [N]"

signal_type_to_label = {
    'test': "$S_{\mathrm{test}}$",
    'train': "$S_{\mathrm{train}}$",
}

signal_type_to_spring_label = {
    'test': "$F_{\mathrm{test}}$ [N]",
    'train': "$F_{\mathrm{train}}$ [N]",
}

In [None]:
cascade_colors = colors.hsv_to_rgb([
    [0.62, 1.0, 0.90],
    [0.30, 1.0, 0.77],
    [0.13, 1.0, 0.86],
    [0.00, 1.0, 0.90],
])
springs_colors = [cascade_colors[0], cascade_colors[1], cascade_colors[3]]

sampling_colors = ['#4fc6ff', '#ff8a19', '#a551b6']

color_prior = '#c9c9c9'
color_signal = 'grey'
color_note = '#404040'

alpha_mid = 0.5
alpha_light = 0.3
alpha_very_light = 0.15

## general utils

In [None]:
from collections.abc import Iterable
import mpl_toolkits.axes_grid1 as ag


def _normalize_axsizes_and_axcounts(axs, axc):
    # if ax sizes are iterable, compute new ax count
    if isinstance(axs, Iterable):
        axs = list(axs)
        if axc is None:
            axc = len(axs)
        else:
            assert axc == len(axs), "axes count doesn't agree with number of sizes"
    
    # if not, make sure they are a list with correct count
    else:
        if axc is None:
            axc = 1
        axs = [axs for _ in range(axc)]

    return axs, axc


def _make_sizes(start, axs, end, spaces):
    n = len(axs)
    assert len(spaces) == n - 1
    
    ls = [start] + [e for p in zip(axs, spaces) for e in p] + [axs[-1], end]
    ss = [ag.Size.Fixed(l) for l in ls]
    
    return ss, sum(ls)


def subplots_from_axsize(
    axsize=(3, 2),
    nrows=None, ncols=None,
    top=0.1, bottom=0.5, left=0.5, right=0.1,
    hspace=0.5, wspace=0.5,
):
    """
    Similar to plt.subplots() but uses fixed instead of relative sizes.
    This allows for more control over the final axes sizes.
    
    Examples:
    subplots_from_axsize(axsize=(3, 2), nrows=2) creates a figure with two axes of size (3, 2)
    subplots_from_axsize(axsize=(3, [2, 1])) creates a figure with two axes: (3, 2) and (3, 1)
    
    """
    axx, axy = axsize
    
    axx, ncols = _normalize_axsizes_and_axcounts(axx, ncols)
    axy, nrows = _normalize_axsizes_and_axcounts(axy, nrows)
    
    hs, _ = _normalize_axsizes_and_axcounts(hspace, nrows - 1)
    ws, _ = _normalize_axsizes_and_axcounts(wspace, ncols - 1)
    
    w_sizes, total_w = _make_sizes(left, axx, right, ws)
    h_sizes, total_h = _make_sizes(top, axy, bottom, hs)
    
    fig = plt.figure(figsize=(total_w, total_h))
    
    divider = ag.Divider(fig, (0, 0, 1, 1), w_sizes, h_sizes[::-1], aspect=False)
    axs = np.array([
        [
            fig.add_axes(divider.get_position(), axes_locator=divider.new_locator(nx=2*col+1, ny=2*row+1))
            for col in range(ncols)
        ]
        for row in range(nrows-1, -1, -1)
    ])
        
    return fig, axs

In [None]:
def save_and_show(fig, path):
    fig.savefig(path, transparent=True)
    fig.patch.set_facecolor('fuchsia')
    fig.patch.set_alpha(0.1)
    plt.show()

In [None]:
def axis_time_ticker(ax, xaxis=True, yaxis=False, major=3, minor=1):
    if xaxis:
        ax.xaxis.set_minor_locator(ticker.MultipleLocator(minor))
        ax.xaxis.set_major_locator(ticker.MultipleLocator(major))
    if yaxis:
        ax.yaxis.set_minor_locator(ticker.MultipleLocator(minor))
        ax.yaxis.set_major_locator(ticker.MultipleLocator(major))

def log10_formatter(x, pos):
    if round(x) == x and x >= 0:
        if x <= 3:
            return str(round(10**x))
        return f"$10^{round(x)}$"

    return f"${10**x}$"

def latify_parameter_name(param_name):
    assert len(param_name) == 2
    return "$" + param_name[0] + "_" + param_name[1] + "$"

## axis formatters

In [None]:
def format_signal_ax(ax, t_min=0, t_max=None, y_min=0, y_max=1, signal_type='test'):
    ax.spines.right.set_visible(False)
    ax.spines.top.set_visible(False)
    axis_time_ticker(ax)
    
    # y axis
    ax.yaxis.set_minor_locator(ticker.MultipleLocator(0.5))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.set_ylim(y_min - 0.24 * (y_max - y_min), y_max)  # leave some space
    ax.spines.left.set_bounds(y_min, y_max)  # pretty
    
    # grid
    ax.grid(True, color='black', ls=':', which='both', axis='x', alpha=alpha_mid)
    ax.grid(True, color='black', ls=':', which='major', axis='y', alpha=alpha_mid)
    
    # x axis
    if t_max is None:
        return
    
    ax.set_xlim(t_min - 0.5, t_max + 0.5)  # leave some space
    ax.spines.bottom.set_bounds(t_min, t_max)  # pretty
    
    ax.set_ylabel(signal_type_to_label[signal_type])

In [None]:
def format_plot_ax(ax, t_min=0, t_max=None, y_min=None, y_max=None, y_xaxis_pos=None):
    ax.spines.right.set_visible(False)
    ax.spines.top.set_visible(False)
    
    # y axis
    ax.yaxis.set_minor_locator(ticker.MultipleLocator(0.5))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
    y_xaxis_pos = y_xaxis_pos if y_xaxis_pos is not None else y_min - 0.3
    ax.set_ylim(y_xaxis_pos, y_max)  # leave some space
    ax.spines.left.set_bounds(y_min, y_max)  # pretty
    
    # grid
    ax.grid(True, color='black', ls=':', which='both', axis='x', alpha=alpha_mid)
    ax.grid(True, color='black', ls=':', which='major', axis='y', alpha=alpha_mid)
    
    # x axis
    if t_max is None:
        return
    
    ax.set_xlim(t_min - 0.5, t_max + 0.5)  # leave some space
    ax.spines.bottom.set_bounds(t_min, t_max)  # pretty
    
def format_cascade_ax(
    ax, t_min=0, t_max=None,
    y_min=cascade_softlog(0.0),
    y_max=cascade_softlog(1.0).round(2),
    y_xaxis_pos=cascade_softlog(0.0) - 0.45,
):
    format_plot_ax(ax, t_min=t_min, t_max=t_max, y_min=y_min, y_max=y_max, y_xaxis_pos=y_xaxis_pos)
    axis_time_ticker(ax)
    
def format_springs_ax(
    ax, t_min=0, t_max=None,
    y_min=-3,
    y_max=+2,
    y_xaxis_pos=None,
):
    format_plot_ax(ax, t_min=t_min, t_max=t_max, y_min=y_min, y_max=y_max, y_xaxis_pos=y_xaxis_pos)
    axis_time_ticker(ax, minor=5, major=15)

In [None]:
def format_hist_ax(ax, hist_range, ylim_max=2.0):
    ax.spines.left.set_visible(False)
    ax.spines.right.set_visible(False)
    ax.spines.top.set_visible(False)
    
    ax.set_ylim(0, ylim_max)
    if hist_range is not None:
        ax.set_xlim(*hist_range)
    
    ax.set_yticks([])

## plotters

In [None]:
def plot_signal(ax, ts, signal):
    ax.plot(ts, signal(ts), lw=3, color=color_signal, alpha=alpha_mid, clip_on=False)

In [None]:
def plot_hist_samples(
    ax,
    samples,
    color='black',
    ylim_max=2.0,
    hist_range=None,
    parameter_name=None,
):
    ax.hist(
        samples,
        bins=60, range=hist_range, density=True,
        histtype='step',
        linewidth=1.5,
        color=color,
        clip_on=False,
    )

    format_hist_ax(ax, hist_range=hist_range, ylim_max=ylim_max)
    
    if parameter_name is not None:
        ax.set_ylabel(latify_parameter_name(parameter_name), rotation=0)

In [None]:
def lognormal_pdf_in_log10(
    mi=1.0,
    scale=3.0,
    mult=1.0,  # 1000 or 10 for feedbacks
):
    mi_new = mi + np.log(mult)
    
    def pdf(xs):
        return ss.norm.pdf(xs, loc=mi_new / np.log(10), scale=scale / np.log(10))
    
    return pdf

    
def plot_prior_pdf(
    ax,
    prior_pdf,
    plot_range,
):
    ls = jnp.linspace(*plot_range, 301)
    ax.fill_between(
        ls,
        0,
        prior_pdf(ls),
        color=color_prior,
    )

## figure templates

In [None]:
def prep_signal_and_n_steps(hours, n_steps=4, scale_figure=1, hspace=0.35):
    signal_y = axsize_signal_12h[1]
    cascade_y = axsize_cascade_12h[1]
    
    return subplots_from_axsize(
        axsize=(inch_per_h * float(hours) * scale_figure, [signal_y * scale_figure] + n_steps * [cascade_y * scale_figure]),
        hspace=hspace,
    )

## standard figures

In [None]:
ts_train_default = jnp.linspace(0, 11, 12)

ts_detailed_11h = jnp.linspace(0, 11, 1101)
ts_detailed_12h = jnp.linspace(0, 12, 1201)

In [None]:
def figure_cascade_training(
    cascade, parameters,
    ts_signal=ts_detailed_12h,
    measurement_error=0.3,
    signal_type='train',
):
    ts_train = cascade.ts
    ys = cascade.run(parameters)
    
    # we add 1 hour for format_cascade_ax()
    fig, axs = prep_signal_and_n_steps(ts_signal.max() - ts_signal.min() + 1)
    
    ax = axs[0, 0]
    format_signal_ax(ax, t_max=ts_train.max(), signal_type=signal_type)
    plot_signal(ax, ts_detailed_12h, cascade.signal)
    
    for i, ax in enumerate(axs[1:, 0]):
        format_cascade_ax(ax, t_max=ts_train.max())
        ax.set_xticks(ts_train, minor=True)        
        ax.set_ylabel(f"{text_softlog}$(K_{i+1})$")

        ax.errorbar(
            ts_train, ys[..., i],
            ls=':', lw=1,
            yerr=measurement_error,
            color=cascade_colors[i],
            capsize=2,
            #elinewidth=3,
            #clip_on=False
        )

    ax.set_xlabel(text_time_hour)

    return fig, axs

In [None]:
def figure_cascade_prediction(
    cascade_fit,
    samples,
    cascade_true=None,
    parameters_true=None,
    signal_y_max=2.0,
    show_steps=[1, 2, 3, 4],
    annotate_measured_steps=[],
    signal_type='test',
    scale_figure=1.0,
):
    if parameters_true is not None and cascade_true is None:
        cascade_true = cascade_fit
    
    ts = cascade_fit.ts
    
    # we add 1 hour for format_cascade_ax()
    fig, axs = prep_signal_and_n_steps(ts.max() - ts.min() + 1, n_steps=len(show_steps), scale_figure=scale_figure)
    
    if parameters_true is not None:
        ys_true = cascade_true.run(parameters_true)
    ys_pred = jax.lax.map(cascade_fit.run, samples)
    
    ys_q10, ys_med, ys_q90 = jnp.quantile(ys_pred, jnp.array([0.1, 0.5, 0.9]), axis=0)
    
    # plot signal
    ax = axs[0, 0]
    format_signal_ax(ax, t_max=ts.max(), y_max=signal_y_max, signal_type=signal_type)
    plot_signal(ax, ts, cascade_fit.signal)
    
    # plot steps
    for ax, step in zip(axs[1:, 0], show_steps):
        step_idx = step - 1
        format_cascade_ax(ax, t_max=ts.max())
        
        # prediction
        ax.fill_between(ts, ys_q10[..., step_idx], ys_q90[..., step_idx], color=cascade_colors[step_idx], alpha=0.3)
        ax.plot(ts, ys_med[..., step_idx], color=cascade_colors[step_idx])
        
        # true
        if parameters_true is not None:
            ax.plot(ts, ys_true[..., step_idx], ':', color='black')
            
        # measured
        #if step in annotate_measured_steps:
        #    ax.annotate('measured', (0.5, 0.85), color=color_note, xycoords='axes fraction', ha='center', va='center')
        
        ax.set_ylabel(f"{text_softlog}$(K_{step})$")

    ax.set_xlabel(text_time_hour)
    
    # title
    #ax = axs[0, 0]
    #ax.set_title(title, color=color_note, pad=12)
    
    return fig, axs

In [None]:
def figure_changes_prediction(
    cascade_fit,
    samples,
    cascade_true=None,
    parameters_true=None,
    step_idx=3,
    strength=10.0,
    signal_y_max=2.0,
    signal_type='test',
    scale_figure=1.0,
    hspace=0.35,
):
    if parameters_true is not None and cascade_true is None:
        cascade_true = cascade_fit
    
    assert cascade_fit.parameters_names == cascade_true.parameters_names
    
    ts = cascade_fit.ts
    
    # we add 1 hour for format_cascade_ax()
    n_params = samples.shape[-1]
    fig, axs = prep_signal_and_n_steps(ts.max() - ts.min() + 1, n_steps=n_params, scale_figure=scale_figure, hspace=hspace)
        
    # plot signal
    ax = axs[0, 0]
    format_signal_ax(ax, t_max=ts.max(), y_max=signal_y_max, signal_type=signal_type)
    plot_signal(ax, ts, cascade_fit.signal)

    # plot steps
    for i, ax in enumerate(axs[1:, 0]):
        format_cascade_ax(ax, t_max=ts.max())
        for mul, color in [(1 / strength, 'blue'), (strength, 'red')]:
            # set up changed parameters
            factors = jnp.ones(n_params).at[i].set(mul)
            samples_changed = samples * factors[None]
            if parameters_true is not None:
                parameters_true_changed = parameters_true * factors
        
            # run prediction
            ys_pred = jax.lax.map(cascade_fit.run, samples_changed)
            ys_q10, ys_med, ys_q90 = jnp.quantile(ys_pred, jnp.array([0.1, 0.5, 0.9]), axis=0)
            if parameters_true is not None:
                ys_true = cascade_true.run(parameters_true_changed)
                
            # plot
            ## prediction
            ax.fill_between(ts, ys_q10[..., step_idx], ys_q90[..., step_idx], color=color, alpha=0.3)
            ax.plot(ts, ys_med[..., step_idx], color=color)
            ## true
            if parameters_true is not None:
                ax.plot(ts, ys_true[..., step_idx], ':', color='black')
            
        parameter_name = cascade_fit.parameters_names[i]
        ax.set_title("[" + latify_parameter_name(parameter_name) + "]")
        ax.set_ylabel(f"{text_softlog}$(K_{step_idx + 1})$")

    ax.set_xlabel(text_time_hour)
    
    return fig, axs

In [None]:
def figure_springs_prediction(
    springs_fit,
    samples,
    springs_true=None,
    parameters_true=None,
    signal_y_min=-0.2, signal_y_max=0.2,
    signal_type='test',
):
    if parameters_true is not None and springs_true is None:
        springs_true = springs_fit
    
    ts = springs_fit.ts
    
    # we add 1 hour for format_springs_ax()
    plot_duration = float(ts.max() - ts.min() + 1)
    springs_signal_y = axsize_springs_signal_60s[1]
    springs_y = axsize_springs_60s[1]
    fig, axs = subplots_from_axsize(
        axsize=(inch_per_s * plot_duration, [springs_signal_y, springs_y]),
        hspace=0.35,
    )
    ax_signal, ax_springs = axs.reshape(-1)
    
    if parameters_true is not None:
        ys_true = springs_true.run(parameters_true)
    ys_pred = jax.lax.map(springs_fit.run, samples)
    
    ys_q10, ys_med, ys_q90 = jnp.quantile(ys_pred, jnp.array([0.1, 0.5, 0.9]), axis=0)
    
    # plot signal
    ax = ax_signal
    format_signal_ax(ax, t_max=ts.max(), y_min=signal_y_min, y_max=signal_y_max)
    plot_signal(ax, ts, springs_fit.signal)
    axis_time_ticker(ax, minor=5, major=15)
    ax.yaxis.set_minor_locator(ticker.MultipleLocator(0.1))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(0.2))
    ax.set_ylabel(signal_type_to_spring_label[signal_type])
    
    # plot steps
    ax = ax_springs
    format_springs_ax(ax, t_max=ts.max())
    for i in range(ys_pred.shape[-1]):
        # prediction
        ax.fill_between(ts, ys_q10[..., i], ys_q90[..., i], color=springs_colors[i], alpha=0.3)
        ax.plot(ts, ys_med[..., i], color=springs_colors[i])
        
        # true
        if parameters_true is not None:
            ax.plot(ts, ys_true[..., i], ':', color='black')
    
    ax.set_xlabel(text_time_seconds)
    ax.set_ylabel(text_pos_m)
    
    return fig, axs

In [None]:
def figure_histograms_raw(
    parameters_names,
    samplings,
    sampling_colors=sampling_colors,
    param_name_to_hist_range=lambda _: (-5, 5),
    param_name_to_prior_pdf=lambda _: None,
    parameters_true=None,
    ylim_max=2.0,
    horizontal=False,
    hspace=0.3,
    wspace=0.4,
):
    nrows_or_ncols = {
        ('ncols' if horizontal else 'nrows'): len(parameters_names),
    }
    fig, axs = subplots_from_axsize(
        axsize=axsize_hist,
        **nrows_or_ncols,
        hspace=hspace,
        wspace=wspace,
    )

    for i, (param_name, ax) in enumerate(zip(parameters_names, axs.reshape(-1))):
        hist_range = param_name_to_hist_range(param_name)
        prior_pdf = param_name_to_prior_pdf(param_name)
        
        plot_prior_pdf(ax, prior_pdf, hist_range)
        
        if parameters_true is not None:
            ax.axvline(parameters_true[i], color='black', ls=':')
        
        for samples, color in zip(samplings, sampling_colors):
            plot_hist_samples(
                ax,
                samples[..., i],
                hist_range=hist_range,
                color=color, ylim_max=ylim_max,
                parameter_name=param_name
            )
            
    return fig, axs

In [None]:
def figure_histograms_cascade(
    parameters_names,
    samplings,
    sampling_colors=sampling_colors,
    select=None,
    feedback_param_shift=1000.0,
    parameters_true=None,
    **kwargs,
):
    # useful for masking parameters
    if select is not None:
        parameters_names = select(parameters_names)
        samplings = [select(s) for s in samplings]
        if parameters_true is not None:
            parameters_true = select(parameters_true)
            
    # make log10
    samplings = np.log10(samplings)
    if parameters_true is not None:
        parameters_true = np.log10(parameters_true)
    
    def param_name_to_hist_range(param_name):
        if param_name[0] == 'f':
            return (-3 + np.log10(feedback_param_shift), 4 + np.log10(feedback_param_shift))
        else:
            return (-3, 4)
        
    def param_name_to_prior_pdf(param_name):
        if param_name[0] == 'f':
            return lognormal_pdf_in_log10(mult=feedback_param_shift)
        else:
            return lognormal_pdf_in_log10()

    fig, axs = figure_histograms_raw(
        parameters_names,
        samplings,
        sampling_colors=sampling_colors,
        param_name_to_hist_range=param_name_to_hist_range,
        param_name_to_prior_pdf=param_name_to_prior_pdf,
        parameters_true=parameters_true,
        **kwargs,
    )
    
    # make annotations in normal scale
    for ax in axs.reshape(-1):
        ax.xaxis.set_major_formatter(log10_formatter)
        ax.xaxis.set_minor_locator(ticker.MultipleLocator(1))
        ax.xaxis.set_major_locator(ticker.MultipleLocator(3))
        
    return fig, axs

In [None]:
def figure_histograms_springs(
    parameters_names,
    samplings,
    sampling_colors=sampling_colors,
    select=None,
    parameters_true=None,
    **kwargs,
):
    # useful for masking parameters
    if select is not None:
        parameters_names = select(parameters_names)
        samplings = [select(s) for s in samplings]
        if parameters_true is not None:
            parameters_true = select(parameters_true)

    
    def param_name_to_hist_range(param_name):
        if param_name[0] == 'm':
            return (0, 20)
        else:
            return (0, 5)
        
    def param_name_to_prior_pdf(param_name):
        if param_name[0] == 'm':
            return lambda _: 1/(20 - 0.1)
        else:
            return lambda _: 1/(5 - 0.1)

    fig, axs = figure_histograms_raw(
        parameters_names,
        samplings,
        sampling_colors=sampling_colors,
        param_name_to_hist_range=param_name_to_hist_range,
        param_name_to_prior_pdf=param_name_to_prior_pdf,
        parameters_true=parameters_true,
        **kwargs,
    )
    
    # make annotations in normal scale
    for ax, param_name in zip(axs.reshape(-1), parameters_names):
        if param_name[0] == 'm':
            ax.xaxis.set_minor_locator(ticker.MultipleLocator(1))
            ax.xaxis.set_major_locator(ticker.MultipleLocator(5))
        else:
            ax.xaxis.set_minor_locator(ticker.MultipleLocator(0.5))
            ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
            
    return fig, axs

In [None]:
from sklearn.decomposition import PCA

def compute_stds(samples):
    assert len(samples.shape) == 2
    samples = samples - samples.mean(keepdims=True, axis=0)
    variances = PCA().fit(samples).explained_variance_
    return np.sqrt(variances)


def figure_pca(
    samplings,
    sampling_colors=sampling_colors,
    log=True,
):
    if log:
        samplings = np.log(samplings)
    
    sds = np.array([
        compute_stds(samples) for samples in samplings
    ])

    fig, axs = subplots_from_axsize(axsize=axsize_pca)
    ax = axs[0, 0]
    
    n_dims = sds.shape[-1]
    n_bars = sds.shape[0] + 1
    
    # bar placement
    ls = 1.2 * np.arange(n_dims)
    offsets = np.arange(n_bars) * 0.2
    offsets = offsets - offsets.mean()
    
    # prior
    ax.bar(ls + offsets[0], np.exp(np.full_like(sds[0], 3)), 0.2, color=color_prior)
    
    # samplings
    for offset, sd, color in zip(offsets[1:], sds, sampling_colors):
        ax.bar(ls + offset, np.exp(sd), 0.2, color=color)
        
    # make nice
    ax.set_xticks(ls, ["$\delta_{" + str(i+1) + "}$" for i in range(len(ls))])
    ax.set_yscale('log')
    ax.set_yticks([1, 1.5, 3, 10, 20])
    ax.set_yticks([], minor=True)
    ax.yaxis.set_major_formatter('{x:.1f}')
    ax.set_ylim(1.0, np.exp(3))
    ax.grid(True, color='black', ls=':', which='major', axis='y', alpha=0.5, clip_on=False)
    #ax.grid(True, color='black', ls=':', which='minor', axis='y', alpha=1.0, clip_on=False)
    ax.spines.right.set_visible(False)
    ax.spines.top.set_visible(False)
    
    return fig, axs

# Eigenvectors

In [None]:
def samples_pca(samples):
    assert len(samples.shape) == 2
    
    samples = samples.copy()
    samples = np.log(samples)
    samples = samples - samples.mean(keepdims=True, axis=0)
    pca = PCA().fit(samples)
    stds = np.sqrt(pca.explained_variance_)
    
    return pca.components_

In [None]:
def plot_eigenvector(ax, v):
    assert len(v) == 9
    
    # fix signs
    if np.abs(v[0]) > 0.1:
        v = v * np.sign(v[0])
    elif np.sum(v[::2]) != 0:
        v = v * np.sign(np.sum(v[::2]))
    
    # prep table
    e = np.full(10, np.nan)
    e[:8] = v[:8]
    e[-1] = v[-1]
    e = e.reshape(-1, 2).transpose()
    
    # plot
    sns.heatmap(
        e, ax=ax, annot=True,
        vmin=-1, vmax=+1, fmt=f'.2f',
        square=True, cmap='coolwarm', linewidth=1,
        cbar=False, xticklabels=False, yticklabels=False,
    )
    
    return ax

In [None]:
def disentangle_vectors(vs, mask):
    n, _ = vs.shape
    
    def loss(a):
        g = a - a.transpose()
        m = jax.scipy.linalg.expm(g)
        ws = m.dot(vs)
        ls = ws * mask
        return jnp.sum(jnp.square(ls))
    
    grad = jax.grad(loss)
    loss = jax.jit(loss)
    grad = jax.jit(grad)
    
    # find optimal rotation
    a = jnp.array(np.random.normal(size=(n, n)))
    print(f"Initial loss: {loss(a):.3f}")
    for _ in range(5000):
        a -= 0.01 * grad(a)
    print(f"Final loss: {loss(a):.3f}")
    
    # disentangle
    g = a - a.transpose()
    m = jax.scipy.linalg.expm(g)
    ws = m.dot(vs)
    return ws

# Data utils

In [None]:
def load_samples(cache_path): 
    with open(cache_path, 'rb') as handle:
        samples = pickle.load(handle)['samples_by_chain']
        samples = samples.reshape(-1, samples.shape[-1])  # drop chain axis

    return samples


def subsample_for_pred(samples, goal=1000):
    num_samples, num_params = samples.shape
    thinning = num_samples // goal
    thinning = max(1, thinning)
    samples_pred = samples[::thinning]
    
    assert len(samples_pred) == goal
    
    return samples_pred

# Figures

# Figure 1. Introduction to the model

## Figure 1A. Model scheme (external)

## Figure 1B. Model equations (external)

## Figure 1C. Model parameters table (external)

## Figure 1D. Model default training trajectory (signal 4h)

In [None]:
fig, axs = figure_cascade_training(
    cascade_k4p9_fb.CascadeK4P9Fb(ts=ts_train_default+1, signal=signals.make_pulse_delayed(4)),
    cascade_k4p9_fb.parameters_default,
)
save_and_show(fig, './figures-manuscript/fig-1d-training-data.svg')

# Figure 2. Prediction

In [None]:
samples_4, samples_24, samples_1234 = map(load_samples, [
    './cache/cascade_k4p9_fb_self_4h_measure_4.pkl',
    './cache/cascade_k4p9_fb_self_4h_measure_2+4.pkl',
    './cache/cascade_k4p9_fb_self_4h_measure_1+2+3+4.pkl',
])

samples_4_pred, samples_24_pred, samples_1234_pred = map(subsample_for_pred,
    [samples_4, samples_24, samples_1234]
)

## Figure 2A. Prediction on blocky, measure 4

In [None]:
%%time
fig, axs = figure_cascade_prediction(
    cascade_k4p9_fb.CascadeK4P9Fb(ts=ts_detailed_12h, signal=signals.test_blocky),
    samples_4_pred,
    parameters_true=cascade_k4p9_fb.parameters_default,
    annotate_measured_steps=[4],
)

save_and_show(fig, './figures-manuscript/fig-2a-K4-pred-blocky.svg')

## Figure 2B. Prediction on wiggly, measure 4

In [None]:
%%time
fig, axs = figure_cascade_prediction(
    cascade_k4p9_fb.CascadeK4P9Fb(ts=ts_detailed_12h, signal=signals.test_wiggly),
    samples_4_pred,
    parameters_true=cascade_k4p9_fb.parameters_default,
    annotate_measured_steps=[4],
)

save_and_show(fig, './figures-manuscript/fig-2b-K4-pred-wiggly.svg')

## Figure 2C. Prediction on blocky, measure 2+4

In [None]:
%%time
fig, axs = figure_cascade_prediction(
    cascade_k4p9_fb.CascadeK4P9Fb(ts=ts_detailed_12h, signal=signals.test_blocky),
    samples_24_pred,
    parameters_true=cascade_k4p9_fb.parameters_default,
    annotate_measured_steps=[2, 4],
)

save_and_show(fig, './figures-manuscript/fig-2c-K24-pred-blocky.svg')

## Figure 2D. Prediction on blocky, measure 1+2+3+4

In [None]:
%%time
fig, axs = figure_cascade_prediction(
    cascade_k4p9_fb.CascadeK4P9Fb(ts=ts_detailed_12h, signal=signals.test_blocky),
    samples_1234_pred,
    parameters_true=cascade_k4p9_fb.parameters_default,
    annotate_measured_steps=[1, 2, 3, 4],
)

save_and_show(fig, './figures-manuscript/fig-2d-K1234-pred-blocky.svg')

## Figure 2E. Compare histograms

In [None]:
kwargs=dict(
    parameters_names=cascade_k4p9_fb.CascadeK4P9Fb(ts=ts_detailed_12h, signal=None).parameters_names,
    samplings=[samples_4, samples_24, samples_1234],
    parameters_true=cascade_k4p9_fb.parameters_default,
    hspace=0.5,
)

fig, axs = figure_histograms_cascade(
    **kwargs, select=lambda x: np.array(x)[..., 0:-1:2], ylim_max=3.0, 
)
save_and_show(fig, './figures-manuscript/fig-2e1-hist-as.svg')

fig, axs = figure_histograms_cascade(
    **kwargs, select=lambda x: np.array(x)[..., 1:-1:2], ylim_max=3.0,
)
save_and_show(fig, './figures-manuscript/fig-2e2-hist-ds.svg')

fig, axs = figure_histograms_cascade(
    **kwargs, select=lambda x: np.array(x)[..., -1:], ylim_max=3.0,
)
save_and_show(fig, './figures-manuscript/fig-2e3-hist-f.svg')

## Figure 2F. PCA dimensionality analysis

In [None]:
fig, axs = figure_pca([samples_4, samples_24, samples_1234])
save_and_show(fig, './figures-manuscript/fig-2f-pca.svg')

# Figure 3. KOs

## Figure 3A. Prediction on blocky, measure 4

In [None]:
%%time

fig, axs = figure_changes_prediction(
    cascade_k4p9_fb.CascadeK4P9Fb(ts=ts_detailed_12h, signal=signals.test_blocky),
    samples_4_pred,
    parameters_true=cascade_k4p9_fb.parameters_default,
    strength=10.0,
    scale_figure=1.0,
    hspace=0.5,
)

fig.legend(
    [lines.Line2D([0], [0], color='red', lw=2), lines.Line2D([0], [0], color='blue', lw=2)],
    ['× 10', '÷ 10'],
    loc='outside lower center', frameon=False,
)

save_and_show(fig, './figures-manuscript/fig-3a-blocky.svg')

## Figure 3B. Prediction on wiggly, measure 4

In [None]:
%%time

fig, axs = figure_changes_prediction(
    cascade_k4p9_fb.CascadeK4P9Fb(ts=ts_detailed_12h, signal=signals.test_wiggly),
    samples_4_pred,
    parameters_true=cascade_k4p9_fb.parameters_default,
    strength=10.0,
    scale_figure=1.0,
    hspace=0.5,
)

save_and_show(fig, './figures-manuscript/fig-3b-wiggly.svg')

# Figure 4. Relaxed model

In [None]:
samples_relaxed_4, samples_relaxed_1234 = map(load_samples, [
    './cache/cascade_k4p11_fb_to_cascade_k4p9_fb_4h_measure_4.pkl',
    './cache/cascade_k4p11_fb_to_cascade_k4p9_fb_4h_measure_1+2+3+4.pkl',
])

samples_relaxed_4_pred, samples_relaxed_1234_pred = map(subsample_for_pred,
    [samples_relaxed_4, samples_relaxed_1234]
)

## Figure 4A. Prediction on blocky, measure 4

In [None]:
%%time
fig, axs = figure_cascade_prediction(
    cascade_k4p11_fb.CascadeK4P11Fb(ts=ts_detailed_12h, signal=signals.test_blocky),
    samples_relaxed_4_pred,
    cascade_true=cascade_k4p9_fb.CascadeK4P9Fb(ts=ts_detailed_12h, signal=signals.test_blocky),
    parameters_true=cascade_k4p9_fb.parameters_default,
    annotate_measured_steps=[4],
)

save_and_show(fig, './figures-manuscript/fig-4a-K4-pred-blocky.svg')

## Figure 4B. Prediction on blocky, measure 1+2+3+4

In [None]:
%%time
fig, axs = figure_cascade_prediction(
    cascade_k4p11_fb.CascadeK4P11Fb(ts=ts_detailed_12h, signal=signals.test_blocky),
    samples_relaxed_1234_pred,
    cascade_true=cascade_k4p9_fb.CascadeK4P9Fb(ts=ts_detailed_12h, signal=signals.test_blocky),
    parameters_true=cascade_k4p9_fb.parameters_default,
    annotate_measured_steps=[1, 2, 3, 4],
)

save_and_show(fig, './figures-manuscript/fig-4b-K1234-pred-blocky.svg')

## Figure 4C. Histograms of feedbacks

In [None]:
fig, axs = figure_histograms_cascade(
    parameters_names=cascade_k4p11_fb.CascadeK4P11Fb(ts=ts_detailed_12h, signal=None).parameters_names,
    samplings=[samples_relaxed_4, samples_relaxed_1234],
    sampling_colors=[sampling_colors[0], sampling_colors[-1]],
    feedback_param_shift=10.0,
    select=lambda x: np.array(x)[..., -3:],
    ylim_max=1.0,
)

save_and_show(fig, './figures-manuscript/fig-4c-hist-fs.svg')

# Figure 5. Simplified model (shorter cascade)

In [None]:
import identifiability.model.cascade_k2p5_fb as cascade_k2p5_fb

In [None]:
samples_k2p5_fb_4 = load_samples('./cache/cascade_k2p5_fb_to_cascade_k4p9_fb_4h_measure_4.pkl')
samples_k2p5_fb_4_pred = subsample_for_pred(samples_k2p5_fb_4)

In [None]:
%%time
fig, axs = figure_cascade_prediction(
    cascade_k2p5_fb.CascadeK2P5Fb(ts=ts_detailed_12h, signal=signals.make_pulse_delayed(4)),
    samples_k2p5_fb_4_pred,
    cascade_true=cascade_k4p9_fb.CascadeK4P9Fb(ts=ts_detailed_12h, signal=signals.make_pulse_delayed(4)),
    parameters_true=cascade_k4p9_fb.parameters_default,
    show_steps=[1, 4],
    scale_figure=2.0,
    signal_type='train',
    #annotate_measured_steps=[4],
)

save_and_show(fig, './figures-manuscript/fig-5a-K4-simplified-train.svg')

In [None]:
%%time
fig, axs = figure_cascade_prediction(
    cascade_k2p5_fb.CascadeK2P5Fb(ts=ts_detailed_12h, signal=signals.test_blocky),
    samples_k2p5_fb_4_pred,
    cascade_true=cascade_k4p9_fb.CascadeK4P9Fb(ts=ts_detailed_12h, signal=signals.test_blocky),
    parameters_true=cascade_k4p9_fb.parameters_default,
    show_steps=[1, 4],
    scale_figure=2.0,
    annotate_measured_steps=[4],
)

save_and_show(fig, './figures-manuscript/fig-5b-K4-simplified-pred-blocky.svg')

In [None]:
fig, axs = figure_histograms_cascade(
    parameters_names=cascade_k2p5_fb.CascadeK2P5Fb(ts=ts_detailed_12h, signal=None).parameters_names,
    samplings=[samples_k2p5_fb_4],
    horizontal=True,
)

save_and_show(fig, './figures-manuscript/fig-5c-simplified-hists.svg')

In [None]:
%%time
fig, axs = figure_cascade_prediction(
    cascade_k2p5_fb.CascadeK2P5Fb(ts=ts_detailed_12h, signal=signals.test_wiggly),
    samples_k2p5_fb_4_pred,
    cascade_true=cascade_k4p9_fb.CascadeK4P9Fb(ts=ts_detailed_12h, signal=signals.test_wiggly),
    parameters_true=cascade_k4p9_fb.parameters_default,
    show_steps=[4],
    scale_figure=2.0,
    annotate_measured_steps=[4],
)

save_and_show(fig, './figures-internal/fig-5d-K4-simplified-pred-wiggly.svg')

# Appendix I.

# Figure S1. Eigenvectors

In [None]:
fig, axs = subplots_from_axsize(
    axsize=(2.0, 0.8),
    nrows=4, ncols=2,
    hspace=[0.4, 0.4, 0.2],
    wspace=0.2,
)

plot_eigenvector(axs[0, 0], samples_pca(samples_4)[-1])
axs[0, 1].set_axis_off()

plot_eigenvector(axs[1, 0], samples_pca(samples_24)[-1])
plot_eigenvector(axs[1, 1], samples_pca(samples_24)[-2])

plot_eigenvector(axs[2, 0], samples_pca(samples_1234)[-1])
plot_eigenvector(axs[2, 1], samples_pca(samples_1234)[-2])
plot_eigenvector(axs[3, 0], samples_pca(samples_1234)[-3])
plot_eigenvector(axs[3, 1], samples_pca(samples_1234)[-4])

save_and_show(fig, './figures-manuscript/supp-fig-s1a-eigenvectors.svg')

In [None]:
ws_24 = disentangle_vectors(
    vs=jnp.array(samples_pca(samples_24)[-2:]),
    mask=jnp.array([
        [0., 0., 0., 0., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0., 0.],
    ])
)

ws_1234 = disentangle_vectors(
    vs=jnp.array(samples_pca(samples_1234)[-4:]),
    mask=jnp.array([
        [0., 0., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 0., 0., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 0., 0., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0., 0.],
    ])
)

In [None]:
fig, axs = subplots_from_axsize(
    axsize=(2.0, 0.8),
    nrows=4, ncols=2,
    hspace=[0.4, 0.4, 0.2],
    wspace=0.2,
)

v1 = samples_pca(samples_4)[-1]
plot_eigenvector(axs[0, 0], v1)
axs[0, 1].set_axis_off()

plot_eigenvector(axs[1, 0], ws_24[0])
plot_eigenvector(axs[1, 1], ws_24[1])

plot_eigenvector(axs[2, 0], ws_1234[0])
plot_eigenvector(axs[2, 1], ws_1234[1])
plot_eigenvector(axs[3, 0], ws_1234[2])
plot_eigenvector(axs[3, 1], ws_1234[3])

save_and_show(fig, './figures-manuscript/supp-fig-s1b-disentangled.svg')

# Appendix II. Pitfalls

## Wrong training (pulse too short)

In [None]:
samples_short_4, samples_short_1234 = map(load_samples, [
    './cache/cascade_k4p9_fb_self_1h_measure_4.pkl',
    './cache/cascade_k4p9_fb_self_1h_measure_1+2+3+4.pkl',
])

samples_short_4_pred, samples_short_1234_pred = map(subsample_for_pred,
    [samples_short_4, samples_short_1234]
)

In [None]:
fig, axs = figure_cascade_training(
    cascade_k4p9_fb.CascadeK4P9Fb(ts=ts_train_default+1, signal=signals.make_pulse_delayed(1)),
    cascade_k4p9_fb.parameters_default,
)
axs[0, 0].set_title("training data", color=color_note, pad=12)

save_and_show(fig, './figures-manuscript/supp-fig-s2a-training-data-short.svg')

In [None]:
%%time
fig, axs = figure_cascade_prediction(
    cascade_k4p9_fb.CascadeK4P9Fb(ts=ts_detailed_12h, signal=signals.test_blocky),
    samples_short_4_pred,
    parameters_true=cascade_k4p9_fb.parameters_default,
    annotate_measured_steps=[4],
)

save_and_show(fig, './figures-manuscript/supp-fig-s2b-K4-short-pred-blocky.svg')

In [None]:
%%time
fig, axs = figure_cascade_prediction(
    cascade_k4p9_fb.CascadeK4P9Fb(ts=ts_detailed_12h, signal=signals.test_blocky),
    samples_short_1234_pred,
    parameters_true=cascade_k4p9_fb.parameters_default,
    annotate_measured_steps=[1, 2, 3, 4],
)

save_and_show(fig, './figures-manuscript/supp-fig-s2c-K1234-short-pred-blocky.svg')

In [None]:
fig, axs = figure_pca(
    [samples_short_1234, samples_1234],
    sampling_colors=['limegreen', sampling_colors[-1]],
)
save_and_show(fig, './figures-manuscript/supp-fig-s2d-pca.svg')

## Wrong training & simplified model!

In [None]:
samples_short_simplified_4, = map(load_samples, [
    './cache/cascade_k2p5_fb_to_cascade_k4p9_fb_1h_measure_4.pkl',
])

samples_short_simplified_4_pred, = map(subsample_for_pred,
    [samples_short_simplified_4, ]
)

In [None]:
%%time
fig, axs = figure_cascade_prediction(
    cascade_k2p5_fb.CascadeK2P5Fb(ts=ts_detailed_12h, signal=signals.make_pulse_delayed(1)),
    samples_short_simplified_4,
    cascade_true=cascade_k4p9_fb.CascadeK4P9Fb(ts=ts_detailed_12h, signal=signals.make_pulse_delayed(1)),
    parameters_true=cascade_k4p9_fb.parameters_default,
    annotate_measured_steps=[4],
    show_steps=[4],
    scale_figure=2.0,
    signal_type='train',
)

save_and_show(fig, './figures-internal/supp-fig-s?a-K4-short-simplified-train.svg')

In [None]:
%%time
fig, axs = figure_cascade_prediction(
    cascade_k2p5_fb.CascadeK2P5Fb(ts=ts_detailed_12h, signal=signals.test_blocky),
    samples_short_simplified_4,
    cascade_true=cascade_k4p9_fb.CascadeK4P9Fb(ts=ts_detailed_12h, signal=signals.test_blocky),
    parameters_true=cascade_k4p9_fb.parameters_default,
    annotate_measured_steps=[4],
    show_steps=[4],
    scale_figure=2.0,
)

save_and_show(fig, './figures-internal/supp-fig-s?b-K4-short-simplified-pred-blocky.svg')

## Wrong model (no feedback)

In [None]:
import identifiability.model.cascade_k4p8 as cascade_k4p8

In [None]:
samples_k4p8_4, samples_k4p8_1234 = map(load_samples, [
    './cache/cascade_k4p8_to_cascade_k4p9_fb_4h_measure_4.pkl',
    './cache/cascade_k4p8_to_cascade_k4p9_fb_4h_measure_1+2+3+4.pkl',
])

samples_k4p8_4_pred, samples_k4p8_1234_pred = map(subsample_for_pred,
    [samples_k4p8_4, samples_k4p8_1234]
)

In [None]:
%%time
fig, axs = figure_cascade_prediction(
    cascade_k4p8.CascadeK4P8(ts=ts_detailed_12h, signal=signals.make_pulse_delayed(4)),
    samples_k4p8_4_pred,
    cascade_true=cascade_k4p9_fb.CascadeK4P9Fb(ts=ts_detailed_12h, signal=signals.make_pulse_delayed(4)),
    parameters_true=cascade_k4p9_fb.parameters_default,
    annotate_measured_steps=[4],
    signal_type='train',
)

save_and_show(fig, './figures-manuscript/supp-fig-s3a-K4-wrong-train.svg')

In [None]:
%%time
fig, axs = figure_cascade_prediction(
    cascade_k4p8.CascadeK4P8(ts=ts_detailed_12h, signal=signals.test_blocky),
    samples_k4p8_4_pred,
    cascade_true=cascade_k4p9_fb.CascadeK4P9Fb(ts=ts_detailed_12h, signal=signals.test_blocky),
    parameters_true=cascade_k4p9_fb.parameters_default,
    annotate_measured_steps=[4],
)

save_and_show(fig, './figures-manuscript/supp-fig-s3b-K4-wrong-pred-blocky.svg')

In [None]:
%%time
fig, axs = figure_cascade_prediction(
    cascade_k4p8.CascadeK4P8(ts=ts_detailed_12h, signal=signals.make_pulse_delayed(4)),
    samples_k4p8_1234_pred,
    cascade_true=cascade_k4p9_fb.CascadeK4P9Fb(ts=ts_detailed_12h, signal=signals.make_pulse_delayed(4)),
    parameters_true=cascade_k4p9_fb.parameters_default,
    annotate_measured_steps=[1, 2, 3, 4],
    signal_type='train',
)

save_and_show(fig, './figures-manuscript/supp-fig-s3c-K1234-wrong-train.svg')

In [None]:
%%time
fig, axs = figure_cascade_prediction(
    cascade_k4p8.CascadeK4P8(ts=ts_detailed_12h, signal=signals.test_blocky),
    samples_k4p8_1234_pred,
    cascade_true=cascade_k4p9_fb.CascadeK4P9Fb(ts=ts_detailed_12h, signal=signals.test_blocky),
    parameters_true=cascade_k4p9_fb.parameters_default,
    annotate_measured_steps=[1, 2, 3, 4],
)

save_and_show(fig, './figures-manuscript/supp-fig-s3b-K1234-wrong-pred-blocky.svg')

In [None]:
kwargs=dict(
    parameters_names=cascade_k4p8.CascadeK4P8(ts=ts_detailed_12h, signal=None).parameters_names,
    samplings=[samples_k4p8_4, samples_k4p8_1234],
    sampling_colors=[sampling_colors[0], sampling_colors[-1]],
    horizontal=True,
    wspace=0.25,
)

fig, axs = figure_histograms_cascade(
    **kwargs, select=lambda x: np.array(x)[..., 0::2],
)
save_and_show(fig, './figures-manuscript/supp-fig-s3c1-hist-as.svg')

fig, axs = figure_histograms_cascade(
    **kwargs, select=lambda x: np.array(x)[..., 1::2],
)
save_and_show(fig, './figures-manuscript/supp-fig-s3c2-hist-ds.svg')

# Springs

In [None]:
import identifiability.model.springs as springs

springs_parameters_true = springs.parameters_default

#springs_ts_train = jnp.linspace(0, 40, 41)
springs_ts_train_detailed = jnp.linspace(0, 50, 201)
springs_signal_train = lambda t: -0.2 * (t >= 10) * (t < 20)

springs_ts_test = jnp.linspace(0, 190, 301)
springs_signal_test = lambda t: -0.2 * jnp.sin((t - 10) * jnp.pi / 20) * (t >= 10)

## Figure 6

In [None]:
samples_springs_3, samples_springs_123 = map(load_samples, [
    './cache/springs_measure_3.pkl',
    './cache/springs_measure_1+2+3.pkl',
])

samples_springs_3_pred, samples_springs_123_pred = map(subsample_for_pred,
    [samples_springs_3, samples_springs_123]
)

In [None]:
%%time
fig, axs = figure_springs_prediction(
    springs.SpringsModel(ts=springs_ts_train_detailed, signal=springs_signal_train),
    samples_springs_3_pred,
    parameters_true=springs_parameters_true,
    signal_type='train',
)

save_and_show(fig, './figures-manuscript/fig-6a-M3-train.svg')

In [None]:
%%time
fig, axs = figure_springs_prediction(
    springs.SpringsModel(ts=springs_ts_test, signal=springs_signal_test),
    samples_springs_3_pred,
    parameters_true=springs_parameters_true,
)

save_and_show(fig, './figures-manuscript/fig-6b-M3-pred.svg')

In [None]:
%%time
fig, axs = figure_springs_prediction(
    springs.SpringsModel(ts=springs_ts_train_detailed, signal=springs_signal_train),
    samples_springs_123_pred,
    parameters_true=springs_parameters_true,
    signal_type='train',
)

save_and_show(fig, './figures-manuscript/fig-6c-M123-train.svg')

In [None]:
%%time
fig, axs = figure_springs_prediction(
    springs.SpringsModel(ts=springs_ts_test, signal=springs_signal_test),
    samples_springs_123_pred,
    parameters_true=springs_parameters_true,
)

save_and_show(fig, './figures-manuscript/fig-6d-M123-pred.svg')

In [None]:
kwargs=dict(
    parameters_names=springs.SpringsModel(ts=springs_ts_test, signal=springs_signal_test).parameters_names,
    samplings=[samples_springs_3, samples_springs_123],
    sampling_colors=[cascade_colors[-1], 'black'],
    horizontal=True,
    wspace=0.35,
)

fig, axs = figure_histograms_springs(
    **kwargs, select=lambda x: np.array(x)[..., 0:3], ylim_max=1.0,
)
save_and_show(fig, './figures-manuscript/fig-6e1-hist-ms.svg')

fig, axs = figure_histograms_springs(
    **kwargs, select=lambda x: np.array(x)[..., 3:6], ylim_max=4.0,
)
save_and_show(fig, './figures-manuscript/fig-6e2-hist-bs.svg')

fig, axs = figure_histograms_springs(
    **kwargs, select=lambda x: np.array(x)[..., 6:8], ylim_max=4.0,
)
save_and_show(fig, './figures-manuscript/fig-6e3-hist-ks.svg')