In [1]:
from pathlib import Path
import math
import numpy as np
import matplotlib.pyplot as plt
import xarray as xr
import pandas as pd
import cartopy.crs as ccrs
from scipy import stats  # [추가] 통계 검정용

In [2]:
idx = pd.read_csv("../Data/neuron_indices.csv")
u850_anom = xr.open_dataset("../Data/uwnd850_anom.nc")
olr_anom = xr.open_dataset("../Data/olr_anom2.nc")
sst_anom = xr.open_dataset(Path("../Data/sst_anom.nc"))

In [3]:
# 1. Create the summer date index (1991-2023, JJA)
# Total days: 3036
full_range = pd.date_range(start='1991-06-01', end='2023-08-31', freq='D')
summer_dates = full_range[full_range.month.isin([6, 7, 8])]

# 2. Assign the new index to your existing DataFrame 'idx'
# This replaces the 0, 1, 2... index with the dates directly.
idx.index = summer_dates
idx.columns = ["cluster"]

# 3. Verify the result
print("--- Result ---")
print(idx.head())

print(f"\nShape of idx: {idx.shape}")

--- Result ---
            cluster
1991-06-01        7
1991-06-02        7
1991-06-03        7
1991-06-04        7
1991-06-05        7

Shape of idx: (3036, 1)


In [4]:
# Check unique clusters
unique_clusters = idx['cluster'].unique()
print(f"Unique clusters: {sorted(unique_clusters)}")
print(f"Number of clusters: {len(unique_clusters)}")

Unique clusters: [np.int64(1), np.int64(2), np.int64(3), np.int64(4), np.int64(5), np.int64(6), np.int64(7), np.int64(8), np.int64(9)]
Number of clusters: 9


In [5]:
def compute_lead_lag_composites(data_dict, idx, lag_periods):
    """
    Compute lead-lag composites AND p-values for all clusters and lag periods.
    
    Returns:
    --------
    composites : dict
        Mean fields {variable: {cluster: {lag_name: composite}}}
    p_values : dict
        P-value fields {variable: {cluster: {lag_name: p_value}}}
    """
    import time
    
    unique_clusters = idx['cluster'].unique()
    
    # Initialize structures
    composites = {var: {} for var in data_dict.keys()}
    p_values = {var: {} for var in data_dict.keys()} # [추가]
    
    start_time = time.time()
    
    # Pre-fetch time values for faster masking
    time_values = data_dict[list(data_dict.keys())[0]].time.values
    
    for cluster in sorted(unique_clusters):
        print(f"\nProcessing Cluster {cluster}...")
        
        # Get dates for this cluster
        cluster_dates = idx[idx['cluster'] == cluster].index
        print(f"  Number of days: {len(cluster_dates)}")
        
        # Initialize cluster dictionaries
        for var in data_dict.keys():
            composites[var][cluster] = {}
            p_values[var][cluster] = {} # [추가]
        
        # Compute for each lag
        for lag_name, lag_days in lag_periods.items():
            # Calculate lagged dates
            lagged_dates = cluster_dates + pd.Timedelta(days=lag_days)
            lagged_dates_np = lagged_dates.to_numpy()
            
            # Find valid dates
            valid_mask = np.isin(lagged_dates_np, time_values)
            valid_dates = lagged_dates[valid_mask]
            
            n_samples = len(valid_dates)
            
            for var_name, data in data_dict.items():
                if n_samples > 1: # 최소 2개 이상의 샘플 필요
                    selection = data.sel(time=valid_dates)
                    
                    # 1. Mean Composite
                    comp_mean = selection.mean(dim='time')
                    
                    # 2. Significance Test (One-sample t-test against 0)
                    comp_std = selection.std(dim='time', ddof=1)
                    se = comp_std / np.sqrt(n_samples) # Standard Error
                    
                    # T-statistic & P-value
                    # (Avoid division by zero if se is 0, though xarray handles inf usually)
                    t_stat = comp_mean / se
                    pval = 2 * stats.t.sf(np.abs(t_stat), df=n_samples - 1)
                    
                    # Wrap p-value in DataArray to keep coordinates
                    pval_da = xr.DataArray(pval, coords=comp_mean.coords, dims=comp_mean.dims)
                    
                    composites[var_name][cluster][lag_name] = comp_mean
                    p_values[var_name][cluster][lag_name] = pval_da
                else:
                    composites[var_name][cluster][lag_name] = None
                    p_values[var_name][cluster][lag_name] = None
    
    elapsed_time = time.time() - start_time
    print(f"\n✓ All composites & stats computed in {elapsed_time:.2f}s ({elapsed_time/60:.2f}min)")
    
    return composites, p_values

print("Function defined: compute_lead_lag_composites (with t-test)")

Function defined: compute_lead_lag_composites (with t-test)


In [6]:
# Define lead-lag periods (in days) in time order
lag_periods = {
    'lead_1y': -365,
    'lead_6m': -180,
    'lead_3m': -90,
    'lead_2m': -60,
    'lead_1m': -30,
    '0d': 0,
    'lag_1m': 30,
    'lag_2m': 60,
    'lag_3m': 90,
    'lag_6m': 180,
    'lag_1y': 365
}

print("Lead-lag periods (time order):")
for name, days in lag_periods.items():
    print(f"  {name}: {days} days")

Lead-lag periods (time order):
  lead_1y: -365 days
  lead_6m: -180 days
  lead_3m: -90 days
  lead_2m: -60 days
  lead_1m: -30 days
  0d: 0 days
  lag_1m: 30 days
  lag_2m: 60 days
  lag_3m: 90 days
  lag_6m: 180 days
  lag_1y: 365 days


In [7]:
def _pretty_lag_label(lag_name):
    if lag_name == '0d':
        return '0d'
    direction = 'Lead' if lag_name.startswith('lead_') else 'Lag'
    suffix = lag_name.split('_', 1)[1] if '_' in lag_name else lag_name
    suffix_map = {
        '1y': '1-year', '6m': '6-month', '3m': '3-month',
        '2m': '2-month', '1m': '1-month'
    }
    return f"{direction} {suffix_map.get(suffix, suffix)}"

def plot_variable_lead_lag(composites, p_values, variable, unique_clusters, lag_periods, 
                           vmin, vmax, cbar_label, cmap='RdBu_r', sig_level=0.05):
    """
    Create lead-lag composite plots with stippling for significance (p < 0.05).
    """
    results_dir = Path("../Results")
    results_dir.mkdir(parents=True, exist_ok=True)
    
    ncols = 3
    nrows = math.ceil(len(lag_periods) / ncols)
    
    for cluster in sorted(unique_clusters):
        fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 6, nrows * 5), 
                                subplot_kw={'projection': ccrs.PlateCarree(central_longitude=180)},
                                layout="constrained")
        axes = np.atleast_1d(axes).ravel()
        fig.suptitle(f'{variable.upper()} Lead-Lag Composites - Cluster {cluster}', 
                    fontsize=16, fontweight='bold')
        
        im = None
        
        for idx_ax, (lag_name, lag_days) in enumerate(lag_periods.items()):
            ax = axes[idx_ax]
            comp = composites[variable][cluster][lag_name]
            pval = p_values[variable][cluster][lag_name] # [추가] Get p-value
            
            label = _pretty_lag_label(lag_name)
            
            if comp is not None:
                # 1. Plot Shading (Mean)
                im = comp.plot(ax=ax, cmap=cmap, add_colorbar=False, 
                              vmin=vmin, vmax=vmax, transform=ccrs.PlateCarree())
                
                # 2. Add Stippling (Significance)
                if pval is not None:
                    # levels=[0, sig_level, 1]: 0~0.05 gets hatched, 0.05~1 gets transparent
                    ax.contourf(comp.lon, comp.lat, pval, 
                                levels=[0, sig_level, 1],
                                hatches=['..', ''], # '..' = dots
                                colors='none', 
                                transform=ccrs.PlateCarree())
                
                ax.set_title(f'{label} ({abs(lag_days)}d)', fontweight='bold')
                ax.set_xlabel('Longitude')
                ax.set_ylabel('Latitude')
                ax.coastlines()
            else:
                ax.text(0.5, 0.5, 'No data', ha='center', va='center', 
                       transform=ax.transAxes)
                ax.set_title(f'{label} ({abs(lag_days)}d)')
        
        # Hide unused subplots
        for extra_ax in axes[len(lag_periods):]:
            extra_ax.set_visible(False)
        
        if im is not None:
            fig.colorbar(im, ax=axes, orientation='horizontal', 
                        pad=0.05, label=cbar_label, shrink=0.5)
        
        outfile = results_dir / f"{variable}_lead_lag_stippled_cluster{cluster}.png"
        plt.savefig(outfile, dpi=300)
        plt.close(fig)

# Note: plot_combined_composites는 Stippling 로직 추가가 복잡하여 
# 개별 변수 플롯(plot_variable_lead_lag) 사용을 권장합니다.
 
print("Plotting functions updated with Stippling")

Plotting functions updated with Stippling


In [8]:
# Compute all composites and p-values
data_dict = {
    'u850': u850_anom['uwnd'],
    'olr': olr_anom['olr'],
    'sst': sst_anom['anom']
}

# [수정] Unpack two return values
composites, p_values = compute_lead_lag_composites(data_dict, idx, lag_periods)


Processing Cluster 1...
  Number of days: 381

Processing Cluster 2...
  Number of days: 375

Processing Cluster 3...
  Number of days: 318

Processing Cluster 4...
  Number of days: 311

Processing Cluster 5...
  Number of days: 312

Processing Cluster 6...
  Number of days: 281

Processing Cluster 7...
  Number of days: 343

Processing Cluster 8...
  Number of days: 375

Processing Cluster 9...
  Number of days: 340

✓ All composites & stats computed in 226.61s (3.78min)


## Combined View: All Variables for Each Cluster

## Visualize SST Lead-Lag Composites

In [9]:
# Visualize SST with Stippling
plot_variable_lead_lag(composites, p_values, 'sst', unique_clusters, lag_periods,
                       vmin=-1, vmax=1, cbar_label='SST Anomaly (°C)')

## Visualize OLR Lead-Lag Composites

In [10]:
# Visualize OLR with Stippling
plot_variable_lead_lag(composites, p_values, 'olr', unique_clusters, lag_periods,
                       vmin=-20, vmax=20, cbar_label='OLR Anomaly (W/m²)')

## Visualize U850 Lead-Lag Composites

In [11]:
# Visualize U850 with Stippling
plot_variable_lead_lag(composites, p_values, 'u850', unique_clusters, lag_periods,
                       vmin=-3, vmax=3, cbar_label='U850 Anomaly (m/s)')