In [None]:
import numpy as np
import matplotlib.pyplot as plt

from lib.correlation_integral import LongSpatial02CI, LongSpatialCI
from lib.total_variation import TotalVariation
from models.lib.firstorderode.base import RungeKutta
from models.lib.firstorderode.lorenz import Lorenz63
from models.lib.firstorderode.roessler import Roessler76
from models.lib.firstorderode.sprott import SprottAttractors
from models.lib.toy_models import JAU24a, LinaJL

Model = lambda seed = None, n_buffer = 10000: RungeKutta(Lorenz63(s=10, r=28, b=8 / 3, seed=seed,), odeint_dt=1e-3, step_size=20, n_buffer=n_buffer,)

In [None]:
from matplotlib import colors
import matplotlib.pyplot as plt
import numpy as np


cdict = {'red':   ((0.0, 0.0, 0.0),
                   (0.5, 1.0, 1.0),
                   (1.0, 1.0, 1.0)),

         'green': ((0.0, 116./255., 116./255.),
                   (0.5, 1., 1.),
                   (1.0, 121./255., 121./255.)),

         'blue':  ((0.0, 122./255., 122./255.),
                   (0.5, 1., 1.0),
                   (1.0, 0.0, 0.0))
        }
       

TUI = colors.LinearSegmentedColormap('TUI', cdict)
norm = colors.Normalize(vmin=0, vmax=1) 

color_hist = TUI(norm(1))#"#ff885a"
color_chi = "black"
color_threshold="black"
color_region = "grey"
color_res = TUI(norm(0))

In [None]:
dt_02 = True
npz = np.load(f"results/sweeped_thetas{"_02" if dt_02 else "" }.npz")
data=npz["data"]
thetas=npz["thetas"]

In [None]:
def middles(lower, upper, n_bins):
    bins = np.linspace(lower, upper, n_bins+1)
    return np.convolve(bins, [1/2, 1/2], "valid")

In [None]:
from matplotlib.gridspec import GridSpec
from mpl_toolkits.mplot3d import proj3d

from scipy.stats import chi2

dt_02 = True

data = np.load(f"results/reference_values_ranges{"_02" if dt_02 else ""}.npy")
data = data[:, -1]
thresholds = np.quantile(data, 0.95, axis=0)
thresholds[0] = 1-thresholds[0]
thresholds[-1] = np.log(chi2(10).ppf(0.95))

npz = np.load(f"results/low_res/sweeped_thetas{"_02" if dt_02 else "" }.npz")
data=npz["data"]
thetas=npz["thetas"]

for dim in range(3):
    data_dim = data[dim]
    thetas_dim = thetas[dim]

    fig = plt.figure(figsize=(7, 5))
    plt.rc('font', size=12)  # Change global font size
    grid = GridSpec(5, 5)
    seed = 0

    theta_ref = np.array((10, 28, 8/3))

    timeseries = Model(seed).get_timeseries(5_000)[0]
    
    ax = fig.add_subplot(grid[-1, 0], projection="3d")
    ax.plot(*timeseries, color="None")
    ax.axis("off")
    proj = lambda x, ax=ax: proj3d.proj_transform(*x.T, ax.get_proj())[:2]

    ax_mini_plots = []
    idx = np.round(middles(0, len(thetas[dim]) - 1, 5)).astype(int)
    for i, theta_dim in enumerate(thetas[dim][idx]):
        theta = theta_ref.copy()
        theta[dim] = theta_dim
        other_timeseries = RungeKutta(Lorenz63(*theta, seed), 2.5e-3, 10, 5_000).get_timeseries(5_000, 1)[0]
        # Set up the 3D plot
        ax = fig.add_subplot(grid[-1, i])
        ax_mini_plots.append(ax)
        #ax.plot(*proj(timeseries), color="grey", rasterized=True)
        ax.plot(*proj(other_timeseries), color=TUI(norm(1)), rasterized=True, alpha=0.7)
        ax.axis("off")

    # plot measures

    for i_measure, d, name in zip(*zip(*enumerate(data_dim.T)), ("AExc", "ADev", "TVar", "Log GCI")):
        if i_measure != 0:
            ax = fig.add_subplot(grid[i_measure, :], sharex=ax)
        else:
            ax = fig.add_subplot(grid[i_measure, :])
        if i_measure == 0:
            d = 1 - d
        ax.set_ylabel(name)

        if i_measure == 3:
            d = np.log(d)
        
        med, qmi, qma = np.median(d, axis=-1), np.quantile(d, 0.25, axis=-1), np.quantile(d, 0.75, axis=-1)
        print(med.shape, thetas_dim.shape)
        ax.plot(thetas_dim, med, color=color_hist, label="Median")
        ax.fill_between(thetas_dim, qmi, qma, color=color_hist, alpha=0.3, label='25 - 75 Quantile')
        ax.yaxis.set_label_coords(-0.15, 0.5)  # Align all y-labels to the same x position

        xlim, ylim = ax.get_xlim(), ax.get_ylim()
        xmi = thetas_dim[np.argwhere(qmi <= thresholds[i_measure])][0][0]
        argma = np.argwhere(qma <= thresholds[i_measure])[0][0]
        xma = thetas_dim[np.argwhere(np.logical_and(qmi >= thresholds[i_measure], thetas_dim > thetas_dim[argma]) )][0][0]
        ylim = list(ylim)
        ylim[0] = min(ylim[0], thresholds[i_measure] - (ylim[1] - thresholds[i_measure]) * 0.2)
        #ax.fill_betweenx([ylim[0], thresholds[i_measure]], [xlim[0]]*2, [xlim[1]]*2, color="grey", alpha=0.5, label="Acceptance \nArea", zorder=0)
        ax.fill_betweenx([ylim[0], ylim[1]], [xmi]*2, [xma]*2, color="grey", alpha=0.5, label="Less than\n75% Rejection", zorder=0)
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        ax.axvline(theta_ref[dim], color="black", label="Reference")
        ax.axhline(thresholds[i_measure], color="black", ls=(0, (3, 3, 1, 3)), label="Threshold")
        ax.set_xlabel(fr"$\theta_{["x","y","z"][dim]}$")
        if i_measure == 1:
            ax.legend(loc="lower right")
        #ax.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))
    fig.tight_layout()
    fig.subplots_adjust(hspace=0.0, wspace=0)

    for ax in ax_mini_plots:
        pos = ax.get_position()  # Get current position
        ax.set_position([pos.x0, pos.y0 -  0.09, pos.width, pos.height*0.9])  # Move it down
        #ax.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))

    fig.savefig(f"pictures/sweep_theta{"_02" if dt_02 else "" }_dim_{dim}.pdf", bbox_inches="tight")
    plt.show()