In [None]:
from lib.correlation_integral import LongSpatial02CI
from models.lib.firstorderode.base import RungeKutta
from models.lib.firstorderode.lorenz import Lorenz63
from models.lib.firstorderode.roessler import Roessler76
from models.lib.firstorderode.sprott import SprottAttractors
from models.lib.toy_models import LinaJL
from reservoir_helper import get_lorenz63_data_and_reservoir

dt_02 = True
n_nodes = 20
i_model = 4

#get_lorenz63_data_and_reservoir = lambda seed, n_nodes: get_lorenz63_data_and_reservoir(seed, n_nodes, Model)
import numpy as np
import matplotlib.pyplot as plt

In [None]:
from lib.attractor_deviation import AttractorDeviation
from lib.attractor_inclusion import AttractorInclusion
from lib.correlation_integral import LongSpatialCI, SpatialCI
from lib.distribution_deviation import DistributionDeviation
from lib.local_step_ahead import LocalStepAheadPredictability
from lib.total_variation import TotalVariation

if dt_02:
    models = []
    for ODEClass in [* SprottAttractors[1:4], Roessler76]:
        models += lambda seed = None, n_buffer = 10000, ODEClass=ODEClass: RungeKutta(ODEClass(seed), 2e-2, 10, n_buffer),
    models += lambda seed = None, n_buffer = 10000: RungeKutta(Lorenz63(s=10, r=28, b=8 / 3, seed=seed,), odeint_dt=1e-3, step_size=20, n_buffer=n_buffer,),
    CI = LongSpatial02CI
else:
    models = []
    for ODEClass in [* SprottAttractors[1:4], Roessler76]:
        models += lambda seed = None, n_buffer = 10000, ODEClass=ODEClass: RungeKutta(ODEClass(seed), 2e-2, 50, n_buffer),
    models += lambda seed = None, n_buffer = 10000: RungeKutta(Lorenz63(s=10, r=28, b=8 / 3, seed=seed,), odeint_dt=1e-3, step_size=100, n_buffer=n_buffer,),
    CI = LongSpatialCI

In [None]:
from numpy import absolute

bounded = lambda pre: not (absolute(pre) > 2).any()
last = 5
osc = lambda pre, last=3: not (absolute(pre[-last:-1] - pre[-1]) < 5e-3).all() 

In [None]:
from lib.attractor_deviation import AttractorDeviation
from lib.attractor_inclusion import AttractorInclusion
from lib.correlation_integral import LongSpatialCI, SpatialCI
from lib.distribution_deviation import DistributionDeviation
from lib.local_step_ahead import LocalStepAheadPredictability

lcis = list(map(
    lambda x: CI(f"data/heikki_long{"_dt_02" if dt_02 else ""}/" + x + f"_{CI.__name__}_ndata_300_nmul_100_seed_0.npz"), (
        map(lambda x: x().ode.__class__.__name__, models)
    ) 
))

adevs = list(map(
    lambda x: AttractorDeviation("data/hist/hist_" + x + ".npz"), (
        map(lambda x: x().ode.__class__.__name__, models)
    ) 
))

ddevs = list(map(
    lambda x: TotalVariation("data/hist/hist_" + x + ".npz"), (
        map(lambda x: x().ode.__class__.__name__, models)
    ) 
))

aincs = list(map(
    lambda x: AttractorInclusion("data/hist/hist_" + x + ".npz"), (
        map(lambda x: x().ode.__class__.__name__, models)
    ) 
))

plot_hist_measure = list(map(
    lambda x: TotalVariation("data/hist/hist_" + x + ".npz", resolution=64), (
        map(lambda x: x().ode.__class__.__name__, models)
    ) 
))

In [None]:
from multiprocessing import Pool
from numpy.random import default_rng
from tqdm import tqdm
from lib.rmse import rmse
from lib.valid_prediction_time import valid_prediction_time
from scipy.stats import chi2
from numpy.random import SeedSequence

if dt_02:
    models = []
    for ODEClass in [* SprottAttractors[1:4], Roessler76]:
        models += lambda seed = None, n_buffer = 10000, ODEClass=ODEClass: RungeKutta(ODEClass(seed), 2e-2, 10, n_buffer),
    models += lambda seed = None, n_buffer = 10000: RungeKutta(Lorenz63(s=10, r=28, b=8 / 3, seed=seed,), odeint_dt=1e-3, step_size=20, n_buffer=n_buffer,),
    CI = LongSpatial02CI
else:
    models = []
    for ODEClass in [* SprottAttractors[1:4], Roessler76]:
        models += lambda seed = None, n_buffer = 10000, ODEClass=ODEClass: RungeKutta(ODEClass(seed), 2e-2, 50, n_buffer),
    models += lambda seed = None, n_buffer = 10000: RungeKutta(Lorenz63(s=10, r=28, b=8 / 3, seed=seed,), odeint_dt=1e-3, step_size=100, n_buffer=n_buffer,),
    CI = LongSpatialCI

def get_trajectory(inputs):
    spectral_radius, i_model, seed = inputs
    measures = []
    measures_ref = []
    names = []
    rng = default_rng(seed)

    weights_internal = np.eye(n_nodes)


    predictions_closed_loop, predictions_open_loop, data, normalizer, model, train, open_loop, closed_loop = get_lorenz63_data_and_reservoir(seed, n_nodes, models[i_model], spectral_radius=spectral_radius, regulizer=(1e-7 if dt_02 else 5e-6), weights_internal = weights_internal, g_in =(0.4 if dt_02 else 1) )
    
    compare = predictions_closed_loop[:-1]
        
    measures += [
    bounded(predictions_closed_loop), # type: ignore
    osc(predictions_closed_loop),
    (valid_prediction_time(data[closed_loop][1:], predictions_closed_loop[:-1])),
    rmse(data[closed_loop][1:], predictions_closed_loop[:-1])**2
    ]

    names += [
    ("bounded"), # type: ignore
    ("ozillating"),
    ("valid_prediciton_time"),
    ("mse")
    ]

    chiout = lcis[i_model](compare, rng, unnormalizer=normalizer.unnormalize)

    measures += 1-aincs[i_model](compare, normalizer=normalizer),
    measures += adevs[i_model](compare, normalizer=normalizer),
    measures += ddevs[i_model](compare, normalizer=normalizer),

    measures += chiout[-1][0],
    names += "Heikki",

    measures_ref += 1-aincs[i_model](data[closed_loop], normalizer=normalizer),
    measures_ref += adevs[i_model](data[closed_loop], normalizer=normalizer),
    measures_ref += ddevs[i_model](data[closed_loop], normalizer=normalizer),
    chiout = lcis[i_model](data[closed_loop], rng, unnormalizer=normalizer.unnormalize)
    measures_ref += chiout[-1][0],
    

    names += ("AIncs", "Adev", "Ddev")
            
    return predictions_closed_loop, data[closed_loop], measures, measures_ref



In [None]:
models2 = [*SprottAttractors[1:4], Roessler76, Lorenz63]

data20 = np.load(f"results/spectral_radius_n_nodes_{20}/uncoupled_single_0.4{"" if not dt_02 else "_dt_02"}_{models2[i_model].__name__}.npz")
data20["measures"].shape

ss = SeedSequence(0)
ss.spawn(data20["measures"].shape[1])
seeds = np.array(ss.spawn(data20["measures"].shape[1]))

In [None]:
data = data20["measures"].squeeze()#
print(data.shape)
idx = np.all(data[:, :2] == 1, axis=-1)
data = data[idx, 4:]
data = data[:, [3, 1, 2, 0]]
data[:, 0] = 1- data[:, 0]
seeds = seeds[idx]
print(data.shape)

In [None]:
models_name = ["Sprott B", "Sprott C", "Sprott D", "Roessler76", "Lorenz63"]
thresholds = np.load(f"results/reference_values_ranges{"_02" if dt_02 else ""}.npy")
thresholds = thresholds[:, i_model]
thresholds[:, 0] = 1 - thresholds[:, 0]
thresholds = np.quantile(thresholds, 0.95, axis=0)
thresholds[-1] = chi2.ppf(0.95, df = 10)

In [None]:
idx = np.all(data[:] > thresholds, axis=-1)
seeds = seeds[idx]
data = data[idx]
data.shape, np.mean(idx)

In [None]:
from mpl_toolkits.mplot3d import proj3d

ax = plt.subplot(111, projection="3d")
proj = lambda x, ax=ax: proj3d.proj_transform(*x.T, ax.get_proj())[:2]

# Plot many lorenzes

In [None]:
def worker(i):
    return get_trajectory((0.4, 4, seeds[i]))
with Pool(25) as p:
    all_rejected_measures = list(map(np.array, zip(*p.map(worker, range(25)))))

In [None]:

from matplotlib import colors
from matplotlib import gridspec
from matplotlib.gridspec import GridSpec
from mpl_toolkits.mplot3d import proj3d
from mpl_toolkits.axes_grid1 import ImageGrid


cdict = {'red':   ((0.0, 0.0, 0.0),
                   (0.5, 1.0, 1.0),
                   (1.0, 1.0, 1.0)),

         'green': ((0.0, 116./255., 116./255.),
                   (0.5, 1., 1.),
                   (1.0, 121./255., 121./255.)),

         'blue':  ((0.0, 122./255., 122./255.),
                   (0.5, 1., 1.0),
                   (1.0, 0.0, 0.0))
        }
       

TUI = colors.LinearSegmentedColormap('TUI', cdict)
norm = colors.Normalize(vmin=0, vmax=1) 

color_hist = TUI(norm(1))#"#ff885a"
color_chi = "black"
color_threshold="black"
color_region = "grey"
color_res = TUI(norm(0))

from models.lib.firstorderode.lorenz import Lorenz63
from models.lib.firstorderode.roessler import Roessler76
from models.lib.firstorderode.sprott import SprottAttractors


models = [*SprottAttractors[1:4], Roessler76, Lorenz63]

from scipy.stats import chi2
import numpy as np
import matplotlib.pyplot as plt
from brokenaxes import brokenaxes

plt.rcParams.update({
    "text.usetex": False,            # Use LaTeX for text rendering
    "text.latex.preamble": r"\usepackage{xcolor}"  # Load the xcolor package
})

models_name = ["Sprott B", "Sprott C", "Sprott D", "Roessler76", "Lorenz63"]
data = np.load(f"results/reference_values_ranges{"_02" if dt_02 else ""}.npy")
#data[:, :, -1] = chi2(10).ppf(np.exp(data[:, :, -1]))
data = data.transpose((1, 2, 0))
data.shape

fig = plt.figure(figsize=(10,6))

n_rows = 3
n_columns = 6
i_model = 4
seed = 0

n_res_to_ref_nodes = 8/11 #6/(10+1)
    
n_total = (n_rows * n_columns - 1)
n_res_nodes = int(n_total * n_res_to_ref_nodes)
n_ref_nodes = n_total - n_res_nodes

rng = default_rng(seed)

#all_measure = np.concatenate([ref_measure[:, idx_ref], res_measure[:, idx_res]], axis=-1)
#res = np.concatenate([np.array(collected2[i_model][1])[idx_ref], np.array(collected2[i_model][2])[idx_res]], axis=0)
#is_ref = np.concatenate((np.ones(n_ref_nodes), np.zeros(n_res_nodes)))

# all_measure = 
# res = all_rejected_measures[0][:11]
# is_ref = np.zeros(11)

all_measure = np.concatenate([all_rejected_measures[3][:n_ref_nodes].T, all_rejected_measures[2][:n_res_nodes, 4:].T], axis=-1)
res = np.concatenate([all_rejected_measures[1][:n_ref_nodes],  all_rejected_measures[0][:n_res_nodes]], axis=0)
is_ref = np.concatenate((np.ones(n_ref_nodes), np.zeros(n_res_nodes)))

# # include reference data
# ref_measure2 = np.array(collected2[i_model][0]).T[:]
# idx = rng.choice(ref_measure2.shape[-1], size=n_rows * n_columns, replace=False)
# ref_measure[:, n_columns:] = ref_measure2[8:, idx][:, n_columns:]

index_measure = np.argsort(all_measure[-2])
grid = GridSpec(6, 6, height_ratios=[*[0.2]*4, 0.05, 0.3*n_rows])
fontsize=12

roman_numerals = list(map(lambda x: x.lower(), [r"$\mathsf{I}$", r"$\mathsf{II}$", r"$\mathsf{III}$", 
                  r"$\mathsf{IV}$", r"$\mathsf{V}$", r"$\mathsf{VI}$", 
                  r"$\mathsf{VII}$", r"$\mathsf{VIII}$", r"$\mathsf{IX}$", 
                  r"$\mathsf{X}$", r"$\mathsf{XI}$"]))

latin_letter = "a,b,c,d,e,f,g,h,i,j,k,l,m,n,o,p,q,r,s,t,u,v,w,x,y,z".upper().split(",")

# define projection
miniplot_grid = ImageGrid(fig, grid[-1, :],
                 nrows_ncols=(n_rows, n_columns),  # 2x2 grid
                 share_all=True,
                 cbar_mode='single',
                 cbar_location='right',
                 cbar_size='5%',
                 cbar_pad=0.4)
ax = fig.add_subplot(grid[0, 0], projection="3d")
proj = lambda x, ax=ax: proj3d.proj_transform(*x.T, ax.get_proj())[:2]
ax = miniplot_grid[0]
ax.axis("off")



# plot mini hist
index_with_star, index_without_star = 0, 0 
global x_edges
global y_edges
global norm_colorbar
for index_1d, ax in enumerate(miniplot_grid[1:]):
    
    if index_1d == 0:
        # color pallet
        cmap = plt.cm.Greys.copy()
        cmap.set_over('#fc03f8')     # For values > vmax
        # find common colorbar
        max_counts, x_edges, y_edges = np.histogram2d(*proj(res[1]), bins=64)
        norm_colorbar = colors.Normalize(0, 13.2, clip=False)
        _, _, _, hist = ax.hist2d(*proj(res[index_1d]), rasterized=True, bins=[x_edges, y_edges], norm=norm_colorbar, cmap=cmap)
        cbar = miniplot_grid.cbar_axes[0].colorbar(hist, extend="max", cmap=cmap)
        print("vmin", cbar.vmin, "vmax", cbar.vmax)
        cbar.set_label("Counts", size=12)
    index_1d = index_measure[index_1d]
    ax.hist2d(*proj(res[index_1d]), rasterized=True, bins=[x_edges, y_edges], norm=norm_colorbar, cmap=cmap)
    ax.axis("off")
    if is_ref[index_1d] == 1:
        legend = ax.legend(title = roman_numerals[index_with_star], loc="upper left", fontsize=fontsize, bbox_to_anchor=(0, 1), borderaxespad=0)
        index_with_star += 1
    else:
        legend = ax.legend(title = latin_letter[index_without_star], loc="upper left", fontsize=fontsize, bbox_to_anchor=(0, 1), borderaxespad=0)
        index_without_star += 1

    legend.get_frame().set_alpha(None)
    legend.get_title().set_color("black")
    legend.get_frame().set_facecolor((0.6, 0.6, 0.6, 0.2))
    legend.get_frame().set_edgecolor((1, 1, 1, 0.0))
    
chi2 = chi2(df=10)
quantiles = np.zeros((5, 4, 2))

print(f"Reference Samples: {data.shape[-1]}")

dd = data[i_model]
# load reservoir data
data20 = np.load(f"results/spectral_radius_n_nodes_{20}/uncoupled_single_0.4{"" if not dt_02 else "_dt_02"}_{models[i_model].__name__}.npz")
names = data20["names"][0, 0]
measures20 = data20["measures"][0].T
n_samples = measures20.shape[1]
name2index = { name:i for i, name in enumerate(names)}

for i_measure, d in enumerate(dd):
    ax = fig.add_subplot(grid[i_measure, :])
    
    if i_model == 4:
        ax.set_ylabel(["AExc", "ADev", "TVar", "GCI"][i_measure]+"\n"+r"Counts", color=color_hist, fontsize=fontsize)

    if i_measure == 3:
        ax.set_xlabel("Measure",fontsize=fontsize)
    
    if n_nodes == 20:
        if i_measure == 0:
            ax.set_xlim(-0.005, 0.025)

        if i_measure == 1:
            ax.set_xlim(111_500, 112_500)

        if i_measure == 2:
            ax.set_xlim(0.865, 0.915)

        heikki_cut_off = 40

    if n_nodes == 10:
        heikki_cut_off = 80

    if i_measure == 0:
        d = 1-d

    # calculate quantiles 

    if i_measure == 3:
        qmax = chi2.ppf(0.95)
        qmin = 0
    else:
        qmax = np.quantile(d, 0.95)

    # blue ghistogram

    
    name = ["AIncs", "Adev", "Ddev", "Heikki"][i_measure]
    hist2 = []
    for i_nodes, (measures, color) in enumerate(zip([measures20], (TUI(norm(0)), "#8502c7"))):
        idx_bnd_osc = np.logical_and(measures[name2index["bounded"]], measures[name2index["ozillating"]])
        if name == "AIncs":
            measure = 1 - measures[name2index["AIncs"]][idx_bnd_osc]
        else:  
            measure = measures[name2index[name]][idx_bnd_osc]

        #calc bins

        if  i_measure == 3 and heikki_cut_off is not None:
            _, bins = np.histogram([*measure[measure < heikki_cut_off], * d[d < heikki_cut_off]], bins = 101) #determine bins
            bins -= (bins[1] - bins[0])*0.5
        else:
            if i_measure == 0:
                _, bins = np.histogram([*measure, *d], bins = 2001) #determine bins
            elif i_measure == 1:
                _, bins = np.histogram([*measure, *d], bins = 401) #determine bins
            else:
                _, bins = np.histogram([*measure, *d], bins = 101) #determine bins
            bins -= (bins[1] - bins[0])*0.5

        # plot orange hist
        if i_nodes == 0:
            n, _, hist1 = ax.hist(d, bins = bins, label="Reference", color=color_hist, density=False, histtype= "stepfilled")#, alpha=0.7)
            ax.ticklabel_format(axis='y', style='sci', scilimits=(0, 0)) 

        # plot measure hist

        ax_twin = ax.twinx()

        _, _, hist = ax_twin.hist(measure, bins = bins, label="Reservoir", color=color, alpha=0.5, density=False, histtype= "stepfilled")
        hist2.append(hist[0])
        ax_twin.set_ylabel("Counts", color=color_res, fontsize=fontsize)
        ax_twin.ticklabel_format(axis='y', style='sci', scilimits=(0, 0)) 

    # threshold and shaded area
    if i_measure == 3:
        x = np.linspace(-1, ax.get_xlim()[-1], 1000)
        ax.plot(x, chi2.pdf(x), color=color_chi, label=r"$\chi^2 (df = 10)$")
    
    ax.set_xlim(ax.get_xlim())
    ax.tick_params(labelbottom=True)
    region = ax.axvspan(ax.get_xlim()[0], qmax, alpha=0.5, color=color_region, label="acceptance \nregion", zorder=0) 
    if i_measure == 1:
        vline = ax.axvline(qmax, color=color_threshold, ls=(0, (3, 3, 1, 3)), label="Threshold:\n{:.0f}".format(qmax), lw=1)
    else:
        vline = ax.axvline(qmax, color=color_threshold, ls=(0, (3, 3, 1, 3)), label="Threshold:\n{:.3f}".format(qmax), lw=1)
    ax = ax_twin
    index_outside = 0
    index_with_star, index_without_star = 0, 0
    for i in index_measure:
        index_for_value = all_measure[i_measure][i]
        #value = ref_measure[i_measure]
        if is_ref[i] == 1:
            label = roman_numerals[index_with_star]
            index_with_star += 1
        else:
            label = latin_letter[index_without_star]
            index_without_star += 1
            
        if index_for_value > ax.get_xlim()[-1]:
            height = 0.2*(index_outside) + 0.15 
            ax.transData
            ax.annotate( label, (1, height), xytext=(0.95-0.02*index_outside, height ), fontsize=fontsize, xycoords='axes fraction', ha='right', zorder=200)
            ax.annotate("", (1, height+0.1), xytext=(0.952-0.02*index_outside, height+0.1 ), fontsize=fontsize, arrowprops=dict(arrowstyle='->'), xycoords='axes fraction', zorder=200)
            index_outside += 1
        else:
            ax.annotate(label, (index_for_value, ax.get_ylim()[-1]*0.09), fontsize=fontsize, zorder=200)
            ax.scatter(index_for_value, 0, marker="o", zorder=200, clip_on=False, color="black", s=50)
    
    #plt.setp(ax.xaxis.get_majorticklabels(), rotation=20)
    #plt.setp(ax.yaxis.get_majorticklabels(), rotation=20)

    # if i_measure == 3:
    #     ax.legend(loc="lower right")

miniplot_grid = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=grid[-1, 0])
ax = fig.add_subplot(miniplot_grid[-1, 0])
ax.axis("off")
ax.legend([hist1[0], *hist2, vline, region], [f"Reference [i - {roman_numerals[index_with_star-1]}]", f"Reservoir [A - {latin_letter[index_without_star-1]}]", "Threshold", "Acceptance\nRegion"], bbox_to_anchor=(1.05, 2.76))
#fig.autofmt_xdate()
fig.tight_layout()
fig.subplots_adjust(hspace=0.32, wspace=0.19)
plt.savefig(f"pictures/spectral_radius_n_nodes_{n_nodes}/model_quantiles{"_02" if dt_02 else ""}.pdf", dpi=200)
plt.savefig(f"pictures/spectral_radius_n_nodes_{n_nodes}/model_quantiles{"_02" if dt_02 else ""}.svg", dpi=200)
plt.show()