In [None]:
import xarray as xr
import pandas as pd
import numpy as np
import datetime
import dask
import matplotlib.pyplot as plt
import time as timing
from glob import glob
import seaborn as sns

In [None]:
from dask.distributed import Client, progress, LocalCluster

In [None]:
dask.config.set({"distributed.scheduler.worker-ttl": 6400})
cluster = LocalCluster(name="dask_LE", processes=True, n_workers = 24,threads_per_worker=1)
cluster.adapt(minimum=1, maximum=24)
client = Client(cluster)


### Do you want to see the workers working?

In [None]:
print("Dashboard URL: {}".format(cluster.dashboard_link))

### Setup domain and time information

In [None]:
# Open domain_cfg (in this case mesh_mask) file
domcfg = xr.open_dataset('/gws/nopw/j04/canari/shared/large-ensemble/ocean/mesh_mask.nc')
# As we'll be multipling vo by e3v we have to make sure the have common dims (not sure there's any other way around this)
domcfg = domcfg.rename_dims({'z':'depthv'})

In [None]:
# What LE members are available?
LEM = [1,  10,  11,  12,  13,  18,  2,  20,  21,  22,  24,  28,  3,  30,  31,  4,  5,  6,  7,  9]

In [None]:
# Gathering time information from first and last file and converting from cfTime to DateTime for ease of use
fs = sorted(glob('/gws/nopw/j04/canari/shared/large-ensemble/priority/HIST2/%s/OCN/yearly/*/*grid_V_vomecrty.nc'%(LEM[0])))
d0 = xr.open_dataset(fs[0])
dn = xr.open_dataset(fs[-1])

tbeg = datetime.datetime.strptime(str(d0.time_counter.values[0]), '%Y-%m-%d %H:%M:%S')
tend = datetime.datetime.strptime(str(dn.time_counter.values[-1]),'%Y-%m-%d %H:%M:%S')    
tend = datetime.datetime(tend.year+1 if tend.month == 12 else tend.year, 1 if tend.month == 12 else tend.month+1, 1)
time = pd.date_range(tbeg, tend, freq = 'M')

tbeg = time[0]
tend = time[-1]
tbeg = datetime.datetime(tbeg.year, tbeg.month, 1)

time_beg = pd.date_range(tbeg, tend, freq='MS')
time_end = (pd.date_range(tbeg, freq='M', periods=time_beg.size)
             + pd.Timedelta('1d'))

time = time_beg + (time_end - time_beg) / 2

In [None]:
e3v=domcfg.e3v_0[0,:,900:1075,900:1025].load()
e1t=domcfg.e1t[0,900:1075,900:1025].load()

In [None]:
# Delayed open_dataset of LE files
@dask.delayed
def open_LE_delayed(t0, lem):
    return xr.open_dataset(glob('/gws/nopw/j04/canari/shared/large-ensemble/priority/HIST2/%s/OCN/yearly/%04d/*grid_V_vomecrty.nc'%(lem, t0))[0],
                               chunks={'time_counter':1})

def calc_bt_psi(t0, lem):
    # Set up psi calculation with delayed file opening
    var_data = ((open_LE_delayed(t0, lem)['vomecrty'][:,:,900:1075,900:1025]
                 *e3v).sum(dim='depthv')
                *e1t/1e6).cumsum(dim='x').data
    # Tell Dask the delayed function returns an array, and the size and type of that array
    return dask.array.from_delayed(var_data[np.newaxis,:], (1,12, 175, 125), d0['vomecrty'].dtype)

In [None]:
# Delayed open_dataset of LE files
@dask.delayed
def open_LE_delayed(t0, lem):
    return xr.open_dataset(glob('/gws/nopw/j04/canari/shared/large-ensemble/priority/HIST2/%s/OCN/yearly/%04d/*grid_V_vomecrty.nc'%(lem, t0))[0],
                               chunks={'time_counter':1})

def calc_bt_psi(t0, lem):
    # Set up psi calculation with delayed file opening
    var_data = ((open_LE_delayed(t0, lem)['vomecrty'][:,:,900:1075,900:1025]
                 *domcfg.e3v_0[0,:,900:1075,900:1025]).sum(dim='depthv')
                *domcfg.e1t[0,900:1075,900:1025]/1e6).cumsum(dim='x').data
    # Tell Dask the delayed function returns an array, and the size and type of that array
    return dask.array.from_delayed(var_data[np.newaxis,:], (1,12, 175, 125), d0['vomecrty'].dtype)

In [None]:
dA = []
for lem in LEM:

    delayed_psi = dask.array.concatenate([calc_bt_psi(t, lem) for t in np.arange(1950,2015)], axis=1)
    
    bt_psi = xr.DataArray(delayed_psi,
                 dims = ['ensemble member','time_counter','y','x'],
                 coords = {
                     'ensemble member': [lem,],
                     'time_counter': time,
                     'nav_lat': d0.nav_lat[900:1075,900:1025],
                     'nav_lon': d0.nav_lon[900:1075,900:1025],
                     'y': d0.y[900:1075],
                     'x': d0.x[900:1025],
                 },
                name = 'PSI',
                )
    start_time = timing.time()
    dA.append(bt_psi.min(dim=["x", "y"]).compute().compute())
    #tmp=bt_psi.min(dim=["x", "y"]).compute().compute()
    end_time = timing.time()
    execution_time = end_time - start_time
    print(f"Execution time of {lem}: {execution_time}")
    
    del delayed_psi 

psi_min = xr.concat(dA, dim='ensemble member')

In [None]:
client.close()
cluster.close()