In [1]:
# Set up a cluster using dask
from dask_gateway import Gateway
from dask.distributed import Client

gateway = Gateway()
cluster = gateway.new_cluster()

cluster.scale(20)
cluster

KeyboardInterrupt: 

In [None]:
#Make a client so you can see the progress of tasks (click the link that appears below)
client = Client(cluster)
client

In [None]:
import intake
import pandas as pd
import pprint
import fsspec

import numpy as np
import xarray as xr
import xgcm
import xesmf as xe
import s3fs

import matplotlib.pyplot as plt
%matplotlib inline

from fastjmd95 import jmd95numba

In [None]:
col = intake.open_esm_datastore(
    "https://raw.githubusercontent.com/NCAR/cesm-lens-aws/main/intake-catalogs/aws-cesm1-le.json")
col_subset = col.search(experiment="20C",variable=["TEMP", "SALT","UVEL","VVEL","WVEL","SHF","VNT","WTT"])
#col_subset = col.search(experiment="20C",variable=["TEMP"])
dsets = col_subset.to_dataset_dict(zarr_kwargs={"consolidated": True}, storage_options={"anon": True})

In [None]:
ds = dsets['ocn.20C.monthly']
theta = ds.TEMP
salt = ds.SALT
uvel = ds.UVEL
vvel = ds.VVEL
wvel = ds.WVEL
vnt = ds.VNT
wtt = ds.WTT

shf = ds.SHF

url = "s3://ncar-cesm-lens/ocn/static/grid.zarr"
fs = s3fs.S3FileSystem(anon=True)
grid_ds = xr.open_zarr(s3fs.S3Map(url, s3=fs))

In [None]:
th0 = theta.roll(nlon=60).sel(nlon=slice(0,120))
slt0 = salt.roll(nlon=60).sel(nlon=slice(0,120))
uvel0 = uvel.roll(nlon=60).sel(nlon=slice(0,120))
vel0 = vvel.roll(nlon=60).sel(nlon=slice(0,120))
wvel0 = wvel.roll(nlon=60).sel(nlon=slice(0,120))
#vnt0 = vnt.roll(nlon=60)
#wtt0 = wtt.roll(nlon=60)
shf0 = shf.roll(nlon=60).sel(nlon=slice(0,120))
grid0 = grid_ds.roll(nlon=60,roll_coords=True).sel(nlon=slice(0,120))

uvel0=uvel0.rename({'nlon':'vlon','nlat':'vlat'})
vel0=vel0.rename({'nlon':'vlon','nlat':'vlat'})

In [None]:
atl_mask = grid0.REGION_MASK
atl_mask = atl_mask.where((atl_mask==6) | (atl_mask==8)| (atl_mask==-12))
atl_mask = atl_mask.where(np.isnan(atl_mask),1)
atl_mask = atl_mask.where(atl_mask.nlon>5)
atl_mask = atl_mask.where(atl_mask.nlon<111)
atl_mask = atl_mask.where((atl_mask.nlon<108) | (atl_mask.nlat>331))

In [None]:
#create grid object

#first merge v and w 

ds = vel0.to_dataset().merge(wvel0)
ds = ds.merge(th0)
ds = ds.merge(uvel0)
ds = ds.merge(grid0.DXU.rename({'nlon':'vlon','nlat':'vlat'}).rename('gridS'))
ds = ds.merge(grid0.DYU.rename({'nlon':'vlon','nlat':'vlat'}).rename('gridW'))
z_outer = np.concatenate((ds.z_w_top.values,np.asarray(550000).reshape((-1,))))
ds = ds.assign_coords(z_w_outer=z_outer)

from xgcm.autogenerate import generate_grid_ds
from xgcm import Grid

grid = Grid(ds, coords={'X':{'center': 'nlon','right': 'vlon'},'Y':{'center': 'nlat','right': 'vlat'},'Z': {'center': 'z_t','outer': 'z_w_outer'}},periodic=False)

#calculate dz
dz = grid0.dz#grid.diff(ds.z_w_outer,axis='Z',boundary='extend')

In [None]:
vs = grid.interp(ds.VVEL*ds.gridS,'X',boundary='fill')
ts = grid.interp(ds.TEMP,'Y',boundary='fill')

vw = grid.interp(ds.UVEL*ds.gridW,'Y',boundary='fill')
tw = grid.interp(ds.TEMP,'X',boundary='fill')
deltaxS = grid.interp(ds.gridS,'X',boundary='fill')
deltayW = grid.interp(grid.interp(ds.gridW,'Y',boundary='fill'),'X',boundary='fill')
maskS = atl_mask.rename({'nlat':'vlat'})

In [None]:
mid=slice(1,37)
timeslice = slice('1940-01-01','2016-02-01')
df = pd.DataFrame({'date1':pd.date_range('1941-01-01','1941-12-01',freq='MS'),
                   'date2':pd.date_range('1941-02-01','1942-01-01', freq='MS')})
df['diff'] = df['date2']-df['date1']
dt = df['diff'].dt.total_seconds().to_numpy()
ds_dt = xr.DataArray(np.repeat(dt,66),coords={'time':ds.TEMP.sel(time=timeslice).time.values},dims=['time'])

In [11]:
#heat transport
heat_transport = vs*ts

In [15]:
ht_42 = ((heat_transport*dz*maskS/100**3).sel(member_id=mid,time=timeslice
                                                                  )).sel(vlat=298).sum(['nlon','z_t']).rolling(time=24).mean().load().dropna("time",how='all')
ht_42_towrite = ht_42.to_dataset(name='ht_42').load()
ht_42_towrite.to_netcdf('/home/jovyan/amoc_heat_transport/ht_42.nc')



In [12]:
ht_34 = ((heat_transport*dz*maskS/100**3).sel(member_id=mid,time=timeslice
                                                                  )).sel(vlat=286).sum(['nlon','z_t']).rolling(time=24).mean().load().dropna("time",how='all')
ht_34_towrite = ht_34.to_dataset(name='ht_34').load()
ht_34_towrite.to_netcdf('/home/jovyan/amoc_heat_transport/ht_34.nc')

In [20]:
ht_26 = ((heat_transport*dz*maskS/100**3).sel(member_id=mid,time=timeslice
                                                                  )).sel(nlon=slice(22,120)).sel(vlat=270).sum(['nlon','z_t']).rolling(time=24).mean().dropna("time",how='all').load()

ht_26_towrite = ht_26.to_dataset(name='ht_26').load()
ht_26_towrite.to_netcdf('/home/jovyan/amoc_heat_transport/ht_26.nc')

In [21]:
ht_5S = ((heat_transport*dz*maskS/100**3).sel(member_id=mid,time=timeslice
                                                                  )).sel(nlon=slice(30,120)).sel(vlat=163).sum(['nlon','z_t']).rolling(time=24).mean().dropna("time",how='all').load()
ht_5S_towrite = ht_5S.to_dataset(name='ht_5S').load()
ht_5S_towrite.to_netcdf('/home/jovyan/amoc_heat_transport/ht_5S.nc')

In [13]:
cluster.shutdown()

2023-06-21 17:48:51,534 - distributed.client - ERROR - Failed to reconnect to scheduler after 30.00 seconds, closing client
Exception in callback None()
handle: <Handle cancelled>
Traceback (most recent call last):
  File "/srv/conda/envs/notebook/lib/python3.10/site-packages/tornado/iostream.py", line 1389, in _do_ssl_handshake
    self.socket.do_handshake()
  File "/srv/conda/envs/notebook/lib/python3.10/ssl.py", line 1342, in do_handshake
    self._sslobj.do_handshake()
ssl.SSLZeroReturnError: TLS/SSL connection has been closed (EOF) (_ssl.c:997)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/srv/conda/envs/notebook/lib/python3.10/asyncio/events.py", line 80, in _run
    self._context.run(self._callback, *self._args)
  File "/srv/conda/envs/notebook/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 206, in _handle_events
    handler_func(fileobj, events)
  File "/srv/conda/envs/notebook/lib/python3.10/s