In [None]:
import os
import pickle
from collections import Counter

import cartopy.crs as ccrs
import cartopy.feature as cfeature
import cmocean as cmo
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import salem
import seaborn as sns
from dmelon import plotting, utils
from matplotlib import colors as c
from matplotlib.colors import BoundaryNorm
from matplotlib.ticker import MaxNLocator
from tqdm.notebook import tqdm

import xarray as xr

In [None]:
settings = None

In [None]:
settings = utils.load_json(settings)

In [None]:
MONTH = settings["MONTH"]
DATA_DIR = settings["DATA_DIR"]
MONTH_DIR = os.path.join(DATA_DIR, f"{settings['INIT_MONTH']}.{MONTH}")
VALIDATION_DIR = os.path.join(MONTH_DIR, "validation")
NC_DIR = os.path.join(MONTH_DIR, "Data")
PLOTS_DIR = os.path.join(MONTH_DIR, "plots")

utils.check_folder(PLOTS_DIR)

In [None]:
MONTHS_ORDER = [10, 11, 12, 1, 2, 3, 4]

In [None]:
pred_data = xr.open_dataset(os.path.join(NC_DIR, "pred_data.nc")).pred_data

metric_data = xr.open_dataset(os.path.join(NC_DIR, "metric_data.nc")).metric_data

metric2_data = xr.open_dataset(os.path.join(NC_DIR, "metric2_data.nc")).metric2_data

nvar_data = xr.open_dataset(os.path.join(NC_DIR, "nvar_data.nc")).nvar_data

pred_data_val = xr.open_dataset(os.path.join(NC_DIR, "pred_data_val.nc")).pred_data_val

thresh_data = xr.open_dataset(os.path.join(NC_DIR, "thresh_data.nc")).thresh_data

with open(os.path.join(NC_DIR, "model_data.pickle"), "rb") as handle:
    model_data = pickle.load(handle)

In [None]:
predictors = pd.read_excel(
    os.path.join(settings["MODEL_SRC"], settings["PREDICTORS"]), index_col=[0]
)
predictors.head()

In [None]:
pisco = xr.open_dataset(settings["PISCO_DATA"], decode_times=False).rename(
    {"X": "lon", "Y": "lat", "T": "time"}
)
pisco.time.attrs["calendar"] = "360_day"
pisco = xr.decode_cf(pisco).Prec
pisco["time"] = pd.date_range("1981-01", "2016-12", freq="MS") + pd.DateOffset(days=14)
pisco = pisco.sel(time=slice("1981-10-01", "2016-05-01"))
pisco

In [None]:
mask = np.where(pisco.isel(time=0) >= 0, 1, np.nan)

In [None]:
no_neg = pred_data.where(pred_data >= 0)
no_neg_val = pred_data_val.where(pred_data_val >= 0, 0) * mask

In [None]:
fig, axs = plt.subplots(
    figsize=(6, 4),
    dpi=300,
    ncols=4,
    nrows=2,
    sharey=True,
    subplot_kw={"projection": ccrs.PlateCarree()},
)

levels = MaxNLocator(nbins=6).tick_values(0.4, 1)
cmap = cmo.cm.thermal
cmap.set_under(color="white")
norm = BoundaryNorm(levels, ncolors=cmap.N, clip=False)

for num, month in enumerate(MONTHS_ORDER):
    ax = axs.ravel()[num]
    p = ax.pcolormesh(
        metric_data.lon.data,
        metric_data.lat.data,
        (metric_data.sel(month=month).data) ** (0.5),
        cmap=cmap,
        norm=norm,
        transform=ccrs.PlateCarree(),
        vmin=0.4,
        vmax=1,
    )
    ax.set_title(f"Target Month {month:02d}", size=6)
    plotting.format_latlon(ax, ccrs.PlateCarree(), lon_step=5, lat_step=5)
    ax.add_feature(cfeature.BORDERS)
    ax.add_feature(cfeature.COASTLINE)
    ax.set_extent((-81.25, -68.05, -18.75, 0.95), crs=ccrs.PlateCarree())
    ax.tick_params(axis="both", labelsize=5)
    ax.gridlines(linewidth=0.5, linestyle="--", alpha=0.5)
fig.suptitle("Correlation Coefficient (R)", y=0.95)
fig.delaxes(ax=axs.ravel().tolist()[-1])
fig.colorbar(p, ax=axs.ravel().tolist(), extend="min")
fig.savefig(
    os.path.join(PLOTS_DIR, "metric_correlation_plot.png"),
    bbox_inches="tight",
)

In [None]:
fig, axs = plt.subplots(
    figsize=(6, 4),
    dpi=300,
    ncols=4,
    nrows=2,
    sharey=True,
    subplot_kw={"projection": ccrs.PlateCarree()},
)

levels = MaxNLocator(nbins=6).tick_values(0.4, 1)
cmap = cmo.cm.thermal
cmap.set_under(color="white")
norm = BoundaryNorm(levels, ncolors=cmap.N, clip=False)

for num, month in enumerate(MONTHS_ORDER):
    ax = axs.ravel()[num]
    p = ax.pcolormesh(
        metric2_data.lon.data,
        metric2_data.lat.data,
        metric2_data.sel(month=month).data,
        cmap=cmap,
        norm=norm,
        transform=ccrs.PlateCarree(),
        vmin=0,
        vmax=1,
    )
    ax.set_title(f"Target Month {month:02d}", size=6)
    plotting.format_latlon(ax, ccrs.PlateCarree(), lon_step=5, lat_step=5)
    ax.add_feature(cfeature.BORDERS)
    ax.add_feature(cfeature.COASTLINE)
    ax.set_extent((-81.25, -68.05, -18.75, 0.95), crs=ccrs.PlateCarree())
    ax.tick_params(axis="both", labelsize=5)
    ax.gridlines(linewidth=0.5, linestyle="--", alpha=0.5)
fig.suptitle("Adjusted Coefficient of Determination (R2_adj)", y=0.95)
fig.delaxes(ax=axs.ravel().tolist()[-1])
fig.colorbar(p, ax=axs.ravel().tolist(), extend="min")
fig.savefig(
    os.path.join(PLOTS_DIR, "metric_r2adj_plot.png"),
    bbox_inches="tight",
)

In [None]:
fig, axs = plt.subplots(
    figsize=(6, 4),
    dpi=300,
    ncols=4,
    nrows=2,
    sharey=True,
    subplot_kw={"projection": ccrs.PlateCarree()},
)

levels = MaxNLocator(nbins=10).tick_values(0, 0.1)
cmap = cmo.cm.thermal_r
norm = BoundaryNorm(levels, ncolors=cmap.N, clip=True)

for num, month in enumerate(MONTHS_ORDER):
    ax = axs.ravel()[num]
    p = ax.pcolormesh(
        thresh_data.lon.data,
        thresh_data.lat.data,
        thresh_data.sel(month=month).data,
        cmap=cmap,
        transform=ccrs.PlateCarree(),
        norm=norm,
    )
    ax.set_title(f"Target Month {month:02d}", size=6)
    plotting.format_latlon(ax, ccrs.PlateCarree(), lon_step=5, lat_step=5)
    ax.add_feature(cfeature.BORDERS)
    ax.add_feature(cfeature.COASTLINE)
    ax.set_extent((-81.25, -68.05, -18.75, 0.95), crs=ccrs.PlateCarree())
    ax.tick_params(axis="both", labelsize=5)
    ax.gridlines(linewidth=0.5, linestyle="--", alpha=0.5)
fig.suptitle("Significance level", y=0.95)
fig.delaxes(ax=axs.ravel().tolist()[-1])
cbar = fig.colorbar(p, ax=axs.ravel().tolist())
cbar.ax.invert_yaxis()
fig.savefig(
    os.path.join(PLOTS_DIR, "metric_siglevel_plot.png"),
    bbox_inches="tight",
)

In [None]:
fig, axs = plt.subplots(figsize=(10, 4), ncols=4, nrows=2, dpi=300, sharey=True)

for num, month in enumerate(MONTHS_ORDER):
    flat_data = nvar_data.where(thresh_data <= 0.05).sel(month=month).data.flatten()
    flat_data = flat_data[~np.isnan(flat_data)].astype(int)
    flat_count = Counter(flat_data)
    flat_df = pd.DataFrame(columns=range(0, 20), index=["count"])
    for key, item in flat_count.items():
        flat_df[key] = item
    ax = axs.ravel()[num]
    sns.barplot(
        data=flat_df[list(range(17))],
        ax=ax,
    )
    ax.tick_params(axis="both", labelsize=5)
    ax.set_title(f"Target Month {month:02d}", size=6)
    ax.grid(ls="--", alpha=0.5)


fig.text(0.5, 0, "# Variables used", va="center", ha="center")
fig.text(0, 0.5, "Count", va="center", ha="center", rotation="vertical")

fig.suptitle("Number of variables used")
fig.delaxes(ax=axs.ravel().tolist()[-1])
fig.tight_layout()
fig.savefig(
    os.path.join(PLOTS_DIR, "nvar_plot.png"),
    bbox_inches="tight",
)

In [None]:
bias = (
    (no_neg_val.sel(time=slice(None, "2016-05-01")) - pisco)
    .groupby("time.month")
    .mean()
)

fig, axs = plt.subplots(
    figsize=(6, 4),
    dpi=300,
    ncols=4,
    nrows=2,
    sharey=True,
    subplot_kw={"projection": ccrs.PlateCarree()},
)

for num, month in enumerate(MONTHS_ORDER):
    ax = axs.ravel().tolist()[num]
    p = ax.pcolormesh(
        bias.lon.data,
        bias.lat.data,
        bias.sel(month=month).data,
        cmap=cmo.cm.balance,
        transform=ccrs.PlateCarree(),
        vmax=10,
        vmin=-10,
    )
    ax.set_title(f"Target Month {month:02d}", size=6)
    plotting.format_latlon(ax, ccrs.PlateCarree(), lon_step=5, lat_step=5)
    ax.add_feature(cfeature.BORDERS)
    ax.add_feature(cfeature.COASTLINE)
    ax.set_extent((-81.25, -68.05, -18.75, 0.95), crs=ccrs.PlateCarree())
    ax.tick_params(axis="both", labelsize=5)
    ax.gridlines(linewidth=0.5, linestyle="--", alpha=0.5)
fig.colorbar(p, ax=axs.ravel().tolist(), extend="both")
fig.suptitle("Mean Bias (Validation Model - PISCOPrecv2p1) [mm]")
fig.delaxes(ax=axs.ravel().tolist()[-1])
fig.savefig(
    os.path.join(PLOTS_DIR, "val_pisco_bias_plot.png"),
    bbox_inches="tight",
)

In [None]:
bias = (
    (no_neg_val.sel(time=slice(None, "2016-05-01")) - pisco)
    .groupby("time.month")
    .mean()
    * 100
) / pisco.groupby("time.month").mean()

fig, axs = plt.subplots(
    figsize=(6, 4),
    dpi=300,
    ncols=4,
    nrows=2,
    sharey=True,
    subplot_kw={"projection": ccrs.PlateCarree()},
)

for num, month in enumerate(MONTHS_ORDER):
    ax = axs.ravel().tolist()[num]
    p = ax.pcolormesh(
        bias.lon.data,
        bias.lat.data,
        bias.sel(month=month).data,
        cmap=cmo.cm.balance,
        transform=ccrs.PlateCarree(),
        vmax=20,
        vmin=-20,
    )
    ax.set_title(f"Target Month {month:02d}", size=6)
    plotting.format_latlon(ax, ccrs.PlateCarree(), lon_step=5, lat_step=5)
    ax.add_feature(cfeature.BORDERS)
    ax.add_feature(cfeature.COASTLINE)
    ax.set_extent((-81.25, -68.05, -18.75, 0.95), crs=ccrs.PlateCarree())
    ax.tick_params(axis="both", labelsize=5)
    ax.gridlines(linewidth=0.5, linestyle="--", alpha=0.5)
fig.colorbar(p, ax=axs.ravel().tolist(), extend="both")
fig.suptitle("Mean Bias (Validation Model - PISCOPrecv2p1) [%]")
fig.delaxes(ax=axs.ravel().tolist()[-1])
fig.savefig(
    os.path.join(PLOTS_DIR, "val_pisco_bias_perc_plot.png"),
    bbox_inches="tight",
)

In [None]:
mae = (
    np.abs(no_neg_val.sel(time=slice(None, "2016-05-01")) - pisco)
    .groupby("time.month")
    .mean()
)

fig, axs = plt.subplots(
    figsize=(6, 4),
    dpi=300,
    ncols=4,
    nrows=2,
    sharey=True,
    subplot_kw={"projection": ccrs.PlateCarree()},
)

for num, month in enumerate(MONTHS_ORDER):
    ax = axs.ravel().tolist()[num]
    p = ax.pcolormesh(
        mae.lon.data,
        mae.lat.data,
        mae.sel(month=month).data,
        cmap=cmo.cm.tempo,
        transform=ccrs.PlateCarree(),
        vmax=100,
        vmin=0,
    )
    ax.set_title(f"Target Month {month:02d}", size=6)
    plotting.format_latlon(ax, ccrs.PlateCarree(), lon_step=5, lat_step=5)
    ax.add_feature(cfeature.BORDERS)
    ax.add_feature(cfeature.COASTLINE)
    ax.set_extent((-81.25, -68.05, -18.75, 0.95), crs=ccrs.PlateCarree())
    ax.tick_params(axis="both", labelsize=5)
    ax.gridlines(linewidth=0.5, linestyle="--", alpha=0.5)
fig.colorbar(p, ax=axs.ravel().tolist(), extend="max")
fig.suptitle("MAE (Validation Model - PISCOPrecv2p1) [mm]")
fig.delaxes(ax=axs.ravel().tolist()[-1])
fig.savefig(
    os.path.join(PLOTS_DIR, "val_pisco_mae_plot.png"),
    bbox_inches="tight",
)

In [None]:
mae = (
    np.abs(no_neg_val.sel(time=slice(None, "2016-05-01")) - pisco)
    .groupby("time.month")
    .mean()
    * 100
) / pisco.groupby("time.month").mean()

fig, axs = plt.subplots(
    figsize=(6, 4),
    dpi=300,
    ncols=4,
    nrows=2,
    sharey=True,
    subplot_kw={"projection": ccrs.PlateCarree()},
)

for num, month in enumerate(MONTHS_ORDER):
    ax = axs.ravel().tolist()[num]
    p = ax.pcolormesh(
        mae.lon.data,
        mae.lat.data,
        mae.sel(month=month).data,
        cmap=cmo.cm.tempo,
        transform=ccrs.PlateCarree(),
        vmax=100,
        vmin=0,
    )
    ax.set_title(f"Target Month {month:02d}", size=6)
    plotting.format_latlon(ax, ccrs.PlateCarree(), lon_step=5, lat_step=5)
    ax.add_feature(cfeature.BORDERS)
    ax.add_feature(cfeature.COASTLINE)
    ax.set_extent((-81.25, -68.05, -18.75, 0.95), crs=ccrs.PlateCarree())
    ax.tick_params(axis="both", labelsize=5)
    ax.gridlines(linewidth=0.5, linestyle="--", alpha=0.5)
fig.colorbar(p, ax=axs.ravel().tolist(), extend="max")
fig.suptitle("MAE (Validation Model - PISCOPrecv2p1) [%]")
fig.delaxes(ax=axs.ravel().tolist()[-1])
fig.savefig(
    os.path.join(PLOTS_DIR, "val_pisco_mae_perc_plot.png"),
    bbox_inches="tight",
)

In [None]:
model_data = model_data.dropna(dim="month", how="all")

In [None]:
month_vars = {}
for month in tqdm(model_data.month.data):
    month_vars[month] = []
    for lat in model_data.lat.data:
        for lon in model_data.lon.data:
            test_model = model_data.sel(month=month, lat=lat, lon=lon).data.item()
            if not isinstance(test_model, float):
                month_vars[month] = (
                    month_vars[month] + test_model.params.index.to_list()[1:]
                )

In [None]:
month_vars_common = {
    month: Counter(data).most_common(13) for month, data in month_vars.items()
}

In [None]:
var_list = []
for month, data in month_vars_common.items():
    var_list = var_list + list(dict(data).keys())

In [None]:
colors = sns.color_palette("hls", n_colors=len(set(var_list)))
color_mapping = {key: [colors[num]] for num, key in enumerate(set(var_list))}
color_mapping = pd.DataFrame(color_mapping, index=["color"])

In [None]:
fig, axs = plt.subplots(figsize=(10, 6), ncols=4, nrows=2, dpi=300, sharey=True)

for num, (month, data) in enumerate(month_vars_common.items()):
    plot_data = pd.DataFrame(data, columns=["vars", "count"])
    plot_data["color"] = color_mapping[plot_data["vars"]].values[0]
    ax = axs.ravel()[num]
    ax.bar(
        x="vars",
        height="count",
        data=plot_data,
        color=plot_data["color"],
        edgecolor="k",
    )
    for item in ax.get_xticklabels():
        item.set_rotation(90)
    ax.tick_params(axis="both", labelsize=5)
    ax.set_title(f"Target Month {month:02d}", size=6)


fig.text(0.5, 0, "Variables", va="center", ha="center")
fig.text(0, 0.5, "Count", va="center", ha="center", rotation="vertical")

fig.suptitle("Most common variables per month")
fig.delaxes(ax=axs.ravel().tolist()[-1])
fig.tight_layout()
fig.savefig(
    os.path.join(PLOTS_DIR, "vars_freq.png"),
    bbox_inches="tight",
)

In [None]:
per_month_maps = {}


def create_dataArray(name, month):
    return xr.DataArray(
        np.nan,
        coords=[
            ("month", [month]),
            ("lat", model_data.lat.data),
            ("lon", model_data.lon.data),
        ],
        name=name,
    )


for month, data in month_vars_common.items():
    per_month_maps[month] = {
        var_name: create_dataArray(var_name, month) for var_name in dict(data).keys()
    }

for month, data in tqdm(per_month_maps.items()):
    for lat in tqdm(model_data.lat.data):
        for lon in model_data.lon.data:
            test_model = model_data.sel(month=month, lat=lat, lon=lon).data.item()
            if not isinstance(test_model, float):
                for variable, xr_container in data.items():
                    if variable in test_model.params.index:
                        xr_container.loc[dict(month=month, lat=lat, lon=lon)] = 1

In [None]:
for month, data in per_month_maps.items():
    plot_data = pd.DataFrame(data.keys(), columns=["vars"])
    plot_data["color"] = color_mapping[plot_data["vars"]].values[0]

    fig, axs = plt.subplots(
        figsize=(8, 12),
        dpi=300,
        ncols=4,
        nrows=4,
        sharey=True,
        subplot_kw={"projection": ccrs.PlateCarree()},
    )

    for num, (var_name, xr_data) in enumerate(data.items()):
        ax = axs.ravel().tolist()[num]
        p = ax.pcolormesh(
            xr_data.lon.data,
            xr_data.lat.data,
            xr_data.sel(month=month).data,
            cmap=c.ListedColormap(plot_data.query("vars==@var_name")["color"]),
            transform=ccrs.PlateCarree(),
        )

        ax.set_title(f"{var_name}", size=6)
        plotting.format_latlon(ax, ccrs.PlateCarree(), lon_step=5, lat_step=5)
        ax.add_feature(cfeature.BORDERS)
        ax.add_feature(cfeature.COASTLINE)
        ax.set_extent((-81.25, -68.05, -18.75, 0.95), crs=ccrs.PlateCarree())
        ax.tick_params(axis="both", labelsize=5)
        ax.gridlines(linewidth=0.5, linestyle="--", alpha=0.5)
    for ax in axs.ravel().tolist()[-3:]:
        fig.delaxes(ax=ax)
    fig.suptitle(
        f"Spatial distribution of most common variables in target month {month:02d}",
        y=0.92,
    )
    fig.savefig(
        os.path.join(PLOTS_DIR, f"vars_dist_month{month:02d}.png"),
        bbox_inches="tight",
    )

# PERU - ONLY

In [None]:
PERU_DIR = os.path.join(PLOTS_DIR, "PERU")
utils.check_folder(PERU_DIR)

In [None]:
shsdf = salem.read_shapefile(
    "/data/users/grivera/Shapes/countries/ne_10m_admin_0_countries.shp"
)

In [None]:
subset_model = model_data.salem.roi(shape=shsdf.query("ADMIN == 'Peru'"))
month_vars = {}
for month in tqdm(subset_model.month.data):
    month_vars[month] = []
    for lat in subset_model.lat.data:
        for lon in subset_model.lon.data:
            test_model = subset_model.sel(month=month, lat=lat, lon=lon).data.item()
            if not isinstance(test_model, float):
                month_vars[month] = (
                    month_vars[month] + test_model.params.index.to_list()[1:]
                )

In [None]:
month_vars_common = {
    month: Counter(data).most_common(13) for month, data in month_vars.items()
}

In [None]:
var_list = []
for month, data in month_vars_common.items():
    var_list = var_list + list(dict(data).keys())

In [None]:
colors = sns.color_palette("hls", n_colors=len(set(var_list)))
color_mapping = {key: [colors[num]] for num, key in enumerate(set(var_list))}
color_mapping = pd.DataFrame(color_mapping, index=["color"])

In [None]:
fig, axs = plt.subplots(figsize=(10, 6), ncols=4, nrows=2, dpi=300, sharey=True)

for num, (month, data) in enumerate(month_vars_common.items()):
    plot_data = pd.DataFrame(data, columns=["vars", "count"])
    plot_data["color"] = color_mapping[plot_data["vars"]].values[0]
    ax = axs.ravel()[num]
    ax.bar(
        x="vars",
        height="count",
        data=plot_data,
        color=plot_data["color"],
        edgecolor="k",
    )
    for item in ax.get_xticklabels():
        item.set_rotation(90)
    ax.tick_params(axis="both", labelsize=5)
    ax.set_title(f"Target Month {month:02d}", size=6)


fig.text(0.5, 0, "Variables", va="center", ha="center")
fig.text(0, 0.5, "Count", va="center", ha="center", rotation="vertical")

fig.suptitle("Most common variables per month")
fig.delaxes(ax=axs.ravel().tolist()[-1])
fig.tight_layout()
fig.savefig(
    os.path.join(PERU_DIR, "vars_freq_PERU.png"),
    bbox_inches="tight",
)

In [None]:
per_month_maps = {}


for month, data in month_vars_common.items():
    per_month_maps[month] = {
        var_name: create_dataArray(var_name, month) for var_name in dict(data).keys()
    }

for month, data in tqdm(per_month_maps.items()):
    for lat in tqdm(subset_model.lat.data):
        for lon in subset_model.lon.data:
            test_model = subset_model.sel(month=month, lat=lat, lon=lon).data.item()
            if not isinstance(test_model, float):
                for variable, xr_container in data.items():
                    if variable in test_model.params.index:
                        xr_container.loc[dict(month=month, lat=lat, lon=lon)] = 1

In [None]:
for month, data in per_month_maps.items():
    plot_data = pd.DataFrame(data.keys(), columns=["vars"])
    plot_data["color"] = color_mapping[plot_data["vars"]].values[0]

    fig, axs = plt.subplots(
        figsize=(8, 12),
        dpi=300,
        ncols=4,
        nrows=4,
        sharey=True,
        subplot_kw={"projection": ccrs.PlateCarree()},
    )

    for num, (var_name, xr_data) in enumerate(data.items()):
        ax = axs.ravel().tolist()[num]
        p = ax.pcolormesh(
            xr_data.lon.data,
            xr_data.lat.data,
            xr_data.sel(month=month).data,
            cmap=c.ListedColormap(plot_data.query("vars==@var_name")["color"]),
            transform=ccrs.PlateCarree(),
        )

        ax.set_title(f"{var_name}", size=6)
        plotting.format_latlon(ax, ccrs.PlateCarree(), lon_step=5, lat_step=5)
        ax.add_feature(cfeature.BORDERS)
        ax.add_feature(cfeature.COASTLINE)
        ax.set_extent((-81.25, -68.05, -18.75, 0.95), crs=ccrs.PlateCarree())
        ax.tick_params(axis="both", labelsize=5)
        ax.gridlines(linewidth=0.5, linestyle="--", alpha=0.5)
    for ax in axs.ravel().tolist()[-3:]:
        fig.delaxes(ax=ax)
    fig.suptitle(
        f"Spatial distribution of most common variables in target month {month:02d}",
        y=0.92,
    )
    fig.savefig(
        os.path.join(PERU_DIR, f"vars_dist_month{month:02d}_PERU.png"),
        bbox_inches="tight",
    )

# MAE NEW - OLD

In [None]:
months_list = [
    "JAN",
    "FEB",
    "MAR",
    "APR",
    "MAY",
    "JUN",
    "JUL",
    "AUG",
    "SEP",
    "OCT",
    "NOV",
    "DEC",
]

In [None]:
def plot_metric_maps(metric_data, vmax=40, vmin=-40):
    fig, axs = plt.subplots(
        figsize=(6, 4),
        dpi=300,
        ncols=4,
        nrows=2,
        sharey=True,
        subplot_kw={"projection": ccrs.PlateCarree()},
    )

    for num, month in enumerate(MONTHS_ORDER):
        ax = axs.ravel().tolist()[num]
        p = ax.pcolormesh(
            metric_data.lon.data,
            metric_data.lat.data,
            metric_data.sel(month=month).data,
            cmap=cmo.cm.balance,
            transform=ccrs.PlateCarree(),
            vmax=vmax,
            vmin=vmin,
        )
        ax.set_title(f"Target Month {month:02d}", size=6)
        plotting.format_latlon(ax, ccrs.PlateCarree(), lon_step=5, lat_step=5)
        ax.add_feature(cfeature.BORDERS)
        ax.add_feature(cfeature.COASTLINE)
        ax.set_extent((-81.25, -68.05, -18.75, 0.95), crs=ccrs.PlateCarree())
        ax.tick_params(axis="both", labelsize=5)
        ax.gridlines(linewidth=0.5, linestyle="--", alpha=0.5)
    return fig, axs, p

In [None]:
DIFF_DIR = os.path.join(PLOTS_DIR, "Diff")
DIFF_PREV_DIR = os.path.join(DIFF_DIR, "Prev")
utils.check_folder(DIFF_DIR)
utils.check_folder(DIFF_PREV_DIR)

In [None]:
for i in range(5):
    # Get the current month index from the list
    month_index = months_list.index(MONTH) - i

    # Get the current and previous month names
    prev_month = months_list[month_index - 1]
    curr_month = months_list[month_index]

    # Load previous data
    pred_data_prev = xr.open_dataset(
        os.path.join(DATA_DIR, f"{prev_month}/Data/pred_data_val.nc")
    ).pred_data_val
    no_neg_prev = pred_data_prev.where(pred_data_prev >= 0)

    # Load current data
    pred_data_curr = xr.open_dataset(
        os.path.join(DATA_DIR, f"{curr_month}/Data/pred_data_val.nc")
    ).pred_data_val
    no_neg_curr = pred_data_curr.where(pred_data_prev >= 0)

    # Do some computation
    mae_current = (
        np.abs(no_neg_curr.sel(time=slice(None, "2016-05-01")) - pisco)
        .groupby("time.month")
        .mean()
    )

    mae_prev = (
        np.abs(no_neg_prev.sel(time=slice(None, "2016-05-01")) - pisco)
        .groupby("time.month")
        .mean()
    )

    mae_diff = mae_current - mae_prev

    fig, axs, p = plot_metric_maps(mae_diff)
    fig.colorbar(p, ax=axs.ravel().tolist(), extend="both")
    fig.suptitle(f"Diferencia de MAE ({curr_month} - {prev_month}) [mm]")
    fig.delaxes(ax=axs.ravel().tolist()[-1])
    fig.savefig(
        os.path.join(
            DIFF_PREV_DIR,
            f"val_pisco_mae_{curr_month.lower()}-{prev_month.lower()}.png",
        ),
        bbox_inches="tight",
    )

    mae_diff_perc = (mae_current - mae_prev) * 100 / mae_prev

    fig, axs, p = plot_metric_maps(mae_diff_perc, vmax=100, vmin=-100)
    fig.colorbar(p, ax=axs.ravel().tolist(), extend="both")
    fig.suptitle(f"Diferencia de MAE ({curr_month} - {prev_month}) [%]")
    fig.delaxes(ax=axs.ravel().tolist()[-1])
    fig.savefig(
        os.path.join(
            DIFF_PREV_DIR,
            f"val_pisco_mae_{curr_month.lower()}-{prev_month.lower()}_perc.png",
        ),
        bbox_inches="tight",
    )

In [None]:
month_index = months_list.index(MONTH)

for i in range(1, 6):
    # Get the current and previous month names
    prev_month = months_list[month_index - i]
    curr_month = months_list[month_index]

    # Load previous data
    pred_data_prev = xr.open_dataset(
        os.path.join(DATA_DIR, f"{prev_month}/Data/pred_data_val.nc")
    ).pred_data_val
    no_neg_prev = pred_data_prev.where(pred_data_prev >= 0)

    # Load current data
    pred_data_curr = xr.open_dataset(
        os.path.join(DATA_DIR, f"{curr_month}/Data/pred_data_val.nc")
    ).pred_data_val
    no_neg_curr = pred_data_curr.where(pred_data_prev >= 0)

    # Do some computation
    mae_current = (
        np.abs(no_neg_curr.sel(time=slice(None, "2016-05-01")) - pisco)
        .groupby("time.month")
        .mean()
    )

    mae_prev = (
        np.abs(no_neg_prev.sel(time=slice(None, "2016-05-01")) - pisco)
        .groupby("time.month")
        .mean()
    )

    mae_diff = mae_current - mae_prev

    fig, axs, p = plot_metric_maps(mae_diff)
    fig.colorbar(p, ax=axs.ravel().tolist(), extend="both")
    fig.suptitle(f"Diferencia de MAE ({curr_month} - {prev_month}) [mm]")
    fig.delaxes(ax=axs.ravel().tolist()[-1])
    fig.savefig(
        os.path.join(
            DIFF_DIR, f"val_pisco_mae_{curr_month.lower()}-{prev_month.lower()}.png"
        ),
        bbox_inches="tight",
    )

    mae_diff_perc = (mae_current - mae_prev) * 100 / mae_prev

    fig, axs, p = plot_metric_maps(mae_diff_perc, vmax=100, vmin=-100)
    fig.colorbar(p, ax=axs.ravel().tolist(), extend="both")
    fig.suptitle(f"Diferencia de MAE ({curr_month} - {prev_month}) [%]")
    fig.delaxes(ax=axs.ravel().tolist()[-1])
    fig.savefig(
        os.path.join(
            DIFF_DIR,
            f"val_pisco_mae_{curr_month.lower()}-{prev_month.lower()}_perc.png",
        ),
        bbox_inches="tight",
    )

# Per Year

In [None]:
PER_YEAR = os.path.join(PLOTS_DIR, "per_year")
utils.check_folder(PER_YEAR)

In [None]:
bias = no_neg_val - pisco

for year in range(1983, 2017):
    bias_sel = bias.sel(time=slice(f"{year-1}-10-01", f"{year}-05-01"))

    fig, axs = plt.subplots(
        figsize=(6, 4),
        dpi=300,
        ncols=4,
        nrows=2,
        sharey=True,
        subplot_kw={"projection": ccrs.PlateCarree()},
    )

    for num, date in enumerate(bias_sel.time.data):
        ax = axs.ravel().tolist()[num]
        p = ax.pcolormesh(
            bias_sel.lon.data,
            bias_sel.lat.data,
            bias_sel.sel(time=date).data,
            cmap=cmo.cm.balance,
            transform=ccrs.PlateCarree(),
            vmax=150,
            vmin=-150,
        )
        ax.set_title(f"{pd.to_datetime(date):%Y-%m}", size=6)
        plotting.format_latlon(ax, ccrs.PlateCarree(), lon_step=5, lat_step=5)
        ax.add_feature(cfeature.BORDERS)
        ax.add_feature(cfeature.COASTLINE)
        ax.set_extent((-81.25, -68.05, -18.75, 0.95), crs=ccrs.PlateCarree())
        ax.tick_params(axis="both", labelsize=5)
        ax.gridlines(linewidth=0.5, linestyle="--", alpha=0.5)
    fig.colorbar(p, ax=axs.ravel().tolist(), extend="both")
    fig.suptitle("Bias (Validation Model - PISCOPrecv2p1) [mm]")
    fig.delaxes(ax=axs.ravel().tolist()[-1])
    fig.savefig(
        os.path.join(PER_YEAR, f"{year-1}-{year}.bias.plot.png"),
        bbox_inches="tight",
    )
    print(f"Done year {year}")
    plt.close(fig)

In [None]:
mae = np.abs(no_neg_val - pisco)
for year in range(1983, 2017):
    mae_sel = mae.sel(time=slice(f"{year-1}-10-01", f"{year}-05-01"))

    fig, axs = plt.subplots(
        figsize=(6, 4),
        dpi=300,
        ncols=4,
        nrows=2,
        sharey=True,
        subplot_kw={"projection": ccrs.PlateCarree()},
    )

    for num, date in enumerate(mae_sel.time.data):
        ax = axs.ravel().tolist()[num]
        p = ax.pcolormesh(
            mae_sel.lon.data,
            mae_sel.lat.data,
            mae_sel.sel(time=date).data,
            cmap=cmo.cm.tempo,
            transform=ccrs.PlateCarree(),
            vmax=150,
            vmin=0,
        )
        ax.set_title(f"{pd.to_datetime(date):%Y-%m}", size=6)
        plotting.format_latlon(ax, ccrs.PlateCarree(), lon_step=5, lat_step=5)
        ax.add_feature(cfeature.BORDERS)
        ax.add_feature(cfeature.COASTLINE)
        ax.set_extent((-81.25, -68.05, -18.75, 0.95), crs=ccrs.PlateCarree())
        ax.tick_params(axis="both", labelsize=5)
        ax.gridlines(linewidth=0.5, linestyle="--", alpha=0.5)
    fig.colorbar(p, ax=axs.ravel().tolist(), extend="both")
    fig.suptitle("MAE (Validation Model - PISCOPrecv2p1) [mm]")
    fig.delaxes(ax=axs.ravel().tolist()[-1])
    fig.savefig(
        os.path.join(PER_YEAR, f"{year-1}-{year}.mae-mm.plot.png"),
        bbox_inches="tight",
    )
    print(f"Done year {year}")
    plt.close(fig)

In [None]:
mae = (np.abs(no_neg_val - pisco).groupby("time.month")) / pisco.groupby(
    "time.month"
).mean()

mae = mae * 100
for year in range(1983, 2017):
    mae_sel = mae.sel(time=slice(f"{year-1}-10-01", f"{year}-05-01"))

    fig, axs = plt.subplots(
        figsize=(6, 4),
        dpi=300,
        ncols=4,
        nrows=2,
        sharey=True,
        subplot_kw={"projection": ccrs.PlateCarree()},
    )

    for num, date in enumerate(mae_sel.time.data):
        ax = axs.ravel().tolist()[num]
        p = ax.pcolormesh(
            mae_sel.lon.data,
            mae_sel.lat.data,
            mae_sel.sel(time=date).data,
            cmap=cmo.cm.tempo,
            transform=ccrs.PlateCarree(),
            vmax=100,
            vmin=0,
        )
        ax.set_title(f"{pd.to_datetime(date):%Y-%m}", size=6)
        plotting.format_latlon(ax, ccrs.PlateCarree(), lon_step=5, lat_step=5)
        ax.add_feature(cfeature.BORDERS)
        ax.add_feature(cfeature.COASTLINE)
        ax.set_extent((-81.25, -68.05, -18.75, 0.95), crs=ccrs.PlateCarree())
        ax.tick_params(axis="both", labelsize=5)
        ax.gridlines(linewidth=0.5, linestyle="--", alpha=0.5)
    fig.colorbar(p, ax=axs.ravel().tolist(), extend="both")
    fig.suptitle("MAE (Validation Model - PISCOPrecv2p1) [%]")
    fig.delaxes(ax=axs.ravel().tolist()[-1])
    fig.savefig(
        os.path.join(PER_YEAR, f"{year-1}-{year}.mae-perc.plot.png"),
        bbox_inches="tight",
    )
    print(f"Done year {year}")
    plt.close(fig)