Notebook with functions for plotting sea level anomaly maps and time series.

In [1]:
import numpy as np
import xarray as xr
import xesmf as xe

# modules for plotting datetime data
import matplotlib.dates as mdates
from matplotlib.axis import Axis

# modules for using datetime variables
import datetime
from datetime import time

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.gridspec as gridspec
import matplotlib.cm as cm

from xgcm import Grid
%matplotlib inline
import warnings
warnings.filterwarnings('ignore')

import cartopy.crs as ccrs
import cmocean

import subprocess as sp

import matplotlib.ticker as mticker
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER

from matplotlib.ticker import ScalarFormatter

from xclim import ensembles
from xoverturning import calcmoc
import cmip_basins
import momlevel

import cftime
from pandas.errors import OutOfBoundsDatetime  # Import the specific error

import os

In [2]:
%run /home/Kiera.Lowman/Kd-sensitivity-analysis/notebooks/plotting_functions.ipynb

In [3]:
from matplotlib import font_manager
# Specify the path to your custom .otf font file
font_path = '/home/Kiera.Lowman/.fonts/HelveticaNeueRoman.otf'

# Add the font to the matplotlib font manager
font_manager.fontManager.addfont(font_path)

# Retrieve the font's name from the file
prop = font_manager.FontProperties(fname=font_path)
font_name = prop.get_name()

# Set the default font globally
plt.rcParams['font.family'] = font_name

plt.rcParams['axes.labelsize'] = 12    # Axis label size
plt.rcParams['xtick.labelsize'] = 10     # X-axis tick label size
plt.rcParams['ytick.labelsize'] = 10     # Y-axis tick label size
plt.rcParams['axes.titlesize'] = 14      # Title size
plt.rcParams['legend.fontsize'] = 10     # Legend font size

# Sea level maps

In [4]:
def plot_SL_diff(title, ref_ds, exp_ds, start_yr, end_yr, variant="steric",
                 cb_max=None, icon=None,
                 savefig=False, fig_dir=None, prefix=None,
                 verbose=False):

    lat_res = 3 * 210
    lon_res = 3 * 360

    ref_ds["areacello"] = ref_ds["areacello"]*ref_ds["wet"]
    ref_ds = ref_ds.rename({"temp":"thetao", "salt":"so"})

    ref_state = momlevel.reference.setup_reference_state(ref_ds)

    exp_ds["areacello"] = exp_ds["areacello"]*exp_ds["wet"]
    exp_ds = exp_ds.rename({"temp":"thetao", "salt":"so"})

    exp_result = momlevel.steric(exp_ds, reference = ref_state, variant = variant)

    exp_slr = exp_result[0][variant]

    global_result = momlevel.steric(exp_ds, reference = ref_state, variant = variant, domain = "global")
    global_slr = global_result[0][variant]
    
    # # Define new finer grids for latitude and longitude
    # fine_lat = np.linspace(exp_slr.yh.min(), exp_slr.yh.max(), lat_res)
    # fine_lon = np.linspace(exp_slr.xh.min(), exp_slr.xh.max(), lon_res)

    # # Interpolate the variable onto the new grid using linear interpolation
    # exp_slr = exp_slr.interp(yh=fine_lat, xh=fine_lon)

    # # define your target grid as an xarray Dataset
    # grid_out = xr.Dataset({
    #     'lon': (['lon'], np.linspace(exp_slr.xh.min(),
    #                                   exp_slr.xh.max(),
    #                                   lon_res)),
    #     'lat': (['lat'], np.linspace(exp_slr.yh.min(),
    #                                   exp_slr.yh.max(),
    #                                   lat_res))
    # })

    # # create a bilinear regridder with periodic boundary in longitude
    # regridder = xe.Regridder(exp_slr, grid_out, 'bilinear', periodic=True)
    
    # # do the remap
    # exp_slr = regridder(exp_slr)
    # exp_slr = exp_slr.rename({"lon":"xh", "lat":"yh"})

    # Diagnostic min and max (not for setting bounds)
    min_val = float(np.nanmin(exp_slr.values))
    max_val = float(np.nanmax(exp_slr.values))

    # 0.5th and 99.5th percentiles
    per0p5 = float(np.nanpercentile(exp_slr.values, 0.5))
    per99p5 = float(np.nanpercentile(exp_slr.values, 99.5))
    if verbose:
        print(f"Full data min/max: {min_val:.3f}/{max_val:.3f}")
        print(f"Percentile-based max magnitude: {max(abs(per0p5), abs(per99p5)):.3f}")

    extra_tick_digits = False
    
    # Decide an initial data-based max
    if cb_max is not None:
        if cb_max == 0.2:
            chosen_n = 20
        elif cb_max == 0.4:
            chosen_n = 20
        elif cb_max == 0.6:
            chosen_n = 20
        elif cb_max == 0.8:
            chosen_n = 20
        elif cb_max == 1:
            chosen_n = 20
        elif cb_max == 1.5:
            chosen_n = 12
        elif cb_max == 2:
            chosen_n = 20
        elif cb_max == 3:
            chosen_n = 12
        elif cb_max == 4:
            chosen_n = 20
        else:
            raise ValueError("cb_max is not an acceptable value.")
            
        data_max = cb_max
        chosen_step = 2*data_max/chosen_n
        
    else:
        chosen_n, chosen_step = \
        get_cb_spacing(per0p5,per99p5,min_bnd=0.2,min_spacing=0.02,min_n=10,max_n=20,verbose=verbose)

    max_mag = 0.5 * chosen_n * chosen_step 

    zero_step, disc_cmap, disc_norm, boundaries, extend, tick_positions \
    = create_cb_params(max_mag, min_val, max_val, chosen_n, chosen_step, verbose=verbose)
    
    # Create the figure and projection
    fig, ax = plt.subplots(figsize=(7.5, 5), #used to be figsize=(12, 8)
                           subplot_kw={'projection': ccrs.Robinson(central_longitude=209.5),
                                       'facecolor': 'grey'}) # used to be '0.75'
    
    diff_plot = exp_slr.plot(#vmin=plot_min, vmax=plot_max,
                  x='geolon', y='geolat',
                  cmap=disc_cmap, norm=disc_norm,
                      #You can pick any projection from the cartopy list but, whichever projection you use, you still have to set
                  transform=ccrs.PlateCarree(),
                  add_labels=False,
                  add_colorbar=False)
    
    diff_plot.axes.set_title(f"{title}\n{variant.capitalize()} SLR: Year {start_yr}–{end_yr}")#,fontdict={'fontsize':14})
        
    diff_cb = plt.colorbar(diff_plot, shrink=0.58, pad=0.04, extend=extend, 
                           boundaries=boundaries, norm=disc_norm, spacing='proportional')
    tick_labels = []
    for val in tick_positions:
        if (np.abs(val) == 0.125 or np.abs(val) == 1/3 or np.abs(val) == 2/3 or np.abs(val) == 5/8):
            tick_labels.append(f"{val:.3f}")
        elif chosen_step < 0.1:
            tick_labels.append(f"{val:.2f}")
        else:
            tick_labels.append(f"{val:.1f}")
            
    diff_cb.set_ticks(tick_positions)
    diff_cb.ax.set_yticklabels(tick_labels)
    diff_cb.ax.tick_params(labelsize=10)
    diff_cb.set_label("Sea Level Anomaly (m)")
    # for t in diff_cb.ax.get_yticklabels():
    #     t.set_horizontalalignment('center')
    #     t.set_x(2.0 if vmax < 10 else 2.2)
    if zero_step < 0.1 or vmax > 10: #or extra_tick_digits:
        plt.setp(diff_cb.ax.get_yticklabels(), horizontalalignment='center', x=2.2)
    else:
        plt.setp(diff_cb.ax.get_yticklabels(), horizontalalignment='center', x=2.0)

    # print mean value in mm over Eurasia
    mean_val = 10**3 * global_slr.isel(time=0).values
    mean_str = f"{mean_val:.1f}"
    # e.g. at lon=0°, lat=60°
    ax.text(
        0.17, 0.8,
        f"{mean_str} mm",
        transform=ax.transAxes,    # interpret (x,y) in axes‐fraction
        fontsize=12,
        va='top',                  # vertical alignment
        ha='center',                 # horizontal alignment
        # backgroundcolor='white',       # optional box to improve legibility
        color='white', fontweight='bold', alpha=1
    )

    if icon is not None:
        image_path = f"/home/Kiera.Lowman/profile_icons/{icon}_icon.png"  # Replace with your image path
        img = mpimg.imread(image_path)
        # Create an OffsetImage in upper right corner
        imagebox = OffsetImage(img, zoom=0.09)  # Adjust zoom as needed
        ab = AnnotationBbox(imagebox, (0.95, 1.00), xycoords="axes fraction", frameon=False) # Set the image position (e.g., top-right corner)
        ax.add_artist(ab) # Add image to the figure

    if savefig:
        plt.savefig(fig_dir + f'{prefix}_{variant}_{str(start_yr).zfill(4)}_{str(end_yr).zfill(4)}.png', dpi=600, bbox_inches='tight')
        plt.close()

In [7]:
def plot_GMSLR_ts(diff_type,fig_dir,start_year,end_year,variant="steric",
                 ylimits = [-0.1,0.5],
                 leg_loc = 'upper left',
                 leg_ncols = 1,
                 profiles = ['surf','therm','mid','bot'], 
                 power_inputs = ['0.1TW', '0.2TW', '0.5TW'], 
                 power_var_suff = ['0p1TW', '0p2TW', '0p5TW'], 
                 power_strings = ['0.1 TW', '0.2 TW', '0.5 TW'],
                 savefig=False):
    """
    Function to plot anomaly over time from time series data for a particular CO2 scenario (each power input on separate plot).
    
    Inputs:
    diff_type (str): one of
                    ['const-1860ctrl',
                    'doub-1860exp','doub-2xctrl','doub-1860ctrl',
                    'quad-1860exp','quad-4xctrl','quad-1860ctrl']
    fig_dir (str): directory path to save figure
    start_year (int): start year of avg period
    end_year (int): end year of avg period
    """

    prof_dict = {'surf': 'b',
                 'therm': 'm',
                 'mid': 'g',
                 'bot': 'r'}

    new_chunk = (end_year - start_year + 1)/10
    
    # separate plot for each power input
    for pow_idx, power in enumerate(power_var_suff):
        fig, ax = plt.subplots(figsize=(6,3))

        # Add a horizontal line at y=0
        ax.axhline(y=0, color='black', linestyle='-', linewidth=1)
        
        ds_root = f'{power}_{start_year}_{end_year}'
        for prof in profiles:
            print(f'Starting {prof}')
            if diff_type == 'const-1860ctrl':
                ref_ds_name = f'const_ctrl_{start_year}_{end_year}'
                exp_ds_name = f'const_{prof}_{ds_root}'
                
            elif diff_type == 'doub-1860exp':
                ref_ds_name = f'const_{prof}_{ds_root}'
                exp_ds_name = f'doub_{prof}_{ds_root}'
                
            elif diff_type == 'doub-2xctrl':
                ref_ds_name = f'doub_ctrl_{start_year}_{end_year}'
                exp_ds_name = f'doub_{prof}_{ds_root}'
                
            elif diff_type == 'doub-1860ctrl':
                ref_ds_name = f'const_ctrl_{start_year}_{end_year}'
                exp_ds_name = f'doub_{prof}_{ds_root}'
                
            elif diff_type == 'quad-1860exp':
                ref_ds_name = f'const_{prof}_{ds_root}'
                exp_ds_name = f'quad_{prof}_{ds_root}'
                
            elif diff_type == 'quad-4xctrl':
                ref_ds_name = f'quad_ctrl_{start_year}_{end_year}'
                exp_ds_name = f'quad_{prof}_{ds_root}'
                
            elif diff_type == 'quad-1860ctrl':
                ref_ds_name = f'const_ctrl_{start_year}_{end_year}'
                exp_ds_name = f'quad_{prof}_{ds_root}'

            ref_ds = myVars[ref_ds_name].chunk({'time':new_chunk}).persist()
            print("Done rechunking")
            
            ref_ds["areacello"] = ref_ds["areacello"]*ref_ds["wet"]
            ref_ds = ref_ds.rename({"temp":"thetao", "salt":"so"})
            print("Done masking and rename")
            
            exp_ds = myVars[exp_ds_name].chunk({'time':new_chunk}).persist()
            print("Done rechunking")
            
            exp_ds["areacello"] = exp_ds["areacello"]*exp_ds["wet"]
            exp_ds = exp_ds.rename({"temp":"thetao", "salt":"so"})
            print("Done masking and rename")
        
            ref_state = momlevel.reference.setup_reference_state(ref_ds)
            global_result = momlevel.steric(exp_ds, reference = ref_state, variant = variant, domain = "global")
            global_slr = global_result[0][variant]
            print("Done momlevel calc")

            time = np.linspace(start_year,end_year,num=len(global_slr) )
            ax.plot(time,global_slr,label=prof,color=prof_dict[prof])
            
            print(f'Ending {prof}')

        if diff_type == 'doub-1860exp' or diff_type == 'doub-1860ctrl':
            print(f'Starting ctrl')
            ref_ds_name = f'const_ctrl_{start_year}_{end_year}'
            exp_ds_name = f'doub_ctrl_{start_year}_{end_year}'
            
            ref_ds = myVars[ref_ds_name].chunk({'time':new_chunk}).persist()
            print("Done rechunking")
            
            ref_ds["areacello"] = ref_ds["areacello"]*ref_ds["wet"]
            ref_ds = ref_ds.rename({"temp":"thetao", "salt":"so"})

            exp_ds = myVars[exp_ds_name].chunk({'time':new_chunk}).persist()
            print("Done rechunking")
            
            exp_ds["areacello"] = exp_ds["areacello"]*exp_ds["wet"]
            exp_ds = exp_ds.rename({"temp":"thetao", "salt":"so"})
        
            ref_state = momlevel.reference.setup_reference_state(ref_ds)
            global_result = momlevel.steric(exp_ds, reference = ref_state, variant = variant, domain = "global")
            global_slr = global_result[0][variant]

            time = np.linspace(start_year,end_year,num=len(global_slr) )
            ax.plot(time,global_slr,label='control',color='k')

            print(f'Ending ctrl')
            
        ax.set_xlabel("Time (Years)")
        ax.set_ylabel("Sea Level Anomaly (m)")
        ax.set_xlim(0,end_year)
        ax.set_ylim(ylimits)
        ax.legend(loc=leg_loc,ncols=leg_ncols)
        ax.grid("both")
    
        ax.minorticks_on()
        ax.grid(which='major', linestyle='-', linewidth='0.5', color='gray')

        if diff_type == 'const-1860ctrl':
            title_str = f"Const CO2: {power_strings[pow_idx]} Cases\n"
            fig_name = f"const_{power}"
            
        elif diff_type == 'doub-1860exp':
            title_str = f"1pct2xCO2 Radiative: {power_strings[pow_idx]} Cases\n"
            fig_name = f"2xCO2-const_{power}"
            
        elif diff_type == 'doub-2xctrl':
            title_str = f"1pct2xCO2 Mixing: {power_strings[pow_idx]} Cases\n"
            fig_name = f"2xCO2-2xctrl_{power}"
            
        elif diff_type == 'doub-1860ctrl':
            title_str = f"1pct2xCO2 Total: {power_strings[pow_idx]} Cases\n"
            fig_name = f"2xCO2-const-ctrl_{power}"
            
        elif diff_type == 'quad-1860exp':
            title_str = f"1pct4xCO2 Radiative: {power_strings[pow_idx]} Cases\n"
            fig_name = f"4xCO2-const_{power}"
            
        elif diff_type == 'quad-4xctrl':
            title_str = f"1pct4xCO2 Mixing: {power_strings[pow_idx]} Cases\n"
            fig_name = f"4xCO2-4xctrl_{power}"
            
        elif diff_type == 'quad-1860ctrl':
            title_str = f"1pct4xCO2 Total: {power_strings[pow_idx]} Cases\n"
            fig_name = f"4xCO2-const-ctrl_{power}"

        title_suff = f"Global Mean {variant.capitalize()} Sea Level Anomaly"
        ax.set_title(title_str+title_suff)

        if savefig:
            plt.savefig(fig_dir+f'{diff_type}_GMSLR_{variant}_{fig_name}_{start_year}_{end_year}.pdf', dpi=600, bbox_inches='tight')
            plt.close()

In [6]:
def create_SLR_plots(diff_type,start_year,end_year,
                     variant="steric",
                     profiles = ['surf','therm','mid','bot'],
                     prof_strings = ["Surf","Therm","Mid","Bot"],
                     power_var_suff = ['0p1TW', '0p2TW', '0p5TW'],
                     power_strings = ['0.1 TW', '0.2 TW', '0.5 TW'],
                     pref_addition=None,
                     plot_max=None,
                     savefig=False,fig_dir=None,
                     extra_verbose=False):
    """
    Inputs:
    diff_type (str): one of
                    ['const-1860ctrl',
                    'doub-1860exp','doub-2xctrl','doub-1860ctrl',
                    'quad-1860exp','quad-4xctrl','quad-1860ctrl']
    fig_dir (str): path of parent directory in which to save figure
    start_year (int): start year of avg period
    end_year (int): end year of avg period
    variant (str): type of SLR (either "steric", "thermosteric", or "halosteric")
    plot_max (int/float): input for plot_SL_diff
    savefig (boolean): input for plot_SL_diff
    fig_dir (boolean): input for plot_SL_diff
    extra_verbose (boolean): input for plot_SL_diff
    """

    if savefig:
        if not os.path.exists(fig_dir):
            os.makedirs(fig_dir)

    # control cases
    if diff_type == 'doub-1860exp':
        ref_name = f"const_ctrl_{start_year}_{end_year}"
        ds_name = f'doub_ctrl_{start_year}_{end_year}'
        title_str = f"1pct2xCO2 Control"
        fig_name = f"2xCO2-const-ctrl"
        fig_prefix = fig_dir+fig_name
        plot_SL_diff(title_str, myVars[ref_name], myVars[ds_name], start_year, end_year, variant=variant,
                         cb_max=plot_max, icon=prof,
                         savefig=savefig, fig_dir=fig_path, prefix=fig_name,
                         verbose=extra_verbose)
        print(f"Done {fig_name}.")
        
    elif diff_type == 'quad-1860exp':
        ref_name = f"const_ctrl_{start_year}_{end_year}"
        ds_name = f'quad_ctrl_{start_year}_{end_year}'
        title_str = f"1pct4xCO2 Control"
        fig_name = f"4xCO2-const-ctrl"
        fig_prefix = fig_dir+fig_name
        plot_SL_diff(title_str, myVars[ref_name], myVars[ds_name], start_year, end_year, variant=variant,
                         cb_max=plot_max, icon=prof,
                         savefig=savefig, fig_dir=fig_path, prefix=fig_name,
                         verbose=extra_verbose)
        print(f"Done {fig_name}.")
        
    # perturbation cases
    for i, power_str in enumerate(power_strings):
        for j, prof in enumerate(profiles):

            # get name of perturbation ds
            if diff_type == 'const-1860ctrl':
                ds_name = f'const_{prof}_{power_var_suff[i]}_{start_year}_{end_year}'
                fig_path = f"{fig_dir}/piControl/"
            elif (diff_type == 'doub-1860exp' or diff_type == 'doub-2xctrl' or diff_type == 'doub-1860ctrl'):
                ds_name = f'doub_{prof}_{power_var_suff[i]}_{start_year}_{end_year}'
                fig_path = f"{fig_dir}/2xCO2/"
            elif (diff_type == 'quad-1860exp' or diff_type == 'quad-4xctrl' or diff_type == 'quad-1860ctrl'):
                ds_name = f'quad_{prof}_{power_var_suff[i]}_{start_year}_{end_year}'
                fig_path = f"{fig_dir}/4xCO2/"

            # get name of reference ds
            if (diff_type == 'const-1860ctrl' or diff_type == 'doub-1860ctrl' or diff_type == 'quad-1860ctrl'):
                ref_name = f"const_ctrl_{start_year}_{end_year}"
            elif diff_type == 'doub-2xctrl':
                ref_name = f"doub_ctrl_{start_year}_{end_year}"
            elif diff_type == 'quad-4xctrl':
                ref_name = f"quad_ctrl_{start_year}_{end_year}"
            elif (diff_type == 'doub-1860exp' or diff_type == 'quad-1860exp'):
                ref_name = f"const_{prof}_{power_var_suff[i]}_{start_year}_{end_year}"

            # assign plot title and fig name
            if diff_type == 'const-1860ctrl':
                title_str = f"Const {prof_strings[j]} {power_str}"
                fig_name = f"{prof}_{power_var_suff[i]}"
                
            elif diff_type == 'doub-1860exp':
                title_str = f"1pct2xCO2 — Const CO2: {prof_strings[j]} {power_str}"
                fig_name = f"2xCO2-const_{prof}_{power_var_suff[i]}"
                
            elif diff_type == 'doub-2xctrl':
                title_str = f"1pct2xCO2 {prof_strings[j]} {power_str} — 1pct2xCO2 Control"
                fig_name = f"2xCO2-2xctrl_{prof}_{power_var_suff[i]}"
                
            elif diff_type == 'doub-1860ctrl':
                title_str = f"1pct2xCO2 {prof_strings[j]} {power_str} — Const CO2 Control"
                fig_name = f"2xCO2-const-ctrl_{prof}_{power_var_suff[i]}"
                
            elif diff_type == 'quad-1860exp':
                title_str = f"1pct4xCO2 — Const CO2: {prof_strings[j]} {power_str}"
                fig_name = f"4xCO2-const_{prof}_{power_var_suff[i]}"
                
            elif diff_type == 'quad-4xctrl':
                title_str = f"1pct4xCO2 {prof_strings[j]} {power_str} — 1pct4xCO2 Control"
                fig_name = f"4xCO2-4xctrl_{prof}_{power_var_suff[i]}"
                
            elif diff_type == 'quad-1860ctrl':
                title_str = f"1pct4xCO2 {prof_strings[j]} {power_str} — Const CO2 Control"
                fig_name = f"4xCO2-const-ctrl_{prof}_{power_var_suff[i]}"
            
            if savefig:
                fig_path = fig_path + f"{str(start_year).zfill(4)}_{str(end_year).zfill(4)}/"
                if not os.path.exists(fig_path):
                    os.makedirs(fig_path)
            else:
                fig_path = None

            if pref_addition != None:
                fig_name = fig_name + "_" + pref_addition
            
            plot_SL_diff(title_str, myVars[ref_name], myVars[ds_name], start_year, end_year, variant=variant,
                         cb_max=plot_max, icon=prof,
                         savefig=savefig, fig_dir=fig_path, prefix=fig_name,
                         verbose=extra_verbose)

            print(f"Done {fig_name}.")