In [1]:
import argparse
import json
import logging

import dask
import numpy as np
import xarray as xr

from dask.distributed import Client
import dask.config

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
dask.config.set(
    {'distributed.worker.memory.target':False,
     'distributed.worker.memory.spill':False,
     'distributed.worker.memory.pause':False,
     'distributed.worker.memory.terminate':False,}
)

<dask.config.set at 0x7f93f0177b50>

In [4]:
# print(dask.config.config)

In [5]:
dask.config.set({'interface': 'lo'})

<dask.config.set at 0x7f93f0177c10>

In [6]:
import sys

In [7]:
sys.path.append('../src')

In [8]:
from bc_module_v2 import bc_module
import helper_modules

import delayed_module

In [9]:
with open("../src/conf/domain_config.json", "r") as j:
    domain_config = json.loads(j.read())

In [10]:
with open("../src/conf/attribute_config.json", "r") as j:
    attribute_config = json.loads(j.read())

In [11]:
with open("../src/conf/variable_config.json", "r") as j:
    variable_config = json.loads(j.read())

In [12]:
domain_config = domain_config['west_africa']

In [13]:
variable_config = {
    key: value
    for key, value in variable_config.items()
    if key in domain_config["variables"]
}

In [14]:
reg_dir_dict, glob_dir_dict = helper_modules.set_and_make_dirs(domain_config)

In [15]:
syr_calib = domain_config["syr_calib"]
eyr_calib = domain_config["eyr_calib"]

In [16]:
syr_calib = 1981
eyr_calib = 2011

In [32]:
client, cluster = helper_modules.getCluster('rome', 1, 30)
        
client.get_versions(check=True)
client.amm.start()
         
print(f"Dask dashboard available at {client.dashboard_link}")

Dask dashboard available at http://172.27.80.110:42077/status


Perhaps you already have a cluster running?
Hosting the HTTP server on port 42077 instead


In [None]:
client.run(lambda dask_worker: {a: getattr(dask_worker.memory_manager, a) for a in ["memory_limit", "memory_target_fraction", "memory_spill_fraction", "memory_pause_fraction"]})

In [None]:
# Setting: 
# 'fat', 1, 40 --> (lat:20, lon:20) one time step 10 seconds
# 'rome', 1, 40 --> (lat:20, lon:20) one time step too long (maybe beacuse of disk-write and read???????)
# 'rome', 1, 40 --> (lat:12, lon:12) one time step too long (maybe beacuse of disk-write and read???????)
# 'rome', 1, 35 --> (lat:12, lon:12) one time step too long (maybe beacuse of disk-write and read???????)
# 'rome', 1, 35 --> (lat:20, lon:20) one time step too long (maybe beacuse of disk-write and read???????)
# 'fat', 1, 40 --> (lat:20, lon:20) one time step 10 seconds (no writing on disk)
# 'haswell', 1, 20 --> (lat:10, lon:10) one time step too long (maybe beacuse of disk-write and read???????)
# 'haswell', 1, 10 --> (lat:20, lon:20) one time step too long (maybe beacuse of disk-write and read???????)
# 'rome', 1, 30 --> (lat:15, lon:15), without dask spilling on disk, but anyway slow (some worker gets red, because of out of memory)
# 'rome', 1, 30 --> (lat:10, lon:10), without dask spilling on disk, but anyway slow (some worker gets red, because of out of memory)

In [31]:
client.close()
cluster.close()



In [33]:
raw_full, pp_full, refrcst_full, ref_full = helper_modules.set_input_files(domain_config, reg_dir_dict, 4, 2016, 'tp')

In [34]:
coords = helper_modules.get_coords_from_frcst(raw_full)

In [35]:
global_attributes = helper_modules.update_global_attributes(
    attribute_config, domain_config["bc_params"], coords, 'west_africa'
)

In [36]:
encoding = helper_modules.set_encoding(variable_config, coords)

In [37]:
ds = helper_modules.create_4d_netcdf(
    pp_full,
    global_attributes,
    domain_config,
    variable_config,
    coords,
    'tp',
)

In [23]:
ref_full

'/bg/data/NCZarr/s2s_forecasts/west_africa/02_reference/zarr_stores/ERA5_Land_0.1_linechunks.zarr'

In [None]:
client.restart()

In [None]:
ds_pred.close()
ds_obs.close()
ds_mdl.close()

In [38]:
# best option until now 10 for each (about 18 seconds for each time step) --> 100 time steps about 30 min.
lat = 10
lon = 10
latlon = 200

In [None]:
np.arange(10)

In [39]:
ds_obs = xr.open_zarr(ref_full, consolidated=False)
ds_obs = xr.open_zarr(
    ref_full,
    chunks={"time": len(ds_obs.time), "lat": lat, "lon": lon},
    consolidated=False
    )
da_obs = ds_obs['tp']
# da_obs = da_obs.isel(lat=np.arange(latlon), lon=np.arange(latlon))
# da_obs = ds_obs.persist()
# da_obs = dask.delayed(da_obs)

In [40]:
ds_mdl = xr.open_zarr(refrcst_full, consolidated=False)
ds_mdl = xr.open_zarr(
    refrcst_full,
    chunks={
       "time": len(ds_mdl.time),
       "ens": len(ds_mdl.ens),
       "lat": lat,
       "lon": lon
    },
    consolidated=False
    )
da_mdl = ds_mdl['tp']
# da_mdl = da_mdl.isel(lat=np.arange(latlon), lon=np.arange(latlon))
# da_mdl = ds_mdl.persist()
#da_mdl = dask.delayed(da_mdl)

In [41]:
ds_pred = xr.open_dataset(raw_full)
ds_pred = xr.open_mfdataset(
    raw_full,
    chunks={
        "time": len(ds_pred.time),
        "ens": len(ds_pred.ens),
        "lat": lat,
        "lon": lon
     },
     parallel=False,
     engine="netcdf4",
)
da_pred = ds_pred['tp']
# da_pred = da_pred.rename({'time': 'pred_time'})
# da_pred = da_pred.isel(lat=np.arange(latlon), lon=np.arange(latlon))
# da_pred = da_pred.persist()
# da_pred = dask.delayed(da_pred)

In [42]:
da_temp = xr.DataArray(
    None,
    dims=["time", "lat", "lon", "ens"],
    coords={
         "time": (
             "time",
             coords["time"],
             {"standard_name": "time", "long_name": "time"},
          ),
           "ens": (
              "ens",
              coords["ens"],
                {
            "standard_name": "realization",
                                "long_name": "ensemble_member",
                            },
                        ),
                        "lat": (
                            "lat",
                            coords["lat"],
                            {
                                "standard_name": "latitude",
                                "long_name": "latitude",
                                "units": "degrees_north",
                            },
                        ),
                        "lon": (
                            "lon",
                            coords["lon"],
                            {
                                "standard_name": "longitude",
                                "long_name": "longitude",
                                "units": "degrees_east",
                            },
                        ),
                    },
                ).persist()

In [None]:
timestep = 0

In [None]:
intersection_day_obs, intersection_day_mdl = delayed_module.get_intersect_days(timestep, domain_config, da_obs, da_mdl, da_pred)

In [None]:
%%time
for timestep in range(0,10):
    intersection_day_obs, intersection_day_mdl = delayed_module.get_intersect_days(timestep, domain_config, da_obs, da_mdl, da_pred)

    da_obs_sub = da_obs.loc[dict(time=intersection_day_obs)]
    # da_obs_sub = da_obs_sub.isel(lat=np.arange(latlon), lon=np.arange(latlon))
    # da_obs_sub= dask.delayed(da_obs_sub)
    da_obs_sub.to_netcdf("test_obs_"+ str(timestep) +".nc")

    with dask.config.set(**{'array.slicing.split_large_chunks': False}):
        da_mdl_sub = da_mdl.loc[dict(time=intersection_day_mdl)]



    da_mdl_sub = da_mdl_sub.stack(ens_time=("ens", "time"), create_index=True)
    da_mdl_sub = da_mdl_sub.drop("time")
    # da_mdl_sub = da_mdl_sub.isel(lat=np.arange(latlon), lon=np.arange(latlon))
    da_mdl_sub.to_netcdf("test_mdl_"+ str(timestep) +".nc")
    # da_mdl_sub= dask.delayed(da_mdl_sub)

    da_pred_sub = da_pred.isel(time=timestep)
    # da_pred_sub = da_pred_sub.isel(lat=np.arange(latlon), lon=np.arange(latlon))
    da_pred_sub.to_netcdf("test_pred_"+ str(timestep) +".nc")
    # da_pred_sub= dask.delayed(da_pred_sub)

In [43]:
%%time
da_obs = da_obs.persist()
da_mdl = da_mdl.persist()
da_pred = da_pred.persist()

CPU times: user 217 ms, sys: 16.3 ms, total: 233 ms
Wall time: 225 ms


In [46]:
client.rebalance()

In [None]:
da_obs.close()

In [None]:
%%time
with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    da_mdl_sub = da_mdl.where(da_mdl['time'] == act_dates, drop=True).persist()

In [None]:
%%time
da_obs_sub = da_obs.loc[dict(time=intersection_day_obs)]
# da_obs_sub = da_obs_sub.isel(lat=np.arange(latlon), lon=np.arange(latlon))
# da_obs_sub= dask.delayed(da_obs_sub)
# da_obs_sub.to_netcdf("test_obs_"+ str(timestep) +".nc")

with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    da_mdl_sub = da_mdl.loc[dict(time=intersection_day_mdl)]
    

    
da_mdl_sub = da_mdl_sub.stack(ens_time=("ens", "time"), create_index=True)
da_mdl_sub = da_mdl_sub.drop("time")
# da_mdl_sub = da_mdl_sub.isel(lat=np.arange(latlon), lon=np.arange(latlon))
# da_mdl_sub.to_netcdf("test_mdl_"+ str(timestep) +".nc")
# da_mdl_sub= dask.delayed(da_mdl_sub)

da_pred_sub = da_pred.isel(time=timestep)
# da_pred_sub = da_pred_sub.isel(lat=np.arange(latlon), lon=np.arange(latlon))
# da_pred_sub.to_netcdf("test_pred_"+ str(timestep) +".nc")
# da_pred_sub= dask.delayed(da_pred_sub)

In [None]:
%%time
da_temp[timestep, :, :, :] = xr.apply_ufunc(
            bc_module,
            da_pred_sub,
            da_obs_sub,
            da_mdl_sub,
            kwargs={
                "bc_params": domain_config["bc_params"],
                "precip": variable_config['tp']["isprecip"],
            },
            input_core_dims=[["ens"], ["time"], ["ens_time"]],
            output_core_dims=[["ens"]],
            vectorize=True,
            dask="parallelized",
            # dask='allowed',
            output_dtypes=[np.float64])

In [None]:
lat = 15
lon = 15

In [None]:
da_pred_sub = xr.open_mfdataset("test_pred_0.nc", chunks={'ens': len(ds_pred.ens), 'lat': lat, 'lon': lon}, parallel=True, engine='netcdf4')
da_pred_sub = da_pred_sub.tp

In [None]:
da_mdl_sub = xr.open_mfdataset("test_mdl_0.nc", chunks={'lat': lat, 'lon': lon}, parallel=True, engine='netcdf4')
da_mdl_sub = da_mdl_sub.tp

In [None]:
da_obs_sub = xr.open_mfdataset("test_obs_0.nc", chunks={'lat': lat, 'lon': lon}, parallel=True, engine='netcdf4')
da_obs_sub = da_obs_sub.tp

In [None]:
%%time
lat = 20
lon = 20

for timestep in range(0,10):
    print("correct timestep: " + str(timestep))
    da_pred_sub = xr.open_mfdataset("test_pred_"+ str(timestep) +".nc", chunks={'ens': len(ds_pred.ens), 'lat': lat, 'lon': lon}, parallel=True, engine='netcdf4')
    da_pred_sub = da_pred_sub.tp

    da_mdl_sub = xr.open_mfdataset("test_mdl_"+ str(timestep) +".nc", chunks={'lat': lat, 'lon': lon}, parallel=True, engine='netcdf4')
    da_mdl_sub = da_mdl_sub.tp

    da_obs_sub = xr.open_mfdataset("test_obs_"+ str(timestep) +".nc", chunks={'lat': lat, 'lon': lon}, parallel=True, engine='netcdf4')
    da_obs_sub = da_obs_sub.tp


    da_temp[timestep, :, :, :] = xr.apply_ufunc(
            bc_module,
            da_pred_sub,
            da_obs_sub,
            da_mdl_sub,
            kwargs={
                "bc_params": domain_config["bc_params"],
                "precip": variable_config['tp']["isprecip"],
            },
            input_core_dims=[["ens"], ["time"], ["ens_time"]],
            output_core_dims=[["ens"]],
            vectorize=True,
            dask="parallelized",
            # dask='allowed',
            output_dtypes=[np.float64])

In [47]:
%%time

for timestep in range(0,214):
    print("correct timestep: " + str(timestep))
    intersection_day_obs, intersection_day_mdl = delayed_module.get_intersect_days(timestep, domain_config, da_obs, da_mdl, da_pred)
    
    
    da_obs_sub = da_obs.loc[dict(time=intersection_day_obs)]
    # da_obs_sub = da_obs_sub.isel(lat=np.arange(latlon), lon=np.arange(latlon))
    # da_obs_sub= dask.delayed(da_obs_sub)
    # da_obs_sub.to_netcdf("test_obs_"+ str(timestep) +".nc")

    with dask.config.set(**{'array.slicing.split_large_chunks': False}):
        da_mdl_sub = da_mdl.loc[dict(time=intersection_day_mdl)]



    da_mdl_sub = da_mdl_sub.stack(ens_time=("ens", "time"), create_index=True)
    da_mdl_sub = da_mdl_sub.drop("time")
    # da_mdl_sub = da_mdl_sub.isel(lat=np.arange(latlon), lon=np.arange(latlon))
    # da_mdl_sub.to_netcdf("test_mdl_"+ str(timestep) +".nc")
    # da_mdl_sub= dask.delayed(da_mdl_sub)

    da_pred_sub = da_pred.isel(time=timestep)
    # da_pred_sub = da_pred_sub.isel(lat=np.arange(latlon), lon=np.arange(latlon))
    # da_pred_sub.to_netcdf("test_pred_"+ str(timestep) +".nc")
    # da_pred_sub= dask.delayed(da_pred_sub)


    da_temp[timestep, :, :, :] = xr.apply_ufunc(
            bc_module,
            da_pred_sub,
            da_obs_sub,
            da_mdl_sub,
            kwargs={
                "bc_params": domain_config["bc_params"],
                "precip": variable_config['tp']["isprecip"],
            },
            input_core_dims=[["ens"], ["time"], ["ens_time"]],
            output_core_dims=[["ens"]],
            vectorize=True,
            dask="parallelized",
            # dask='allowed',
            output_dtypes=[np.float64])



correct timestep: 0
correct timestep: 1
correct timestep: 2
correct timestep: 3
correct timestep: 4
correct timestep: 5
correct timestep: 6
correct timestep: 7
correct timestep: 8
correct timestep: 9
correct timestep: 10
correct timestep: 11
correct timestep: 12
correct timestep: 13
correct timestep: 14
correct timestep: 15
correct timestep: 16
correct timestep: 17
correct timestep: 18
correct timestep: 19
correct timestep: 20
correct timestep: 21
correct timestep: 22
correct timestep: 23
correct timestep: 24
correct timestep: 25
correct timestep: 26
correct timestep: 27
correct timestep: 28
correct timestep: 29
correct timestep: 30
correct timestep: 31
correct timestep: 32
correct timestep: 33
correct timestep: 34
correct timestep: 35
correct timestep: 36
correct timestep: 37
correct timestep: 38
correct timestep: 39
correct timestep: 40
correct timestep: 41
correct timestep: 42
correct timestep: 43
correct timestep: 44
correct timestep: 45
correct timestep: 46
correct timestep: 47
co

In [49]:
del ds_obs

Task exception was never retrieved
future: <Task finished name='Task-2287325' coro=<Client._gather.<locals>.wait() done, defined at /home/borkenhagen-c/miniconda3/envs/forecast_bias_correction/lib/python3.10/site-packages/distributed/client.py:2119> exception=AllExit()>
Traceback (most recent call last):
  File "/home/borkenhagen-c/miniconda3/envs/forecast_bias_correction/lib/python3.10/site-packages/distributed/client.py", line 2128, in wait
    raise AllExit()
distributed.client.AllExit
Task exception was never retrieved
future: <Task finished name='Task-2287299' coro=<Client._gather.<locals>.wait() done, defined at /home/borkenhagen-c/miniconda3/envs/forecast_bias_correction/lib/python3.10/site-packages/distributed/client.py:2119> exception=AllExit()>
Traceback (most recent call last):
  File "/home/borkenhagen-c/miniconda3/envs/forecast_bias_correction/lib/python3.10/site-packages/distributed/client.py", line 2128, in wait
    raise AllExit()
distributed.client.AllExit
Task excepti

NameError: name 'ds_obs' is not defined

In [None]:
test = da_temp.mean(dim="ens")

In [None]:
test.isel(time=0).plot()

In [None]:
test.isel(time=1).plot()

In [None]:
%%time
test.compute()

In [None]:
test

In [None]:
test.values

In [None]:
mdl = da_mdl_sub.isel(lat=0, lon=0)
obs = da_obs_sub.isel(lat=0, lon=0)
pred = da_pred_sub.isel(lat=0, lon=0)

In [None]:
%%time
# bc_params = {
#    'dry_thresh': 0.01,
#    'precip': True,
#    'low_extrapol': "delta_additive",
#    'up_extrapol': "delta_additive",
#    'extremes': "weibull",
#    'intermittency': True,
#    'nquants': 2500
# }

# nts = len(pred.time.values)

# print(mdl)

ds_nan = pred.copy()
ds_nan[:] = np.nan
ds_mean = ds_nan

# only do the bc-calculation, if obs and mdl are not NAN
# print(np.any(~np.isnan(obs)))
# print(np.any(~np.isnan(mdl)))
# Check if we have sufficient values for the bcsd
if np.any(~np.isnan(obs)) and np.any(~np.isnan(mdl)):
    if len(np.unique(mdl)) > 10 and len(np.unique(obs)) > 10:

        # if np.any(~np.isnan(obs)) and np.any(~np.isnan(mdl)):
        # nmdl = mdl.shape[0]
        nmdl = bc_params["nquants"]
        p_min_mdl = 1 / (nmdl + 1)
        p_max_mdl = nmdl / (nmdl + 1)
        p_mdl = np.linspace(p_min_mdl, p_max_mdl, nmdl)
        q_mdl = np.nanquantile(mdl, p_mdl, interpolation="midpoint")

        # obs quantile
        nobs = obs.shape[0]
        p_min_obs = 1 / (nobs + 1)
        p_max_obs = nobs / (nobs + 1)
        p_obs = np.linspace(p_min_obs, p_max_obs, nobs)
        q_obs = np.nanquantile(obs, p_obs, interpolation="midpoint")

        # Interpolate
        # Remove the dublicate values
        q_mdl, ids_mdl = np.unique(q_mdl, return_index=True)
        p_mdl = p_mdl[ids_mdl]

        # print(q_mdl)

        pred = pred.copy()
        # pred_1 = pred.copy()

        pred[pred > max(q_mdl)] = max(q_mdl)
        pred[pred < min(q_mdl)] = min(q_mdl)

        # if len(q_mdl)>1 and ~np.isnan(q_mdl.item(0)):
        # Transform the predictions to the rank space
        # from scipy.interpolate import interp1d
        Y_pred = interp1d(q_mdl, p_mdl)(pred)
        # else:
        # create nan-array with size, that match pred and contains nan
        #     Y_pred = ds_nan

        q_obs, ids_obs = np.unique(q_obs, return_index=True)
        p_obs = p_obs[ids_obs]

        # if len(q_obs)>1 and ~np.isnan(q_obs.item(0)):
        # Transform the predictions to the rank space
        # from scipy.interpolate import interp1d
        # Y_pred = interp1d(q_obs,p_obs)(pred)
        # else:
        # Y_pred = pred

        # pred_corr = interp1d(p_obs, q_obs, fill_value='extrapolate')(Y_pred) #bounds_error=True
        pred_corr = np.interp(Y_pred, p_obs, q_obs, left=np.nan, right=np.nan)
        # else:
        # pred_corr = ds_nan

        if precip:
            # print("True")
            p_dry_obs = len(np.where(obs < bc_params["dry_thresh"])[0]) / len(obs)
            p_dry_mdl = len(np.where(mdl < bc_params["dry_thresh"])[0]) / len(mdl)
            # print(p_dry_obs, p_dry_mdl)

        # Check if any of the prediction probabilities are above or below the
        # maximum or minimum observation probabilities
        if precip:
            up = np.where((Y_pred > p_max_obs) & (pred > 0))[0]
            low = np.where((Y_pred < p_min_obs) & (pred > 0))[0]
            # print(low)
        else:
            up = np.where(Y_pred > p_max_obs)[0]
            low = np.where(Y_pred < p_min_obs)[0]
            # print(low)

        # pred_corr = pred_corr.copy()

        if up.size != 0:
            if bc_params["up_extrapol"] == "constant":
                pred_corr[up] = np.max(obs)
            elif bc_params["up_extrapol"] == "distribution":
                if precip:
                    # Fit an extreme-value distribution to the observations
                    # from scipy.stats import gumbel_l
                    pd = gumbel_l.fit(obs)
                    pred_corr[up] = gumbel_l.ppf(Y_pred[up], pd[0], pd[1])
                else:
                    # from scipy.stats import norm
                    [MUHAT, SIGMAHAT] = norm.fit(obs)
                    pred_corr[up] = norm.ppf(Y_pred[up], MUHAT, SIGMAHAT)

            elif bc_params["up_extrapol"] == "delta_additive":
                delta = np.quantile(
                    obs, p_max_obs, interpolation="midpoint"
                ) - np.quantile(mdl, p_max_obs, interpolation="midpoint")
                pred_corr[up] = pred[up] + delta

            elif bc_params["up_extrapol"] == "delta_scaling":
                delta = np.quantile(
                    obs, p_max_obs, interpolation="midpoint"
                ) / np.quantile(mdl, p_max_obs, interpolation="midpoint")
                pred_corr[up] = pred[up] * delta

        if up.size != 0:
            if bc_params["low_extrapol"] == "constant":
                pred_corr[low] = np.min(obs)
            elif bc_params["low_extrapol"] == "distribution":
                if precip:
                    # Fit an extreme-value distribution to the observations
                    # There is a huge problem with packages for Weibull-Distribution in Matlab.
                    # The scipy.stats.weibull_min performs poor, maybe due to a different optimizer.
                    # Use instead Packages like: surpyval, or reliability
                    # import surpyval as surv
                    # from surpyval import Weibull
                    model = surv.Weibull.fit(obs[obs > 0])
                    pd = [model.alpha, model.beta]
                    # pred_corr[low] = surv.Weibull.qf(Y_pred[low], alpha, beta)
                else:
                    # from scipy.stats import norm
                    [MUHAT, SIGMAHAT] = norm.fit(obs)
                    pred_corr[low] = norm.ppf(Y_pred[low], MUHAT, SIGMAHAT)
            elif bc_params["low_extrapol"] == "delta_additive":
                delta = np.quantile(
                    obs, p_min_obs, interpolation="midpoint"
                ) - np.quantile(mdl, p_min_obs, interpolation="midpoint")
                pred_corr[low] = pred[low] + delta
            elif bc_params["low_extrapol"] == "delta_scaling":
                delta = np.quantile(
                    obs, p_min_obs, interpolation="midpoint"
                ) / np.quantile(mdl, p_min_obs, interpolation="midpoint")
                pred_corr[low] = pred[low] * delta

                # Intermittency correction for precipitation
        if precip:

            # Set the precipitation values with lower probabilities than the
            #  dry-day probability of the observations to 0.
            pred_corr[Y_pred <= p_dry_obs] = 0

            if bc_params["intermittency"]:
                # Search for dry days in the predictions
                zero_pred = np.where(pred < bc_params["dry_thresh"])[0]

                if p_dry_obs >= p_dry_mdl:
                    # If the dry-day probability of the observations is higher than
                    # the model, set the corresponding forecast values to 0
                    pred_corr[zero_pred] = 0
                elif p_dry_obs < p_dry_mdl:
                    # If the dry-day probability of the model is higher than the
                    # observations, do some magic...
                    if p_dry_mdl > 0:
                        # First, draw some uniform random samples between 0 and the
                        # dry-day probability of the model
                        zero_smples = p_dry_mdl * np.random.rand(len(zero_pred))
                        # Transform these random samples to the data space
                        if bc_params["extremes"] == "weibull":
                            # if len(q_obs)>1 and ~np.isnan(q_obs.item(0)):
                            # zero_corr = interp1d(p_obs, q_obs, bounds_error=False)(zero_smples)
                            zero_corr = np.interp(
                                zero_smples, p_obs, q_obs, left=np.nan, right=np.nan
                            )
                            ######################
                            # Erstmal draußen lassen, brauchen wir erstmal nicht gibt auch kein Plug&Play für
                            # "icdf" in Python
                            ######################
                            # else:
                            # zero_corr   = icdf(Ofit, zero_smples);
                            # else:
                            #   zero_corr = zero_smples

                            # Now, set all transfomed random samples with probabilities
                            # lower than the dry day probability of the observations to
                            # zero.
                            zero_corr[zero_smples <= p_dry_obs] = 0
                            # Replace the elements in the predictions with the
                            # corresponding intermittency-corrected values.
                            pred_corr[zero_pred] = zero_corr
                    else:

                        pred_corr[zero_pred] = 0
                        # If the probability of a dry day is 0 (which might happen
                        # in some very ... cases), we simply set the probabilities,
                        # which correspond to the forecasted zero values, to the
                        # minimum probability of the observations.

    else:
        ds_mean[:] = np.nanmean(obs)
        pred_corr = ds_mean

else:
    pred_corr = ds_nan

In [None]:
%%time
da_temp[timestep, :, :, :] = xr.apply_ufunc(
        bc_module,
        da_pred_sub,
        da_obs_sub,
        da_mdl_sub,
        kwargs={
            "bc_params": domain_config["bc_params"],
            "precip": variable_config['tp']["isprecip"],
        },
        input_core_dims=[["ens"], ["time"], ["ens_time"]],
        output_core_dims=[["ens"]],

        vectorize=True,
        dask="parallelized",
        #dask='allowed',
        output_dtypes=[np.float64])

In [None]:
tst[:,:,0].plot()

In [None]:
                input_core_dims=[["ens", "pred_time"], ["time"], ["ens_time"]],
        output_core_dims=[["ens"]],
          
        join='outer',

In [None]:
da_mdl

In [None]:
test = xr.apply_ufunc(
        bc_module,
        da_pred,
        da_obs,
        da_mdl,
        kwargs={
            "domain_config": domain_config,
            "precip": variable_config['tp']["isprecip"],
        },
        input_core_dims=[["ens"], ["time"], ["time", "ens"]],
        output_core_dims=[["ens"]],
        exclude_dims = set(("time",)),  
        vectorize=False,
        dask="parallelized",
        #dask='allowed',
        output_dtypes=[np.float64])

In [None]:
test

In [None]:
test.compute()

In [None]:
test