# Calculate gradient wind balance

In [1]:
import os

import cosima_cookbook as cc
from dask.distributed import Client

import numpy as np
import xarray as xr

import xgcm
from oceanpy import gradient_wind_from_ssh, define_grid, horizontal_divergence, horizontal_strain, relative_vorticity

from numbers import Number

In [2]:
outdir = os.path.join(os.sep, 'g', 'data', 'v45', 'jm6603', 'checkouts', 'phd', 'src', 'cosima', '02_manuscript', 'output')
if not os.path.exists(outdir):
    os.makedirs(outdir)

In [3]:
from numbers import Number
def to_netcdf(ds, file_name):

    valid_types = (str, Number, np.ndarray, np.number, list, tuple)
    try:
        ds.to_netcdf(file_name)
    except TypeError as e:
        print(e.__class__.__name__, e)
        for variable in ds.variables.values():
            for k, v in variable.attrs.items():
                if not isinstance(v, valid_types) or isinstance(v, bool):
                    variable.attrs[k] = str(v)
        ds.to_netcdf(file_name)

## Load data

In [4]:
session = cc.database.create_session()
expt = '01deg_jra55v140_iaf'

In [5]:
client = Client()
client

0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: /proxy/8787/status,

0,1
Dashboard: /proxy/8787/status,Workers: 2
Total threads: 2,Total memory: 9.00 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:38625,Workers: 2
Dashboard: /proxy/8787/status,Total threads: 2
Started: Just now,Total memory: 9.00 GiB

0,1
Comm: tcp://127.0.0.1:45821,Total threads: 1
Dashboard: /proxy/40087/status,Memory: 4.50 GiB
Nanny: tcp://127.0.0.1:43667,
Local directory: /jobfs/125319081.gadi-pbs/dask-worker-space/worker-sddbfg6q,Local directory: /jobfs/125319081.gadi-pbs/dask-worker-space/worker-sddbfg6q

0,1
Comm: tcp://127.0.0.1:46465,Total threads: 1
Dashboard: /proxy/38371/status,Memory: 4.50 GiB
Nanny: tcp://127.0.0.1:32799,
Local directory: /jobfs/125319081.gadi-pbs/dask-worker-space/worker-_3phk4_3,Local directory: /jobfs/125319081.gadi-pbs/dask-worker-space/worker-_3phk4_3


In [6]:
# data output frequency
freq = '1 daily'

# time limits of dataset
start, end = '1997-04-01', '1997-04-30'

# location limits of dataset
lon_lim = slice(-225.2, -210.8) #230
lat_lim = slice(-53.7, -46.3)

monthly_period = slice('1997-04-01', '1997-04-30')
flex_period = slice('1997-04-10', '1997-04-25')

### Load and select coordinates

In [7]:
# load coordinates
dxt = cc.querying.getvar(expt=expt, variable='dxt', session=session, frequency='static', n=1)
dyt = cc.querying.getvar(expt=expt, variable='dyt', session=session, frequency='static', n=1)
# dzt = cc.querying.getvar(expt=expt, variable='dzt', session=session, frequency='1 monthly', n=1)

dxu = cc.querying.getvar(expt=expt, variable='dxu', session=session, frequency='static', n=1)
dyu = cc.querying.getvar(expt=expt, variable='dyu', session=session, frequency='static', n=1)

area_t = cc.querying.getvar(expt=expt, variable='area_t', session=session, frequency='static', n=1)
area_u = cc.querying.getvar(expt=expt, variable='area_u', session=session, frequency='static', n=1)

In [8]:
dxt_lim = dxt.sel(xt_ocean=lon_lim, yt_ocean=lat_lim)
dyt_lim = dyt.sel(xt_ocean=lon_lim, yt_ocean=lat_lim)
# dzt_lim = dzt.sel(xt_ocean=lon_lim, yt_ocean=lat_lim)

dxu_lim = dxu.sel(xu_ocean=lon_lim, yu_ocean=lat_lim)
dyu_lim = dyu.sel(xu_ocean=lon_lim, yu_ocean=lat_lim)

areat_lim = area_t.sel(xt_ocean=lon_lim, yt_ocean=lat_lim)
areau_lim = area_u.sel(xu_ocean=lon_lim, yu_ocean=lat_lim)


### Load and select variables

In [9]:
# load variables 
sl = cc.querying.getvar(expt=expt, variable='sea_level', session=session, frequency=freq, start_time=start, end_time=end)
u = cc.querying.getvar(expt=expt, variable='u', session=session, frequency=freq, start_time=start, end_time=end)
v = cc.querying.getvar(expt=expt, variable='v', session=session, frequency=freq, start_time=start, end_time=end)
# wt = cc.querying.getvar(expt=expt, variable='wt', session=session, frequency=freq, start_time=start, end_time=end)


# select spatial area
sl_lim = sl.sel(xt_ocean=lon_lim, yt_ocean=lat_lim)
u_lim = u.sel(xu_ocean=lon_lim, yu_ocean=lat_lim)
v_lim = v.sel(xu_ocean=lon_lim, yu_ocean=lat_lim)
# wt_lim = wt.sel(xt_ocean=lon_lim, yt_ocean=lat_lim)


## Define Grid

### Calculate finite differences with package `xgcm`
The way `xgcm` works is that we first create a `grid` object that has all the information regarding our staggered grid. For our case, `grid` needs to know the location of the `xt_ocean`, `xu_ocean` points (and same for y) and their relative orientation to one another, i.e., that `xu_ocean` is shifted to the right of `xt_ocean` by $\frac{1}{2}$ grid-cell.

In [10]:
# define coordinates
coords = {'xt_ocean': None, 'yt_ocean': None, 'xu_ocean': 0.5, 'yu_ocean': 0.5}
distances=('dxt', 'dyt', 'dxu', 'dyu')
areas=('area_u', 'area_t')
dims=('X', 'Y')

coordinates = xr.merge([dxt_lim, dyt_lim, dxu_lim, dyu_lim, areat_lim, areau_lim])

vel = xr.merge([coordinates, sl_lim.sel(time=monthly_period), u_lim.sel(time=monthly_period), v_lim.sel(time=monthly_period)])

grid = define_grid(vel, dims, coords, distances, areas, periodic=False)
grid

<xgcm.Grid>
Y Axis (not periodic, boundary='extend'):
  * center   yt_ocean --> outer
  * outer    yu_ocean --> center
X Axis (not periodic, boundary='extend'):
  * center   xt_ocean --> right
  * right    xu_ocean --> center

## Gradient wind, geostrophic and ageostrophic velocities

In [11]:
# Calculate flow speed at the surface
sea_level = sl_lim.sel(time=monthly_period)
# u = u_lim.sel(time=flex_period)
# v = v_lim.sel(time=flex_period)
# V = np.sqrt(u**2 + v**2)
# V.name, u.name, v.name = 'Vtot', 'utot', 'vtot'

# Calculate gradient wind and geostrophic velocities from sea level
UTM54 = 'EPSG:32754'
gw = gradient_wind_from_ssh(sea_level, transform=UTM54, 
                            dimensions=('time', 'yt_ocean', 'xt_ocean'))

# Smoothing (!!!!)
gw_smooth = gradient_wind_from_ssh(sea_level, transform=UTM54, 
                                   dimensions=('time', 'yt_ocean', 'xt_ocean'), 
                                   smooth={'boxcar': 3})

# save gradient wind dataset
gw = xr.merge([sea_level.to_dataset(), gw_smooth])
file_name = os.path.join(outdir, 'gw-vel.nc')
if not os.path.exists(file_name):
    to_netcdf(gw, file_name)

TypeError Invalid value for attr 'time_bounds': <xarray.DataArray 'time_bounds' (time: 181, nv: 2)>
dask.array<concatenate, shape=(181, 2), dtype=timedelta64[ns], chunksize=(1, 2), chunktype=numpy.ndarray>
Coordinates:
  * time     (time) datetime64[ns] 1997-01-01T12:00:00 ... 1997-06-30T12:00:00
  * nv       (nv) float64 1.0 2.0
Attributes:
    long_name:  time axis boundaries
    calendar:   GREGORIAN. For serialization to netCDF files, its value must be of one of the following types: str, Number, ndarray, number, list, tuple
