In [6]:
import argparse
import json
import logging

import dask
import numpy as np
import xarray as xr

from dask.distributed import Client
import dask.config

In [7]:
dask.config.set(
    {'distributed.worker.memory.target':False,
     'distributed.worker.memory.spill':False,
     'distributed.worker.memory.pause':False,
     'distributed.worker.memory.terminate':False,}
)

<dask.config.set at 0x7f57d93efd60>

In [8]:
import sys

In [9]:
sys.path.append('../src')

In [10]:
from bc_module_v2 import bc_module
import helper_modules

import delayed_module

In [11]:
with open("../src/conf/domain_config.json", "r") as j:
    domain_config = json.loads(j.read())

In [12]:
with open("../src/conf/attribute_config.json", "r") as j:
    attribute_config = json.loads(j.read())

In [13]:
with open("../src/conf/variable_config.json", "r") as j:
    variable_config = json.loads(j.read())

In [14]:
domain_config = domain_config['west_africa']

In [15]:
variable_config = {
    key: value
    for key, value in variable_config.items()
    if key in domain_config["variables"]
}

In [16]:
reg_dir_dict, glob_dir_dict = helper_modules.set_and_make_dirs(domain_config)

In [17]:
syr_calib = domain_config["syr_calib"]
eyr_calib = domain_config["eyr_calib"]

In [18]:
client, cluster = helper_modules.getCluster('haswell', 2, 40)
        
client.get_versions(check=True)
client.amm.start()
         
print(f"Dask dashboard available at {client.dashboard_link}")

Perhaps you already have a cluster running?
Hosting the HTTP server on port 39561 instead


Dask dashboard available at http://172.27.80.110:39561/status


In [None]:
client.close()
cluster.close()

In [19]:
raw_full, pp_full, refrcst_full, ref_full = helper_modules.set_input_files(domain_config, reg_dir_dict, 4, 2016, 'tp')

In [20]:
coords = helper_modules.get_coords_from_frcst(raw_full)

In [21]:
global_attributes = helper_modules.update_global_attributes(
    attribute_config, domain_config["bc_params"], coords, 'west_africa'
)

In [22]:
encoding = helper_modules.set_encoding(variable_config, coords)

In [23]:
ds = helper_modules.create_4d_netcdf(
    pp_full,
    global_attributes,
    domain_config,
    variable_config,
    coords,
    'tp',
)

In [24]:
ds_obs = xr.open_zarr(ref_full, consolidated=False)
ds_obs = xr.open_zarr(
    ref_full,
    chunks={"time": len(ds_obs.time), "lat": 'auto', "lon": 'auto'},
    consolidated=False
    )
da_obs = ds_obs['tp']

In [25]:
ds_mdl = xr.open_zarr(refrcst_full, consolidated=False)
ds_mdl = xr.open_zarr(
    refrcst_full,
    chunks={
       "time": len(ds_mdl.time),
       "ens": len(ds_mdl.ens),
       "lat": 'auto',
       "lon": 'auto'
    },
    consolidated=False
    )
da_mdl = ds_mdl['tp']

In [26]:
ds_pred = xr.open_dataset(raw_full)
ds_pred = xr.open_mfdataset(
    raw_full,
    chunks={
        "time": 1,
        "ens": len(ds_pred.ens),
        "lat": 'auto',
        "lon": 'auto',
     },
     parallel=True,
     engine="netcdf4",
)
da_pred = ds_pred['tp'].persist()

In [27]:
da_temp = xr.DataArray(
    None,
    dims=["time", "lat", "lon", "ens"],
    coords={
         "time": (
             "time",
             coords["time"],
             {"standard_name": "time", "long_name": "time"},
          ),
           "ens": (
              "ens",
                            coords["ens"],
                            {
                                "standard_name": "realization",
                                "long_name": "ensemble_member",
                            },
                        ),
                        "lat": (
                            "lat",
                            coords["lat"],
                            {
                                "standard_name": "latitude",
                                "long_name": "latitude",
                                "units": "degrees_north",
                            },
                        ),
                        "lon": (
                            "lon",
                            coords["lon"],
                            {
                                "standard_name": "longitude",
                                "long_name": "longitude",
                                "units": "degrees_east",
                            },
                        ),
                    },
                ).persist()

In [28]:
timestep = 0

In [29]:
intersection_day_obs, intersection_day_mdl = delayed_module.get_intersect_days(timestep, domain_config, da_obs, da_mdl, da_pred)

In [None]:
%%time
with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    da_mdl_sub = da_mdl.where(da_mdl['time'] == act_dates, drop=True).persist()

In [46]:
%%time
da_obs_sub = da_obs.loc[dict(time=intersection_day_obs)]

with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    da_mdl_sub = da_mdl.loc[dict(time=intersection_day_mdl)]
    
da_mdl_sub = da_mdl_sub.stack(ens_time=("ens", "time"), create_index=True)
da_mdl_sub = da_mdl_sub.drop("time")

da_pred_sub = da_pred.isel(time=timestep)

CPU times: user 56.2 ms, sys: 17.5 ms, total: 73.7 ms
Wall time: 68.9 ms


In [None]:
da_mdl_sub

In [47]:
%%time
da_temp[timestep, :, :, :] = xr.apply_ufunc(
        bc_module,
        da_pred_sub,
        da_obs_sub,
        da_mdl_sub,
        kwargs={
            "bc_params": domain_config["bc_params"],
            "precip": variable_config['tp']["isprecip"],
        },
        input_core_dims=[["ens"], ["time"], ["ens_time"]],
        output_core_dims=[["ens"]],
        vectorize=True,
        dask="parallelized",
        #dask='allowed',
        output_dtypes=[np.float64])

KilledWorker: ("('rechunk-split-rechunk-merge-1cf032253a7bdab42cb251f63812df5e', 2, 4, 0)", <WorkerState 'tcp://172.27.80.131:36125', name: SLURMCluster-1-21, status: closed, memory: 0, processing: 94>)

In [None]:
tst[:,:,0].plot()