# Calculate pressure variable

In [1]:
import os

import cosima_cookbook as cc
from dask.distributed import Client

import xarray as xr
import numpy as np

from oceanpy import define_grid

from gsw import f, SA_from_SP, p_from_z, geo_strf_dyn_height, grav
from numbers import Number

In [2]:
outdir = os.path.join(os.sep, 'g', 'data', 'v45', 'jm6603', 'checkouts', 'phd', 'src', 'cosima', '02_manuscript', 'output')
if not os.path.exists(outdir):
    os.makedirs(outdir)

In [3]:
def to_netcdf(ds, file_name):

    valid_types = (str, Number, np.ndarray, np.number, list, tuple)
    try:
        ds.to_netcdf(file_name)
    except TypeError as e:
        print(e.__class__.__name__, e)
        for variable in ds.variables.values():
            for k, v in variable.attrs.items():
                if not isinstance(v, valid_types) or isinstance(v, bool):
                    variable.attrs[k] = str(v)
        ds.to_netcdf(file_name)

## Load data

In [4]:
session = cc.database.create_session()
expt = '01deg_jra55v140_iaf'

In [5]:
client = Client()
client

0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: /proxy/40419/status,

0,1
Dashboard: /proxy/40419/status,Workers: 7
Total threads: 14,Total memory: 63.00 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:41209,Workers: 7
Dashboard: /proxy/40419/status,Total threads: 14
Started: Just now,Total memory: 63.00 GiB

0,1
Comm: tcp://127.0.0.1:36165,Total threads: 2
Dashboard: /proxy/42767/status,Memory: 9.00 GiB
Nanny: tcp://127.0.0.1:39137,
Local directory: /jobfs/124996680.gadi-pbs/dask-worker-space/worker-_tdbwvyj,Local directory: /jobfs/124996680.gadi-pbs/dask-worker-space/worker-_tdbwvyj

0,1
Comm: tcp://127.0.0.1:44179,Total threads: 2
Dashboard: /proxy/39517/status,Memory: 9.00 GiB
Nanny: tcp://127.0.0.1:39559,
Local directory: /jobfs/124996680.gadi-pbs/dask-worker-space/worker-n6q920fh,Local directory: /jobfs/124996680.gadi-pbs/dask-worker-space/worker-n6q920fh

0,1
Comm: tcp://127.0.0.1:44797,Total threads: 2
Dashboard: /proxy/41457/status,Memory: 9.00 GiB
Nanny: tcp://127.0.0.1:37905,
Local directory: /jobfs/124996680.gadi-pbs/dask-worker-space/worker-77vbiir0,Local directory: /jobfs/124996680.gadi-pbs/dask-worker-space/worker-77vbiir0

0,1
Comm: tcp://127.0.0.1:38331,Total threads: 2
Dashboard: /proxy/44181/status,Memory: 9.00 GiB
Nanny: tcp://127.0.0.1:41265,
Local directory: /jobfs/124996680.gadi-pbs/dask-worker-space/worker-prearjj9,Local directory: /jobfs/124996680.gadi-pbs/dask-worker-space/worker-prearjj9

0,1
Comm: tcp://127.0.0.1:38367,Total threads: 2
Dashboard: /proxy/42199/status,Memory: 9.00 GiB
Nanny: tcp://127.0.0.1:43549,
Local directory: /jobfs/124996680.gadi-pbs/dask-worker-space/worker-75ag3vig,Local directory: /jobfs/124996680.gadi-pbs/dask-worker-space/worker-75ag3vig

0,1
Comm: tcp://127.0.0.1:39925,Total threads: 2
Dashboard: /proxy/39693/status,Memory: 9.00 GiB
Nanny: tcp://127.0.0.1:43767,
Local directory: /jobfs/124996680.gadi-pbs/dask-worker-space/worker-t2qhqm8i,Local directory: /jobfs/124996680.gadi-pbs/dask-worker-space/worker-t2qhqm8i

0,1
Comm: tcp://127.0.0.1:46103,Total threads: 2
Dashboard: /proxy/33605/status,Memory: 9.00 GiB
Nanny: tcp://127.0.0.1:45899,
Local directory: /jobfs/124996680.gadi-pbs/dask-worker-space/worker-ijadjeds,Local directory: /jobfs/124996680.gadi-pbs/dask-worker-space/worker-ijadjeds


In [6]:
# time limits of dataset
start, end = '1997-01-01', '1997-06-30'

# data output frequency
freq = '1 daily'

In [7]:
# location limits of dataset
lon_lim = slice(-225.2, -210.8)
lat_lim = slice(-53.7, -46.3)

meander_period = slice('1997-01-01', '1997-06-30')
monthly_period = slice('1997-04-01', '1997-04-30')
flex_period = slice('1997-04-10', '1997-04-25')

### Load and select coordinates

In [8]:
dxt = cc.querying.getvar(expt=expt, variable='dxt', session=session, frequency='static', n=1)
dyt = cc.querying.getvar(expt=expt, variable='dyt', session=session, frequency='static', n=1)
dzt = cc.querying.getvar(expt=expt, variable='dzt', session=session, frequency='1 monthly', n=1)

dxu = cc.querying.getvar(expt=expt, variable='dxu', session=session, frequency='static', n=1)
dyu = cc.querying.getvar(expt=expt, variable='dyu', session=session, frequency='static', n=1)

area_t = cc.querying.getvar(expt=expt, variable='area_t', session=session, frequency='static', n=1)
area_u = cc.querying.getvar(expt=expt, variable='area_u', session=session, frequency='static', n=1)

kmu = cc.querying.getvar(expt=expt, variable='kmu', session=session, frequency='static', n=1)
kmt = cc.querying.getvar(expt=expt, variable='kmt', session=session, frequency='static', n=1)

geolat_t = cc.querying.getvar(expt, variable='geolat_t', session=session, n=1)
geolon_t = cc.querying.getvar(expt, variable='geolon_t', session=session, n=1)

In [9]:
dxt_lim = dxt.sel(xt_ocean=lon_lim, yt_ocean=lat_lim)
dyt_lim = dyt.sel(xt_ocean=lon_lim, yt_ocean=lat_lim)
dzt_lim = dzt.sel(xt_ocean=lon_lim, yt_ocean=lat_lim).isel(time=1)
dzt_lim.name = 'dst'

dxu_lim = dxu.sel(xu_ocean=lon_lim, yu_ocean=lat_lim)
dyu_lim = dyu.sel(xu_ocean=lon_lim, yu_ocean=lat_lim)

areat_lim = area_t.sel(xt_ocean=lon_lim, yt_ocean=lat_lim)
areau_lim = area_u.sel(xu_ocean=lon_lim, yu_ocean=lat_lim)

kmu_lim = kmu.sel(xu_ocean=lon_lim, yu_ocean=lat_lim)
kmt_lim = kmt.sel(xt_ocean=lon_lim, yt_ocean=lat_lim)

lat_t = geolat_t.sel(xt_ocean=lon_lim, yt_ocean=lat_lim)
lon_t = geolon_t.sel(xt_ocean=lon_lim, yt_ocean=lat_lim)

### Load and select variables

In [10]:
# hydrography
sl = cc.querying.getvar(expt=expt, variable='sea_level', session=session, frequency=freq, start_time=start, end_time=end)
temp = cc.querying.getvar(expt=expt, variable='temp', session=session, frequency=freq, start_time=start, end_time=end)
salt = cc.querying.getvar(expt=expt, variable='salt', session=session, frequency=freq, start_time=start, end_time=end)

# # velocities
# u = cc.querying.getvar(expt=expt, variable='u', session=session, frequency=freq, start_time=start, end_time=end)
# v = cc.querying.getvar(expt=expt, variable='v', session=session, frequency=freq, start_time=start, end_time=end)
# wt = cc.querying.getvar(expt=expt, variable='wt', session=session, frequency=freq, start_time=start, end_time=end)

# topography
ht = cc.querying.getvar(expt=expt, variable='ht', session=session, frequency='static', n=1)

In [11]:
sl_lim = sl.sel(xt_ocean=lon_lim, yt_ocean=lat_lim, time=meander_period)
temp_lim = temp.sel(xt_ocean=lon_lim, yt_ocean=lat_lim, time=meander_period)
salt_lim = salt.sel(xt_ocean=lon_lim, yt_ocean=lat_lim, time=meander_period)
ht_lim = ht.sel(xt_ocean=lon_lim, yt_ocean=lat_lim)

## Constants

In [12]:
rho_0 = 1036 # kg/m^3
p_ref = 0 #1500

## Define grid

In [13]:
# define coordinates
coords = {'xt_ocean': None, 'yt_ocean': None, 'st_ocean': None, 'xu_ocean': 0.5, 'yu_ocean': 0.5}
distances=('dxt', 'dyt', 'dst', 'dxu', 'dyu')
areas=('area_u', 'area_t')
dims=('X', 'Y', 'S')

coordinates = xr.merge([
    dxt_lim, dyt_lim, dzt_lim, dxu_lim, dyu_lim, areat_lim, areau_lim, kmu_lim, kmt_lim, lon_t, lat_t, ht_lim])
sea_level = xr.merge([coordinates, sl_lim]) #, u_lim, v_lim, wt_lim])

# define grid
grid = define_grid(sea_level, dims, coords, distances, areas, periodic=False)

In [14]:
file_name = os.path.join(outdir, 'coordinates.nc')
if not os.path.exists(file_name):
    to_netcdf(coordinates, os.path.join(outdir, 'coordinates.nc'))

TypeError Invalid value for attr 'time_bounds': <xarray.DataArray 'time_bounds' (time: 3, nv: 2)>
dask.array<open_dataset-58c27d34cd3c9dac8f214dabff64eef4time_bounds, shape=(3, 2), dtype=timedelta64[ns], chunksize=(1, 2), chunktype=numpy.ndarray>
Coordinates:
  * time     (time) datetime64[ns] 1958-01-16T12:00:00 ... 1958-03-16T12:00:00
  * nv       (nv) float64 1.0 2.0
Attributes:
    long_name:  time axis boundaries
    calendar:   GREGORIAN. For serialization to netCDF files, its value must be of one of the following types: str, Number, ndarray, number, list, tuple


## Calculate pressure

In [14]:
file_name = os.path.join(outdir, 'hydro.nc')
if os.path.exists(file_name):
    hydro = xr.open_dataset(file_name)
else:
    # Pressure coordinate
    lat_t_3d = lat_t.broadcast_like(coordinates.dst)
    z_3d = (-coordinates.st_ocean).broadcast_like(coordinates.dst)
    p_3d = xr.apply_ufunc(p_from_z, z_3d, lat_t_3d, dask='parallelized', output_dtypes=[z_3d.dtype])
    p_3d = p_3d.compute()
    
    # Absolute Salinity
    SA = xr.apply_ufunc(SA_from_SP, salt_lim, p_3d, lon_t, lat_t, dask='parallelized', output_dtypes=[z_3d.dtype])
    SA = SA.compute()
    SA.name = 'SA'
    SA.attrs = {'standard_name': 'sea_water_absolute_salinity', 'units': r'$\mathrm{gkg}^{-1}$'}
    
    # Conservative Temperature
    CT = temp_lim-273
    CT.name = 'CT'
    CT.attrs.update(units=r'$^\circ$C')
    
    # Hydrography dataset
    hydro = xr.merge([sea_level, SA, CT])

    # Dynamic Height anomaly
    p_4d = p_3d.broadcast_like(hydro.SA)
    z_4d = z_3d.broadcast_like(hydro.SA)
    lat_t_4d = lat_t_3d.broadcast_like(hydro.SA)
    
    # NOTE: minus before geo_strf_dyn_height or deltaD is because of intergrating downwards
    deltaD = - xr.DataArray(geo_strf_dyn_height(SA.values, CT.values, p_4d.values, p_ref=p_ref, axis=1),
                          coords = [hydro.time, hydro.st_ocean, hydro.yt_ocean, hydro.xt_ocean],
                          dims = ['time', 'st_ocean', 'yt_ocean', 'xt_ocean'],
                          name = 'deltaD',
                          attrs = {'standard_name': 'dynamic height anomaly', 'units': r'$\mathrm{m}^2\mathrm{s}^{-2}$'})
    # deltaD = - xr.apply_ufunc(geo_strf_dyn_height, hydro.SA, hydro.CT, p_4d, kwargs={'p_ref': p_ref, 'axis': 1}, dask='parallelized', output_dtypes=[hydro.SA.dtype])
    # deltaD = deltaD.compute()
    # deltaD.name = 'deltaD'
    
    # Pressure variable
    pressure = xr.apply_ufunc(p_from_z, z_4d, lat_t_4d, deltaD, dask='parallelized', output_dtypes=[deltaD.dtype])
    pressure = pressure.compute()
    pressure.name = 'pressure'
    pressure.attrs = {'standard_name': 'sea_water_pressure', 'units':'dbar'}
    
    hydro = xr.merge([hydro, deltaD, pressure])

    # Save hydrographic pressure dataset to file
    to_netcdf(hydro, file_name)

TypeError Invalid value for attr 'c_grid_axis_shift': None. For serialization to netCDF files, its value must be of one of the following types: str, Number, ndarray, number, list, tuple
