In [None]:
%reload_ext autoreload

%autoreload 2

%reload_ext dotenv
%dotenv

import gc

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr

from mlde_utils import dataset_split_path, workdir_path
from mlde_utils.transforms import build_target_transform
from mlde_notebooks import create_map_fig, plot_map
from mlde_notebooks.data import open_dataset_split, attach_derived_variables, attach_eval_coords
from mlde_notebooks.display import pretty_table
from mlde_notebooks.distribution import plot_freq_density

In [None]:
var = "target_swbgt"
dataset = "bham64_ccpm-4x_12em_mv"
split = "val"
derived_variables_config = {
    "swbgt": [
        "mlde_notebooks.derived_variables.swbgt",
        {"temp": "tmean150cm", "rh": "relhum150cm"},
    ]
}

In [None]:
ds = open_dataset_split(dataset, split)
ds = attach_eval_coords(ds)
ds = attach_derived_variables(ds, derived_variables_config, prefixes=["target"])
ds

In [None]:
da = ds[var]

In [None]:
xr.merge([
    da.mean().rename("mean"), 
    da.std().rename("std"),
    da.min().rename("min"), 
    da.quantile(0.25).drop("quantile").rename("25%"),
    da.quantile(0.5).drop("quantile").rename("50%"),
    da.quantile(0.75).drop("quantile").rename("75%"),
    da.max().rename("max"),
]).to_pandas()

In [None]:
bins=50
# bins = np.histogram_bin_edges([], bins=150, range=(-3, 3))

da.plot(label=var, density=True, bins=bins,)
plt.legend()

plt.show()

In [None]:
time_mean = da.mean(dim=["ensemble_member", "time"])
time_std = da.std(dim=["ensemble_member", "time"])

fig, axd = create_map_fig([["mean", "std"], ["mean_style", "."]])
plot_map(time_mean, ax=axd["mean"], style=None, cmap="turbo", title=f"Time mean", add_colorbar=True)
plot_map(time_mean, ax=axd["mean_style"], style=var.replace("target_", ""), title=f"Styled Time mean", add_colorbar=True)
plot_map(time_std , ax=axd["std"], style=None, cmap="turbo", title=f"Time std", add_colorbar=True)

In [None]:
domain_mean = da.mean(dim=["grid_longitude", "grid_latitude"])

domain_mean.rolling(time=90).mean().plot(col="ensemble_member", col_wrap=4)
plt.show()

domain_mean.resample(time="MS").mean().plot(col="ensemble_member", col_wrap=4)
plt.show()

domain_mean.resample(time="MS").mean().plot(alpha=0.2, hue="ensemble_member", add_legend=False)
plt.show()
domain_mean.resample(time="MS").mean().mean("ensemble_member").plot()
plt.show()