Notebook with functions for plotting various quantities. This is designed to be used with the notebook read_and_calculate.ipynb.

In [1]:
import numpy as np
import xarray as xr

# modules for plotting datetime data
import matplotlib.dates as mdates
from matplotlib.axis import Axis

# modules for using datetime variables
import datetime
from datetime import time

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.gridspec as gridspec
import matplotlib.cm as cm

# for custom legend
from matplotlib.lines import Line2D
from matplotlib.patches import Patch

import matplotlib.image as mpimg
from matplotlib.offsetbox import OffsetImage, AnnotationBbox

from xgcm import Grid
%matplotlib inline
import warnings
warnings.filterwarnings('ignore')

import cartopy.crs as ccrs
import cmocean
import colorcet

import subprocess as sp

import matplotlib.ticker as mticker
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER

from matplotlib.ticker import ScalarFormatter

from xclim import ensembles

import cftime
from pandas.errors import OutOfBoundsDatetime  # Import the specific error

import os

# Temperature and temp anomaly plotting functions

## Temperature maps

In [2]:
# revised function with better colorbar
def plot_pp_temp_diff(prefix,title,pp_diff_da,z_idx,start_yr,end_yr,cb_max=None,hatching=False,icon=None,verbose=False):

    depth = pp_diff_da.coords['z_l'].values[z_idx]
    diff_da = pp_diff_da.temp.isel(z_l=z_idx)
    if hatching:
        hatch_mask = pp_diff_da.temp_disagree.isel(z_l=z_idx).isel(time=0)

    # Diagnostic min and max (not for setting bounds)
    min_val = float(np.nanmin(diff_da.values))
    max_val = float(np.nanmax(diff_da.values))

    # 0.5th and 99.5th percentiles
    per0p5 = float(np.nanpercentile(diff_da.values, 0.5))
    per99p5 = float(np.nanpercentile(diff_da.values, 99.5))
    if verbose:
        print(f"Full data min/max: {min_val:.3f}/{max_val:.3f}")
        print(f"Percentile-based max magnitude: {max(abs(per0p5), abs(per99p5)):.3f}")

    # Decide an initial data-based max
    if cb_max is not None:
        data_max = cb_max
    else:
        data_max = max(abs(per0p5), abs(per99p5))

    # Enforce a minimum of 1.0
    data_max = max(data_max, 1.0)
    print(f"data_max: {data_max}")

    # ------------------------------------------------------------------
    # Choose n (10..20) and stepCandidate (multiple of 0.2, >=0.2)
    # so that stepCandidate*n >= 2*data_max
    # We pick the largest n for which we can find a suitable stepCandidate.
    # ------------------------------------------------------------------
    chosen_n = None
    chosen_step = None
    
    for n in range(20, 9, -2):
        # The step we *need* to at least cover data_max
        stepNeeded = (2 * data_max) / n
        # Round that up to the nearest multiple of 0.2
        ceil_step = np.ceil(stepNeeded / 0.2) * 0.2
        floor_step = np.floor(stepNeeded / 0.2) * 0.2
        print(f"ceil_step, floor_step: {ceil_step}, {floor_step}")
        if abs(stepNeeded - ceil_step) > abs(stepNeeded - floor_step):
            continue
        else:
            stepCandidate = ceil_step
        # Enforce stepCandidate >= 0.2
        if stepCandidate < 0.2:
            continue

        # This candidate max_mag
        max_magCandidate = 0.5 * n * stepCandidate

        # By construction, max_magCandidate >= data_max
        # so if we get here, it's acceptable. We'll pick the first one from n=20 downward.
        chosen_n = n
        chosen_step = stepCandidate
        if verbose:
            print(f"Found suitable n={n}, step={chosen_step}, max_mag={max_magCandidate}")
        break

    # If none found (should not happen), just force n=12
    if chosen_n is None:
        chosen_n = 10
        chosen_step = 0.2
        if verbose:
            print("No feasible n found in [10..20]! Using fallback: n=10, step=0.2")

    max_mag = 0.5 * chosen_n * chosen_step  # final ± range
    vmin, vmax = -max_mag, max_mag

    # Build the boundaries array for the discrete colormap
    # We'll have (chosen_n+1) boundaries
    boundaries = vmin + np.arange(chosen_n + 1) * chosen_step
    # (Because chosen_n * chosen_step = total_range = 2*max_mag.)

    if verbose:
        print(f"Final chosen_n: {chosen_n}, step: {chosen_step}, vmin={vmin}, vmax={vmax}")
        print(f"Boundaries: {boundaries}")

    # Create the figure and projection
    fig, ax = plt.subplots(figsize=(12, 8),
                           subplot_kw={'projection': ccrs.Robinson(central_longitude=209.5),
                                       'facecolor': '0.75'})

    # Discretize the cmocean "balance" colormap into chosen_n segments
    base_cmap = cmocean.cm.balance
    newcolors = base_cmap(np.linspace(0, 1, chosen_n))
    disc_cmap = mcolors.LinearSegmentedColormap.from_list(
        'discrete_balance', newcolors, N=chosen_n
    )

    # Use BoundaryNorm with the chosen boundaries
    norm = mcolors.BoundaryNorm(boundaries, ncolors=chosen_n)
    
    diff_plot = diff_da.plot(#vmin=plot_min, vmax=plot_max,
                  x='geolon', y='geolat',
                  cmap=disc_cmap, norm=norm,
                  # subplot_kws=subplot_kws,
                      #You can pick any projection from the cartopy list but, whichever projection you use, you still have to set
                  transform=ccrs.PlateCarree(),
                  add_labels=False,
                  add_colorbar=False)

    if hatching:
        # Add hatching where all members agree on the sign
        plt.contourf(
            diff_da['geolon'], diff_da['geolat'], hatch_mask,
            levels=[0.5, 1.5],  # Define binary levels
            colors='none',  # No color, just hatching
            hatches=['////'],  # Hatching pattern
            transform=ccrs.PlateCarree()
        )
    
    diff_plot.axes.set_title(f"{title}\nYear {start_yr}–{end_yr}, z = {depth:,.2f} m",fontdict={'fontsize':16})
    
    # Determine the extend setting for the colorbar arrows
    if min_val < vmin and max_val > vmax:
        extend = 'both'
    elif min_val < vmin:
        extend = 'min'
    elif max_val > vmax:
        extend = 'max'
    else:
        extend = 'neither'
    diff_cb = plt.colorbar(diff_plot, shrink=0.6, extend=extend)

    # Decide how to label the ticks
    if chosen_n < 14:
        # Label all boundaries
        tick_arr = boundaries
    else:
        # Label every other boundary, but ensure 0 is labeled
        all_indices = np.arange(chosen_n + 1)
        # take every other boundary
        if (chosen_n/2) % 2 == 0: # if chosen_n is an even multiple of 2 (e.g. 12, 16)
            label_indices = all_indices[::2]
        else: # if chosen_n is an odd multiple of 2 (e.g. 10, 14)
            label_indices = all_indices[1::2]
        # # ensure 0 is in the labels
        # zero_index = np.argmin(np.abs(boundaries))
        # if zero_index not in label_indices:
        #     label_indices = np.sort(np.append(label_indices, zero_index))
        tick_arr = boundaries[label_indices]

    # Set tick labels for the colorbar.
    tick_labels = [f"{x:.1f}" for x in tick_arr]
    diff_cb.set_ticks(tick_arr)
    diff_cb.ax.set_yticklabels(tick_labels)
    diff_cb.ax.tick_params(labelsize=14)
    diff_cb.set_label("Temperature Anomaly ($\degree$C)", fontdict={'fontsize': 14})
    for t in diff_cb.ax.get_yticklabels():
        t.set_horizontalalignment('center')
        t.set_x(2.0 if vmax < 10 else 2.2)

    if icon is not None:
        image_path = f"/home/Kiera.Lowman/profile_icons/{icon}_icon.png"  # Replace with your image path
        img = mpimg.imread(image_path)
        # Create an OffsetImage
        imagebox = OffsetImage(img, zoom=0.15)  # Adjust zoom as needed
        ab = AnnotationBbox(imagebox, (0.05, 1.02), xycoords="axes fraction", frameon=False) # Set the image position (e.g., top-right corner)
        ax.add_artist(ab) # Add image to the figure

    plt.savefig(f'{prefix}_dT_{str(start_yr).zfill(4)}_{str(end_yr).zfill(4)}_z_{depth:.0f}.png', dpi=600, bbox_inches='tight')

In [3]:
# original function
# def plot_pp_temp_diff(prefix,title,pp_diff_da,z_idx,start_yr,end_yr,cb_max=None,hatching=False,icon=None,non_linear_cb=False,verbose=False):

#     depth = pp_diff_da.coords['z_l'].values[z_idx]
#     diff_da = pp_diff_da.temp.isel(z_l=z_idx)
    
#     if hatching:
#         hatch_mask = pp_diff_da.temp_disagree.isel(z_l=z_idx).isel(time=0)

#     # NOT used for plot bounds but used for colorbar arrows
#     min_val = np.nanmin(diff_da.values)
#     max_val = np.nanmax(diff_da.values)

#     # used for plot bounds
#     per0p5 = np.nanpercentile(diff_da.values,0.5)
#     per99p5 = np.nanpercentile(diff_da.values,99.5)
    
#     if verbose:
#         if np.abs(per0p5) > np.abs(per99p5):
#             print(f"0.5 to 99.5 percentile data max mag: {np.abs(per0p5):.3f}")
#         else:
#             print(f"0.5 to 99.5 percentile data max mag: {np.abs(per99p5):.3f}")

#     if verbose:
#         # print(f"Data min: {min_val:.3f}\t Data max: {max_val:.3f}")
#         if np.abs(min_val) > np.abs(max_val):
#             print(f"Full data max mag: {np.abs(min_val):.3f}")
#         else:
#             print(f"Full data max mag: {np.abs(max_val):.3f}")

#     # set plot bounds based on 0.5 and 99.5 percentile
#     if cb_max != None:
#         max_mag = cb_max
#     elif np.abs(per0p5) > np.abs(per99p5):
#         max_mag = np.abs(per0p5)
#     else:
#         max_mag = np.abs(per99p5)

#     # # set plot bounds based on min and max
#     # if cb_max != None:
#     #     max_mag = cb_max
#     # elif np.abs(min_val) > np.abs(max_val):
#     #     max_mag = np.abs(min_val)
#     # else:
#     #     max_mag = np.abs(max_val)
        
#     # setting plot min and max
#     if max_mag < 0.4:
#         plot_min = -0.4
#         plot_max = 0.4
#         num_ticks = int((plot_max-plot_min)/0.1) + 1
#         num_colors = int((plot_max-plot_min)/0.05)
#         tick_arr = np.linspace(plot_min,plot_max,num=num_ticks)
#         for i in range(0,len(tick_arr)):
#             tick_arr[i] = round(tick_arr[i]/0.1)*0.1
#     elif max_mag <= 0.5:
#         plot_min = -round(max_mag/0.1)*0.1
#         plot_max = round(max_mag/0.1)*0.1
#         num_ticks = int((plot_max-plot_min)/0.1) + 1
#         num_colors = int((plot_max-plot_min)/0.05)
#         tick_arr = np.linspace(plot_min,plot_max,num=num_ticks)
#         for i in range(0,len(tick_arr)):
#             tick_arr[i] = round(tick_arr[i]/0.1)*0.1
#     elif max_mag <= 1.1:
#         plot_min = -round(max_mag/0.2)*0.2
#         plot_max = round(max_mag/0.2)*0.2
#         num_ticks = int((plot_max-plot_min)/0.2) + 1
#         num_colors = int((plot_max-plot_min)/0.1)
#         tick_arr = np.linspace(plot_min,plot_max,num=num_ticks)
#         for i in range(0,len(tick_arr)):
#             tick_arr[i] = round(tick_arr[i]/0.1)*0.1
#     elif max_mag <= 2:
#         plot_min = -round(max_mag/0.4)*0.4
#         plot_max = round(max_mag/0.4)*0.4
#         num_ticks = int((plot_max-plot_min)/0.4) + 1
#         num_colors = int((plot_max-plot_min)/0.2)
#         tick_arr = np.linspace(plot_min,plot_max,num=num_ticks)
#         for i in range(0,len(tick_arr)):
#             tick_arr[i] = round(tick_arr[i]/0.1)*0.1
#     elif max_mag <= 3:
#         plot_min = -round(max_mag/0.5)*0.5
#         plot_max = round(max_mag/0.5)*0.5
#         num_ticks = int((plot_max-plot_min)/0.5) + 1
#         num_colors = int((plot_max-plot_min)/0.25)
#         tick_arr = np.linspace(plot_min,plot_max,num=num_ticks)
#         for i in range(0,len(tick_arr)):
#             tick_arr[i] = round(tick_arr[i]/0.1)*0.1
#     # elif max_mag <= 4.1:
#     #     plot_min = -round(max_mag)
#     #     plot_max = round(max_mag)
#     #     num_ticks = int(plot_max-plot_min) + 1
#     #     num_colors = int((plot_max-plot_min)/0.25)
#     #     tick_arr = np.linspace(plot_min,plot_max,num=num_ticks)
#     elif max_mag <= 6.1:
#         plot_min = -round(max_mag)
#         plot_max = round(max_mag)
#         num_ticks = int(plot_max-plot_min) + 1
#         num_colors = int((plot_max-plot_min)/0.5)
#         tick_arr = np.linspace(plot_min,plot_max,num=num_ticks)
#     elif max_mag <= 6.1:
#         plot_min = -round(max_mag)
#         plot_max = round(max_mag)
#         num_ticks = int(plot_max-plot_min) + 1
#         num_colors = int((plot_max-plot_min)/0.5)
#         tick_arr = np.linspace(plot_min,plot_max,num=num_ticks)
#     # elif max_mag <= 7.49:
#     #     plot_min = -round(max_mag)
#     #     plot_max = round(max_mag)
#     #     num_ticks = int((plot_max-plot_min)/2) + 1
#     #     num_colors = int((plot_max-plot_min)/0.5)
#     #     tick_arr = np.linspace(plot_min,plot_max,num=num_ticks)
#     elif max_mag <= 10:
#         plot_min = -np.ceil(max_mag/2)*2
#         plot_max = np.ceil(max_mag/2)*2
#         num_ticks = int((plot_max-plot_min)/2) + 1
#         num_colors = int((plot_max-plot_min)/1)
#         tick_arr = np.linspace(plot_min,plot_max,num=num_ticks)
#     elif max_mag <= 13:
#         plot_min = -np.ceil(max_mag/3)*3
#         plot_max = np.ceil(max_mag/3)*3
#         num_ticks = int((plot_max-plot_min)/3) + 1
#         num_colors = int((plot_max-plot_min)/1.5)
#         tick_arr = np.linspace(plot_min,plot_max,num=num_ticks)
#     elif max_mag <= 15:
#         plot_min = -np.ceil(max_mag/4)*4
#         plot_max = np.ceil(max_mag/4)*4
#         num_ticks = int((plot_max-plot_min)/4) + 1
#         num_colors = int((plot_max-plot_min)/2)
#         tick_arr = np.linspace(plot_min,plot_max,num=num_ticks)
#     else:
#         print("Warning: plot bounds more than +/- 15")

#     if verbose:
#         print(f"num_colors = {num_colors}") 
#         print(f"Plot min: {plot_min:.3f}\t Plot max: {plot_max:.3f}\n")

#     # plt.figure(figsize=[12, 8])
#     fig, ax = plt.subplots(figsize=(12, 8), subplot_kw={'projection': ccrs.Robinson(central_longitude=209.5), 'facecolor': '0.75'})

#     cmap = cmocean.cm.balance  # define the colormap
#     cmaplist = [cmap(i) for i in range(cmap.N)] # extract all colors from the balance map
#     # force the first color entry to be grey
#     # cmaplist[0] = (.5, .5, .5, 1.0)
    
#     # create the new map
#     disc_bal_cmap = mcolors.LinearSegmentedColormap.from_list('Custom cmap', cmaplist, cmap.N)
    
#     # define the bins and normalize
#     norm_bounds = np.linspace(plot_min, plot_max, num_colors + 1)
#     disc_norm = mcolors.BoundaryNorm(norm_bounds, cmap.N)
        
#     subplot_kws=dict(projection=ccrs.Robinson(central_longitude=209.5), facecolor='0.75') #projection=ccrs.PlateCarree(),facecolor='gray'
#     # projection=ccrs.Robinson(central_longitude=180)
    
#     if not non_linear_cb:
#         diff_plot = diff_da.plot(#vmin=plot_min, vmax=plot_max,
#                       x='geolon', y='geolat',
#                       cmap=disc_bal_cmap, norm=disc_norm,
#                       # subplot_kws=subplot_kws,
#                           #You can pick any projection from the cartopy list but, whichever projection you use, you still have to set
#                       transform=ccrs.PlateCarree(),
#                       add_labels=False,
#                       add_colorbar=False)

#     else:
#         norm = mcolors.SymLogNorm(linthresh=cb_max/2, linscale = 0.6, vmin=plot_min, vmax=plot_max, base=10)
        
#         diff_plot = diff_da.plot(vmin=plot_min, vmax=plot_max,
#                       x='geolon', y='geolat',
#                       cmap=cmocean.cm.balance, norm=norm,
#                       subplot_kws=subplot_kws,
#                           #You can pick any projection from the cartopy list but, whichever projection you use, you still have to set
#                       transform=ccrs.PlateCarree(),
#                       add_labels=False,
#                       add_colorbar=False)

#     if hatching:
#         # Add hatching where all members agree on the sign
#         plt.contourf(
#             diff_da['geolon'], diff_da['geolat'], hatch_mask,
#             levels=[0.5, 1.5],  # Define binary levels
#             colors='none',  # No color, just hatching
#             hatches=['////'],  # Hatching pattern
#             transform=ccrs.PlateCarree()
#         )
    
#     diff_plot.axes.set_title(f"{title}\nYear {start_yr}–{end_yr}, z = {depth:,.2f} m",fontdict={'fontsize':16})
    
#     # diff_cb = plt.colorbar(diff_plot, fraction=0.046, pad=0.04)

#     if min_val < plot_min and max_val > plot_max:
#         extend = 'both'
#     elif min_val < plot_min:
#         extend = 'min'
#     elif max_val > plot_max:
#         extend = 'max'
#     else:
#         extend = 'neither'
    
#     diff_cb = plt.colorbar(diff_plot, shrink=0.6, extend=extend)

#     tick_labels = [f"{x:.1f}" for x in tick_arr]
    
#     diff_cb.set_ticks(tick_arr)
#     diff_cb.ax.set_yticklabels(tick_labels)
#     diff_cb.ax.tick_params(labelsize=14)
#     diff_cb.set_label("Temperature Anomaly ($\degree$C)",fontdict={'fontsize':14})

#     for t in diff_cb.ax.get_yticklabels():
#         t.set_horizontalalignment('center')
#         if plot_max < 10:
#             t.set_x(2.0)
#         else:
#             t.set_x(2.2)

#     if icon is not None:
#         image_path = f"/home/Kiera.Lowman/profile_icons/{icon}_icon.png"  # Replace with your image path
#         img = mpimg.imread(image_path)
#         # Create an OffsetImage
#         imagebox = OffsetImage(img, zoom=0.15)  # Adjust zoom as needed
#         ab = AnnotationBbox(imagebox, (0.05, 1.02), xycoords="axes fraction", frameon=False) # Set the image position (e.g., top-right corner)
#         ax.add_artist(ab) # Add image to the figure

#     plt.savefig(f'{prefix}_dT_{str(start_yr).zfill(4)}_{str(end_yr).zfill(4)}_z_{depth:.0f}.png', dpi=600, bbox_inches='tight')

In [4]:
def create_temp_diff_plots(diff_type,fig_dir,start_year,end_year,z_idx,
                           profiles = ['surf','therm','mid','bot'],
                           prof_strings = ["Surf","Therm","Mid","Bot"],
                           power_var_suff = ['0p1TW', '0p2TW', '0p5TW'],
                           power_strings = ['0.1 TW', '0.2 TW', '0.5 TW'],
                           dT_max=None,
                           hatching=False,
                           extra_verbose=False):
    """
    Inputs:
    diff_type (str): one of
                    ['const-1860ctrl',
                    'doub-1860exp','doub-2xctrl','doub-1860ctrl',
                    'quad-1860exp','quad-4xctrl','quad-1860ctrl']
    fig_dir (str): directory path to save figure
    start_year (int): start year of avg period
    end_year (int): end year of avg period
    z_idx (int): z-index for depth of temp anomalies to plot
    dT_max (int/float): input for plot_pp_temp_diff
    hatching (boolean): input for plot_pp_temp_diff
    extra_verbose (boolean): input for plot_pp_temp_diff
    """

    if not os.path.exists(fig_dir):
        os.makedirs(fig_dir)
    
    for i, power_str in enumerate(power_strings):
        for j, prof in enumerate(profiles):

            if diff_type == 'const-1860ctrl':
                ds_root = f'const_{prof}_{power_var_suff[i]}_{start_year}_{end_year}_diff'
            elif (diff_type == 'doub-1860exp' or diff_type == 'doub-2xctrl' or diff_type == 'doub-1860ctrl'):
                ds_root = f'doub_{prof}_{power_var_suff[i]}_{start_year}_{end_year}_diff'
            elif (diff_type == 'quad-1860exp' or diff_type == 'quad-4xctrl' or diff_type == 'quad-1860ctrl'):
                ds_root = f'quad_{prof}_{power_var_suff[i]}_{start_year}_{end_year}_diff'
            
            if diff_type == 'const-1860ctrl':
                title_str = f"Const {prof_strings[j]} {power_str}"
                ds_name = ds_root
                fig_name = f"{prof}_{power_var_suff[i]}"
                
            elif diff_type == 'doub-1860exp':
                title_str = f"1pct2xCO2 — Const CO2: {prof_strings[j]} {power_str}"
                ds_name = f'{ds_root}_1860'
                fig_name = f"2xCO2-const_{prof}_{power_var_suff[i]}"
                
            elif diff_type == 'doub-2xctrl':
                title_str = f"1pct2xCO2 {prof_strings[j]} {power_str} — 1pct2xCO2 Control"
                ds_name = f'{ds_root}_2xctrl'
                fig_name = f"2xCO2-2xctrl_{prof}_{power_var_suff[i]}"
                
            elif diff_type == 'doub-1860ctrl':
                title_str = f"1pct2xCO2 {prof_strings[j]} {power_str} — Const CO2 Control"
                ds_name = f'{ds_root}_const_ctrl'
                fig_name = f"2xCO2-const-ctrl_{prof}_{power_var_suff[i]}"
                
            elif diff_type == 'quad-1860exp':
                title_str = f"1pct4xCO2 — Const CO2: {prof_strings[j]} {power_str}"
                ds_name = f'{ds_root}_1860'
                fig_name = f"4xCO2-const_{prof}_{power_var_suff[i]}"
                
            elif diff_type == 'quad-4xctrl':
                title_str = f"1pct4xCO2 {prof_strings[j]} {power_str} — 1pct4xCO2 Control"
                ds_name = f'{ds_root}_4xctrl'
                fig_name = f"4xCO2-4xctrl_{prof}_{power_var_suff[i]}"
                
            elif diff_type == 'quad-1860ctrl':
                title_str = f"1pct4xCO2 {prof_strings[j]} {power_str} — Const CO2 Control"
                ds_name = f'{ds_root}_const_ctrl'
                fig_name = f"4xCO2-const-ctrl_{prof}_{power_var_suff[i]}"
                
            fig_prefix = fig_dir+fig_name

            plot_pp_temp_diff(fig_prefix, title_str, myVars[ds_name], z_idx, \
                              start_year, end_year, cb_max=dT_max, hatching=hatching, icon=prof, verbose=extra_verbose)

            print(f"Done {fig_name}.")

In [5]:
def plot_pp_temp_mean(prefix,title,pp_temp_da,z_idx,start_yr,end_yr,verbose=False):

    depth = pp_temp_da.coords['z_l'].values[z_idx]
    run_da = pp_temp_da.isel(z_l=z_idx)
    
    min_val = np.nanmin(run_da.values)
    max_val = np.nanmax(run_da.values)
    
    if verbose:
        print(f"Data min: {min_val:.3f}\t Data max: {max_val:.3f}")

    plot_min = -2
    plot_max = 30
    num = int((plot_max-plot_min)/4) + 1
    tick_arr = np.linspace(plot_min,plot_max,num=num)
    
    num_colors = 4 * (num - 1)

    plt.figure(figsize=[12, 8])

    cmap = cmocean.cm.thermal  # define the colormap
    cmaplist = [cmap(i) for i in range(cmap.N)] # extract all colors from the balance map
    # force the first color entry to be grey
    # cmaplist[0] = (.5, .5, .5, 1.0)
    
    # create the new map
    disc_bal_cmap = mcolors.LinearSegmentedColormap.from_list('Custom cmap', cmaplist, cmap.N)
    
    # define the bins and normalize
    norm_bounds = np.linspace(plot_min, plot_max, num_colors + 1)
    disc_norm = mcolors.BoundaryNorm(norm_bounds, cmap.N)
        
    subplot_kws=dict(projection=ccrs.Robinson(central_longitude=209.5), facecolor='0.75')
    
    run_plot = run_da.plot(x='geolon', y='geolat',
                  cmap=disc_bal_cmap, norm=disc_norm,
                  subplot_kws=subplot_kws,
                      #You can pick any projection from the cartopy list but, whichever projection you use, you still have to set
                  transform=ccrs.PlateCarree(),
                  add_labels=False,
                  add_colorbar=False)

    run_plot.axes.set_title(f"{title}\nYear {start_yr}–{end_yr}, z = {depth:,.2f} m",fontdict={'fontsize':16})

    run_cb = plt.colorbar(run_plot, ticks=tick_arr, shrink=0.6, extend='both')
    run_cb.ax.tick_params(labelsize=14)
    run_cb.set_label("Temperature ($\degree$C)",fontdict={'fontsize':12})

    plt.show()

    plt.savefig(f'{prefix}_temp_{str(start_yr).zfill(4)}_{str(end_yr).zfill(4)}_z_{depth:.0f}.png', dpi=600, bbox_inches='tight')

In [6]:
# def transform_depth(z):
#     """
#     Custom transformation for depth axis:
#     - Expands top 1000 m
#     - Compresses lower depths
#     """
#     if z <= 1000:
#         return z  # No transformation for shallow depths
#     else:
#         return 1000 + (z - 1000) * 0.2  # Compress deeper depths

## Basin mean anomaly functions (T, N2, rhopot2)

In [7]:
def transform_depth(z, max_depth, axis_split):
    """
    Custom transformation for depth axis:
    - Expands top 1000 m
    - Compresses lower depths
    """
    compress_factor = axis_split/(max_depth - axis_split)
    
    # return np.where(z <= 1000, z, 1000 + (z - 1000) * 0.2)
    return np.where(z <= axis_split, z, axis_split + (z - axis_split) * compress_factor)

In [8]:
def plot_temp_diff_basin(title,diff_ds,basin_name,max_depth,axis_split,start_yr,end_yr,
                         check_nn=False,nn_threshold=0.05,cb_max=None,mask_dataset=None,
                         run_ds=None, # must be passed to plot density overlays
                         savefig=False,fig_dir=None,prefix=None,
                         verbose=False):

    if verbose and mask_dataset is None:
        print("mask_ds is none")

    if len(diff_ds.time.values) > 1:
        raise ValueError("diff_ds cannot be a time series.")
        
    diff_ds = diff_ds.isel(time=0)
    diff_dat = get_pp_basin_dat(diff_ds, "temp", basin_name, check_nn=check_nn, nn_threshold=nn_threshold,
                                mask_ds=mask_dataset)#, verbose=verbose)
    
    diff_dat = diff_dat.sel(z_l=slice(0,max_depth))

    # Apply transformation to depth coordinates
    transformed_z = xr.apply_ufunc(transform_depth, diff_dat.z_l, 
                               kwargs={"max_depth": max_depth, "axis_split": axis_split})
    diff_dat = diff_dat.assign_coords(z_l=transformed_z)

    if run_ds is not None:
        if len(run_ds.time.values) > 1:
            raise ValueError("run_ds cannot be a time series.")
        run_ds = run_ds.isel(time=0)
        density_dat = get_pp_basin_dat(run_ds, "rhopot2", basin_name, check_nn=check_nn, nn_threshold=nn_threshold,
                                   mask_ds=mask_dataset)#, verbose=verbose)
        density_dat = density_dat.sel(z_l=slice(0,max_depth))
        dens_transformed_z = xr.apply_ufunc(transform_depth, density_dat.z_l, 
                               kwargs={"max_depth": max_depth, "axis_split": axis_split})
        density_dat = density_dat.assign_coords(z_l=dens_transformed_z)

    # used for colorbar arrows
    min_val = np.nanmin(diff_dat.values)
    max_val = np.nanmax(diff_dat.values)

    # used for plot bounds
    p0p5 = np.nanpercentile(diff_dat.values,0.5)
    p99p5 = np.nanpercentile(diff_dat.values,99.5)
    
    if verbose:
        if np.abs(p0p5) > np.abs(p99p5):
            print(f"0.5 to 99.5th percentile data max mag: {np.abs(p0p5):.3f}")
        else:
            print(f"0.5 to 99.5th percentile data max mag: {np.abs(p99p5):.3f}")

    if cb_max != None:
        max_mag = cb_max
    elif np.abs(p0p5) > np.abs(p99p5):
        max_mag = np.abs(p0p5)
    else:
        max_mag = np.abs(p99p5)
        
    # setting plot min and max
    if max_mag <= 0.2:
        plot_min = -0.2
        plot_max = 0.2
        num_ticks = int((plot_max-plot_min)/0.05) + 1
        num_colors = int((plot_max-plot_min)/0.025)
        tick_arr = np.linspace(plot_min,plot_max,num=num_ticks)
        for i in range(0,len(tick_arr)):
            tick_arr[i] = round(tick_arr[i]/0.01)*0.01
    elif max_mag <= 0.6:
        plot_min = -round(max_mag/0.1)*0.1
        plot_max = round(max_mag/0.1)*0.1
        num_ticks = int((plot_max-plot_min)/0.1) + 1
        num_colors = int((plot_max-plot_min)/0.05)
        tick_arr = np.linspace(plot_min,plot_max,num=num_ticks)
        for i in range(0,len(tick_arr)):
            tick_arr[i] = round(tick_arr[i]/0.1)*0.1
    elif max_mag <= 1:
        plot_min = -round(max_mag/0.2)*0.2
        plot_max = round(max_mag/0.2)*0.2
        num_ticks = int((plot_max-plot_min)/0.2) + 1
        num_colors = int((plot_max-plot_min)/0.1)
        tick_arr = np.linspace(plot_min,plot_max,num=num_ticks)
        for i in range(0,len(tick_arr)):
            tick_arr[i] = round(tick_arr[i]/0.1)*0.1
    elif max_mag <= 2:
        plot_min = -round(max_mag/0.4)*0.4
        plot_max = round(max_mag/0.4)*0.4
        num_ticks = int((plot_max-plot_min)/0.4) + 1
        num_colors = int((plot_max-plot_min)/0.2)
        tick_arr = np.linspace(plot_min,plot_max,num=num_ticks)
        for i in range(0,len(tick_arr)):
            tick_arr[i] = round(tick_arr[i]/0.1)*0.1
    elif max_mag <= 2.5:
        plot_min = -round(max_mag/0.5)*0.5
        plot_max = round(max_mag/0.5)*0.5
        num_ticks = int((plot_max-plot_min)/0.5) + 1
        num_colors = int((plot_max-plot_min)/0.25)
        tick_arr = np.linspace(plot_min,plot_max,num=num_ticks)
        for i in range(0,len(tick_arr)):
            tick_arr[i] = round(tick_arr[i]/0.1)*0.1
    elif max_mag <= 3.3:
        plot_min = -round(max_mag)
        plot_max = round(max_mag)
        num_ticks = int((plot_max-plot_min)) + 1
        num_colors = int((plot_max-plot_min)/0.25)
        tick_arr = np.linspace(plot_min,plot_max,num=num_ticks)
        for i in range(0,len(tick_arr)):
            tick_arr[i] = round(tick_arr[i]/0.1)*0.1
    elif max_mag <= 6.1:
        plot_min = -round(max_mag)
        plot_max = round(max_mag)
        num_ticks = int(plot_max-plot_min) + 1
        num_colors = int((plot_max-plot_min)/0.5)
        tick_arr = np.linspace(plot_min,plot_max,num=num_ticks)
    elif max_mag <= 10:
        plot_min = -np.ceil(max_mag/2)*2
        plot_max = np.ceil(max_mag/2)*2
        num_ticks = int((plot_max-plot_min)/2) + 1
        num_colors = int((plot_max-plot_min)/1)
        tick_arr = np.linspace(plot_min,plot_max,num=num_ticks)
    elif max_mag <= 13:
        plot_min = -np.ceil(max_mag/3)*3
        plot_max = np.ceil(max_mag/3)*3
        num_ticks = int((plot_max-plot_min)/3) + 1
        num_colors = int((plot_max-plot_min)/1.5)
        tick_arr = np.linspace(plot_min,plot_max,num=num_ticks)
    elif max_mag <= 15:
        plot_min = -np.ceil(max_mag/4)*4
        plot_max = np.ceil(max_mag/4)*4
        num_ticks = int((plot_max-plot_min)/4) + 1
        num_colors = int((plot_max-plot_min)/2)
        tick_arr = np.linspace(plot_min,plot_max,num=num_ticks)
    else:
        print("Warning: plot bounds more than +/- 15")
        
    if verbose:
        print(f"num_colors = {num_colors}")  
        print(f"Plot bounds: {plot_min:.3f} to {plot_max:.3f}\n")

    cmap = cmocean.cm.balance  # define the colormap
    cmaplist = [cmap(i) for i in range(cmap.N)] # extract all colors from the balance map

    # create the new map
    disc_bal_cmap = mcolors.LinearSegmentedColormap.from_list('Custom cmap', cmaplist, cmap.N)
    
    # define the bins and normalize
    norm_bounds = np.linspace(plot_min, plot_max, num_colors + 1)
    disc_norm = mcolors.BoundaryNorm(norm_bounds, cmap.N)

    plt.figure(figsize=[11, 3.8])
    
    subplot_kws=dict(facecolor='grey')
    
    diff_p = diff_dat.plot(x='true_lat', y='z_l',
              cmap=disc_bal_cmap,
              norm=disc_norm,
              subplot_kws=subplot_kws,
                  #You can pick any projection from the cartopy list but, whichever projection you use, you still have to set
              # transform=ccrs.PlateCarree(),
              add_labels=False,
              add_colorbar=False)

    # diff_p.axes.set_facecolor('grey')

    # Define original depth values and their transformed positions
    if axis_split == 1000:
        if max_depth >= 6000:
            depth_labels = [0, 250, 500, 750, 1000, 2000, 3000, 4000, 5000, 6000]
        elif max_depth >= 5000:
            depth_labels = [0, 250, 500, 750, 1000, 2000, 3000, 4000, 5000]
        else:
            raise ValueError(f"{max_depth} not an acceptable value.")
    elif axis_split == 1500:
        if max_depth >= 6000:
            depth_labels = [0, 500, 1000, 1500, 2500, 3500, 4500, 5500]
        elif max_depth >= 5000:
            depth_labels = [0, 500, 1000, 1500, 2000, 3000, 4000, 5000]
        else:
            raise ValueError(f"{max_depth} not an acceptable value.")
    elif axis_split == 2000:
        if max_depth >= 6000:
            depth_labels = [0, 500, 1000, 1500, 2000, 3000, 4000, 5000, 6000]
        elif max_depth >= 5000:
            depth_labels = [0, 500, 1000, 1500, 2000, 3000, 4000, 5000]
        else:
            raise ValueError(f"{max_depth} not an acceptable value.")
    # elif axis_split == 2500:
    #     if max_depth >= 6000:
    #         depth_labels = [0, 500, 1000, 1500, 2000, 2500, 3000, 4000, 5000, 6000]
    #     elif max_depth >= 5000:
    #         depth_labels = [0, 1000, 2000, 3000, 4000, 5000]
    #     else:
    #         raise ValueError(f"{max_depth} not an acceptable value.")
    elif axis_split == 3000:
        if max_depth >= 6000:
            depth_labels = [0, 1000, 2000, 3000, 4000, 5000, 6000]
        elif max_depth >= 5000:
            depth_labels = [0, 1000, 2000, 3000, 4000, 5000]
        else:
            raise ValueError(f"{max_depth} not an acceptable value.")
    else:
        raise ValueError(f"{axis_split} not an acceptable value. Must be in list [1000, 1500, 2000, 3000]")

    depth_positions = [transform_depth(d, max_depth, axis_split) for d in depth_labels]
    
    # Apply these ticks
    diff_p.axes.set_yticks(ticks=depth_positions, labels=[str(d) for d in depth_labels], fontsize=14)

    if run_ds is not None:
        # Overlay contour lines on top of the filled contour plot
        ax = diff_p.axes  # Get the existing plot axis
        # contour_levels = np.linspace(1010, 1042, 10)  # Define contour levels
        min_rho = np.nanmin(density_dat.values)
        max_rho = np.nanmax(density_dat.values)
        print(f"Min and max density: {min_rho:.1f}, {max_rho:.1f}")
        print("Density dat:",density_dat)
        density_p = ax.contour(density_dat["true_lat"], density_dat["z_l"], density_dat, #levels=contour_levels, 
                               colors="k", linewidths=0.8)  # Black contour lines
        # Add contour labels
        ax.clabel(density_p, inline=True, fontsize=10)#fmt="%.2f"
    
    diff_p.axes.invert_yaxis()
    
    if min_val < plot_min and max_val > plot_max:
        extend = 'both'
    elif min_val < plot_min:
        extend = 'min'
    elif max_val > plot_max:
        extend = 'max'
    else:
        extend = 'neither'
        
    diff_cb = plt.colorbar(diff_p, ticks=tick_arr, fraction=0.046, pad=0.04, extend=extend)
    
    diff_cb.ax.tick_params(labelsize=14)
    diff_cb.set_label("Temperature Anomaly ($\degree$C)",fontdict={'fontsize':14})

    for t in diff_cb.ax.get_yticklabels():
        t.set_horizontalalignment('center')
        if max_mag <= 2.5:
            t.set_x(2.8)
        else:
            t.set_x(2.3)

    # if basin_name == "global":
    #     diff_p.axes.set_xlim(-60,60)
    #     diff_p.axes.set_xticks(ticks=[-60,-40,-20,0,20,40,60],
    #                           labels=['60$\degree$S','40$\degree$S','20$\degree$S','0$\degree$',
    #                                   '20$\degree$N','40$\degree$N','60$\degree$N'], fontsize=14)
    # # north atlantic focus
    # elif basin_name == "atl":
    #     diff_p.axes.set_xlim(-40,65)
    #     diff_p.axes.set_xticks(ticks=[-40,-20,0,20,40,60],
    #                           labels=['40$\degree$S','20$\degree$S','0$\degree$',
    #                                   '20$\degree$N','40$\degree$N','60$\degree$N'], fontsize=14)
    # else:
    #     diff_p.axes.set_xlim(-80,83)
    #     diff_p.axes.set_xticks(ticks=[-80,-60,-40,-20,0,20,40,60,80],
    #                           labels=['80$\degree$S','60$\degree$S','40$\degree$S','20$\degree$S','0$\degree$',
    #                                   '20$\degree$N','40$\degree$N','60$\degree$N','80$\degree$N'], fontsize=14)

    diff_p.axes.set_xlim(-80,70)
    diff_p.axes.set_xticks(ticks=[-80,-60,-40,-20,0,20,40,60],
                          labels=['80$\degree$S','60$\degree$S','40$\degree$S','20$\degree$S','0$\degree$',
                                  '20$\degree$N','40$\degree$N','60$\degree$N'], fontsize=14)
    
    diff_p.axes.tick_params(axis='y', labelsize=14)
        
    # diff_p.axes.set_xlabel('Latitude', fontsize=18)
    diff_p.axes.set_ylabel('Depth (m)', fontsize=14)
    diff_p.axes.set_title(f"{title}\nYear {start_yr}–{end_yr}",fontdict={'fontsize':14})

    if savefig is True:
        if fig_dir is None:
            raise ValueError("Must specify 'fig_dir' = <directory>.")
        if prefix is None:
            raise ValueError("Must specify prefix for figure file name.")
            
        if not os.path.exists(fig_dir):
            os.makedirs(fig_dir)

        plt.savefig(fig_dir + f'{prefix}_dT_{str(start_yr).zfill(4)}_{str(end_yr).zfill(4)}.png', dpi=600, bbox_inches='tight')

In [9]:
def plot_salt_diff_basin(title,diff_ds,basin_name,max_depth,axis_split,start_yr,end_yr,
                         check_nn=False,nn_threshold=0.05,cb_max=None,mask_dataset=None,
                         run_ds=None, # must be passed to plot density overlays
                         savefig=False,fig_dir=None,prefix=None,
                         verbose=False):

    if verbose and mask_dataset is None:
        print("mask_ds is none")

    if len(diff_ds.time.values) > 1:
        raise ValueError("diff_ds cannot be a time series.")
        
    diff_ds = diff_ds.isel(time=0)
    diff_dat = get_pp_basin_dat(diff_ds, "salt", basin_name, check_nn=check_nn, nn_threshold=nn_threshold,
                                mask_ds=mask_dataset)#, verbose=verbose)
    
    diff_dat = diff_dat.sel(z_l=slice(0,max_depth))

    # Apply transformation to depth coordinates
    transformed_z = xr.apply_ufunc(transform_depth, diff_dat.z_l, 
                               kwargs={"max_depth": max_depth, "axis_split": axis_split})
    diff_dat = diff_dat.assign_coords(z_l=transformed_z)

    if run_ds is not None:
        if len(run_ds.time.values) > 1:
            raise ValueError("run_ds cannot be a time series.")
        run_ds = run_ds.isel(time=0)
        density_dat = get_pp_basin_dat(run_ds, "rhopot2", basin_name, check_nn=check_nn, nn_threshold=nn_threshold,
                                   mask_ds=mask_dataset)#, verbose=verbose)
        density_dat = density_dat.sel(z_l=slice(0,max_depth))
        dens_transformed_z = xr.apply_ufunc(transform_depth, density_dat.z_l, 
                               kwargs={"max_depth": max_depth, "axis_split": axis_split})
        density_dat = density_dat.assign_coords(z_l=dens_transformed_z)

    # used for colorbar arrows
    min_val = np.nanmin(diff_dat.values)
    max_val = np.nanmax(diff_dat.values)

    # used for plot bounds
    p0p5 = np.nanpercentile(diff_dat.values,0.5)
    p99p5 = np.nanpercentile(diff_dat.values,99.5)
    
    if verbose:
        if np.abs(p0p5) > np.abs(p99p5):
            print(f"0.5 to 99.5th percentile data max mag: {np.abs(p0p5):.3e}")
        else:
            print(f"0.5 to 99.5th percentile data max mag: {np.abs(p99p5):.3e}")

    if cb_max != None:
        max_mag = cb_max
    elif np.abs(p0p5) > np.abs(p99p5):
        max_mag = np.abs(p0p5)
    else:
        max_mag = np.abs(p99p5)
        
    # setting plot min and max
    if max_mag <= 0.05:
        plot_min = -0.05
        plot_max = 0.05
        num_ticks = int((plot_max-plot_min)/0.01) + 1
        num_colors = int((plot_max-plot_min)/0.005)
        tick_arr = np.linspace(plot_min,plot_max,num=num_ticks)
        # print(tick_arr)
        for i in range(0,len(tick_arr)):
            tick_arr[i] = round(tick_arr[i]/0.01)*0.01
        # print("Rounded tick arr:", tick_arr)
    elif max_mag <= 0.12:
        plot_min = -0.1
        plot_max = 0.1
        num_ticks = int((plot_max-plot_min)/0.02) + 1
        num_colors = int((plot_max-plot_min)/0.01)
        tick_arr = np.linspace(plot_min,plot_max,num=num_ticks)
        # print(tick_arr)
        for i in range(0,len(tick_arr)):
            tick_arr[i] = round(tick_arr[i]/0.01)*0.01
        # print("Rounded tick arr:", tick_arr)
    elif max_mag <= 0.25:
        plot_min = -round(max_mag/0.05)*0.05
        plot_max = round(max_mag/0.05)*0.05
        num_ticks = int((plot_max-plot_min)/0.05) + 1
        num_colors = int((plot_max-plot_min)/0.025)
        tick_arr = np.linspace(plot_min,plot_max,num=num_ticks)
        print(tick_arr)
        for i in range(0,len(tick_arr)):
            tick_arr[i] = round(tick_arr[i]/0.01)*0.01
        print("Rounded tick arr:", tick_arr)
    elif max_mag <= 0.5:
        plot_min = -0.5
        plot_max = 0.5
        num_ticks = int((plot_max-plot_min)/0.1) + 1
        num_colors = int((plot_max-plot_min)/0.05)
        tick_arr = np.linspace(plot_min,plot_max,num=num_ticks)
        print(tick_arr)
        for i in range(0,len(tick_arr)):
            tick_arr[i] = round(tick_arr[i]/0.1)*0.1
        print("Rounded tick arr:", tick_arr)
    elif max_mag <= 1:
        plot_min = -round(max_mag/0.2)*0.2
        plot_max = round(max_mag/0.2)*0.2
        num_ticks = int((plot_max-plot_min)/0.2) + 1
        num_colors = int((plot_max-plot_min)/0.1)
        tick_arr = np.linspace(plot_min,plot_max,num=num_ticks)
        for i in range(0,len(tick_arr)):
            tick_arr[i] = round(tick_arr[i]/0.1)*0.1
    elif max_mag <= 2:
        plot_min = -round(max_mag/0.4)*0.4
        plot_max = round(max_mag/0.4)*0.4
        num_ticks = int((plot_max-plot_min)/0.4) + 1
        num_colors = int((plot_max-plot_min)/0.2)
        tick_arr = np.linspace(plot_min,plot_max,num=num_ticks)
        for i in range(0,len(tick_arr)):
            tick_arr[i] = round(tick_arr[i]/0.1)*0.1
    elif max_mag <= 2.5:
        plot_min = -round(max_mag/0.5)*0.5
        plot_max = round(max_mag/0.5)*0.5
        num_ticks = int((plot_max-plot_min)/0.5) + 1
        num_colors = int((plot_max-plot_min)/0.25)
        tick_arr = np.linspace(plot_min,plot_max,num=num_ticks)
        for i in range(0,len(tick_arr)):
            tick_arr[i] = round(tick_arr[i]/0.1)*0.1
    elif max_mag <= 3.3:
        plot_min = -round(max_mag)
        plot_max = round(max_mag)
        num_ticks = int((plot_max-plot_min)) + 1
        num_colors = int((plot_max-plot_min)/0.25)
        tick_arr = np.linspace(plot_min,plot_max,num=num_ticks)
        for i in range(0,len(tick_arr)):
            tick_arr[i] = round(tick_arr[i]/0.1)*0.1
    elif max_mag <= 6.1:
        plot_min = -round(max_mag)
        plot_max = round(max_mag)
        num_ticks = int(plot_max-plot_min) + 1
        num_colors = int((plot_max-plot_min)/0.5)
        tick_arr = np.linspace(plot_min,plot_max,num=num_ticks)
    # elif max_mag <= 10:
    #     plot_min = -np.ceil(max_mag/2)*2
    #     plot_max = np.ceil(max_mag/2)*2
    #     num_ticks = int((plot_max-plot_min)/2) + 1
    #     num_colors = int((plot_max-plot_min)/1)
    #     tick_arr = np.linspace(plot_min,plot_max,num=num_ticks)
    # elif max_mag <= 13:
    #     plot_min = -np.ceil(max_mag/3)*3
    #     plot_max = np.ceil(max_mag/3)*3
    #     num_ticks = int((plot_max-plot_min)/3) + 1
    #     num_colors = int((plot_max-plot_min)/1.5)
    #     tick_arr = np.linspace(plot_min,plot_max,num=num_ticks)
    # elif max_mag <= 15:
    #     plot_min = -np.ceil(max_mag/4)*4
    #     plot_max = np.ceil(max_mag/4)*4
    #     num_ticks = int((plot_max-plot_min)/4) + 1
    #     num_colors = int((plot_max-plot_min)/2)
    #     tick_arr = np.linspace(plot_min,plot_max,num=num_ticks)
    # else:
    #     print("Warning: plot bounds more than +/- 15")
        
    if verbose:
        print(f"num_colors = {num_colors}")  
        print(f"Plot bounds: {plot_min:.3e} to {plot_max:.3e}\n")

    cmap = cmocean.cm.balance  # define the colormap
    cmaplist = [cmap(i) for i in range(cmap.N)] # extract all colors from the balance map

    # create the new map
    disc_bal_cmap = mcolors.LinearSegmentedColormap.from_list('Custom cmap', cmaplist, cmap.N)
    
    # define the bins and normalize
    norm_bounds = np.linspace(plot_min, plot_max, num_colors + 1)
    disc_norm = mcolors.BoundaryNorm(norm_bounds, cmap.N)

    plt.figure(figsize=[11, 3.8])
    
    subplot_kws=dict(facecolor='grey')
    
    diff_p = diff_dat.plot(x='true_lat', y='z_l',
              # cmap = cmocean.cm.balance,
              cmap=disc_bal_cmap,
              norm=disc_norm,
              subplot_kws=subplot_kws,
                  #You can pick any projection from the cartopy list but, whichever projection you use, you still have to set
              # transform=ccrs.PlateCarree(),
              add_labels=False,
              add_colorbar=False)

    # diff_p.axes.set_facecolor('grey')

    # Define original depth values and their transformed positions
    if axis_split == 1000:
        if max_depth >= 6000:
            depth_labels = [0, 250, 500, 750, 1000, 2000, 3000, 4000, 5000, 6000]
        elif max_depth >= 5000:
            depth_labels = [0, 250, 500, 750, 1000, 2000, 3000, 4000, 5000]
        else:
            raise ValueError(f"{max_depth} not an acceptable value.")
    elif axis_split == 1500:
        if max_depth >= 6000:
            depth_labels = [0, 500, 1000, 1500, 2500, 3500, 4500, 5500]
        elif max_depth >= 5000:
            depth_labels = [0, 500, 1000, 1500, 2000, 3000, 4000, 5000]
        else:
            raise ValueError(f"{max_depth} not an acceptable value.")
    elif axis_split == 2000:
        if max_depth >= 6000:
            depth_labels = [0, 500, 1000, 1500, 2000, 3000, 4000, 5000, 6000]
        elif max_depth >= 5000:
            depth_labels = [0, 500, 1000, 1500, 2000, 3000, 4000, 5000]
        else:
            raise ValueError(f"{max_depth} not an acceptable value.")
    # elif axis_split == 2500:
    #     if max_depth >= 6000:
    #         depth_labels = [0, 500, 1000, 1500, 2000, 2500, 3000, 4000, 5000, 6000]
    #     elif max_depth >= 5000:
    #         depth_labels = [0, 1000, 2000, 3000, 4000, 5000]
    #     else:
    #         raise ValueError(f"{max_depth} not an acceptable value.")
    elif axis_split == 3000:
        if max_depth >= 6000:
            depth_labels = [0, 1000, 2000, 3000, 4000, 5000, 6000]
        elif max_depth >= 5000:
            depth_labels = [0, 1000, 2000, 3000, 4000, 5000]
        else:
            raise ValueError(f"{max_depth} not an acceptable value.")
    else:
        raise ValueError(f"{axis_split} not an acceptable value. Must be in list [1000, 1500, 2000, 3000]")

    depth_positions = [transform_depth(d, max_depth, axis_split) for d in depth_labels]
    
    # Apply these ticks
    diff_p.axes.set_yticks(ticks=depth_positions, labels=[str(d) for d in depth_labels], fontsize=14)

    if run_ds is not None:
        # Overlay contour lines on top of the filled contour plot
        ax = diff_p.axes  # Get the existing plot axis
        # contour_levels = np.linspace(1010, 1042, 10)  # Define contour levels
        min_rho = np.nanmin(density_dat.values)
        max_rho = np.nanmax(density_dat.values)
        print(f"Min and max density: {min_rho:.1f}, {max_rho:.1f}")
        print("Density dat:",density_dat)
        density_p = ax.contour(density_dat["true_lat"], density_dat["z_l"], density_dat, #levels=contour_levels, 
                               colors="k", linewidths=0.8)  # Black contour lines
        # Add contour labels
        ax.clabel(density_p, inline=True, fontsize=10)#fmt="%.2f"
    
    diff_p.axes.invert_yaxis()
    
    if min_val < plot_min and max_val > plot_max:
        extend = 'both'
    elif min_val < plot_min:
        extend = 'min'
    elif max_val > plot_max:
        extend = 'max'
    else:
        extend = 'neither'
        
    diff_cb = plt.colorbar(diff_p, ticks=tick_arr, fraction=0.046, pad=0.04, extend=extend)
    # diff_cb = plt.colorbar(diff_p, fraction=0.046, pad=0.04)
    
    diff_cb.ax.tick_params(labelsize=14)
    diff_cb.set_label("Salinity Anomaly (psu)",fontdict={'fontsize':14})

    for t in diff_cb.ax.get_yticklabels():
        t.set_horizontalalignment('center')
        if max_mag <= 2.5:
            t.set_x(2.8)
        else:
            t.set_x(2.3)

    # if basin_name == "global":
    #     diff_p.axes.set_xlim(-60,60)
    #     diff_p.axes.set_xticks(ticks=[-60,-40,-20,0,20,40,60],
    #                           labels=['60$\degree$S','40$\degree$S','20$\degree$S','0$\degree$',
    #                                   '20$\degree$N','40$\degree$N','60$\degree$N'], fontsize=14)
    # # north atlantic focus
    # elif basin_name == "atl":
    #     diff_p.axes.set_xlim(-40,65)
    #     diff_p.axes.set_xticks(ticks=[-40,-20,0,20,40,60],
    #                           labels=['40$\degree$S','20$\degree$S','0$\degree$',
    #                                   '20$\degree$N','40$\degree$N','60$\degree$N'], fontsize=14)
    # else:
    #     diff_p.axes.set_xlim(-80,83)
    #     diff_p.axes.set_xticks(ticks=[-80,-60,-40,-20,0,20,40,60,80],
    #                           labels=['80$\degree$S','60$\degree$S','40$\degree$S','20$\degree$S','0$\degree$',
    #                                   '20$\degree$N','40$\degree$N','60$\degree$N','80$\degree$N'], fontsize=14)

    diff_p.axes.set_xlim(-80,70)
    diff_p.axes.set_xticks(ticks=[-80,-60,-40,-20,0,20,40,60],
                          labels=['80$\degree$S','60$\degree$S','40$\degree$S','20$\degree$S','0$\degree$',
                                  '20$\degree$N','40$\degree$N','60$\degree$N'], fontsize=14)
    
    diff_p.axes.tick_params(axis='y', labelsize=14)
        
    # diff_p.axes.set_xlabel('Latitude', fontsize=18)
    diff_p.axes.set_ylabel('Depth (m)', fontsize=14)
    diff_p.axes.set_title(f"{title}\nYear {start_yr}–{end_yr}",fontdict={'fontsize':14})

    if savefig is True:
        if fig_dir is None:
            raise ValueError("Must specify 'fig_dir' = <directory>.")
        if prefix is None:
            raise ValueError("Must specify prefix for figure file name.")
            
        if not os.path.exists(fig_dir):
            os.makedirs(fig_dir)

        plt.savefig(fig_dir + f'{prefix}_salt_{str(start_yr).zfill(4)}_{str(end_yr).zfill(4)}.png', dpi=600, bbox_inches='tight')

In [10]:
# adapted from plot_temp_diff_basin

def plot_N2_diff_basin(title,diff_ds,basin_name,max_depth,axis_split,start_yr,end_yr,
                       check_nn=False,nn_threshold=0.05,cb_max=None,mask_dataset=None,
                       run_ds=None, # must be passed to plot density overlays
                       savefig=False,fig_dir=None,prefix=None,
                       verbose=False):

    if verbose and mask_dataset is None:
        print("mask_ds is none")

    if len(diff_ds.time.values) > 1:
        raise ValueError("diff_ds cannot be a time series.")
        
    diff_ds = diff_ds.isel(time=0)

    if verbose:
        min_N2 = np.nanmin(diff_ds.N2.values)
        max_N2 = np.nanmax(diff_ds.N2.values)
        print(f"Min and max N2: {min_N2:.3e}, {max_N2:.3e}")
    
    diff_dat = get_pp_basin_dat(diff_ds, "N2", basin_name, check_nn=check_nn, nn_threshold=nn_threshold,
                                mask_ds=mask_dataset)#, verbose=verbose)
    diff_dat = diff_dat.sel(z_i=slice(0,max_depth))

    # Apply transformation to depth coordinates
    transformed_z = xr.apply_ufunc(transform_depth, diff_dat.z_i, 
                               kwargs={"max_depth": max_depth, "axis_split": axis_split})
    diff_dat = diff_dat.assign_coords(z_i=transformed_z)

    if run_ds is not None:
        if len(run_ds.time.values) > 1:
            raise ValueError("run_ds cannot be a time series.")
        run_ds = run_ds.isel(time=0)
        density_dat = get_pp_basin_dat(run_ds, "rhopot2", basin_name, check_nn=check_nn, nn_threshold=nn_threshold,
                                   mask_ds=mask_dataset)#, verbose=verbose)
        density_dat = density_dat.sel(z_i=slice(0,max_depth))
        dens_transformed_z = xr.apply_ufunc(transform_depth, density_dat.z_i, 
                               kwargs={"max_depth": max_depth, "axis_split": axis_split})
        density_dat = density_dat.assign_coords(z_i=dens_transformed_z)

    # used for colorbar arrows
    min_val = np.nanmin(diff_dat.values)
    max_val = np.nanmax(diff_dat.values)

    # used for plot bounds
    p0p5 = np.nanpercentile(diff_dat.values,0.5)
    p99p5 = np.nanpercentile(diff_dat.values,99.5)
    
    if verbose:
        if np.abs(p0p5) > np.abs(p99p5):
            print(f"0.5 to 99.5th percentile data max mag: {np.abs(p0p5):.3f}")
        else:
            print(f"0.5 to 99.5th percentile data max mag: {np.abs(p99p5):.3f}")

    if cb_max != None:
        max_mag = cb_max
    elif np.abs(p0p5) > np.abs(p99p5):
        max_mag = np.abs(p0p5)
    else:
        max_mag = np.abs(p99p5)

    if verbose:
        print(f"Basin mean min and max N2: {min_val:.3e}, {max_val:.3e}")
        
    # # setting plot min and max
    # if max_mag <= 1:
    #     plot_min = -round(max_mag/0.2)*0.2
    #     plot_max = round(max_mag/0.2)*0.2
    #     num_ticks = int((plot_max-plot_min)/0.2) + 1
    #     num_colors = int((plot_max-plot_min)/0.1)
    #     tick_arr = np.linspace(plot_min,plot_max,num=num_ticks)
    #     for i in range(0,len(tick_arr)):
    #         tick_arr[i] = round(tick_arr[i]/0.1)*0.1
    # elif max_mag <= 2:
    #     plot_min = -round(max_mag/0.4)*0.4
    #     plot_max = round(max_mag/0.4)*0.4
    #     num_ticks = int((plot_max-plot_min)/0.4) + 1
    #     num_colors = int((plot_max-plot_min)/0.2)
    #     tick_arr = np.linspace(plot_min,plot_max,num=num_ticks)
    #     for i in range(0,len(tick_arr)):
    #         tick_arr[i] = round(tick_arr[i]/0.1)*0.1
    # else:
    #     print("Warning: plot bounds more than +/- 15")
        
    # if verbose:
    #     print(f"num_colors = {num_colors}")  
    #     print(f"Plot bounds: {plot_min:.3f} to {plot_max:.3f}\n")

    # cmap = cmocean.cm.balance  # define the colormap
    # cmaplist = [cmap(i) for i in range(cmap.N)] # extract all colors from the balance map

    # # create the new map
    # disc_bal_cmap = mcolors.LinearSegmentedColormap.from_list('Custom cmap', cmaplist, cmap.N)
    
    # # define the bins and normalize
    # norm_bounds = np.linspace(plot_min, plot_max, num_colors + 1)
    # disc_norm = mcolors.BoundaryNorm(norm_bounds, cmap.N)

    plt.figure(figsize=[11, 3.8])
    
    subplot_kws=dict(facecolor='grey')
    
    diff_p = diff_dat.plot(x='true_lat', y='z_i', vmin=-max_mag, vmax=max_mag,#robust=True
              # cmap=disc_bal_cmap,
              cmap=cmocean.cm.curl,
              # norm=disc_norm,
              subplot_kws=subplot_kws,
              add_labels=False,
              add_colorbar=False)
    
    # diff_p.axes.set_facecolor('grey')

    # Define original depth values and their transformed positions
    if axis_split == 1000:
        if max_depth >= 6000:
            depth_labels = [0, 250, 500, 750, 1000, 2000, 3000, 4000, 5000, 6000]
        elif max_depth >= 5000:
            depth_labels = [0, 250, 500, 750, 1000, 2000, 3000, 4000, 5000]
        else:
            raise ValueError(f"{max_depth} not an acceptable value.")
    elif axis_split == 1500:
        if max_depth >= 6000:
            depth_labels = [0, 500, 1000, 1500, 2500, 3500, 4500, 5500]
        elif max_depth >= 5000:
            depth_labels = [0, 500, 1000, 1500, 2000, 3000, 4000, 5000]
        else:
            raise ValueError(f"{max_depth} not an acceptable value.")
    elif axis_split == 2000:
        if max_depth >= 6000:
            depth_labels = [0, 500, 1000, 1500, 2000, 3000, 4000, 5000, 6000]
        elif max_depth >= 5000:
            depth_labels = [0, 500, 1000, 1500, 2000, 3000, 4000, 5000]
        else:
            raise ValueError(f"{max_depth} not an acceptable value.")
    # elif axis_split == 2500:
    #     if max_depth >= 6000:
    #         depth_labels = [0, 500, 1000, 1500, 2000, 2500, 3000, 4000, 5000, 6000]
    #     elif max_depth >= 5000:
    #         depth_labels = [0, 1000, 2000, 3000, 4000, 5000]
    #     else:
    #         raise ValueError(f"{max_depth} not an acceptable value.")
    elif axis_split == 3000:
        if max_depth >= 6000:
            depth_labels = [0, 1000, 2000, 3000, 4000, 5000, 6000]
        elif max_depth >= 5000:
            depth_labels = [0, 1000, 2000, 3000, 4000, 5000]
        else:
            raise ValueError(f"{max_depth} not an acceptable value.")
    else:
        raise ValueError(f"{axis_split} not an acceptable value. Must be in list [1000, 1500, 2000, 3000]")

    depth_positions = [transform_depth(d, max_depth, axis_split) for d in depth_labels]
    
    # Apply these ticks
    diff_p.axes.set_yticks(ticks=depth_positions, labels=[str(d) for d in depth_labels], fontsize=14)

    if run_ds is not None:
        # Overlay contour lines on top of the filled contour plot
        ax = diff_p.axes  # Get the existing plot axis
        # contour_levels = np.linspace(1010, 1042, 10)  # Define contour levels
        min_rho = np.nanmin(density_dat.values)
        max_rho = np.nanmax(density_dat.values)
        print(f"Min and max density: {min_rho:.1f}, {max_rho:.1f}")
        print("Density dat:",density_dat)
        density_p = ax.contour(density_dat["true_lat"], density_dat["z_i"], density_dat, #levels=contour_levels, 
                               colors="k", linewidths=0.8)  # Black contour lines
        # Add contour labels
        ax.clabel(density_p, inline=True, fontsize=10)#fmt="%.2f"
    
    diff_p.axes.invert_yaxis()
    
    # if min_val < plot_min and max_val > plot_max:
    #     extend = 'both'
    # elif min_val < plot_min:
    #     extend = 'min'
    # elif max_val > plot_max:
    #     extend = 'max'
    # else:
    #     extend = 'neither'
        
    # diff_cb = plt.colorbar(diff_p, ticks=tick_arr, fraction=0.046, pad=0.04, extend=extend)
    diff_cb = plt.colorbar(diff_p, fraction=0.046, pad=0.04)
    diff_cb.formatter.set_powerlimits((0, 0))
    
    diff_cb.ax.tick_params(labelsize=14)
    diff_cb.set_label("N$^2$ Anomaly (s$^{-1}$)",fontdict={'fontsize':14})

    for t in diff_cb.ax.get_yticklabels():
        t.set_horizontalalignment('center')
        t.set_x(2.6)
        # if max_mag <= 2e-5:
        #     t.set_x(2.6)
        # else:
        #     t.set_x(2.0)
    
    diff_p.axes.set_xlim(-80,70)
    diff_p.axes.set_xticks(ticks=[-80,-60,-40,-20,0,20,40,60],
                          labels=['80$\degree$S','60$\degree$S','40$\degree$S','20$\degree$S','0$\degree$',
                                  '20$\degree$N','40$\degree$N','60$\degree$N'], fontsize=14)
    
    diff_p.axes.tick_params(axis='y', labelsize=14)
    diff_p.axes.set_ylabel('Depth (m)', fontsize=14)
    diff_p.axes.set_title(f"{title}\nYear {start_yr}–{end_yr}",fontdict={'fontsize':14})

    if savefig is True:
        if fig_dir is None:
            raise ValueError("Must specify 'fig_dir' = <directory>.")
        if prefix is None:
            raise ValueError("Must specify prefix for figure file name.")
            
        if not os.path.exists(fig_dir):
            os.makedirs(fig_dir)

        plt.savefig(fig_dir + f'{prefix}_N2_anom_{str(start_yr).zfill(4)}_{str(end_yr).zfill(4)}.png', dpi=600, bbox_inches='tight')

In [11]:
# adapted from plot_temp_diff_basin

def plot_rhopot2_diff_basin(title,diff_ds,basin_name,max_depth,axis_split,start_yr,end_yr,
                       check_nn=False,nn_threshold=0.05,cb_max=None,mask_dataset=None,
                       savefig=False,fig_dir=None,prefix=None,
                       verbose=False):

    if verbose and mask_dataset is None:
        print("mask_ds is none")

    if len(diff_ds.time.values) > 1:
        raise ValueError("diff_ds cannot be a time series.")
        
    diff_ds = diff_ds.isel(time=0)

    if verbose:
        min_rhopot2 = np.nanmin(diff_ds.rhopot2.values)
        max_rhopot2 = np.nanmax(diff_ds.rhopot2.values)
        print(f"Min and max rhopot2: {min_rhopot2:.3e}, {max_rhopot2:.3e}")
    
    diff_dat = get_pp_basin_dat(diff_ds, "rhopot2", basin_name, check_nn=check_nn, nn_threshold=nn_threshold,
                                mask_ds=mask_dataset)#, verbose=verbose)
    diff_dat = diff_dat.sel(z_l=slice(0,max_depth))

    # Apply transformation to depth coordinates
    transformed_z = xr.apply_ufunc(transform_depth, diff_dat.z_l, 
                               kwargs={"max_depth": max_depth, "axis_split": axis_split})
    diff_dat = diff_dat.assign_coords(z_l=transformed_z)

    # used for colorbar arrows
    min_val = np.nanmin(diff_dat.values)
    max_val = np.nanmax(diff_dat.values)

    # used for plot bounds
    p0p5 = np.nanpercentile(diff_dat.values,0.5)
    p99p5 = np.nanpercentile(diff_dat.values,99.5)
    
    if verbose:
        if np.abs(p0p5) > np.abs(p99p5):
            print(f"0.5 to 99.5th percentile data max mag: {np.abs(p0p5):.3f}")
        else:
            print(f"0.5 to 99.5th percentile data max mag: {np.abs(p99p5):.3f}")

    if cb_max != None:
        max_mag = cb_max
    elif np.abs(p0p5) > np.abs(p99p5):
        max_mag = np.abs(p0p5)
    else:
        max_mag = np.abs(p99p5)

    if verbose:
        print(f"Basin mean min and max rhopot2: {min_val:.3e}, {max_val:.3e}")
        
    # # setting plot min and max
    # if max_mag <= 1:
    #     plot_min = -round(max_mag/0.2)*0.2
    #     plot_max = round(max_mag/0.2)*0.2
    #     num_ticks = int((plot_max-plot_min)/0.2) + 1
    #     num_colors = int((plot_max-plot_min)/0.1)
    #     tick_arr = np.linspace(plot_min,plot_max,num=num_ticks)
    #     for i in range(0,len(tick_arr)):
    #         tick_arr[i] = round(tick_arr[i]/0.1)*0.1
    # elif max_mag <= 2:
    #     plot_min = -round(max_mag/0.4)*0.4
    #     plot_max = round(max_mag/0.4)*0.4
    #     num_ticks = int((plot_max-plot_min)/0.4) + 1
    #     num_colors = int((plot_max-plot_min)/0.2)
    #     tick_arr = np.linspace(plot_min,plot_max,num=num_ticks)
    #     for i in range(0,len(tick_arr)):
    #         tick_arr[i] = round(tick_arr[i]/0.1)*0.1
    # else:
    #     print("Warning: plot bounds more than +/- 15")
        
    # if verbose:
    #     print(f"num_colors = {num_colors}")  
    #     print(f"Plot bounds: {plot_min:.3f} to {plot_max:.3f}\n")

    # cmap = cmocean.cm.balance  # define the colormap
    # cmaplist = [cmap(i) for i in range(cmap.N)] # extract all colors from the balance map

    # # create the new map
    # disc_bal_cmap = mcolors.LinearSegmentedColormap.from_list('Custom cmap', cmaplist, cmap.N)
    
    # # define the bins and normalize
    # norm_bounds = np.linspace(plot_min, plot_max, num_colors + 1)
    # disc_norm = mcolors.BoundaryNorm(norm_bounds, cmap.N)

    plt.figure(figsize=[11, 3.8])
    
    subplot_kws=dict(facecolor='grey')
    
    diff_p = diff_dat.plot(x='true_lat', y='z_l', vmin=-max_mag, vmax=max_mag,#robust=True
              # cmap=disc_bal_cmap,
              cmap=cmocean.cm.delta,
              # norm=disc_norm,
              subplot_kws=subplot_kws,
              add_labels=False,
              add_colorbar=False)
    
    # diff_p.axes.set_facecolor('grey')

    # Define original depth values and their transformed positions
    if axis_split == 1000:
        if max_depth >= 6000:
            depth_labels = [0, 250, 500, 750, 1000, 2000, 3000, 4000, 5000, 6000]
        elif max_depth >= 5000:
            depth_labels = [0, 250, 500, 750, 1000, 2000, 3000, 4000, 5000]
        else:
            raise ValueError(f"{max_depth} not an acceptable value.")
    elif axis_split == 1500:
        if max_depth >= 6000:
            depth_labels = [0, 500, 1000, 1500, 2500, 3500, 4500, 5500]
        elif max_depth >= 5000:
            depth_labels = [0, 500, 1000, 1500, 2000, 3000, 4000, 5000]
        else:
            raise ValueError(f"{max_depth} not an acceptable value.")
    elif axis_split == 2000:
        if max_depth >= 6000:
            depth_labels = [0, 500, 1000, 1500, 2000, 3000, 4000, 5000, 6000]
        elif max_depth >= 5000:
            depth_labels = [0, 500, 1000, 1500, 2000, 3000, 4000, 5000]
        else:
            raise ValueError(f"{max_depth} not an acceptable value.")
    # elif axis_split == 2500:
    #     if max_depth >= 6000:
    #         depth_labels = [0, 500, 1000, 1500, 2000, 2500, 3000, 4000, 5000, 6000]
    #     elif max_depth >= 5000:
    #         depth_labels = [0, 1000, 2000, 3000, 4000, 5000]
    #     else:
    #         raise ValueError(f"{max_depth} not an acceptable value.")
    elif axis_split == 3000:
        if max_depth >= 6000:
            depth_labels = [0, 1000, 2000, 3000, 4000, 5000, 6000]
        elif max_depth >= 5000:
            depth_labels = [0, 1000, 2000, 3000, 4000, 5000]
        else:
            raise ValueError(f"{max_depth} not an acceptable value.")
    else:
        raise ValueError(f"{axis_split} not an acceptable value. Must be in list [1000, 1500, 2000, 3000]")

    depth_positions = [transform_depth(d, max_depth, axis_split) for d in depth_labels]
    
    # Apply these ticks
    diff_p.axes.set_yticks(ticks=depth_positions, labels=[str(d) for d in depth_labels], fontsize=14)
    
    diff_p.axes.invert_yaxis()
    
    # if min_val < plot_min and max_val > plot_max:
    #     extend = 'both'
    # elif min_val < plot_min:
    #     extend = 'min'
    # elif max_val > plot_max:
    #     extend = 'max'
    # else:
    #     extend = 'neither'
        
    # diff_cb = plt.colorbar(diff_p, ticks=tick_arr, fraction=0.046, pad=0.04, extend=extend)
    diff_cb = plt.colorbar(diff_p, fraction=0.046, pad=0.04)
    
    diff_cb.ax.tick_params(labelsize=14)
    diff_cb.set_label("$\sigma_2$ Anomaly (kg/m$^{3}$)",fontdict={'fontsize':14})

    for t in diff_cb.ax.get_yticklabels():
        t.set_horizontalalignment('center')
        t.set_x(2.6)
        # if max_mag <= 2e-5:
        #     t.set_x(2.6)
        # else:
        #     t.set_x(2.0)
    
    diff_p.axes.set_xlim(-80,70)
    diff_p.axes.set_xticks(ticks=[-80,-60,-40,-20,0,20,40,60],
                          labels=['80$\degree$S','60$\degree$S','40$\degree$S','20$\degree$S','0$\degree$',
                                  '20$\degree$N','40$\degree$N','60$\degree$N'], fontsize=14)
    
    diff_p.axes.tick_params(axis='y', labelsize=14)
    diff_p.axes.set_ylabel('Depth (m)', fontsize=14)
    diff_p.axes.set_title(f"{title}\nYear {start_yr}–{end_yr}",fontdict={'fontsize':14})

    if savefig is True:
        if fig_dir is None:
            raise ValueError("Must specify 'fig_dir' = <directory>.")
        if prefix is None:
            raise ValueError("Must specify prefix for figure file name.")
            
        if not os.path.exists(fig_dir):
            os.makedirs(fig_dir)

        plt.savefig(fig_dir + f'{prefix}_rhopot2_anom_{str(start_yr).zfill(4)}_{str(end_yr).zfill(4)}.png', dpi=600, bbox_inches='tight')

### Basin mean temp anom remapped from rho- to z-space

In [12]:
# function for plotting temp anomaly in z-space by remapping from rho-space data
# I got it to work using x_mean=True in the call to calc_zrho_dat() in get_pp_zrho_basin_dat(), but it looked very different from z-space output

def plot_zrho_temp_diff_basin(title,diff_ds,static_ds,basin_name,#max_depth,axis_split,
                              start_yr,end_yr,
                             check_nn=False,nn_threshold=0.05,cb_max=None,mask_dataset=None,
                             run_ds=None,
                             savefig=False,fig_dir=None,prefix=None,
                             verbose=False):

    if verbose and mask_dataset is None:
        print("mask_ds is none")

    if len(diff_ds.time.values) > 1:
        raise ValueError("diff_ds cannot be a time series.")
        
    diff_ds = diff_ds.isel(time=0)
    
    diff_dat, depth_field = get_pp_zrho_basin_dat(diff_ds, static_ds, 'cent', ["temp", "volcello", "dxt", "dyt", "wet"], basin_name, check_nn=check_nn, nn_threshold=nn_threshold,
                                mask_ds=mask_dataset)#, verbose=verbose)
    
    # print("\ndepth_field:\n", depth_field)
    # print(f"\ndiff_dat.temp: {diff_dat.temp}")
    
    # diff_dat.coords['depth'] = depth_field
    # print(f"diff_dat.coords: {diff_dat.coords}")

    # used for colorbar arrows
    min_val = np.nanmin(diff_dat["temp"].values)
    max_val = np.nanmax(diff_dat["temp"].values)

    # used for plot bounds
    p0p5 = np.nanpercentile(diff_dat["temp"].values,0.5)
    p99p5 = np.nanpercentile(diff_dat["temp"].values,99.5)
    
    if verbose:
        if np.abs(p0p5) > np.abs(p99p5):
            print(f"0.5 to 99.5th percentile data max mag: {np.abs(p0p5):.3f}")
        else:
            print(f"0.5 to 99.5th percentile data max mag: {np.abs(p99p5):.3f}")

    if cb_max != None:
        max_mag = cb_max
    elif np.abs(p0p5) > np.abs(p99p5):
        max_mag = np.abs(p0p5)
    else:
        max_mag = np.abs(p99p5)
        
    # setting plot min and max
    if max_mag <= 1:
        plot_min = -round(max_mag/0.2)*0.2
        plot_max = round(max_mag/0.2)*0.2
        num_ticks = int((plot_max-plot_min)/0.2) + 1
        num_colors = int((plot_max-plot_min)/0.1)
        tick_arr = np.linspace(plot_min,plot_max,num=num_ticks)
        for i in range(0,len(tick_arr)):
            tick_arr[i] = round(tick_arr[i]/0.1)*0.1
    elif max_mag <= 2:
        plot_min = -round(max_mag/0.4)*0.4
        plot_max = round(max_mag/0.4)*0.4
        num_ticks = int((plot_max-plot_min)/0.4) + 1
        num_colors = int((plot_max-plot_min)/0.2)
        tick_arr = np.linspace(plot_min,plot_max,num=num_ticks)
        for i in range(0,len(tick_arr)):
            tick_arr[i] = round(tick_arr[i]/0.1)*0.1
    elif max_mag <= 2.5:
        plot_min = -round(max_mag/0.5)*0.5
        plot_max = round(max_mag/0.5)*0.5
        num_ticks = int((plot_max-plot_min)/0.5) + 1
        num_colors = int((plot_max-plot_min)/0.25)
        tick_arr = np.linspace(plot_min,plot_max,num=num_ticks)
        for i in range(0,len(tick_arr)):
            tick_arr[i] = round(tick_arr[i]/0.1)*0.1
    elif max_mag <= 3.3:
        plot_min = -round(max_mag)
        plot_max = round(max_mag)
        num_ticks = int((plot_max-plot_min)) + 1
        num_colors = int((plot_max-plot_min)/0.25)
        tick_arr = np.linspace(plot_min,plot_max,num=num_ticks)
        for i in range(0,len(tick_arr)):
            tick_arr[i] = round(tick_arr[i]/0.1)*0.1
    elif max_mag <= 6.1:
        plot_min = -round(max_mag)
        plot_max = round(max_mag)
        num_ticks = int(plot_max-plot_min) + 1
        num_colors = int((plot_max-plot_min)/0.5)
        tick_arr = np.linspace(plot_min,plot_max,num=num_ticks)
    elif max_mag <= 10:
        plot_min = -np.ceil(max_mag/2)*2
        plot_max = np.ceil(max_mag/2)*2
        num_ticks = int((plot_max-plot_min)/2) + 1
        num_colors = int((plot_max-plot_min)/1)
        tick_arr = np.linspace(plot_min,plot_max,num=num_ticks)
    elif max_mag <= 13:
        plot_min = -np.ceil(max_mag/3)*3
        plot_max = np.ceil(max_mag/3)*3
        num_ticks = int((plot_max-plot_min)/3) + 1
        num_colors = int((plot_max-plot_min)/1.5)
        tick_arr = np.linspace(plot_min,plot_max,num=num_ticks)
    elif max_mag <= 15:
        plot_min = -np.ceil(max_mag/4)*4
        plot_max = np.ceil(max_mag/4)*4
        num_ticks = int((plot_max-plot_min)/4) + 1
        num_colors = int((plot_max-plot_min)/2)
        tick_arr = np.linspace(plot_min,plot_max,num=num_ticks)
    else:
        print("Warning: plot bounds more than +/- 15")
        
    if verbose:
        print(f"num_colors = {num_colors}")  
        print(f"Plot bounds: {plot_min:.3f} to {plot_max:.3f}\n")

    cmap = cmocean.cm.balance  # define the colormap
    cmaplist = [cmap(i) for i in range(cmap.N)] # extract all colors from the balance map

    # create the new map
    disc_bal_cmap = mcolors.LinearSegmentedColormap.from_list('Custom cmap', cmaplist, cmap.N)
    
    # define the bins and normalize
    norm_bounds = np.linspace(plot_min, plot_max, num_colors + 1)
    disc_norm = mcolors.BoundaryNorm(norm_bounds, cmap.N)

    plt.figure(figsize=[11, 3.8])
    
    subplot_kws=dict(facecolor='black')
    
    diff_p = diff_dat["temp"].plot(x='true_lat', y='depth',
              cmap=disc_bal_cmap,
              norm=disc_norm,
              subplot_kws=subplot_kws,
                  #You can pick any projection from the cartopy list but, whichever projection you use, you still have to set
              # transform=ccrs.PlateCarree(),
              add_labels=False,
              add_colorbar=False)

    # # Define original depth values and their transformed positions
    # if axis_split == 1000:
    #     if max_depth >= 6000:
    #         depth_labels = [0, 250, 500, 750, 1000, 2000, 3000, 4000, 5000, 6000]
    #     elif max_depth >= 5000:
    #         depth_labels = [0, 250, 500, 750, 1000, 2000, 3000, 4000, 5000]
    #     else:
    #         raise ValueError(f"{max_depth} not an acceptable value.")
    # elif axis_split == 2000:
    #     if max_depth >= 6000:
    #         depth_labels = [0, 500, 1000, 1500, 2000, 3000, 4000, 5000, 6000]
    #     elif max_depth >= 5000:
    #         depth_labels = [0, 500, 1000, 1500, 2000, 3000, 4000, 5000]
    #     else:
    #         raise ValueError(f"{max_depth} not an acceptable value.")
    # elif axis_split == 2500:
    #     if max_depth >= 5000:
    #         depth_labels = [0, 1000, 2000, 3000, 4000, 5000]
    #     else:
    #         raise ValueError(f"{max_depth} not an acceptable value.")
    # elif axis_split == 3000:
    #     if max_depth >= 6000:
    #         depth_labels = [0, 1000, 2000, 3000, 4000, 5000, 6000]
    #     elif max_depth >= 5000:
    #         depth_labels = [0, 1000, 2000, 3000, 4000, 5000]
    #     else:
    #         raise ValueError(f"{max_depth} not an acceptable value.")
    # else:
    #     raise ValueError(f"{axis_split} not an acceptable value. Must be in list [1000, 2000, 3000]")

    # depth_positions = [transform_depth(d, max_depth, axis_split) for d in depth_labels]
    
    # # Apply these ticks
    # diff_p.axes.set_yticks(ticks=depth_positions, labels=[str(d) for d in depth_labels], fontsize=14)
    
    diff_p.axes.invert_yaxis()
    
    if min_val < plot_min and max_val > plot_max:
        extend = 'both'
    elif min_val < plot_min:
        extend = 'min'
    elif max_val > plot_max:
        extend = 'max'
    else:
        extend = 'neither'
        
    diff_cb = plt.colorbar(diff_p, ticks=tick_arr, fraction=0.046, pad=0.04, extend=extend)
    
    diff_cb.ax.tick_params(labelsize=14)
    diff_cb.set_label("Temperature Anomaly ($\degree$C)",fontdict={'fontsize':14})

    for t in diff_cb.ax.get_yticklabels():
        t.set_horizontalalignment('center')
        if max_mag <= 2.5:
            t.set_x(2.8)
        else:
            t.set_x(2.3)

    # if basin_name == "global":
    #     diff_p.axes.set_xlim(-60,60)
    #     diff_p.axes.set_xticks(ticks=[-60,-40,-20,0,20,40,60],
    #                           labels=['60$\degree$S','40$\degree$S','20$\degree$S','0$\degree$',
    #                                   '20$\degree$N','40$\degree$N','60$\degree$N'], fontsize=14)
    # # north atlantic focus
    # elif basin_name == "atl":
    #     diff_p.axes.set_xlim(-40,65)
    #     diff_p.axes.set_xticks(ticks=[-40,-20,0,20,40,60],
    #                           labels=['40$\degree$S','20$\degree$S','0$\degree$',
    #                                   '20$\degree$N','40$\degree$N','60$\degree$N'], fontsize=14)
    # else:
    #     diff_p.axes.set_xlim(-80,83)
    #     diff_p.axes.set_xticks(ticks=[-80,-60,-40,-20,0,20,40,60,80],
    #                           labels=['80$\degree$S','60$\degree$S','40$\degree$S','20$\degree$S','0$\degree$',
    #                                   '20$\degree$N','40$\degree$N','60$\degree$N','80$\degree$N'], fontsize=14)

    # diff_p.axes.set_xlim(-80,83)
    diff_p.axes.set_xticks(ticks=[-80,-60,-40,-20,0,20,40,60,80],
                          labels=['80$\degree$S','60$\degree$S','40$\degree$S','20$\degree$S','0$\degree$',
                                  '20$\degree$N','40$\degree$N','60$\degree$N','80$\degree$N'], fontsize=14)
    
    diff_p.axes.tick_params(axis='y', labelsize=14)
        
    # diff_p.axes.set_xlabel('Latitude', fontsize=18)
    diff_p.axes.set_ylabel('Depth (m)', fontsize=14)
    diff_p.axes.set_title(f"{title}\nYear {start_yr}–{end_yr}",fontdict={'fontsize':14})

    if savefig is True:
        if fig_dir is None:
            raise ValueError("Must specify 'fig_dir' = <directory>.")
        if prefix is None:
            raise ValueError("Must specify prefix for figure file name.")
            
        if not os.path.exists(fig_dir):
            os.makedirs(fig_dir)

        plt.savefig(fig_dir + f'{prefix}_dT_{str(start_yr).zfill(4)}_{str(end_yr).zfill(4)}.png', dpi=600, bbox_inches='tight')

## Linearity of temp difference plotting

In [13]:
def plot_pp_linearity_temp_diff(prefix,title,pp_diff_small,pp_diff_big,linear_factor,z_idx,start_yr,end_yr,cb_max=None,non_linear_cb=False,verbose=False):

    depth = pp_diff_small.coords['z_l'].values[z_idx]
    
    small_diff = pp_diff_small.isel(z_l=z_idx)
    big_diff = pp_diff_big.isel(z_l=z_idx)

    lin_diff = big_diff - linear_factor * small_diff
    
    min_val = np.nanmin(lin_diff.values)
    max_val = np.nanmax(lin_diff.values)
    
    if verbose:
        print(f"Data min: {min_val:.3f}\t Data max: {max_val:.3f}")

    if cb_max != None:
        max_mag = cb_max
    elif np.abs(min_val) > np.abs(max_val):
        max_mag = np.abs(min_val)
    else:
        max_mag = np.abs(max_val)
        
    # setting plot min and max
    if max_mag <= 0.9:
        plot_min = -round(max_mag/0.1)*0.1
        plot_max = round(max_mag/0.1)*0.1
        num = int((plot_max-plot_min)/0.1) + 1
        tick_arr = np.linspace(plot_min,plot_max,num=num)
        for i in range(0,len(tick_arr)):
            tick_arr[i] = round(tick_arr[i]/0.1)*0.1
    elif max_mag <= 1.4:
        plot_min = -round(max_mag/0.2)*0.2
        plot_max = round(max_mag/0.2)*0.2
        num = int((plot_max-plot_min)/0.2) + 1
        tick_arr = np.linspace(plot_min,plot_max,num=num)
        for i in range(0,len(tick_arr)):
            tick_arr[i] = round(tick_arr[i]/0.1)*0.1
    elif max_mag <= 3.3:
        plot_min = -round(max_mag/0.4)*0.4
        plot_max = round(max_mag/0.4)*0.4
        num = int((plot_max-plot_min)/0.4) + 1
        tick_arr = np.linspace(plot_min,plot_max,num=num)
        # tick_arr = np.arange(plot_min,plot_max+0.4,0.4)
        for i in range(0,len(tick_arr)):
            tick_arr[i] = round(tick_arr[i]/0.1)*0.1
    elif max_mag < 5:
        plot_min = -round(max_mag)
        plot_max = round(max_mag)
        num = int((plot_max-plot_min)) + 1
        tick_arr = np.linspace(plot_min,plot_max,num=num)
    elif max_mag < 8:
        plot_min = -np.ceil(max_mag/2)*2
        plot_max = np.ceil(max_mag/2)*2
        num = int((plot_max-plot_min)/2) + 1
        tick_arr = np.linspace(plot_min,plot_max,num=num)
    elif max_mag < 12:
        plot_min = -np.ceil(max_mag/3)*3
        plot_max = np.ceil(max_mag/3)*3
        num = int((plot_max-plot_min)/3) + 1
        tick_arr = np.linspace(plot_min,plot_max,num=num)
    elif max_mag < 15:
        plot_min = -np.ceil(max_mag/4)*4
        plot_max = np.ceil(max_mag/4)*4
        num = int((plot_max-plot_min)/4) + 1
        tick_arr = np.linspace(plot_min,plot_max,num=num)
    else:
        plot_min = -16
        plot_max = 16
        tick_arr = [-16, -12, -8, -4, 0, 4 , 8, 12, 16]
        num = int((plot_max-plot_min)/4) + 1
        
    num_colors = 2 * (num - 1)

    if verbose:
        print(f"num = {num}\t num_colors = {num_colors}")  
        print(f"Plot min: {plot_min:.3f}\t Plot max: {plot_max:.3f}")

    plt.figure(figsize=[12, 8])

    cmap = cmocean.cm.balance  # define the colormap
    cmaplist = [cmap(i) for i in range(cmap.N)] # extract all colors from the balance map
    # force the first color entry to be grey
    # cmaplist[0] = (.5, .5, .5, 1.0)
    
    # create the new map
    disc_bal_cmap = mcolors.LinearSegmentedColormap.from_list('Custom cmap', cmaplist, cmap.N)
    
    # define the bins and normalize
    norm_bounds = np.linspace(plot_min, plot_max, num_colors + 1)
    disc_norm = mcolors.BoundaryNorm(norm_bounds, cmap.N)
        
    subplot_kws=dict(projection=ccrs.Robinson(central_longitude=209.5), facecolor='0.75') #projection=ccrs.PlateCarree(),facecolor='gray'
    # projection=ccrs.Robinson(central_longitude=180)
    
    if non_linear_cb == False:
        diff_plot = lin_diff.plot(#vmin=plot_min, vmax=plot_max,
                      x='geolon', y='geolat',
                      cmap=disc_bal_cmap, norm=disc_norm,
                      subplot_kws=subplot_kws,
                          #You can pick any projection from the cartopy list but, whichever projection you use, you still have to set
                      transform=ccrs.PlateCarree(),
                      add_labels=False,
                      add_colorbar=False)

    elif non_linear_cb == True:
        norm = mcolors.SymLogNorm(linthresh=cb_max/2, linscale = 0.6, vmin=plot_min, vmax=plot_max, base=10)
        
        diff_plot = lin_diff.plot(vmin=plot_min, vmax=plot_max,
                      x='geolon', y='geolat',
                      cmap=cmocean.cm.balance, norm=norm,
                      subplot_kws=subplot_kws,
                          #You can pick any projection from the cartopy list but, whichever projection you use, you still have to set
                      transform=ccrs.PlateCarree(),
                      add_labels=False,
                      add_colorbar=False)
    
    # diff_plot.axes.coastlines()
    diff_plot.axes.set_title(f"{title}: Year {start_yr}–{end_yr}, z = {depth:,.2f} m",fontdict={'fontsize':18})
    
    # diff_cb = plt.colorbar(diff_plot, fraction=0.046, pad=0.04)
    diff_cb = plt.colorbar(diff_plot, shrink=0.6, extend='both')

    tick_labels = [f"{x:.1f}" for x in tick_arr]
    
    diff_cb.set_ticks(tick_arr)
    diff_cb.ax.set_yticklabels(tick_labels)
    diff_cb.ax.tick_params(labelsize=14)
    diff_cb.set_label("Temperature Anomaly ($\degree$C)",fontdict={'fontsize':14})

    for t in diff_cb.ax.get_yticklabels():
        t.set_horizontalalignment('center')   
        t.set_x(2.0)

    plt.savefig(f'{prefix}_dT_{str(start_yr).zfill(4)}_{str(end_yr).zfill(4)}_z_{depth:.0f}.png', dpi=600, bbox_inches='tight')

# Basin mean N2 change plotting

# Diffusivity plotting functions

## Global mean profiles

In [14]:
# plotting Kd variable with continuous y-axis

def plot_Kd_cont_yaxis(co2_scen,fig_dir,start_yr,end_yr,Kd_var,max_Kd,
                       max_z = 6250, 
                       profiles = ['surf','therm','mid','bot'], 
                       power_inputs = ['0.1TW', '0.2TW', '0.5TW'], 
                       power_var_suff = ['0p1TW', '0p2TW', '0p5TW'], 
                       power_strings = ['0.1 TW', '0.2 TW', '0.5 TW']):

    if not os.path.exists(fig_dir):
        os.makedirs(fig_dir)

    sci_formatter = ScalarFormatter(useMathText=True)
    sci_formatter.set_scientific(True)
    sci_formatter.set_powerlimits((0, 0))

    # depth = myVars[f"{co2_scen}_{profiles[0]}_{power_var_suff[0]}_{start_yr}_{end_yr}_mean"]['z_i']
    depth = myVars[f"{co2_scen}_ctrl_{start_yr}_{end_yr}_mean"]['z_i']

    prof_colors = ['b','m','g','r']
    
    if Kd_var == "Kd_int_tuned":
        title_pref = r"Global Mean $\kappa_{\mathregular{add}}$"
    elif Kd_var == "Kd_interface":
        title_pref = r"Global Mean $\kappa_{\mathregular{tot}}$"
    elif Kd_var == "Kd_int_base":
        title_pref = r"Global Mean $\kappa_{\mathregular{base}}$"

    if co2_scen == "const":
        co2_str = "Const CO2"
    elif co2_scen == "doub":
        co2_str = "1pct2xCO2"
    elif co2_scen == "quad":
        co2_str = "1pct4xCO2"
        
    # plot for each power input
    for pow_idx in range(len(power_var_suff)):
        fig, ax = plt.subplots(figsize=(5,6))

        if (Kd_var == "Kd_interface" or Kd_var == "Kd_int_base"):
            ax.plot(myVars[f"{co2_scen}_ctrl_{start_yr}_{end_yr}_mean"]["Kd_interface"][0,:],depth,label='control',color='k')
        
        for i in range(len(profiles)):
            ax.plot(myVars[f"{co2_scen}_{profiles[i]}_{power_var_suff[pow_idx]}_{start_yr}_{end_yr}_mean"][Kd_var][0,:],
                    depth,label=f'{profiles[i]}',color=prof_colors[i])

        ax.set_xlabel(r"$\kappa_d$ (m/s$^2$)")
        ax.set_ylabel("Depth (m)")
        
        ax.xaxis.set_major_formatter(sci_formatter)
        ax.set_xlim(0,max_Kd)
        ax.set_ylim(0,max_z)
        ax.invert_yaxis()
        if (Kd_var == "Kd_interface" or Kd_var == "Kd_int_base"):
            ax.legend(loc='lower right', fontsize=10, labelspacing=0.1)
        else:
            ax.legend(loc='best', fontsize=10, labelspacing=0.1)
        ax.grid("both")
        ax.minorticks_on()
        ax.grid(which='major', linestyle='-', linewidth='0.5', color='gray')
        ax.set_title(title_pref+f" {power_strings[pow_idx]} {co2_str}:\nYear {start_yr} to {end_yr}")
        
        plt.savefig(fig_dir+f'{Kd_var}_{co2_scen}_{power_var_suff[pow_idx]}_{str(start_yr).zfill(4)}_{str(end_yr).zfill(4)}.png', dpi=600, bbox_inches='tight')

    # plot with all data
    fig, ax = plt.subplots(figsize=(5,6))
    
    line_list = ['solid','dashed','dotted']

    if (Kd_var == "Kd_interface" or Kd_var == "Kd_int_base"):
        custom_leg_1 = [Line2D([0], [0], color='k', lw=2),
                Line2D([0], [0], color='b', lw=2),
                Line2D([0], [0], color='m', lw=2),
                Line2D([0], [0], color='g', lw=2),
                Line2D([0], [0], color='r', lw=2)]
        custom_leg_2 = [Line2D([0], [0], linestyle=line_list[0], lw=2, color='k'),
                Line2D([0], [0], linestyle=line_list[1], lw=2, color='k'),
                Line2D([0], [0], linestyle=line_list[2], lw=2, color='k')]
        
        leg_labels_1 = ['control']
        for elem in profiles:
            leg_labels_1.append(elem)
            
        leg_labels_2 = copy.deepcopy(power_strings)
        
        ax.plot(myVars[f"{co2_scen}_ctrl_{start_yr}_{end_yr}_mean"]["Kd_interface"][0,:],depth,label='control',color='k')
        
    else:
        custom_leg = [Line2D([0], [0], color='b', lw=2),
                Line2D([0], [0], color='m', lw=2),
                Line2D([0], [0], color='g', lw=2),
                Line2D([0], [0], color='r', lw=2),
                Line2D([0], [0], linestyle=line_list[0], lw=2, color='k'),
                Line2D([0], [0], linestyle=line_list[1], lw=2, color='k'),
                Line2D([0], [0], linestyle=line_list[2], lw=2, color='k')]
        leg_labels = copy.deepcopy(profiles)
        for elem in power_strings:
            leg_labels.append(elem)
    
    for pow_idx, power_str in enumerate(power_strings):
        for i in range(len(profiles)):
            ax.plot(myVars[f"{co2_scen}_{profiles[i]}_{power_var_suff[pow_idx]}_{start_yr}_{end_yr}_mean"][Kd_var][0,:],depth,label=f'{power_str} {profiles[i]}',
                    linestyle=line_list[pow_idx],color=prof_colors[i])
        
        # for i in range(len(profiles)):
        #     ax.plot(myVars[f"{profiles[i]}_0p2TW_{start_yr}_{end_yr}_Kdadd_mean"][0,:],depth,label=f'0.2 TW {profiles[i]}',linestyle='dashed',color=prof_colors[i])
        
        # for i in range(len(profiles)):
        #     ax.plot(myVars[f"{profiles[i]}_0p5TW_{start_yr}_{end_yr}_Kdadd_mean"][0,:],depth,label=f'0.5 TW {profiles[i]}',linestyle='dotted',color=prof_colors[i])
        
    ax.set_xlabel(r"$\kappa_d$ (m/s$^2$)")
    ax.set_ylabel("Depth (m)")
    
    ax.xaxis.set_major_formatter(sci_formatter)
    ax.set_xlim(0,max_Kd)
    ax.set_ylim(0,max_z)
    ax.invert_yaxis()
    # ax.legend(loc='best',ncol=2)
    if (Kd_var == "Kd_interface" or Kd_var == "Kd_int_base"):
        # First legend (5 labels)
        legend1 = ax.legend(
            custom_leg_1, leg_labels_1,
            loc='lower right',
            fontsize=10, labelspacing=0.1,
            bbox_to_anchor=(0.7, 0.0),  # Adjust position as needed
            frameon=True
        )
        # Second legend (3 labels, positioned below the first)
        legend2 = ax.legend(
            custom_leg_2, leg_labels_2,
            loc='lower right',
            fontsize=10, labelspacing=0.1,
            bbox_to_anchor=(1.0, 0.0),  # Adjust position as needed
            frameon=True
        )
        
        # Add the first legend back to the axis
        ax.add_artist(legend1)
    else:
        ax.legend(custom_leg, leg_labels, loc='best', fontsize=10, ncol = 2, labelspacing=0.1)
    
    ax.grid("both")
    ax.minorticks_on()
    ax.grid(which='major', linestyle='-', linewidth='0.5', color='gray')
    ax.set_title(title_pref+f" {co2_str}:\nYear {start_yr} to {end_yr}")
    
    plt.savefig(fig_dir+f'{Kd_var}_{co2_scen}_{str(start_yr).zfill(4)}_{str(end_yr).zfill(4)}.png', dpi=600, bbox_inches='tight')

In [15]:
# plotting Kd variable with split axis plot (abrupt change in y-axis)

def plot_Kd_split_yaxis(co2_scen,fig_dir,start_yr,end_yr,Kd_var,max_Kd,
                        axis_break=850,
                        max_z = 6250,
                       profiles = ['surf','therm','mid','bot'], 
                       power_inputs = ['0.1TW', '0.2TW', '0.5TW'], 
                       power_var_suff = ['0p1TW', '0p2TW', '0p5TW'], 
                       power_strings = ['0.1 TW', '0.2 TW', '0.5 TW']):

    if not os.path.exists(fig_dir):
        os.makedirs(fig_dir)

    sci_formatter = ScalarFormatter(useMathText=True)
    sci_formatter.set_scientific(True)
    sci_formatter.set_powerlimits((0, 0))

    # depth = myVars[f"{co2_scen}_{profiles[0]}_{power_var_suff[0]}_{start_yr}_{end_yr}_mean"]['z_i']
    depth = myVars[f"{co2_scen}_ctrl_{start_yr}_{end_yr}_mean"]['z_i']

    prof_colors = ['b','m','g','r']
    
    if Kd_var == "Kd_int_tuned":
        title_pref = r"Global Mean $\kappa_{\mathregular{add}}$"
    elif Kd_var == "Kd_interface":
        title_pref = r"Global Mean $\kappa_{\mathregular{tot}}$"
    elif Kd_var == "Kd_int_base":
        title_pref = r"Global Mean $\kappa_{\mathregular{base}}$"

    if co2_scen == "const":
        co2_str = "Const CO2"
    elif co2_scen == "doub":
        co2_str = "1pct2xCO2"
    elif co2_scen == "quad":
        co2_str = "1pct4xCO2"

    # plot for each power input
    for pow_idx in range(len(power_var_suff)):
        # Create a figure with GridSpec
        fig = plt.figure(figsize=(5,6))
        gs = GridSpec(2, 1, height_ratios=[1, 1], hspace=0)  # Adjust height_ratios
        
        # Top subplot
        ax1 = fig.add_subplot(gs[0])
    
        if (Kd_var == "Kd_interface" or Kd_var == "Kd_int_base"):
            ax1.plot(myVars[f"{co2_scen}_ctrl_{start_yr}_{end_yr}_mean"]["Kd_interface"][0,:].sel(z_i=slice(0,axis_break)),
                     depth.sel(z_i=slice(0,axis_break)),label='control',color='k')
    
        for i in range(len(profiles)):
            ax1.plot(myVars[f"{co2_scen}_{profiles[i]}_{power_var_suff[pow_idx]}_{start_yr}_{end_yr}_mean"][Kd_var][0,:].sel(z_i=slice(0,axis_break)),
                    depth.sel(z_i=slice(0,axis_break)),label=f'{profiles[i]}',color=prof_colors[i])
    
        ax1.spines['bottom'].set_visible(False)  # Hide bottom spine
        ax1.tick_params(bottom=True, labelbottom=False)  # Enable ticks but hide labels
        ax1.set_ylim(0, axis_break)
        ax1.invert_yaxis()
        
        # Bottom subplot
        ax2 = fig.add_subplot(gs[1])
    
        if (Kd_var == "Kd_interface" or Kd_var == "Kd_int_base"):
            ax2.plot(myVars[f"{co2_scen}_ctrl_{start_yr}_{end_yr}_mean"]["Kd_interface"][0,:].sel(z_i=slice(axis_break,None)),
                     depth.sel(z_i=slice(axis_break,None)),label='control',color='k')

        for i in range(len(profiles)):
            ax2.plot(myVars[f"{co2_scen}_{profiles[i]}_{power_var_suff[pow_idx]}_{start_yr}_{end_yr}_mean"][Kd_var][0,:].sel(z_i=slice(axis_break,None)),
                    depth.sel(z_i=slice(axis_break,None)),label=f'{profiles[i]}',color=prof_colors[i])
    
        ax2.set_ylim(axis_break,max_z)
        ax2.invert_yaxis()
        
        # Synchronize the x-axis limits
        ax1.xaxis.set_major_formatter(sci_formatter)
        ax2.xaxis.set_major_formatter(sci_formatter)
        ax1.set_xlim(0,max_Kd)
        ax2.set_xlim(0,max_Kd)
    
        ax1.grid("both")
        ax2.grid("both")
        ax1.grid(which='major', linestyle='-', linewidth='0.5', color='gray')
        ax2.grid(which='major', linestyle='-', linewidth='0.5', color='gray')
        
        if (Kd_var == "Kd_interface" or Kd_var == "Kd_int_base"):
            ax2.legend(loc='lower right', fontsize=10, labelspacing=0.1)
        else:
            ax2.legend(loc='best', fontsize=10, labelspacing=0.1)
        
        ax2.set_xlabel(r"$\kappa_d$ (m/s$^2$)")
        fig.text(0, 0.5, "Depth (m)", va='center', rotation='vertical')
        ax1.set_title(title_pref+f" {power_strings[pow_idx]} {co2_str}:\nYear {start_yr} to {end_yr}")
        
        plt.savefig(fig_dir+f'{Kd_var}_{co2_scen}_{power_var_suff[pow_idx]}_{str(start_yr).zfill(4)}_{str(end_yr).zfill(4)}.png', dpi=600, bbox_inches='tight')

    
    # plot with all data
    # Create a figure with GridSpec
    fig = plt.figure(figsize=(5,6))
    gs = GridSpec(2, 1, height_ratios=[1, 1], hspace=0)  # Adjust height_ratios

    line_list = ['solid','dashed','dotted']
    
    # Top subplot
    ax1 = fig.add_subplot(gs[0])

    if (Kd_var == "Kd_interface" or Kd_var == "Kd_int_base"):
        custom_leg_1 = [Line2D([0], [0], color='k', lw=2),
                Line2D([0], [0], color='b', lw=2),
                Line2D([0], [0], color='m', lw=2),
                Line2D([0], [0], color='g', lw=2),
                Line2D([0], [0], color='r', lw=2)]
        custom_leg_2 = [Line2D([0], [0], linestyle=line_list[0], lw=2, color='k'),
                Line2D([0], [0], linestyle=line_list[1], lw=2, color='k'),
                Line2D([0], [0], linestyle=line_list[2], lw=2, color='k')]
        
        leg_labels_1 = ['control']
        for elem in profiles:
            leg_labels_1.append(elem)
            
        leg_labels_2 = copy.deepcopy(power_strings)
        
        ax1.plot(myVars[f"{co2_scen}_ctrl_{start_yr}_{end_yr}_mean"]["Kd_interface"][0,:].sel(z_i=slice(0,axis_break)),
                 depth.sel(z_i=slice(0,axis_break)),label=f'control',color='k')
        
    else:
        custom_leg = [Line2D([0], [0], color='b', lw=2),
                Line2D([0], [0], color='m', lw=2),
                Line2D([0], [0], color='g', lw=2),
                Line2D([0], [0], color='r', lw=2),
                Line2D([0], [0], linestyle=line_list[0], lw=2, color='k'),
                Line2D([0], [0], linestyle=line_list[1], lw=2, color='k'),
                Line2D([0], [0], linestyle=line_list[2], lw=2, color='k')]
        leg_labels = copy.deepcopy(profiles)
        for elem in power_strings:
            leg_labels.append(elem)

    for pow_idx in range(len(power_var_suff)):
        for i in range(len(profiles)):
            ax1.plot(myVars[f"{co2_scen}_{profiles[i]}_{power_var_suff[pow_idx]}_{start_yr}_{end_yr}_mean"][Kd_var][0,:].sel(z_i=slice(0,axis_break)),
                    depth.sel(z_i=slice(0,axis_break)),label=f'{power_strings[pow_idx]} {profiles[i]}',
                    linestyle=line_list[pow_idx],color=prof_colors[i])

    ax1.spines['bottom'].set_visible(False)  # Hide bottom spine
    ax1.tick_params(bottom=True, labelbottom=False)  # Enable ticks but hide labels
    ax1.set_ylim(0, axis_break)
    ax1.invert_yaxis()
    
    # Bottom subplot
    ax2 = fig.add_subplot(gs[1])

    if (Kd_var == "Kd_interface" or Kd_var == "Kd_int_base"):
        ax2.plot(myVars[f"{co2_scen}_ctrl_{start_yr}_{end_yr}_mean"]["Kd_interface"][0,:].sel(z_i=slice(axis_break,None)),
                 depth.sel(z_i=slice(axis_break,None)),label=f'control',color='k')

    for pow_idx in range(len(power_var_suff)):
        for i in range(len(profiles)):
            ax2.plot(myVars[f"{co2_scen}_{profiles[i]}_{power_var_suff[pow_idx]}_{start_yr}_{end_yr}_mean"][Kd_var][0,:].sel(z_i=slice(axis_break,None)),
                    depth.sel(z_i=slice(axis_break,None)),label=f'{power_strings[pow_idx]} {profiles[i]}',
                    linestyle=line_list[pow_idx],color=prof_colors[i])

    ax2.set_ylim(axis_break,max_z)
    ax2.invert_yaxis()
    
    # Synchronize the x-axis limits
    ax1.xaxis.set_major_formatter(sci_formatter)
    ax2.xaxis.set_major_formatter(sci_formatter)
    ax1.set_xlim(0,max_Kd)
    ax2.set_xlim(0,max_Kd)

    ax1.grid("both")
    ax2.grid("both")
    ax1.grid(which='major', linestyle='-', linewidth='0.5', color='gray')
    ax2.grid(which='major', linestyle='-', linewidth='0.5', color='gray')
    
    # ax2.legend(loc='best',ncol=2)
    if (Kd_var == "Kd_interface" or Kd_var == "Kd_int_base"):
        # First legend (5 labels)
        legend1 = ax2.legend(
            custom_leg_1, leg_labels_1,
            loc='lower right',
            fontsize=10, labelspacing=0.1,
            bbox_to_anchor=(0.7, 0.0),  # Adjust position as needed
            frameon=True
        )
        # Second legend (3 labels, positioned below the first)
        legend2 = ax2.legend(
            custom_leg_2, leg_labels_2,
            loc='lower right',
            fontsize=10, labelspacing=0.1,
            bbox_to_anchor=(1.0, 0.0),  # Adjust position as needed
            frameon=True
        )
        
        # Add the first legend back to the axis
        ax2.add_artist(legend1)
    else:
        ax2.legend(custom_leg, leg_labels, loc='best', fontsize=10, ncol = 2, labelspacing=0.1)
    
    ax2.set_xlabel(r"$\kappa_d$ (m/s$^2$)")
    fig.text(0, 0.5, "Depth (m)", va='center', rotation='vertical')
    ax1.set_title(title_pref+f" {co2_str}:\nYear {start_yr} to {end_yr}")
    
    plt.savefig(fig_dir+f'{Kd_var}_{co2_scen}_{str(start_yr).zfill(4)}_{str(end_yr).zfill(4)}.png', dpi=600, bbox_inches='tight')

## Diffusivity maps and basin means

In [16]:
def plot_pp_Kd_map(title,pp_ds,Kd_var,z_idx,start_yr,end_yr,layer_var=False,savefig=False,cb_min=-10,\
                   cb_max=None,prefix=None,verbose=False):

    if layer_var == False:
        Kd_dat = pp_ds[Kd_var].isel(z_i=z_idx)
        depth = pp_ds[Kd_var].coords['z_i'].values[z_idx]
    else:
        Kd_dat = pp_ds[Kd_var].isel(z_l=z_idx)
        depth = pp_ds[Kd_var].coords['z_l'].values[z_idx]

    if verbose:
        print(f"Kd min: {np.nanmin(Kd_dat.values):.3e}\t Kd max: {np.nanmax(Kd_dat.values):.3e}")

    log_Kd_dat = np.log10(Kd_dat)
    log_Kd_dat = log_Kd_dat.where(log_Kd_dat != -np.inf, -50)
    
    dat_min = np.nanmin(log_Kd_dat.values)
    dat_max = np.nanmax(log_Kd_dat.values)
    
    if verbose:
        print(f"Log(Kd) min: {dat_min:.3e}\t Log(Kd) max: {dat_max:.3e}")

    if cb_max != None:
        max_val = cb_max
    else:
        max_val = dat_max

    plot_min = cb_min
    plot_max = np.ceil(max_val)
    num = int(plot_max - plot_min) + 1
    tick_arr = np.linspace(plot_min,plot_max,num=num)
    
    num_colors = 2 * (num - 1)
    
    if verbose:
        print(f"num = {num}\t num_colors = {num_colors}")  
        print(f"Plot min: {plot_min:.3f}\t Plot max: {plot_max:.3f}")
    
    plt.figure(figsize=[12, 8])
    
    cmap = cmocean.cm.dense  # define the colormap
    cmaplist = [cmap(i) for i in range(cmap.N)] # extract all colors from the balance map
    # force the first color entry to be grey
    # cmaplist[0] = (.5, .5, .5, 1.0)
    
    # create the new map
    disc_bal_cmap = mcolors.LinearSegmentedColormap.from_list('Custom cmap', cmaplist, cmap.N)
    
    # define the bins and normalize
    norm_bounds = np.linspace(plot_min, plot_max, num_colors + 1)
    disc_norm = mcolors.BoundaryNorm(norm_bounds, cmap.N)
        
    subplot_kws=dict(projection=ccrs.Robinson(central_longitude=209.5), facecolor='0.75') #projection=ccrs.PlateCarree(),facecolor='gray'
    # projection=ccrs.Robinson(central_longitude=180)
    
    Kd_plot = log_Kd_dat.plot(vmin=plot_min, vmax=plot_max,
                  x='geolon', y='geolat',
                  # cmap=cmocean.cm.dense,
                  cmap=disc_bal_cmap, norm=disc_norm,
                  subplot_kws=subplot_kws,
                      #You can pick any projection from the cartopy list but, whichever projection you use, you still have to set
                  transform=ccrs.PlateCarree(),
                  add_labels=False,
                  add_colorbar=False)
    
    # Kd_plot.axes.coastlines()
    Kd_plot.axes.set_title(f"{title}: Year {start_yr}–{end_yr}, z = {depth:,.2f} m",fontdict={'fontsize':18})
    
    # Kd_cb = plt.colorbar(Kd_plot, fraction=0.046, pad=0.04)
    Kd_cb = plt.colorbar(Kd_plot, ticks=tick_arr, shrink=0.6, extend='both') #fraction=0.046, pad=0.04,

    # tick_labels = [f"{x:.0f}" for x in tick_arr] # str(x)
    # tick_labels[np.ceil(num)] = "0"
    Kd_cb.set_ticks(tick_arr)
    Kd_cb.ax.set_yticklabels(tick_labels)
    Kd_cb.ax.tick_params(labelsize=14)
    Kd_cb.set_label("log$_{10}$ ($m^2/s$)",fontdict={'fontsize':14})

    for t in Kd_cb.ax.get_yticklabels():
        t.set_horizontalalignment('center')   
        t.set_x(2.0)

    if savefig == True:
        plt.savefig(f'{prefix}_{Kd_var}_{str(start_yr).zfill(4)}_{str(end_yr).zfill(4)}_z_{depth:.0f}.png', dpi=600, bbox_inches='tight')

In [17]:
def plot_Kd_basin(title_p1, title_p2, pp_ds, Kd_var, basin_name, max_depth, start_yr, end_yr, layer_var=False,
                  cb_min=-7, cb_max=None, cb_spacing=0.25, non_lin_cb_val=None,
                  savefig=False, fig_dir=None, prefix=None,
                  check_nn=True, nn_threshold=0.00, full_field_var=None, verbose=False):
    
    Kd_dat = get_pp_basin_dat(pp_ds, Kd_var, basin_name, check_nn=check_nn, nn_threshold=nn_threshold,
                              full_field_var=full_field_var, verbose=verbose)
    
    if layer_var==False:
        Kd_dat = Kd_dat.sel(z_i=slice(0,max_depth))
    else:
        Kd_dat = Kd_dat.sel(z_l=slice(0,max_depth))

    if verbose:
        print(f"Kd min: {np.nanmin(Kd_dat.values):.3e}\t Kd max: {np.nanmax(Kd_dat.values):.3e}")
    
    log_Kd_dat = np.log10(Kd_dat)
    log_Kd_dat = log_Kd_dat.where(log_Kd_dat != -np.inf, -50)
    
    dat_min = np.nanmin(log_Kd_dat.values)
    dat_max = np.nanmax(log_Kd_dat.values)
    
    if verbose:
        print(f"Log(Kd) min: {dat_min:.3f}\t Log(Kd) max: {dat_max:.3f}")

    if cb_max != None:
        max_val = cb_max
    else:
        max_val = dat_max

    plot_min = cb_min
    plot_max = np.ceil(max_val)

    cmap = cmocean.cm.dense  # define the colormap
    cmaplist = [cmap(i) for i in range(cmap.N)] # extract all colors from the balance map
    
    # create the new map
    disc_bal_cmap = mcolors.LinearSegmentedColormap.from_list('Custom cmap', cmaplist, cmap.N)

    if non_lin_cb_val != None:
        # define the bins and normalize
        num_col_lower = 2*int(non_lin_cb_val - plot_min)
        num_ticks_lower = int(non_lin_cb_val - plot_min)
        num_col_upper = int((plot_max - non_lin_cb_val)/cb_spacing)
        num_ticks_upper = int((plot_max - (non_lin_cb_val))/(2*cb_spacing))
    
        lower_bounds = np.linspace(plot_min,non_lin_cb_val,num_col_lower,endpoint=False)
        lower_ticks = np.linspace(plot_min,non_lin_cb_val,num_ticks_lower,endpoint=False)
        upper_bounds = np.linspace(non_lin_cb_val, plot_max, num_col_upper + 1)
        upper_ticks = np.linspace(non_lin_cb_val, plot_max, num_ticks_upper + 1)

        # print(lower_bounds)
        # print(upper_bounds)
        # print(num_col_lower)
        # print(num_col_upper)
        
        norm_bounds = np.concatenate((lower_bounds,upper_bounds))
        tick_arr = np.concatenate((lower_ticks,upper_ticks))

    else:
        # define the bins and normalize
        num_col = int((plot_max - plot_min)/cb_spacing)
        num_ticks = int((plot_max - plot_min)/(2*cb_spacing))
        
        norm_bounds = np.linspace(plot_min, plot_max, num_col + 1)
        tick_arr = np.linspace(plot_min, plot_max, num_ticks + 1)

    # if verbose:
    #     print(f"num_col: {num_col}\tnum_ticks: {num_ticks}")
    #     print(f"norm_bounds: {norm_bounds}")
    #     print(f"tick_arr: {tick_arr}")
    
    # for i in range(0,len(tick_arr)):
    #     tick_arr[i] = round(tick_arr[i]/0.1)*0.1

    # norm_bounds = np.linspace(plot_min, plot_max, num_colors + 1)
    disc_norm = mcolors.BoundaryNorm(norm_bounds, cmap.N)

    plt.figure(figsize=[11, 3.8])
    
    subplot_kws=dict(facecolor='grey')

    if layer_var == False:
        Kd_p = log_Kd_dat.plot(x='true_lat', y='z_i',
                  cmap=disc_bal_cmap,
                  norm=disc_norm,
                  subplot_kws=subplot_kws,
                      #You can pick any projection from the cartopy list but, whichever projection you use, you still have to set
                  # transform=ccrs.PlateCarree(),
                  add_labels=False,
                  add_colorbar=False)
    else:
        Kd_p = log_Kd_dat.plot(x='true_lat', y='z_l',
                  cmap=disc_bal_cmap,
                  norm=disc_norm,
                  subplot_kws=subplot_kws,
                      #You can pick any projection from the cartopy list but, whichever projection you use, you still have to set
                  # transform=ccrs.PlateCarree(),
                  add_labels=False,
                  add_colorbar=False)
    
    Kd_p.axes.invert_yaxis()
    Kd_p.axes.minorticks_on()

    Kd_cb = plt.colorbar(Kd_p, ticks=tick_arr, fraction=0.046, pad=0.04, extend='both') #shrink=0.6

    tick_labels = [f"{x:.2f}" for x in tick_arr] # str(x)
    Kd_cb.set_ticks(tick_arr)
    Kd_cb.ax.set_yticklabels(tick_labels)
    Kd_cb.ax.tick_params(labelsize=14)
    Kd_cb.set_label("log$_{10}$ ($m^2/s$)",fontdict={'fontsize':14})

    for t in Kd_cb.ax.get_yticklabels():
        t.set_horizontalalignment('center')   
        t.set_x(2.8)

    # Mar 11: fix the bounds later
    # Kd_p.axes.set_xlim(-60,60)
    # Kd_p.axes.set_xticks(ticks=[-60,-40,-20,0,20,40,60],labels=['60$\degree$S','40$\degree$S','20$\degree$S','0$\degree$',\
    #                                                             '20$\degree$N','40$\degree$N','60$\degree$N'], fontsize=14)
    
    Kd_p.axes.tick_params(axis='y', labelsize=14)
    
    Kd_p.axes.set_ylabel('Depth (m)', fontsize=14)
    Kd_p.axes.set_title(f"{title_p1}\n{title_p2}: Year {start_yr}–{end_yr}",fontdict={'fontsize':16})

    if savefig is True:
        if fig_dir is None:
            raise ValueError("Must specify 'fig_dir' = <directory>.")
        if prefix is None:
            raise ValueError("Must specify prefix for figure file name.")
            
        if not os.path.exists(fig_dir):
            os.makedirs(fig_dir)

        plt.savefig(fig_dir + f'{prefix}_{Kd_var}_{str(start_yr).zfill(4)}_{str(end_yr).zfill(4)}.png', dpi=600, bbox_inches='tight')