Notebook with functions for plotting sea level anomaly maps and time series.

In [1]:
import numpy as np
import xarray as xr
import xesmf as xe

# modules for plotting datetime data
import matplotlib.dates as mdates
from matplotlib.axis import Axis

# modules for using datetime variables
import datetime
from datetime import time

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.gridspec as gridspec
import matplotlib.cm as cm

from xgcm import Grid
%matplotlib inline
import warnings
warnings.filterwarnings('ignore')

import cartopy.crs as ccrs
import cmocean

import subprocess as sp

import matplotlib.ticker as mticker
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER

from matplotlib.ticker import ScalarFormatter

from xclim import ensembles
from xoverturning import calcmoc
import cmip_basins
import momlevel

import cftime
from pandas.errors import OutOfBoundsDatetime  # Import the specific error

import os

In [2]:
%run /home/Kiera.Lowman/Kd-sensitivity-analysis/notebooks/plotting_functions.ipynb

In [3]:
from matplotlib import font_manager
# Specify the path to your custom .otf font file
font_path = '/home/Kiera.Lowman/.fonts/HelveticaNeueRoman.otf'

# Add the font to the matplotlib font manager
font_manager.fontManager.addfont(font_path)

# Retrieve the font's name from the file
prop = font_manager.FontProperties(fname=font_path)
font_name = prop.get_name()

# Set the default font globally
plt.rcParams['font.family'] = font_name

plt.rcParams['axes.labelsize'] = 12    # Axis label size
plt.rcParams['xtick.labelsize'] = 10     # X-axis tick label size
plt.rcParams['ytick.labelsize'] = 10     # Y-axis tick label size
plt.rcParams['axes.titlesize'] = 14      # Title size
plt.rcParams['legend.fontsize'] = 10     # Legend font size

# Sea level maps

In [4]:
# def plot_SL_diff(title, ref_ds, exp_ds, start_yr, end_yr, variant="steric",
#                  cb_max=None, icon=None,
#                  savefig=False, fig_dir=None, prefix=None,
#                  verbose=False):

#     ref_ds["areacello"] = ref_ds["areacello"]*ref_ds["wet"]
#     ref_ds = ref_ds.rename({"temp":"thetao", "salt":"so"})

#     ref_state = momlevel.reference.setup_reference_state(ref_ds)

#     exp_ds["areacello"] = exp_ds["areacello"]*exp_ds["wet"]
#     exp_ds = exp_ds.rename({"temp":"thetao", "salt":"so"})

#     exp_result = momlevel.steric(exp_ds, reference = ref_state, variant = variant)

#     exp_slr = exp_result[0][variant]

#     global_result = momlevel.steric(exp_ds, reference = ref_state, variant = variant, domain = "global")
#     global_slr = global_result[0][variant]
    
#     # Step 1: Normalize geolon to [0, 360) to avoid wraparound issues
#     exp_slr = exp_slr.assign_coords(
#         geolon=((exp_slr.geolon + 360) % 360)
#     )
    
#     # Step 2: Define target lat/lon grid resolution
#     lat_res = 3 * 210  # e.g., 630 points from -76.75 to 89.75
#     lon_res = 3 * 360  # e.g., 1080 points from 0 to 360
    
#     target_lat = np.linspace(exp_slr.geolat.min(), exp_slr.geolat.max(), lat_res)
#     target_lon = np.linspace(0, 360, lon_res)
    
#     # Step 3: Build source and target grid datasets
#     ds_in = xr.Dataset({
#         "lat": (["yh", "xh"], exp_slr.geolat.values),
#         "lon": (["yh", "xh"], exp_slr.geolon.values),
#     })
    
#     ds_out = xr.Dataset({
#         "lat": (["lat"], target_lat),
#         "lon": (["lon"], target_lon),
#     })
    
#     # Step 4: Create the regridder (periodic=True for wrapping at 0/360)
#     regridder = xe.Regridder(ds_in, ds_out, method="bilinear", periodic=True, reuse_weights=False)
    
#     # Step 5: Apply the regridder to your data
#     exp_slr_interp = regridder(exp_slr)
    
#     # # Optional: Wrap longitudes to [-180, 180] for plotting
#     # exp_slr_interp = exp_slr_interp.assign_coords(
#     #     lon=(((exp_slr_interp.lon + 180) % 360) - 180)
#     # )

#     # # Sort coordinates to fix plotting error
#     # exp_slr_interp = exp_slr_interp.sortby(['lat', 'lon'])
    
#     # Done! exp_slr_interp is now interpolated on a regular (lat, lon) grid

#     # Diagnostic min and max (not for setting bounds)
#     min_val = float(np.nanmin(exp_slr.values))
#     max_val = float(np.nanmax(exp_slr.values))

#     # 0.5th and 99.5th percentiles
#     per0p5 = float(np.nanpercentile(exp_slr.values, 0.5))
#     per99p5 = float(np.nanpercentile(exp_slr.values, 99.5))
#     if verbose:
#         print(f"Full data min/max: {min_val:.3f}/{max_val:.3f}")
#         print(f"Percentile-based max magnitude: {max(abs(per0p5), abs(per99p5)):.3f}")

#     extra_tick_digits = False
    
#     # Decide an initial data-based max
#     if cb_max is not None:
#         if cb_max == 0.2:
#             chosen_n = 20
#         elif cb_max == 0.4:
#             chosen_n = 20
#         elif cb_max == 0.6:
#             chosen_n = 20
#         elif cb_max == 0.8:
#             chosen_n = 20
#         elif cb_max == 1:
#             chosen_n = 20
#         elif cb_max == 1.5:
#             chosen_n = 12
#         elif cb_max == 2:
#             chosen_n = 20
#         elif cb_max == 3:
#             chosen_n = 12
#         elif cb_max == 4:
#             chosen_n = 20
#         else:
#             raise ValueError("cb_max is not an acceptable value.")
            
#         data_max = cb_max
#         chosen_step = 2*data_max/chosen_n
        
#     else:
#         chosen_n, chosen_step = \
#         get_cb_spacing(per0p5,per99p5,min_bnd=0.2,min_spacing=0.02,min_n=10,max_n=20,verbose=verbose)

#     max_mag = 0.5 * chosen_n * chosen_step 

#     zero_step, disc_cmap, disc_norm, boundaries, extend, tick_positions \
#     = create_cb_params(max_mag, min_val, max_val, chosen_n, chosen_step, verbose=verbose)
    
#     # Create the figure and projection
#     fig, ax = plt.subplots(figsize=(7.5, 5), #used to be figsize=(12, 8)
#                            subplot_kw={'projection': ccrs.Robinson(central_longitude=209.5),
#                                        'facecolor': 'grey'}) # used to be '0.75'
    
#     # diff_plot = exp_slr.plot(#vmin=plot_min, vmax=plot_max,
#     #               x='geolon', y='geolat',
#     #               cmap=disc_cmap, norm=disc_norm,
#     #                   #You can pick any projection from the cartopy list but, whichever projection you use, you still have to set
#     #               transform=ccrs.PlateCarree(),
#     #               add_labels=False,
#     #               add_colorbar=False)

#     diff_plot = exp_slr_interp.plot(#vmin=plot_min, vmax=plot_max,
#                   x='lon', y='lat',
#                   cmap=disc_cmap, norm=disc_norm,
#                       #You can pick any projection from the cartopy list but, whichever projection you use, you still have to set
#                   transform=ccrs.PlateCarree(),
#                   add_labels=False,
#                   add_colorbar=False)
    
#     diff_plot.axes.set_title(f"{title}\n{variant.capitalize()} SLR: Year {start_yr}–{end_yr}")#,fontdict={'fontsize':14})
        
#     diff_cb = plt.colorbar(diff_plot, shrink=0.58, pad=0.04, extend=extend, 
#                            boundaries=boundaries, norm=disc_norm, spacing='proportional')
#     tick_labels = []
#     for val in tick_positions:
#         if (np.abs(val) == 0.125 or np.abs(val) == 1/3 or np.abs(val) == 2/3 or np.abs(val) == 5/8):
#             tick_labels.append(f"{val:.3f}")
#         elif chosen_step < 0.1:
#             tick_labels.append(f"{val:.2f}")
#         else:
#             tick_labels.append(f"{val:.1f}")
            
#     diff_cb.set_ticks(tick_positions)
#     diff_cb.ax.set_yticklabels(tick_labels)
#     diff_cb.ax.tick_params(labelsize=10)
#     diff_cb.set_label("Sea Level Anomaly (m)")
#     # for t in diff_cb.ax.get_yticklabels():
#     #     t.set_horizontalalignment('center')
#     #     t.set_x(2.0 if max_mag < 10 else 2.2)
#     if zero_step < 0.1 or max_mag > 10: #or extra_tick_digits:
#         plt.setp(diff_cb.ax.get_yticklabels(), horizontalalignment='center', x=2.2)
#     else:
#         plt.setp(diff_cb.ax.get_yticklabels(), horizontalalignment='center', x=2.0)

#     # print mean value in mm over Eurasia
#     mean_val = 10**3 * global_slr.isel(time=0).values
#     mean_str = f"{mean_val:.1f}"
#     # e.g. at lon=0°, lat=60°
#     ax.text(
#         0.17, 0.8,
#         f"{mean_str} mm",
#         transform=ax.transAxes,    # interpret (x,y) in axes‐fraction
#         fontsize=12,
#         va='top',                  # vertical alignment
#         ha='center',                 # horizontal alignment
#         # backgroundcolor='white',       # optional box to improve legibility
#         color='white', fontweight='bold', alpha=1
#     )

#     if icon is not None:
#         image_path = f"/home/Kiera.Lowman/profile_icons/{icon}_icon.png"  # Replace with your image path
#         img = mpimg.imread(image_path)
#         # Create an OffsetImage in upper right corner
#         imagebox = OffsetImage(img, zoom=0.09)  # Adjust zoom as needed
#         ab = AnnotationBbox(imagebox, (0.95, 1.00), xycoords="axes fraction", frameon=False) # Set the image position (e.g., top-right corner)
#         ax.add_artist(ab) # Add image to the figure

#     if savefig:
#         plt.savefig(fig_dir + f'{prefix}_{variant}_{str(start_yr).zfill(4)}_{str(end_yr).zfill(4)}.png', dpi=600, bbox_inches='tight')
#         plt.close()

In [14]:
def plot_SL_diff(
    panel_title=None,                       # short title when used inside a grid
    ref_ds=None, exp_ds=None,
    start_yr=None, end_yr=None, variant="steric",
    cb_max=None, icon=None,
    savefig=False, fig_dir=None, prefix=None,
    verbose=False,
    # grid hooks:
    ax=None,                 # if provided, draw into this axes (no new fig)
    add_colorbar=True,       # grid sets this False (figure-level bar instead)
    return_cb_params=False,  # grid sets this True (to build figure-level bar)
    cb_label="Sea Level Anomaly (m)",
    # backwards-compat if you still call with `title=...`
    title=None
):
    # ---- titles / labels ----
    panel_title = panel_title or title or ""

    # ---- your original SLA computation ----
    ref_ds["areacello"] = ref_ds["areacello"] * ref_ds["wet"]
    ref_ds = ref_ds.rename({"temp": "thetao", "salt": "so"})
    ref_state = momlevel.reference.setup_reference_state(ref_ds)

    exp_ds["areacello"] = exp_ds["areacello"] * exp_ds["wet"]
    exp_ds = exp_ds.rename({"temp": "thetao", "salt": "so"})

    exp_result = momlevel.steric(exp_ds, reference=ref_state, variant=variant)
    exp_slr = exp_result[0][variant]

    global_result = momlevel.steric(exp_ds, reference=ref_state, variant=variant, domain="global")
    global_slr = global_result[0][variant]

    # ---- normalize & regrid to regular (lat, lon) ----
    exp_slr = exp_slr.assign_coords(geolon=((exp_slr.geolon + 360) % 360))
    lat_res = 3 * 210
    lon_res = 3 * 360
    target_lat = np.linspace(exp_slr.geolat.min(), exp_slr.geolat.max(), lat_res)
    target_lon = np.linspace(0, 360, lon_res)

    ds_in = xr.Dataset({
        "lat": (["yh", "xh"], exp_slr.geolat.values),
        "lon": (["yh", "xh"], exp_slr.geolon.values),
    })
    ds_out = xr.Dataset({
        "lat": (["lat"], target_lat),
        "lon": (["lon"], target_lon),
    })
    regridder = xe.Regridder(ds_in, ds_out, method="bilinear", periodic=True, reuse_weights=False)
    exp_slr_interp = regridder(exp_slr)

    # ---- color scaling (unchanged logic) ----
    min_val = float(np.nanmin(exp_slr.values))
    max_val = float(np.nanmax(exp_slr.values))
    per0p5  = float(np.nanpercentile(exp_slr.values, 0.5))
    per99p5 = float(np.nanpercentile(exp_slr.values, 99.5))
    if verbose:
        print(f"Full data min/max: {min_val:.3f}/{max_val:.3f}")
        print(f"Percentile-based max magnitude: {max(abs(per0p5), abs(per99p5)):.3f}")

    extra_tick_digits = False
    if cb_max is not None:
        if cb_max in (0.2, 0.4, 0.6, 0.8, 1, 2, 4):
            chosen_n = 20
        elif cb_max in (1.5, 3):
            chosen_n = 12
        else:
            raise ValueError("cb_max is not an acceptable value.")
        data_max = cb_max
        chosen_step = 2 * data_max / chosen_n
    else:
        chosen_n, chosen_step = get_cb_spacing(
            per0p5, per99p5, min_bnd=0.2, min_spacing=0.02, min_n=10, max_n=20, verbose=verbose
        )

    max_mag = 0.5 * chosen_n * chosen_step

    zero_step, disc_cmap, disc_norm, boundaries, extend, tick_positions = create_cb_params(
        max_mag, min_val, max_val, chosen_n, chosen_step, verbose=verbose
    )

    # ---- figure/axes management (grid-compatible) ----
    created_fig = None
    if ax is None:
        created_fig, ax = plt.subplots(
            figsize=(7.5, 5),
            subplot_kw={'projection': ccrs.Robinson(central_longitude=209.5), 'facecolor': 'grey'}
        )

    # ---- draw the panel ----
    diff_plot = exp_slr_interp.plot(
        x='lon', y='lat',
        cmap=disc_cmap, norm=disc_norm,
        transform=ccrs.PlateCarree(),
        add_labels=False, add_colorbar=False, ax=ax
    )

    # titles: short for subplots, longer when standalone
    if created_fig is None:
        ax.set_title(f"{panel_title}")
    else:
        ax.set_title(f"{panel_title}\n{variant.capitalize()} SLR: Year {start_yr}–{end_yr}")

    # ---- optional per-panel colorbar (suppressed by grid) ----
    diff_cb = None
    if add_colorbar:
        diff_cb = plt.colorbar(
            diff_plot, ax=ax, shrink=0.58, pad=0.04, extend=extend,
            boundaries=boundaries, norm=disc_norm, spacing='proportional'
        )
        tick_labels = []
        for val in tick_positions:
            if (np.abs(val) in {0.125, 1/3, 2/3, 5/8}):
                tick_labels.append(f"{val:.3f}")
            elif chosen_step < 0.1:
                tick_labels.append(f"{val:.2f}")
            else:
                tick_labels.append(f"{val:.1f}")
        diff_cb.set_ticks(tick_positions)
        diff_cb.ax.set_yticklabels(tick_labels)
        diff_cb.ax.tick_params(labelsize=10)
        diff_cb.set_label(cb_label)
        # fixed: old code referenced undefined `vmax`
        if zero_step < 0.1 or max_mag > 10:
            plt.setp(diff_cb.ax.get_yticklabels(), horizontalalignment='center', x=2.2)
        else:
            plt.setp(diff_cb.ax.get_yticklabels(), horizontalalignment='center', x=2.0)

    # ---- package params for a figure-level colorbar (used by the grid) ----
    cb_params = None
    if return_cb_params:
        tick_labels = []
        for val in tick_positions:
            if (np.abs(val) in {0.125, 1/3, 2/3, 5/8}):
                tick_labels.append(f"{val:.3f}")
            elif chosen_step < 0.1:
                tick_labels.append(f"{val:.2f}")
            else:
                tick_labels.append(f"{val:.1f}")
        cb_params = dict(
            mappable=diff_plot,      # carries cmap+norm
            cmap=disc_cmap,
            norm=disc_norm,
            boundaries=boundaries,
            extend=extend,
            spacing='proportional',
            ticks=tick_positions,
            ticklabels=tick_labels,
            label=cb_label
        )

    # ---- annotation: global mean (mm) ----
    mean_val = 1e3 * global_slr.isel(time=0).values
    ax.text(0.17, 0.8, f"{mean_val:.1f} mm", transform=ax.transAxes,
            fontsize=12, va='top', ha='center', color='white', fontweight='bold', alpha=1)

    # ---- optional corner icon ----
    if icon is not None:
        image_path = f"/home/Kiera.Lowman/profile_icons/{icon}_icon.png"
        img = mpimg.imread(image_path)
        imagebox = OffsetImage(img, zoom=0.09)
        ab = AnnotationBbox(imagebox, (0.95, 1.00), xycoords="axes fraction", frameon=False)
        ax.add_artist(ab)

    # ---- save if we created the figure ----
    if savefig and created_fig is not None:
        out = fig_dir + f'{prefix}_{variant}_{str(start_yr).zfill(4)}_{str(end_yr).zfill(4)}.png'
        created_fig.savefig(out, dpi=600, bbox_inches='tight')
        plt.close(created_fig)

    # what the grid expects:
    return ax, diff_plot, diff_cb, cb_params

In [7]:
def create_SLR_plots(diff_type,start_year,end_year,
                     variant="steric",
                     profiles = ['surf','therm','mid','bot'],
                     prof_strings = ["Surf","Therm","Mid","Bot"],
                     power_var_suff = ['0p1TW', '0p2TW', '0p5TW'],
                     power_strings = ['0.1 TW', '0.2 TW', '0.5 TW'],
                     pref_addition=None,
                     plot_max=None,
                     savefig=False,fig_dir=None,
                     extra_verbose=False):
    """
    Inputs:
    diff_type (str): one of
                    ['const-1860ctrl',
                    'doub-1860exp','doub-2xctrl','doub-1860ctrl',
                    'quad-1860exp','quad-4xctrl','quad-1860ctrl']
    fig_dir (str): path of parent directory in which to save figure
    start_year (int): start year of avg period
    end_year (int): end year of avg period
    variant (str): type of SLR (either "steric", "thermosteric", or "halosteric")
    plot_max (int/float): input for plot_SL_diff
    savefig (boolean): input for plot_SL_diff
    fig_dir (boolean): input for plot_SL_diff
    extra_verbose (boolean): input for plot_SL_diff
    """

    if savefig:
        if not os.path.exists(fig_dir):
            os.makedirs(fig_dir)

    # control cases
    if diff_type == 'doub-1860exp':
        ref_name = f"const_ctrl_{start_year}_{end_year}"
        ds_name = f'doub_ctrl_{start_year}_{end_year}'
        title_str = f"1pct2xCO2 Control"
        fig_name = f"2xCO2-const-ctrl"
        fig_prefix = fig_dir+fig_name
        # plot_SL_diff(title_str, myVars[ref_name], myVars[ds_name], start_year, end_year, variant=variant,
        #                  cb_max=plot_max, icon=prof,
        #                  savefig=savefig, fig_dir=fig_path, prefix=fig_name,
        #                  verbose=extra_verbose)
        plot_SL_diff(panel_title=title_str, ref_ds=myVars[ref_name], exp_ds=myVars[ds_name], start_yr=start_year, end_yr=end_year, variant=variant,
                         cb_max=plot_max, icon=prof,
                         savefig=savefig, fig_dir=fig_path, prefix=fig_name,
                         verbose=extra_verbose)
        print(f"Done {fig_name}.")
        
    elif diff_type == 'quad-1860exp':
        ref_name = f"const_ctrl_{start_year}_{end_year}"
        ds_name = f'quad_ctrl_{start_year}_{end_year}'
        title_str = f"1pct4xCO2 Control"
        fig_name = f"4xCO2-const-ctrl"
        fig_prefix = fig_dir+fig_name
        # plot_SL_diff(title_str, myVars[ref_name], myVars[ds_name], start_year, end_year, variant=variant,
        #                  cb_max=plot_max, icon=prof,
        #                  savefig=savefig, fig_dir=fig_path, prefix=fig_name,
        #                  verbose=extra_verbose)
        plot_SL_diff(panel_title=title_str, ref_ds=myVars[ref_name], exp_ds=myVars[ds_name], start_yr=start_year, end_yr=end_year, variant=variant,
                         cb_max=plot_max, icon=prof,
                         savefig=savefig, fig_dir=fig_path, prefix=fig_name,
                         verbose=extra_verbose)
        print(f"Done {fig_name}.")
        
    # perturbation cases
    for i, power_str in enumerate(power_strings):
        for j, prof in enumerate(profiles):

            # get name of perturbation ds
            if diff_type == 'const-1860ctrl':
                ds_name = f'const_{prof}_{power_var_suff[i]}_{start_year}_{end_year}'
                fig_path = f"{fig_dir}/piControl/"
            elif (diff_type == 'doub-1860exp' or diff_type == 'doub-2xctrl' or diff_type == 'doub-1860ctrl'):
                ds_name = f'doub_{prof}_{power_var_suff[i]}_{start_year}_{end_year}'
                fig_path = f"{fig_dir}/2xCO2/"
            elif (diff_type == 'quad-1860exp' or diff_type == 'quad-4xctrl' or diff_type == 'quad-1860ctrl'):
                ds_name = f'quad_{prof}_{power_var_suff[i]}_{start_year}_{end_year}'
                fig_path = f"{fig_dir}/4xCO2/"

            # get name of reference ds
            if (diff_type == 'const-1860ctrl' or diff_type == 'doub-1860ctrl' or diff_type == 'quad-1860ctrl'):
                ref_name = f"const_ctrl_{start_year}_{end_year}"
            elif diff_type == 'doub-2xctrl':
                ref_name = f"doub_ctrl_{start_year}_{end_year}"
            elif diff_type == 'quad-4xctrl':
                ref_name = f"quad_ctrl_{start_year}_{end_year}"
            elif (diff_type == 'doub-1860exp' or diff_type == 'quad-1860exp'):
                ref_name = f"const_{prof}_{power_var_suff[i]}_{start_year}_{end_year}"

            # assign plot title and fig name
            if diff_type == 'const-1860ctrl':
                title_str = f"Const {prof_strings[j]} {power_str}"
                fig_name = f"{prof}_{power_var_suff[i]}"
                
            elif diff_type == 'doub-1860exp':
                title_str = f"1pct2xCO2 — Const CO2: {prof_strings[j]} {power_str}"
                fig_name = f"2xCO2-const_{prof}_{power_var_suff[i]}"
                
            elif diff_type == 'doub-2xctrl':
                title_str = f"1pct2xCO2 {prof_strings[j]} {power_str} — 1pct2xCO2 Control"
                fig_name = f"2xCO2-2xctrl_{prof}_{power_var_suff[i]}"
                
            elif diff_type == 'doub-1860ctrl':
                title_str = f"1pct2xCO2 {prof_strings[j]} {power_str} — Const CO2 Control"
                fig_name = f"2xCO2-const-ctrl_{prof}_{power_var_suff[i]}"
                
            elif diff_type == 'quad-1860exp':
                title_str = f"1pct4xCO2 — Const CO2: {prof_strings[j]} {power_str}"
                fig_name = f"4xCO2-const_{prof}_{power_var_suff[i]}"
                
            elif diff_type == 'quad-4xctrl':
                title_str = f"1pct4xCO2 {prof_strings[j]} {power_str} — 1pct4xCO2 Control"
                fig_name = f"4xCO2-4xctrl_{prof}_{power_var_suff[i]}"
                
            elif diff_type == 'quad-1860ctrl':
                title_str = f"1pct4xCO2 {prof_strings[j]} {power_str} — Const CO2 Control"
                fig_name = f"4xCO2-const-ctrl_{prof}_{power_var_suff[i]}"
            
            if savefig:
                fig_path = fig_path + f"{str(start_year).zfill(4)}_{str(end_year).zfill(4)}/"
                if not os.path.exists(fig_path):
                    os.makedirs(fig_path)
            else:
                fig_path = None

            if pref_addition != None:
                fig_name = fig_name + "_" + pref_addition
            
            # plot_SL_diff(title_str, myVars[ref_name], myVars[ds_name], start_year, end_year, variant=variant,
            #              cb_max=plot_max, icon=prof,
            #              savefig=savefig, fig_dir=fig_path, prefix=fig_name,
            #              verbose=extra_verbose)

            plot_SL_diff(panel_title=title_str, ref_ds=myVars[ref_name], exp_ds=myVars[ds_name], start_yr=start_year, end_yr=end_year, variant=variant,
                         cb_max=plot_max, icon=prof,
                         savefig=savefig, fig_dir=fig_path, prefix=fig_name,
                         verbose=extra_verbose)

            print(f"Done {fig_name}.")

# Sea level time series

In [8]:
def generate_all_needed_names(diff_type, start_year, end_year, profiles, power_var_suff):
    """
    Generate a sorted list of all dataset names (reference and experiment) needed
    for a given diff_type, time window, profiles, and power suffixes.

    Parameters:
    - diff_type (str): one of ['const-1860ctrl', 'doub-1860exp', 'doub-2xctrl',
                       'doub-1860ctrl', 'quad-1860exp', 'quad-4xctrl', 'quad-1860ctrl']
    - start_year (int): start of averaging period
    - end_year (int): end of averaging period
    - profiles (list of str): e.g. ['surf','therm','mid','bot']
    - power_var_suff (list of str): e.g. ['0p1TW','0p2TW','0p5TW']

    Returns:
    - List[str]: sorted unique dataset names to preprocess
    """
    names = set()
    # Helper to format the common control names
    ctrl_const = f"const_ctrl_{start_year}_{end_year}"
    ctrl_doub  = f"doub_ctrl_{start_year}_{end_year}"
    ctrl_quad  = f"quad_ctrl_{start_year}_{end_year}"

    for power in power_var_suff:
        ds_root = f"{power}_{start_year}_{end_year}"
        for prof in profiles:
            if diff_type == 'const-1860ctrl':
                ref = ctrl_const
                exp = f"const_{prof}_{ds_root}"
            elif diff_type == 'doub-1860exp':
                ref = f"const_{prof}_{ds_root}"
                exp = f"doub_{prof}_{ds_root}"
            elif diff_type == 'doub-2xctrl':
                ref = f"doub_ctrl_{start_year}_{end_year}"
                exp = f"doub_{prof}_{ds_root}"
            elif diff_type == 'doub-1860ctrl':
                ref = ctrl_const
                exp = f"doub_{prof}_{ds_root}"
            elif diff_type == 'quad-1860exp':
                ref = f"const_{prof}_{ds_root}"
                exp = f"quad_{prof}_{ds_root}"
            elif diff_type == 'quad-4xctrl':
                ref = ctrl_quad
                exp = f"quad_{prof}_{ds_root}"
            elif diff_type == 'quad-1860ctrl':
                ref = ctrl_const
                exp = f"quad_{prof}_{ds_root}"
            else:
                raise ValueError(f"Unknown diff_type: {diff_type}")

            names.add(ref)
            names.add(exp)

        # Add the extra control experiment line for specific diff_types
        if diff_type in ('doub-1860exp', 'doub-2xctrl'):
            names.update({ctrl_const, ctrl_doub})

    return sorted(names)

In [9]:
def preprocess(ds, time_chunk):
    # ds = ds.chunk({'time': time_chunk,
    #                'yh': 105,
    #                'xh': 180})
    ds['areacello'] = ds.areacello * ds.wet
    ds = ds.rename({'temp':'thetao','salt':'so'})
    return ds.persist()

In [10]:
def calculate_GMSLR(diff_type,fig_dir,start_year,end_year,variant="steric",
                 profiles = ['surf','therm','mid','bot'],
                 power_var_suff = ['0p1TW', '0p2TW', '0p5TW']):

    n_yrs = (end_year - start_year + 1)
    time = np.linspace(start_year,end_year,num=n_yrs)

    new_chunk = int(n_yrs/5)

    # 1) preprocess all needed datasets
    all_needed_names = generate_all_needed_names(diff_type, start_year, end_year, profiles, power_var_suff)
    print(all_needed_names)
    # prepped = {name: preprocess(myVars[name]) for name in all_needed_names}
    prepped = {}
    for name in all_needed_names:
        print(f"Preprocessing {name}…")
        prepped[name] = preprocess(myVars[name], new_chunk)
        print(f"Finished {name}")
    print("All datasets preprocessed!")
    # print(prepped)
    
    # separate plot for each power input
    for pow_idx, power in enumerate(power_var_suff):

        if diff_type == 'const-1860ctrl' or diff_type == 'doub-1860ctrl' or diff_type == 'quad-1860ctrl':
            ref_ds_name = f'const_ctrl_{start_year}_{end_year}'
            ref_ds = prepped[ref_ds_name]#.load()
            ref_state = momlevel.reference.setup_reference_state(ref_ds)
        elif diff_type == 'doub-2xctrl':
            ref_ds_name = f'doub_ctrl_{start_year}_{end_year}'
            ref_ds = prepped[ref_ds_name].load()
            ref_state = momlevel.reference.setup_reference_state(ref_ds)
        elif diff_type == 'quad-4xctrl':
            ref_ds_name = f'quad_ctrl_{start_year}_{end_year}'
            ref_ds = prepped[ref_ds_name]#.load()
            ref_state = momlevel.reference.setup_reference_state(ref_ds)
                
        ds_root = f'{power}_{start_year}_{end_year}'
        for prof in profiles:
            print(f'Starting {prof}')
            if diff_type == 'const-1860ctrl':
                exp_ds_name = f'const_{prof}_{ds_root}'
                sla_name = f'const_{prof}_{ds_root}_const_ctrl_SLA'
                
            elif diff_type == 'doub-2xctrl':
                exp_ds_name = f'doub_{prof}_{ds_root}'
                sla_name = f'doub_{prof}_{ds_root}_2xctrl_SLA'
                
            elif diff_type == 'doub-1860ctrl':
                exp_ds_name = f'doub_{prof}_{ds_root}'
                sla_name = f'doub_{prof}_{ds_root}_const_ctrl_SLA'

            elif diff_type == 'quad-4xctrl':
                exp_ds_name = f'quad_{prof}_{ds_root}'
                sla_name = f'quad_{prof}_{ds_root}_4xctrl_SLA'
                
            elif diff_type == 'quad-1860ctrl':
                exp_ds_name = f'quad_{prof}_{ds_root}'
                sla_name = f'quad_{prof}_{ds_root}_const_ctrl_SLA'

            elif diff_type == 'doub-1860exp' or diff_type == 'quad-1860exp':
                ref_ds_name = f'const_{prof}_{ds_root}'
                ref_ds = prepped[ref_ds_name]#.load()
                ref_state = momlevel.reference.setup_reference_state(ref_ds)
                
                if diff_type == 'doub-1860exp':
                    exp_ds_name = f'doub_{prof}_{ds_root}'
                    sla_name = f'doub_{prof}_{ds_root}_1860_SLA'
                elif diff_type == 'quad-1860exp':
                    exp_ds_name = f'quad_{prof}_{ds_root}'
                    sla_name = f'quad_{prof}_{ds_root}_1860_SLA'

            exp_ds = prepped[exp_ds_name]#.load()
            global_result = momlevel.steric(exp_ds, reference = ref_state, variant = variant, domain = "global")
            global_slr = global_result[0][variant]
            print("Done momlevel calc")
            myVars[sla_name] = global_slr
            print(f'Ending {prof}')
            print(f'Done {sla_name}')

        if diff_type == 'doub-1860exp' or diff_type == 'doub-1860ctrl' or diff_type == 'quad-1860exp' or diff_type == 'quad-1860ctrl':
            print(f'Starting ctrl')
            
            if diff_type == 'doub-1860exp' or diff_type == 'doub-1860ctrl':
                ref_ds_name = f'const_ctrl_{start_year}_{end_year}'
                exp_ds_name = f'doub_ctrl_{start_year}_{end_year}'
                sla_name = f'doub_ctrl_{start_year}_{end_year}_SLA'

            elif diff_type == 'quad-1860exp' or diff_type == 'quad-1860ctrl':
                ref_ds_name = f'const_ctrl_{start_year}_{end_year}'
                exp_ds_name = f'doub_ctrl_{start_year}_{end_year}'
                sla_name = f'quad_ctrl_{start_year}_{end_year}_const_ctrl_SLA'
                
            ref_ds = prepped[ref_ds_name]#.load()
            exp_ds = prepped[exp_ds_name]#.load()
            
            ref_state = momlevel.reference.setup_reference_state(ref_ds)
            global_result = momlevel.steric(exp_ds, reference = ref_state, variant = variant, domain = "global")
            global_slr = global_result[0][variant]
            myVars[sla_name] = global_slr
            print(f'Ending ctrl')
            print(f'Done {sla_name}')

In [11]:
def plot_GMSLR_ts(diff_type,fig_dir,start_year,end_year,variant="steric",
                 ylimits = [-0.1,0.5],
                 leg_loc = 'upper left',
                 leg_ncols = 1,
                 profiles = ['surf','therm','mid','bot'],
                 power_var_suff = ['0p1TW', '0p2TW', '0p5TW'], 
                 power_strings = ['0.1 TW', '0.2 TW', '0.5 TW'],
                 savefig=False):
    """
    Function to plot anomaly over time from time series data for a particular CO2 scenario (each power input on separate plot).
    
    Inputs:
    diff_type (str): one of
                    ['const-1860ctrl',
                    'doub-1860exp','doub-2xctrl','doub-1860ctrl',
                    'quad-1860exp','quad-4xctrl','quad-1860ctrl']
    fig_dir (str): directory path to save figure
    start_year (int): start year of avg period
    end_year (int): end year of avg period
    """

    prof_dict = {'surf': 'b',
                 'therm': 'm',
                 'mid': 'g',
                 'bot': 'r'}

    n_yrs = (end_year - start_year + 1)
    time = np.linspace(start_year,end_year,num=n_yrs)
    
    # separate plot for each power input
    for pow_idx, power in enumerate(power_var_suff):
        fig, ax = plt.subplots(figsize=(6,3))

        # Add a horizontal line at y=0
        ax.axhline(y=0, color='black', linestyle='-', linewidth=1)
                
        ds_root = f'{power}_{start_year}_{end_year}'
        for prof in profiles:
            print(f'Starting {prof}')
            if diff_type == 'const-1860ctrl':
                exp_ds_name = f'const_{prof}_{ds_root}_const_ctrl_SLA'
                
            elif diff_type == 'doub-2xctrl':
                exp_ds_name = f'doub_{prof}_{ds_root}_2xctrl_SLA'
                
            elif diff_type == 'doub-1860ctrl':
                exp_ds_name = f'doub_{prof}_{ds_root}_const_ctrl_SLA'

            elif diff_type == 'quad-4xctrl':
                exp_ds_name = f'quad_{prof}_{ds_root}_4xctrl_SLA'
                
            elif diff_type == 'quad-1860ctrl':
                exp_ds_name = f'quad_{prof}_{ds_root}_const_ctrl_SLA'

            elif diff_type == 'doub-1860exp':
                exp_ds_name = f'doub_{prof}_{ds_root}_1860_SLA'
                
            elif diff_type == 'quad-1860exp':
                exp_ds_name = f'quad_{prof}_{ds_root}_1860_SLA'

            ax.plot(time,myVars[exp_ds_name].load(),label=prof,color=prof_dict[prof])
            print(f'Ending {prof}')

        if diff_type == 'doub-1860exp' or diff_type == 'doub-1860ctrl':
            
            print(f'Starting ctrl')
            exp_ds_name = f'doub_ctrl_{start_year}_{end_year}_const_ctrl_SLA'
            ax.plot(time,myVars[exp_ds_name].load(),label='control',color='k')
            print(f'Ending ctrl')
            
        ax.set_xlabel("Time (Years)")
        ax.set_ylabel("Sea Level Anomaly (m)")
        ax.set_xlim(0,end_year)
        ax.set_ylim(ylimits)
        ax.legend(loc=leg_loc,ncols=leg_ncols)
        ax.grid("both")
    
        ax.minorticks_on()
        ax.grid(which='major', linestyle='-', linewidth='0.5', color='gray')

        if diff_type == 'const-1860ctrl':
            title_str = f"Const CO2: {power_strings[pow_idx]} Cases\n"
            fig_name = f"const_{power}"
            
        elif diff_type == 'doub-1860exp':
            title_str = f"1pct2xCO2 Radiative: {power_strings[pow_idx]} Cases\n"
            fig_name = f"2xCO2-const_{power}"
            
        elif diff_type == 'doub-2xctrl':
            title_str = f"1pct2xCO2 Mixing: {power_strings[pow_idx]} Cases\n"
            fig_name = f"2xCO2-2xctrl_{power}"
            
        elif diff_type == 'doub-1860ctrl':
            title_str = f"1pct2xCO2 Total: {power_strings[pow_idx]} Cases\n"
            fig_name = f"2xCO2-const-ctrl_{power}"
            
        elif diff_type == 'quad-1860exp':
            title_str = f"1pct4xCO2 Radiative: {power_strings[pow_idx]} Cases\n"
            fig_name = f"4xCO2-const_{power}"
            
        elif diff_type == 'quad-4xctrl':
            title_str = f"1pct4xCO2 Mixing: {power_strings[pow_idx]} Cases\n"
            fig_name = f"4xCO2-4xctrl_{power}"
            
        elif diff_type == 'quad-1860ctrl':
            title_str = f"1pct4xCO2 Total: {power_strings[pow_idx]} Cases\n"
            fig_name = f"4xCO2-const-ctrl_{power}"

        title_suff = f"Global Mean {variant.capitalize()} Sea Level Anomaly"
        ax.set_title(title_str+title_suff)

        if savefig:
            plt.savefig(fig_dir+f'{diff_type}_GMSLR_{variant}_{fig_name}_{start_year}_{end_year}.pdf', dpi=600, bbox_inches='tight')
            plt.close()