In [55]:
import xarray as xr
import xmitgcm
import xgcm
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import ipywidgets
import dask
import cmocean
import pandas as pd

In [2]:
model_dir = '/g/data/jk72/ed7737/SO-channel_embayment/simulations/run/'

In [73]:
ds_2d = xmitgcm.open_mdsdataset(data_dir=model_dir+'/Diags/', grid_dir=model_dir, prefix=['2D_diags'], delta_t=500, calendar='360_day', ref_date='2000-1-1 0:0:0')

In [74]:
ds_seaice = xmitgcm.open_mdsdataset(data_dir=model_dir+'/Diags/', grid_dir=model_dir, prefix=['seaIceDiag'], delta_t=500, calendar='360_day', ref_date='2000-1-1 0:0:0')

In [75]:
ds_momentum = xmitgcm.open_mdsdataset(data_dir=model_dir+'/Diags/', grid_dir=model_dir, prefix=['Momentum_diags'], delta_t=500, calendar='360_day', ref_date='2000-1-1 0:0:0')

In [76]:
ds_EXF = xmitgcm.open_mdsdataset(data_dir=model_dir+'/Diags/', grid_dir=model_dir, prefix=['EXF_diags'], delta_t=500, calendar='360_day', ref_date='2000-1-1 0:0:0')

In [77]:
ds_layers = xmitgcm.open_mdsdataset(data_dir=model_dir+'/Diags/', grid_dir=model_dir, prefix=['layDiag'], delta_t=500, calendar='360_day', ref_date='2000-1-1 0:0:0')

layer_bounds = np.array([33.0, 33.5, 34.0, 34.2,
                         34.4, 34.6, 34.8, 35.0, 35.2,
                         35.3, 35.4, 35.5, 35.6, 35.7,
                         35.8, 35.9, 36.0, 36.1, 36.2,
                         36.3, 36.4, 36.5, 36.6, 36.7,
                         36.75,
                         36.80, 36.84, 36.88, 36.92, 36.96,
                         37.00, 37.04, 37.08, 37.12, 37.16,
                         37.20, 37.24, 37.28, 37.32, 37.36,
                         37.40, 37.44, 37.48, 37.52, 37.56,
                         37.60, 37.64, 37.68, 37.72, 37.76,
                         37.80])
layer_midpoints = (layer_bounds[1:] + layer_bounds[:-1])/2.

ds_layers = ds_layers.rename_dims({'_UNKNOWN_':'layer_pot_dens'})
ds_layers = ds_layers.assign_coords(layer_pot_dens=layer_midpoints)


In [None]:
ds_state = xmitgcm.open_mdsdataset(data_dir=model_dir+'/Diags/', grid_dir=model_dir, prefix=['state'], delta_t=500, calendar='360_day', ref_date='2000-1-1 0:0:0')

In [None]:
ds_state['drW'] = ds_state.hFacW * ds_state.drF #vertical cell size at u point
ds_state['drS'] = ds_state.hFacS * ds_state.drF #vertical cell size at v point
ds_state['drC'] = ds_state.hFacC * ds_state.drF #vertical cell size at tracer point

metrics = {
    ('X',): ['dxC', 'dxG'], # X distances
    ('Y',): ['dyC', 'dyG'], # Y distances
    ('Z',): ['drW', 'drS', 'drC'], # Z distances
    ('X', 'Y'): ['rA', 'rAz', 'rAs', 'rAw'] # Areas
}

grid = xgcm.Grid(ds_state, periodic=['X'], metrics=metrics)
grid

In [None]:
ds_2d

In [None]:
ds_2d['TFLUX'].sel(time=ds_2d['time'][-1]).plot()

In [None]:
ds_state['THETA'].sel(Z=0, method='nearest').sel(time=ds_state['time'][-1], method='nearest').plot()

In [None]:
def plot_temperature(i, xloc):
    fig, ax = plt.subplots(1,2, figsize=(13,4))
    
    ds_state['THETA'].sel(XC=xloc, method='nearest').sel(time=ds_state['time'][i], method='nearest').plot(ax=ax[0],
                                                                                                          vmin=-2, vmax=15,
                                                                                                         cmap=cmocean.cm.thermal)
    (-ds_2d['MXLDEPTH'].sel(XC=xloc, method='nearest').sel(time=ds_2d['time'][i])).plot(color='k', ax=ax[0])
    
    ds_2d['MXLDEPTH'].sel(time=ds_2d['time'][i]).plot.contour(ax=ax[1],
                                                             levels=[1,10,50,
                                                                     100,200,300,400,500,600,700,800,900,1000,
                                                                     1200,1400,1600,1800,2000])
    ds_state['THETA'].sel(Z=0, method='nearest').sel(time=ds_state['time'][i], method='nearest').plot(ax=ax[1],
                                                                                                      vmin=-2, vmax=15,
                                                                                                      cmap=cmocean.cm.thermal)
    ds_seaice['SI_Fract'].sel(time=ds_seaice['time'][i]).where(
            ds_seaice['SI_Fract'].sel(time=ds_seaice['time'][i])>0.15).plot(cmap='binary_r', ax=ax[1], vmin=0, vmax=1, zorder=3)


    


In [None]:
t_max = min(ds_state['time'].shape[0]-1, ds_2d['time'].shape[0]-1)

ipywidgets.interactive(plot_temperature,
                       i=ipywidgets.IntSlider(value=t_max, min=0,max=t_max),
                       xloc=ipywidgets.FloatSlider(value=400e3, min=0,max=1e6, step=10e3))

In [None]:
heat = grid.average(ds_state['THETA'].where(ds_state['SALT']>0), ['X','Y'])

In [None]:
heat.sel(Z=0, method='nearest').plot(label='Z=0')
heat.sel(Z=-1000, method='nearest').plot(label='Z=-1000')
heat.sel(Z=-2000, method='nearest').plot(label='Z=-2000')
heat.sel(Z=-3000, method='nearest').plot(label='Z=-3000')
heat.sel(Z=-4000, method='nearest').plot(label='Z=-4000')


plt.legend()

In [None]:
# change in temperature
xloc = 400e3
(ds_state['THETA'].sel(XC=xloc, method='nearest').sel(time=ds_state['time'][-1], method='nearest') - 
    ds_state['THETA'].sel(XC=xloc, method='nearest').sel(time=ds_state['time'][0], method='nearest')).plot(vmin=-3, vmax=3, cmap='RdBu_r')
(-ds_2d['MXLDEPTH'].sel(XC=xloc, method='nearest').sel(time=ds_2d['time'][-1])).plot(color='k')

In [None]:
def plot_salinity(i, xloc):
    fig, ax = plt.subplots(1,2, figsize=(13,4))
    
    ds_state['SALT'].sel(XC=xloc, method='nearest').sel(time=ds_state['time'][i], method='nearest').plot(ax=ax[0],
                                                                                                          vmin=34, vmax=35,
                                                                                                         cmap=cmocean.cm.haline)
    (-ds_2d['MXLDEPTH'].sel(XC=xloc, method='nearest').sel(time=ds_2d['time'][i])).plot(color='k', ax=ax[0])
    
    ds_2d['MXLDEPTH'].sel(time=ds_2d['time'][i]).plot.contour(ax=ax[1],
                                                             levels=[1,10,50,
                                                                     100,200,300,400,500,600,700,800,900,1000,
                                                                     1200,1400,1600,1800,2000])
    ds_state['SALT'].sel(Z=0, method='nearest').sel(time=ds_state['time'][i], method='nearest').plot(ax=ax[1],
                                                                                                      vmin=34, vmax=35,
                                                                                                      cmap=cmocean.cm.haline)
    ds_seaice['SI_Fract'].sel(time=ds_seaice['time'][i]).where(
            ds_seaice['SI_Fract'].sel(time=ds_seaice['time'][i])>0.15).plot(cmap='binary_r', ax=ax[1], vmin=0, vmax=1, zorder=3)


    



In [None]:
t_max = min(ds_state['time'].shape[0]-1, ds_2d['time'].shape[0]-1)

ipywidgets.interactive(plot_salinity,
                       i=ipywidgets.IntSlider(value=t_max, min=0,max=t_max),
                       xloc=ipywidgets.FloatSlider(value=400e3, min=0,max=1e6, step=10e3))

In [None]:
salinity = grid.average(ds_state['SALT'].where(ds_state['SALT']>0), ['X','Y'])

salinity.sel(Z=0, method='nearest').plot(label='Z=0')
salinity.sel(Z=-1000, method='nearest').plot(label='Z=-1000')
salinity.sel(Z=-2000, method='nearest').plot(label='Z=-2000')
salinity.sel(Z=-3000, method='nearest').plot(label='Z=-3000')
salinity.sel(Z=-4000, method='nearest').plot(label='Z=-4000')



plt.legend()

In [None]:
# change in salinity
xloc = 400e3
(ds_state['SALT'].sel(XC=xloc, method='nearest').sel(time=ds_state['time'][-1], method='nearest') - 
    ds_state['SALT'].sel(XC=xloc, method='nearest').sel(time=ds_state['time'][0], method='nearest')).plot(vmin=-.5, vmax=.5, cmap='RdBu_r')
(-ds_2d['MXLDEPTH'].sel(XC=xloc, method='nearest').sel(time=ds_2d['time'][-1])).plot(color='k')

In [None]:
xloc = 400e3
ds_state['THETA'].sel(XC=xloc, method='nearest').sel(time=ds_state['time'][-1], method='nearest').plot()
(-ds_2d['MXLDEPTH'].sel(XC=xloc, method='nearest').sel(time=ds_2d['time'][-1])).plot(color='k')

In [None]:
ds_2d['MXLDEPTH'].sel(time=ds_2d['time'][-1], method='nearest').plot()

In [None]:
ds_state['SALT'].sel(XC=400e3, method='nearest').sel(time=ds_state['time'][-1], method='nearest').plot(vmin=34, vmax=35)

In [None]:
ds_state['UVEL'].sel(Z=0, method='nearest').sel(time=ds_state['time'][-1]).plot()

In [None]:
ds_state['UVEL'].mean(dim='XG').sel(time=ds_state['time'][-1]).plot()

In [None]:
ds_state['VVEL'].mean(dim='XC').sel(time=ds_state['time'][-1]).plot()

In [None]:
transport = grid.integrate(ds_state.UVEL, ['Y', 'Z'])

In [None]:
transport.sel(XG=400e3, method='nearest').plot()

In [None]:
ds_state['VVEL'].sel(Z=0, method='nearest').sel(time=ds_state['time'][-1]).plot()

## Sea Ice

In [None]:
ds_seaice['SI_Fract'].sel(time=ds_seaice['time'][-1]).plot()

In [None]:
sea_ice_time_series = grid.integrate(ds_seaice['SI_Fract'], ['X', 'Y'])/1e12 # in millions of square km

In [None]:
sea_ice_time_series.plot()

## vorticity

In [None]:
zeta = (-grid.diff(ds_state.UVEL * ds_state.dxC, 'Y') + grid.diff(ds_state.VVEL * ds_state.dyC, 'X'))/ds_state.rAz


In [None]:
zeta.sel(Z=0, method='nearest').sel(time=ds_state['time'][-1]).plot()

## Layers


In [None]:
ds_layers

In [None]:
ds_layers['LaVH3RHO'].mean(dim='XC').cumsum(dim='layer_density').sel(time=ds_layers['time'][-1]).plot()