In [1]:
import numpy as np
import xarray as xr
import dask

import matplotlib.pyplot as plt

from dask import compute, delayed
from scipy.stats import rankdata

In [2]:
@delayed
def comp_histogram(obs_in,ensemble_in,mask_in):
    mask = np.bool_(mask_in.data)

    obs = obs_in.data[mask.data]
    ensemble = ensemble_in.data[:,mask]

    combined = np.vstack((obs[np.newaxis],ensemble))

    ranks=np.apply_along_axis(lambda x: rankdata(x,method='min'),0,combined)

    ties=np.sum(ranks[0]==ranks[1:], axis=0)
    ranks=ranks[0]
    tie=np.unique(ties)

    for i in range(1,len(tie)):
        index=ranks[ties==tie[i]]
        ranks[ties==tie[i]]=[np.random.randint(index[j],index[j]+tie[i]+1,tie[i])[0] for j in range(len(index))]

    return np.histogram(ranks, bins=np.linspace(0.5, combined.shape[0]+0.5, combined.shape[0]+1))

In [3]:
obs = np.random.randn(16, 31, 40, 240)
ensemble = np.random.randn(20, 16, 31, 40, 240)
mask = np.random.randint(0, 2, (16, 31, 40, 240))

da_obs = xr.DataArray(obs, name="obs", coords={"fct": range(16), "ts": range(31), "lat": np.linspace(-90,90,40), "lon": np.linspace(0,360,240)}, dims=("fct", "ts", "lat", "lon"))
da_ensemble = xr.DataArray(ensemble, name="ensemble", coords={"ens": range(20), "fct": range(16), "ts": range(31), "lat": np.linspace(-90,90,40), "lon": np.linspace(0,360,240)}, dims=("ens", "fct", "ts", "lat", "lon"))
da_mask = xr.DataArray(mask, name="mask", coords={"fct": range(16), "ts": range(31), "lat": np.linspace(-90,90,40), "lon": np.linspace(0,360,240)}, dims=("fct", "ts", "lat", "lon"))

In [4]:
delayed_results = []
for ifct in range(16):

    da_obs_isel = da_obs.isel(fct=ifct)
    da_ensemble_isel = da_ensemble.isel(fct=ifct)
    da_mask_isel = da_mask.isel(fct=ifct)

    histogram = comp_histogram(da_obs_isel,da_ensemble_isel,da_mask_isel)

    delayed_results.append(histogram)

print(delayed_results)

[Delayed('comp_histogram-10ddd5c9-5754-4b07-ae3c-d90dd4a676d5'), Delayed('comp_histogram-5c5df0c8-0ccb-4d8f-87b8-b85dbf704c2c'), Delayed('comp_histogram-a016bcf2-4399-4d77-9000-d9d3f0ff35d8'), Delayed('comp_histogram-32b9b8b4-e3a3-4da6-b9e7-a95af5fadbff'), Delayed('comp_histogram-0a5f2a64-85fe-4254-8c7e-dfad27c5f562'), Delayed('comp_histogram-f17a22e4-dcef-4225-a650-97de08271ba9'), Delayed('comp_histogram-2590ddab-f949-4fe4-942e-720889dc96a8'), Delayed('comp_histogram-a4a2f39b-f971-4424-97ce-7b4824fdfd49'), Delayed('comp_histogram-1a158dbf-ad72-424e-8be2-9c13bfeca8ac'), Delayed('comp_histogram-dd8dcdff-be31-454d-bdf7-1cb4fcf1fa69'), Delayed('comp_histogram-4fb82033-b075-4127-82c7-671569fcedb6'), Delayed('comp_histogram-436211ff-20ce-49d1-865a-8fdc37822878'), Delayed('comp_histogram-29018099-ab39-4d1a-ada8-4507bdc42718'), Delayed('comp_histogram-93391ee1-88b4-4a1c-8df9-2259491bda34'), Delayed('comp_histogram-34f223b0-6d55-47f6-8250-416299409608'), Delayed('comp_histogram-48bb9fb4-6ba4-4

In [8]:
%time
results = dask.compute(*delayed_results, scheduler="processes")

CPU times: user 3 µs, sys: 1e+03 ns, total: 4 µs
Wall time: 5.72 µs


In [9]:
results

((array([7044, 7110, 7108, 6974, 7180, 7138, 7027, 7016, 7089, 7081, 7073,
         7160, 7208, 7125, 7164, 7036, 7109, 6844, 7167, 7155, 7041]),
  array([ 0.5,  1.5,  2.5,  3.5,  4.5,  5.5,  6.5,  7.5,  8.5,  9.5, 10.5,
         11.5, 12.5, 13.5, 14.5, 15.5, 16.5, 17.5, 18.5, 19.5, 20.5, 21.5])),
 (array([7052, 7068, 7193, 7189, 7151, 7203, 6968, 7279, 7080, 7067, 7025,
         7085, 7065, 7066, 7000, 7126, 6957, 7065, 7016, 7154, 7089]),
  array([ 0.5,  1.5,  2.5,  3.5,  4.5,  5.5,  6.5,  7.5,  8.5,  9.5, 10.5,
         11.5, 12.5, 13.5, 14.5, 15.5, 16.5, 17.5, 18.5, 19.5, 20.5, 21.5])),
 (array([7190, 7177, 7130, 7001, 7104, 7130, 7068, 7034, 6959, 7124, 7115,
         7017, 7060, 7108, 7265, 7185, 7133, 7087, 6985, 7051, 7063]),
  array([ 0.5,  1.5,  2.5,  3.5,  4.5,  5.5,  6.5,  7.5,  8.5,  9.5, 10.5,
         11.5, 12.5, 13.5, 14.5, 15.5, 16.5, 17.5, 18.5, 19.5, 20.5, 21.5])),
 (array([7000, 7037, 7161, 7097, 6997, 7090, 7089, 7048, 7197, 7172, 7012,
         7196, 7107, 7086, 7