**Script to check basic features of UKESM G6solar precip field.**

G6sulfur SAI scheme:

Aerosol is injected in a 20-degree wide curtain along 0E longitude
*A high-warming scenario (SSP5‑8.5) is run, but SO₂ is continuously injected*
*so that global mean temperatures track the more moderate SSP2‑4.5 pathway*

In G6solar simulations solar constant is tunned (mimicking G6sulfur SAI) to cool temperatures
from SSP5‑8.5 (very high warming) down toward of SSP2‑4.5 (moderate warming).

In this script:

1) compare annual mean rainfall G6solar(SAI simulation) with SSP5‑8.5(baseline simulation)


In [None]:
import xarray as xr
from pathlib import Path

In [None]:
# =========================
# User-defined metadata
# =========================
model_name = "UKESM1-0-LL"
ensemble_name = "r1i1p1f2"
varname = "pr"

In [None]:
# =========================
# Base CEDA paths
# =========================
CEDA_BASE = Path("/badc/cmip6/data/CMIP6")

SSP585_base = (CEDA_BASE/"ScenarioMIP"/"MOHC"/model_name/"ssp585"/ensemble_name/"Amon"/varname/"gn"/"latest")
SSP245_base = (CEDA_BASE/"ScenarioMIP"/"MOHC"/model_name/"ssp245"/ensemble_name/"Amon"/varname/"gn"/"latest")
Hist_base = (CEDA_BASE/"CMIP"/"MOHC"/model_name/"historical"/ensemble_name/"Amon"/varname/"gn"/"latest")
SAI_base = (CEDA_BASE/"GeoMIP"/"MOHC"/model_name/"G6solar"/ensemble_name/"Amon"/varname/"gn"/"latest")


SSP585_path = str(SSP585_base)
SSP245_path = str(SSP245_base)
Hist_path = str(Hist_base)
SAI_path = str(SAI_base)

In [None]:
# =========================
# Functions
# =========================
def open_files(path):
    """Open all NetCDF files in a directory with xarray"""
    return xr.open_mfdataset(
        f"{path}/*.nc",
        combine="by_coords",
        parallel=True
    )

def read_var(ds, var):
    """Extract a variable from dataset"""
    return ds[var]

In [None]:
#Open file
ds_ssp585 = open_files(SSP585_path)
ds_ssp245 = open_files(SSP245_path)
ds_Hist = open_files(Hist_path)
ds_sai = open_files(SAI_path)

In [None]:
#Read Variable
pr_ssp585 = read_var(ds_ssp585, varname)
pr_ssp245 = read_var(ds_ssp245, varname)
pr_hist = read_var(ds_Hist, varname)
pr_sai = read_var(ds_sai, varname)

In [None]:
import matplotlib.pyplot as plt
plt.plot(pr_sai.mean(dim=("lat", "lon")))
plt.plot(pr_ssp585.mean(dim=("lat", "lon")))
plt.plot(pr_ssp245.mean(dim=("lat", "lon")))
plt.plot(pr_hist.mean(dim=("lat", "lon")))

In [None]:
def climatological_mean(da, start_month, end_month):
    """
    Compute climatological mean over all years
    for a given month range.
    
    Handles cross-year seasons (e.g. DJF).
    
    Parameters
    ----------
    da : xarray.DataArray
        Input data with time dimension
    start_month : int
        Starting month (1–12)
    end_month : int
        Ending month (1–12)
        
    Returns
    -------
    xarray.DataArray
        Climatological mean
    """

    month = da["time"].dt.month

    if start_month <= end_month:
        # Normal season (e.g. JJAS)
        da_sel = da.where(
            (month >= start_month) & (month <= end_month),
            drop=True
        )
    else:
        # Cross-year season (e.g. DJF)
        da_sel = da.where(
            (month >= start_month) | (month <= end_month),
            drop=True
        )

    return da_sel.mean("time")


In [None]:
#Annual mean (Jan–Dec)
pr_sai_ann = climatological_mean(pr_sai, 1, 12)
pr_ssp585_ann = climatological_mean(pr_ssp585, 1, 12)
pr_ssp245_ann = climatological_mean(pr_ssp245, 1, 12)
pr_hist_ann = climatological_mean(pr_hist, 1, 12)

#JJAS (Jun–Sep)
pr_sai_jjas = climatological_mean(pr_sai, 6, 9)
pr_ssp585_jjas = climatological_mean(pr_ssp585, 6, 9)
pr_ssp245_jjas = climatological_mean(pr_ssp245, 6, 9)
pr_hist_jjas = climatological_mean(pr_hist, 6, 9)

#DJF (Dec–Feb)
pr_sai_djf = climatological_mean(pr_sai, 12, 2)
pr_ssp585_djf = climatological_mean(pr_ssp585, 12, 2)
pr_ssp245_djf = climatological_mean(pr_ssp245, 12, 2)
pr_hist_djf = climatological_mean(pr_hist, 12, 2)



In [None]:
def plot_precip_diff_3panel(
    ann_A, jjas_A, djf_A,
    ann_B, jjas_B, djf_B,
    label_A="EXP_A",
    label_B="EXP_B",
    lat_min=-50,
    lat_max=50,
    vmin=-2,
    vmax=2,
    cmap="RdBu_r"
):
    """
    Plot 3-panel (Annual, JJAS, DJF) precipitation differences (A − B).
    
    All inputs must be climatological means (lat, lon).
    """

    import matplotlib.pyplot as plt
    import cartopy.crs as ccrs
    import cartopy.feature as cfeature

    SEC_PER_DAY = 86400.0
    proj = ccrs.PlateCarree()

    def slice_lat(da):
        return da.sel(lat=slice(lat_min, lat_max))

    # Compute differences and convert units
    diff_ann  = slice_lat((ann_A  - ann_B)  * SEC_PER_DAY)
    diff_jjas = slice_lat((jjas_A - jjas_B) * SEC_PER_DAY)
    diff_djf  = slice_lat((djf_A  - djf_B)  * SEC_PER_DAY)

    data_list = [diff_ann, diff_jjas, diff_djf]

    titles = [
        f"Annual mean ({label_A} − {label_B})",
        f"JJAS mean ({label_A} − {label_B})",
        f"DJF mean ({label_A} − {label_B})"
    ]

    fig, axes = plt.subplots(
        nrows=3, ncols=1,
        figsize=(10, 13),
        subplot_kw={"projection": proj},
        constrained_layout=True
    )

    for ax, da, title in zip(axes, data_list, titles):
        im = da.plot(
            ax=ax,
            transform=ccrs.PlateCarree(),
            cmap=cmap,
            vmin=vmin,
            vmax=vmax,
            add_colorbar=False
        )

        ax.coastlines(linewidth=0.8)
        ax.add_feature(cfeature.BORDERS, linewidth=0.4)
        ax.set_title(title)
        # ax.set_extent([0, 360, lat_min, lat_max], crs=proj)

    cbar = fig.colorbar(
        im,
        ax=axes,
        orientation="horizontal",
        fraction=0.05,
        pad=0.08
    )
    cbar.set_label("Precipitation difference (mm day$^{-1}$)")

    plt.show()


In [None]:
# G6solar − SSP585
plot_precip_diff_3panel(
    pr_sai_ann, pr_sai_jjas, pr_sai_djf,
    pr_ssp585_ann, pr_ssp585_jjas, pr_ssp585_djf,
    label_A="G6solar",
    label_B="SSP585"
)


In [None]:
# G6solar − SSP245
plot_precip_diff_3panel(
    pr_sai_ann, pr_sai_jjas, pr_sai_djf,
    pr_ssp245_ann, pr_ssp245_jjas, pr_ssp245_djf,
    label_A="G6solar",
    label_B="SSP245"
)


In [None]:
# G6solar − Historical
plot_precip_diff_3panel(
    pr_sai_ann, pr_sai_jjas, pr_sai_djf,
    pr_hist_ann, pr_hist_jjas, pr_hist_djf,
    label_A="G6solar",
    label_B="Hist"
)


In [None]:
import xarray as xr

def seasonal_mean_by_year(da, start_month, end_month):
    """
    Compute seasonal mean for each year.
    Handles cross-year seasons (e.g. DJF) correctly.
    """

    month = da["time"].dt.month
    year = da["time"].dt.year

    if start_month <= end_month:
        # Normal season (e.g. JJAS)
        da_sel = da.where(
            (month >= start_month) & (month <= end_month),
            drop=True
        )

        da_season = da_sel.groupby("time.year").mean("time")

    else:
        # Cross-year season (e.g. DJF)
        da_sel = da.where(
            (month >= start_month) | (month <= end_month),
            drop=True
        )

        # Recompute month/year AFTER selection
        month_sel = da_sel["time"].dt.month
        year_sel = da_sel["time"].dt.year

        season_year = xr.where(month_sel == 12, year_sel + 1, year_sel)

        da_season = (
            da_sel
            .assign_coords(season_year=("time", season_year.data))
            .groupby("season_year")
            .mean("time")
            .rename({"season_year": "year"})
        )

    return da_season


In [None]:

# #Compute mean By year
# #Annual mean (Jan–Dec)
# pr_sai_ann_by_year = seasonal_mean_by_year(pr_sai, 1, 12)
# pr_ssp585_ann_by_year = seasonal_mean_by_year(pr_ssp585, 1, 12)
# pr_ssp245_ann_by_year = seasonal_mean_by_year(pr_ssp245, 1, 12)
# pr_hist_ann_by_year = seasonal_mean_by_year(pr_hist, 1, 12)

# #JJAS (Jun–Sep)
# pr_sai_jjas_by_year = seasonal_mean_by_year(pr_sai, 6, 9)
# pr_ssp585_jjas_by_year = seasonal_mean_by_year(pr_ssp585, 6, 9)
# pr_ssp245_jjas_by_year = seasonal_mean_by_year(pr_ssp245, 6, 9)
# pr_hist_jjas_by_year = seasonal_mean_by_year(pr_hist, 6, 9)

# #DJF (Dec–Feb)
# pr_sai_djf_by_year = seasonal_mean_by_year(pr_sai, 12, 2)
# pr_ssp585_djf_by_year = seasonal_mean_by_year(pr_ssp585, 12, 2)
# pr_ssp245_djf_by_year = seasonal_mean_by_year(pr_ssp245, 12, 2)
# pr_hist_djf_by_year = seasonal_mean_by_year(pr_hist, 12, 2)

# #Compute interannual variabilty
# #Annual (Jan–Dec)
# pr_sai_ann_std  = pr_sai_ann_by_year.std("year")
# pr_ssp585_ann_std  = pr_ssp585_ann_by_year.std("year")
# pr_ssp245_ann_std  = pr_ssp245_ann_by_year.std("year")
# pr_hist_ann_std  = pr_hist_ann_by_year.std("year")

# #JJAS (Jun–Sep)
# pr_sai_jjas_std = pr_sai_jjas_by_year.std("year")
# pr_ssp585_jjas_std  = pr_ssp585_jjas_by_year.std("year")
# pr_ssp245_jjas_std  = pr_ssp245_jjas_by_year.std("year")
# pr_hist_jjas_std  = pr_hist_jjas_by_year.std("year")

# #DJF (Dec–Feb)
# pr_sai_djf_std  = pr_sai_djf_by_year.std("year")
# pr_ssp585_djf_std  = pr_ssp585_djf_by_year.std("year")
# pr_ssp245_djf_std  = pr_ssp245_djf_by_year.std("year")
# pr_hist_djf_std  = pr_hist_djf_by_year.std("year")



In [None]:
#Making the code efficient

SEC_PER_DAY = 86400.0


#Defining experiments
experiments = {
    "HIST": pr_hist,
    "SSP245": pr_ssp245,
    "SSP585": pr_ssp585,
    "SAI": pr_sai,
}

#Defining seasons in terms of months
seasons = {
    "ANN":  (1, 12),
    "JJAS": (6, 9),
    "DJF":  (12, 2),
}

#Computing seasonal means by year
pr_by_year = {}

for exp, da in experiments.items():
    pr_by_year[exp] = {}
    for season, (m1, m2) in seasons.items():
        pr_by_year[exp][season] = seasonal_mean_by_year(da, m1, m2)


#Computing interannual variability (std over years)
pr_std = {}

for exp in experiments:
    pr_std[exp] = {}
    for season in seasons:
        pr_std[exp][season] = (pr_by_year[exp][season].std("year")* SEC_PER_DAY)


In [None]:

print(pr_by_year)

In [None]:
import matplotlib.pyplot as plt
plt.plot(pr_by_year['SAI']['ANN'].mean(dim=("lat", "lon")))
plt.plot(pr_by_year['SSP245']['ANN'].mean(dim=("lat", "lon")))
plt.plot(pr_by_year['SSP585']['ANN'].mean(dim=("lat", "lon")))
plt.plot(pr_by_year['HIST']['ANN'].mean(dim=("lat", "lon")))

In [None]:
pr_by_year['HIST']['ANN'].year

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(10, 4))

# HIST (1850–2014)
hist = pr_by_year['HIST']['ANN'].mean(dim=("lat", "lon"))
ax.plot(hist.year, hist, label="HIST", color="black")

# Future scenarios (2020–2100)
ssp585 = pr_by_year['SSP585']['ANN'].mean(dim=("lat", "lon"))
ssp245 = pr_by_year['SSP245']['ANN'].mean(dim=("lat", "lon"))
sai    = pr_by_year['SAI']['ANN'].mean(dim=("lat", "lon"))

ax.plot(ssp585.year, ssp585, label="SSP585")
ax.plot(ssp245.year, ssp245, label="SSP245")
ax.plot(sai.year,    sai,    label="SAI")

ax.set_xlim(1850, 2100)
ax.set_xlabel("Year")
ax.set_ylabel("Precipitation")
ax.legend()
ax.grid(True)

plt.show()


In [None]:
# A 2×2 plotting function (STD maps)

def plot_std_2x2(
    std_dict,
    season="ANN",
    lat_min=-50,
    lat_max=50,
    lon_min=0,
    lon_max=360,
    vmin=-3,
    vmax=3,
    cmap="viridis"
):
    """
    Plot 2x2 interannual variability (std) maps
    for HIST, SSP245, SSP585, SAI.
    """

    import matplotlib.pyplot as plt
    import cartopy.crs as ccrs
    import cartopy.feature as cfeature

    proj = ccrs.PlateCarree()

    order = ["HIST", "SSP245", "SSP585", "SAI"]

    fig, axes = plt.subplots(
        nrows=2, ncols=2,
        figsize=(12, 8),
        subplot_kw={"projection": proj},
        constrained_layout=True
    )

    for ax, exp in zip(axes.flat, order):
        da = std_dict[exp][season].sel(lat=slice(lat_min, lat_max),lon=slice(lon_min, lon_max))

        im = da.plot(
            ax=ax,
            transform=proj,
            cmap=cmap,
            vmin=vmin,
            vmax=vmax,
            add_colorbar=False
        )

        ax.coastlines(linewidth=0.8)
        ax.add_feature(cfeature.BORDERS, linewidth=0.4)
        ax.set_title(f"{exp} – {season}")

    cbar = fig.colorbar(
        im,
        ax=axes,
        orientation="horizontal",
        fraction=0.05,
        pad=0.07
    )
    cbar.set_label("Interannual variability (mm day$^{-1}$)")

    plt.show()


In [None]:
# Interannual variability of annual mean
plot_std_2x2(pr_std, cmap="jet", season="ANN", vmin=0, vmax=5)


In [None]:
# Interannual variability of summer mean
plot_std_2x2(pr_std, season="JJAS", vmin=0, vmax=5, cmap="jet")


In [None]:
# Interannual variability of summer mean
plot_std_2x2(pr_std, season="JJAS", vmin=0, vmax=5, cmap="jet", lat_min=-10, lat_max=40, lon_min=50, lon_max=110)

In [None]:




# #Compute differences (SAI − SSP)
# diff_ann  = pr_sai_ann  - pr_ssp_ann
# diff_jjas = pr_sai_jjas - pr_ssp_jjas
# diff_djf  = pr_sai_djf  - pr_ssp_djf

# #Convert units from kg/m2/s to mm/day
# SEC_PER_DAY = 86400.0

# diff_ann_mmday  = diff_ann  * SEC_PER_DAY
# diff_jjas_mmday = diff_jjas * SEC_PER_DAY
# diff_djf_mmday  = diff_djf  * SEC_PER_DAY


# import matplotlib.pyplot as plt
# import cartopy.crs as ccrs
# import cartopy.feature as cfeature

# proj = ccrs.PlateCarree()
# vmin, vmax = -2, 2   # adjust if needed
# cmap = "RdBu_r"

# fig, axes = plt.subplots(
#     nrows=3, ncols=1,
#     figsize=(10, 13),
#     subplot_kw={"projection": proj},
#     constrained_layout=True
# )

# def slice_tropics(da, lat_min=-50, lat_max=50):
#     return da.sel(lat=slice(lat_min, lat_max))


# data_list = [
#     slice_tropics(diff_ann_mmday),
#     slice_tropics(diff_jjas_mmday),
#     slice_tropics(diff_djf_mmday)
# ]

# titles = [
#     "Annual mean (G6solar − SSP585)",
#     "JJAS mean (G6solar − SSP585)",
#     "DJF mean (G6solar − SSP585)"
# ]

# for ax, da, title in zip(axes, data_list, titles):
#     im = da.plot(
#         ax=ax,
#         transform=ccrs.PlateCarree(),
#         cmap=cmap,
#         vmin=vmin,
#         vmax=vmax,
#         add_colorbar=False
#     )
    
#     ax.coastlines(linewidth=0.8)
#     ax.add_feature(cfeature.BORDERS, linewidth=0.4)
#     ax.set_title(title)

# # Shared colorbar
# cbar = fig.colorbar(
#     im,
#     ax=axes,
#     orientation="horizontal",
#     fraction=0.05,
#     pad=0.08
# )
# cbar.set_label("Precipitation difference (mm day$^{-1}$)")

# plt.show()


In [None]:
# #Compute differences (SAI − SSP)
# diff_ann  = pr_sai_ann  - pr_ssp_ann
# diff_jjas = pr_sai_jjas - pr_ssp_jjas
# diff_djf  = pr_sai_djf  - pr_ssp_djf

# #Convert units from kg/m2/s to mm/day
# SEC_PER_DAY = 86400.0

# diff_ann_mmday  = diff_ann  * SEC_PER_DAY
# diff_jjas_mmday = diff_jjas * SEC_PER_DAY
# diff_djf_mmday  = diff_djf  * SEC_PER_DAY


In [None]:
# import matplotlib.pyplot as plt
# import cartopy.crs as ccrs
# import cartopy.feature as cfeature


In [None]:
# proj = ccrs.PlateCarree()
# vmin, vmax = -2, 2   # adjust if needed
# cmap = "RdBu_r"


In [None]:
# fig, axes = plt.subplots(
#     nrows=3, ncols=1,
#     figsize=(10, 13),
#     subplot_kw={"projection": proj},
#     constrained_layout=True
# )

# def slice_tropics(da, lat_min=-50, lat_max=50):
#     return da.sel(lat=slice(lat_min, lat_max))


# data_list = [
#     slice_tropics(diff_ann_mmday),
#     slice_tropics(diff_jjas_mmday),
#     slice_tropics(diff_djf_mmday)
# ]

# titles = [
#     "Annual mean (G6solar − SSP585)",
#     "JJAS mean (G6solar − SSP585)",
#     "DJF mean (G6solar − SSP585)"
# ]

# for ax, da, title in zip(axes, data_list, titles):
#     im = da.plot(
#         ax=ax,
#         transform=ccrs.PlateCarree(),
#         cmap=cmap,
#         vmin=vmin,
#         vmax=vmax,
#         add_colorbar=False
#     )
    
#     ax.coastlines(linewidth=0.8)
#     ax.add_feature(cfeature.BORDERS, linewidth=0.4)
#     ax.set_title(title)

# # Shared colorbar
# cbar = fig.colorbar(
#     im,
#     ax=axes,
#     orientation="horizontal",
#     fraction=0.05,
#     pad=0.08
# )
# cbar.set_label("Precipitation difference (mm day$^{-1}$)")

# plt.show()
