In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import os

# ==========================================
# 1. Setup Directories
# ==========================================

# Set working directory to the script's location
try:
    script_dir = Path(__file__).resolve().parent
    os.chdir(script_dir)
    print(f"Working directory: {os.getcwd()}")
except NameError:
    script_dir = Path.cwd()

# Define paths relative to the script
DATA_DIR = Path("../Data")
RESULTS_DIR = Path("../Results")

# Create Results directory if it doesn't exist
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

# ==========================================
# 2. Helper Functions (Data Loading)
# ==========================================

def load_wide_format(filename, skip_rows, name, missing_val=None):
    """Loads climate data in 'Wide' format (Year, Jan...Dec)."""
    filepath = DATA_DIR / filename
    try:
        with open(filepath, 'r') as f:
            lines = f.readlines()
        
        data = []
        for line in lines[skip_rows:]:
            parts = line.split()
            if len(parts) >= 13: 
                try:
                    year = int(parts[0])
                    vals = [float(x) for x in parts[1:13]]
                    for month, val in enumerate(vals, 1):
                        data.append({'year': year, 'month': month, name: val})
                except ValueError:
                    continue
        
        df = pd.DataFrame(data)
        if missing_val is not None:
             df[name] = df[name].replace(missing_val, np.nan)
        return df
    except Exception as e:
        print(f"Warning: Could not load {name} from {filepath}. Error: {e}")
        return pd.DataFrame()

def load_long_format(filename, name, missing_val=None):
    """Loads climate data in 'Long' format (Year, Month, Value)."""
    filepath = DATA_DIR / filename
    try:
        with open(filepath, 'r') as f:
            lines = f.readlines()
        
        data = []
        start_idx = 0
        for i, line in enumerate(lines):
             parts = line.split()
             if len(parts) >= 3 and parts[0].isdigit() and len(parts[0])==4 and parts[1].isdigit():
                 start_idx = i
                 break
        
        for line in lines[start_idx:]:
            parts = line.split()
            if len(parts) >= 3:
                try:
                    year = int(parts[0])
                    month = int(parts[1])
                    val = float(parts[2])
                    data.append({'year': year, 'month': month, name: val})
                except ValueError:
                    continue 

        df = pd.DataFrame(data)
        if missing_val is not None:
             df[name] = df[name].replace(missing_val, np.nan)
        return df
    except Exception as e:
        print(f"Warning: Could not load {name} from {filepath}. Error: {e}")
        return pd.DataFrame()

def calculate_seasonal_means(df_wide, index_name):
    """Calculates seasonal means (DJF, MAM, JJA, SON)."""
    df_pivot = df_wide.pivot(index='year', columns='month', values=index_name)
    seasonal_data = pd.DataFrame(index=df_pivot.index)
    
    seasonal_data[f"{index_name}_JJA"] = df_pivot[[6, 7, 8]].mean(axis=1)
    seasonal_data[f"{index_name}_MAM"] = df_pivot[[3, 4, 5]].mean(axis=1)
    seasonal_data[f"{index_name}_SON"] = df_pivot[[9, 10, 11]].mean(axis=1)
    
    # DJF: Dec (prev year) + Jan (curr) + Feb (curr)
    dec_prev = df_pivot[12].shift(1) 
    seasonal_data[f"{index_name}_DJF"] = (dec_prev + df_pivot[1] + df_pivot[2]) / 3
    
    return seasonal_data

# ==========================================
# 3. Data Processing
# ==========================================

print("--- 1. Loading Cluster Data ---")
cluster_file = DATA_DIR / "u850_cluster_counts.csv"

try:
    df_clusters = pd.read_csv(cluster_file, index_col="year")
    df_clusters.columns = df_clusters.columns.astype(str)
    
    # Filter for target clusters (1, 5, 8)
    target_clusters = ['1', '5', '8']
    available_cols = [c for c in df_clusters.columns if c.strip() in target_clusters]
    df_clusters = df_clusters[available_cols]
    
    print(f"Loaded clusters: {df_clusters.columns.tolist()}")
except Exception as e:
    print(f"Critical Error loading {cluster_file}: {e}")
    exit()

print("\n--- 2. Loading & Processing Climate Indices ---")
files = {
    "DMI": ("dmi.had.long.data", "wide", 1),
    "PDO": ("ersst.v5.pdo.dat", "wide", 2),
    "ONI": ("oni.data", "wide", 1),
    "NAO": ("norm.nao.monthly.b5001.current.ascii", "long", None),
    "PNA": ("norm.pna.monthly.b5001.current.ascii", "long", None),
    "WP":  ("wp_index.txt", "long", -99.90)
}

df_seasonal_all = pd.DataFrame()

for name, (fname, fmt, param) in files.items():
    if fmt == "wide":
        df_raw = load_wide_format(fname, param, name)
    else:
        df_raw = load_long_format(fname, name, missing_val=param)
        
    if not df_raw.empty:
        df_seas = calculate_seasonal_means(df_raw, name)
        if df_seasonal_all.empty:
            df_seasonal_all = df_seas
        else:
            df_seasonal_all = df_seasonal_all.join(df_seas, how='outer')
        print(f"Processed {name}")

# ==========================================
# 4. Correlation & Plotting
# ==========================================

print("\n--- 3. Analyzing & Plotting ---")
common_years = df_clusters.index.intersection(df_seasonal_all.index)
df_clus_common = df_clusters.loc[common_years]
df_clim_common = df_seasonal_all.loc[common_years]

# Identify significant correlations
correlations = {}
threshold = 0.30

for clus in df_clus_common.columns:
    for ind in df_clim_common.columns:
        r = df_clus_common[clus].corr(df_clim_common[ind])
        if abs(r) > threshold:
            correlations[(clus, ind)] = r

sorted_corrs = sorted(correlations.items(), key=lambda x: abs(x[1]), reverse=True)

if sorted_corrs:
    print(f"Found {len(sorted_corrs)} significant correlations. Saving plots...")
    
    for (clus, ind), r in sorted_corrs:
        # Create a NEW figure for each plot
        fig, ax1 = plt.subplots(figsize=(10, 6))
        
        # Plot Cluster Frequency (Left Axis)
        color1 = 'tab:blue'
        ax1.set_xlabel('Year', fontsize=12)
        ax1.set_ylabel(f'Cluster {clus} Count', color=color1, fontsize=12, fontweight='bold')
        ax1.plot(df_clus_common.index, df_clus_common[clus], color=color1, marker='o', linewidth=2, label=f'Cluster {clus}')
        ax1.tick_params(axis='y', labelcolor=color1, labelsize=10)
        ax1.grid(axis='x', linestyle='--', alpha=0.5)
        
        # Plot Climate Index (Right Axis)
        ax2 = ax1.twinx()
        color2 = 'tab:red'
        ax2.set_ylabel(f'{ind} Index', color=color2, fontsize=12, fontweight='bold')
        ax2.plot(df_clim_common.index, df_clim_common[ind], color=color2, linestyle='--', marker='x', linewidth=1.5, label=ind)
        ax2.tick_params(axis='y', labelcolor=color2, labelsize=10)
        
        # Add Zero Line for Index
        ax2.axhline(0, color='gray', linewidth=0.8, alpha=0.5)
        
        # Title
        plt.title(f"Cluster {clus} vs {ind} (Correlation: r={r:.3f})", fontsize=14, pad=15)
        
        # Save File
        # Sanitize filename (replace invalid chars if any, though usually fine here)
        safe_ind = ind.replace("/", "_")
        filename = f"timeseries_Cluster{clus}_{safe_ind}.png"
        save_path = RESULTS_DIR / filename
        
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close(fig) # Close figure to free memory
        
        print(f"  Saved: {filename}")
        
    print(f"\nAll {len(sorted_corrs)} plots saved to {RESULTS_DIR.resolve()}")

else:
    print("No significant correlations found to plot.")

--- 1. Loading Cluster Data ---
Loaded clusters: ['1', '5', '8']

--- 2. Loading & Processing Climate Indices ---
Processed DMI
Processed PDO
Processed ONI
Processed NAO
Processed PNA
Processed WP

--- 3. Analyzing & Plotting ---
Found 8 significant correlations. Saving plots...
  Saved: timeseries_Cluster8_ONI_DJF.png
  Saved: timeseries_Cluster1_DMI_SON.png
  Saved: timeseries_Cluster8_ONI_MAM.png
  Saved: timeseries_Cluster1_ONI_SON.png
  Saved: timeseries_Cluster8_WP_MAM.png
  Saved: timeseries_Cluster5_PNA_SON.png
  Saved: timeseries_Cluster1_PDO_MAM.png
  Saved: timeseries_Cluster1_ONI_DJF.png

All 8 plots saved to D:\Research\Projects\NWP_SOM\Results
