In [None]:
%matplotlib inline
import pandas as pd
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt

import arkouda as ak
ak.connect()

In [None]:
ds = xr.tutorial.open_dataset("rasm").load().chunk(chunked_array_type="arkouda")
ds

In [None]:
month_length = ds.time.dt.days_in_month
month_length

In [None]:
# Calculate the weights by grouping by 'time.season'.
weights = (
    month_length.groupby("time.season") / month_length.groupby("time.season").sum()
)

# Test that the sum of the weights for each season is 1.0
np.testing.assert_allclose(weights.groupby("time.season").sum().values, np.ones(4))


# convert weights to use an Arkouda Array (see: https://github.com/pydata/xarray/issues/9040)
from arkouda import array_api as xp

weights = xr.DataArray(
    xp.asarray(weights.data),
    coords=weights.coords,
)

ds_weighted = (ds * weights).groupby("time.season").sum(dim="time")
ds_weighted


In [None]:
# only used for comparisons
ds_unweighted = ds.groupby("time.season").mean("time")
ds_diff = ds_weighted - ds_unweighted

In [None]:
# Quick plot to show the results
notnull = xp.logical_not(xp.isnan(ds_unweighted["Tair"][0].data))

fig, axes = plt.subplots(nrows=4, ncols=3, figsize=(14, 12))
for i, season in enumerate(("DJF", "MAM", "JJA", "SON")):
    ds_weighted["Tair"].sel(season=season).where(notnull).plot.pcolormesh(
        ax=axes[i, 0],
        vmin=-30,
        vmax=30,
        cmap="Spectral_r",
        add_colorbar=True,
        extend="both",
    )

    ds_unweighted["Tair"].sel(season=season).where(notnull).plot.pcolormesh(
        ax=axes[i, 1],
        vmin=-30,
        vmax=30,
        cmap="Spectral_r",
        add_colorbar=True,
        extend="both",
    )

    ds_diff["Tair"].sel(season=season).where(notnull).plot.pcolormesh(
        ax=axes[i, 2],
        vmin=-0.1,
        vmax=0.1,
        cmap="RdBu_r",
        add_colorbar=True,
        extend="both",
    )

    axes[i, 0].set_ylabel(season)
    axes[i, 1].set_ylabel("")
    axes[i, 2].set_ylabel("")

for ax in axes.flat:
    ax.axes.get_xaxis().set_ticklabels([])
    ax.axes.get_yaxis().set_ticklabels([])
    ax.axes.axis("tight")
    ax.set_xlabel("")

axes[0, 0].set_title("Weighted by DPM")
axes[0, 1].set_title("Equal Weighting")
axes[0, 2].set_title("Difference")

plt.tight_layout()

fig.suptitle("Seasonal Surface Air Temperature", fontsize=16, y=1.02)