# Zarr SOSE ds - cleaned version

In [3]:
from dask_gateway import GatewayCluster

cluster = GatewayCluster()
cluster.adapt(minimum=2, maximum=10)  # or cluster.scale(n) to a fixed size.
client = cluster.get_client()

In [5]:
client

0,1
Client  Scheduler: gateway://traefik-gcp-uscentral1b-prod-dask-gateway.prod:80/prod.40d1a90c120b41fc84a4f6a8d9b8425d  Dashboard: /services/dask-gateway/clusters/prod.40d1a90c120b41fc84a4f6a8d9b8425d/status,Cluster  Workers: 2  Cores: 4  Memory: 8.59 GB


In [6]:
import xarray as xr
from matplotlib import pyplot as plt
import gcsfs
import dask
import dask.array as dsa
import numpy as np
import fsspec
%matplotlib inline

In [9]:
xr.set_options(display_style='text')

<xarray.core.options.set_options at 0x7fcd0d99b090>

## Open SOSE Dataset from the Cloud

In [10]:
fs = gcsfs.GCSFileSystem(requester_pays=True)

mapping = fs.get_mapper('gcs://pangeo-ecco-sose')

ds = xr.open_zarr(mapping, consolidated=True)
ds

In [11]:
coords = ds.coords.to_dataset().reset_coords()
dsr = ds.reset_coords(drop=True)
dsr

## Create xgcm grid

[Xgcm](http://xgcm.readthedocs.io) is a package which helps with the analysis of GCM data.

In [12]:
import xgcm
grid = xgcm.Grid(ds, periodic=('X', 'Y'))
grid

<xgcm.Grid>
Z Axis (not periodic, boundary=None):
  * center   Z --> left
  * left     Zl --> center
  * outer    Zp1 --> center
  * right    Zu --> center
T Axis (not periodic, boundary=None):
  * center   time
Y Axis (periodic, boundary=None):
  * center   YC --> left
  * left     YG --> center
X Axis (periodic, boundary=None):
  * center   XC --> left
  * left     XG --> center

## Tracer Budgets

Here we will do the heat and salt budgets for SOSE. In integral form, these budgets can be written as

$$
\mathcal{V} \frac{\partial S}{\partial t} = G^S_{adv} + G^S_{diff} + G^S_{surf} + G^S_{linfs}
$$


$$
\mathcal{V} \frac{\partial \theta}{\partial t} = G^\theta_{adv} + G^\theta_{diff} + G^\theta_{surf} + G^\theta_{linfs} + G^\theta_{sw}
$$

where $\mathcal{V}$ is the volume of the grid cell. The terms on the right-hand side are called _tendencies_. They add up to the total tendency (the left hand side).

The first term is the convergence of advective fluxes. The second is the convergence of diffusive fluxes. The third is the explicit surface flux. The fourth is the correction due to the linear free-surface approximation. The fifth is shortwave penetration (only for temperature).

### Flux Divergence

First we define a function to calculate the convergence of the advective and diffusive fluxes, since this has to be repeated for both tracers.

In [13]:
def tracer_flux_budget(suffix):
    """Calculate the convergence of fluxes of tracer `suffix` where 
    `suffix` is `TH` or `SLT`. Return a new xarray.Dataset."""
    conv_horiz_adv_flux = -(grid.diff(dsr['ADVx_' + suffix], 'X') +
                          grid.diff(dsr['ADVy_' + suffix], 'Y')).rename('conv_horiz_adv_flux_' + suffix)
    conv_horiz_diff_flux = -(grid.diff(dsr['DFxE_' + suffix], 'X') +
                          grid.diff(dsr['DFyE_' + suffix], 'Y')).rename('conv_horiz_diff_flux_' + suffix)
    # sign convention is opposite for vertical fluxes
    conv_vert_adv_flux = grid.diff(dsr['ADVr_' + suffix], 'Z', boundary='fill').rename('conv_vert_adv_flux_' + suffix)
    conv_vert_diff_flux = (grid.diff(dsr['DFrE_' + suffix], 'Z', boundary='fill') +
                           grid.diff(dsr['DFrI_' + suffix], 'Z', boundary='fill') +
                           grid.diff(dsr['KPPg_' + suffix], 'Z', boundary='fill')).rename('conv_vert_diff_flux_' + suffix)
    
    all_fluxes = [conv_horiz_adv_flux, conv_horiz_diff_flux, conv_vert_adv_flux, conv_vert_diff_flux]
    #conv_all_fluxes = sum(all_fluxes).rename('conv_total_flux_' + suffix)
    #return xr.merge(all_fluxes + [conv_all_fluxes])
    return xr.merge(all_fluxes)

In [14]:
# sum of all converging adv/diff fluxes
budget_slt = tracer_flux_budget('SLT')
budget_slt

In [15]:
budget_th = tracer_flux_budget('TH')
budget_th

In [16]:
#adding true total tendency

volume = (coords.drF * coords.rA * coords.hFacC)
#client.scatter(volume)
day2seconds = (24*60*60)

budget_th['total_tendency_TH_truth'] = (dsr.TOTTTEND * volume) / day2seconds
budget_slt['total_tendency_SLT_truth'] = (dsr.TOTSTEND * volume) / day2seconds

In [17]:
budget_th

In [18]:
budget_slt

### Surface Fluxes

The surface fluxes are only active in the top model layer. We need to specify some constants to convert to the proper units and scale factors to convert to integral form. They also require some xarray special sauce to merge with the 3D variables.

In [19]:
# constants
heat_capacity_cp = 3.994e3
runit2mass = 1.035e3 #rho

# treat the shortwave flux separately from the rest of the surface flux
surf_flux_th = (dsr.TFLUX - dsr.oceQsw) * coords.rA / (heat_capacity_cp * runit2mass)
lin_fs_correction_th = -(dsr.WTHMASS.isel(Zl=0, drop=True) * coords.rA)

#sw flux
surf_flux_th_sw = dsr.oceQsw * coords.rA / (heat_capacity_cp * runit2mass)

# salt
surf_flux_slt = dsr.SFLUX * coords.rA  / runit2mass
lin_fs_correction_slt = -(dsr.WSLTMASS.isel(Zl=0, drop=True) * coords.rA)
#units: (˚C m^3)/s

### Shortwave Flux

Special treatment is needed for the shortwave flux, which penetrates into the interior of the water column

In [20]:
def swfrac(coords, fact=1., jwtype=2):
    """Clone of MITgcm routine for computing sw flux penetration.
    z: depth of output levels"""
    
    rfac = [0.58 , 0.62, 0.67, 0.77, 0.78]
    a1 = [0.35 , 0.6  , 1.0  , 1.5  , 1.4]
    a2 = [23.0 , 20.0 , 17.0 , 14.0 , 7.9 ]
    
    facz = fact * coords.Zl.sel(Zl=slice(0, -200))
    j = jwtype-1
    swdk = (rfac[j] * np.exp(facz / a1[j]) +
            (1-rfac[j]) * np.exp(facz / a2[j]))
            
    return swdk.rename('swdk')

_, swdown = xr.align(dsr.Zl, surf_flux_th_sw * swfrac(coords), join='left', )
swdown = swdown.fillna(0)

    >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    ...     array[indexer]

To avoid creating the large chunks, set the option
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    ...     array[indexer]
  return self.array[key]


In [21]:
budget_th = xr.merge([budget_th, surf_flux_th.rename('surface_flux_conv_TH'), 
                      lin_fs_correction_th.rename('lin_fs_correction_TH'),
                    (-grid.diff(swdown, 'Z', boundary='fill').fillna(0.)).rename('sw_flux_conv_TH')])
budget_th = budget_th.chunk({'XC':-1, 'YC':-1, 'Z':-1, 'time':1})
budget_th

In [22]:
budget_slt = xr.merge([budget_slt, surf_flux_slt.rename('surface_flux_conv_SLT'),
                       lin_fs_correction_slt.rename('lin_fs_correction_SLT')])
budget_slt = budget_slt.chunk({'XC':-1, 'YC':-1, 'Z':-1, 'time':1})
budget_slt

## Saving Data

In [16]:
gcfs_auth = gcsfs.GCSFileSystem(project='pangeo-181919', token='browser')
token = gcfs_auth.tokens[('pangeo-181919', 'full_control')]
gcfs_w_token = gcsfs.GCSFileSystem(project='pangeo-181919', token=token)

Please visit this URL to authorize this application: https://accounts.google.com/o/oauth2/auth?response_type=code&client_id=586241054156-9kst7ltfj66svc342pcn43vp6ta3idin.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdevstorage.full_control&state=z6dpw8OLC8MuUAWuNE0iuUw9iIQ7dA&prompt=consent&access_type=offline


Enter the authorization code:  4/wAEjZExfVpgfp_UuFQZtl-SY-uu4njQdf-RwtNi-3ZWcfJM5MCx8amk


In [17]:
gcsmap_slt = gcfs_w_token.get_mapper('pangeo-tmp/stb2145/SOSE/budget_slt.zarr') #creating a map of path file will be stored in
gcsmap_slt

<fsspec.mapping.FSMap at 0x7f9eec5f8f10>

In [18]:
#saving salt budget
budget_slt.to_zarr(gcsmap_slt, mode='w') #writing the dataset into a file with the specified path

<xarray.backends.zarr.ZarrStore at 0x7f9eeeb32a10>

In [19]:
gcsmap_th = gcfs_w_token.get_mapper('pangeo-tmp/stb2145/SOSE/budget_th.zarr')
gcsmap_th

<fsspec.mapping.FSMap at 0x7f9eecb68810>

In [20]:
#saving heat budget
budget_th.to_zarr(gcsmap_th, mode='w')

<xarray.backends.zarr.ZarrStore at 0x7f9ee4834050>

## WG entire column

In [23]:
#index locations corresponding to above lat/lon coordinates
lower_lat_lower = 0
upper_lat_lower = 71
left_lon_lower = 1794
right_lon_lower = 2159

lower_cell_range = dict(Zl=[27, 42])
lower_cell_range_z = dict(Z=slice(27, 42))

In [24]:
#set the boundaries
lat_range_lower = dict(YC=slice(lower_lat_lower, upper_lat_lower))
lon_range_lower = dict(XC=slice(left_lon_lower, right_lon_lower))

**If this is discretized, since we're on individual grid-pt level, I don't need to define a lower range z for interior?**

In [25]:
budget_slt_wg = budget_slt.isel(**lat_range_lower, **lon_range_lower)

In [26]:
budget_th_wg = budget_th.isel(**lat_range_lower, **lon_range_lower)
budget_th_wg

In [27]:
budget_th_wg['total_tendency_TH'] = (budget_th_wg.conv_horiz_adv_flux_TH + budget_th_wg.conv_horiz_diff_flux_TH 
                                     + budget_th_wg.surface_flux_conv_TH + budget_th_wg.lin_fs_correction_TH)

In [28]:
budget_slt_wg['total_tendency_SLT'] = (budget_slt_wg.conv_horiz_adv_flux_SLT + 
                                       budget_slt_wg.conv_horiz_diff_flux_SLT + 
                                       budget_slt_wg.surface_flux_conv_SLT + 
                                       budget_slt_wg.lin_fs_correction_SLT)

In [29]:
budget_th_wg

In [30]:
import dask
dask.compute?

[0;31mSignature:[0m [0mdask[0m[0;34m.[0m[0mcompute[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
Compute several dask collections at once.

Parameters
----------
args : object
    Any number of objects. If it is a dask object, it's computed and the
    result is returned. By default, python builtin collections are also
    traversed to look for dask objects (for more information see the
    ``traverse`` keyword). Non-dask arguments are passed through unchanged.
traverse : bool, optional
    By default dask traverses builtin python collections looking for dask
    objects passed to ``compute``. For large collections this can be
    expensive. If none of the arguments contain any dask objects, set
    ``traverse=False`` to avoid doing this traversal.
scheduler : string, optional
    Which scheduler to use like "threads", "synchronous" or "processes".
    If not provided, the default is to

In [None]:
tmp = dask.compute(budget_th_wg.conv_horiz_diff_flux_TH.sum(dim=['XC', 'YC', 'Z']), retries=5)

In [31]:
rhs_th = budget_th_wg.total_tendency_TH.sum(dim=['XC', 'YC', 'Z']).load()

KeyboardInterrupt: 

In [None]:
lhs_th = budget_th_wg.total_tendency_TH_truth.sum(dim=['XC', 'YC', 'Z']).load()

In [None]:
plt.figure(figsize=(12,6))
rhs_th.plot(label='rhs')
lhs_th.plot(label='lhs')
plt.legend();

In [None]:
budget_slt_wg

In [None]:
rhs_slt = budget_slt_wg.total_tendency_SLT.sum(dim=['XC', 'YC', 'Z']).load()
lhs_slt = budget_slt_wg.total_tendency_SLT_truth.sum(dim=['XC', 'YC', 'Z']).load()

In [None]:
plt.figure(figsize=(12,6))
rhs_slt.plot(label='rhs')
lhs_slt.plot(label='lhs')
plt.legend();

## WG Interior

In [None]:
rhs_th_lower = budget_th_wg.total_tendency_TH.isel(**lower_cell_range_z).sum(dim=['XC', 'YC', 'Z']).load()
lhs_th_lower = budget_th_wg.total_tendency_TH_truth.isel(**lower_cell_range_z).sum(dim=['XC', 'YC', 'Z']).load()

In [None]:
plt.figure(figsize=(12,6))
rhs_th_lower.plot(label='rhs')
lhs_th_lower.plot(label='lhs')
plt.legend();

In [None]:
rhs_slt_lower = budget_slt_wg.isel(**lower_cell_range_z).total_tendency_SLT.sum(dim=['XC', 'YC', 'Z']).load()
lhs_slt_lower = budget_slt_wg.isel(**lower_cell_range_z).total_tendency_SLT_truth.sum(dim=
                                                                                      ['XC', 'YC', 'Z']).load()

In [None]:
plt.figure(figsize=(12,6))
rhs_slt_lower.plot(label='rhs')
lhs_slt_lower.plot(label='lhs')
plt.legend();