# **oceanliner**: observing system simulation experiments (OSSEs) for SWOT in situ campaigns

This notebook enables a user to select a region from one of the 15 [Adopt-a-Crossover](https://www.swot-adac.org/) sites and specify a sampling pattern (e.g., the path of an ocean glider, ship-based underway CTD, Wave Glider, Saildrone, or mooring), and then does the following:
* download llc4320 data to the local machine
* compute steric height and potential vorticity 
* interpolate model fields in space/time over the trajectory
* return set of subsampled variables on a regular grid 
* store and plot outputs, including "true" and "subsampled" steric height 

### User inputs:
* **RegionName**: name of region (corresponding to filename) - options are WesternMed  ROAM_MIZ  NewCaledonia  NWPacific  BassStrait  RockallTrough  ACC_SMST MarmaraSea  LabradorSea  CapeBasin Boknis GotlandBasin NWAustralia WestAtlantic Yongala - [link to PO.DAAC data listings](https://podaac.jpl.nasa.gov/datasetlist?ids=Processing+Levels&values=4+-+Gridded+Model+Output&search=Pre-SWOT+llc4320&view=list&provider=)
* **start_date**, **ndays**: specify date range as start date & number of days.
* **PLATFORM** : simulated platform with with to sample the model: glider track (`glider`), shipboard underway CTD (`uctd`), Wave Glider (`wave_glider`), Saildrone (`saildrone`) mooring (`mooring`), or a user-specified trajectory (`trajectory_file`), in which casea netCDF file **trajectory_file** must also be specified. Specifying a sampling platform applies realistic default values for platform speed and depth range
* **PATTERN**: survey PATTERN -- can be `lawnmower` or `back-forth`
* **datadir** : directory where data files are stored



### Future developments to include:
* implement other sampling platforms that interact with model current fields (e.g., drifters)
* adapt for the AWS cloud
* compute other variables of interest
* import simulated SWOT data at the same location
* pipe output into optimal interpolation software
* implement other models (including biogeochemical model)
* efficiency improvements






In [None]:
## Imports

# Native packages
from math import radians, degrees, sin, cos, asin, acos, sqrt
import datetime
import time
import sys
import os
import warnings

# Third-party packages for data manipulation
import numpy as np
import pandas as pd
import xarray as xr

# Other third-party packages
import netCDF4 as nc4

# Third-party packages for data interpolation
from scipy import interpolate
from scipy.interpolate import griddata
from xgcm import Grid
import xgcm.grid

# Third-party packages for data visualizations
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from mpl_toolkits import mplot3d
from mpl_toolkits.mplot3d import axes3d
import matplotlib.dates as mdates

# osse tools package
# del sys.modules['oceanliner_functions']  # uncomment if troubleshooting oceanliner_functions
from oceanliner_functions import download_llc4320_data, compute_derived_fields, get_survey_track, survey_interp, rotate_vector_to_EN


In [None]:
# --------------------------------------------------------------------
# USER INPUTS: 
# --------------------------------------------------------------------
# *** users should only have to modify the variables in this cell ***
# --------------------------------------------------------------------


# ------------ region and time span

# RegionName: region with MITgcm llc4320 data on the PO.DAAC 
#   options are:
#     WesternMed  ROAM_MIZ  NewCaledonia  NWPacific  BassStrait  RockallTrough  ACC_SMST
#     MarmaraSea  LabradorSea  CapeBasin Boknis GotlandBasin NWAustralia WestAtlantic Yongala
RegionName = 'WesternMed' 

# start_date:  first date to subsample the model. 
#   MITgcm llc4320 data are from 2011-Sep-13 to 2012-Nov-15
start_date = datetime.date(2012,1,1)
# ndays: number of days over which to subsample the model (starting at start_date) 
#   note: ndays must be >1 
#   note: a large number of ndays may crash smaller machines
ndays = 5 
# --------------------------------------------------------------------


# ------------ directories -------------------------------------------

# directory where data files are stored
datadir = '/data3/llc4320/' + RegionName + '/'                     
# directory where generated output files will be stored
outputdir = '/data3/adac/osse_output/' + RegionName + '/'           
# directory where figures will be stored
figdir = '/data2/Dropbox/projects/adac/figures/' + RegionName + '/' # store figures
# --------------------------------------------------------------------
# directory where data files are stored
datadir = '/Users/kdrushka/data/adac/mitgcm/netcdf/' + RegionName + '/'                     
# directory where generated output files will be stored
outputdir = '/data3/adac/osse_output/' + RegionName + '/'           
# directory where figures will be stored
figdir = '/data2/Dropbox/projects/adac/figures/' + RegionName + '/' # store figures
# --------------------------------------------------------------------




# ------------ sampling parameters ------------------------------------

# PLATFORM: which instrument to sample with:
#   options are glider, uctd, mooring, wave_glider, saildrone, waypoints, or trajectory_file
PLATFORM = 'glider' 
# -----  if *only* PLATFORM is specified, reasonable defaults will be selected for the following parameters:

# WAYPOINTS:
#  specify either a list of lon,lat waypoints as {'x':[151, 153], 'y':[-55, -56]}
#    note: for a mooring, specify a single value for x and y
#    note: xwaypoints & ywaypoints must have the same size
#  or, a filename (string) of a netCDF file specifying the waypoints as {'waypoint_file':'filename.nc'} (EXAMPLE NEEDED)
#  or, None - in which case a simple pattern within the domain will be selected based on "PATTERN" (useful for testing/demo)
#  ** note, if the waypoints are outside of the domain, get_survey_track will raise an error (TODO: add warning earlier)
WAYPOINTS = {'x':[3, 3], 'y':[38, 38.2]}

# PATTERN: if WAYPOINTS=None, generate waypoints using this pattern:
#   options are back-forth (repeated back+forth betwen two waypoints) or lawnmower (radiator survey)
#   if waypoints are set, the value of PATTERN is irrelevant 
PATTERN = 'lawnmower' # back-forth or lawnmower 

# AT_END: what to do after the end of the list of waypoints is reached.
#   options are reverse (follow the track back to the start), repeat (go straight back to start and then repeat) or terminate (stop) 
AT_END = 'reverse'

# zrange: depth range of measurements as a 2-element list. 
#   * note, if depth values are negative (As in MITgcm), zrange should be negative
zrange = [-1, -1000]

# z_res: vertical sampling resolution in m
z_res = 1

# zmooring: for PLATFORM=mooring, 
#   specify the depth of the instruments
#   (default x/y is the center of the domain)
zmooring = [-1, -10, -50, -500, -1000]
      
# hspeed: horizontal speed of the platform (in m/s) 
hspeed = 0.25

# vspeed: vertical speed of the platform through the water (in m/s)
#   for glider, this is the speed at which the glider dives
#   for uctd, this is the fall-rate of the instrument
#   for mooring, wave_glider and saildrone vspeed is irrelevant
vspeed = 0.1


# DERIVED_VARIABLES: whether or not to compute the "derived variables" - 'steric_height', 'N2', and/or 'vorticity'
#   it takes significant memory and time to derive and save the stored variables, so None is recommended 
#   if memory is limited
DERIVED_VARIABLES = {'steric_height', 'N2', 'vorticity'} # or, None

# --------------------------------------------------------------------




# ------------ save flags ------- ------------------------------------

# SAVE_FIGURES: if True, save some basic figures
SAVE_FIGURES = False 

# SAVE_PRELIMINARY: if True, save preliminary along-track data. 
#  This takes more space but is faster and less prone to crashing than the gridding step
SAVE_PRELIMINARY = False 


In [None]:
# store the user-specified details in "sampling_details" variable
sampling_details = {
   'PLATFORM' : PLATFORM,
    'WAYPOINTS' : WAYPOINTS,
    'PATTERN' : PATTERN,
    'zrange' : zrange,  
    'z_res' : z_res,
    'hspeed' : hspeed,
    'vspeed' : vspeed,
    'AT_END' : AT_END,
    'zmooring' : zmooring,
    'DERIVED_VARIABLES' : DERIVED_VARIABLES,
    'SAVE_PRELIMINARY' : SAVE_PRELIMINARY,
    'start_date' : start_date,
    'ndays' : ndays
}


#### Download & load model data and derived fields

Based on [LLC4320](https://data.nas.nasa.gov/viz/vizdata/llc4320/index.html), the 1/48-degree global MITgcm simulation produced by the ECCO project. Ten regional cut-outs of the simulation are available on the [PO.DAAC](https://podaac.jpl.nasa.gov/datasetlist?ids=Processing+Levels&values=4+-+Gridded+Model+Output&search=Pre-SWOT+llc4320&view=list&provider=); the 4x4 degree regional domains are small enough to enable fairly easy downloads and processing. 

In [None]:
# download files:
download_llc4320_data(RegionName, datadir, start_date, ndays)

# derive & save new files with steric height & vorticity
if sampling_details['DERIVED_VARIABLES']:
    compute_derived_fields(RegionName, datadir, start_date, ndays, sampling_details['DERIVED_VARIABLES'])

In [288]:
# def compute_derived_fields_TEST(RegionName, datadir, start_date, ndays, DERIVED_VARIABLES):
#     """ Check for derived files in {datadir}/derived and compute if the files don't exist


#     Args:
#         RegionName (str): It can be selected from the list of regions with pre-SWOT llc4320 data
#         datadir (str): Directory where input model files are stored
#         start_date (datetime): Starting date for computing derived fields
#         ndays (int): Number of days from the start date to compute derived fields
#         DERIVED_VARIABLES (str list): specifies which variables to derive (steric_height, N2 and/or vorticity)

#     Returns:
#         None
        
#     Raises: 
#         TBD: TBD

#     """
from datetime import datetime, date, time, timedelta
import gsw as sw
# directory to save derived data to - create if doesn't exist
derivedir = datadir + 'derived/'
if not(os.path.isdir(derivedir)):
    os.mkdir(derivedir)

# files to load:
date_list = [start_date + timedelta(days=x) for x in range(ndays)]
target_files = [f'{datadir}LLC4320_pre-SWOT_{RegionName}_{date_list[n].strftime("%Y%m%d")}.nc' for n in range(ndays)] # list target files

# list of derived files:
derived_files = [f'{derivedir}LLC4320_pre-SWOT_{RegionName}_derived-fields_{date_list[n].strftime("%Y%m%d")}.nc' for n in range(ndays)] # list target files


# loop through input files, then do the following:
# - compute steric height
# - compute N2
# - interpolate vector quantities (velocity and wind) to the tracer grid
# - compute vorticity
fis = range(len(target_files))




DERIVED_VARIABLES = 'N2'
cnt = 0 # count
for fi in fis:
    # input filename:
    thisf=target_files[fi]
    # output filename:
    fnout = thisf.replace(RegionName + '_' , RegionName + '_derived-fields_')
    fnout = fnout.replace(RegionName + '/' , RegionName + '/derived/')
    # check if output file already exists
    if (not(os.path.isfile(fnout))):   
        print(f'computing {DERIVED_VARIABLES} for {thisf}') 
        # load file:
        ds = xr.open_dataset(thisf)

        if (('steric_height' in DERIVED_VARIABLES) or ('N2' in DERIVED_VARIABLES)):
            # mean lat/lon of domain
            xav = ds.XC.isel(j=0).mean(dim='i')
            yav = ds.YC.isel(i=0).mean(dim='j')
            if 'steric_height' in DERIVED_VARIABLES:
                # Steric height calculation requires a reference profile. Get this from Argo.
                # -------
                # first time through the loop, load reference profile:
                # load a single file to get coordinates
                if cnt==0:

                    # for transforming U and V, and for the vorticity calculation, build the xgcm grid:
                    # see https://xgcm.readthedocs.io/en/latest/xgcm-examples/02_mitgcm.html
                    grid = xgcm.Grid(ds, coords={'X':{'center': 'i', 'left': 'i_g'}, 
                                 'Y':{'center': 'j', 'left': 'j_g'},
                                 'T':{'center': 'time'},
                                 'Z':{'center': 'k'}})


                    # --- load reference file of argo data
                    # here we use the 3x3 annual mean Argo product on standard produced by IRPC & distributed by ERDDAP
                    # https://apdrc.soest.hawaii.edu/erddap/griddap/hawaii_soest_defb_b79c_cb17.html
                    # - download the profile closest to xav,yav once (quick), use it, then delete it.

                    # URL gets temp & salt at all levels
                    argofile = f'https://apdrc.soest.hawaii.edu/erddap/griddap/hawaii_soest_625d_3b64_cc4d.nc?temp[(0000-12-15T00:00:00Z):1:(0000-12-15T00:00:00Z)][(0.0):1:(2000.0)][({yav.data}):1:({yav.data})][({xav.data}):1:({xav.data})],salt[(0000-12-15T00:00:00Z):1:(0000-12-15T00:00:00Z)][(0.0):1:(2000.0)][({yav.data}):1:({yav.data})][({xav.data}):1:({xav.data})]'

                    # delete the argo file if it exists 
                    if os.path.isfile('argo_local.nc'):
                        os.remove('argo_local.nc')
                    # use requests to get the file, and write locally:
                    r = requests.get(argofile)
                    file = open('argo_local.nc','wb')
                    file.write(r.content)
                    file.close()
                    # open the argo file:
                    argods = xr.open_dataset('argo_local.nc',decode_times=False)
                    # get rid of time coord/dim/variable, which screws up the time in ds if it's loaded
                    argods = argods.squeeze().reset_coords(names = {'time'}, drop=True) 
                    # reference profiles: annual average Argo T/S using nearest neighbor
                    Tref = argods["temp"]
                    Sref = argods["salt"]
                    # SA and CT from gsw:
                    # see example from https://discourse.pangeo.io/t/wrapped-for-dask-teos-10-gibbs-seawater-gsw-oceanographic-toolbox/466
                    Pref = xr.apply_ufunc(sw.p_from_z, -argods.LEV, yav)
                    Pref.compute()
                    SAref = xr.apply_ufunc(sw.SA_from_SP, Sref, Pref, xav, yav,
                                           dask='parallelized', output_dtypes=[Sref.dtype])
                    SAref.compute()
                    CTref = xr.apply_ufunc(sw.CT_from_pt, Sref, Tref, # Theta is potential temperature
                                           dask='parallelized', output_dtypes=[Sref.dtype])
                    CTref.compute()
                    Dref = xr.apply_ufunc(sw.density.rho, SAref, CTref, Pref,
                                        dask='parallelized', output_dtypes=[Sref.dtype])
                    Dref.compute()


                    cnt = cnt+1
                    print()
                    # end reference profile calculation


            # -------
            # --- COMPUTE STERIC HEIGHT and/or N2 IN STEPS ---
            # 0. create datasets for variables of interest:
            ss = ds.Salt
            tt = ds.Theta
            pp = xr.DataArray(sw.p_from_z(ds.Z,ds.YC))

            # 1. compute absolute salinity and conservative temperature
            sa = xr.apply_ufunc(sw.SA_from_SP, ss, pp, xav, yav, dask='parallelized', output_dtypes=[ss.dtype])
            sa.compute()
            ct = xr.apply_ufunc(sw.CT_from_pt, sa, tt, dask='parallelized', output_dtypes=[ss.dtype])
            ct.compute()

            # compute N2:
            if 'N2' in DERIVED_VARIABLES:
                N2, pmid = sw.Nsquared(sa, ct, pp, lat=None, axis=1) # axis 1 = depth
                # make into a dataset
                N2_ds = N2.to_dataset(name='N2')
                print(N2)
                print(pmid)
                print(N2_ds)


            # compute steric height:
            if 'steric_height' in DERIVED_VARIABLES:
                dd = xr.apply_ufunc(sw.density.rho, sa, ct, pp, dask='parallelized', output_dtypes=[ss.dtype])
                dd.compute()
                # 2. compute specific volume anomaly: gsw.density.specvol_anom_standard(SA, CT, p)
                sva = xr.apply_ufunc(sw.density.specvol_anom_standard, sa, ct, pp, dask='parallelized', output_dtypes=[ss.dtype])
                sva.compute()
                # 3. compute steric height = integral(0:z1) of Dref(z)*sva(z)*dz(z)
                # - first, interpolate Dref to the model pressure levels
                Drefi = Dref.interp(LEV=-ds.Z)
                dz = -ds.Z_bnds.diff(dim='nb').drop_vars('nb').squeeze() # distance between interfaces

                # steric height computation (summation/integral)
                # - increase the size of Drefi and dz to match the size of sva
                Db = Drefi.broadcast_like(sva)
                dzb = dz.broadcast_like(sva)
                dum = Db * sva * dzb
                sh = dum.cumsum(dim='k') 
                # this gives sh as a 3-d variable, (where the depth dimension 
                # represents the deepest level from which the specific volume anomaly was interpolated)
                # - but in reality we just want the SH that was determined by integrating over
                # the full survey depth, which gives a 2-d output:
                sh_true = dum.sum(dim='k') 

                # make into dataset:
                sh_ds = sh.to_dataset(name='steric_height')
                sh_true_ds = sh_true.to_dataset(name='steric_height_true')            
                # add/rename the Argo reference profile variables to dout:
                tref = Tref.to_dataset(name='Tref')
                tref = tref.merge(Sref).rename({'salt': 'Sref'}).\
                    rename({'LEV':'zref','latitude':'yav','longitude':'xav'})

        if 'vorticity' in DERIVED_VARIABLES:                
            # --- COMPUTE VORTICITY using xgcm and interpolate to X, Y
            # see https://xgcm.readthedocs.io/en/latest/xgcm-examples/02_mitgcm.html
            vorticity = (grid.diff(ds.V*ds.DXG, 'X') - grid.diff(ds.U*ds.DYG, 'Y'))/ds.RAZ
            vorticity = grid.interp(grid.interp(vorticity, 'X', boundary='extend'), 'Y', boundary='extend')
            # make into dataset
            v_ds =vorticity.to_dataset(name='vorticity')




        # ------------------
        # save derived fields in a new file
        if 'steric_height' in DERIVED_VARIABLES:
            dout = sh_ds
            dout = dout.merge(sh_true_ds)
            # - add ref profiles to dout and drop uneeded vars/coords
            dout = dout.merge(tref).drop_vars({'longitude','latitude','LEV'})
            # - add attributes for all variables
            dout.steric_height.attrs = {'long_name' : 'Steric height',
                                    'units' : 'm',
                                    'comments_1' : 'Computed by integrating the specific volume anomaly (SVA) multiplied by a reference density, where the reference density profile is calculated from temperature & salinity profiles from the APDRC 3x3deg gridded Argo climatology product (accessed through ERDDAP). The profile nearest to the center of the domain is selected, and T & S profiles are averaged over one year before computing ref density. SVA is computed from the model T & S profiles. the Gibbs Seawater Toolbox is used compute reference density and SVA. steric_height is given at all depth levels (dep): steric_height at a given depth represents steric height signal generated by the water column above that depth - so the deepest steric_height value represents total steric height (and is saved in steric_height_true'
                                       }
            dout.steric_height_true.attrs = dout.steric_height.attrs
            dout.Tref.attrs = {'long_name' : f'Reference temperature profile at {yav.data}N,{xav.data}E',
                                'units' : 'degree_C',
                                'comments_1' : 'From Argo 3x3 climatology produced by APDRC'}
            dout.Sref.attrs = {'long_name' : f'Reference salinity profile at {yav.data}N,{xav.data}E',
                                    'units' : 'psu',
                                    'comments_1' : 'From Argo 3x3 climatology produced by APDRC'}

            dout.zref.attrs = {'long_name' : f'Reference depth for Tref and Sref',
                                    'units' : 'm',
                                    'comments_1' : 'From Argo 3x3 climatology produced by APDRC'}

            # merge vorticity 
            if 'vorticity' in DERIVED_VARIABLES:  
                dout = dout.merge(v_ds)

        # if we only computed vorticity, dout = v_ds
        elif 'vorticity' in DERIVED_VARIABLES:  
            dout = v_ds


        # if vorticity, add the attrs:
        if 'vorticity' in DERIVED_VARIABLES:  
            dout.vorticity.attrs = {'long_name' : 'Vertical component of the vorticity',
                                'units' : 's-1',
                                'comments_1' : 'computed on DXG,DYG then interpolated to X,Y'}         

        # - save netcdf file with derived fields
        netcdf_fill_value = nc4.default_fillvals['f4']
        dv_encoding = {}
        for dv in dout.data_vars:
            dv_encoding[dv]={'zlib':True,  # turns compression on\
                        'complevel':1,     # 1 = fastest, lowest compression; 9=slowest, highest compression \
                        'shuffle':True,    # shuffle filter can significantly improve compression ratios, and is on by default \
                        'dtype':'float32',\
                        '_FillValue':netcdf_fill_value}
        # save to a new file
        print(' ... saving to ', fnout)
        # TROUBLESHOOTING::::: DELETE THE RETURN LINE
        #return dout, dv_encoding
        dout.to_netcdf(fnout,format='netcdf4',encoding=dv_encoding)



# release Argo file at the end of all files
if 'argods' in locals():
    argods.close()


In [None]:
sampling_details['DERIVED_VARIABLES'] = 'N2'
compute_derived_fields_TEST(RegionName, datadir, start_date, ndays, sampling_details['DERIVED_VARIABLES'])

In [293]:
N2, pmid = sw.Nsquared(sa, ct, pp, lat=None, axis=1)

In [None]:
N2, pmid = xr.apply_ufunc(sw.Nsquared, sa, ct, pp, kwargs={"lat": None, }(lat=None), (axis=1), dask='parallelized', output_dtypes=[ss.dtype])

                

In [346]:
xr.apply_ufunc?kwargs={"axis": -1}

In [355]:
# can't figure out how to make xr.apply_ufunc(sw.Nsquared, sa, ct, pp, kwargs={"lat": None, "axis":1}, ... ) work, so do it this way
N2, pmid = sw.Nsquared(sa, ct, pp, lat=None, axis=1)

# interpolate from pmid to depth
pmid_av = np.mean(pmid, axis=(0,2,3))
zmid = xr.DataArray(sw.z_from_p(pmid_av,np.tile(yav,[len(pmid_av)])))




# N2i = grid.interp(N2, 'X', boundary='extend'), 'Y', boundary='extend')

In [362]:
# interpolate N2 to the vertical depth grid from zmid
# N2i = np.interp
np.interp?

[0;31mSignature:[0m [0mnp[0m[0;34m.[0m[0minterp[0m[0;34m([0m[0mx[0m[0;34m,[0m [0mxp[0m[0;34m,[0m [0mfp[0m[0;34m,[0m [0mleft[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m [0mright[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m [0mperiod[0m[0;34m=[0m[0;32mNone[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
One-dimensional linear interpolation for monotonically increasing sample points.

Returns the one-dimensional piecewise linear interpolant to a function
with given discrete data points (`xp`, `fp`), evaluated at `x`.

Parameters
----------
x : array_like
    The x-coordinates at which to evaluate the interpolated values.

xp : 1-D sequence of floats
    The x-coordinates of the data points, must be increasing if argument
    `period` is not specified. Otherwise, `xp` is internally sorted after
    normalizing the periodic boundaries with ``xp = xp % period``.

fp : 1-D sequence of float or complex
    The y-coordinates of the data points, same len

In [361]:
# make into dataarrays
N2_da = xr.DataArray(N2, dims=['time', 'k', 'j','i'])
zmid_da = xr.DataArray(zmid, dims=['time', 'k', 'j','i'])


ValueError: different number of dimensions on data and dims: 1 vs 4

In [356]:
# N2

# xarray.DataArray(data=<NA>, coords=None, dims=None, name=None, attrs=None, indexes=None, fastpath=False)
sa


#### Load all model data files. 


In [None]:
%%time
date_list = [start_date + datetime.timedelta(days=x) for x in range(ndays)]
target_files = [f'{datadir}LLC4320_pre-SWOT_{RegionName}_{date_list[n].strftime("%Y%m%d")}.nc' for n in range(ndays)] # list target files

# open the dataset
ds = xr.open_mfdataset(target_files)

# XC, YC and Z are the same at all times, so select a single time
# (note, this breaks for a single file - always load >1 file)
X = ds.XC.isel(time=0) 
Y = ds.YC.isel(time=0)

### Transform vector quantities to the tracer grid and rotate if needed

At higher latitudes, vector quantities (U/V/TAU_U/TAU_V) are in model coordinates and must be rotated to earth coordinates
rotation vectors are given in AngleSN, AngleCS

In [None]:
# first, interpolate U,V and oceTAUX, oceTAUY to the tracer grid
print('interpolating to tracer grid')

# xgcm.Grid interp raises a DeprecationWarning
warnings.filterwarnings(action='ignore', category=DeprecationWarning)


# define grid
grid = xgcm.Grid(ds, coords={'X':{'center': 'i', 'left': 'i_g'}, 
                     'Y':{'center': 'j', 'left': 'j_g'},
                     'T':{'center': 'time'},
                     'Z':{'center': 'k'}})    
U_c = grid.interp(ds.U, 'X', boundary='extend')
V_c = grid.interp(ds.V, 'Y', boundary='extend')
# do the same for TAUX and TAUY:
oceTAUX_c = grid.interp(ds.oceTAUX, 'X', boundary='extend')
oceTAUY_c = grid.interp(ds.oceTAUY, 'Y', boundary='extend')


# second, rotate vectors to geophysical (east, north) coordinates instead of model ones 
# (if needed, i.e. if AngleSN exists)
angles = "AngleSN"
if 'AngleSN' in list(ds.data_vars):
    print('Rotating vector quantities to east/north')
    U_c, V_c = rotate_vector_to_EN(U_c, V_c, ds['AngleCS'], ds['AngleSN'])
    oceTAUX_c, oceTAUY_c = rotate_vector_to_EN(oceTAUX_c, oceTAUY_c, ds['AngleCS'], ds['AngleSN'])

# replace the vector variables in ds 
print('replacing vectors with tracer-grid versions')
ds['U'] = U_c
ds['V'] = V_c
ds['oceTAUX'] = oceTAUX_c
ds['oceTAUY'] = oceTAUY_c

In [None]:
%%time
# load the corresponding derived fields 
if sampling_details['DERIVED_VARIABLES']:
    derivedir = datadir + 'derived/'
    derived_files = [f'{derivedir}LLC4320_pre-SWOT_{RegionName}_derived-fields_{date_list[n].strftime("%Y%m%d")}.nc' for n in range(ndays)] # list target files
    dsd = xr.open_mfdataset(derived_files)
    
    # merge the derived and raw data
    ds = ds.merge(dsd)
    
# drop a bunch of other vars we don't actually use - can comment this out if these are wanted
ds = ds.drop_vars({'DXV','DYU', 'DXC','DXG', 'DYC','DYG', 'XC_bnds', 'YC_bnds', 'Zp1', 'Zu','Zl','Z_bnds', 'nb'})
ds

### Create & plot sampling track

Use the `get_survey_track` function to apply the sampling strategy specified in `sampling_details`

returns:
* `survey_track`: lat/lon/time/depth of the platform's track,
* `survey_indices`: indices of the survey track in "ds"
* `sampling_details`: the actual sampling details used, ie, those that were specified + any default values


In [None]:
del sys.modules['oceanliner_functions']  # uncomment if troubleshooting oceanliner_functions
from oceanliner_functions import download_llc4320_data, compute_derived_fields, get_survey_track, survey_interp

survey_track, survey_indices, sampling_details = get_survey_track(ds, sampling_details)

# print:
sampling_details

In [None]:
# ---- generate name of file to save outputs in ---- 
filename_base = (f'OSSE_{RegionName}_{sampling_details["PLATFORM"]}_{start_date}_to_{start_date + datetime.timedelta(ndays)}_maxdepth{int(sampling_details["zrange"][1])}')
filename_out_base = (f'{outputdir}{filename_base}')
print(filename_base)
sampling_details['filename_out_base'] = filename_out_base

### Visualize the sampling track over a single model snapshot:

In [None]:
%matplotlib inline
plt.figure(figsize=(20,5))

# map of Theta at time zero
ax = plt.subplot(1,3,1)
ssto = plt.pcolormesh(X,Y,ds.Theta.isel(k=0, time=0).values, shading='auto')
if not (sampling_details['PLATFORM'] == 'mooring' or sampling_details['PLATFORM'] == 'sim_mooring'):
    tracko = plt.scatter(survey_track.lon, survey_track.lat, c=(survey_track.time-survey_track.time[0])/1e9/86400, cmap='Reds', s=0.75)
    plt.colorbar(ssto).set_label('SST, $^o$C')
    plt.colorbar(tracko).set_label('days from start')
    plt.title('SST and survey track: ' + RegionName)
else:
    plt.plot(survey_track.lon, survey_track.lat, marker='*', c='r')
    plt.title('SST and mooring location: ' + RegionName + ' region')


# depth/time plot of first few datapoints
ax = plt.subplot(1,3,2)
iplot = slice(0,20000)
if not (sampling_details['PLATFORM'] == 'mooring'):
    plt.plot(survey_track.time.isel(points=iplot), survey_track.dep.isel(points=iplot), marker='.')
else:
    # not quite right but good enough for now.
    # (times shouldn't increase with depth)
    plt.scatter((np.tile(survey_track['time'].isel(time=iplot), int(survey_track['dep'].data.size))),
         np.tile(survey_track['dep'], int(survey_track['time'].isel(time=iplot).data.size)),marker='.')             
# plt.xlim([start_date + datetime.timedelta(days=0), start_date + datetime.timedelta(days=2)])
plt.ylabel('Depth, m')
plt.gca().xaxis.set_major_locator(mdates.DayLocator(interval=1))
plt.gcf().autofmt_xdate()
plt.title('Sampling pattern')

# lon/time plot
ax = plt.subplot(1,3,3)
iplot = slice(0,20000)
if not (sampling_details['PLATFORM'] == 'mooring'):
    plt.plot(survey_track.time.isel(points=iplot), survey_track.lon.isel(points=iplot), marker='.')
else:
    # not quite right but good enough for now.
    # (times shouldn't increase with depth)
    plt.scatter((np.tile(survey_track['time'].isel(time=iplot), int(survey_track['lon'].data.size))),
         np.tile(survey_track['lon'], int(survey_track['time'].isel(time=iplot).data.size)),marker='.')             
# plt.xlim([start_date + datetime.timedelta(days=0), start_date + datetime.timedelta(days=2)])
plt.ylabel('Lon')
plt.gca().xaxis.set_major_locator(mdates.DayLocator(interval=1))
plt.gcf().autofmt_xdate()



# save
if SAVE_FIGURES:
    plt.savefig(figdir + filename_base + '_sampling.png', dpi=400, transparent=False, facecolor='white')

plt.show()

### MAIN FUNCTION OF OCEANLINER:
Interpolate data with the specified sampling PATTERN to create `subsampled_data` then put on a regular grid and store in `sgridded`

In [None]:
%%time

# xarray raises a future version / depracation warning...
warnings.filterwarnings(action='ignore', category=FutureWarning)

subsampled_data, sgridded = survey_interp(ds, survey_track, survey_indices, sampling_details)
sgridded

### Visualizations

Basic plots to show the interpolated variables

In [None]:
# - base list of vbls:
vbls3d = ['Theta','Salt', 'U', 'V']
# if derived fields, add those to the list:
if sampling_details['DERIVED_VARIABLES']:
    if ('steric_height' in sampling_details['DERIVED_VARIABLES']):
        vbls3d.append('steric_height')
    if ('vorticity' in sampling_details['DERIVED_VARIABLES']):
        vbls3d.append('vorticity')

ylim = [min(sgridded['depth'].values), max(sgridded['depth'].values)]
# ylim = [-200, -1]

nr = len(vbls3d) # # of rows
fig,ax=plt.subplots(nr,figsize=(8,len(vbls3d)*2),constrained_layout=True)


for j in range(nr):
    sgridded[vbls3d[j]].plot(ax=ax[j], ylim=ylim)
    ax[j].plot(sgridded.time.data, -sgridded.KPPhbl.data, c='k')
    ax[j].set_title(vbls3d[j])

if SAVE_FIGURES:
    plt.savefig(figdir + filename_base + '_3D.png', dpi=400, transparent=False, facecolor='white')


In [None]:
## selected 2d fields
j=0
nr = 6 # # of rows
fig,ax=plt.subplots(nr,figsize=(10,8),constrained_layout=True)


# wind vectors
ax[j].quiver(sgridded.time.data,0,sgridded.oceTAUX.data, sgridded.oceTAUY.data)
ax[j].set_title('Wind stress')    
ax[j].set_ylabel('N m-2')
# SH 
if sampling_details['DERIVED_VARIABLES']:
    if ('steric_height' in sampling_details['DERIVED_VARIABLES']):
        j+=1
        ax[j].plot(sgridded.time,sgridded.steric_height-sgridded.steric_height.mean(), 
                     sgridded.time.data,sgridded.steric_height_true-sgridded.steric_height_true.mean())
        ax[j].set_title('Steric height')
        ax[j].legend(['subsampled','true'])
        ax[j].set_ylabel('m')

# SSH
j+=1
ax[j].plot(sgridded.time,sgridded.Eta)
ax[j].set_title('SSH')
ax[j].set_ylabel('m')

# MLD
j+=1
ax[j].plot(sgridded.time,sgridded.KPPhbl)
ax[j].set_title('MLD')
ax[j].set_ylabel('m')
ax[j].invert_yaxis()

# surface heat flux
j+=1
ax[j].plot(sgridded.time,sgridded.oceQnet, sgridded.time,sgridded.oceQsw)
ax[j].set_title('Surface heat flux into the ocean')
ax[j].legend(['total','shortwave'])
ax[j].set_ylabel('W m-2')

# surface FW flux
j+=1
ax[j].plot(sgridded.time,sgridded.oceFWflx)
ax[j].set_title('Surface freshwater flux into the ocean') 
ax[j].set_ylabel('kg m-2 s-1')

# horiz line:
for j in range(nr):
    ax[j].axhline(0, color='grey', linewidth=0.8)

if SAVE_FIGURES:
    plt.savefig(figdir + filename_base + '_2D.png', dpi=400, transparent=False, facecolor='white')
   

In [None]:
ds.steric_height.isel(k=-1, j=100, i=100).plot(x='time')

In [None]:
# for sgridded, steric height looks v different from total... why???

# check: does ds.sh_true looks like ds.sh? YES
plt.plot(ds.steric_height_true.isel(j=100, i=100).data - np.mean(ds.steric_height_true.isel(j=100, i=100).data))
plt.plot(ds.steric_height.isel(j=100, i=100, k=51).data - np.mean(ds.steric_height.isel(j=100, i=100, k=51).data))


In [None]:
# check: does subsanpled sh look right? NO it seems to be rotated???
plt.scatter(subsampled_data.time.data, subsampled_data.dep.data, c=subsampled_data.steric_height_true.data)
plt.colorbar()

In [None]:
plt.scatter(subsampled_data.time.data, subsampled_data.dep.data, c=subsampled_data.steric_height.data)
plt.colorbar()

In [None]:
ds

### Save interpolated data

Fvaluesth raw and gridded subsampled data, add attributes and save


In [None]:
# add metadata to attributes
attrs = sampling_details
attrs['start_date'] = start_date.strftime('%Y-%m-%d')
end_date = sgridded['time'].data[-1]
attrs['end_date'] = np.datetime_as_string(end_date,unit='D')
attrs['ndays'] = ndays
attrs.pop('DERIVED_VARIABLES')

In [None]:
# ------ subsampled data
# this is slow and generates a huge file, so consider skipping
if sampling_parameters['PLATFORM'] != 'mooring':
    filename_out = filename_out_base + '_subsampled.nc'
    print(f'saving to {filename_out}')
    subsampled_data.attrs = attrs
    netcdf_fill_value = nc4.default_fillvals['f4']
    dv_encoding={'zlib':True,  # turns compression on\
                'complevel':9,     # 1 = fastest, lowest compression; 9=slowest, highest compression \
                'shuffle':True,    # shuffle filter can significantly improve compression ratios, and is on by default \
                'dtype':'float32',\
                '_FillValue':netcdf_fill_value}
    # save to a new file
    # subsampled_data.to_netcdf(filename_out,format='netcdf4',encoding=dv_encoding)
    subsampled_data.to_netcdf(filename_out,format='netcdf4')
    !ls -ltrh {filename_out}

In [None]:
# ------ gridded:
filename_out = filename_out_base + '_gridded.nc'
print(f'saving to {filename_out}')
sgridded.attrs = attrs
netcdf_fill_value = nc4.default_fillvals['f4']
dv_encoding={'zlib':True,  # turns compression on\
            'complevel':9,     # 1 = fastest, lowest compression; 9=slowest, highest compression \
            'shuffle':True,    # shuffle filter can significantly improve compression ratios, and is on by default \
            'dtype':'float32',\
            '_FillValue':netcdf_fill_value}
# save to a new file
# subsampled_data.to_netcdf(filename_out,format='netcdf4',encoding=dv_encoding)
sgridded.to_netcdf(filename_out,format='netcdf4')
!ls -ltrh {filename_out}

In [None]:
# ------ gridded:
filename_out = filename_out_base + '_gridded.nc'
print(f'saving to {filename_out}')
sgridded.attrs = attrs
netcdf_fill_value = nc4.default_fillvals['f4']
dv_encoding={'zlib':True,  # turns compression on\
            'complevel':9,     # 1 = fastest, lowest compression; 9=slowest, highest compression \
            'shuffle':True,    # shuffle filter can significantly improve compression ratios, and is on by default \
            'dtype':'float32',\
            '_FillValue':netcdf_fill_value}
# save to a new file
# subsampled_data.to_netcdf(filename_out,format='netcdf4',encoding=dv_encoding)
sgridded.to_netcdf(filename_out,format='netcdf4')
!ls -ltrh {filename_out}


### Visualize interpolated data in 3D

In [None]:
%matplotlib qt

fig = plt.figure(figsize=(12, 12))
ax = plt.axes(projection='3d')
fig.subplots_adjust(left=0.25, bottom=0.25)

ax.set_xlabel('longitude', fontsize=15, rotation=150)
ax.set_ylabel('latitude',fontsize=15)
ax.set_zlabel('depth', fontsize=15, rotation=60)

p = ax.scatter3D(subsampled_data.lon.data, subsampled_data.lat.data, subsampled_data.dep.data, c=subsampled_data.Theta.data, s=1)
fig.colorbar(p).set_label('Temperature ($^o$C)')
ax.set_title('Temperature interpolated to the survey track')