In [None]:
#!/usr/bin/env python3
"""
Perform EOF analysis on all ensemble members for a chosen CMIP6 model.
Works for annual-mean CMIP6 data with 2D latitude/longitude fields (dims j,i).

Author: Freja Klejnstrup
Date: 2025-10-28
"""

import xarray as xr
import numpy as np
import os
from eofs.xarray import Eof

# -----------------------------
# User settings
# -----------------------------
model = 'EC-Earth3'
base_dir = "/data/projects/nckf/frekle/CMIP6_data"
members = [f"r{i}i1p1f1" for i in range(1, 26)]  # r1–r25
variables = ['thetao', 'so']  # 3D temp and salinity
output_dir = f"/data/projects/nckf/frekle/EOF_results/{model}"
os.makedirs(output_dir, exist_ok=True)


# -----------------------------
# Helper functions
# -----------------------------

def load_surface_data(var, member):
    """Load 3D CMIP6 field and extract the surface layer (lev=0)."""
    file_path = f"{base_dir}/{var}/{var}_masked_{member}.nc"
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"Missing file: {file_path}")

    ds = xr.open_dataset(file_path)

    da = ds[var]
    # Extract surface (first depth level)
    if 'lev' in da.dims:
        da = da.isel(lev=0)

    # Subset time range if needed
    if 'time' in da.dims:
        da = da.sel(time=slice("1840", "2014"))

    # Attach coordinates
    if 'latitude' in ds and 'longitude' in ds:
        da = da.assign_coords(lat=ds['latitude'], lon=ds['longitude'])

    return da


def detrend_mean(da):
    """Remove long-term mean to obtain anomalies."""
    return da - da.mean('time')


def calculate_eof(da, n_modes=10):
    """Perform EOF analysis using eofs.xarray."""
    # Compute weights (area weighting with latitude)
    lat = da.lat
    if lat.ndim == 2:
        # Some CMIP6 models use 2D lat/lon fields
        lat_weights = np.sqrt(np.cos(np.deg2rad(lat)))
    else:
        lat_weights = np.sqrt(np.cos(np.deg2rad(lat)))[:, np.newaxis]

    solver = Eof(da, weights=lat_weights)

    eofs = solver.eofsAsCorrelation(neofs=n_modes)
    pcs = solver.pcs(npcs=n_modes, pcscaling=1)
    varfrac = solver.varianceFraction(neigs=n_modes)
    return eofs, pcs, varfrac


def save_eof_results(member, varname, eofs, pcs, varfrac):
    """Save EOF results to NetCDF file."""
    out = xr.Dataset({
        "EOFs": eofs,
        "PCs": pcs,
        "variance_fraction": varfrac
    })
    outfile = f"{output_dir}/EOF_{varname}_{member}.nc"
    out.to_netcdf(outfile)
    print(f"✅ Saved EOF results for {varname} {member} → {outfile}")


# -----------------------------
# Main loop
# -----------------------------
for member in members:
    print(f"\nProcessing ensemble member: {member}")

    for var in variables:
        try:
            da = load_surface_data(var, member)
            da_anom = detrend_mean(da)
            eofs, pcs, varfrac = calculate_eof(da_anom, n_modes=10)
            save_eof_results(member, var, eofs, pcs, varfrac)

        except Exception as e:
            print(f"⚠️ Skipped {var} {member} due to: {e}")
print("\nAll done!")


Processing ensemble member: r1i1p1f1
✅ Saved EOF results for thetao r1i1p1f1 → /data/projects/nckf/frekle/EOF_results/EC-Earth3/EOF_thetao_r1i1p1f1.nc
✅ Saved EOF results for so r1i1p1f1 → /data/projects/nckf/frekle/EOF_results/EC-Earth3/EOF_so_r1i1p1f1.nc

Processing ensemble member: r2i1p1f1
✅ Saved EOF results for thetao r2i1p1f1 → /data/projects/nckf/frekle/EOF_results/EC-Earth3/EOF_thetao_r2i1p1f1.nc
✅ Saved EOF results for so r2i1p1f1 → /data/projects/nckf/frekle/EOF_results/EC-Earth3/EOF_so_r2i1p1f1.nc

Processing ensemble member: r3i1p1f1
✅ Saved EOF results for thetao r3i1p1f1 → /data/projects/nckf/frekle/EOF_results/EC-Earth3/EOF_thetao_r3i1p1f1.nc
✅ Saved EOF results for so r3i1p1f1 → /data/projects/nckf/frekle/EOF_results/EC-Earth3/EOF_so_r3i1p1f1.nc

Processing ensemble member: r4i1p1f1
✅ Saved EOF results for thetao r4i1p1f1 → /data/projects/nckf/frekle/EOF_results/EC-Earth3/EOF_thetao_r4i1p1f1.nc
✅ Saved EOF results for so r4i1p1f1 → /data/projects/nckf/frekle/EOF_resul

In [13]:
#!/usr/bin/env python3
"""
Compare EOF modes across ensemble members for a given CMIP6 model.

Creates one large figure:
 ├── 3 rows (for 3 selected members)
 └── 4 columns: EOF1 | EOF2 | EOF3 | PCs (PC1–3 time series)

Also includes:
 - PC1 correlation matrix across all members
 - Explained variance summary
 - Decodes time to calendar years
 - Full Atlantic extent & geographic labels

Author: Freja Klejnstrup
Date: 2025-10-29
"""

import os
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import cftime

# -----------------------------
# User settings
# -----------------------------
model = "EC-Earth3"
varname = "thetao"  # change to "so" for salinity
base_dir = f"/data/projects/nckf/frekle/EOF_results/{model}"
members = [f"r{i}i1p1f1" for i in range(1, 26)]
plot_members = ["r1i1p1f1", "r6i1p1f1", "r10i1p1f1", "r17i1p1f1"]  # shown in the figure
n_modes = 3

# Output directory
fig_root = "/data/users/frekle/EOF/Figures"
fig_dir = os.path.join(fig_root, varname)
os.makedirs(fig_dir, exist_ok=True)
print(f"📁 Figures will be saved in: {fig_dir}")

# -----------------------------
# Helper functions
# -----------------------------
def load_eof_results(member, var):
    path = f"{base_dir}/EOF_{var}_{member}.nc"
    if not os.path.exists(path):
        raise FileNotFoundError(path)
    return xr.open_dataset(path)


def decode_time_to_years(time):
    if np.issubdtype(time.dtype, np.number) and "units" in time.attrs:
        try:
            tdt = cftime.num2date(time.values, units=time.attrs["units"],
                                  calendar=time.attrs.get("calendar", "gregorian"))
            return np.array([t.year + t.timetuple().tm_yday / 365.25 for t in tdt])
        except Exception:
            return np.arange(len(time))
    else:
        try:
            arr = xr.DataArray(time)
            return (arr.dt.year + arr.dt.dayofyear / 365.25).values
        except Exception:
            return np.arange(len(time))


def correlate_pc_time_series(pc_dict, mode=0):
    names = list(pc_dict.keys())
    n = len(names)
    corr = np.zeros((n, n))
    for i in range(n):
        for j in range(n):
            a = pc_dict[names[i]][:, mode]
            b = pc_dict[names[j]][:, mode]
            mask = np.isfinite(a) & np.isfinite(b)
            corr[i, j] = np.corrcoef(a[mask], b[mask])[0, 1]
    return corr, names


# -----------------------------
# Load EOFs, PCs, variance
# -----------------------------
eof_patterns, pcs, varfrac, times = {}, {}, {}, {}

for m in members:
    try:
        ds = load_eof_results(m, varname)
        eof_patterns[m] = ds["EOFs"].isel(mode=slice(0, n_modes))
        pcs[m] = ds["PCs"].values
        varfrac[m] = ds["variance_fraction"].values
        years = decode_time_to_years(ds["PCs"]["time"]) if "time" in ds["PCs"].coords else np.arange(pcs[m].shape[0])
        times[m] = years
    except Exception as e:
        print(f"⚠️ Skipping {m}: {e}")

print(f"✅ Loaded EOF results for {len(eof_patterns)} members.")

# -----------------------------
# -----------------------------
# 1️⃣ Big figure: EOF maps + PC plots (3×4)
# -----------------------------
import matplotlib.gridspec as gridspec

fig = plt.figure(figsize=(18, 12))
gs = gridspec.GridSpec(len(plot_members), 4, width_ratios=[1, 1, 1, 1.2])

def draw_map(ax, field, lon, lat, title):
    pcm = ax.pcolormesh(lon, lat, field, transform=ccrs.PlateCarree(),
                        cmap="RdBu_r", shading="auto")
    ax.coastlines()
    ax.add_feature(cfeature.LAND, facecolor="lightgray", zorder=2)
    ax.set_extent([-100, 20, -60, 80])
    ax.set_xticks(np.arange(-100, 40, 20), crs=ccrs.PlateCarree())
    ax.set_yticks(np.arange(-60, 90, 20), crs=ccrs.PlateCarree())
    ax.tick_params(labelsize=8)
    ax.set_xlabel("Longitude (°E)", fontsize=9)
    ax.set_ylabel("Latitude (°N)", fontsize=9)
    ax.set_title(title, fontsize=9)
    return pcm

last_pcm = None
for i, member in enumerate(plot_members):
    if member not in eof_patterns:
        continue
    ds = eof_patterns[member]
    lon, lat = ds["longitude"], ds["latitude"]

    # --- Three Cartopy maps
    for mode in range(3):
        ax_map = fig.add_subplot(gs[i, mode], projection=ccrs.PlateCarree())
        title = f"EOF{mode+1} ({varfrac[member][mode]*100:.1f}%)"
        last_pcm = draw_map(ax_map, ds.isel(mode=mode), lon, lat, title)

    # --- One normal plot for PCs
    ax_pc = fig.add_subplot(gs[i, 3])
    for mode in range(3):
        ax_pc.plot(times[member], pcs[member][:, mode],
                   label=f"PC{mode+1}", lw=1.2)
    ax_pc.set_xlim(times[member][0], times[member][-1])
    ax_pc.set_xticks(np.arange(1850, 2020, 20))
    ax_pc.set_xlabel("Year")
    ax_pc.set_ylabel("Amplitude (a.u.)")
    ax_pc.set_title(f"{member} PCs", fontsize=9)
    ax_pc.grid(True, linestyle="--", alpha=0.4)
    if i == 0:
        ax_pc.legend(fontsize=8, loc="upper right")

# Shared colorbar
cax = fig.add_axes([0.91, 0.15, 0.015, 0.7])
fig.colorbar(last_pcm, cax=cax, label="EOF amplitude (°C)")

plt.suptitle(f"{model} {varname.upper()} | EOF Modes 1–3 + PCs (4 members)", fontsize=15)
plt.subplots_adjust(wspace=0.3, hspace=0.25, right=0.9, top=0.93, bottom=0.05)
plt.savefig(f"{fig_dir}/EOF_PCs_{varname}_4members_combined.png", dpi=300)
plt.close()
print("✅ Saved combined EOF+PC figure (4×4 grid).")


# -----------------------------
# 2️⃣ PC1 time series (all members + ensemble mean)
# -----------------------------
plt.figure(figsize=(12, 6))

# Collect PC1s into an aligned matrix
pc1_matrix = []
min_len = min(pcs[m].shape[0] for m in pcs.keys())
common_time = None

for member in pcs.keys():
    # Truncate to minimum length for safe averaging
    pc1 = pcs[member][:min_len, 0]
    pc1_matrix.append(pc1)
    if common_time is None:
        common_time = times[member][:min_len]
    plt.plot(times[member][:min_len], pc1, color='gray', alpha=0.4, lw=1)

# Ensemble mean
pc1_mean = np.mean(pc1_matrix, axis=0)
plt.plot(common_time, pc1_mean, color='black', lw=2.5, label='Ensemble mean')

# Plot styling
plt.title(f"{model} {varname.upper()} | PC1 Time Series Across Members")
plt.xlabel("Year")
plt.ylabel("PC1 (a.u.)")
plt.xticks(np.arange(1850, 2020, 20))
plt.legend(frameon=False)
plt.grid(True, linestyle="--", alpha=0.4)
plt.tight_layout()
plt.savefig(f"{fig_dir}/PC1_timeseries_{varname}.png", dpi=300)
plt.close()
print("✅ Saved PC1 time series comparison plot (with ensemble mean).")

# -----------------------------
# 3️⃣ Explained variance (mean ± std)
# -----------------------------
var_table = np.array([varfrac[m][:3] for m in pcs.keys()])
mean_var = var_table.mean(axis=0)
std_var = var_table.std(axis=0)

plt.figure(figsize=(8, 5))
x = np.arange(1, 4)
plt.bar(x, mean_var*100, yerr=std_var*100, capsize=5)
plt.xticks(x, [f"EOF{i}" for i in x])
plt.ylabel("Explained Variance (%)")
plt.title(f"{model} {varname.upper()} | Mean Explained Variance (±1σ)")
plt.tight_layout()
plt.savefig(f"{fig_dir}/explained_variance_{varname}.png", dpi=300)
plt.close()
print("✅ Saved explained variance summary.")






📁 Figures will be saved in: /data/users/frekle/EOF/Figures/thetao
⚠️ Skipping r22i1p1f1: /data/projects/nckf/frekle/EOF_results/EC-Earth3/EOF_thetao_r22i1p1f1.nc
⚠️ Skipping r23i1p1f1: /data/projects/nckf/frekle/EOF_results/EC-Earth3/EOF_thetao_r23i1p1f1.nc
⚠️ Skipping r25i1p1f1: /data/projects/nckf/frekle/EOF_results/EC-Earth3/EOF_thetao_r25i1p1f1.nc
✅ Loaded EOF results for 22 members.
✅ Saved combined EOF+PC figure (4×4 grid).
✅ Saved PC1 time series comparison plot (with ensemble mean).
✅ Saved explained variance summary.
