JM: 09 Apr 23

Tidy up of some of the data for partial res investigation.

In [None]:
# prelim loading

import numpy as np
import xarray as xr
import matplotlib.pyplot as plt

from numba import jit
import copy

import cmocean
import sys
scivis_path = "./scivis_cm/KeyColormaps/"
sys.path.append(scivis_path)
import cm_xml_to_matplotlib as scivis_cm
wave4_cm = scivis_cm.make_cmap(scivis_path + "3Wbgy5.xml")
# scivis_cm.plot_cmap(wave4_cm)

from matplotlib.colors import Normalize, LinearSegmentedColormap
from matplotlib import cm
from matplotlib.colorbar import ColorbarBase
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.colorbar import make_axes
import matplotlib.ticker as mticker

from pyCDFTOOLS.cdfcurl import *
from pyCDFTOOLS.draw_clock import *

# define some defaults
plt.rcParams["font.family"] = "DejaVu Serif"
plt.rcParams["mathtext.fontset"] = "cm"
plt.rcParams["mathtext.rm"] = "serif"
plt.rcParams["image.cmap"] = "RdBu_r" # \"*_r\" is reverse of standard colour\n",
plt.rcParams["axes.formatter.limits"] = [-4, 4]
plt.rcParams["font.size"] = 12.0

label_str = "abcdefghijklmnopqrstuvwxyz"

In [None]:
# quick and dirty circumpolar transport subroutine
def utrans_tot(mesh_mask, uoce):
    """total ACC transport in Sv 
       excluding the recirculation zone north of 2000 km"""
    
    dydz = mesh_mask["e3u_0"][0, :, :, :] * mesh_mask["e2u"][0, :, :]
    umask = mesh_mask["umask"][0, :, :, :].values
    y_ind = mesh_mask["gphiu"][0, :, 5] > 2000
    umask[:, y_ind, :] = 0
    
    return np.mean(np.sum(np.sum(uoce * umask * dydz, axis=0), axis=0)).values / 1e6

def utrans_bot(mesh_mask, uoce):
    """bottom flow depth-integrated transport ("barotropic" transport) in Sv
       excluding the recirculation zone north of 2000 km"""
    
    tmask = mesh_mask["tmask"][0, 0, :, :].values
    mbathy = mesh_mask["mbathy"][0, :, :].values
    mbathy *= tmask  # ?? some mesh mash have weird indices
    umask = mesh_mask["umask"][0, :, :, :].values
    dydz_zint = (np.sum(mesh_mask["e3u_0"][0, :, :, :] * umask, axis=0) 
              * mesh_mask["e2u"][0, :, :]).values
    ny, nx = uoce[0, :, :].shape
    ubt_zint = np.zeros((ny, nx))
    for i in range(nx):
        for j in range(ny):
            ubt_zint[j, i] = uoce[mbathy[j, i]-1, j, i]
    ubt_zint *= dydz_zint
    
    y_ind = mesh_mask["gphiu"][0, :, 5] > 2000
    umask[:, y_ind, :] = 0
    
    return np.mean(np.sum(ubt_zint * umask[0, :, :], axis=0), axis=0) / 1e6

-------------------------
## 1) vorticity snapshots

In [None]:
# R025 sample vorticity

folder_list = ["no_split/gm0000/tau100x/OUTPUTS/", "no_split/gm0500/tau100x/OUTPUTS/",
               "no_split/alp0060_lam80/tau100x/OUTPUTS/", "split_100km/alp0060_lam80/tau100x/OUTPUTS/"]
title_list = [r"$\kappa_{\rm gm} = 0$", r"$\kappa_{\rm gm} = 500$, as is",
             r"GEOM $\alpha = 0.06$, as is", r"GEOM $\alpha = 0.06$, new way"]

# formatting
kwargs = {"kz"     : 0,
          "kt"     : -1,
          "lprint" : False,
          "lperio" : True,
          "loverf" : False}

# for vorticity sesmic is good, RdBu_r is ok too; with cmocean, curl and balance are probably ok
vmin, vmax = -2e-5, 2e-5
misc_opts = {"levels" : np.linspace(vmin, vmax, 41),
             "cmap"   : cmocean.cm.curl,
             "extend" : "both"}

norm = Normalize(vmin = vmin, vmax = vmax)
colors = misc_opts["cmap"](np.linspace(0, 1, misc_opts["cmap"].N))
cmap2 = LinearSegmentedColormap.from_list('dummy', colors)

# fig properties
nrows = 2
ncols = 2
fig, axes = plt.subplots(figsize = (14, 4.5), nrows = nrows, ncols = ncols)

for i in range(nrows):
    for j in range(ncols):
        ind = (i * ncols) + j + 1
        data_dir = f"../../UNAGI/EXP_R025/{folder_list[ind-1]}"
        fileU = "UNAGI_1y_08010101_08101230_surf_U.nc"
#         fileU = "UNAGI_1y_05010101_05101230_surf_U.nc"
        fileV = fileU.replace("U.nc", "V.nc")
        lonT, latT, curlu, opt_dic = cdfcurl(data_dir, fileU, "ssu_inst", fileV, "ssv_inst", **kwargs)
        
        mesh = axes[i, j].contourf(lonT[:, 1:-1], latT[:, 1:-1], curlu[:, 1:-1], **misc_opts, zorder=-5)
        axes[i, j].set_rasterization_zorder(-1)  # raster the contourf to reduce size
#         axes[i, j].set_aspect('equal')
        axes[i, j].set_ylim(0, 2400)
        axes[i, j].set_xlim(0, 9000)
        axes[i, j].set_title(f"{title_list[ind-1]}")
        axes[i, j].text(100, 2550, f"({label_str[ind-1]})")

# axes label formatting
axes[0, 0].set_xticklabels([])
axes[0, 0].set_ylabel(r"$y$ ($\mathrm{km}$)")
axes[0, 1].set_xticklabels([])
axes[0, 1].set_yticklabels([])
axes[1, 0].set_xlabel(r"$x$ ($\mathrm{km}$)")
axes[1, 0].set_ylabel(r"$y$ ($\mathrm{km}$)")
axes[1, 1].set_xlabel(r"$x$ ($\mathrm{km}$)")
axes[1, 1].set_yticklabels([])

# create colourbar axes by grabbing the bounding boxes of the individual plot axes
pos_bot = axes[nrows - 1, ncols - 1].get_position()
pos_top = axes[0, ncols - 1].get_position()
ax_cb = fig.add_axes([pos_bot.x1 + 0.01,
                      pos_bot.y0,
                      0.005,
                      pos_top.y1 - pos_bot.y0])
norm = Normalize(vmin = vmin, vmax = vmax)
cb = ColorbarBase(ax_cb, cmap = cmap2, norm = norm)
cb.set_ticks(np.arange(vmin, vmax+1e-6, 1e-5))
    
fig.savefig("./partial_res_figs/unagi_R025_xi_snap.pdf", dpi=150, bbox_inches="tight")

-----------------------

## 2) Tests on filtering routine

In [None]:
# splitting subroutines

# compiled loop testing diffusion loops

@jit(nopython=True)
def diffusion_loop_old(tmask, zaheeu, zaheev, e1e2t, z_out, cycle=1):
    
    npj, npi = tmask.shape
    
    # compute first derivative
    ztu, ztv, zlap = np.zeros((npj, npi)), np.zeros((npj, npi)), np.zeros((npj, npi))
    
    for jn in range(cycle):
        for jj in range(npj-1):
            for ji in range(npi-1):
                ztu[jj, ji] = zaheeu[jj, ji] * (z_out[jj  , ji+1] - z_out[jj, ji])
                ztv[jj, ji] = zaheev[jj, ji] * (z_out[jj+1, ji  ] - z_out[jj, ji])

        for jj in range(1, npj-1):
            for ji in range(1, npi-1):
                zlap[jj, ji] = (  ztu[jj, ji] - ztu[jj  , ji-1]
                                + ztv[jj, ji] - ztv[jj-1, ji  ]
                               ) / e1e2t[jj, ji]
                z_out[jj, ji] = z_out[jj, ji] + zlap[jj, ji] * tmask[jj, ji]
                
        
        z_out[:, -1] = z_out[:,  1]  # halo
        z_out[:,  0] = z_out[:, -2]  # halo
    
    return z_out

@jit(nopython=True)
def diffusion_loop(tmask, zaheeu, zaheev, e1e2t, z_in, power=2, gamma=100, diag_freq=50):
    
    npj, npi = tmask.shape
    
    ztu, ztv, zlap = np.zeros((npj, npi)), np.zeros((npj, npi)), np.zeros((npj, npi))
    z_orig, z_prev, z_now = np.zeros((npj, npi)), np.zeros((npj, npi)), np.zeros((npj, npi))
    
    # make copy of "original" field to be held fix during the iteration stage
    
    z_orig[:, :] = z_in[:, :]
    
    for n in range(power):
        z_prev[:, :] = z_orig[:, :]
        z_now[:, :] = 0.0
        
        glob_sup_res = 1e15
        jn = 0
        print("iteration at power = ", n+1)

        while (glob_sup_res > 1e-3) and (jn < 500):
            
            # compute first derivative
            for jj in range(npj-1):
                for ji in range(npi-1):
                    ztu[jj, ji] = zaheeu[jj, ji] * (z_prev[jj  , ji+1] - z_prev[jj, ji])
                    ztv[jj, ji] = zaheev[jj, ji] * (z_prev[jj+1, ji  ] - z_prev[jj, ji])

            for jj in range(1, npj-1):
                for ji in range(1, npi-1):
                    zlap[jj, ji] = (  ztu[jj, ji] - ztu[jj  , ji-1]
                                    + ztv[jj, ji] - ztv[jj-1, ji  ]
                                   ) / e1e2t[jj, ji]
        #             z_now[jj, ji] = z_orig[jj, ji] + zlap[jj, ji] * tmask[jj, ji]
                    z_now[jj, ji] = (1.0 / gamma) * z_orig[jj, ji] + (1.0 / gamma) * tmask[jj, ji] * (
                        (gamma - 1.0) * z_prev[jj, ji] +  zlap[jj, ji]
                    )
            
            # diagnostics
            glob_sup_res = np.max(np.abs(z_now - z_prev))
            
            if (jn % diag_freq == 0):
                print("  it = ", jn, " global sup res = ", glob_sup_res)
        
            # updates
            jn += 1
            z_now[:, -1] = z_now[:,  1]  # halo
            z_now[:,  0] = z_now[:, -2]  # halo
            z_prev[:, :] = z_now[:, :] # iteration

        # return T* of (I - nu * D2) T*(n+1) = T(n)
        z_orig[:, :] = z_now[:, :]
        
        print("  it = ", jn, " global sup res = ", glob_sup_res)
    
    return z_now

@jit(nopython=True)
def diffusion_loop2(tmask, zaheeu, zaheev, e1e2t, z_in, power=2, gamma=100, diag_freq=50):
    
    npj, npi = tmask.shape
    
    ztu, ztv, zlap = np.zeros((npj, npi)), np.zeros((npj, npi)), np.zeros((npj, npi))
    z_orig, z_prev, z_now = np.zeros((npj, npi)), np.zeros((npj, npi)), np.zeros((npj, npi))
    
    # make copy of "original" field to be held fix during the iteration stage
    
    z_orig[:, :] = z_in[:, :]
    
    for n in range(power):
        z_prev[:, :] = z_orig[:, :]
        z_now[:, :] = 0.0
        
        glob_sup_res = 1e15
        jn = 0
        print("iteration at power = ", n+1)

        while (glob_sup_res > 1e-6) and (jn < 500):
            
            # compute first derivative
            for jj in range(npj-1):
                for ji in range(npi-1):
                    ztu[jj, ji] = zaheeu[jj, ji] * (z_prev[jj  , ji+1] - z_prev[jj, ji])
                    ztv[jj, ji] = zaheev[jj, ji] * (z_prev[jj+1, ji  ] - z_prev[jj, ji])

            for jj in range(1, npj-1):
                for ji in range(1, npi-1):
                    zlap[jj, ji] = (  ztu[jj, ji] - ztu[jj  , ji-1]
                                    + ztv[jj, ji] - ztv[jj-1, ji  ]
                                   ) / e1e2t[jj, ji]
        #             z_now[jj, ji] = z_orig[jj, ji] + zlap[jj, ji] * tmask[jj, ji]
                    z_now[jj, ji] = (1.0 / gamma) * z_orig[jj, ji] + (1.0 / gamma**2) * tmask[jj, ji] * (
                        (gamma - 1.0) * z_prev[jj, ji] +  zlap[jj, ji]
                    )
            
            # diagnostics
            glob_sup_res = np.max(np.abs(z_now - z_prev))
            
            if (jn % diag_freq == 0):
                print("  it = ", jn, " global sup res = ", glob_sup_res)
        
            # updates
            jn += 1
            z_now[:, -1] = z_now[:,  1]  # halo
            z_now[:,  0] = z_now[:, -2]  # halo
            z_prev[:, :] = z_now[:, :] # iteration

        # return T* of (I - nu * D2) T*(n+1) = T(n)
        z_orig[:, :] = z_now[:, :]
        
        print("  it = ", jn, " global sup res = ", glob_sup_res)
    
    return z_now

# ugly code for spectrum analysis
def spectrum_old(tmask, zaheeu, zaheev, e1e2t, toce, cycle_list):
    Lx = 9000  # units in km
    scale_factor = 1.0 / Lx

    # perform (r)fft in zonal direction (periodic)
    z_init = copy.deepcopy(toce)
    power_avg = np.mean(np.abs(np.fft.rfft(z_init, axis=-1, norm="ortho")) ** 2, axis=0)  # average over the y
    k_vec = np.arange(power_avg.shape[-1]) * scale_factor

    fig = plt.figure(figsize=(8, 3))
    ax = plt.axes() # don't plot the mean (zero mode)
    ax.loglog(k_vec[1::], power_avg[1::], label=r"cycle=0")

    for cycle in cycle_list:
        z_out = diffusion_loop_old(tmask, zaheeu, zaheev, e1e2t, z_init, cycle=cycle)
        power_avg = np.mean(np.abs(np.fft.rfft(z_out, axis=-1, norm="ortho")) ** 2, axis=0)
        ax.loglog(k_vec[1::], power_avg[1::], label=f"cycle={cycle}")
    ax.grid()
    ax.legend()
    
    return ax

# catch all code for zonal spectrum analysis
def spectrum(tmask, e1e2t, e2_e1u, e2_e1v, umask, vmask, toce, L_list, gamma_list, ax=None):
    Lx = 9000  # units in km
    scale_factor = 1.0 / Lx

    # perform (r)fft in zonal direction (periodic)
    z_init = copy.deepcopy(toce)
    power_avg = np.mean(np.abs(np.fft.rfft(z_init, axis=-1, norm="ortho")) ** 2, axis=0)  # average over the y
    k_vec = np.arange(power_avg.shape[-1]) * scale_factor

    if ax == None:
        fig = plt.figure(figsize=(8, 3))
        ax = plt.axes() # don't plot the mean (zero mode)
    ax.loglog(k_vec[1::], power_avg[1::], label=r"$L=$0 km")

    for k in range(len(L_list)):
        nu = L_list[k]
        zaheeu = nu * nu * e2_e1u * umask
        zaheev = nu * nu * e2_e1v * vmask
        z_out = diffusion_loop(tmask, zaheeu, zaheev, e1e2t, toce, power=2, gamma=gamma_list[k])
        power_avg = np.mean(np.abs(np.fft.rfft(z_out, axis=-1, norm="ortho")) ** 2, axis=0)
        ax.loglog(k_vec[1::], power_avg[1::], label=f"$L=${int(nu / 1e3)} km")
    ax.grid()
    ax.legend(fontsize="10")
    
    return ax

In [None]:
# load R025 data
data_dir = "../../UNAGI/EXP_R025/no_split/gm0000/tau100x/OUTPUTS/"
mesh_mask = xr.open_dataset(data_dir + "../../../../mesh_mask.nc")
data = xr.open_dataset(data_dir + "UNAGI_1y_08010101_08101230_surf_T.nc")

lon = data["nav_lon"].values
lat = data["nav_lat"].values
toce = data["sst_inst"][0, :, :].values

npj, npi = lon.shape

# overwrite lon because of the periodic wrapping
lon[:, -1] += lon[:, 0]
lon[:,  0]  = 0.0

# only at the surface so no need for loop in z
jk = 0

# compute some metrics that are always needed
e1u, e2u = mesh_mask["e1u"][0, :, :].values, mesh_mask["e2u"][0, :, :].values
e1v, e2v = mesh_mask["e1v"][0, :, :].values, mesh_mask["e2v"][0, :, :].values
e1t, e2t = mesh_mask["e1t"][0, :, :].values, mesh_mask["e2t"][0, :, :].values

# ? not entirely sure why the mesh_mask file has some zeros in the e[12] at the periodicities even though
#   it looks like the ocean variables are correct...just overwritting them here
e1u[:, 0] = e1u[:, -1] = e2u[:, 0] = e2u[:, -1] = 25.0e3
e1v[:, 0] = e1v[:, -1] = e2v[:, 0] = e2v[:, -1] = 25.0e3
e1t[:, 0] = e1t[:, -1] = e2t[:, 0] = e2t[:, -1] = 25.0e3

e1e2t = e1t * e2t
e2_e1u = e2u / e1u
e2_e1v = e2v / e1v
del e1u, e2u, e1v, e2v, e1t, e2t

umask = mesh_mask["umask"][0, jk, :, :].values
vmask = mesh_mask["vmask"][0, jk, :, :].values
tmask = mesh_mask["tmask"][0, jk, :, :].values

# ? not entirely sure why the mesh_mask file has some zeros in the masks at the periodicities even though
#   it looks like the ocean variables are correct...just overwritting them here
umask[1:-1, 0] = umask[1:-1, -1] = 1
vmask[1:-2, 0] = vmask[1:-2, -1] = 1

In [None]:
# pre-mortem analysis (100 km filter, on kgm = 0 data)

nu = 1.0e5 # this is 100km filter
zaheeu = nu * nu * e2_e1u * umask
zaheev = nu * nu * e2_e1v * vmask

gamma = 75  # stabilises but makes the convergence slow
z_out = diffusion_loop(tmask, zaheeu, zaheev, e1e2t, toce, power=2, gamma=gamma)

# fig properties
nrows = 2
ncols = 2
fig, axes = plt.subplots(figsize = (14, 4.5), nrows = nrows, ncols = ncols)

for i in range(nrows):
    for j in range(ncols):
        ind = (i * ncols) + j + 1
        if ind == 1:
            mesh = axes[i, j].contourf(lon, lat, toce, 41, vmin=0, vmax=14, zorder=-5)
        if ind == 2:
            mesh = axes[i, j].contourf(lon, lat, z_out, 41, vmin=0, vmax=14, zorder=-5)
        if ind == 3:
            res = toce - z_out
            mesh = axes[i, j].contourf(lon[:, 1:-1:], lat[:, 1:-1:], res[:, 1:-1:], 41,
                                         vmin=-2, vmax=2, zorder=-5)
        axes[i, j].set_rasterization_zorder(-1)  # raster the contourf to reduce size
        if ind < 4:
            axes[i, j].set_ylim(0, 2400)
            axes[i, j].set_xlim(0, 9000)
            axes[i, j].text(100, 2550, f"({label_str[ind-1]})")
            
# plotting of avg of zonal power spectrum
L_list = [50e3, 100e3, 200e3]
gamma_list = [25, 75, 300]

spectrum(tmask, e1e2t, e2_e1u, e2_e1v, umask, vmask, toce, L_list, gamma_list, ax=axes[i, j])
ticks = 1.0 / np.asarray([2000, 500, 200, 100, 50, 25, 10])
axes[i, j].set_xticks(ticks)
axes[i, j].set_xticklabels([r"$1/2000$", r"$1/500$", r"$1/200$", r"$1/100$", r"$1/50$", r"$1/25$", r"$1/10$"])
axes[i, j].set_xlim([1/5000, 1/40])
axes[i, j].set_ylim([1e-7, 1e1])
        
# axes label formatting
axes[0, 0].set_xticklabels([])
axes[0, 0].set_ylabel(r"$y$ ($\mathrm{km}$)")
axes[0, 1].set_xticklabels([])
axes[0, 1].set_yticklabels([])
axes[1, 0].set_xlabel(r"$x$ ($\mathrm{km}$)")
axes[1, 0].set_ylabel(r"$y$ ($\mathrm{km}$)")
axes[1, 1].set_xlabel(r"(zonal wavelength)${}^{-1}$ ($\mathrm{km}^{-1}$)")
axes[1, 1].yaxis.set_label_position("right")
axes[1, 1].yaxis.set_ticks_position("right")
axes[1, 1].set_ylabel(r"$|\mathcal{F}(\Theta)|^2$")

axes[0, 0].set_title(r"full $\Theta(z=0)$")
axes[0, 1].set_title(r"filtered $\Theta(z=0)$ [$L =$ 100 km]")
axes[1, 0].set_title(r"residual $\Theta(z=0)$ [$L =$ 100 km]")

axes[1, 1].text(1/4750, 30, f"({label_str[ind-1]})")

# create colourbar axes by grabbing the bounding boxes of the individual plot axes
pos = axes[0, 0].get_position()
ax_cb = fig.add_axes([pos.x1 + 0.01,
                      pos.y0,
                      0.005,
                      pos.y1 - pos.y0])
norm = Normalize(vmin = 0, vmax = 14)
cb = ColorbarBase(ax_cb, cmap = plt.cm.RdBu_r, norm = norm)

pos = axes[0, 1].get_position()
ax_cb = fig.add_axes([pos.x1 + 0.01,
                      pos.y0,
                      0.005,
                      pos.y1 - pos.y0])
norm = Normalize(vmin = 0, vmax = 14)
cb = ColorbarBase(ax_cb, cmap = plt.cm.RdBu_r, norm = norm)

pos = axes[1, 0].get_position()
ax_cb = fig.add_axes([pos.x1 + 0.01,
                      pos.y0,
                      0.005,
                      pos.y1 - pos.y0])
norm = Normalize(vmin = -2, vmax = 2)
cb = ColorbarBase(ax_cb, cmap = plt.cm.RdBu_r, norm = norm)

fig.savefig("./partial_res_figs/pre_mortem_R025_SST_100km.pdf", dpi=150, bbox_inches="tight")

In [None]:
# post-mortem analysis (alp = 0.06, lam-1 = 80 days)

data_dir = f"../../UNAGI/EXP_R025/split_100km/alp0060_lam80/tau100x/OUTPUTS/"
mesh_mask = xr.open_dataset(data_dir + "../../../../mesh_mask.nc")
data = xr.open_dataset(data_dir + "UNAGI_1y_08010101_08101230_surf_T.nc")

# fig properties
nrows = 2
ncols = 2
fig, axes = plt.subplots(figsize = (14, 4.5), nrows = nrows, ncols = ncols)

for i in range(nrows):
    for j in range(ncols):
        ind = (i * ncols) + j + 1
        
        if ind == 1:
            mesh = axes[i, j].contourf(lon, lat, data["sst_inst"][-1, :, :], 41, vmin=0, vmax=14, zorder=-5)
        if ind == 2:
            mesh = axes[i, j].contourf(lon, lat, data["tem_large"][-1, 0, :, :], 41, vmin=0, vmax=14, zorder=-5)
        if ind == 3:
            mesh = axes[i, j].contourf(lon, lat, data["tem_small"][-1, 0, :, :], 41, vmin=-2, vmax=2, zorder=-5)
        if ind == 4:
            
            Lx = 9000  # units in km
            scale_factor = 1.0 / Lx
            power_avg = np.mean(np.abs(np.fft.rfft(data["sst_inst"][-1, :, :], axis=-1, norm="ortho")) ** 2, 
                                axis=0)  # average over the y
            k_vec = np.arange(power_avg.shape[-1]) * scale_factor
            axes[i, j].loglog(k_vec[1::], power_avg[1::], label=r"full field")
            power_avg = np.mean(np.abs(np.fft.rfft(data["tem_large"][-1, 0, :, :], axis=-1, norm="ortho")) ** 2, 
                                axis=0)
            axes[i, j].loglog(k_vec[1::], power_avg[1::], "C2", label=r"filtered field")
            
            data_dir = f"../../UNAGI/EXP_R025/no_split/alp0060_lam80/tau100x/OUTPUTS/"
            mesh_mask = xr.open_dataset(data_dir + "../../../../mesh_mask.nc")
            data = xr.open_dataset(data_dir + "UNAGI_1y_08010101_08101230_surf_T.nc")
            power_avg = np.mean(np.abs(np.fft.rfft(data["sst_inst"][-1, :, :], axis=-1, norm="ortho")) ** 2, 
                                axis=0)
            axes[i, j].loglog(k_vec[1::], power_avg[1::], "C1--", label=r"full field, no filter", alpha=0.7)
            
            data_dir = f"../../UNAGI/EXP_R025/no_split/gm0000/tau100x/OUTPUTS/"
            mesh_mask = xr.open_dataset(data_dir + "../../../../mesh_mask.nc")
            data = xr.open_dataset(data_dir + "UNAGI_1y_08010101_08101230_surf_T.nc")
            power_avg = np.mean(np.abs(np.fft.rfft(data["sst_inst"][-1, :, :], axis=-1, norm="ortho")) ** 2, 
                                axis=0)
            axes[i, j].loglog(k_vec[1::], power_avg[1::], "k--", label=r"$\kappa_{\rm gm} = 0$", alpha=0.7)
            
            axes[i, j].grid()
            axes[i, j].legend()
            
            ticks = 1.0 / np.asarray([2000, 500, 200, 100, 50, 25, 10])
            axes[i, j].set_xticks(ticks)
            axes[i, j].set_xticklabels([r"$1/2000$", r"$1/500$", r"$1/200$", r"$1/100$", r"$1/50$", r"$1/25$", r"$1/10$"])
            axes[i, j].set_xlim([1/5000, 1/40])
            axes[i, j].set_ylim([1e-7, 1e1])
            
        if ind < 4:
            axes[i, j].set_rasterization_zorder(-1)  # raster the contourf to reduce size
    #         axes[i, j].set_aspect('equal')
            axes[i, j].set_ylim(0, 2400)
            axes[i, j].set_xlim(0, 9000)
            axes[i, j].text(100, 2550, f"({label_str[ind-1]})")

# axes label formatting
axes[0, 0].set_xticklabels([])
axes[0, 0].set_ylabel(r"$y$ ($\mathrm{km}$)")
axes[0, 1].set_xticklabels([])
axes[0, 1].set_yticklabels([])
axes[1, 0].set_xlabel(r"$x$ ($\mathrm{km}$)")
axes[1, 0].set_ylabel(r"$y$ ($\mathrm{km}$)")
axes[1, 1].set_xlabel(r"(zonal wavelength)${}^{-1}$ ($\mathrm{km}^{-1}$)")
axes[1, 1].yaxis.set_label_position("right")
axes[1, 1].yaxis.set_ticks_position("right")
axes[1, 1].set_ylabel(r"$|\mathcal{F}(\Theta)|^2$")

axes[0, 0].set_title(r"full $\Theta(z=0)$")
axes[0, 1].set_title(r"filtered $\Theta(z=0)$ [$L =$ 100 km]")
axes[1, 0].set_title(r"residual $\Theta(z=0)$ [$L =$ 100 km]")

axes[1, 1].text(1/4750, 30, f"({label_str[ind-1]})")

# create colourbar axes by grabbing the bounding boxes of the individual plot axes
pos = axes[0, 0].get_position()
ax_cb = fig.add_axes([pos.x1 + 0.01,
                      pos.y0,
                      0.005,
                      pos.y1 - pos.y0])
norm = Normalize(vmin = 0, vmax = 14)
cb = ColorbarBase(ax_cb, cmap = plt.cm.RdBu_r, norm = norm)

pos = axes[0, 1].get_position()
ax_cb = fig.add_axes([pos.x1 + 0.01,
                      pos.y0,
                      0.005,
                      pos.y1 - pos.y0])
norm = Normalize(vmin = 0, vmax = 14)
cb = ColorbarBase(ax_cb, cmap = plt.cm.RdBu_r, norm = norm)

pos = axes[1, 0].get_position()
ax_cb = fig.add_axes([pos.x1 + 0.01,
                      pos.y0,
                      0.005,
                      pos.y1 - pos.y0])
norm = Normalize(vmin = -2, vmax = 2)
cb = ColorbarBase(ax_cb, cmap = plt.cm.RdBu_r, norm = norm)

fig.savefig("./partial_res_figs/post_mortem_R025_SST_100km_1.pdf", dpi=150, bbox_inches="tight")

-----------------------

## 3) Energy diagnostics

In [None]:
# energy with varying resolution (no split [no gm, alp>0] and split)

res_list = ["100", "050", "025", "010"]
# res_list = ["100", "025", "010"]

# 1st index to denote no_split/gm0000, no_split/alp>0, split/alp>0
# 2nd index to denote eke, epe, and param ee
ene_diag = np.zeros((len(res_list), 3, 3))

for ind in range(len(res_list)):
    
    if res_list[ind] == "100":        
        
        data_dir = f"../../UNAGI/EXP_R100/alp0060_lam80/tau100x/OUTPUTS/"
        mesh_mask = xr.open_dataset(data_dir + "../../../mesh_mask.nc")
        
        e1t, e2t = mesh_mask["e1t"][0, :, :].values, mesh_mask["e2t"][0, :, :].values
        e3t_0 = mesh_mask["e3t_0"][0, :, :, :].values
        tmask = mesh_mask["tmask"][0, :, :, :].values
        dvol = e3t_0 * e1t[np.newaxis, :, :] * e2t[np.newaxis, :, :] * tmask
        vol = np.sum(np.sum(np.sum(dvol)))
        e1e2t = e1t * e2t
        del e1t, e2t, dvol
        
        ds_eke = xr.open_dataset(data_dir + "eke_zint_tave.nc")
        ds_epe = xr.open_dataset(data_dir + "epe_zint_tave.nc")
        ds_ee  = xr.open_dataset(data_dir + "UNAGI_10y_08010101_08101230_grid_T.nc")
        ene_diag[ind, :, 0] = np.sum(np.sum(ds_eke["eke_zint"].values * e1e2t)) / vol
        ene_diag[ind, :, 1] = np.sum(np.sum(ds_epe["epe_zint"].values * e1e2t)) / vol
        ene_diag[ind, :, 2] = np.sum(np.sum(ds_ee["eke"].values * e1e2t)) / vol
                
        # transport diagnostics
        data_U = xr.open_dataset(data_dir + "UNAGI_10y_08010101_08101230_grid_U.nc")
        uoce = data_U["uoce"][0, :, :, :].values
        ACC_tot = utrans_tot(mesh_mask, uoce)
        ACC_bot = utrans_bot(mesh_mask, uoce)
        print(f"res = {res_list[ind]}: ACC tot = {ACC_tot:.2f}, bot = {ACC_bot:.2f}, therm = {ACC_tot-ACC_bot:.2f}")
        
        print(" ")
        
    else:
        if res_list[ind] == "010":
            exp_list = ["no_split/gm0000", "no_split/alp0060_lam80", "split_100km/alp0060_lam80"]  # to remove
        else:
            exp_list = ["no_split/gm0000", "no_split/alp0060_lam80", "split_100km/alp0060_lam80"]
            
        for indd in range(len(exp_list)):
            
            data_dir = f"../../UNAGI/EXP_R{res_list[ind]}/{exp_list[indd]}/tau100x/OUTPUTS/"
            mesh_mask = xr.open_dataset(data_dir + "../../../../mesh_mask.nc")
        
            e1t, e2t = mesh_mask["e1t"][0, :, :].values, mesh_mask["e2t"][0, :, :].values
            e3t_0 = mesh_mask["e3t_0"][0, :, :, :].values
            tmask = mesh_mask["tmask"][0, :, :, :].values
            dvol = e3t_0 * e1t[np.newaxis, :, :] * e2t[np.newaxis, :, :] * tmask
            vol = np.sum(np.sum(np.sum(dvol)))
            e1e2t = e1t * e2t
            del e1t, e2t, dvol
            
            ds_eke = xr.open_dataset(data_dir + "eke_zint_tave.nc")
            ds_epe = xr.open_dataset(data_dir + "epe_zint_tave.nc")
            ene_diag[ind, indd, 0] = np.sum(np.sum(ds_eke["eke_zint"].values * e1e2t)) / vol
            ene_diag[ind, indd, 1] = np.sum(np.sum(ds_epe["epe_zint"].values * e1e2t)) / vol
            if exp_list[indd] != "no_split/gm0000":
                ds_ee  = xr.open_dataset(data_dir + "UNAGI_10y_08010101_08101230_grid_T.nc")
                ene_diag[ind, indd, 2]  = np.sum(np.sum(ds_ee["eke"].values * e1e2t)) / vol
 
            # transport diagnostics
            data_U = xr.open_dataset(data_dir + "UNAGI_10y_08010101_08101230_grid_U.nc")
            uoce = data_U["uoce"][0, :, :, :].values
            ACC_tot = utrans_tot(mesh_mask, uoce)
            ACC_bot = utrans_bot(mesh_mask, uoce)
            print(f"res = {res_list[ind]}: ACC tot = {ACC_tot:.2f}, bot = {ACC_bot:.2f}, therm = {ACC_tot-ACC_bot:.2f}")
            
        print(" ")


In [None]:
# as bar graphs to show decompositions
reso_vec = ["100", "50", "25", "10"]
width = 0.35

fig = plt.figure(figsize=(14, 3))

# purely explicit energy
ax = plt.subplot(1, 3, 1)
ax.bar(reso_vec, ene_diag[:, 0, 1], width, color="C0", label="EPE")
ax.bar(reso_vec, ene_diag[:, 0, 0], width, color="C3", bottom=ene_diag[:, 0, 1], label="EKE")
ax.bar(reso_vec, ene_diag[:, 0, 2], width, color="C2", bottom=np.sum(ene_diag[:, 0, :2], axis=-1), label="param E")
ax.set_ylabel(r"$\langle E\rangle$ ($\mathrm{m}^2\ \mathrm{s}^{-2}$)")
ax.set_title(r"$\kappa_{\rm gm} = 0$")
ax.set_ylim([0, 0.053])
ax.grid()
ax.legend()

# GEOM but no splitting
ax = plt.subplot(1, 3, 2)
ax.bar(reso_vec, ene_diag[:, 1, 1], width, color="C0", label="EPE")
ax.bar(reso_vec, ene_diag[:, 1, 0], width, color="C3", bottom=ene_diag[:, 1, 1], label="EKE")
ax.bar(reso_vec, ene_diag[:, 1, 2], width, color="C2", bottom=np.sum(ene_diag[:, 1, :2], axis=-1), label="param E")
ax.set_title(r"$\alpha=0.06, L = 0\ \mathrm{km}$")
ax.set_xlabel(r"$\Delta x = \Delta y\ (\mathrm{km})$")
ax.set_ylim([0, 0.053])
ax.grid()

# splitting
ax = plt.subplot(1, 3, 3)
ax.bar(reso_vec, ene_diag[:, 2, 1], width, color="C0", label="EPE")
ax.bar(reso_vec, ene_diag[:, 2, 0], width, color="C3", bottom=ene_diag[:, 2, 1], label="EKE")
ax.bar(reso_vec, ene_diag[:, 2, 2], width, color="C2", bottom=np.sum(ene_diag[:, 2, :2], axis=-1), label="param E")
ax.set_title(r"$\alpha=0.06, L = 100\ \mathrm{km}$")
ax.set_ylim([0, 0.053])
ax.grid()

# plt.savefig("./partial_res_figs/ene_decomp_100km.pdf", bbox_inches="tight")

In [None]:
# as bar graphs to show decompositions
reso_vec = ["100", "50", "25", "10"]
width = 0.25

fig = plt.figure(figsize=(10, 3))
ax = plt.axes()

# R100
ax.bar(0, ene_diag[0, 0, 1], width, color="C0", label="EPE")
ax.bar(0, ene_diag[0, 0, 0], width, color="C3", bottom=ene_diag[0, 0, 1], label="EKE")
ax.bar(0, ene_diag[0, 0, 2], width, color="C2", bottom=np.sum(ene_diag[0, 0, :2], axis=-1), label="param E")

for i in [1, 2, 3]:
    for j in [0, 1, 2]: # kgm = 0, no filter, filter
        
        if j == 0: # kgm = 0
            x = i-width-0.02
            alpha = 1
            hatch=""
        elif j == 1: # no filtering case 
            x = i
            alpha = 1
            hatch=""
        elif j == 2: # filtered case
            x = i+width+0.02
            alpha = 1
            hatch=""
        
        ax.bar(x, ene_diag[i, j, 1], width, 
               color="C0", alpha=alpha, hatch=hatch)
        ax.bar(x, ene_diag[i, j, 0], width, 
               color="C3", bottom=ene_diag[i, j, 1], alpha=alpha, hatch=hatch)
        ax.bar(x, ene_diag[i, j, 2], width, 
               color="C2", bottom=np.sum(ene_diag[i, j, :2], axis=-1), alpha=alpha, hatch=hatch)

ax.set_ylabel(r"$\langle E\rangle$ ($\mathrm{m}^2\ \mathrm{s}^{-2}$)")
ax.set_ylim([0, 0.053])
ax.set_xticks([0, 1, 2, 3])
ax.set_xticklabels(reso_vec)
ax.set_xlabel(r"$\Delta x = \Delta y\ (\mathrm{km})$")
ax.grid()
ax.legend(loc="best", bbox_to_anchor=(0, 0.6, 0.5, 0.5))

In [None]:
# as bar graphs to show decompositions
reso_vec = ["100", "50", "25", "10"]
width = 0.25

fig = plt.figure(figsize=(10, 3))
ax = plt.axes()

# R100
ax.bar(0, ene_diag[0, 0, 1], width, color="C0", label="EPE")
ax.bar(0, ene_diag[0, 0, 0], width, color="C3", bottom=ene_diag[0, 0, 1], label="EKE")
ax.bar(0, ene_diag[0, 0, 2], width, color="C2", bottom=np.sum(ene_diag[0, 0, :2], axis=-1), label="param E")

# CAREFUL! swapping the index of no filter and filtered calculation
ene_diag[:, [2, 1], :] = ene_diag[:, [1, 2], :]

for j in [0, 1, 2]: # NOW kgm = 0, filter, no filter
    for i in [1, 2, 3]:
        
        if i == 1:
            x = j-width-0.02+1
            alpha = 1
            hatch=""
        elif i == 2:
            x = j+1
            alpha = 1
            hatch=""
        elif i == 3:
            x = j+width+0.02+1
            alpha = 1
            hatch=""
        
        ax.bar(x, ene_diag[i, j, 1], width, 
               color="C0", alpha=alpha, hatch=hatch)
        ax.bar(x, ene_diag[i, j, 0], width, 
               color="C3", bottom=ene_diag[i, j, 1], alpha=alpha, hatch=hatch)
        ax.bar(x, ene_diag[i, j, 2], width, 
               color="C2", bottom=np.sum(ene_diag[i, j, :2], axis=-1), alpha=alpha, hatch=hatch)

ax.set_ylabel(r"$\langle E\rangle$ ($\mathrm{m}^2\ \mathrm{s}^{-2}$)")
ax.set_ylim([0, 0.06])
ax.set_xlim([-0.3, 3.5])
ax.set_xticks([0, 
               1-width-0.02, 1, 1+width+0.02, 
               2-width-0.02, 2, 2+width+0.02, 
               3-width-0.02, 3, 3+width+0.02])
ax.set_xticklabels(["100", "50", "25", "10", "50", "25", "10", "50", "25", "10"])
ax.set_xlabel(r"$\Delta x = \Delta y\ (\mathrm{km})$")
ax.yaxis.grid()
ax.legend(loc=2, bbox_to_anchor=(0.0, 0.75, 0.6, 0.45), ncol=3)
ax.text(0, 0.052, r"(1) R100", horizontalalignment="center")
ax.text(1, 0.052, r"(2) $\kappa_{\rm gm} = 0$", horizontalalignment="center")
ax.text(2, 0.052, r"(3) $L = 100$ km", horizontalalignment="center")
ax.text(3, 0.052, r"(4) $L = 0$ km", horizontalalignment="center")

R100_ene = np.sum(ene_diag[0, 0, :])
ax.plot([-1, 3.5], [R100_ene, R100_ene], "k:", alpha=0.8)

plt.savefig("./partial_res_figs/ene_decomp_100km.pdf", bbox_inches="tight")

In [None]:
ene_diag[0, 0, 1] / np.sum(ene_diag[0, 0, :])

-----------------------

## 4) Mean state sensitivities

In [None]:
# varying wind transport for various runs [JM: R100 and both R025 need redoing]
exp_list = ["050", "075", "100", "125", "150", "200", "300", "400"]

print("reading R010, gm0000...")

ACC_tot_R010, ACC_bot_R010 = np.zeros(len(exp_list)), np.zeros(len(exp_list))

for ind in range(len(exp_list)):
    data_dir = f"../../UNAGI/UNAGI/EXP_R010/no_split/gm0000/tau{exp_list[ind]}x/OUTPUTS/"
    mesh_mask = xr.open_dataset(data_dir + "../../../../mesh_mask.nc")
    data_U = xr.open_dataset(data_dir + "UNAGI_10y_08010101_08101230_grid_U.nc")
    uoce = data_U["uoce"][0, :, :, :].values
    
    ACC_tot_R010[ind] = utrans_tot(mesh_mask, uoce)
    ACC_bot_R010[ind] = utrans_bot(mesh_mask, uoce)
    
print("reading R100, GEOM...")

ACC_tot_R100, ACC_bot_R100 = np.zeros(len(exp_list)), np.zeros(len(exp_list))

for ind in range(len(exp_list)):
    data_dir = f"../../UNAGI/UNAGI/EXP_R100/alp0060_lam80/tau{exp_list[ind]}x/OUTPUTS/"
    mesh_mask = xr.open_dataset(data_dir + "../../../mesh_mask.nc")
    data_U = xr.open_dataset(data_dir + "UNAGI_10y_08010101_08101230_grid_U.nc")
    uoce = data_U["uoce"][0, :, :, :].values
    
    ACC_tot_R100[ind] = utrans_tot(mesh_mask, uoce)
    ACC_bot_R100[ind] = utrans_bot(mesh_mask, uoce)
    
print("reading R025, gm0000...")

ACC_tot_R025_gm0, ACC_bot_R025_gm0 = np.zeros(len(exp_list)), np.zeros(len(exp_list))

for ind in range(len(exp_list)):
    data_dir = f"../../UNAGI/UNAGI/EXP_R025/no_split/gm0000/tau{exp_list[ind]}x/OUTPUTS/"
    mesh_mask = xr.open_dataset(data_dir + "../../../../mesh_mask.nc")
    data_U = xr.open_dataset(data_dir + "UNAGI_10y_08010101_08101230_grid_U.nc")
    uoce = data_U["uoce"][0, :, :, :].values
    
    ACC_tot_R025_gm0[ind] = utrans_tot(mesh_mask, uoce)
    ACC_bot_R025_gm0[ind] = utrans_bot(mesh_mask, uoce)
    
print("reading R025, GEOM split...")
    
ACC_tot_R025, ACC_bot_R025 = np.zeros(len(exp_list)), np.zeros(len(exp_list))

for ind in range(len(exp_list)):
    data_dir = f"../../UNAGI/UNAGI/EXP_R025/split_100km/alp0060_lam80/tau{exp_list[ind]}x/OUTPUTS/"
    mesh_mask = xr.open_dataset(data_dir + "../../../../mesh_mask.nc")
    data_U = xr.open_dataset(data_dir + "UNAGI_10y_08010101_08101230_grid_U.nc")
    uoce = data_U["uoce"][0, :, :, :].values
    
    ACC_tot_R025[ind] = utrans_tot(mesh_mask, uoce)
    ACC_bot_R025[ind] = utrans_bot(mesh_mask, uoce)

In [None]:
# transport varying wind forcing

tau_vec = np.asarray(exp_list, dtype=float)/100
fig = plt.figure(figsize=(10, 3))
ax = plt.subplot(1, 2, 1)
ax.plot(tau_vec, ACC_tot_R010, 'x-')
# ax.plot(tau_vec, ACC_tot_R100, 'x-', label=r"R100, GEOM, $L=0$ km"))
ax.plot(tau_vec, ACC_tot_R025_gm0, 'x-')
ax.plot(tau_vec, ACC_tot_R025, 'C3x-')
ax.set_ylim([50, 250])
ax.set_xlabel(r"$\times\tau_0$")
ax.set_ylabel(r"$T_{\rm tot}$ (Sv)")
ax.text(0.4, 230, r"(a)")
ax.grid()

ax = plt.subplot(1, 2, 2)
ax.plot(tau_vec, ACC_tot_R010-ACC_bot_R010, 'x-', label=r"R010, $\kappa_{\rm gm} = 0$")
# ax.plot(tau_vec, ACC_tot_R100-ACC_bot_R100, 'x-')
ax.plot(tau_vec, ACC_tot_R025_gm0-ACC_bot_R025_gm0, 'x-', label=r"R025, $\kappa_{\rm gm} = 0$")
ax.plot(tau_vec, ACC_tot_R025-ACC_bot_R025, 'C3x-', label=r"R025, GEOM, $L=100$ km")
ax.set_ylim([30, 100])
ax.set_xlabel(r"$\times\tau_0$")
ax.set_ylabel(r"$T_{\rm tot} - T_{\rm bot}$ (Sv)")
ax.text(0.4, 93.5, r"(b)")
ax.grid()
ax.legend(loc=4, fontsize=11)

plt.savefig("./partial_res_figs/transport_vary_tau.pdf", bbox_inches="tight")

In [None]:
# take R010 kgm = 0 to be model truth

data_dir = "../../UNAGI/EXP_R010/no_split/gm0000/tau100x/OUTPUTS/"
mesh_mask = xr.open_dataset(data_dir + "../../../../mesh_mask.nc")
data_T = xr.open_dataset(data_dir + "UNAGI_10y_08010101_08101230_grid_T.nc")
data_U = xr.open_dataset(data_dir + "UNAGI_10y_08010101_08101230_grid_U.nc")

lat_ref = data_T["nav_lat"].values
z_ref = -data_T["deptht"].values

t_levels = np.arange(0, 10+1, 2)
u_levels = np.linspace(-0.15, 0.15, 21)

fig = plt.figure(figsize=(16, 3))

exp_list = ["EXP_R010/no_split/gm0000/tau100x",
            "EXP_R100/alp0060_lam80/tau100x", 
            "EXP_R025/split_100km/alp0060_lam80/tau100x",
            "EXP_R025/no_split/gm0000/tau100x"]

for i in range(len(exp_list)):
    
    # load data
    data_dir = f"../../UNAGI/{exp_list[i]}/OUTPUTS/"
    if i == 1:
        mesh_mask = xr.open_dataset(data_dir + "../../../mesh_mask.nc")
    else:
        mesh_mask = xr.open_dataset(data_dir + "../../../../mesh_mask.nc")
    data_T = xr.open_dataset(data_dir + "UNAGI_10y_08010101_08101230_grid_T.nc")
    data_U = xr.open_dataset(data_dir + "UNAGI_10y_08010101_08101230_grid_U.nc")
    
    lon = data_T["nav_lon"].values
    lat = data_T["nav_lat"].values
    z = -data_T["deptht"].values

    # define some metrics of relevance
    e1t = mesh_mask["e1t"][0, :, :].values
    tmask = mesh_mask["tmask"][0, :, :, :].values
    e1u = mesh_mask["e1u"][0, :, :].values
    umask = mesh_mask["umask"][0, :, :, :].values

    Lx_grid_T = np.mean(e1t[np.newaxis, :, :] * tmask, axis=-1)
    Lx_grid_T[Lx_grid_T == 0] = 1.0  # don't divide by zero
    Lx_grid_U = np.mean(e1u[np.newaxis, :, :] * umask, axis=-1)
    Lx_grid_U[Lx_grid_U == 0] = 1.0  # don't divide by zero

    toce = data_T["toce"][0, :, :, :].values
    uoce = data_U["uoce"][0, :, :, :].values
    
    t_zonal_avg = np.mean(toce * e1t[np.newaxis, :, :] * tmask, axis=-1) / Lx_grid_T
    t_zonal_avg[Lx_grid_T == 1.0] = np.nan  # don't plot out walls

    u_zonal_avg = np.mean(uoce * e1u[np.newaxis, :, :] * umask, axis=-1) / Lx_grid_U
    u_zonal_avg[Lx_grid_U == 1.0] = np.nan  # don't plot out walls
    
    if i == 0:
        t_zonal_avg_ref = copy.deepcopy(t_zonal_avg)

    # plots
    
    ax = plt.subplot(1, 4, i+1)

    mesh = ax.contourf(lat[1:-1, 0], z, u_zonal_avg[:, 1:-1], levels=u_levels, extend="both")
    lines = ax.contour(lat[1:-1, 0], z, t_zonal_avg[:, 1:-1], colors='k', levels=t_levels)
    if i != 0:
        ax.set_yticklabels([])
    ax.set_xlim([0, 2200])
    ax.set_ylim([-2000, 0])
    ax.clabel(lines)
    if i != 0:
        lines = ax.contour(lat_ref[1:-1, 0], z_ref, t_zonal_avg_ref[:, 1:-1], levels=t_levels, 
                           colors='g', alpha=0.7)
        ax.clabel(lines)
    
    ax.text(30, 50, f"({label_str[i]})")
    ax.set_xlabel("$y\ (\mathrm{km})$")
    
    if i == 0:
        ax.set_ylabel(r"$z\ (\mathrm{m})$")
    
# create colourbar axes by grabbing the bounding boxes of the individual plot axes
pos = ax.get_position()
ax_cb = fig.add_axes([pos.x1 + 0.01,
                      pos.y0,
                      0.005,
                      pos.y1 - pos.y0])
norm = Normalize(vmin = -0.15, vmax = 0.15)
cb = ColorbarBase(ax_cb, cmap = plt.cm.RdBu_r, norm = norm)
cb.set_ticks([-0.15, 0.0, 0.15])

fig.savefig("./partial_res_figs/zonal_mean_states.pdf", bbox_inches="tight")

In [None]:
# take R010 kgm = 0 to be model truth (R100 -> R025 no splitting for comparison)

data_dir = "../../UNAGI/EXP_R010/no_split/gm0000/tau100x/OUTPUTS/"
mesh_mask = xr.open_dataset(data_dir + "../../../../mesh_mask.nc")
data_T = xr.open_dataset(data_dir + "UNAGI_10y_08010101_08101230_grid_T.nc")
data_U = xr.open_dataset(data_dir + "UNAGI_10y_08010101_08101230_grid_U.nc")

lat_ref = data_T["nav_lat"].values
z_ref = -data_T["deptht"].values

t_levels = np.arange(0, 10+1, 2)
u_levels = np.linspace(-0.15, 0.15, 21)

fig = plt.figure(figsize=(12, 7))

exp_list = ["EXP_R010/no_split/gm0000/tau100x",
            "EXP_R100/alp0060_lam80/tau100x",
            "EXP_R025/no_split/gm0000/tau100x",
            "EXP_R025/no_split/alp0060_lam80/tau100x",
            "EXP_R025/split_100km/alp0060_lam80/tau100x"
            ]

ax_list = []

for i in range(len(exp_list)):
    
    # load data
    data_dir = f"../../UNAGI/{exp_list[i]}/OUTPUTS/"
    if i == 1:
        mesh_mask = xr.open_dataset(data_dir + "../../../mesh_mask.nc")
    else:
        mesh_mask = xr.open_dataset(data_dir + "../../../../mesh_mask.nc")
    data_T = xr.open_dataset(data_dir + "UNAGI_10y_08010101_08101230_grid_T.nc")
    data_U = xr.open_dataset(data_dir + "UNAGI_10y_08010101_08101230_grid_U.nc")
    
    lon = data_T["nav_lon"].values
    lat = data_T["nav_lat"].values
    z = -data_T["deptht"].values

    # define some metrics of relevance
    e1t = mesh_mask["e1t"][0, :, :].values
    tmask = mesh_mask["tmask"][0, :, :, :].values
    e1u = mesh_mask["e1u"][0, :, :].values
    umask = mesh_mask["umask"][0, :, :, :].values

    Lx_grid_T = np.mean(e1t[np.newaxis, :, :] * tmask, axis=-1)
    Lx_grid_T[Lx_grid_T == 0] = 1.0  # don't divide by zero
    Lx_grid_U = np.mean(e1u[np.newaxis, :, :] * umask, axis=-1)
    Lx_grid_U[Lx_grid_U == 0] = 1.0  # don't divide by zero

    toce = data_T["toce"][0, :, :, :].values
    uoce = data_U["uoce"][0, :, :, :].values
    
    t_zonal_avg = np.mean(toce * e1t[np.newaxis, :, :] * tmask, axis=-1) / Lx_grid_T
    t_zonal_avg[Lx_grid_T == 1.0] = np.nan  # don't plot out walls

    u_zonal_avg = np.mean(uoce * e1u[np.newaxis, :, :] * umask, axis=-1) / Lx_grid_U
    u_zonal_avg[Lx_grid_U == 1.0] = np.nan  # don't plot out walls
    
    if i == 0:
        t_zonal_avg_ref = copy.deepcopy(t_zonal_avg)

    # plots
    
    ax = plt.subplot(2, 3, i+2)
    
    ax_list.append(ax)

    mesh = ax.contourf(lat[1:-1, 0], z, u_zonal_avg[:, 1:-1], levels=u_levels, extend="both")
    lines = ax.contour(lat[1:-1, 0], z, t_zonal_avg[:, 1:-1], colors='k', levels=t_levels)
    ax.set_xlim([0, 2200])
    ax.set_ylim([-2000, 0])
    ax.clabel(lines)
    if i != 0:
        lines = ax.contour(lat_ref[1:-1, 0], z_ref, t_zonal_avg_ref[:, 1:-1], levels=t_levels, 
                           colors='g', alpha=0.7)
        ax.clabel(lines)
    
    ax.text(30, -150, f"({label_str[i]})")
    
# create colourbar axes by grabbing the bounding boxes of the individual plot axes
pos_bot = ax_list[-1].get_position()
pos_top = ax_list[1].get_position()
ax_cb = fig.add_axes([pos_bot.x1 + 0.01,
                      pos_bot.y0,
                      0.005,
                      pos_top.y1 - pos_bot.y0 + 0.01])
norm = Normalize(vmin = -0.15, vmax = 0.15)
cb = ColorbarBase(ax_cb, cmap = plt.cm.RdBu_r, norm = norm)
cb.set_ticks([-0.15, 0.0, 0.15])

# axes label and position formatting

offset = 0.13

ax_list[0].set_title("R010", fontsize=12)
ax_list[0].set_ylabel("$z\ (\mathrm{km})$")
pos = ax_list[0].get_position()
ax_list[0].set_position([pos.x0-offset, pos.y0+0.01, pos.x1-pos.x0, pos.y1-pos.y0])

ax_list[1].set_title("R100", fontsize=12)
ax_list[1].set_yticklabels([])
pos = ax_list[1].get_position()
ax_list[1].set_position([pos.x0-offset, pos.y0+0.01, pos.x1-pos.x0, pos.y1-pos.y0])

ax_list[2].set_title(r"R025, $\kappa_{\rm gm} = 0$", fontsize=12)
ax_list[2].set_ylabel("$z\ (\mathrm{km})$")

ax_list[3].set_title(r"R025, GEOM, $L=0$ km", fontsize=12)
ax_list[3].set_yticklabels([])
ax_list[3].set_xlabel("$y\ (\mathrm{km})$")

ax_list[4].set_title(r"R025, GEOM, $L=100$ km", fontsize=12)
ax_list[4].set_yticklabels([])

fig.savefig("./partial_res_figs/zonal_mean_states.pdf", bbox_inches="tight")

-----------------------

## 5) Misc schematic

In [None]:
# schematic of ideal energy decomposition

width = 0.2

fig = plt.figure(figsize=(10, 3))
ax = plt.subplot(1, 2, 1)

percentage = 0.1
ax.bar(0, percentage, width, color="C1", label=r"$E_{\rm explicit}$")
ax.bar(0, 1-percentage, width, color="C2", bottom=percentage, label=r"$E_{\rm param}$")

percentage = 0.5
ax.bar(0.5, percentage, width, color="C1")
ax.bar(0.5, 1-percentage, width, color="C2", bottom=percentage)

ax.bar(1, 1, width, color="C1")

ax.plot([-0.5, 1.5], [1.0, 1.0], 'k:', alpha=0.8)
ax.set_xlim([-0.3, 1.3])
ax.set_ylim([0, 1.15])

ax.set_xticks([0, 0.5, 1.0])
ax.set_xticklabels(["coarse\n resolution", "eddy\n permitting", "eddy\n resolving"])
ax.set_yticks([])
ax.set_ylabel(r"$E_{\rm tot} = E_{\rm explicit} + E_{\rm param}$")
ax.text(-0.28, 1.05, "(a)")

ax = plt.subplot(1, 2, 2)

percentage = 0.1
ax.bar(0, percentage, width, color="C1", label=r"$E_{\rm explicit}$")
ax.bar(0, 1-percentage, width, color="C2", bottom=percentage, label=r"$E_{\rm param}$")

percentage = 0.4
ax.bar(0.5, percentage, width, color="C1")
ax.bar(0.5, 1-percentage, width, color="C2", bottom=percentage)

percentage = 0.6
ax.bar(1.0, percentage, width, color="C1")
ax.bar(1.0, 1-percentage, width, color="C2", bottom=percentage)

ax.plot([-0.5, 1.5], [1.0, 1.0], 'k:', alpha=0.8)
ax.set_xlim([-0.3, 1.3])
ax.set_ylim([0, 1.15])

ax.set_xticks([0, 0.5, 1.0])
ax.set_xticklabels(["coarse\n resolution", "eddy\n permitting", "eddy\n resolving"])
ax.set_yticks([])
ax.text(-0.28, 1.05, "(b)")

ax.legend(loc=1, bbox_to_anchor=(-0.41, 0.4, 0.5, 0.5))

plt.savefig("./partial_res_figs/ene_decomp_schematic.pdf", bbox_inches="tight")