In [1]:
import xmitgcm
import xarray as xr
import numpy as np
import xgcm
import datetime
import os
import scipy
from matplotlib import pyplot as plt 
import cartopy as cart
import pyresample
import pandas as pd



from mpl_toolkits.mplot3d import Axes3D 
from scipy import interpolate
import scipy.io as sio

import warnings; warnings.simplefilter('ignore')

class LLCMapper:

    def __init__(self, ds, dx=0.25, dy=0.25):
        # Extract LLC 2D coordinates
        lons_1d = ds.XC.values.ravel()
        lats_1d = ds.YC.values.ravel()

        # Define original grid
        self.orig_grid = pyresample.geometry.SwathDefinition(lons=lons_1d, lats=lats_1d)

        # Longitudes latitudes to which we will we interpolate
        lon_tmp = np.arange(-180, 180, dx) + dx/2
        lat_tmp = np.arange(-90, -54, dy) + dy/2

        # Define the lat lon points of the two parts.
        self.new_grid_lon, self.new_grid_lat = np.meshgrid(lon_tmp, lat_tmp)
        self.new_grid  = pyresample.geometry.GridDefinition(lons=self.new_grid_lon,
                                                            lats=self.new_grid_lat)

    def __call__(self, da, ax=None, projection=cart.crs.SouthPolarStereo(), lon_0=-60, **plt_kwargs):

        assert set(da.dims) == set(['j', 'i']), "da must have dimensions ['j', 'i']"

        if ax is None:
            fig, ax = plt.subplots(figsize=(12, 6), subplot_kw={'projection': projection})
        else:
            m = plt.axes(projection=projection)
        
        
        field = pyresample.kd_tree.resample_nearest(self.orig_grid, da.fillna(0).values,
                                                    self.new_grid,
                                                    radius_of_influence=100000,
                                                    fill_value=None)
        field1 = pyresample.kd_tree.resample_nearest(self.orig_grid, coords.Depth.values,
                                                    self.new_grid,
                                                    radius_of_influence=100000,
                                                    fill_value=None)
        vmax = plt_kwargs.pop('vmax', field.max())
        vmin = plt_kwargs.pop('vmin', field.min())


        x,y = self.new_grid_lon, self.new_grid_lat

        # Find index where data is splitted for mapping
        split_lon_idx = round(x.shape[1]/(360/(lon_0 if lon_0>0 else lon_0+360)))


        p = ax.pcolormesh(x[:,:split_lon_idx], y[:,:split_lon_idx], field[:,:split_lon_idx],
                         vmax=vmax, vmin=vmin, transform=cart.crs.PlateCarree(), zorder=1, **plt_kwargs)
        p = ax.pcolormesh(x[:,split_lon_idx:], y[:,split_lon_idx:], field[:,split_lon_idx:],
                         vmax=vmax, vmin=vmin, transform=cart.crs.PlateCarree(), zorder=2, **plt_kwargs)
        p1 = ax.contour(x[:,:split_lon_idx], y[:,:split_lon_idx], field1[:,:split_lon_idx],
                         [1000], colors='black',transform=cart.crs.PlateCarree(), zorder=1)
        p1 = ax.contour(x[:,split_lon_idx:], y[:,split_lon_idx:], field1[:,split_lon_idx:],
                         [1000], colors='black',transform=cart.crs.PlateCarree(), zorder=2)
        ax.add_feature(cart.feature.LAND, facecolor='0.5', zorder=3)
        ax.gridlines(draw_labels=True, dms=True, x_inline=False, y_inline=False,color='black',linestyle=':')
        label = ''
        if da.name is not None:
            label = da.name
        if 'units' in da.attrs:
            label += ' [%s]' % da.attrs['units']
        cb = plt.colorbar(p, shrink=0.4, label=label)
        return ax
    
    # coefficients nonlinear equation of state in pressure coordinates for
# 1. density of fresh water at p = 0
eosJMDCFw = [ 999.842594,
              6.793952e-02,
           -  9.095290e-03,
              1.001685e-04,
           -  1.120083e-06,
              6.536332e-09,
            ]
# 2. density of sea water at p = 0
eosJMDCSw = [     8.244930e-01,
               -  4.089900e-03,
                  7.643800e-05,
               -  8.246700e-07,
                  5.387500e-09,
               -  5.724660e-03,
                  1.022700e-04,
               -  1.654600e-06,
                  4.831400e-04,
            ]
# coefficients in pressure coordinates for
# 3. secant bulk modulus K of fresh water at p = 0
eosJMDCKFw = [    1.965933e+04,
                  1.444304e+02,
                - 1.706103e+00,
                  9.648704e-03,
                - 4.190253e-05,
             ]
# 4. secant bulk modulus K of sea water at p = 0
eosJMDCKSw = [    5.284855e+01,
                - 3.101089e-01,
                  6.283263e-03,
                - 5.084188e-05,
                  3.886640e-01,
                  9.085835e-03,
                - 4.619924e-04,
             ]
# 5. secant bulk modulus K of sea water at p
eosJMDCKP = [     3.186519e+00,
                  2.212276e-02,
                - 2.984642e-04,
                  1.956415e-06,
                  6.704388e-03,
                - 1.847318e-04,
                  2.059331e-07,
                  1.480266e-04,
                  2.102898e-04,
                - 1.202016e-05,
                  1.394680e-07,
                - 2.040237e-06,
                  6.128773e-08,
                  6.207323e-10,
            ]

def densjmd95(s,theta,p):
    """
    Computes in-situ density of sea water
    Density of Sea Water using Jackett and McDougall 1995 (JAOT 12)
    polynomial (modified UNESCO polynomial).
    Parameters
    ----------
    s : array_like
        salinity [psu (PSS-78)]
    theta : array_like
        potential temperature [degree C (IPTS-68)];
        same shape as s
    p : array_like
        pressure [dbar]; broadcastable to shape of s
    Returns
    -------
    dens : array
        density [kg/m^3]
    Example
    -------
    >>> densjmd95(35.5, 3., 3000.)
    1041.83267
    Notes
    -----
    AUTHOR:  Martin Losch 2002-08-09  (mlosch@mit.edu)
    Jackett and McDougall, 1995, JAOT 12(4), pp. 381-388
    """

    # make sure arguments are floating point
    s = np.asfarray(s)
    t = np.asfarray(theta)
    p = np.asfarray(p)

    # convert pressure to bar
    p = .1*p

    t2 = t*t
    t3 = t2*t
    t4 = t3*t

    if np.any(s<0):
        sys.stderr.write('negative salinity values! setting to nan\n')
#       the sqrt will take care of this
#        if s.ndim > 0:
#            s[s<0] = np.nan
#        else:
#            s = np.nan
            
    s3o2 = s*np.sqrt(s)

    # density of freshwater at the surface
    rho = ( eosJMDCFw[0]
          + eosJMDCFw[1]*t
          + eosJMDCFw[2]*t2
          + eosJMDCFw[3]*t3
          + eosJMDCFw[4]*t4
          + eosJMDCFw[5]*t4*t
          )
    # density of sea water at the surface
    rho = ( rho
           + s*(
                 eosJMDCSw[0]
               + eosJMDCSw[1]*t
               + eosJMDCSw[2]*t2
               + eosJMDCSw[3]*t3
               + eosJMDCSw[4]*t4
               )
           + s3o2*(
                 eosJMDCSw[5]
               + eosJMDCSw[6]*t
               + eosJMDCSw[7]*t2
               )
           + eosJMDCSw[8]*s*s
          )

    rho = rho / (1. - p/bulkmodjmd95(s,t,p))

    return rho

def bulkmodjmd95(s,theta,p):
    """ Compute bulk modulus
    """
    # make sure arguments are floating point
    s = np.asfarray(s)
    t = np.asfarray(theta)
    p = np.asfarray(p)

    t2 = t*t
    t3 = t2*t
    t4 = t3*t

#    if np.any(s<0):
#        sys.stderr.write('negative salinity values! setting to nan\n')
#       the sqrt will take care of this
#        if s.ndim > 0:
#            s[s<0] = np.nan
#        else:
#            s = np.nan

    s3o2 = s*np.sqrt(s)

    #p = pressure(i,j,k,bi,bj)*SItoBar
    p2 = p*p
    # secant bulk modulus of fresh water at the surface
    bulkmod = ( eosJMDCKFw[0]
              + eosJMDCKFw[1]*t
              + eosJMDCKFw[2]*t2
              + eosJMDCKFw[3]*t3
              + eosJMDCKFw[4]*t4
              )
    # secant bulk modulus of sea water at the surface
    bulkmod = ( bulkmod
              + s*(      eosJMDCKSw[0]
                       + eosJMDCKSw[1]*t
                       + eosJMDCKSw[2]*t2
                       + eosJMDCKSw[3]*t3
                       )
              + s3o2*(   eosJMDCKSw[4]
                       + eosJMDCKSw[5]*t
                       + eosJMDCKSw[6]*t2
                       )
               )
    # secant bulk modulus of sea water at pressure p
    bulkmod = ( bulkmod
              + p*(   eosJMDCKP[0]
                    + eosJMDCKP[1]*t
                    + eosJMDCKP[2]*t2
                    + eosJMDCKP[3]*t3
                  )
              + p*s*(   eosJMDCKP[4]
                      + eosJMDCKP[5]*t
                      + eosJMDCKP[6]*t2
                    )
              + p*s3o2*eosJMDCKP[7]
              + p2*(   eosJMDCKP[8]
                     + eosJMDCKP[9]*t
                     + eosJMDCKP[10]*t2
                   )
              + p2*s*(  eosJMDCKP[11]
                      + eosJMDCKP[12]*t
                      + eosJMDCKP[13]*t2
                     )
               )

    return bulkmod



# aliases
dens = densjmd95

In [None]:
from dask_jobqueue import SGECluster


cluster = SGECluster(cores=1,memory="5GB", interface="ib0",
             queue="Analysis3.q",
             walltime="02:00:00")
cluster.scale(10)
#62 gb total 
from dask.distributed import Client
import dask.config
import distributed
client = Client(cluster)
client

In [None]:
datadir1 = '/data2/antarctic_model/MITgcm_SO/experiments_mky/llc270_icb_CM4pI_cycle3/results/diags/'  
griddir =  '/data2/antarctic_model/MITgcm_SO/experiments_mky/llc270_icb_CM4pI_cycle3/results/' 
#datadir1 = '/data2/myoungs/MITgcm_SO/experiments_mky/llc270_notides_CM4SSP_exprun/results/diags/'    
ds = xmitgcm.open_mdsdataset(datadir1,grid_dir=griddir,prefix=['state_2d_set1','layers_3d_set2','fluxes_3d_set1','trsp_3d_set1','state_3d_set1']
                             ,geometry='curvilinear',delta_t = 1200)
ds = ds.assign_coords(time=(ds.time+np.datetime64("1992-12-31")))
ds['drW'] = ds.hFacW * ds.drF #vertical cell size at u point
ds['drS'] = ds.hFacS * ds.drF #vertical cell size at v point
ds['drC'] = ds.hFacC * ds.drF #vertical cell size at tracer point
metrics = {
    ('X',): ['dxC', 'dxG'], # X distances
    ('Y',): ['dyC', 'dyG'], # Y distances
    ('Z',): ['drW', 'drS', 'drC'], # Z distances
    ('X', 'Y'): ['rA', 'rAz', 'rAs', 'rAw'] # Areas
}
grid = xgcm.Grid(ds,periodic=False,metrics=metrics)#,coords={'1RHO':{'center':'l1_c','inner':'l1_i','outer':'l1_b'}})
grid
coords = ds.coords.to_dataset().reset_coords()
#ds = ds.reset_coords(drop=True)
p_ref = -1029*9.81*ds.Z[69]/1e4
#dens = xr.apply_ufunc(densjmd95,ds.SALT.mean(dim='time'),ds.THETA.mean(dim='time'),p_ref,dask='allowed')
#ds['dens'] = dens
mapper = LLCMapper(coords)
#mapper

datadir1 = '/data2/antarctic_model/MITgcm_SO/experiments_mky/llc270_icb_CM4SSP/results/diags/' 
griddir =  '/data2/antarctic_model/MITgcm_SO/experiments_mky/llc270_icb_CM4SSP/results/' 
#datadir1 = '/data2/myoungs/MITgcm_SO/experiments_mky/llc270_notides_CM4SSP_exprun/results/diags/'    
dsp = xmitgcm.open_mdsdataset(datadir1,grid_dir=griddir,prefix=['state_2d_set1','layers_3d_set2','fluxes_3d_set1','trsp_3d_set1','state_3d_set1']
                             ,geometry='curvilinear',delta_t = 1200)
dsp = dsp.assign_coords(time=(dsp.time+np.datetime64("1992-12-31")))
gridp = xgcm.Grid(dsp, periodic=False)
coordsp = dsp.coords.to_dataset().reset_coords()
dsp = dsp.reset_coords(drop=True)
mapperp = LLCMapper(coordsp)
p_ref = -1029*9.81*ds.Z[69]/1e4
#densp = xr.apply_ufunc(densjmd95,dsp.SALT.mean(dim='time'),dsp.THETA.mean(dim='time'),p_ref,dask='allowed')
#dsp['dens'] = densp

#datadir1 = '/data3/astewart/MITgcm_SO/experiments_mky/llc270_notides_CM4pI_fixedmelt/results/diags/'     
datadir1 = '/data2/antarctic_model/MITgcm_SO/experiments_mky/llc270_icb_CM4SSP_fixstrat/results/diags/' 
griddir =  '/data2/antarctic_model/MITgcm_SO/experiments_mky/llc270_icb_CM4SSP_fixstrat/results/'  
dsfix = xmitgcm.open_mdsdataset(datadir1,grid_dir=griddir,prefix=['state_2d_set1','layers_3d_set2','fluxes_3d_set1','trsp_3d_set1','state_3d_set1']
                             ,geometry='curvilinear',delta_t = 900)
dsfix = dsfix.assign_coords(time=(dsfix.time+np.datetime64("1992-12-31")))

dsfix
p_ref = -1029*9.81*ds.Z[69]/1e4
#densfix = xr.apply_ufunc(densjmd95,dsfix.SALT.mean(dim='time'),dsfix.THETA.mean(dim='time'),p_ref,dask='allowed')
#dsfix['dens'] = densfix

In [None]:
bottomd = coords.Z.where(coords.hFacC>0).min(dim='k')
thetamean = ds.THETA[-120:].where(coords.Z == bottomd).max(dim='k').mean(dim='time')
smean = ds.SALT[-120:].where(coords.Z == bottomd).max(dim='k').mean(dim='time')
#densmean = ds.dens.where(coords.Z == bottomd).max(dim='k')
thetameanp = dsp.THETA[-120:].where(coords.Z == bottomd).max(dim='k').mean(dim='time')
smeanp = dsp.SALT[-120:].where(coords.Z == bottomd).max(dim='k').mean(dim='time')
#densmeanp = dsp.dens.where(coords.Z == bottomd).max(dim='k')
thetameanf = dsfix.THETA[-120:].where(coords.Z == bottomd).max(dim='k').mean(dim='time')
smeanf = dsfix.SALT[-120:].where(coords.Z == bottomd).max(dim='k').mean(dim='time')
#densmeanf = dsfix.dens.where(coords.Z == bottomd).max(dim='k')

In [None]:
plt.figure()
mapper((thetameanf-thetamean).fillna(0),cmap='RdBu_r',vmin=-1.8,vmax=1.8)
plt.title('Forced Response')

plt.figure()
mapper((thetameanp-thetameanf).fillna(0),cmap='RdBu_r',vmin=-1.8,vmax=1.8)
plt.title('Melt Feedback')

In [None]:
plt.figure()
mapper((smeanf-smean).fillna(0),cmap='RdBu_r',vmin=-0.1, vmax =0.1)
plt.title('Forced Response')

plt.figure()
mapper((smeanp-smeanf).fillna(0),cmap='RdBu_r',vmin=-0.1, vmax =0.1)
plt.title('Melt Feedback')

In [None]:
plt.figure()
mapper((densmeanf-densmean).fillna(0),cmap='RdBu_r',vmin=-0.2,vmax=0.2)
plt.title('Forced Response')

plt.figure()
mapper((densmeanp-densmeanf).fillna(0),cmap='RdBu_r',vmin=-0.2,vmax=0.2)
plt.title('Melt Feedback')

In [None]:
thetamean = grid.integrate(ds.THETA.mean(dim='time'),axis='Z')/coords.drC.sum(dim='k')
smean = grid.integrate(ds.SALT.mean(dim='time'),axis='Z')/coords.drC.sum(dim='k')
thetameanp = grid.integrate(dsp.THETA.mean(dim='time'),axis='Z')/coords.drC.sum(dim='k')
smeanp = grid.integrate(dsp.SALT.mean(dim='time'),axis='Z')/coords.drC.sum(dim='k')
thetameanf = grid.integrate(dsfix.THETA.mean(dim='time'),axis='Z')/coords.drC.sum(dim='k')
smeanf = grid.integrate(dsfix.SALT.mean(dim='time'),axis='Z')/coords.drC.sum(dim='k')