In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy
import definitive_dyn_indicators.scripts.data_manager as dm
import os
from tqdm.auto import tqdm
from numba import njit
import lmfit
import joblib
from joblib import Parallel, delayed

njobs = os.cpu_count()
print(f'Number of cores: {njobs}')


In [None]:
@njit
def log_fit(x, a, k):
    return -k * np.log10(x) + a


def residual_log_fit(params, x, y):
    a = (params["a"].value)
    k = params["k"].value

    model = log_fit(x, a, k)
    return model - np.log10(y)


@njit
def fit_3(x, a, k, c):
    return a / np.power(x, k) + c


def residual_3_fit(params, x, y):
    a = params["a"].value
    k = params["k"].value
    c = params["c"].value

    model = fit_3(x, a, k, c)
    return (model - y) / y


In [None]:
def clean_data(x, y):
    x = x[~np.logical_or(np.logical_or(np.isnan(y), np.isinf(y)), y == 0)]
    y = y[~np.logical_or(np.logical_or(np.isnan(y), np.isinf(y)), y == 0)]
    return x, y


def fit(x, y, s, i, kind="scale_law", extra_log=False):
    # print(i)
    try:
        if extra_log:
            y = np.log10(y)
        
        x, y = clean_data(x, y)
        y = np.absolute(y)
        
        if len(x[x > 100]) < 2:
            return "discarded"

        params = lmfit.Parameters()
        if kind == "log_fit":
            params.add("a", value=0)
            params.add("k", value=1)
            result = lmfit.minimize(
                residual_log_fit, params, args=(x, y), method="least_squares")
        elif kind == "fit_3":
            params.add("a", value=1)
            params.add("k", value=1)
            params.add("c", value=0)
            result = lmfit.minimize(
                residual_3_fit, params, args=(x, y), method="least_squares")
        elif kind == "fit_fix_k":
            params.add("a", value=1)
            params.add("k", value=1, vary=False)
            params.add("c", value=0)
            result = lmfit.minimize(
                residual_3_fit, params, args=(x, y), method="least_squares")
        elif kind == "fit_fix_a":
            params.add("a", value=1, vary=False)
            params.add("k", value=1)
            params.add("c", value=0)
            result = lmfit.minimize(
                residual_3_fit, params, args=(x, y), method="least_squares")
        elif kind == "fit_fix_c":
            params.add("a", value=1)
            params.add("k", value=1)
            params.add("c", value=0, vary=False)
            result = lmfit.minimize(
                residual_3_fit, params, args=(x, y), method="least_squares")
        else:
            raise ValueError(f"kind {kind} not recognized")
        return result
    except ValueError:
        # print(e)
        return "error"


In [None]:
print("Initializing data manager...")
data = dm.data_manager(data_dir=".")

print("Setting up configuration...")
data.henon_config["samples"] = 1000

data.henon_config["t_base_2"] = np.array([], dtype=int)
data.henon_config["t_base"] = np.array([], dtype=int)

data.henon_config["t_base_10"] = np.logspace(3, 8, 16, base=10, dtype=int)
data.henon_config["t_linear"] = np.linspace(
    100000, 100000000, 1999, dtype=int)

data.henon_config = dm.refresh_henon_config(dm.henon_config)


In [None]:
config = data.get_config()
print(list(config.keys()))
extents = config["x_extents"] + config["y_extents"]
samples = config["samples"]
print(f"Samples: {samples}")
times = np.asarray(data.get_times())


In [None]:
group = (
    0.168,                  # omega_x
    0.201,                  # omega_y
    "sps",                  # modulation_kind
    16.0,                   # epsilon
    0.01,                   # mu
    np.nan,                 # kick amplitude
    np.nan,                 # omega_0 
)

In [None]:
lyapunov = data.better_lyapunov(group)

In [None]:
with data.get_file_from_group(group, "random", "true_displacement") as f:
    stability = f["steps"][:]

In [None]:
report_list = joblib.Parallel(n_jobs=njobs)(delayed(fit)(
    x[1].index.to_numpy(), x[1].to_numpy(), y, i, kind="fit_fix_k"
) for i, (x, y) in enumerate(zip(
    lyapunov.iterrows(),
    stability
))
)

par_lyap_a = np.array([x.params["a"].value if x != "error" and x !=
                        "discarded" else np.nan for x in report_list])
par_lyap_c = np.array([x.params["c"].value if x != "error" and x !=
                        "discarded" else np.nan for x in report_list])
par_lyap_chi = np.array([x.chisqr if x != "error" and x !=
                          "discarded" else np.nan for x in report_list])


In [None]:
# save the data
with data.get_file_from_group(group, "random", "true_displacement") as f:
    a_dataset = f.require_dataset(
        "fit_a", shape=stability.shape, dtype=float, compression="gzip", shuffle=True)
    c_dataset = f.require_dataset(
        "fit_c", shape=stability.shape, dtype=float, compression="gzip", shuffle=True)
    a_dataset[:] = par_lyap_a
    c_dataset[:] = par_lyap_c


In [None]:
# load the data
with data.get_file_from_group(group, "random", "true_displacement") as f:
    a_dataset = f.require_dataset(
        "fit_a", shape=stability.shape, dtype=float, compression="gzip", shuffle=True)
    c_dataset = f.require_dataset(
        "fit_c", shape=stability.shape, dtype=float, compression="gzip", shuffle=True)
    par_lyap_a = a_dataset[:]
    par_lyap_c = c_dataset[:]

In [None]:
f_par_lyap_c = par_lyap_c.copy()
f_par_lyap_c[stability != 100000000] = np.nan

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8, 6))

ax.hist(np.log10(par_lyap_c), bins=100, density=True,
                    label="value of all $\\log_{{10}}c$")
ax.hist(np.log10(f_par_lyap_c), bins=100, density=True, 
                    label="value of stable $\\log_{{10}}c$")

plt.tight_layout()

In [None]:
threshold_c = np.median(f_par_lyap_c[~np.isnan(f_par_lyap_c)])

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8, 6))

mask = np.asarray(
    stability==100000000, dtype=float
)
mask[stability <= 10] = np.nan

mask[np.logical_and(
    par_lyap_c > threshold_c,
    stability == 100000000
)] = 0.5

map = ax.imshow(
    (mask).reshape(samples, samples),
    cmap="viridis", extent=extents, origin="lower"
)
ax.set_title(f"classification (N=100000000)")

plt.tight_layout()