In [1]:
import time
from pathlib import Path
import math
import numpy as np
import matplotlib.pyplot as plt
import xarray as xr
import pandas as pd
import cartopy.crs as ccrs
from scipy import stats

In [2]:
# Load data
data_path = Path("../Data")
idx = pd.read_csv(data_path / "neuron_indices.csv")
u850_anom = xr.open_dataset(data_path / "uwnd850_anom.nc")
olr_anom = xr.open_dataset(data_path / "olr_anom2.nc")
sst_anom = xr.open_dataset(data_path / "sst_anom.nc")

# Create summer date index (1991-2023, JJA)
full_range = pd.date_range(start="1991-06-01", end="2023-08-31", freq="D")
summer_dates = full_range[full_range.month.isin([6, 7, 8])]
idx.index = summer_dates
idx.columns = ["cluster"]

# Get clusters info
unique_clusters = idx["cluster"].unique()

print(f"Data loaded: {idx.shape[0]} days, {len(unique_clusters)} clusters")
print(f"Clusters: {sorted(unique_clusters)}")
print(f"\nFirst few cluster assignments:")
print(idx.head())

Data loaded: 3036 days, 9 clusters
Clusters: [np.int64(1), np.int64(2), np.int64(3), np.int64(4), np.int64(5), np.int64(6), np.int64(7), np.int64(8), np.int64(9)]

First few cluster assignments:
            cluster
1991-06-01        7
1991-06-02        7
1991-06-03        7
1991-06-04        7
1991-06-05        7


In [3]:
# %%
# Configuration for plot settings
class PlotConfig:
    RESULTS_DIR = Path("../Results")
    DPI = 300


def compute_lead_lag_composites(data_dict, idx, lag_periods):
    """
    Compute lead-lag composites AND p-values for all clusters and lag periods.
    """
    unique_clusters = idx["cluster"].unique()
    composites = {var: {} for var in data_dict.keys()}
    p_values = {var: {} for var in data_dict.keys()}  # [추가] p-value 저장소

    # xarray time values for masking
    time_values = data_dict[list(data_dict.keys())[0]].time.values

    start_time = time.time()

    for cluster in sorted(unique_clusters):
        print(f"\nProcessing Cluster {cluster}...")
        cluster_dates = idx[idx["cluster"] == cluster].index
        print(f"  Days: {len(cluster_dates)}")

        for var in data_dict.keys():
            composites[var][cluster] = {}
            p_values[var][cluster] = {}

        for lag_name, lag_days in lag_periods.items():
            # Lag date calculation
            lagged_dates = cluster_dates + pd.Timedelta(days=lag_days)
            valid_mask = np.isin(lagged_dates.to_numpy(), time_values)
            valid_dates = lagged_dates[valid_mask]

            for var_name, data in data_dict.items():
                if (
                    len(valid_dates) > 1
                ):  # 샘플이 최소 2개 이상이어야 분산 계산 가능
                    # 데이터 선택
                    selection = data.sel(time=valid_dates)

                    # 1. Mean Composite
                    comp_mean = selection.mean(dim="time")

                    # 2. Significance Test (One-sample t-test against 0)
                    # t = mean / (std / sqrt(n))
                    comp_std = selection.std(dim="time", ddof=1)
                    n_samples = len(valid_dates)

                    # Standard Error
                    se = comp_std / np.sqrt(n_samples)

                    # T-statistic (se가 0인 경우 처리 필요하지만 xarray가 inf 처리함)
                    t_stat = comp_mean / se

                    # P-value (Two-tailed) using survival function (sf = 1 - cdf)
                    # Degrees of freedom = n - 1
                    pval = 2 * stats.t.sf(np.abs(t_stat), df=n_samples - 1)

                    # 결과를 xarray 형태로 변환 (좌표 유지를 위해)
                    pval_da = xr.DataArray(
                        pval, coords=comp_mean.coords, dims=comp_mean.dims
                    )

                    composites[var_name][cluster][lag_name] = comp_mean
                    p_values[var_name][cluster][lag_name] = pval_da
                else:
                    composites[var_name][cluster][lag_name] = None
                    p_values[var_name][cluster][lag_name] = None

    elapsed = time.time() - start_time
    print(
        f"\n✓ Composites & Stats computed in {elapsed:.2f}s ({elapsed / 60:.2f}min)"
    )
    return composites, p_values


print("✓ Functions loaded (with significance test)")

✓ Functions loaded (with significance test)


In [4]:
# %%
SUFFIX_MAP = {
    "1y": "1-year",
    "6m": "6-month",
    "3m": "3-month",
    "2m": "2-month",
    "1m": "1-month",
    "15d": "15-day",
    "10d": "10-day",
    "5d": "5-day",
}


def pretty_lag_label(lag_name):
    if lag_name == "0d":
        return "0d"
    direction = "Lead" if lag_name.startswith("lead_") else "Lag"
    suffix = lag_name.split("_", 1)[1] if "_" in lag_name else lag_name
    return f"{direction} {SUFFIX_MAP.get(suffix, suffix)}"

In [5]:
# %%
def plot_variable_lead_lag(
    composites,
    p_values,
    variable,
    unique_clusters,
    lag_periods,
    vmin,
    vmax,
    cbar_label,
    cmap="RdBu_r",
    sig_level=0.05,
):
    """
    Create lead-lag composite plots with stippling for significance.
    """
    PlotConfig.RESULTS_DIR.mkdir(parents=True, exist_ok=True)

    ncols, nrows = 3, math.ceil(len(lag_periods) / 3)

    for cluster in sorted(unique_clusters):
        fig, axes = plt.subplots(
            nrows,
            ncols,
            figsize=(ncols * 6, nrows * 5),
            subplot_kw={"projection": ccrs.PlateCarree(central_longitude=180)},
            layout="constrained",
        )
        axes = np.atleast_1d(axes).ravel()
        fig.suptitle(
            f"{variable.upper()} Lead-Lag Composites - Cluster {cluster}",
            fontsize=16,
            fontweight="bold",
        )

        im = None
        for idx_ax, (lag_name, lag_days) in enumerate(lag_periods.items()):
            ax = axes[idx_ax]
            comp = composites[variable][cluster][lag_name]
            pval = p_values[variable][cluster][
                lag_name
            ]  # [추가] p-value 가져오기

            label = pretty_lag_label(lag_name)

            if comp is not None:
                # 1. Plot Shading (Mean)
                im = comp.plot(
                    ax=ax,
                    cmap=cmap,
                    add_colorbar=False,
                    vmin=vmin,
                    vmax=vmax,
                    transform=ccrs.PlateCarree(),
                )

                # 2. Add Stippling (Significance)
                # p-value가 sig_level(0.05)보다 작은 곳에 점 찍기
                if pval is not None:
                    # levels=[0, sig_level, 1] -> 0~0.05 사이(유의함)에는 해칭, 나머지는 투명
                    ax.contourf(
                        comp.lon,
                        comp.lat,
                        pval,
                        levels=[0, sig_level, 1],
                        hatches=["..", ""],  # '..'은 점무늬, ''은 무늬없음
                        colors="none",  # 배경색 없음 (투명)
                        transform=ccrs.PlateCarree(),
                    )

                ax.set_title(f"{label} ({abs(lag_days)}d)", fontweight="bold")
                ax.set_xlabel("Longitude")
                ax.set_ylabel("Latitude")
                ax.coastlines()
            else:
                ax.text(
                    0.5,
                    0.5,
                    "No data",
                    ha="center",
                    va="center",
                    transform=ax.transAxes,
                )
                ax.set_title(f"{label} ({abs(lag_days)}d)")

        for extra_ax in axes[len(lag_periods) :]:
            extra_ax.set_visible(False)

        if im is not None:
            fig.colorbar(
                im,
                ax=axes,
                orientation="horizontal",
                pad=0.05,
                label=cbar_label,
                shrink=0.5,
            )

        outfile = (
            PlotConfig.RESULTS_DIR
            / f"{variable}_pentad_lead_lag_stippled_cluster{cluster}.png"
        )
        plt.savefig(outfile, dpi=PlotConfig.DPI)
        print(f"Saved: {outfile}")
        plt.close(fig)

In [6]:
# %%
# Define analysis parameters
short_lag_periods = {
    "lead_15d": -15,
    "lead_10d": -10,
    "lead_5d": -5,
    "0d": 0,
    "lag_5d": 5,
    "lag_10d": 10,
    "lag_15d": 15,
}

In [7]:
# Prepare data dictionary
data_dict = {
    "u850": u850_anom["uwnd"],
    "olr": olr_anom["olr"],
    "sst": sst_anom["anom"],
}

print(
    f"Analysis period: {len(short_lag_periods)} lags, {len(data_dict)} variables"
)

# %%
# [수정] Compute composites AND p-values
# 반환값이 2개(composites, p_values)로 변경되었습니다.
short_composites, short_pvalues = compute_lead_lag_composites(
    data_dict, idx, short_lag_periods
)

# %% [markdown]
# ### U850 Lead-Lag (with Stippling)

# %%
# Plot U850
# p_values 인자를 추가로 전달합니다.
plot_variable_lead_lag(
    short_composites,
    short_pvalues,
    "u850",
    unique_clusters,
    short_lag_periods,
    vmin=-3,
    vmax=3,
    cbar_label="U850 Anomaly (m/s)",
)

# %% [markdown]
# ### OLR Lead-Lag (with Stippling)

# %%
# Plot OLR
plot_variable_lead_lag(
    short_composites,
    short_pvalues,
    "olr",
    unique_clusters,
    short_lag_periods,
    vmin=-20,
    vmax=20,
    cbar_label="OLR Anomaly (W/m²)",
)

Analysis period: 7 lags, 3 variables

Processing Cluster 1...
  Days: 381

Processing Cluster 2...
  Days: 375

Processing Cluster 3...
  Days: 318

Processing Cluster 4...
  Days: 311

Processing Cluster 5...
  Days: 312

Processing Cluster 6...
  Days: 281

Processing Cluster 7...
  Days: 343

Processing Cluster 8...
  Days: 375

Processing Cluster 9...
  Days: 340

✓ Composites & Stats computed in 177.89s (2.96min)
Saved: ..\Results\u850_pentad_lead_lag_stippled_cluster1.png
Saved: ..\Results\u850_pentad_lead_lag_stippled_cluster2.png
Saved: ..\Results\u850_pentad_lead_lag_stippled_cluster3.png
Saved: ..\Results\u850_pentad_lead_lag_stippled_cluster4.png
Saved: ..\Results\u850_pentad_lead_lag_stippled_cluster5.png
Saved: ..\Results\u850_pentad_lead_lag_stippled_cluster6.png
Saved: ..\Results\u850_pentad_lead_lag_stippled_cluster7.png
Saved: ..\Results\u850_pentad_lead_lag_stippled_cluster8.png
Saved: ..\Results\u850_pentad_lead_lag_stippled_cluster9.png
Saved: ..\Results\olr_pentad_