# Decompose the elevation errors into phase and amplitude components

In [20]:
%matplotlib inline

import datetime, time
import numpy as np
import xesmf as xe
import xarray as xr
import netCDF4 as nc
import cmocean as cm
import matplotlib.ticker
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
from matplotlib.colors import LogNorm
import matplotlib.gridspec as gridspec

import warnings
warnings.filterwarnings('ignore')

In [21]:
mom6 = xr.open_dataset('/g/data/nm03/lxy581/evaluate/amp_phase/tides_025_JSL_x6_global.nc')
mmgd = xr.open_dataset('/g/data/nm03/lxy581/archive/tides_025_SAL_JSL_masked/output003/ocean_static.nc')
data = xr.open_dataset('/g/data/nm03/lxy581/archive/tides_025_SAL_JSL_x6/output002/ocean_interior.nc')
tpxo = xr.open_dataset('/g/data/nm03/TPXO/h_tpxo9.v1.nc')
tpgd = xr.open_dataset('/g/data/nm03/TPXO/grid_tpxo9.nc')

### Read phase and amplitude 

In [22]:
pha_mom6 = mom6.phase 
amp_mom6 = mom6.amp
pha_tpxo = (tpxo.hp.isel({'nc':0}).transpose() - 180) / 180*np.pi +13*np.pi/16
amp_tpxo = tpxo.ha.isel({'nc':0}).transpose()

### Read grid

In [23]:
geolon = mmgd.geolon
geolat = mmgd.geolat
depth  = mmgd.deptho
area   = mmgd.areacello
yh, xh = mmgd.yh, mmgd.xh
lon_tpxo = tpgd.lon_z.transpose()
lat_tpxo = tpgd.lat_z.transpose()
hz_tpxo  = tpgd.hz.transpose()
ny, nx = tpgd.ny, tpgd.nx
fac_dep = xr.where(depth > 1000, 1, np.nan) 
fac_lat = xr.where((geolat < 75) & (geolat > -75), 1, np.nan)
fac = np.array(fac_dep) * np.array(fac_lat)

In [24]:
topog = xr.open_dataset('/home/581/lxy581/tidal_param/MOM6-examples/ocean_only/tides_025/INPUT/ocean_topog.nc')
depth = topog.depth
depth_z = np.array(depth)
depth_z[depth_z==0]=np.nan
depth_da = xr.Dataset(data_vars={'depth_xr': (('yh','xh'), depth_z),
                                }, 
                      coords={'lon': (('yh', 'xh'), np.array(geolon)),
                              'lat': (('yh', 'xh'), np.array(geolat))})
land = xr.where(np.isnan(depth_da.depth_xr[:,:].rename('land')), 1, np.nan)

### Need to interpolate to the same grid - MOM6 grid

In [None]:
ds_tpxo_data = xr.Dataset(data_vars={'pha_tpxo': (('ny','nx'), np.array(pha_tpxo)),
                                     'amp_tpxo': (('ny','nx'), np.array(amp_tpxo)),
                                     'hz_tpxo': (('ny','nx'), np.array(hz_tpxo))},
                          coords={'lon': (('ny', 'nx'), np.array(lon_tpxo)), 
                                  'lat': (('ny', 'nx'), np.array(lat_tpxo))})
ds_mom6_grid = xr.Dataset({"lat": (["yh","xh"], np.array(geolat)),
                           "lon": (["yh","xh"], np.array(geolon))})
regridder = xe.Regridder(ds_tpxo_data, ds_mom6_grid, "bilinear", extrap_method="inverse_dist")
ds_tpxo_mom6_grid = regridder(ds_tpxo_data)

### Calculate TPXO tidal elevation

In [None]:
days = np.array([0,31,28,31,30,31,30,31,31,30,31,30,31])
days_accum = np.cumsum(days)
print(days_accum)

In [None]:
t_19cyc = np.arange(236)
nt = t_19cyc.size
omega_m2 = np.full((nt,yh.size,xh.size),2*np.pi/12.4206014)
elev_tpxo = np.array(ds_tpxo_mom6_grid['amp_tpxo'])[None,:,:] * np.cos(omega_m2 * (t_19cyc[:,None,None]+days_accum[2]*24) - np.array(ds_tpxo_mom6_grid['pha_tpxo'])[None,:,:])

In [None]:
elev_mom6 = np.array(data.e.isel({'time':np.arange(236),'zi':0})) * np.array(fac_lat)[None,:,:] * np.array(fac_dep)[None,:,:]
elev_mom6_recon = np.array(amp_mom6)[None,:,:] * np.cos(omega_m2 * (t_19cyc[:,None,None]+days_accum[2]*24) - np.array(pha_mom6)[None,:,:]) * np.array(fac_lat)[None,:,:] * np.array(fac_dep)[None,:,:]

In [None]:
print(np.nanmean(elev_mom6))
print(np.nanmean(elev_mom6_recon))

In [None]:
yind, xind = 600, 1000
plt.plot(elev_mom6[:,yind,xind],'k--',linewidth=2)
plt.plot(elev_mom6_recon[:,yind,xind],'r',linewidth=2,alpha=0.3)

### Calculate the elevation error over the selected period

In [None]:
elev_err = np.sqrt(np.nanmean( ( elev_mom6 *fac[None,:,:] - elev_tpxo *fac[None,:,:] )**2, axis=0 ))

In [None]:
ele_err_amp = np.sqrt(0.5 * (np.array(amp_mom6) - np.array(ds_tpxo_mom6_grid['amp_tpxo']))**2) * fac
ele_err_pha = np.sqrt(np.array(amp_mom6) * np.array(ds_tpxo_mom6_grid['amp_tpxo']) * (1 - np.cos(np.array(pha_mom6) - np.array(ds_tpxo_mom6_grid['pha_tpxo'])))) * fac

### Plot amplitude-induced elevation error and phase-induced elevation error

In [None]:
a = np.sqrt(np.nansum(elev_err**2 * area)/np.nansum(area))*100
b = np.sqrt(np.nansum(ele_err_amp**2 * area)/np.nansum(area))*100
c = np.sqrt(np.nansum(ele_err_pha**2 * area)/np.nansum(area))*100
print(a)
print(a**2 - b**2 - c**2)

In [None]:
varlist = ['elev_err','ele_err_amp','ele_err_pha']
cb_text = ['elevation error (cm)', 'elevation error (cm)', 'elevation error (cm)']
# title   = ['Total \n Global mean: %.2f cm'%np.sqrt(np.nanmean(elev_err**2))*100,'Amplitude \n Global mean: %.2f cm'%np.sqrt(np.nanmean(ele_err_amp**2))*100, 'Phase \n Global mean: %.2f cm'%np.sqrt(np.nanmean(ele_err_pha**2))*100)]
title   = ['Total \n Area-weighted RMSE: %.1f cm' % a,'Amplitude \n Area-weighted RMSE: %.4f cm' % b, 'Phase \n Area-weighted RMSE: %.4f cm' % c]

fig = plt.figure(figsize=(8, 14))
axs = []

gridsubs = gridspec.GridSpec(3,1)
for gs in gridsubs:
    axs.append(plt.subplot(gs))

print('Start plotting...')
for I, (ax, var) in enumerate(zip(axs, varlist)):
    print(I)

    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.get_xaxis().set_ticks([])
    ax.get_yaxis().set_ticks([])

    # print('Plotting topog...')
    if I<3:
        ax = plt.axes(projection=ccrs.Robinson(central_longitude=-100))
        # Add model land mask
        land.plot.contourf(ax=ax, x='lon', y='lat', colors='darkgrey', zorder=2, transform=ccrs.PlateCarree(), add_colorbar=False)
        # Add model coastline
        land.fillna(0).plot.contour(ax=ax, x='lon', y='lat', colors='k', levels=[0, 1], transform=ccrs.PlateCarree(), add_colorbar=False, linewidths=2)
        tick_locs = np.array([0,0.05,0.10,0.15,0.20,0.25])
        tick_labels = np.array(["0","5","10","15","20","25"])
        
    if I==0:
        print('Plotting error...')
        p1 = ax.pcolormesh(geolon, geolat, globals()[var], transform=ccrs.PlateCarree(), cmap=cm.cm.dense, vmin=0, vmax=0.25) #norm=LogNorm(vmin=1e-2, vmax=1e-0)
        ax_cb = plt.axes([0.9, 0.75, 0.008, 0.15])
        cb1 = plt.colorbar(p1, cax=ax_cb, orientation='vertical', extend='both')
        cb1.ax.set_ylabel(cb_text[I],fontsize=20,rotation=270,labelpad=25);
        cb1.ax.tick_params(labelsize=16)
        cb1.locator   = matplotlib.ticker.FixedLocator(tick_locs)
        cb1.formatter = matplotlib.ticker.FixedFormatter(tick_labels)
        ax.set_position([0.1,0.7,0.75,0.25])
        ax.set_title(title[I],fontsize=24)
    
    if I==1:
        print('Plotting amp...')
        p1 = ax.pcolormesh(geolon, geolat, globals()[var], transform=ccrs.PlateCarree(), cmap=cm.cm.dense, vmin=0, vmax=0.25) #norm=LogNorm(vmin=1e-2, vmax=1e-0)
        ax_cb = plt.axes([0.9, 0.45, 0.008, 0.15])
        cb1 = plt.colorbar(p1, cax=ax_cb, orientation='vertical', extend='both')
        cb1.ax.set_ylabel(cb_text[I],fontsize=20,rotation=270,labelpad=25);
        cb1.ax.tick_params(labelsize=16)
        cb1.locator   = matplotlib.ticker.FixedLocator(tick_locs)
        cb1.formatter = matplotlib.ticker.FixedFormatter(tick_labels)
        ax.set_position([0.1,0.4,0.75,0.25])
        ax.set_title(title[I],fontsize=24)


    if I==2:
        print('Plotting phase...')
        p1 = ax.pcolormesh(geolon, geolat, globals()[var], transform=ccrs.PlateCarree(), cmap=cm.cm.dense, vmin=0, vmax=0.25)
        ax_cb = plt.axes([0.9, 0.15, 0.008, 0.15])
        cb1 = plt.colorbar(p1, cax=ax_cb, orientation='vertical', extend='both')
        cb1.ax.set_ylabel(cb_text[I],fontsize=20,rotation=270,labelpad=25);
        cb1.ax.tick_params(labelsize=16)
        cb1.locator   = matplotlib.ticker.FixedLocator(tick_locs)
        cb1.formatter = matplotlib.ticker.FixedFormatter(tick_labels)
        ax.set_position([0.1,0.1,0.75,0.25])
        ax.set_title(title[I],fontsize=24)

print('Saving...')
plt.savefig('/g/data/nm03/lxy581/evaluate/amp_phase/MOM6_JSL_x6_M2_decompose_amp_pha.png', dpi=300, bbox_inches='tight')
# plt.savefig('/g/data/nm03/lxy581/evaluate/amp_phase/MOM6_JSL_M2_masked_x8_006_decompose_amp_pha.png', dpi=300, bbox_inches='tight')