In [None]:
import xarray as xr
from pathlib import Path
import myfunctions as mf

In [None]:
# =========================
# User-defined metadata
# =========================

varname = "ta"

# =========================
# Base CEDA paths
# =========================

CEDA_BASE = Path("/badc/cmip6/data/CMIP6")


# =========================
# Analysis period
# =========================
year_start=2071
year_end=2100

In [None]:
#Model Names
MODELS = {
    "UKESM1-0-LL":  {"institution": "MOHC",         "ensemble": "r1i1p1f2",  "grid": "gn",},
    # "CNRM-ESM2-1":  {"institution": "CNRM-CERFACS", "ensemble": "r1i1p1f2",  "grid": "gr",},
    # "MPI-ESM1-2-LR":{"institution": "MPI-M",        "ensemble": "r1i1p1f1",  "grid": "gn",},
    # "CESM2-WACCM":  {"institution": "NCAR",         "ensemble": "r1i1p1f1",  "grid": "gn",},
    # "IPSL-CM6A-LR": {"institution": "IPSL",         "ensemble": "r1i1p1f1",  "grid": "gr",},
}


In [None]:
#Experiment details
EXPERIMENTS = {
    "HIST":     {"project": "CMIP",        "scenario": "historical", "color": "black"},
    "SSP245":   {"project": "ScenarioMIP", "scenario": "ssp245"},
    "SSP585":   {"project": "ScenarioMIP", "scenario": "ssp585"},
    "G6solar":  {"project": "GeoMIP",      "scenario": "G6solar"},
    "G6sulfur": {"project": "GeoMIP",      "scenario": "G6sulfur"},
}

In [None]:
for model_name, model_meta in MODELS.items():

    var = {}
    var_by_year = {}

    # --- LOAD DATA FOR A MODEL ---
    for exp, meta in EXPERIMENTS.items():
        if exp not in ["SSP245", "G6sulfur", "G6solar"]:
            continue

        # open dataset
        # --- special-case ensemble override ---
        if model_name == "CESM2-WACCM":
            if meta["scenario"] == "G6sulfur":
                ensemble = "r1i1p1f2"
            else:
                ensemble = "r1i1p1f1"
        else:
            ensemble = model_meta["ensemble"]


        base = (
            CEDA_BASE
            / meta["project"]
            / model_meta["institution"]
            / model_name
            / meta["scenario"]
            / ensemble
            / "Amon"
            / varname
            / model_meta["grid"]
            / "latest"
        )

        print(str(base))
        # ds = mf.open_files(str(base))
        if model_name == "CESM2-WACCM":
            if meta["scenario"] == "G6sulfur":
                ds = mf.open_files_CESM_G6sulfur(base)
            elif meta["scenario"] == "ssp585":
                ds = mf.open_files_CESM_ssp585(base)
            else:
                ds = mf.open_files(str(base))
        elif model_name == "IPSL-CM6A-LR":
            if meta["scenario"] == "ssp585":
                ds = mf.open_files_IPSL_ssp585(base)
            else:
                ds = mf.open_files(str(base))
        else:
            ds = mf.open_files(str(base))



        
        # ds = ...
        var[exp] = mf.read_var(ds, varname)

    # --- ANNUAL MEAN FOR THIS MODEL ---
    for exp, da in var.items():
        var_by_year[exp] = {
            "ANN": mf.seasonal_mean_by_year(da, 1, 12)
        }
        londim = mf.get_lon_dim(da)

    # --- 2081–2100 MEAN FOR THIS MODEL ---
    ts_mean = {}
    for exp in ["SSP245", "G6sulfur", "G6solar"]:
        ts_mean[exp] = (
            var_by_year[exp]["ANN"]
            .sel(year=slice(year_start, year_end))
            .mean(dim=("year",londim))
        )

    # --- DIFFERENCES (MODEL-SPECIFIC) ---
    ts_mean["G6sulfur-SSP245"] = ts_mean["G6sulfur"] - ts_mean["SSP245"]
    ts_mean["G6solar-SSP245"]  = ts_mean["G6solar"]  - ts_mean["SSP245"]
    ts_mean["G6solar-G6sulfur"] = ts_mean["G6solar"] - ts_mean["G6sulfur"]

    # STORE RESULTS FOR ONE MODEL ONLY
    model_meta["ts_mean"] = ts_mean
    #Verification
    print(model_name, ts_mean["SSP245"].mean().values)

In [None]:
ts_mean

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# --- Settings ---
SCENARIOS = ["SSP245", "G6sulfur", "G6solar"]
DIFFS = ["G6sulfur-SSP245", "G6solar-SSP245", "G6solar-G6sulfur"]

n_models = len(MODELS)+1
n_rows = 4  # SSP245 + 3 differences
n_cols = n_models

fig, axes = plt.subplots(
    nrows=n_rows,
    ncols=n_cols,
    figsize=(4 * n_cols, 3 * n_rows),
    squeeze=False  # ensures axes is 2D
)

# --- Colorbar limits ---
ssp245_vmin, ssp245_vmax = 220, 310
diff_vmax = 5  # symmetric for all difference plots

# --- Plotting ---
for col_idx, (model_name, model_meta) in enumerate(MODELS.items()):
    ts_mean = model_meta["ts_mean"]
    row_axes = axes[:, col_idx]

    # Use lev if exists, else plev
    lev_var = "lev" if "lev" in ts_mean["SSP245"].dims else "plev"
    lev = ts_mean["SSP245"][lev_var].values
    # Convert Pa → hPa if neccessary (checked by ckecking if values are too large)
    if lev.max() > 2000:
        lev = lev / 100


    
    lev = ts_mean["SSP245"][lev_var].values / 100  # convert Pa → hPa
    lat = ts_mean["SSP245"].lat.values

    # --- Row 0: SSP245 ---
    im0 = row_axes[0].pcolormesh(
        lat, lev, ts_mean["SSP245"].values,
        vmin=ssp245_vmin, vmax=ssp245_vmax,
        cmap="coolwarm",
        shading="auto"
    )
    row_axes[0].invert_yaxis()
    row_axes[0].set_ylabel("Pressure (hPa)")
    row_axes[0].set_title(f"{model_name} SSP245", fontsize=10)

    # --- Row 1: G6sulfur - SSP245 ---
    im1 = row_axes[1].pcolormesh(
        lat, lev, ts_mean["G6sulfur-SSP245"].values,
        vmin=-diff_vmax, vmax=diff_vmax,
        cmap="bwr",
        shading="auto"
    )
    row_axes[1].invert_yaxis()
    row_axes[1].set_ylabel("Pressure (hPa)")
    row_axes[1].set_title("G6sulfur − SSP245", fontsize=10)

    # --- Row 2: G6solar - SSP245 ---
    im2 = row_axes[2].pcolormesh(
        lat, lev, ts_mean["G6solar-SSP245"].values,
        vmin=-diff_vmax, vmax=diff_vmax,
        cmap="bwr",
        shading="auto"
    )
    row_axes[2].invert_yaxis()
    row_axes[2].set_ylabel("Pressure (hPa)")
    row_axes[2].set_title("G6solar − SSP245", fontsize=10)

    # --- Row 3: G6solar - G6sulfur ---
    im3 = row_axes[3].pcolormesh(
        lat, lev, ts_mean["G6solar-G6sulfur"].values,
        vmin=-diff_vmax, vmax=diff_vmax,
        cmap="bwr",
        shading="auto"
    )
    row_axes[3].invert_yaxis()
    row_axes[3].set_ylabel("Pressure (hPa)")
    row_axes[3].set_xlabel("Latitude")
    row_axes[3].set_title("G6solar − G6sulfur", fontsize=10)

# --- Shared colorbars per row, aligned with rightmost panel ---
for row_idx, im in enumerate([im0, im1, im2, im3]):
    # Rightmost axes in this row
    ax = axes[row_idx, -1]

    # Get its position in figure coordinates
    pos = ax.get_position()  # Bbox(x0, y0, x1, y1)

    # Create colorbar axes slightly to the right
    cbar_ax = fig.add_axes([pos.x1 + 0.01, pos.y0, 0.015, pos.height])  # left, bottom, width, height

    # Add colorbar
    fig.colorbar(im, cax=cbar_ax, orientation='vertical')

# plt.tight_layout()
plt.show()


**No other model has level-wise temperature data**