In [None]:
import xarray as xr
from pathlib import Path
import myfunctions as mf

In [None]:
# =========================
# User-defined metadata
# =========================

varname = "tos"

# =========================
# Base CEDA paths
# =========================

CEDA_BASE = Path("/badc/cmip6/data/CMIP6")

In [None]:
#Model Names : IMPORTANT NOTE : grid descriptions NOT SAME as atmospheric data
MODELS = {
    "UKESM1-0-LL":  {"institution": "MOHC",         "ensemble": "r1i1p1f2",  "grid": "gn",},
    # "CNRM-ESM2-1":  {"institution": "CNRM-CERFACS", "ensemble": "r1i1p1f2",  "grid": "gn",},
    # "MPI-ESM1-2-LR":{"institution": "MPI-M",        "ensemble": "r1i1p1f1",  "grid": "gn",},
    # "CESM2-WACCM":  {"institution": "NCAR",         "ensemble": "r1i1p1f1",  "grid": "gn",},
    # "IPSL-CM6A-LR": {"institution": "IPSL",         "ensemble": "r1i1p1f1",  "grid": "gr",},
}


In [None]:
#Experiment details
EXPERIMENTS = {
    "HIST":     {"project": "CMIP",        "scenario": "historical", "color": "black"},
    "SSP245":   {"project": "ScenarioMIP", "scenario": "ssp245"},
    "SSP585":   {"project": "ScenarioMIP", "scenario": "ssp585"},
    "G6solar":  {"project": "GeoMIP",      "scenario": "G6solar"},
    "G6sulfur": {"project": "GeoMIP",      "scenario": "G6sulfur"},
}

In [None]:
for model_name, model_meta in MODELS.items():

    var = {}
    var_by_year = {}

    # --- LOAD DATA FOR A MODEL ---
    for exp, meta in EXPERIMENTS.items():
        if exp not in ["SSP245", "G6sulfur", "G6solar"]:
            continue

        # open dataset
        # --- special-case ensemble override ---
        if model_name == "CESM2-WACCM":
            if meta["scenario"] == "G6sulfur":
                ensemble = "r1i1p1f2"
            else:
                ensemble = "r1i1p1f1"
        else:
            ensemble = model_meta["ensemble"]


        base = (
            CEDA_BASE
            / meta["project"]
            / model_meta["institution"]
            / model_name
            / meta["scenario"]
            / ensemble
            / "Omon"
            / varname
            / model_meta["grid"]
            / "latest"
        )

        print(str(base))
        # ds = mf.open_files(str(base))
        if model_name == "CESM2-WACCM":
            if meta["scenario"] == "G6sulfur":
                ds = mf.open_files_CESM_G6sulfur(base)
            elif meta["scenario"] == "ssp585":
                ds = mf.open_files_CESM_ssp585(base)
            else:
                ds = mf.open_files(str(base))
        elif model_name == "IPSL-CM6A-LR":
            if meta["scenario"] == "ssp585":
                ds = mf.open_files_IPSL_ssp585(base)
            else:
                ds = mf.open_files(str(base))
        else:
            ds = mf.open_files(str(base))



        
        # ds = ...
        var[exp] = mf.read_var(ds, varname)

    # --- ANNUAL MEAN FOR THIS MODEL ---
    for exp, da in var.items():
        var_by_year[exp] = {
            "ANN": mf.seasonal_mean_by_year(da, 1, 12)
        }
        # londim = mf.get_lon_dim(da)

    # --- 2081–2100 MEAN FOR THIS MODEL ---
    ts_mean = {}
    for exp in ["SSP245", "G6sulfur", "G6solar"]:
        ts_mean[exp] = (
            var_by_year[exp]["ANN"]
            .sel(year=slice(2081, 2101))
            .mean(dim="year")
        )

    # --- DIFFERENCES (MODEL-SPECIFIC) ---
    ts_mean["G6sulfur-SSP245"] = ts_mean["G6sulfur"] - ts_mean["SSP245"]
    ts_mean["G6solar-SSP245"]  = ts_mean["G6solar"]  - ts_mean["SSP245"]
    ts_mean["G6solar-G6sulfur"] = ts_mean["G6solar"] - ts_mean["G6sulfur"]

    # STORE RESULTS FOR ONE MODEL ONLY
    model_meta["ts_mean"] = ts_mean
    #Verification
    print(model_name, ts_mean["SSP245"].mean().values)

In [None]:
ts_mean

In [None]:
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import numpy as np

years_slice = slice(2081, 2101)  # 2081-2100
SCENARIOS = ["SSP245", "G6sulfur", "G6solar"]
DIFFS = [("G6sulfur", "SSP245"), ("G6solar", "SSP245"), ("G6solar", "G6sulfur")]

n_models = len(MODELS)+1
n_rows = 4
n_cols = n_models

fig, axes = plt.subplots(
    nrows=n_rows,
    ncols=n_cols,
    figsize=(2* 4 * n_cols, 3 * n_rows),
    subplot_kw={"projection": ccrs.PlateCarree(central_longitude=180)},
)

axes = np.atleast_2d(axes)

# # Compute min/max for shared colorbars
# ssp245_data = [m["ts_mean"]["SSP245"] for m in MODELS.values()]
# g6sulfur_diff = [m["ts_mean"]["G6sulfur-SSP245"] for m in MODELS.values()]
# g6solar_diff = [m["ts_mean"]["G6solar-SSP245"] for m in MODELS.values()]
# g6solar_g6sulfur_diff = [m["ts_mean"]["G6solar-G6sulfur"] for m in MODELS.values()]

ssp245_vmin = 20
ssp245_vmax = 33

g6sulfur_vmax = g6solar_vmax = g6solar_g6sulfur_vmax = 1

# --- plotting ---
for col_idx, (model_name, model_meta) in enumerate(MODELS.items()):
    ts_mean = model_meta["ts_mean"]
    row_axes = axes[:, col_idx]

    # Row 0: SSP245
    im0 = row_axes[0].pcolormesh(
        ts_mean["SSP245"].longitude,
        ts_mean["SSP245"].latitude,
        ts_mean["SSP245"],
        vmin=ssp245_vmin,
        vmax=ssp245_vmax,
        cmap="jet",
        shading="auto",
        transform=ccrs.PlateCarree()
    )
    row_axes[0].add_feature(cfeature.LAND,facecolor="white",zorder=10)
    row_axes[0].coastlines()
    row_axes[0].set_title(f"{model_name} SSP245", fontsize=10)

    # Row 1: G6sulfur - SSP245
    im1 = row_axes[1].pcolormesh(
        ts_mean["G6sulfur-SSP245"].longitude,
        ts_mean["G6sulfur-SSP245"].latitude,
        ts_mean["G6sulfur-SSP245"],
        vmin=-g6sulfur_vmax,
        vmax=g6sulfur_vmax,
        cmap="bwr",
        shading="auto",
        transform=ccrs.PlateCarree()
    )
    row_axes[1].add_feature(cfeature.LAND,facecolor="white",zorder=10)
    row_axes[1].coastlines()
    row_axes[1].set_title("G6sulfur − SSP245", fontsize=10)

    # Row 2: G6solar - SSP245
    im2 = row_axes[2].pcolormesh(
        ts_mean["G6solar-SSP245"].longitude,
        ts_mean["G6solar-SSP245"].latitude,
        ts_mean["G6solar-SSP245"],
        vmin=-g6solar_vmax,
        vmax=g6solar_vmax,
        cmap="bwr",
        shading="auto",
        transform=ccrs.PlateCarree()
    )
    row_axes[2].add_feature(cfeature.LAND,facecolor="white",zorder=10)
    row_axes[2].coastlines()
    row_axes[2].set_title("G6solar − SSP245", fontsize=10)

    # Row 3: G6solar - G6sulfur
    im3 = row_axes[3].pcolormesh(
        ts_mean["G6solar-G6sulfur"].longitude,
        ts_mean["G6solar-G6sulfur"].latitude,
        ts_mean["G6solar-G6sulfur"],
        vmin=-g6solar_g6sulfur_vmax,
        vmax=g6solar_g6sulfur_vmax,
        cmap="bwr",
        shading="auto",
        transform=ccrs.PlateCarree()
    )
    row_axes[3].add_feature(cfeature.LAND,facecolor="white",zorder=10)
    row_axes[3].coastlines()
    row_axes[3].set_title("G6solar − G6sulfur", fontsize=10)


# --- Shared colorbars per row, aligned with rightmost panel ---
for row_idx, im in enumerate([im0, im1, im2, im3]):
    # Rightmost axes in this row
    ax = axes[row_idx, -1]

    # Get its position in figure coordinates
    pos = ax.get_position()  # Bbox(x0, y0, x1, y1)

    # Create colorbar axes slightly to the right
    cbar_ax = fig.add_axes([pos.x1 + 0.01, pos.y0, 0.015, pos.height])  # left, bottom, width, height

    # Add colorbar
    fig.colorbar(im, cax=cbar_ax, orientation='vertical')

plt.show()


***NORMALIZED or RELATIVE SST***
subtract own global mean SST

In [None]:
import numpy as np
import xarray as xr

# --- CONFIG ---
years_slice = slice(2081, 2101)
SCENARIOS = ["SSP245", "G6sulfur", "G6solar"]
trop_lat=30

# --- LOOP OVER MODELS ---
for model_name, model_meta in MODELS.items():

    var = {}
    var_by_year = {}

    # --- LOAD DATA FOR EACH SCENARIO ---
    for exp, meta in EXPERIMENTS.items():
        if exp not in SCENARIOS:
            continue

        # --- handle ensemble special cases ---
        if model_name == "CESM2-WACCM":
            ensemble = "r1i1p1f2" if meta["scenario"] == "G6sulfur" else "r1i1p1f1"
        else:
            ensemble = model_meta["ensemble"]

        base = (
            CEDA_BASE
            / meta["project"]
            / model_meta["institution"]
            / model_name
            / meta["scenario"]
            / ensemble
            / "Omon"
            / varname
            / model_meta["grid"]
            / "latest"
        )

        print(f"Loading {exp} for {model_name}: {base}")
        # Load dataset (customized by your mf functions)
        if model_name == "CESM2-WACCM":
            if meta["scenario"] == "G6sulfur":
                ds = mf.open_files_CESM_G6sulfur(base)
            elif meta["scenario"] == "ssp585":
                ds = mf.open_files_CESM_ssp585(base)
            else:
                ds = mf.open_files(str(base))
        elif model_name == "IPSL-CM6A-LR" and meta["scenario"] == "ssp585":
            ds = mf.open_files_IPSL_ssp585(base)
        else:
            ds = mf.open_files(str(base))

        var[exp] = mf.read_var(ds, varname)

    # --- ANNUAL MEAN ---
    for exp, da in var.items():
        var_by_year[exp] = {"ANN": mf.seasonal_mean_by_year(da, 1, 12)}

    # --- TROPICAL MASK USING 1D LATITUDE ---
    lat_2d = var_by_year[SCENARIOS[0]]["ANN"]["latitude"]  # (j,i)
    lat_1d = lat_2d.mean(dim="i").compute()  # average along longitude -> 1D lat
    trop_mask = (lat_1d >= -trop_lat) & (lat_1d <= trop_lat)  # mask tropical latitudes

    # --- 2081–2100 MEAN, TROPICAL SIMPLE SPATIAL MEAN ---
    ts_mean = {}
    for exp in SCENARIOS:
        da = var_by_year[exp]["ANN"].sel(year=years_slice)
        da_trop = da.sel(j=trop_mask)
        # simple mean over years (weighted average to be done later, if needed)
        ts_mean[exp] = da_trop.mean(dim=("year"))

    # --- DIFFERENCES ---
    ts_mean["G6sulfur-SSP245"] = ts_mean["G6sulfur"] - ts_mean["SSP245"]
    ts_mean["G6solar-SSP245"]  = ts_mean["G6solar"]  - ts_mean["SSP245"]
    ts_mean["G6solar-G6sulfur"] = ts_mean["G6solar"] - ts_mean["G6sulfur"]

    # # --- Normalize differences ---
    for key in ["G6sulfur-SSP245", "G6solar-SSP245", "G6solar-G6sulfur"]:
        spatial_mean = ts_mean[key].mean(dim=("j", "i"))
        ts_mean[key] = ts_mean[key] - spatial_mean


    # --- STORE RESULTS ---
    model_meta["ts_mean"] = ts_mean

    # --- VERIFICATION ---
    print(f"{model_name} SSP245 tropical mean: {ts_mean['SSP245'].mean().values:.3f} °C")


In [None]:
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import numpy as np

years_slice = slice(2081, 2101)  # 2081-2100
SCENARIOS = ["SSP245", "G6sulfur", "G6solar"]
DIFFS = [("G6sulfur", "SSP245"), ("G6solar", "SSP245"), ("G6solar", "G6sulfur")]

n_models = len(MODELS)+1
n_rows = 4
n_cols = n_models

fig, axes = plt.subplots(
    nrows=n_rows,
    ncols=n_cols,
    figsize=(2* 4 * n_cols, 3 * n_rows),
    subplot_kw={"projection": ccrs.PlateCarree(central_longitude=180)},
)

axes = np.atleast_2d(axes)

# # Compute min/max for shared colorbars
# ssp245_data = [m["ts_mean"]["SSP245"] for m in MODELS.values()]
# g6sulfur_diff = [m["ts_mean"]["G6sulfur-SSP245"] for m in MODELS.values()]
# g6solar_diff = [m["ts_mean"]["G6solar-SSP245"] for m in MODELS.values()]
# g6solar_g6sulfur_diff = [m["ts_mean"]["G6solar-G6sulfur"] for m in MODELS.values()]

ssp245_vmin = 20
ssp245_vmax = 33

g6sulfur_vmax = g6solar_vmax = g6solar_g6sulfur_vmax = 1

# --- plotting ---
for col_idx, (model_name, model_meta) in enumerate(MODELS.items()):
    ts_mean = model_meta["ts_mean"]
    row_axes = axes[:, col_idx]

    # Row 0: SSP245
    im0 = row_axes[0].pcolormesh(
        ts_mean["SSP245"].longitude,
        ts_mean["SSP245"].latitude,
        ts_mean["SSP245"],
        vmin=ssp245_vmin,
        vmax=ssp245_vmax,
        cmap="jet",
        shading="auto",
        transform=ccrs.PlateCarree()
    )
    row_axes[0].add_feature(cfeature.LAND,facecolor="white",zorder=10)
    row_axes[0].coastlines()
    row_axes[0].set_title(f"{model_name} SSP245", fontsize=10)

    # Row 1: G6sulfur - SSP245
    im1 = row_axes[1].pcolormesh(
        ts_mean["G6sulfur-SSP245"].longitude,
        ts_mean["G6sulfur-SSP245"].latitude,
        ts_mean["G6sulfur-SSP245"],
        vmin=-g6sulfur_vmax,
        vmax=g6sulfur_vmax,
        cmap="bwr",
        shading="auto",
        transform=ccrs.PlateCarree()
    )
    row_axes[1].add_feature(cfeature.LAND,facecolor="white",zorder=10)
    row_axes[1].coastlines()
    row_axes[1].set_title("G6sulfur − SSP245 (Relative SST)", fontsize=10)

    # Row 2: G6solar - SSP245
    im2 = row_axes[2].pcolormesh(
        ts_mean["G6solar-SSP245"].longitude,
        ts_mean["G6solar-SSP245"].latitude,
        ts_mean["G6solar-SSP245"],
        vmin=-g6solar_vmax,
        vmax=g6solar_vmax,
        cmap="bwr",
        shading="auto",
        transform=ccrs.PlateCarree()
    )
    row_axes[2].add_feature(cfeature.LAND,facecolor="white",zorder=10)
    row_axes[2].coastlines()
    row_axes[2].set_title("G6solar − SSP245 (Relative SST)", fontsize=10)

    # Row 3: G6solar - G6sulfur
    im3 = row_axes[3].pcolormesh(
        ts_mean["G6solar-G6sulfur"].longitude,
        ts_mean["G6solar-G6sulfur"].latitude,
        ts_mean["G6solar-G6sulfur"],
        vmin=-g6solar_g6sulfur_vmax,
        vmax=g6solar_g6sulfur_vmax,
        cmap="bwr",
        shading="auto",
        transform=ccrs.PlateCarree()
    )
    row_axes[3].add_feature(cfeature.LAND,facecolor="white",zorder=10)
    row_axes[3].coastlines()
    row_axes[3].set_title("G6solar − G6sulfur (Relative SST)", fontsize=10)


# --- Shared colorbars per row, aligned with rightmost panel ---
for row_idx, im in enumerate([im0, im1, im2, im3]):
    # Rightmost axes in this row
    ax = axes[row_idx, -1]

    # Get its position in figure coordinates
    pos = ax.get_position()  # Bbox(x0, y0, x1, y1)

    # Create colorbar axes slightly to the right
    cbar_ax = fig.add_axes([pos.x1 + 0.01, pos.y0, 0.015, pos.height])  # left, bottom, width, height

    # Add colorbar
    fig.colorbar(im, cax=cbar_ax, orientation='vertical')

plt.show()
