In [1]:
import argparse
import json
import logging

import dask
import numpy as np
import xarray as xr

from dask.distributed import Client
import dask.config

In [2]:
%load_ext autoreload
%autoreload 2

In [None]:
dask.config.set(
    {'distributed.worker.memory.target':False,
     'distributed.worker.memory.spill':False,
     'distributed.worker.memory.pause':False,
     'distributed.worker.memory.terminate':False,}
)

In [3]:
dask.config.set({'interface': 'lo'})

<dask.config.set at 0x7fdbed3f6a10>

In [4]:
import sys

In [5]:
sys.path.append('../src')

In [6]:
from bc_module_v3 import bc_module
import helper_modules

import delayed_module

In [7]:
with open("../src/conf/domain_config.json", "r") as j:
    domain_config = json.loads(j.read())

In [8]:
with open("../src/conf/attribute_config.json", "r") as j:
    attribute_config = json.loads(j.read())

In [9]:
with open("../src/conf/variable_config.json", "r") as j:
    variable_config = json.loads(j.read())

In [10]:
domain_config = domain_config['west_africa']

In [11]:
variable_config = {
    key: value
    for key, value in variable_config.items()
    if key in domain_config["variables"]
}

In [12]:
reg_dir_dict, glob_dir_dict = helper_modules.set_and_make_dirs(domain_config)

In [13]:
syr_calib = domain_config["syr_calib"]
eyr_calib = domain_config["eyr_calib"]

In [None]:
syr_calib = 1981
eyr_calib = 2011

In [14]:
client, cluster = helper_modules.getCluster('rome', 1, 40)
        
client.get_versions(check=True)
client.amm.start()
         
print(f"Dask dashboard available at {client.dashboard_link}")

Dask dashboard available at http://172.27.80.110:8787/status


In [None]:
client.close()
cluster.close()

In [20]:
raw_full, pp_full, refrcst_full, ref_full = helper_modules.set_input_files(domain_config, reg_dir_dict, 4, 2016, 'tp')

In [21]:
coords = helper_modules.get_coords_from_frcst(raw_full)

In [22]:
global_attributes = helper_modules.update_global_attributes(
    attribute_config, domain_config["bc_params"], coords, 'west_africa'
)

In [23]:
encoding = helper_modules.set_encoding(variable_config, coords)

In [24]:
ds = helper_modules.create_4d_netcdf(
    pp_full,
    global_attributes,
    domain_config,
    variable_config,
    coords,
    'tp',
)

In [28]:
ds_obs = xr.open_zarr(ref_full, consolidated=False)
ds_obs = xr.open_zarr(
    ref_full,
    chunks={"time": len(ds_obs.time), "lat": 1, "lon": 1},
    consolidated=False
    )
da_obs = ds_obs['tp']
da_obs = da_obs.isel(lat=[0,1,2], lon=[0,1,2]).persist()

In [29]:
ds_mdl = xr.open_zarr(refrcst_full, consolidated=False)
ds_mdl = xr.open_zarr(
    refrcst_full,
    chunks={
       "time": len(ds_mdl.time),
       "ens": len(ds_mdl.ens),
       "lat": 1,
       "lon": 1
    },
    consolidated=False
    )
da_mdl = ds_mdl['tp']
da_mdl = da_mdl.isel(lat=[0,1,2], lon=[0,1,2]).persist()

In [31]:
ds_pred = xr.open_dataset(raw_full)
ds_pred = xr.open_mfdataset(
    raw_full,
    chunks={
        "time": len(ds_pred.time),
        "ens": len(ds_pred.ens),
        "lat": 1,
        "lon": 1
     },
     parallel=False,
     engine="netcdf4",
)
da_pred = ds_pred['tp']
da_pred = da_pred.rename({'time': 'pred_time'})
da_pred = da_pred.isel(lat=[0,1,2], lon=[0,1,2]).persist()

In [None]:
da_temp = xr.DataArray(
    None,
    dims=["time", "lat", "lon", "ens"],
    coords={
         "time": (
             "time",
             coords["time"],
             {"standard_name": "time", "long_name": "time"},
          ),
           "ens": (
              "ens",
              coords["ens"],
                {
            "standard_name": "realization",
                                "long_name": "ensemble_member",
                            },
                        ),
                        "lat": (
                            "lat",
                            coords["lat"],
                            {
                                "standard_name": "latitude",
                                "long_name": "latitude",
                                "units": "degrees_north",
                            },
                        ),
                        "lon": (
                            "lon",
                            coords["lon"],
                            {
                                "standard_name": "longitude",
                                "long_name": "longitude",
                                "units": "degrees_east",
                            },
                        ),
                    },
                ).persist()

In [None]:
timestep = 0

In [None]:
intersection_day_obs, intersection_day_mdl = delayed_module.get_intersect_days(timestep, domain_config, da_obs, da_mdl, da_pred)

In [None]:
%%time
with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    da_mdl_sub = da_mdl.where(da_mdl['time'] == act_dates, drop=True).persist()

In [None]:
%%time
da_obs_sub = da_obs.loc[dict(time=intersection_day_obs)]
da_obs_sub = da_obs_sub.isel(lat=[0,1,2], lon=[0,1,2])

with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    da_mdl_sub = da_mdl.loc[dict(time=intersection_day_mdl)]
    
da_mdl_sub = da_mdl_sub.stack(ens_time=("ens", "time"), create_index=True)
da_mdl_sub = da_mdl_sub.drop("time")
da_mdl_sub = da_mdl_sub.isel(lat=[0,1,2], lon=[0,1,2])

da_pred_sub = da_pred.isel(time=timestep)
da_pred_sub = da_pred_sub.isel(lat=[0,1,2], lon=[0,1,2])

In [None]:
from bc_module_v3 import bc_module

In [None]:
%%time
da_temp[timestep, :, :, :] = xr.apply_ufunc(
        bc_module,
        da_pred_sub,
        da_obs_sub,
        da_mdl_sub,
        kwargs={
            "bc_params": domain_config["bc_params"],
            "precip": variable_config['tp']["isprecip"],
        },
        input_core_dims=[["ens"], ["time"], ["ens_time"]],
        output_core_dims=[["ens"]],

        vectorize=True,
        dask="parallelized",
        #dask='allowed',
        output_dtypes=[np.float64])

In [None]:
tst[:,:,0].plot()

In [None]:
                input_core_dims=[["ens", "pred_time"], ["time"], ["ens_time"]],
        output_core_dims=[["ens"]],
          
        join='outer',

In [None]:
da_mdl

In [37]:
test = xr.apply_ufunc(
        bc_module,
        da_pred,
        da_obs,
        da_mdl,
        kwargs={
            "domain_config": domain_config,
            "precip": variable_config['tp']["isprecip"],
        },
        input_core_dims=[["ens"], ["time"], ["time", "ens"]],
        output_core_dims=[["ens"]],
        exclude_dims = set(("time",)),  
        vectorize=False,
        dask="parallelized",
        #dask='allowed',
        output_dtypes=[np.float64])

In [39]:
test

Unnamed: 0,Array,Chunk
Bytes,377.93 kiB,41.99 kiB
Shape,"(215, 3, 3, 25)","(215, 1, 1, 25)"
Count,81 Tasks,9 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 377.93 kiB 41.99 kiB Shape (215, 3, 3, 25) (215, 1, 1, 25) Count 81 Tasks 9 Chunks Type float64 numpy.ndarray",215  1  25  3  3,

Unnamed: 0,Array,Chunk
Bytes,377.93 kiB,41.99 kiB
Shape,"(215, 3, 3, 25)","(215, 1, 1, 25)"
Count,81 Tasks,9 Chunks
Type,float64,numpy.ndarray


In [38]:
test.compute()

KilledWorker: ("('transpose-34ac0a88d7f70e58bc27af6e091404ab', 2, 1, 0, 0)", <WorkerState 'tcp://172.27.80.227:36681', name: SLURMCluster-0-35, status: closed, memory: 0, processing: 3>)

In [33]:
test

Unnamed: 0,Array,Chunk
Bytes,377.93 kiB,41.99 kiB
Shape,"(215, 3, 3, 25)","(215, 1, 1, 25)"
Count,81 Tasks,9 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 377.93 kiB 41.99 kiB Shape (215, 3, 3, 25) (215, 1, 1, 25) Count 81 Tasks 9 Chunks Type float64 numpy.ndarray",215  1  25  3  3,

Unnamed: 0,Array,Chunk
Bytes,377.93 kiB,41.99 kiB
Shape,"(215, 3, 3, 25)","(215, 1, 1, 25)"
Count,81 Tasks,9 Chunks
Type,float64,numpy.ndarray
