In [2]:
import logging
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scienceplots  # noqa: F401
from statsmodels.stats.multitest import multipletests

from utils import read_metrics_file, z_score
from utils.constants import EWM_ALPHA, METRICS, WINDOW_SIZE, Events, datetimes

In [3]:
plt.style.use(["science", "nature"])
plt.rcParams.update(
    {
        "font.size": 12,
        "xtick.labelsize": 12,
        "ytick.labelsize": 12,
        "axes.labelsize": 12,
        "legend.fontsize": 12,
    }
)

In [None]:
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    filename="calc_stats.log",
    filemode="w",
)

console = logging.StreamHandler()
console.setLevel(logging.DEBUG)
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
console.setFormatter(formatter)
logging.getLogger().addHandler(console)

logger = logging.getLogger(__name__)

### Constants

In [5]:
# Constants
EVENT: Events = "Forbush Decrease"
event_replace = EVENT.replace(" ", "")
EWM: bool = True  # Use EWM smoothing
ewm_suffix = f"-ewm_alpha_{EWM_ALPHA}" if EWM_ALPHA and EWM else ""

winsorize: float = 0.01  # Percentile for winsorization

In [6]:
summary_df_path = Path(
    f"./data/{event_replace}/summary_derivatives{ewm_suffix}.csv"
)
assert summary_df_path.exists()

summary_df = pd.read_csv(summary_df_path, parse_dates=["index"])
summary_df

Unnamed: 0,date,station,metric,index,lag_hours
0,2023-04-23,SOPO,entropy,2023-04-23 17:59:00,1.150000
1,2023-04-23,SOPO,sampen,2023-04-23 01:06:00,-15.733333
2,2023-04-23,SOPO,permutation_entropy,2023-04-23 12:50:00,-4.000000
3,2023-04-23,SOPO,shannon_entropy,2023-04-23 02:06:00,-14.733333
4,2023-04-23,SOPO,spectral_entropy,2023-04-24 00:50:00,8.000000
...,...,...,...,...,...
475,2024-05-10,APTY,katz_fd,2024-05-10 01:07:00,-16.900000
476,2024-05-10,APTY,petrosian_fd,2024-05-10 13:36:00,-4.416667
477,2024-05-10,APTY,lepel_ziv,2024-05-10 01:25:00,-16.600000
478,2024-05-10,APTY,corr_dim,2024-05-10 01:21:00,-16.666667


## Calculate stats

Based on [tools/generate_results.py](./tools/generate_results.py) calcs

In [7]:
def valid_interval(
    event: Events,
    date: str,
    station: str,
    data: pd.DataFrame = None,
) -> pd.DataFrame:
    if data is None:
        suffix = f"-ewm_alpha_{EWM_ALPHA}" if EWM_ALPHA and EWM else ""
        data = read_metrics_file(
            event,
            date,
            station,
            WINDOW_SIZE,
            datetime_cols={"datetime": None},
            suffix=suffix,
        ).set_index("datetime")

    data = data[(data["window_shape"] == WINDOW_SIZE)]
    return data


def process_derivatives(
    event: Events,
    date: str,
    station: str,
    winsorize_p: int = 0.01,
) -> dict[str, pd.Series]:
    assert 0 < winsorize_p < 0.5, "Percentile must be between 0.0 and 0.5"

    data = read_metrics_file(
        event,
        date,
        station,
        WINDOW_SIZE,
        datetime_cols={"datetime": None},
        suffix=ewm_suffix,
    ).set_index("datetime")

    # First filter by valid interval and then derivate
    metrics_columns = list(filter(lambda col: col in METRICS, data.columns))
    metrics_columns += ["value"]

    valid_df = valid_interval(event, date, station, data)
    interest_df = valid_df[metrics_columns].diff()

    results: dict[str, pd.Series] = {}
    for col in metrics_columns:
        points = interest_df[col]
        low, high = points.quantile([winsorize_p, 1 - winsorize_p])
        points = np.clip(points, low, high)  # Winsorize
        if len(points) < 0:
            continue

        results[col] = points
    return results

### Calculate new metrics

In [8]:
B = 1000  # Bootstrap replicates
block = 30  # Bootstrap block size (~30 mins)
fdr_alpha = 0.05

out_dir = Path("outputs")
out_dir.mkdir(parents=True, exist_ok=True)

fig_out_dir = Path("figures")
fig_out_dir.mkdir(parents=True, exist_ok=True)

table_out_dir = Path("tables")
table_out_dir.mkdir(parents=True, exist_ok=True)

In [9]:
def moving_block_bootstrap(x, block, B):
    n = len(x)
    if n == 0:
        return np.empty((B, 0))

    idx = np.arange(n)
    out = np.empty((B, n), dtype=float)
    for b in range(B):
        cur = []
        while len(cur) < n:
            start = np.random.randint(0, n)
            seg = idx[start : start + block]
            if len(seg) < block:
                seg = np.concatenate([seg, idx[: block - len(seg)]])
            cur.extend(seg.tolist())
        cur = np.array(cur[:n])
        out[b] = x[cur]
    return out


def safe_corr_at_lag(a, b, lag, normalize=True):
    if normalize:
        a = z_score(np.asarray(a, dtype=float))
        b = z_score(np.asarray(b, dtype=float))

    if np.isnan(lag):
        return np.nan
    lag = int(lag)

    if lag > 0:
        a2 = a[:-lag]
        b2 = b[lag:]
    elif lag < 0:
        a2 = a[-lag:]
        b2 = b[:lag]
    else:
        a2 = a
        b2 = b

    if len(a2) < 5:
        return np.nan

    C = np.corrcoef(a2, b2)
    return float(C[0, 1])

In [10]:
def process_event(date: str) -> pd.DataFrame:
    def get_index_of_datetime(df: pd.Series, dt: pd.Timestamp) -> int:
        tmp_df = df.reset_index()
        if dt in tmp_df["datetime"].values:
            return tmp_df[tmp_df["datetime"] == dt].index[0]

        # Find the closest datetime
        time_diffs = (tmp_df["datetime"] - dt).abs()
        return time_diffs.idxmin()

    stations = (
        summary_df[summary_df["date"] == date]["station"].unique().tolist()
    )

    rows = []
    for station in stations:
        logger.info(f"  Processing station {station}...")
        derivatives = process_derivatives(EVENT, date, station, winsorize)

        dx = derivatives.pop("value")
        original_onset_dt = pd.to_datetime(
            datetimes[date]["stations"][station][0]
        )
        original_onset_idx = get_index_of_datetime(dx, original_onset_dt)
        for metric, dy in derivatives.items():
            logger.info(f"    Processing metric {metric}...")
            # These results are the same calculated in summary_df
            lag_dt = dy.idxmax()
            lag_minutes = (lag_dt - original_onset_dt).total_seconds() / 60.0
            lag_idx = get_index_of_datetime(dy, lag_dt)

            # Bootstrap
            # the same operations done in `tools/generate_results.py`
            bs = moving_block_bootstrap(dy.values, block, B)
            corr_obs = safe_corr_at_lag(dy.values, dx.values, lag_idx)

            if np.isnan(corr_obs):
                pval = 1.0
            elif corr_obs >= 0:
                pval = float(
                    (
                        np.apply_along_axis(
                            lambda a: safe_corr_at_lag(a, dx, lag_idx), 1, bs
                        )
                        >= corr_obs
                    ).mean()
                )
            else:
                pval = float(
                    (
                        np.apply_along_axis(
                            lambda a: safe_corr_at_lag(a, dx, lag_idx), 1, bs
                        )
                        <= corr_obs
                    ).mean()
                )

            rows.append(
                {
                    "date": date,
                    "station": station.upper(),
                    "invariant": metric,
                    "lag_minutes": lag_minutes,
                    "lag_start_idx": lag_idx,
                    "lag_start_dt": lag_dt,
                    "corr_at_lag": corr_obs,
                    "pval": pval,
                }
            )

    res = pd.DataFrame(rows)
    if res.empty:
        return res

    # FDR por invariante (global en el evento)
    res["pval_adj"] = np.nan
    res["significant"] = False
    for inv in res["invariant"].unique():
        mask = res["invariant"] == inv
        p = res.loc[mask, "pval"].values
        rej, p_adj, _, _ = multipletests(p, alpha=fdr_alpha, method="fdr_bh")
        res.loc[mask, "pval_adj"] = p_adj
        res.loc[mask, "significant"] = rej.astype(bool)

    return res

### Summarize results

In [11]:
def summarize_global(df):
    # agrega por invariante a través de eventos/estaciones
    g = (
        df.groupby("invariant")
        .agg(
            median_lag=("lag_minutes", "median"),
            iqr_lag=(
                "lag_minutes",
                lambda x: np.subtract(*np.nanpercentile(x, [75, 25])),
            ),
            sig_pct=("significant", lambda x: 100.0 * np.mean(x)),
            n_stations=("station", "nunique"),
        )
        .reset_index()
    )
    # score Rk
    Lref = 120.0
    w1, w2, w3 = 0.4, 0.3, 0.3
    Nstn = df["station"].nunique()
    g["Rk"] = (
        w1 * (np.abs(g["median_lag"]) / Lref).clip(0, 1)
        + w2 * (1 - (g["iqr_lag"] / Lref).clip(0, 1))
        + w3 * (g["n_stations"] / max(Nstn, 1))
    )
    g = g.sort_values("Rk", ascending=False).reset_index(drop=True)
    return g


def summarize_by_event(df):
    e = (
        df.groupby(["date", "invariant"])
        .agg(
            median_lag=("lag_minutes", "median"),
            sig_pct=("significant", lambda x: 100.0 * np.mean(x)),
        )
        .reset_index()
    )
    return e


def tex_rank_table(g):
    rows = []
    for _, r in g.iterrows():
        rows.append(
            f"{r['invariant'].replace('_', '\\_')} & {r['median_lag']:.1f} & {r['iqr_lag']:.1f} & {r['sig_pct']:.0f} & {r['Rk']:.2f} \\\\"
        )
    body = "\n".join(rows)
    tex = (
        r"""
\begin{table}[t]
\centering
\small
\caption{Global ranking of invariants by robustness score $R_k$ and median lead $\widetilde{\ell}_k$ (min; negative = precedes).}
\label{tab:rank_global}
\begin{tabular}{@{}l r r r r@{}}
\toprule
\textbf{Invariant} & $\widetilde{\ell}_k$ & IQR & Sig.\ stations [\%%] & $R_k$ \\
\midrule
"""
        + body
        + r"""
\bottomrule
\end{tabular}
\end{table}
"""
    )
    return tex


def tex_event_table(e, events):
    # pivot inv × evento con (lag, sig)
    frames = []
    for ev in events:
        sub = e[e["date"] == ev][["invariant", "median_lag", "sig_pct"]].copy()
        sub.columns = ["invariant", f"lag_{ev}", f"sig_{ev}"]
        frames.append(sub)
    if not frames:
        return r"\begin{table}[t]\centering\small\caption{No data}\label{tab:event_summary}\begin{tabular}{@{}l@{}}\toprule No data\\ \bottomrule\end{tabular}\end{table}"

    M = frames[0]
    for sub in frames[1:]:
        M = M.merge(sub, on="invariant", how="outer")
    M = M.fillna(np.nan)

    # filas
    rows_list = []
    for _, r in M.iterrows():
        parts = [r["invariant"].replace("_", r"\_")]
        for ev in events:
            lag = r.get(f"lag_{ev}", np.nan)
            sig = r.get(f"sig_{ev}", np.nan)
            parts.append(f"{lag:.1f} & {sig:.0f}")
        rows_list.append(" & ".join(parts) + r" \\")
    rows = "\n".join(rows_list)

    # encabezados y especificación de columnas
    ev_heads = " & ".join(
        [r"\multicolumn{2}{c}{\textbf{%s}}" % ev for ev in events]
    )
    ev_sub = " & ".join([r"$\widetilde{\ell}_k$ & Sig.\ [\%%]"] * len(events))
    colspec = "@{}l " + " ".join(["r r"] * len(events)) + "@{}"
    endcol = 2 * len(events) + 1  # para \cmidrule(lr){2-endcol}

    # plantilla con % (no choca con llaves de LaTeX)
    tex = r"""\begin{table}[t]
\centering
\small
\caption{Per–event summary: median lead $\widetilde{\ell}_k$ (min; negative = precedes) and percent of stations with significant pre–onset change.}
\label{tab:event_summary}
\begin{tabular}{%s}
\toprule
 & %s \\
\cmidrule(lr){2-%d}
\textbf{Invariant} & %s \\
\midrule
%s
\bottomrule
\end{tabular}
\end{table}
""" % (colspec, ev_heads, endcol, ev_sub, rows)
    return tex

In [12]:
def heatmap_median_lead(e, events, outpath):
    piv = e.pivot(
        index="invariant", columns="date", values="median_lag"
    ).reindex(columns=events)
    if piv.empty:
        return

    # Convert to hours for better readability
    piv = piv / 60.0

    fig, ax = plt.subplots(figsize=(10, max(4, 0.35 * len(piv))))
    im = ax.imshow(piv.values, aspect="auto")
    ax.set_yticks(range(len(piv)))
    ax.set_yticklabels([s.replace("_", " ") for s in piv.index])
    ax.set_xticks(range(len(events)))
    ax.set_xticklabels(events)
    ax.set_title(
        "Median lead by invariant and event (min; negative = precedes)"
    )
    for i in range(piv.shape[0]):
        for j in range(piv.shape[1]):
            v = piv.values[i, j]
            if pd.notna(v):
                ax.text(j, i, f"{v:.0f}", ha="center", va="center")
    fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    fig.tight_layout()
    Path(outpath).parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(outpath, dpi=180)
    plt.close(fig)


def violin_lags(df, date, outpath):
    sub = df[df["date"] == date].copy()
    if sub.empty:
        return
    invs = list(sub["invariant"].unique())
    data = [
        sub.loc[sub["invariant"] == k, "lag_minutes"].dropna().values
        for k in invs
    ]

    # Convert to houyrs for better readability
    data = list(map(lambda arr: arr / 60.0, data))

    labels = [k.replace("_", " ") for k in invs]
    fig, ax = plt.subplots(figsize=(10, max(4, 0.35 * len(labels))))
    _ = ax.violinplot(data, showmedians=True, vert=False)
    ax.set_yticks(np.arange(1, len(labels) + 1))
    ax.set_yticklabels(labels)
    ax.set_xlabel("lag* (hours; negative = precedes)")
    ax.set_title(f"Station-wise lag distributions by invariant — {date}")
    fig.tight_layout()
    Path(outpath).parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(outpath, dpi=180)
    plt.close(fig)

## Main

In [13]:
def main():
    events = list(datetimes.keys())

    all_res = []
    for date in events:
        logger.info(f"Processing event {date}")
        res = process_event(date)
        if res.empty:
            logger.warning(f"No results for {date}")
            continue
        res.to_csv(out_dir / f"station_results_{date}.csv", index=False)
        all_res.append(res)

        # figuras por evento (violines)
        violin_lags(
            res, date, fig_out_dir / "violins" / f"lag_violin_{date}.pdf"
        )

    if not all_res:
        logger.error("No events processed. Check column names in your CSVs.")
        return 1

    df = pd.concat(all_res, ignore_index=True)
    g = summarize_global(df)
    e = summarize_by_event(df)
    g.to_csv(out_dir / "global_rank.csv", index=False)
    e.to_csv(out_dir / "event_summary.csv", index=False)

    # tablas LaTeX
    (table_out_dir / "rank_global.tex").write_text(tex_rank_table(g))
    (table_out_dir / "event_summary.tex").write_text(tex_event_table(e, events))

    # heatmap global
    heatmap_median_lead(
        e, events, fig_out_dir / "heatmaps" / "median_lead_heatmap.png"
    )

    logger.info("[OK] Results written to outputs/, tables/, figures/")
    return 0

In [None]:
SystemExit(main())