In [4]:
import xarray as xr
import numpy as np
from numpy import pi, sin, cos, arccos, clip, deg2rad
import numpy.ma as ma
from datetime import datetime
import dask
import time


@dask.delayed
def load_and_compute(dsmapper, lat1, lon1, window_size, plevel):
    
    """
    Load dataset, reduce it to data inside window and do computation
    """

    dynht_ds = xr.open_zarr(dsmapper, consolidated=True, chunks=None)
    distance = xr.apply_ufunc(great_circle_distance, lat1, lon1, dynht_ds.LAT.load(), dynht_ds.LON.load())
    ds_radius = dynht_ds.where(distance < window_size, drop=True)
    data = ds_radius.DH[plevel,:].load().values
    ii = ~xr.apply_ufunc(np.isnan, data)
    
    # Converting from Matlab datenumber into days since 1970-01-01
    dt_1970_01_01 = datetime.toordinal(datetime(1970,1,1,0,0,0,0)) + 366.0  # https://stackoverflow.com/questions/32991934/equivalent-function-of-datenumdatestring-of-matlab-in-python
    time = ds_radius.DATENUM.load().values[ii] - dt_1970_01_01
    data, lat, lon = data[ii], ds_radius.LAT.values[ii], ds_radius.LON.values[ii]

    # Computation that uses data, lat, lon, time,
    # ...
    # ...
    # and returns the results.
    
    return data+lat+lon+time


@dask.delayed
def load_only(dsmapper, plevel):
    
    """
    Load data and the coordinates only, without calling distance-function. Just for reference.
    """

    dynht_ds = xr.open_zarr(dsmapper, consolidated=True, chunks=None)
    data = dynht_ds.DH.sel(level=plevel).load().values
    ii = ~xr.apply_ufunc(np.isnan, data)
    
    # Converting from Matlab datenumber to days since 1970-01-01
    dt_1970_01_01 = datetime.toordinal(datetime(1970,1,1,0,0,0,0)) + 366.0  # https://stackoverflow.com/questions/32991934/equivalent-function-of-datenumdatestring-of-matlab-in-python
    time = dynht_ds.DATENUM.load().values[ii] - dt_1970_01_01
    data, lat, lon = data[ii], dynht_ds.LAT.values[ii], dynht_ds.LON.values[ii]
    
    return data+lat+lon+time


def load_data_in_window_and_compute( plevel, window_size ):
    
    dspath='.../global_dynamic_height.zarr'
    dsmapper = gcs.get_mapper(dspath)
    
    maskmapper = gcs.get_mapper('.../global_1x1_parametermask.zarr')
    maskdata = xr.open_dataset(maskmapper, engine="zarr", consolidated=True)
    mask = maskdata.mask.values
    long = maskdata.longitude.values
    latg = maskdata.latitude.values
    lonB,latB = np.meshgrid(long,latg)
    lonB,latB = ma.MaskedArray(lonB, 1-mask),ma.MaskedArray(latB, 1-mask)

    list_of_delayed_objects = []
    for latmasked, lonmasked in zip(latB.compressed()[:1], lonB.compressed()[:1]):
        returned = load_and_compute(dsmapper, latmasked, lonmasked, window_size, plevel)
        list_of_delayed_objects.append( returned )
        
    return list_of_delayed_objects


def loads_data_only(plevel):
    
    dspath='...global_dynamic_height.zarr'
    dsmapper = gcs.get_mapper(dspath)
    
    maskmapper = gcs.get_mapper('.../global_1x1_parametermask.zarr')
    maskdata = xr.open_dataset(maskmapper, engine="zarr", consolidated=True)
    mask = maskdata.mask.values
    long = maskdata.longitude.values
    latg = maskdata.latitude.values
    lonB,latB = np.meshgrid(long,latg)
    lonB,latB = ma.MaskedArray(lonB, 1-mask),ma.MaskedArray(latB, 1-mask)

    list_of_delayed_objects = []
    for latmasked, lonmasked in zip(latB.compressed()[:1], lonB.compressed()[:1]):
        returned = load_only(dsmapper, plevel)
        list_of_delayed_objects.append( returned )
        
    # Returning list of computations
    return list_of_delayed_objects


def great_circle_distance(lat1, lon1, lat2, lon2):
        
    # from https://earth-env-data-science.github.io/lectures/core_python/organization_and_packaging.html#modules

    # approximate radius of Earth
    R_earth = 6.371e6

    # unpack and convert everything to radians

    phi1, lambda1, phi2, lambda2 = deg2rad(lat1), deg2rad(lon1) , deg2rad(lat2), deg2rad(lon2)

    # apply formula
    # https://en.wikipedia.org/wiki/Great-circle_distance

    return R_earth * arccos( clip( (sin(phi1) * sin(phi2) + cos(phi1) * cos(phi2) * cos(lambda2 - lambda1)), -1, 1 ) )



### Connect to a dask-cluster and set cluster-options

In [2]:
import json
import gcsfs
import pangeo_token

with open(pangeo_token.token) as f:
    token = json.load(f)
gcs = gcsfs.GCSFileSystem(token=token)

from dask_gateway import GatewayCluster, Gateway
from distributed import Client

g = Gateway()
g.list_clusters()

[]

In [None]:
# cluster = g.connect(g.list_clusters()[0].name)

In [3]:
options = g.cluster_options()
options.worker_cores = 2; options.worker_memory = 4
# Create a cluster with those options
cluster = g.new_cluster(options)

In [5]:
g.list_clusters()

[ClusterReport<name=prod.f8047fb21c314d7392a5b76d94f5f32a, status=RUNNING>]

In [6]:
client = Client(cluster)
client

0,1
Connection method: Cluster object,Cluster type: dask_gateway.GatewayCluster
Dashboard: /services/dask-gateway/clusters/prod.f8047fb21c314d7392a5b76d94f5f32a/status,


In [7]:
cluster.scale(1)

### Loading of zarr-files from cloud-storage

Loading dataset only:

In [15]:
%%time 
list_of_global_data_lat_lon_time = dask.compute( loads_data_only(0) )[0]

CPU times: user 40 ms, sys: 10 ms, total: 50 ms
Wall time: 4.41 s


Loading dataset and reduce to data within window:

In [16]:
window_size = 500e3

In [17]:
%%time 
list_of_data_lat_lon_time_in_window = dask.compute( load_data_in_window_and_compute(0,window_size) )[0]

CPU times: user 29.3 ms, sys: 4.02 ms, total: 33.3 ms
Wall time: 5.32 s


#### Scaling down and closing cluster

In [18]:
cluster.scale(0)

In [19]:
cluster.close()

In [20]:
cluster.shutdown()

2023-07-13 12:45:44,375 - distributed.client - ERROR - Failed to reconnect to scheduler after 30.00 seconds, closing client
