In [None]:
import numpy as np
import xarray as xr
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

# ==========================================
# 1. 設定檔案路徑
# ==========================================
nc_file = "data/subset_50E_220E_-20_70-001.nc"
npz_file = "data/EOF_z_JJA_20to30N_profile_eddy.npz"

print("--- 開始執行優化版流程 ---")

# ==========================================
# 2. 讀取與預處理 (完全參照老師邏輯)
# ==========================================
# 讀取 NetCDF
ds = xr.open_dataset(nc_file)
# 讀取 NPZ
data_npz = np.load(npz_file)
full_index = data_npz['NPSH_index_total'][0,:] # 假設這是全時間序列

# --- A) 空間裁切 (0-40N, 0-180E) ---
# 老師的代碼邏輯： lat 0-40, lon 0-180
print("正在進行空間裁切...")
# 處理緯度順序 (由小到大或由大到小)
lat_slice = slice(0, 40) if ds.lat[0] < ds.lat[-1] else slice(40, 0)
ds_sub = ds.sel(lat=lat_slice)

# 處理經度 (只取 0-180E)
# 假設資料是 0-360 或 -180-180，這裡使用 slice(0, 180) 
ds_sub = ds_sub.sel(lon=slice(0, 180))

# --- B) 時間篩選 JJA ---
print("正在篩選 JJA 資料...")
time_name = 'time' # 假設變數名
ds_jja = ds_sub.sel({time_name: ds_sub[time_name].dt.month.isin([6, 7, 8])})

# 同步處理 Index (確保時間對齊)
# 建立一個含有時間座標的 Series 方便篩選
dates_full = pd.to_datetime(ds[time_name].values)
# 確保 index 長度跟原始 ds 時間長度一樣，如果 nc 檔跟 npz 檔時間不一致這裡會報錯
# 這裡假設它們是對應的
if len(dates_full) != len(full_index):
    # 如果長度不同，嘗試建立一個從 1979-01-01 開始的時間軸 (參考老師代碼)
    # 這裡為了保險，我們假設 npz 裡面的 index 是跟著 nc 檔走的
    print(f"警告: Index長度({len(full_index)}) 與 NC時間長度({len(dates_full)}) 不一致")
    print("full_index:",full_index.shape)
    print("dates_full:",dates_full.shape)
    # 這裡做一個簡單的修正：只取兩者重疊的部分或者假設順序一致
    # 為求穩健，我們先假設使用者資料是對齊的，若報錯請告知
    pass

# 篩選 JJA Index
is_jja = np.isin(dates_full.month, [6, 7, 8])
index_jja = full_index[is_jja]
# 確保 index 形狀是 (T,)
index_jja = np.squeeze(index_jja)

# --- C) 計算異常值 (Subtract Mean) ---
z_jja = ds_jja['z']
z_mean = z_jja.mean(dim=time_name)
z_anom = z_jja - z_mean

# 去除 Index 平均
index_anom = index_jja - np.mean(index_jja)

# --- D) 緯度加權 (老師的關鍵步驟) ---
print("執行緯度加權 (Latitude Weighting)...")
lat_vals = z_anom.lat.values
# 廣播權重: sqrt(cos(lat))
weights = np.sqrt(np.cos(np.deg2rad(lat_vals)))
# 利用 xarray 的廣播機制，需要把 weights 轉成 DataArray
weights_da = xr.DataArray(weights, coords={'lat': lat_vals}, dims='lat')
z_weighted = z_anom * weights_da

# --- E) 展平與移除 NaN ---
# stack (time, lat, lon) -> (time, space)
z_flat = z_weighted.stack(space=('lat', 'lon'))
# 移除全為 NaN 的點 (如陸地遮罩)
z_flat = z_flat.dropna(dim='space', how='any')
A = z_flat.values # 矩陣 A (Time, Space)

print(f"矩陣 A 形狀: {A.shape}")

# ==========================================
# 3. EOF 分析 (SVD)
# ==========================================
print("計算 SVD...")
U, s, Vt = np.linalg.svd(A, full_matrices=False)

# 取前 10 個 mode
n_modes = 10
PCs = U[:, :n_modes] * s[:n_modes] # (Time, Modes)
EOFs = Vt[:n_modes, :]             # (Modes, Space)

# 計算解釋變異度看看
eigvals = (s**2) / (A.shape[0] - 1)
var_frac = eigvals[:n_modes] / np.sum(eigvals)
print(f"前 10 個 Mode 解釋變異度總和: {np.sum(var_frac)*100:.2f}%")

# ==========================================
# 4. 計算動力算子 G (老師的 Year-by-Year 方法)
# ==========================================
print("計算動力算子 G (Year-by-Year Average)...")

# 自動偵測年份與每年天數
years = pd.to_datetime(ds_jja[time_name].values).year
unique_years = np.unique(years)
n_years = len(unique_years)
days_per_year = A.shape[0] // n_years 

print(f"偵測到 {n_years} 年, 每年約 {days_per_year} 天 (JJA)")

# 重新 reshape PCs: (Year, Day, Mode)
# 注意：如果有閏年導致天數不一致，reshape 會報錯。
# 老師的代碼假設每年天數固定(92天)。如果你的資料有閏年處理問題，這裡可能需要更強壯的寫法。
# 這裡先嘗試強制 reshape，如果報錯代表天數不整齊
try:
    PCs_reshaped = PCs.reshape(n_years, -1, n_modes)
    days_in_season = PCs_reshaped.shape[1]
except ValueError:
    print("錯誤：無法整除年份，請檢查資料是否包含完整的 JJA")
    raise

G_list = []
for i in range(n_years):
    # 取出第 i 年的 PC
    # shape: (Days, Modes)
    pc_year = PCs_reshaped[i, :, :]
    
    # x(t) 與 x(t+1)
    x_t = pc_year[:-1, :]   # 今天
    x_tp1 = pc_year[1:, :]  # 明天
    
    # 計算該年的 G
    # G = (X0^T X0)^-1 (X0^T X1) -> 注意這通常是 G^T 的解法，或者 X1 = X0 G
    # 老師代碼: inv(tmp.T.dot(tmp)).dot(tmp.T.dot(tmp_p1))
    # 這對應的公式是 x(t+1) = x(t) G (行向量形式)
    # 或者是 G_transpose. 
    # 讓我們統一用標準形式: x(t+1) = G x(t) (列向量形式)
    # 若 PC 是 (Time, Mode)，則 X1 = X0 * G_T
    # G_T = inv(X0.T X0) * X0.T X1
    
    cov = np.dot(x_t.T, x_t)
    lag_cov = np.dot(x_t.T, x_tp1)
    
    # 加上一個微小的 noise 避免矩陣不可逆 (Regularization)
    cov_reg = cov + np.eye(n_modes) * 1e-6
    
    # G_transpose = inv(cov) * lag_cov
    G_T_year = np.dot(np.linalg.inv(cov_reg), lag_cov)
    
    # 我們要 G，所以轉置回來
    G_year = G_T_year.T
    G_list.append(G_year)

# 平均 G
G_ave = np.mean(np.array(G_list), axis=0)

# 檢查穩定性
print(f"G_ave 最大特徵值: {np.max(np.abs(np.linalg.eigvals(G_ave))):.4f}")


# ==========================================
# 5. 計算投影矩陣 Rho (Regression)
# ==========================================
print("計算投影矩陣 Rho...")
# Index = PCs * a
# a = inv(PC.T PC) * PC.T * Index
sol = np.linalg.lstsq(PCs, index_anom, rcond=None)
a_coeffs = sol[0]
rho = np.diag(a_coeffs) # 老師代碼中的 R 矩陣

# ==========================================
# 6. 瞬變增長分析 (Transient Growth)
# ==========================================
print("分析瞬變增長...")

max_lag = 15
growth_curve = []
lags = np.arange(1, max_lag + 1)

for tau in lags:
    # G^tau
    G_tau = np.linalg.matrix_power(G_ave, tau)
    
    # M = (rho G^tau) (rho G^tau)^T
    # A = rho * G^tau
    A_mat = np.dot(rho, G_tau)
    M = np.dot(A_mat, A_mat.T)
    
    # M 的最大特徵值即為最大變異數增長
    eigvals_M = np.linalg.eigvals(M)
    max_growth = np.max(np.abs(eigvals_M))
    
    growth_curve.append(max_growth)

# ==========================================
# 7. 畫圖 (Growth Curve)
# ==========================================
growth_curve = np.array(growth_curve)

# 為了比較好觀察，我們將第一天設為基準 (Normalize)
# 或是直接畫原始值
plt.figure(figsize=(8, 5))
plt.plot(lags, growth_curve, 'o-', color='tab:red', linewidth=2)

# 標示峰值
peak_idx = np.argmax(growth_curve)
peak_lag = lags[peak_idx]
plt.axvline(peak_lag, color='gray', linestyle='--', alpha=0.6, label=f'Peak at Day {peak_lag}')

plt.title("Non-Modal Growth of WNPSH (Replicating Teacher's Method)")
plt.xlabel("Lag (Days)")
plt.ylabel("Variance Growth")
plt.grid(True, alpha=0.3)
plt.legend()
plt.show()

print(f"峰值出現在第 {peak_lag} 天，這是否呈現先升後降？")

In [None]:
var_frac

In [None]:
rho

In [None]:
ds_jja["z"].shape


In [None]:
for i in range(10):
    cs=plt.contourf(EOFs[i,:].reshape([80,208]))
    plt.colorbar(cs)
    plt.title(i)
    plt.show()


In [None]:
print(EOF_normal.reshape(10,80*208).shape)
print(np.diag(rho).shape)
print(np.diag(rho).dot(EOF_normal.reshape(10,80*208)).shape)

In [None]:
# np.array(EOFs)
EOF_normal = np.array(EOFs).reshape([10,80*208])*np.std(PCs, axis=0, keepdims=True).T
EOF_normal = EOF_normal.reshape([10,80,208])

WNPSH_composite = (np.diag(rho).dot(EOF_normal.reshape([10,80*208]))).reshape([80,208])
import numpy as np
import matplotlib.pyplot as plt

# Define coordinate arrays matching the reshape dimensions
lat = np.linspace(0, 40, 80)
lon = np.linspace(0, 180, 208)

# If not already 2D meshgrids, make them
Lon, Lat = np.meshgrid(lon, lat)

# Plot
plt.figure(figsize=(8, 4))
cf = plt.contourf(Lon, Lat, WNPSH_composite, cmap='RdBu_r', levels=21)

plt.colorbar(cf, label='Composite amplitude')
plt.title('WNPSH Composite (0–40°N, 0–180°E)')
plt.xlabel('Longitude (°E)')
plt.ylabel('Latitude (°N)')

plt.xlim(0, 180)
plt.ylim(0, 40)
plt.tight_layout()
plt.show()


In [None]:
# x_p1 = x.dot(G).rho
import matplotlib.colors as mcolors
for i in range(3,60):
    G_lag = np.linalg.matrix_power(G_ave,i)
    R = np.diag(rho[0,:])
    M = G_lag.dot(R).dot((G_lag.dot(R)).T)


    eigvals, eigvecs = np.linalg.eigh(M)

    # eigh returns ascending; flip to descending for convenience
    idx = np.argsort(eigvals)[::-1]
    eigvals = eigvals[idx]
    eigvecs = eigvecs[:, idx]

    tmp = EOF_normal.reshape([10,80*208])
    initial_optimal = eigvecs[:,0][:,np.newaxis].T.dot(tmp)
    initial_optimal = initial_optimal.reshape([80,208])

    norm = mcolors.TwoSlopeNorm(vmin=np.nanmin(initial_optimal),
                                vcenter=0,
                                vmax=np.nanmax(initial_optimal))

    plt.figure(figsize=(8, 4))
    cf = plt.contourf(Lon, Lat, initial_optimal, cmap='RdBu_r', norm=norm, levels=21)
    plt.contour(Lon, Lat, WNPSH_composite, colors='k', linewidths=0.6)

    plt.colorbar(cf, label='Composite amplitude')
    plt.title(f'WNPSH Composite (0–40°N, 0–180°E) , day{i:2d}')
    plt.xlabel('Longitude (°E)')
    plt.ylabel('Latitude (°N)')
    plt.xlim(0, 180)
    plt.ylim(0, 40)
    plt.tight_layout()
    plt.show()


In [None]:
for i in range(2,60,3):
    print(i)