# NOTEBOOK : TIME SERIES ANALYSIS 
## By Lea PONS, Morgan VIROLAN, Lucas SAVONA

In [None]:

#==============================================================================
#LIBRARY IMPORTS
#==============================================================================
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy import signal
from scipy.signal import butter, filtfilt, freqz, welch
from numpy.fft import rfft, rfftfreq
from statsmodels.tsa.stattools import acf, pacf,adfuller
from statsmodels.graphics.tsaplots import plot_acf, plot_pacf
from statsmodels.tsa.ar_model import AutoReg
from statsmodels.stats.diagnostic import acorr_ljungbox
from scipy.signal import buttord, butter, detrend
import warnings

#==================================================================================
# LOAD DATA, PROCESS BIA, CREATE WINDOWS
#==================================================================================

#---------------
# CONFIG 
#---------------

#-------Data-------
#Data paths and columns
BIA_PKL_PATH = "../data/LEA_BIA_RAW.pkl"   # pkl file exported from BIA device (raw data for pandas)
FREQ_COL = "f_48800"              # ~48.8 kHz complex impedance column

# Window timestamps (BIA datetime): 3 minutes BEFORE / fatigue inducing protocole / and 3 minutes AFTER
PRE_START_TIME  = "2025-11-28 14:57:02.563"
PRE_END_TIME    = "2025-11-28 15:00:02.563"
POST_START_TIME = "2025-11-28 15:05:45.278"
POST_END_TIME   = "2025-11-28 15:08:44.534"

# LOAD RAW BIA (.pkl)
data_bia_raw = pd.read_pickle(BIA_PKL_PATH)

# ANALYSIS DATAFRAME (same data, cleaner columns)
data_bia = data_bia_raw.copy()

# Parse time
data_bia["time"] = pd.to_datetime(data_bia["timestamp"], errors="coerce")
data_bia = data_bia.dropna(subset=["time"]).sort_values("time").reset_index(drop=True)

# Complex impedance at ~48.8 kHz
data_bia["Z_48.8k"] = data_bia[FREQ_COL].astype(np.complex128)

# Compute R, Xc, PhA (standard convention: Xc = -imag(Z))
data_bia["R_48.8k_ohm"]   = np.real(data_bia["Z_48.8k"])
data_bia["Xc_48.8k_ohm"]  = -np.imag(data_bia["Z_48.8k"])
data_bia["PhA_48.8k_deg"] = np.degrees(np.arctan2(data_bia["Xc_48.8k_ohm"], data_bia["R_48.8k_ohm"]))

# Compact analysis view
analysis_cols = ["time", "Z_48.8k", "R_48.8k_ohm", "Xc_48.8k_ohm", "PhA_48.8k_deg", "sat", "min", "max"]
data_bia_analysis = data_bia[analysis_cols].copy()

#----------------Windows---------------
# CREATE WINDOWS (PRE / POST)

pre_start  = pd.to_datetime(PRE_START_TIME)
pre_end    = pd.to_datetime(PRE_END_TIME)
post_start = pd.to_datetime(POST_START_TIME)
post_end   = pd.to_datetime(POST_END_TIME)

def slice_window(df, t0, t1):
    m = (df["time"] >= t0) & (df["time"] <= t1)  # inclusive bounds [start, end]
    return df.loc[m].copy()

bia_pre  = slice_window(data_bia_analysis, pre_start, pre_end)
bia_post = slice_window(data_bia_analysis, post_start, post_end)

#----------------------------------------------------------------
# POST_1 / POST_2 (hard segmentation, variance change-point MLE)
#----------------------------------------------------------------

# Work on the POST PhA series
s_post = pd.to_numeric(bia_post["PhA_48.8k_deg"], errors="coerce").to_numpy()
finite_mask = np.isfinite(s_post)
y = s_post[finite_mask]
n = len(y)

K_MIN = int(0.15 * n)  # avoid edges
K_MAX = int(0.85 * n)

# (Gaussian, different variances) MLE split: minimize n1*log(var1)+n2*log(var2)
scores = []
for k in range(K_MIN, K_MAX):
    v1 = np.var(y[:k], ddof=1)
    v2 = np.var(y[k:], ddof=1)
    scores.append(len(y[:k]) * np.log(v1) + len(y[k:]) * np.log(v2))

k_f = int(np.argmin(scores) + K_MIN)              # split index in finite-only y
finite_idx = np.flatnonzero(finite_mask)          # mapping to dataframe rows
k_df = int(finite_idx[k_f])                       # split index in bia_post rows
split_time = bia_post["time"].iloc[k_df]

bia_post_1 = bia_post.iloc[:k_df].copy()
bia_post_2 = bia_post.iloc[k_df:].copy()

print("POST split_time =", split_time)
print("var POST_1 =", float(pd.to_numeric(bia_post_1["PhA_48.8k_deg"], errors="coerce").var(ddof=1)))
print("var POST_2 =", float(pd.to_numeric(bia_post_2["PhA_48.8k_deg"], errors="coerce").var(ddof=1)))

#==================================================================================
#  PLOTS (just for verification)
#==================================================================================

def plot_window(df_win, title, y_col="PhA_48.8k_deg", smooth_n=9):
    if len(df_win) < 5:
        print("Not enough points to plot:", title)
        return
    d = df_win.copy()
    d["time"] = pd.to_datetime(d["time"])
    d = d.sort_values("time")
    y = pd.to_numeric(d[y_col], errors="coerce")
    y_sm = y.rolling(smooth_n, center=True, min_periods=1).mean()

    plt.figure(figsize=(10,3))
    plt.plot(d["time"], y, alpha=0.7, label="raw", color="#028A88")  # teal
    plt.plot(d["time"], y_sm, label=f"rolling mean (n={smooth_n})", color="#F39EC7")  # pink
    plt.title(title, color="#000000")
    plt.xlabel("time", color="#000000")
    plt.ylabel(y_col, color="#000000")
    plt.gca().set_facecolor("#FDFDFD")  # light background
    plt.tight_layout()
    plt.legend()
    plt.show()

plot_window(bia_pre,  "BIA PRE (PhA_48.8k_deg)")
plot_window(bia_post, "BIA POST (PhA_48.8k_deg)")
plot_window(bia_post_1, "BIA POST_1 (PhA_48.8k_deg)")
plot_window(bia_post_2, "BIA POST_2 (PhA_48.8k_deg)")

#==================================================================================
# OUTPUT (variables to use)
#==================================================================================

pha_pre    = bia_pre["PhA_48.8k_deg"].astype(float).dropna().to_numpy()
pha_post   = bia_post["PhA_48.8k_deg"].astype(float).dropna().to_numpy()
pha_post_1 = bia_post_1["PhA_48.8k_deg"].astype(float).dropna().to_numpy()
pha_post_2 = bia_post_2["PhA_48.8k_deg"].astype(float).dropna().to_numpy()

t_pre  = (bia_pre["time"]  - bia_pre["time"].iloc[0]).dt.total_seconds().to_numpy()
t_post = (bia_post["time"] - bia_post["time"].iloc[0]).dt.total_seconds().to_numpy()
t_post_1 = (bia_post_1["time"] - bia_post_1["time"].iloc[0]).dt.total_seconds().to_numpy()
t_post_2 = (bia_post_2["time"] - bia_post_2["time"].iloc[0]).dt.total_seconds().to_numpy()

dt_pre_med  = float(bia_pre["time"].diff().dt.total_seconds().median())
dt_post_med = float(bia_post["time"].diff().dt.total_seconds().median())
dt_post1_med = float(bia_post_1["time"].diff().dt.total_seconds().median())
dt_post2_med = float(bia_post_2["time"].diff().dt.total_seconds().median())

fs_pre_est   = 1.0 / dt_pre_med
fs_post_est  = 1.0 / dt_post_med
fs_post1_est = 1.0 / dt_post1_med
fs_post2_est = 1.0 / dt_post2_med

print("\nVARIABLES À UTILISER")
print("- bia_pre      : DataFrame PRE")
print("- bia_post     : DataFrame POST (global)")
print("- bia_post_1   : DataFrame POST_1 (avant split)")
print("- bia_post_2   : DataFrame POST_2 (après split)")
print("- pha_pre      : numpy array, PhA_48.8k_deg sur PRE")
print("- pha_post     : numpy array, PhA_48.8k_deg sur POST global")
print("- pha_post_1   : numpy array, PhA_48.8k_deg sur POST_1")
print("- pha_post_2   : numpy array, PhA_48.8k_deg sur POST_2")
print("- t_pre        : numpy array, temps (s) relatif au début de PRE")
print("- t_post       : numpy array, temps (s) relatif au début de POST (global)")
print("- t_post_1     : numpy array, temps (s) relatif au début de POST_1")
print("- t_post_2     : numpy array, temps (s) relatif au début de POST_2")
print("- fs_pre_est    : float, fs approx PRE (= 1 / dt_médian)")
print("- fs_post_est   : float, fs approx POST global")
print("- fs_post1_est  : float, fs approx POST_1")
print("- fs_post2_est  : float, fs approx POST_2")

# 1. STOCHASTIC PROCESS

## 1.1 FUNCTION

In [None]:
#==========================
# Global config 
#==========================

Y_COL = "PhA_48.8k_deg"

# We regularize the time grid because ACF/AR assume equally spaced samples.
# 500 ms is close to the observed median dt (~0.513 s) in your windows.
RESAMPLE_RULE = "500ms"

# Lags used for ACF/PACF plots:
# 40 lags at 0.5 s -> ~20 s horizon (enough to see medium-term dependence).
PLOT_LAGS = 40

# Residual whiteness test horizon:
# 20 lags at 0.5 s -> ~10 s horizon (short-to-medium memory check).
LB_LAG = 20

# Candidate AR orders upper bound (clipped further by n//10 below).
P_MAX_CAP = 20

# Significance level for residual whiteness decision.
ALPHA = 0.05



# Optional: keep the notebook clean (does not change results)
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)

def prep_series(df_win, y_col=Y_COL, resample_rule=RESAMPLE_RULE):
    """
    Prepare a clean 1D series for time-series tools (ACF/AR):
      - parse/sort timestamps
      - keep the signal as numeric
      - regularize sampling (resample + interpolate) to enforce equal spacing

    Returns:
      y_out : pd.Series indexed by time
      info  : small dict (n, dt medians, start/end) for traceability
    """
    d = df_win[["time", y_col]].copy()
    d["time"] = pd.to_datetime(d["time"], errors="coerce")
    d = d.dropna(subset=["time"]).sort_values("time").set_index("time")

    y = pd.to_numeric(d[y_col], errors="coerce")

    # Raw dt (before resampling)
    dt_raw = d.index.to_series().diff().dt.total_seconds().dropna()
    dt_raw_med = float(dt_raw.median()) if len(dt_raw) else np.nan

    if resample_rule is None:
        # If raw dt is already regular enough, keep raw sampling
        y_out = y.dropna()
        dt_out = y_out.index.to_series().diff().dt.total_seconds().dropna()
    else:
        # Regular grid -> helps make ACF/AR assumptions explicit
        y_out = y.resample(resample_rule).mean().interpolate("time")
        dt_out = y_out.index.to_series().diff().dt.total_seconds().dropna()

    info = {
        "start": d.index.min(),
        "end": d.index.max(),
        "n_raw": int(y.notna().sum()),
        "n_out": int(y_out.notna().sum()),
        "dt_raw_med_s": dt_raw_med,
        "dt_out_med_s": float(dt_out.median()) if len(dt_out) else np.nan,
        "resample_rule": resample_rule,
    }
    return y_out, info

def plot_series_acf_pacf(y, name, plot_lags=PLOT_LAGS):
    # 1) Series
    plt.figure(figsize=(10, 3))
    plt.plot(y.index, y.values, color="#028A88")  # teal
    plt.title(f"{name} | {Y_COL} (prepared)")
    plt.xlabel("time")
    plt.ylabel(Y_COL)
    plt.tight_layout()
    plt.show()

    # Compute ACF / PACF explicitly
    acf_vals = acf(y.dropna(), nlags=plot_lags, fft=True)
    pacf_vals = pacf(y.dropna(), nlags=plot_lags, method="ywm")

    lags = np.arange(1, plot_lags + 1)
    stem_color = "#028A88"  # teal

    # 2) ACF (exclude lag 0)
    plt.figure(figsize=(8, 3))
    markerline, stemlines, baseline = plt.stem(lags, acf_vals[1:], basefmt=" ")
    plt.setp(markerline, color=stem_color)
    plt.setp(stemlines, color=stem_color)
    plt.axhline(0, color="k", linewidth=0.8)
    plt.axhspan(
        -1.96 / np.sqrt(len(y)),
         1.96 / np.sqrt(len(y)),
        alpha=0.2, color = "#F39EC7"
    )
    plt.title(f"{name} | ACF (lags 1–{plot_lags})")
    plt.xlabel("lag")
    plt.tight_layout()
    plt.show()

    # 3) PACF (exclude lag 0)
    plt.figure(figsize=(8, 3))
    markerline, stemlines, baseline = plt.stem(lags, pacf_vals[1:], basefmt=" ")
    plt.setp(markerline, color=stem_color)
    plt.setp(stemlines, color=stem_color)
    plt.axhline(0, color="k", linewidth=0.8)
    plt.axhspan(
        -1.96 / np.sqrt(len(y)),
         1.96 / np.sqrt(len(y)),
        alpha=0.2, color = "#F39EC7"
    )
    plt.title(f"{name} | PACF (lags 1–{plot_lags})")
    plt.xlabel("lag")
    plt.tight_layout()
    plt.show()
    
def stationarity_tests(y):
    """
    ADF Test:
      - ADF: H0 = unit root (non-stationary). Small p -> reject non-stationarity.
    """
    yv = pd.Series(y).dropna().to_numpy()

    # ADF (autolag chooses lag length via AIC internally)
    adf_stat, adf_p, _, _, adf_crit, _ = adfuller(yv, autolag="AIC")


    return {
        "adf_stat": float(adf_stat),
        "adf_p": float(adf_p),
        "adf_crit": adf_crit,
    }

def need_diff_by_adf(adf_p, alpha=ALPHA):
    # if we FAIL to reject unit root -> difference
    return bool(adf_p > alpha)

def print_stationarity(st, name):
    print(f"\n{name} | stationarity tests")
    print(f"ADF : stat={st['adf_stat']:.3f} | p={st['adf_p']:.4f}")
    

def ar_grid_search(y_fit, p_max_cap=P_MAX_CAP):
    """
    Fit AR(p) for p=1..p_max and report best AIC and best BIC
    cap p to avoid overfitting on short windows
    """
    yv = pd.Series(y_fit).dropna()
    n = len(yv)
    p_max = max(1, min(p_max_cap, n // 10))  # conservative rule

    best_aic = None
    best_bic = None

    for p in range(1, p_max + 1):
        try:
            m = AutoReg(yv, lags=p, old_names=False).fit()
        except Exception:
            continue

        if (best_aic is None) or (m.aic < best_aic.aic):
            best_aic = m
        if (best_bic is None) or (m.bic < best_bic.bic):
            best_bic = m

    if best_aic is None or best_bic is None:
        raise RuntimeError("AR grid search failed (try smaller p_max_cap).")

    return {
        "p_max_used": p_max,
        "model_aic": best_aic,
        "p_aic": int(best_aic.model._maxlag),
        "model_bic": best_bic,
        "p_bic": int(best_bic.model._maxlag),
    }


def choose_p_by_whiteness(y_fit, p_max_cap=P_MAX_CAP, lb_lag=LB_LAG, alpha=ALPHA):
    """
    choose the smallest p such that Ljung-Box p-value > alpha.
    If none passes, fall back to best-BIC (+report the limitation).
    """
    yv = pd.Series(y_fit).dropna()
    n = len(yv)
    p_max = max(1, min(p_max_cap, n // 10))

    rows = []
    models = {}

    for p in range(1, p_max + 1):
        try:
            m = AutoReg(yv, lags=p, old_names=False).fit()
            resid = pd.Series(m.resid).dropna()

            lb = acorr_ljungbox(resid, lags=[lb_lag], return_df=True)
            lb_p = float(lb["lb_pvalue"].iloc[0])

            rows.append((p, float(m.aic), float(m.bic), lb_p))
            models[p] = m
        except Exception:
            continue

    scan = pd.DataFrame(rows, columns=["p", "AIC", "BIC", f"LB_p_lag{lb_lag}"]).sort_values("p")

    ok = scan[scan[f"LB_p_lag{lb_lag}"] > alpha]
    if len(ok):
        p_star = int(ok.iloc[0]["p"])
        return {"p": p_star, "model": models[p_star], "passed": True, "p_max_used": p_max}
    else:
        # fallback: best BIC among scanned models
        p_bic = int(scan.loc[scan["BIC"].idxmin(), "p"])
        return {"p": p_bic, "model": models[p_bic], "passed": False, "p_max_used": p_max}


from statsmodels.tsa.stattools import acf

def residual_diagnostics(model, name, lb_lag=LB_LAG, plot_lags=PLOT_LAGS):
    """
    Minimal diagnostics:
      - Residual ACF (visual) excluding lag 0 (trivially = 1)
      - Ljung-Box p-value at fixed lag (statistical)
    """
    resid = pd.Series(model.resid).dropna()

    # Compute residual ACF explicitly to drop lag 0 from the plot.
    acf_vals = acf(resid.to_numpy(), nlags=plot_lags, fft=True)
    lags = np.arange(1, plot_lags + 1)

    plt.figure(figsize=(8, 3))
    markerline, stemlines, baseline = plt.stem(lags, acf_vals[1:], basefmt=" ")
    plt.setp(markerline, color="#028A88")
    plt.setp(stemlines, color="#028A88")
    plt.axhline(0, color="k", linewidth=0.8)

    # simple approx CI band (same idea as statsmodels' default)
    ci = 1.96 / np.sqrt(len(resid))
    plt.axhspan(-ci, ci, alpha=0.2, color = "#F39EC7")

    plt.title(f"{name} | Residual ACF (lags 1–{plot_lags})")
    plt.xlabel("lag")
    plt.tight_layout()
    plt.show()

    # Ljung-Box whiteness test at a fixed horizon
    lb = acorr_ljungbox(resid, lags=[lb_lag], return_df=True)
    lb_p = float(lb["lb_pvalue"].iloc[0])

    return {
        "lb_p": lb_p,
        "resid_var": float(np.var(resid, ddof=1)),
    }



def run_ar_block(y_fit, name, diff_used, force_diff=False):
    """
    One window:
      - report best AIC and best BIC AR(p)
      - pick final p by whiteness 
      - report a small set of parameters
      - residual diagnostics
    """
    # enforce differencing here (if requested)
    if force_diff:
        y_fit = pd.Series(y_fit).dropna().diff().dropna()
        diff_used = True

    print(f"\n {name} | AR modeling")
    print(f"Series used: diff1={diff_used} | n={len(y_fit)}")

    grid = ar_grid_search(y_fit)
    print(f"AR grid search: p_max_used={grid['p_max_used']}")
    print(f"Best by AIC: p={grid['p_aic']} | AIC={grid['model_aic'].aic:.3f} | BIC={grid['model_aic'].bic:.3f}")
    print(f"Best by BIC: p={grid['p_bic']} | AIC={grid['model_bic'].aic:.3f} | BIC={grid['model_bic'].bic:.3f}")

    chosen = choose_p_by_whiteness(y_fit)
    final_p = chosen["p"]
    final_model = chosen["model"]

    if chosen["passed"]:
        print(f"Final AR(p) chosen by whiteness: p={final_p}")
    else:
        print(f"Final AR(p): no p<=p_max achieved Ljung-Box p>{ALPHA} at lag={LB_LAG}.")
        print(f"Using best-BIC fallback: p={final_p}")
        print("Note: remaining autocorrelation suggests ARMA may be more appropriate.")

    print("Params (head):")
    print(final_model.params.head(6))

    diag = residual_diagnostics(final_model, name)

    print(f"Ljung-Box (lag={LB_LAG}): p={diag['lb_p']:.4f}")

    return {
        "name": name,
        "diff1": diff_used,
        "p_aic": grid["p_aic"],
        "p_bic": grid["p_bic"],
        "p_final": int(final_p),
        "aic_final": float(final_model.aic),
        "bic_final": float(final_model.bic),
        "lb_p": diag["lb_p"],
        "resid_var": diag["resid_var"],
    }


## 1.2 PLOTTING

In [None]:
# =============================================================================
# 1. STOCHASTIC PROCESS ANALYSIS
# =============================================================================

# Prepare both windows (this is the input for all later steps)
y_pre, info_pre = prep_series(bia_pre)
y_post, info_post = prep_series(bia_post)
y_post1, info_post1 = prep_series(bia_post_1)  # alternative POST window for sensitivity check
y_post2, info_post2 = prep_series(bia_post_2)  # alternative POST window for sensitivity check

print("PRE :", f"n_raw={info_pre['n_raw']}, n_out={info_pre['n_out']}, "
             f"dt_raw_med={info_pre['dt_raw_med_s']:.3f}s, dt_out_med={info_pre['dt_out_med_s']:.3f}s, "
             f"range={info_pre['start']} -> {info_pre['end']}, resample={info_pre['resample_rule']}")

print("POST:", f"n_raw={info_post['n_raw']}, n_out={info_post['n_out']}, "
             f"dt_raw_med={info_post['dt_raw_med_s']:.3f}s, dt_out_med={info_post['dt_out_med_s']:.3f}s, "
             f"range={info_post['start']} -> {info_post['end']}, resample={info_post['resample_rule']}")

# Plots for narrative: first PRE, then POST
plot_series_acf_pacf(y_pre, "PRE")
plot_series_acf_pacf(y_post, "POST")
plot_series_acf_pacf(y_post1, "POST1")
plot_series_acf_pacf(y_post2, "POST2")

# Run tests on prepared series 
st_pre = stationarity_tests(y_pre)
st_post = stationarity_tests(y_post)
st_post1 = stationarity_tests(y_post1)
st_post2 = stationarity_tests(y_post2)

print_stationarity(st_pre, "PRE")
print_stationarity(st_post, "POST")
print_stationarity(st_post1, "POST1")
print_stationarity(st_post2, "POST2")


# Decision step for AR modeling 
# AR models assume stationarity; if ADF does not reject non-stationarity (p>0.05),
# we use first difference (diff1) as a simple stationarization step.
y_pre_fit = y_pre.dropna()
y_post_fit = y_post.dropna()
y_post1_fit = y_post1.dropna()
y_post2_fit = y_post2.dropna()

diff_pre = False
diff_post = False
diff_post1 = False
diff_post2 = False

if st_pre["adf_p"] > 0.05:
    y_pre_fit = y_pre_fit.diff().dropna()
    diff_pre = True

if st_post["adf_p"] > 0.05:
    y_post_fit = y_post_fit.diff().dropna()
    diff_post = True

if st_post1["adf_p"] > 0.05:
    y_post1_fit = y_post1_fit.diff().dropna()
    diff_post1 = True

if st_post2["adf_p"] > 0.05:
    y_post2_fit = y_post2_fit.diff().dropna()
    diff_post2 = True


print("\nModeling series (after stationarity step)")
print(f"PRE : diff1={diff_pre} | n={len(y_pre_fit)}")
print(f"POST: diff1={diff_post} | n={len(y_post_fit)}")
print(f"POST1: diff1={diff_post1} | n={len(y_post1_fit)}")
print(f"POST2: diff1={diff_post2} | n={len(y_post2_fit)}")

# Run both windows
out_pre = run_ar_block(y_pre_fit, "PRE", diff_pre)
out_post = run_ar_block(y_post_fit, "POST", diff_post)
out_post1 = run_ar_block(y_post1_fit, "POST1", diff_post1)
out_post2 = run_ar_block(y_post2_fit, "POST2", diff_post2)

summary = pd.DataFrame([
    {
        "Window": "PRE",
        "n_prep": info_pre["n_out"],
        "dt_out_med_s": info_pre["dt_out_med_s"],
        "ADF_p": st_pre["adf_p"],
        "diff1": diff_pre,
        "p_AIC": out_pre["p_aic"],
        "p_BIC": out_pre["p_bic"],
        "p_final": out_pre["p_final"],
        "AIC_final": out_pre["aic_final"],
        "BIC_final": out_pre["bic_final"],
        "LB_p_lag20": out_pre["lb_p"],
        "resid_var": out_pre["resid_var"],
    },
    {
        "Window": "POST",
        "n_prep": info_post["n_out"],
        "dt_out_med_s": info_post["dt_out_med_s"],
        "ADF_p": st_post["adf_p"],
        "diff1": diff_post,
        "p_AIC": out_post["p_aic"],
        "p_BIC": out_post["p_bic"],
        "p_final": out_post["p_final"],
        "AIC_final": out_post["aic_final"],
        "BIC_final": out_post["bic_final"],
        "LB_p_lag20": out_post["lb_p"],
        "resid_var": out_post["resid_var"],
    },
    {
        "Window": "POST1",
        "n_prep": info_post1["n_out"],
        "dt_out_med_s": info_post1["dt_out_med_s"],
        "ADF_p": st_post1["adf_p"],
        "diff1": diff_post1,
        "p_AIC": out_post1["p_aic"],
        "p_BIC": out_post1["p_bic"],
        "p_final": out_post1["p_final"],
        "AIC_final": out_post1["aic_final"],
        "BIC_final": out_post1["bic_final"],
        "LB_p_lag20": out_post1["lb_p"],
        "resid_var": out_post1["resid_var"],
    },
    {
        "Window": "POST2",
        "n_prep": info_post2["n_out"],
        "dt_out_med_s": info_post2["dt_out_med_s"],
        "ADF_p": st_post2["adf_p"],
        "diff1": diff_post2,
        "p_AIC": out_post2["p_aic"],
        "p_BIC": out_post2["p_bic"],
        "p_final": out_post2["p_final"],
        "AIC_final": out_post2["aic_final"],
        "BIC_final": out_post2["bic_final"],
        "LB_p_lag20": out_post2["lb_p"],
        "resid_var": out_post2["resid_var"],
    },
])

# Compact display (rounded)
summary_print = summary.copy()
for c in ["dt_out_med_s","ADF_p","AIC_final","BIC_final","LB_p_lag20","resid_var"]:
    summary_print[c] = summary_print[c].astype(float).round(6)

print("\nPRE vs POST vs POST1 vs POST2 (summary)")
print(summary_print.to_string(index=False))


# 2. SPECTRAL ANALYSIS

## 2.1 FUNCTION

In [None]:
#==============================================================
#SPECTRAL ANALYSE
#==============================================================

class SpectralAnalysis:

#-------------------------
# Filter function 
#-------------------------
    def butter_lowpass_filter(data, cutoff, fs, order=2):
        nyq = 0.5 * fs
        normal_cutoff = cutoff / nyq
        b, a = butter(order, normal_cutoff, btype='low', analog=False)
        y = filtfilt(b, a, data)
        return y
    
    def remove_mean_and_detrend(signal):
        x = signal - np.mean(signal)
        x = detrend(x)
        return x

    def cosine_taper(signal, p=0.1):
        """
        Applique un cosin tapering (fenêtre de Tukey) au signal.
        p : proportion de la fenêtre cosinus (0 = rectangulaire, 1 = Hann)
        """
        from scipy.signal.windows import tukey
        w = tukey(len(signal), alpha=p)
        return signal * w

#------------------------
# FFT function
#------------------------
    def compute_fft_power_spectrum(signal, fs, window='hann'):
        n = len(signal)
        fft_vals = np.fft.rfft(signal)
        freqs = np.fft.rfftfreq(n, d=1.0/fs)
        power = np.abs(fft_vals)**2 / n
        return freqs, power
    
#------------------------
# Plotting function
#------------------------
    def plot_compare_power_spectra(list_freqs, list_power, labels, title, xlabel="Frequency (Hz)", ylabel="Power", logy=True):
        plt.figure(figsize=(14, 5))
        colors = ["#028A88", "#F39EC7", "#d62728", "#1f77b4", "#FFA500"]
        for idx, (freqs, power, label) in enumerate(zip(list_freqs, list_power, labels)):
            color = colors[idx % len(colors)]
            if logy:
                plt.semilogy(freqs, power, label=label, color=color)
            else:
                plt.plot(freqs, power, label=label, color=color)
        plt.title(title)
        plt.xlabel(xlabel)
        plt.ylabel(ylabel)
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.show()

    def plot_raw_and_filtered_panels(
        data_list,
        titles=None,
        xlabels=None,
        ylabels=None,
        suptitle=None,
        figsize=(16, 5)
    ):
        n = len(data_list)
        fig, axes = plt.subplots(1, n, figsize=figsize, squeeze=False)
        if suptitle:
            fig.suptitle(suptitle, fontsize=16, fontweight='bold')
        for i, data in enumerate(data_list):
            ax = axes[0, i]
            ax.plot(data['time'], data['raw'], alpha=0.5, linewidth=0.5, label='Raw', color="#028A88")
            ax.plot(data['time'], data['filtered'], linewidth=2, label='Filtered', color="#F39EC7")
            if titles and i < len(titles):
                ax.set_title(titles[i])
            if xlabels and i < len(xlabels):
                ax.set_xlabel(xlabels[i])
            else:
                ax.set_xlabel('Time (s)')
            if ylabels and i < len(ylabels):
                ax.set_ylabel(ylabels[i])
            else:
                ax.set_ylabel('Phase Angle (°)')
            ax.legend()
            ax.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.show()

    def plot_welch_panels(
        list_freqs,
        list_psd,
        titles=None,
        xlabels=None,
        ylabels=None,
        suptitle=None,
        figsize=(14, 4)
    ):
        n = len(list_freqs)
        colors = ["#028A88", "#F39EC7", "#d62728", "#1f77b4", "#FFA500"]
        fig, axes = plt.subplots(1, n, figsize=figsize, squeeze=False)
        if suptitle:
            fig.suptitle(suptitle, fontsize=16, fontweight='bold')
        for i in range(n):
            color = colors[i % len(colors)]
            ax = axes[0, i]
            ax.semilogy(list_freqs[i], list_psd[i], color=color)
            if titles and i < len(titles):
                ax.set_title(titles[i])
            if xlabels and i < len(xlabels):
                ax.set_xlabel(xlabels[i])
            else:
                ax.set_xlabel('Frequency (Hz)')
            if ylabels and i < len(ylabels):
                ax.set_ylabel(ylabels[i])
            else:
                ax.set_ylabel('PSD')
            ax.grid(True)
        plt.tight_layout()
        plt.show()

    def plot_log_psd_with_fit(freqs, psd, label=None,fmin=None,fmax=None, title="log(PSD) vs log(f) avec fit", xlabel="log10(Frequency [Hz])", ylabel="log10(PSD)"):
        # Filtrer les fréquences et PSD > 0 pour éviter log(0)
        mask = (freqs > 0) & (psd > 0)
        if fmin is not None:
            mask &= freqs >= fmin
        if fmax is not None:
            mask &= freqs <= fmax

        log_freqs = np.log10(freqs[mask])
        log_psd = np.log10(psd[mask])
        if len(log_freqs) < 3:
            print("Pas assez de points pour un fit fiable.")
            return None
    
        # Régression linéaire
        coeffs = np.polyfit(log_freqs, log_psd, 1)
        fit_line = np.polyval(coeffs, log_freqs)
        coeffs = np.polyfit(log_freqs, log_psd, 1)
        slope, intercept = coeffs


        plt.figure()
        plt.plot(log_freqs, log_psd, label=label)
        plt.plot(log_freqs, fit_line, 'r--', label=f'Fit: y={coeffs[0]:.2f}x+{coeffs[1]:.2f}')
        plt.xlabel(xlabel)
        plt.ylabel(ylabel)
        plt.title(title)
        plt.legend()
        plt.grid(True, which="both", ls="--")
        plt.show()
        return slope, intercept

  


## 2.2 DEFINING BUTTERWORTH ORDER / CUTOFF

In [None]:
#--------------------------------------------------------------------------------
# TEST BUTTERWORTH DESIGN
#--------------------------------------------------------------------------------

Fs = 2
Nyq = Fs / 2

fp = 0.05
fs = 0.2

wp = fp / Nyq
ws = fs / Nyq

Ap = 1     # dB
As = 40    # dB

n, Wn = buttord(wp, ws, Ap, As)


print("Ordre :", n)
print("Cutoff :", Wn)

## 2.3 PLOTTING

In [None]:
# =============================================================================
#  2. SPECTRAL ANALYSIS
# =============================================================================

#--------------------------------------------------
# Create stationary signals and time vectors
#--------------------------------------------------

#-------- Signal preprocessing: detrend ---------
# Detrend signals
pha_post_1_detrend = remove_mean_and_detrend(np.asarray(pha_post_1))
pha_pre_detrend = remove_mean_and_detrend(np.asarray(pha_pre))
pha_post_detrend = remove_mean_and_detrend(np.asarray(pha_post))

# time vectors for plotting (match lengths after detrending)
t_post_1_detrend_plot = t_post_1[:len(pha_post_1_detrend)]
pha_post_1_detrend_plot = pha_post_1_detrend[:len(t_post_1_detrend_plot)]
pha_post_1_filtered_plot = pha_post_1_filtered[:len(t_post_1_detrend_plot)]

t_pre_detrend_plot = t_pre[:len(pha_pre_detrend)]
pha_pre_detrend_plot = pha_pre_detrend[:len(t_pre_detrend_plot)]
pha_pre_filtered_plot = pha_pre_filtered[:len(t_pre_detrend_plot)]

t_post_detrend_plot = t_post[:len(pha_post_detrend)]
pha_post_detrend_plot = pha_post_detrend[:len(t_post_detrend_plot)]
pha_post_filtered_plot = pha_post_filtered[:len(t_post_detrend_plot)]

#-------- Signal preprocessing: tapering ---------
# Apply cosine tapering
pha_pre_detrend_tapered = cosine_taper(pha_pre_detrend, p=0.1)
pha_post_1_detrend_tapered = cosine_taper(pha_post_1_detrend, p=0.1)
pha_post_detrend_tapered = cosine_taper(pha_post_detrend, p=0.1)

#-------------------------------------------------
# 1. Raw and Butterworth Filtered Signals
#-------------------------------------------------
cutoff_freq = 0.06 # Hz (determined from Butterworth design test)
#----------- Def Fc and filter order -------------------
# Apply Butterworth lowpass filter

pha_pre_filtered = SpectralAnalysis.butter_lowpass_filter(pha_pre_detrend_tapered, cutoff=cutoff_freq, fs=fs_pre_est, order=2)
pha_post_filtered = SpectralAnalysis.butter_lowpass_filter(pha_post_detrend_tapered, cutoff=cutoff_freq, fs=fs_post_est, order=2)
pha_post_1_filtered = SpectralAnalysis.butter_lowpass_filter(pha_post_1_detrend_tapered, cutoff=cutoff_freq, fs=fs_post1_est, order=2)
pha_post_2_filtered = SpectralAnalysis.butter_lowpass_filter(pha_post_2, cutoff=cutoff_freq, fs=fs_post2_est, order=2)

#----------- Plot raw and filtered signals -------------
# POST
SpectralAnalysis.plot_raw_and_filtered_panels(
    [{'time': t_post_detrend_plot, 'raw': pha_post_detrend_plot, 'filtered': pha_post_filtered_plot}],
    titles=['BIA Phase Angle - POST Fatigue']
)

# PRE vs POST1 vs POST2
SpectralAnalysis.plot_raw_and_filtered_panels(
    [
        {'time': t_pre_detrend_plot, 'raw': pha_pre_detrend_plot, 'filtered': pha_pre_filtered_plot},
        {'time': t_post_1_detrend_plot, 'raw': pha_post_1_detrend_plot, 'filtered': pha_post_1_filtered_plot},
        {'time': t_post_2, 'raw': pha_post_2, 'filtered': pha_post_2_filtered},
    ],
    titles=['PRE', 'POST_1', 'POST_2'],
    suptitle='BIA Phase Angle (Raw & Filtered) — All Windows'
)


#-------------------------------------------------
# 2. FFT Power Spectrum Comparison
#-------------------------------------------------

# ----------- Def the FFT signal --------------
freqs_fft_pre, power_fft_pre = SpectralAnalysis.compute_fft_power_spectrum(pha_pre_detrend_tapered, fs_pre_est)
freqs_fft_post, power_fft_post = SpectralAnalysis.compute_fft_power_spectrum(pha_post_detrend_tapered, fs_post_est)
freqs_fft_post1, power_fft_post1 = SpectralAnalysis.compute_fft_power_spectrum(pha_post_1_detrend_tapered, fs_post1_est)
freqs_fft_post2, power_fft_post2 = SpectralAnalysis.compute_fft_power_spectrum(pha_post_2, fs_post2_est)

# ----------- Plot comparisons --------------
# PRE vs POST
SpectralAnalysis.plot_compare_power_spectra(
    [freqs_fft_pre, freqs_fft_post],
    [power_fft_pre, power_fft_post],
    [f'PRE-fatigue (fs={fs_pre_est:.2f} Hz)', f'POST-fatigue (fs={fs_post_est:.2f} Hz)'],
    title='FFT Power Spectrum Comparison: PRE vs POST Muscle Fatigue',
    ylabel='Power (°²)'
)

# PRE vs POST1 vs POST2
SpectralAnalysis.plot_compare_power_spectra(
    [freqs_fft_pre, freqs_fft_post1, freqs_fft_post2],
    [power_fft_pre, power_fft_post1, power_fft_post2],
    [
        f'PRE (fs={fs_pre_est:.2f} Hz)',
        f'POST_1 (fs={fs_post1_est:.2f} Hz)',
        f'POST_2 (fs={fs_post2_est:.2f} Hz)'
    ],
    title='FFT Power Spectrum Comparison: PRE vs POST_1 vs POST_2',
    ylabel='Power (°²)'
)
#-------------------------------------------------
# 3. Smoothed FFT Power Spectrum Comparison (Butterworth Filtered Signals)
#-------------------------------------------------

# ----------- Def the FFT filtered signal --------------
freqs_fft_pre_filt, power_fft_pre_filt = SpectralAnalysis.compute_fft_power_spectrum(pha_pre_filtered, fs_pre_est)
freqs_fft_post_filt, power_fft_post_filt = SpectralAnalysis.compute_fft_power_spectrum(pha_post_filtered, fs_post_est)

# ----------- Plot comparisons --------------
# PRE vs POST
SpectralAnalysis.plot_compare_power_spectra(
    [freqs_fft_pre_filt, freqs_fft_post_filt],
    [power_fft_pre_filt, power_fft_post_filt],
    [f'PRE Butterworth (fs={fs_pre_est:.2f} Hz)', f'POST Butterworth (fs={fs_post_est:.2f} Hz)'],
    title='FFT Power Spectrum (Butterworth filtered): PRE vs POST',
    ylabel='Power (°²)'
)
# PRE vs POST1 vs POST2
freqs_fft_post1_filt, power_fft_post1_filt = SpectralAnalysis.compute_fft_power_spectrum(pha_post_1_filtered, fs_post1_est)
freqs_fft_post2_filt, power_fft_post2_filt = SpectralAnalysis.compute_fft_power_spectrum(pha_post_2_filtered, fs_post2_est)
SpectralAnalysis.plot_compare_power_spectra(
    [freqs_fft_pre_filt, freqs_fft_post1_filt, freqs_fft_post2_filt],
    [power_fft_pre_filt, power_fft_post1_filt, power_fft_post2_filt],
    [
        f'PRE Butterworth (fs={fs_pre_est:.2f} Hz)',
        f'POST_1 Butterworth (fs={fs_post1_est:.2f} Hz)',
        f'POST_2 Butterworth (fs={fs_post2_est:.2f} Hz)'
    ],
    title='FFT Power Spectrum (Butterworth filtered): PRE vs POST_1 vs POST_2',
    ylabel='Power (°²)'
)


#-------------------------------------------------
# 4. Welch Periodogram (Individual)
#-------------------------------------------------

# ----------- Def the Welch signal --------------
freqs_pre, psd_pre = welch(pha_pre_detrend, fs=fs_pre_est, nperseg=len(pha_post_detrend), noverlap=0)
freqs_post1, psd_post1 = welch(np.asarray(pha_post_1_detrend), fs=fs_post1_est, nperseg=len(pha_post_detrend), noverlap=0)
freqs_post2, psd_post2 = welch(pha_post_2, fs=fs_post2_est, nperseg=len(pha_post_2_filtered), noverlap=0)
freqs_post, psd_post = welch(pha_post_detrend, fs=fs_post_est, nperseg=len(pha_post_detrend), noverlap=0)

# ----------- Plot comparisons --------------
# PRE vs POST
SpectralAnalysis.plot_compare_power_spectra(
    [freqs_pre, freqs_post],
    [psd_pre, psd_post],
    [f'PRE (fs={fs_pre_est:.2f} Hz)', f'POST (fs={fs_post_est:.2f} Hz)'],
    title='Welch Periodogram Comparison: PRE vs POST Muscle Fatigue',
    ylabel='PSD'
)

# PRE vs POST1 vs POST2
SpectralAnalysis.plot_compare_power_spectra(
    [freqs_pre, freqs_post1, freqs_post2],
    [psd_pre, psd_post1, psd_post2],
    [
        f'PRE (fs={fs_pre_est:.2f} Hz)',
        f'POST_1 (fs={fs_post1_est:.2f} Hz)',
        f'POST_2 (fs={fs_post2_est:.2f} Hz)'
    ],
    title='Welch Periodogram Comparison: PRE vs POST_1 vs POST_2',
    ylabel='PSD'
)

SpectralAnalysis.plot_log_psd_with_fit(freqs_pre, psd_pre, fmin=0.03, fmax=0.3, label="PRE")
SpectralAnalysis.plot_log_psd_with_fit(freqs_post, psd_post, fmin=0.03, fmax=0.3, label="POST")
SpectralAnalysis.plot_log_psd_with_fit(freqs_post1, psd_post1, fmin=0.03, fmax=0.3, label="POST1")
SpectralAnalysis.plot_log_psd_with_fit(freqs_post2, psd_post2, fmin=0.03, fmax=0.3, label="POST2")

#-------------------------------------------------
# 5. Dominant Frequency
#-------------------------------------------------

#----------- Find dominant frequencies --------------
# PRE
idx_dom_pre = np.argmax(psd_pre)
f_dom_pre = freqs_pre[idx_dom_pre]
P_dom_pre = psd_pre[idx_dom_pre]
# POST
idx_dom_post = np.argmax(psd_post)
f_dom_post = freqs_post[idx_dom_post]
P_dom_post = psd_post[idx_dom_post]
#POST1
idx_dom_post1 = np.argmax(psd_post1)
f_dom_post1 = freqs_post1[idx_dom_post1]
P_dom_post1 = psd_post1[idx_dom_post1]
#POST2
idx_dom_post2 = np.argmax(psd_post2)
f_dom_post2 = freqs_post2[idx_dom_post2]
P_dom_post2 = psd_post2[idx_dom_post2]

# ----------- Print dominant frequencies --------------
print(f"Dominant Frequency PRE: f={f_dom_pre:.3f} Hz | PSD={P_dom_pre:.3f} °²/Hz")
print(f"Dominant Frequency POST: f={f_dom_post:.3f} Hz | PSD={P_dom_post:.3f} °²/Hz")
print(f"Dominant Frequency POST1: f={f_dom_post1:.3f} Hz | PSD={P_dom_post1:.3f} °²/Hz")
print(f"Dominant Frequency POST2: f={f_dom_post2:.3f} Hz | PSD={P_dom_post2:.3f} °²/Hz")

#----------- Find dominant frequencies (Butterworth filtered) --------------
# PRE (filtered)
freqs_pre_filt, psd_pre_filt = welch(pha_pre_filtered, fs=fs_pre_est, nperseg=len(pha_post_filtered), noverlap=0)
idx_dom_pre_filt = np.argmax(psd_pre_filt)
f_dom_pre_filt = freqs_pre_filt[idx_dom_pre_filt]
P_dom_pre_filt = psd_pre_filt[idx_dom_pre_filt]
# POST (filtered)
freqs_post_filt, psd_post_filt = welch(pha_post_filtered,fs=fs_post_est,nperseg=len(pha_post_filtered), noverlap=0)

idx_dom_post_filt = np.argmax(psd_post_filt)
f_dom_post_filt = freqs_post_filt[idx_dom_post_filt]
P_dom_post_filt = psd_post_filt[idx_dom_post_filt]
# POST1 (filtered)
freqs_post1_filt, psd_post1_filt = welch(pha_post_1_filtered, fs=fs_post1_est, nperseg=len(pha_post_filtered), noverlap=0)
idx_dom_post1_filt = np.argmax(psd_post1_filt)
f_dom_post1_filt = freqs_post1_filt[idx_dom_post1_filt]
P_dom_post1_filt = psd_post1_filt[idx_dom_post1_filt]
# POST2 (filtered)
freqs_post2_filt, psd_post2_filt = welch(pha_post_2_filtered, fs=fs_post2_est, nperseg=len(pha_post_filtered), noverlap=0)
idx_dom_post2_filt = np.argmax(psd_post2_filt)
f_dom_post2_filt = freqs_post2_filt[idx_dom_post2_filt]
P_dom_post2_filt = psd_post2_filt[idx_dom_post2_filt]

# ----------- Print dominant frequencies (filtered) --------------
print(f"Dominant Frequency PRE (filtered):   f={f_dom_pre_filt:.3f} Hz | PSD={P_dom_pre_filt:.3f} °²/Hz")
print(f"Dominant Frequency POST (filtered):  f={f_dom_post_filt:.3f} Hz | PSD={P_dom_post_filt:.3f} °²/Hz")
print(f"Dominant Frequency POST1 (filtered): f={f_dom_post1_filt:.3f} Hz | PSD={P_dom_post1_filt:.3f} °²/Hz")
print(f"Dominant Frequency POST2 (filtered): f={f_dom_post2_filt:.3f} Hz | PSD={P_dom_post2_filt:.3f} °²/Hz")


# 3. FRACTAL ANALYSIS

## 3.1 FUNCTION

In [None]:
#====================================================
# FRACTAL ANALYSIS FUNCTION
#====================================================


class fractal_analysis:
#-------------------------------
# Signal traitement function
#-------------------------------
    def fractal_dimension_boxcount(signal, time_vector=None):
        signal_norm = (signal - np.min(signal)) / (np.max(signal) - np.min(signal) + 1e-10)
        if time_vector is None:
            time_vector = np.arange(len(signal))
        time_norm = (time_vector - np.min(time_vector)) / (np.max(time_vector) - np.min(time_vector) + 1e-10)
        n_points = len(signal)
        min_boxes = 4
        max_boxes = min(n_points // 2, 1000)
        n_divisions = np.logspace(np.log10(min_boxes), np.log10(max_boxes), num=20, dtype=int)
        n_divisions = np.unique(n_divisions)
        box_counts = []
        box_sizes = []
        for n_div in n_divisions:
            box_size = 1.0 / n_div
            box_sizes.append(box_size)
            box_indices_x = (time_norm / box_size).astype(int)
            box_indices_y = (signal_norm / box_size).astype(int)
            box_indices_x = np.clip(box_indices_x, 0, n_div - 1)
            box_indices_y = np.clip(box_indices_y, 0, n_div - 1)
            occupied_boxes = set(zip(box_indices_x, box_indices_y))
            box_counts.append(len(occupied_boxes))
        box_sizes = np.array(box_sizes)
        box_counts = np.array(box_counts)
        log_box_sizes = np.log(box_sizes)
        log_counts = np.log(box_counts)
        coeffs = np.polyfit(log_box_sizes, log_counts, 1)
        D = -coeffs[0]
        return D, box_sizes, box_counts

    def dfa_analysis(signal, min_scale=4, max_scale=None, num_scales=20):
        N = len(signal)
        if max_scale is None:
            max_scale = N // 4
        signal_mean = np.mean(signal)
        y = np.cumsum(signal - signal_mean)
        scales = np.unique(np.logspace(np.log10(min_scale), np.log10(max_scale), num=num_scales, dtype=int))
        fluctuations = []
        for scale in scales:
            num_segments = N // scale
            if num_segments < 1:
                continue
            y_trimmed = y[:num_segments * scale]
            segments = y_trimmed.reshape((num_segments, scale))
            fluctuation_sum = 0
            for segment in segments:
                x_seg = np.arange(scale)
                coeffs = np.polyfit(x_seg, segment, 1)
                trend = np.polyval(coeffs, x_seg)
                residuals = segment - trend
                fluctuation_sum += np.sum(residuals ** 2)
            F_n = np.sqrt(fluctuation_sum / (num_segments * scale))
            fluctuations.append(F_n)
        scales = np.array(scales[:len(fluctuations)])
        fluctuations = np.array(fluctuations)
        log_scales = np.log10(scales)
        log_fluctuations = np.log10(fluctuations)
        coeffs = np.polyfit(log_scales, log_fluctuations, 1)
        alpha = coeffs[0]
        return alpha, scales, fluctuations
#-------------------------------
# Plotting function
#-------------------------------
    def plot_box_counting_panel(signal_list, labels=None):
        n = len(signal_list)
        fig, axes = plt.subplots(1, n, figsize=(7*n, 5), squeeze=False)
        for i, signal in enumerate(signal_list):
            D, box_sizes, box_counts = fractal_analysis.fractal_dimension_boxcount(signal)
            log_box_sizes = np.log(box_sizes)
            log_counts = np.log(box_counts)
            coeffs = np.polyfit(log_box_sizes, log_counts, 1)
            ax = axes[0, i]
            ax.scatter(log_box_sizes, log_counts, color="#028A88", label='Data points')
            ax.plot(log_box_sizes, np.polyval(coeffs, log_box_sizes), color="#F39EC7", lw=2, label=f'D={D:.3f}')
            title = labels[i] if labels else f"Signal {i+1}"
            ax.set_title(f'{title} — Box Counting')
            ax.set_xlabel('log(Box Size)')
            ax.set_ylabel('log(Count)')
            ax.legend()
            ax.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.show()

    def plot_dfa_panel(signal_list, labels=None, min_scale=4, max_scale=100):
        n = len(signal_list)
        fig, axes = plt.subplots(1, n, figsize=(7*n, 5), squeeze=False)
        for i, signal in enumerate(signal_list):
            alpha, scales, fluctuations = fractal_analysis.dfa_analysis(signal, min_scale=min_scale, max_scale=max_scale)
            log_scales = np.log10(scales)
            log_fluctuations = np.log10(fluctuations)
            coeffs = np.polyfit(log_scales, log_fluctuations, 1)
            ax = axes[0, i]
            ax.scatter(log_scales, log_fluctuations, color="#028A88", label='Data points')
            ax.plot(log_scales, np.polyval(coeffs, log_scales), color="#F39EC7", lw=2, label=f'α={alpha:.3f}')
            title = labels[i] if labels else f"Signal {i+1}"
            ax.set_title(f'{title} — DFA')
            ax.set_xlabel('log(Window Size)')
            ax.set_ylabel('log(Fluctuation)')
            ax.legend()
            ax.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.show()

## 3.2 PLOTTING

In [None]:
# =============================================================================
# 3. FRACTAL ANALYSIS
# =============================================================================

# -------- Box Counting Dimension ----------
dim_pre, box_sizes_pre, box_counts_pre = fractal_analysis.fractal_dimension_boxcount(pha_pre_detrend_tapered)
dim_post, box_sizes_post, box_counts_post = fractal_analysis.fractal_dimension_boxcount(pha_post_detrend_tapered)

fractal_analysis.plot_box_counting_panel([pha_pre_detrend_tapered, pha_post_detrend_tapered], labels=["PRE", "POST"])
fractal_analysis.plot_box_counting_panel(
    [pha_pre_detrend_tapered, pha_post_1_detrend_tapered, pha_post_2_filtered],
    labels=["PRE", "POST_1", "POST_2"]
)

# ------- DFA --------

signals = [
    (pha_pre_detrend_tapered, "PRE"),
    (pha_post_1_detrend_tapered, "POST_1"),
    (pha_post_2_filtered, "POST_2"),
]

for signal, label in signals:
    N = len(signal)
    signal_mean = np.mean(signal)

    # 1. Integration (cumulative sum)
    y = np.cumsum(signal - signal_mean)
    plt.figure(figsize=(10,3))
    plt.plot(y, color="#028A88")
    plt.title(f"DFA Step 1: Integrated signal (cumsum) — {label}")
    plt.xlabel("Sample")
    plt.ylabel("Integrated value")
    plt.tight_layout()
    plt.show()

    # 2. Window sizes (scales)
    min_scale = 4
    max_scale = N // 4
    num_scales = 20
    scales = np.unique(np.logspace(np.log10(min_scale), np.log10(max_scale), num=num_scales, dtype=int))

    # 3. For a chosen scale, show detrending
    chosen_scale = scales[len(scales)//2]
    num_segments = N // chosen_scale
    y_trimmed = y[:num_segments * chosen_scale]
    segments = y_trimmed.reshape((num_segments, chosen_scale))

    plt.figure(figsize=(12,4))
    for i in range(min(3, num_segments)):
        x_seg = np.arange(chosen_scale)
        seg = segments[i]
        coeffs = np.polyfit(x_seg, seg, 1)
        trend = np.polyval(coeffs, x_seg)
        plt.subplot(1,3,i+1)
        plt.plot(x_seg, seg, label="Segment")
        plt.plot(x_seg, trend, '--', label="Local trend")
        plt.title(f"{label} — Segment {i+1} (scale={chosen_scale})")
        plt.legend()
    plt.tight_layout()
    plt.show()

    # 4. Compute RMS fluctuation for each scale
    fluctuations = []
    for scale in scales:
        num_segments = N // scale
        if num_segments < 1:
            continue
        y_trimmed = y[:num_segments * scale]
        segments = y_trimmed.reshape((num_segments, scale))
        fluctuation_sum = 0
        for segment in segments:
            x_seg = np.arange(scale)
            coeffs = np.polyfit(x_seg, segment, 1)
            trend = np.polyval(coeffs, x_seg)
            residuals = segment - trend
            fluctuation_sum += np.sum(residuals ** 2)
        F_n = np.sqrt(fluctuation_sum / (num_segments * scale))
        fluctuations.append(F_n)
    scales = np.array(scales[:len(fluctuations)])
    fluctuations = np.array(fluctuations)

    # 5. Log-log regression
    log_scales = np.log10(scales)
    log_fluctuations = np.log10(fluctuations)
    coeffs = np.polyfit(log_scales, log_fluctuations, 1)
    alpha = coeffs[0]
    fit_line = np.polyval(coeffs, log_scales)

    plt.figure(figsize=(7,5))
    plt.scatter(log_scales, log_fluctuations, color="#028A88", label="Data points")
    plt.plot(log_scales, fit_line, color="#F39EC7", lw=2, label=f'α={alpha:.3f}')
    plt.title(f"DFA Step 5: log(Fluctuation) vs log(Window size) — {label}")
    plt.xlabel("log10(Window size)")
    plt.ylabel("log10(Fluctuation)")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

    print(f"Pente alpha estimée ({label}): {alpha:.3f}")

# ANNEXE 

## ASSESSING REGULARITY

In [None]:

plt.figure(figsize=(12, 2))
is_nan = bia_post_2["PhA_48.8k_deg"].isna()
plt.plot(bia_post_2["time"], is_nan, drawstyle="steps-post", color="#d62728")
plt.title("Présence de NaN dans bia_post_2['PhA_48.8k_deg']")
plt.xlabel("Time")
plt.ylabel("is NaN")
plt.yticks([0, 1], ["Non", "Oui"])
plt.tight_layout()
plt.show()


def check_sampling_regular(df, time_col="time", plot=True, tol=0.01):
    """
    Affiche la distribution des intervalles d'échantillonnage et vérifie la régularité.
    Ajoute un diagnostic : faut-il resampler ?
    tol : tolérance relative (ex: 0.01 = 1%) pour accepter l'irrégularité.
    """
    t = pd.to_datetime(df[time_col], errors="coerce").dropna()
    dt = t.diff().dt.total_seconds().dropna()
    print(f"Nombre d'intervalles: {len(dt)}")
    print(f"dt min: {dt.min():.6f} s, dt max: {dt.max():.6f} s, dt médian: {dt.median():.6f} s")
    print(f"Écart-type des dt: {dt.std():.6f} s")
    if plot:
        plt.figure(figsize=(8,3))
        plt.plot(dt.values, '.-', alpha=0.7)
        plt.title("Intervalles d'échantillonnage (dt)")
        plt.xlabel("Index")
        plt.ylabel("dt (s)")
        plt.tight_layout()
        plt.show()
        plt.figure(figsize=(6,3))
        plt.hist(dt, bins=30, color="#028A88", alpha=0.7)
        plt.title("Distribution des dt")
        plt.xlabel("dt (s)")
        plt.tight_layout()
        plt.show()
    # Diagnostic de régularité
    rel_std = dt.std() / dt.median() if dt.median() != 0 else np.nan
    if rel_std < tol:
        print(f"✅ Sampling régulier (écart-type relatif = {rel_std:.2%} < tolérance {tol:.2%}) → Pas besoin de resampler.")
        need_resample = False
    else:
        print(f"⚠️ Sampling irrégulier (écart-type relatif = {rel_std:.2%} ≥ tolérance {tol:.2%}) → Il est conseillé de resampler.")
        need_resample = True
    return dt, need_resample

# Exemple d'utilisation :
dt_post2, need_resample_post2 = check_sampling_regular(bia_post_2)
dt_post1,need_resample_post1 = check_sampling_regular(bia_post_1)
dt_pre, need_resample_pre = check_sampling_regular(bia_pre)
dt_post, need_resample_post = check_sampling_regular(bia_post)