# SOSE SALT BUDGET

In [1]:
import xarray as xr
from matplotlib import pyplot as plt
import gcsfs
import dask
import dask.array as dsa
import numpy as np
import intake
%matplotlib inline



In [None]:
from dask_kubernetes import KubeCluster
from dask.distributed import Client
cluster = KubeCluster(n_workers=5, threads_per_worker=10)
cluster.adapt(minimum=1, maximum=10)
cluster

In [4]:
ocean_url = 'https://raw.githubusercontent.com/pangeo-data/pangeo-datastore/master/intake-catalogs/ocean.yaml'
ocean_cat = intake.Catalog(ocean_url)
ds = ocean_cat["SOSE"].to_dask()
ds

<xarray.Dataset>
Dimensions:   (XC: 2160, XG: 2160, YC: 320, YG: 320, Z: 42, Zl: 42, Zp1: 43, Zu: 42, time: 438)
Coordinates:
    Depth     (YC, XC) float32 dask.array<shape=(320, 2160), chunksize=(320, 2160)>
    PHrefC    (Z) float32 dask.array<shape=(42,), chunksize=(42,)>
    PHrefF    (Zp1) float32 dask.array<shape=(43,), chunksize=(43,)>
  * XC        (XC) float32 0.083333336 0.25 0.4166667 ... 359.75 359.9167
  * XG        (XG) float32 5.551115e-17 0.16666667 ... 359.6667 359.83334
  * YC        (YC) float32 -77.87497 -77.7083 -77.54163 ... -24.874966 -24.7083
  * YG        (YG) float32 -77.9583 -77.79163 -77.62497 ... -24.9583 -24.791632
  * Z         (Z) float32 -5.0 -15.5 -27.0 -39.5 ... -5075.0 -5325.0 -5575.0
  * Zl        (Zl) float32 0.0 -10.0 -21.0 -33.0 ... -4950.0 -5200.0 -5450.0
  * Zp1       (Zp1) float32 0.0 -10.0 -21.0 -33.0 ... -5200.0 -5450.0 -5700.0
  * Zu        (Zu) float32 -10.0 -21.0 -33.0 -46.0 ... -5200.0 -5450.0 -5700.0
    drC       (Zp1) float32 dask.ar

In [5]:
coords = ds.coords.to_dataset().reset_coords() #ds split into coords
dsr = ds.reset_coords(drop=True) #ds dropping coords so it just holds data variables
dsr

<xarray.Dataset>
Dimensions:   (XC: 2160, XG: 2160, YC: 320, YG: 320, Z: 42, Zl: 42, Zp1: 43, Zu: 42, time: 438)
Coordinates:
  * XC        (XC) float32 0.083333336 0.25 0.4166667 ... 359.75 359.9167
  * XG        (XG) float32 5.551115e-17 0.16666667 ... 359.6667 359.83334
  * YC        (YC) float32 -77.87497 -77.7083 -77.54163 ... -24.874966 -24.7083
  * YG        (YG) float32 -77.9583 -77.79163 -77.62497 ... -24.9583 -24.791632
  * Z         (Z) float32 -5.0 -15.5 -27.0 -39.5 ... -5075.0 -5325.0 -5575.0
  * Zl        (Zl) float32 0.0 -10.0 -21.0 -33.0 ... -4950.0 -5200.0 -5450.0
  * Zp1       (Zp1) float32 0.0 -10.0 -21.0 -33.0 ... -5200.0 -5450.0 -5700.0
  * Zu        (Zu) float32 -10.0 -21.0 -33.0 -46.0 ... -5200.0 -5450.0 -5700.0
  * time      (time) datetime64[ns] 2005-01-06 2005-01-11 ... 2010-12-31
Data variables:
    ADVr_SLT  (time, Zl, YC, XC) float32 dask.array<shape=(438, 42, 320, 2160), chunksize=(1, 42, 320, 2160)>
    ADVr_TH   (time, Zl, YC, XC) float32 dask.array<shap

In [6]:
import xgcm
grid = xgcm.Grid(ds, periodic=('XC', 'YC'))
grid

<xgcm.Grid>
Z Axis (not periodic):
  * center   Z --> left
  * left     Zl --> center
  * outer    Zp1 --> center
  * right    Zu --> center
T Axis (not periodic):
  * center   time
Y Axis (not periodic):
  * center   YC --> left
  * left     YG --> center
X Axis (not periodic):
  * center   XC --> left
  * left     XG --> center

## Salt Budget for Weddell Gyre

In [None]:
sss_mean = dsr.SALT.isel(Z=0).mean(dim='time').load()

In [None]:
surface_mask = (coords.hFacC[0]>0).load()

In [8]:
left_lon = 298
upper_lat = -65
right_lon = 360
lower_lat =-80

In [None]:
fig, ax = plt.subplots(figsize=(12, 8), subplot_kw={'facecolor': '0.5'})
(sss_mean.where(surface_mask)).plot(ax=ax, vmin=-2, vmax=10)
ax.plot([left_lon, right_lon], [upper_lat, upper_lat], color='w')
ax.plot([right_lon, right_lon], [lower_lat, upper_lat], color='w')
ax.plot([left_lon, left_lon], [lower_lat, upper_lat], color='w')
ax.plot([left_lon, right_lon], [lower_lat, lower_lat], color='w')
#ax.patch(xy=(-80, 290), width=70, height=20)
#ax.patch(xy=(295,-80), width = 1, height = 1)
ax.set_xlim([287, 362])
ax.set_ylim([-80, -55])
plt.title('Surface Mask @ surface\nLat[290,360]\nLon[-80,-63.5]');

In [None]:
#set the boundaries
lat_range = dict(YC=slice(lower_lat, upper_lat))
lon_range = dict(XC=slice(left_lon, right_lon))

In [9]:
dsr.ADVy_SLT.sel(YG=upper_lat, method='nearest')

<xarray.DataArray 'ADVy_SLT' (time: 438, Z: 42, XC: 2160)>
dask.array<shape=(438, 42, 2160), dtype=float32, chunksize=(1, 42, 2160)>
Coordinates:
  * XC       (XC) float32 0.083333336 0.25 0.4166667 ... 359.75 359.9167
    YG       float32 -64.9583
  * Z        (Z) float32 -5.0 -15.5 -27.0 -39.5 ... -5075.0 -5325.0 -5575.0
  * time     (time) datetime64[ns] 2005-01-06 2005-01-11 ... 2010-12-31
Attributes:
    long_name:      Meridional Advective Flux of Salinity
    mate:           ADVx_SLT
    standard_name:  ADVy_SLT
    units:          psu.m^3/s

In [None]:
#advective flux
adv_flux_y = dsr.ADVy_SLT.sel(**lon_range).sel(YG=upper_lat, method='nearest').sum(dim=['XC', 'Z']).load()
adv_flux_x = dsr.ADVx_SLT.sel(**lat_range).sel(XG=right_lon, method='nearest').sum(dim=['YC', 'Z']).load()
adv_flux = adv_flux_x + adv_flux_y
adv_flux.load()
#units: (˚C m^3)/s

In [None]:
yg_index = dsr.indexes['YG']
yg_index.get_loc(-60, method='nearest')

In [None]:
adv_flux_y

In [None]:
#diffusive flux
diff_flux_y = dsr.DFyE_SLT.sel(**lon_range).sel(YG=upper_lat, method='nearest').sum(dim=['XC', 'Z']).load()
diff_flux_x = dsr.DFxE_SLT.sel(**lat_range).sel(XG=360, method='nearest').sum(dim=['YC', 'Z']).load()
diff_flux = diff_flux_x + diff_flux_y
diff_flux.load()
#units: (˚C m^3)/s

In [None]:
#vertical flux
s_flux_z = (dsr.SFLUX.sel(**lat_range, **lon_range) * coords.rA).sum(dim=['XC','YC'])
s_flux_z.load()
#units: W = (kg*m^2)/s^3

In [None]:
coords.rA.sel(YC=-63.5, method='nearest')

In [13]:
dsr.WSLTMASS.sel(YC=-63.4583, method='nearest')

<xarray.DataArray 'WSLTMASS' (time: 438, Zl: 42, XC: 2160)>
dask.array<shape=(438, 42, 2160), dtype=float32, chunksize=(1, 42, 2160)>
Coordinates:
  * XC       (XC) float32 0.083333336 0.25 0.4166667 ... 359.75 359.9167
    YC       float32 -63.541634
  * Zl       (Zl) float32 0.0 -10.0 -21.0 -33.0 ... -4950.0 -5200.0 -5450.0
  * time     (time) datetime64[ns] 2005-01-06 2005-01-11 ... 2010-12-31
Attributes:
    long_name:      Vertical Mass-Weight Transp of Salinity
    standard_name:  WSLTMASS
    units:          psu.m/s

In [None]:
#linear free surface correction
heat_capacity_cp = 3.994e3 #J/kg*˚C
runit2mass = 1.035e3 #kg/m^3
surface_term = s_flux_z / (heat_capacity_cp * runit2mass)
#units: (˚C m^3)/s

lin_fs_correction = -(dsr.WSLTMASS.isel(Zl=0, drop=True) * coords.rA
                     ).sel(YC=slice(lower_lat, upper_lat)).sum(dim=['XC', 'YC'])
#lin_fs_correction longitude window is slightly different than other terms
#units: (˚C m^3)/s

In [None]:
surface_term.load()
lin_fs_correction.load()

In [None]:
surface_term.plot()
lin_fs_correction.plot()

In [None]:
#is it hFacC or hFacS
total_volume = (coords.rA + coords.drF + coords.hFacC).sum().load()
total_volume

In [15]:
dsr.TOTSTEND.sel(YC=upper_lat, method='nearest')

<xarray.DataArray 'TOTSTEND' (time: 438, Z: 42, XC: 2160)>
dask.array<shape=(438, 42, 2160), dtype=float32, chunksize=(1, 42, 2160)>
Coordinates:
  * XC       (XC) float32 0.083333336 0.25 0.4166667 ... 359.75 359.9167
    YC       float32 -65.04163
  * Z        (Z) float32 -5.0 -15.5 -27.0 -39.5 ... -5075.0 -5325.0 -5575.0
  * time     (time) datetime64[ns] 2005-01-06 2005-01-11 ... 2010-12-31
Attributes:
    long_name:      Tendency of Salinity
    standard_name:  TOTSTEND
    units:          psu/day

In [None]:
tot_s_tend_weddell = (dsr.TOTSTEND * coords.rA * coords.drF * coords.hFacC
          ).sel(**lon_range, YC=slice(lower_lat, upper_lat)).sum(dim=['XC', 'YC', 'Z']) / (24*60*60)
tot_s_tend_weddell.load()
#same longitudinal window as lin_fs_correction (63.541634)
#units: (˚C m^3)/s

In [None]:
adv_flux_y.plot(label='meridional')
adv_flux_x.plot(label='zonal')
plt.legend();

In [None]:
rhs = + adv_flux + diff_flux + lin_fs_correction + surface_term
lhs = tottend_weddell
rhs.plot()
lhs.plot()

In [None]:
(rhs - lhs).plot()

In [None]:
(rhs - lhs).mean().load()

In [None]:
all_terms = xr.merge([tot_s_tend_weddell.rename('tottend'),
                      adv_flux.rename('adv_flux'),
                      diff_flux.rename('diff_flux'),
                      lin_fs_correction.rename('lin_fs'),
                      surface_term.rename('surface')])
all_terms

In [None]:
all_terms.to_array().plot.line(x='time')

In [None]:
df = all_terms.mean(dim='time').reset_coords(drop=True).to_array().to_dataframe(name='budget')
df

In [None]:
rhs_minus_lhs = df.iloc[0] - df.iloc[1:].sum()
rhs_minus_lhs

In [None]:
budget_diff = rhs.sum() + lhs.sum()
budget_diff.load()

In [None]:
df.plot(kind='bar')

In [None]:
tot_s_tend_weddell.plot(label='dT/dt')
adv_flux.plot(label='advective term')
surface_term.plot(label='surface term')
lin_fs_correction.plot(label='lin-surf-correction')
diff_flux.plot(label='diffusive term')
plt.legend();

In [None]:
adv_flux_y.plot()

In [None]:
(surface_term - tot_s_tend_weddell).plot()
adv_flux_y.plot(label='adv flux y')

In [None]:
test = surface_term - tot_s_tend_weddell
test

In [None]:
dsr.TOTSTEND