In [48]:
from typing import TYPE_CHECKING, Generic, Hashable, Iterable, Optional, Union

import numpy as np
import xarray as xr
from xarray.core.weighted import Weighted

import xcollection
from xcollection import Collection

In [160]:
ds = xr.tutorial.open_dataset('rasm')
dsa = xr.tutorial.open_dataset('air_temperature')
ds_dict = {'foo': ds, 'bar': dsa}
collection = xcollection.Collection(ds_dict)

In [166]:
len(ds.Tair)

36

In [170]:
weights = np.cos(ds.Tair)
weights.name = "weights"
weights

In [6]:
weighted = {key: ds.weighted(weights.fillna(0)) for key, ds in c.items()}
print(weighted)

{'foo': DatasetWeighted with weights along dimensions: time, y, x, 'bar': DatasetWeighted with weights along dimensions: time, y, x}


In [55]:
class CollectionWeighted(Weighted["Collection"]):
    def _check_dim(self, dim: Optional[Union[Hashable, Iterable[Hashable]]]):
        """raise an error if any dimension is missing"""

        for key, dataset in self.obj.items():
            if isinstance(dim, str) or not isinstance(dim, Iterable):
                dims = [dim] if dim else []
            else:
                dims = list(dim)
            missing_dims = set(dims) - set(dataset.dims) - set(self.weights.dims)
            if missing_dims:
                raise ValueError(
                    f"{dataset.__class__.__name__} does not contain the dimensions: {missing_dims}"
                )

    def _implementation(self, func, dim, **kwargs) -> "Collection":

        self._check_dim(dim)

        dataset_dict = {}
        for key, dataset in self.obj.items():

            dataset = dataset.map(func, dim=dim, **kwargs)
            dataset_dict[key] = dataset
        return Collection(dataset_dict)

In [56]:
cw = CollectionWeighted(c, weights.fillna(0))

In [155]:
cwm = cw.mean(dim="time")

In [156]:
ds = xr.tutorial.open_dataset('rasm')
dsa = xr.tutorial.open_dataset('air_temperature')
ds_dict = {'foo': ds, 'bar': dsa}
collection = xcollection.Collection(ds_dict)

weights = np.cos(np.deg2rad(ds.Tair))
weights.name = "weights"

collection_weighted = CollectionWeighted(collection, weights.fillna(0))
collection_dict = {
    "mean": collection_weighted.mean(dim="time"),
    "sum": collection_weighted.sum(dim="time"),
    "sum_of_weights": collection_weighted.sum_of_weights(dim="time"),
}

dict_weighted = {key: ds_dict[key].weighted(weights.fillna(0)) for key in ds_dict}
dict_dict = {
    "mean": {key: dict_weighted[key].mean(dim="time") for key in dict_weighted},
    "sum": {key: dict_weighted[key].sum(dim="time") for key in dict_weighted},
    "sum_of_weights": {key: dict_weighted[key].sum_of_weights(dim="time") for key in dict_weighted},
}

for k in collection_dict:
    for j, ds in collection_dict[k].items():
        assert type(ds) == type(dict_dict[k][j])
        assert ds == dict_dict[k][j]
        print(type(ds))

<class 'xarray.core.dataset.Dataset'>
<class 'xarray.core.dataset.Dataset'>
<class 'xarray.core.dataset.Dataset'>
<class 'xarray.core.dataset.Dataset'>
<class 'xarray.core.dataset.Dataset'>
<class 'xarray.core.dataset.Dataset'>
