In [23]:
from general_tamsat_alert import *
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import general_tamsat_alert.weighting_functions as wfs

In [24]:
FILENAME = "../data/drought-model-driving-data_pakistan_19820101-present_0.05.nc"
FIELD = "ndvi"

TIME_LABEL = "time"
LON_LABEL = "lon"
LAT_LABEL = "lat"

In [25]:
ds = xr.load_dataset(FILENAME)
ds

In [26]:
period = get_periodicity(ds, FIELD, time_label=TIME_LABEL)
period

24

In [27]:
def plot_ensembles(
    da: xr.DataArray,
    weights: xr.DataArray = None,
    quantiles = (0.25, 0.5, 0.75),
    data_min: float = None,
    data_max: float = None,
    robust: bool = True,
    plot_value: bool = True,
    plot_abs_bias: bool = True,
    plot_rel_bias: bool = True,
    only_mean: bool = False,
    subplot_index = None,):
    if weights is not None:
        toqarray = da.weighted(weights)
    else:
        toqarray = da

    if only_mean:
        to_plot, bias_array, rel_bias_array = ensembles.get_mean_data(da, weights)
        if subplot_index is not None:
            plt.subplot(*subplot_index)
        if plot_value:
            to_plot.plot(
                x=LON_LABEL,
                y=LAT_LABEL,
                cmap="viridis",
                vmin=data_min,
                vmax=data_max,
                robust=robust,
            )
            plt.suptitle("Values")
        if plot_abs_bias:
            bias_array.plot(x=LON_LABEL, y=LAT_LABEL, robust=robust)
            plt.suptitle("Absolute bias")
        if plot_rel_bias:
            rel_bias_array.plot(x=LON_LABEL, y=LAT_LABEL, robust=robust)
            plt.suptitle("Relative bias")
        return to_plot, bias_array, rel_bias_array

    blank_data = np.empty(
        (len(da.coords[LON_LABEL]), len(da.coords[LAT_LABEL]), len(quantiles))
    )
    quantile_array = xr.DataArray(
        blank_data,
        coords=[da.coords[LAT_LABEL], da.coords[LON_LABEL], list(quantiles)],
        dims=[LAT_LABEL, LON_LABEL, "quantile"],
    )

    for i, q in enumerate(quantiles):
        quantile_array[:, :, i] = toqarray.quantile(q, dim="ensemble")

    if plot_value:
        quantile_array.plot(
            x=LON_LABEL,
            y=LAT_LABEL,
            col="quantile",
            cmap="viridis",
            vmin=data_min,
            vmax=data_max,
            robust=robust,
        )
        plt.suptitle("Values")
    bias_array = quantile_array - da.sel(ensemble=0)
    if plot_abs_bias:
        bias_array.plot(x=LON_LABEL, y=LAT_LABEL, col="quantile", robust=robust)
        plt.suptitle("Absolute bias")
    rel_bias_array = bias_array / da.sel(ensemble=0)
    if plot_rel_bias:
        rel_bias_array.plot(x=LON_LABEL, y=LAT_LABEL, col="quantile",
                            robust=robust)
        plt.suptitle("Relative bias")
    return quantile_array, bias_array, rel_bias_array

In [28]:
import cartopy.crs as ccrs
from cartopy.feature import BORDERS
def plot_changing_lookback(
    da: xr.DataArray,
    period: int,
    ensemble_length: int,
    ensemble_start: int,
    wf: wfs.WeightingFunctionType = wfs.no_weights,
    point = (150, 150),
    plot_mean: bool = True,
    plot_bias: bool = True,
    plot_rel_bias: bool = True,
    projection: ccrs.Projection = ccrs.PlateCarree(),
) -> None:
    offset = np.arange(0, ensemble_length)
    mean_array = xr.DataArray(
        np.empty((len(da[LAT_LABEL]), len(da[LON_LABEL]), len(offset), 3)),
        [da[LAT_LABEL], da[LON_LABEL], offset,
         ["mean", "bias", "relative bias"]],
        [LAT_LABEL, LON_LABEL, "offset", "variable"],
    )
    ensembles, _ = get_ensembles(da, period, ensemble_length, ensemble_start)

    spaghetti_array = xr.DataArray(
        np.empty(
            (
                len(ensembles.coords[TIME_LABEL]),
                len(ensembles.coords["ensemble"]),
                len(offset),
            )
        ),
        [ensembles.coords[TIME_LABEL], ensembles.coords["ensemble"], offset],
        [TIME_LABEL, "ensemble", "offset"],
    )
    # plt.figure()
    for i in offset:
        ensembles, weights = get_ensembles(da, period, ensemble_length - i,
                                           ensemble_start + i, look_back=i,
                                           wf=wf, time_label=TIME_LABEL)
        mean, bias, rel_bias = get_mean_data(ensembles[-1, :, :, :], weights)

        mean_array[:, :, i, 0] = mean.values
        mean_array[:, :, i, 1] = bias.values
        mean_array[:, :, i, 2] = rel_bias.values
        if ensembles.shape[-1] > spaghetti_array.shape[1]:
            spaghetti_array[:, :, i] = ensembles.isel(
                lat=point[1], lon=point[0]
            ).values[:, : spaghetti_array.shape[1]]
        else:
            spaghetti_array[:, :, i] = ensembles.isel(lat=point[1],
                                                      lon=point[0]).values

    proj_args = dict(projection=projection)

    if plot_mean:
        i = mean_array[:, :, :, 0].plot.imshow(
            transform=ccrs.PlateCarree(),
            x=LON_LABEL,
            y=LAT_LABEL,
            col="offset",
            robust=True,
            col_wrap=3,
            subplot_kws=proj_args,
        )
        for ax in plt.gcf().axes[:-1]:
            ax.coastlines()
    if plot_bias:
        i = mean_array[:, :, :, 1].plot.imshow(
            transform=ccrs.PlateCarree(),
            x=LON_LABEL,
            y=LAT_LABEL,
            col="offset",
            robust=True,
            col_wrap=3,
            center=0,
            subplot_kws=proj_args,
        )
        for ax in plt.gcf().axes[:-1]:
            ax.coastlines()
    if plot_rel_bias:
        i = mean_array[:, :, :, 2].plot.imshow(
            transform=ccrs.PlateCarree(),
            x=LON_LABEL,
            y=LAT_LABEL,
            col="offset",
            robust=True,
            col_wrap=3,
            center=0,
            subplot_kws=proj_args,
        )
        for ax in plt.gcf().axes[:-1]:
            ax.coastlines()

    spaghetti_array.plot.line(x=TIME_LABEL, col="offset", add_legend=False,
                              col_wrap=3)
def plot_predictions(
    da: xr.DataArray,
    prediction_date: str,
    start_dates,
    period: int,
    wf: wfs.WeightingFunctionIntermediateType = wfs.no_weights_intermediate,
    data_label: str = "data",
    mean_kwargs=None,
    bias_kwargs=None,
    mean_robust=True,
    bias_robust=True,
):
    if bias_kwargs is None:
        bias_kwargs = {}
    if mean_kwargs is None:
        mean_kwargs = {}
    start_indices, ensemble_lengths, end_index = get_ensemble_indices(
        da, prediction_date, start_dates, time_label=TIME_LABEL
    )

    mean_data = xr.DataArray(
        np.empty((len(da[LAT_LABEL]), len(da[LON_LABEL]), len(start_dates))),
        [da[LAT_LABEL], da[LON_LABEL], da[TIME_LABEL][start_indices]],
        [LAT_LABEL, LON_LABEL, "start date"],
    )
    bias_data = xr.DataArray(
        np.empty((len(da[LAT_LABEL]), len(da[LON_LABEL]), len(start_dates))),
        [da[LAT_LABEL], da[LON_LABEL], da[TIME_LABEL][start_indices]],
        [LAT_LABEL, LON_LABEL, "start date"],
    )
    climate_mean = (
        da[end_index % period :: period, :, :].mean(TIME_LABEL).values[:, :, np.newaxis]
    )
    for i, _ in enumerate(start_dates):
        ensembles, weights = get_ensembles(da, period, ensemble_lengths[i],
                                           start_indices[i], 0,
                                           wf(start_indices[i]),
                                           time_label=TIME_LABEL)
        mean, bias, _ = get_mean_data(ensembles[-1, :, :, :], weights)
        mean_data[:, :, i] = mean.values
        bias_data[:, :, i] = bias.values
    proj_args = dict(projection=ccrs.PlateCarree())
    mean_data.plot.imshow(
        x=LON_LABEL,
        y=LAT_LABEL,
        col="start date",
        robust=mean_robust,
        transform=ccrs.PlateCarree(),
        subplot_kws=proj_args,
        vmin=0,
        cmap="cividis",
        **mean_kwargs,
    )
    plt.suptitle(f"Mean {data_label} for "
                 f"{da[TIME_LABEL][end_index].values}", y=1)
    for ax in plt.gcf().axes[:-1]:
        ax.coastlines()
        ax.add_feature(BORDERS)

    (mean_data - climate_mean).plot.imshow(
        x=LON_LABEL,
        y=LAT_LABEL,
        col="start date",
        robust=mean_robust,
        transform=ccrs.PlateCarree(),
        subplot_kws=proj_args,
        cmap="BrBG",
        **bias_kwargs,
    )
    plt.suptitle(
        f"Anomaly {data_label} from climate mean for "
        f"{da[TIME_LABEL][end_index].values}",
        y=1,
    )
    for ax in plt.gcf().axes[:-1]:
        ax.coastlines()
        ax.add_feature(BORDERS)

    bias_data.plot.imshow(
        x=LON_LABEL,
        y=LAT_LABEL,
        col="start date",
        robust=bias_robust,
        cmap="BrBG",
        transform=ccrs.PlateCarree(),
        subplot_kws=proj_args,
        **bias_kwargs,
    )
    plt.suptitle(
        f"Anomaly {data_label} from observed for "
        f"{da[TIME_LABEL][end_index].values}",
        y=1,
    )
    for ax in plt.gcf().axes[:-1]:
        ax.coastlines()
        ax.add_feature(BORDERS)

In [29]:
def plot_ppmcc(
    da: xr.DataArray,
    prediction_date: str,
    start_dates,
    period: int,
    wf: wfs.WeightingFunctionIntermediateType = wfs.no_weights_intermediate,
):
    start_indices, ensemble_lengths, _ = get_ensemble_indices(
        da, prediction_date, start_dates, time_label=TIME_LABEL
    )
    ppmcc = xr.DataArray(
        np.empty((len(da[LAT_LABEL]), len(da[LON_LABEL]), len(start_dates))),
        [da[LAT_LABEL], da[LON_LABEL], da[TIME_LABEL][start_indices]],
        [LAT_LABEL, LON_LABEL, "start dates"],
    )
    rmse = xr.DataArray(
        np.empty((len(da[LAT_LABEL]), len(da[LON_LABEL]), len(start_dates))),
        [da[LAT_LABEL], da[LON_LABEL], da[TIME_LABEL][start_indices]],
        [LAT_LABEL, LON_LABEL, "start dates"],
    )
    hindcasts, _, observed = get_hindcasts_observed(
        da,
        ensemble_lengths,
        start_indices,
        period,
        wf,
        TIME_LABEL,
    )
    for i, start_index in enumerate(start_indices):
        ppmcc[:, :, i] = xr.corr(hindcasts[i], observed[i], "hindcast")
        square_error = (hindcasts[i] - observed[i]) ** 2
        rmse[:, :, i] = square_error.mean(dim="hindcast") ** 0.5
    (ppmcc ** 2).plot.imshow(
        x=LON_LABEL, y=LAT_LABEL, col="start dates", vmin=0, vmax=1,
        cmap="viridis"
    )
    rmse.plot.imshow(
        x=LON_LABEL, y=LAT_LABEL, col="start dates", robust=True, cmap="viridis"
    )

In [30]:
from general_tamsat_alert.roc_auc import get_roc_auc
def plot_roc_auc(
    da: xr.DataArray,
    prediction_date: str,
    start_dates,
    period: int,
    threshold_value: float = 0.2,
    wf: wfs.WeightingFunctionIntermediateType = wfs.no_weights_intermediate,
):
    roc_auc = get_roc_auc(
        da,
        prediction_date,
        start_dates,
        period,
        threshold_value,
        wf,
        50,
        TIME_LABEL,
    )
    roc_auc.plot(x=LAT_LABEL, y=LON_LABEL, col="start dates", cmap="plasma",
                 vmin=0.5)

In [31]:
def read_noaa_data_file(
    fname: str,
    time_axis: xr.DataArray = None,
    time_label: str = "time",
    replace_given_nan_value=True,
):
    """

    Data format (BNF for anyone that can read it):
    <ws-char> ::= " " | "\t"
    <ws> ::= <ws-char> | <ws_char> <ws>
    <ws-opt> ::= "" | <ws>
    <digit> ::= "0"|"1"|"2"|"3"|"4"|"5"|"6"|"7"|"8"|"9"
    <year> ::= <digit> <digit> <digit> <digit>
    <natural> ::= <digit> | <digit> <natural>
    <integer> ::= "-" <natural> | "+" <natural>
    <real> ::= <integer> "." <natural>
             | <real> "E" <integer>
             | <real> "e" <integer>
    <line-end> ::= <ws-opt> "\\n" | <ws-opt> "\\r\\n"

    <real-3> ::= <real> <ws> <real> <ws> <real>
    <real-12> ::= <real-3> <ws> <real-3> <ws> <real-3> <ws> <real-3>

    <any-str> ::= "" | <any-str> <*>

    <header> ::= <ws-opt> <year> <ws> <year>
    <line> ::= <ws-opt> <year> <ws> <real-12>
             | <ws-opt> <year> <ws> <real-12> <ws> <any-str>
    <data-matrix> ::= <line> | <line> <line-end> <data>

    <nan-value> ::= <real>

    <footer> ::= <any-str>

    <file-format> ::= <header> <line-end> <data-matrix> <line-end>
                      <nan-value> <line-end> <footer>

    Additionally:
     --  The <year> at the start of each <line> must increase
         sequentially from the first <year> in <header> to the last
         <year> in <header> inclusive.
     --  <*> indicates the wildcard character that matches any singular
         ASCII character
     --  All characters in the file *must* be valid ASCII characters

    :param fname:
    :param time_axis:
    :param time_label:
    :param replace_given_nan_value:
    :return:
    """
    with open(fname, "rt") as f:
        try:
            miny, maxy = f.readline().strip().split()
            miny = int(miny)
            maxy = int(maxy)
        except ValueError:
            raise ValueError("File does not contain start/end year on first "
                             "line")
        data = []
        for index, year in enumerate(range(miny, maxy + 1)):
            try:
                line = f.readline()
                line = line.strip().split()
                assert line[0] == str(year)
                line = [np.float64(i) for i in line[1:13]]
                data.extend(line)
            except ValueError:
                raise ValueError(f"Line {index+2} contains invalid number(s)")
            except AssertionError:
                raise ValueError(
                    f"Unexpected value {line[0]} at start of line {index+2}"
                )
        nan_value = np.float64(f.readline().strip())
        data = np.array(data)
        if replace_given_nan_value:
            data[(data <= nan_value + 0.000001)] = np.nan

        # Whether the time axis is start or end of month
        # (it is usually end of month)
        freq = "MS"
        start = f"{miny}-01-01"
        end = f"{maxy}-12-01"
        da = xr.DataArray(data,
                          [xr.date_range(start, end, freq=freq)],
                          [time_label])

    if time_axis is None:
        return da
    else:
        return da.interp({time_label: time_axis},
                         kwargs={"fill_value": "extrapolate"})

In [32]:
def get_weighting_roc_improvement(
    da: xr.DataArray,
    prediction_date: str,
    start_dates,
    period: int,
    threshold_value: float = 0.2,
    wf: wfs.WeightingFunctionIntermediateType = wfs.no_weights_intermediate,
):
    unweighted = get_roc_auc(
        da,
        prediction_date,
        start_dates,
        period,
        threshold_value,
        time_label=TIME_LABEL,
        integration_steps=5000,
    )
    weighted = get_roc_auc(
        da,
        prediction_date,
        start_dates,
        period,
        threshold_value,
        wf,
        integration_steps=5000,
        time_label=TIME_LABEL,
    )
    unweighted.plot.imshow(x=LON_LABEL, y=LAT_LABEL, col="start dates")
    weighted.plot.imshow(x=LON_LABEL, y=LAT_LABEL, col="start dates")
    anomaly = weighted - unweighted
    anomaly.plot.imshow(
        x=LON_LABEL, y=LAT_LABEL, col="start dates", cmap="BrBG", vmin=-0.5,
        vmax=0.5
    )
    rel_anomaly = anomaly / (1 - unweighted)
    rel_anomaly.plot.imshow(
        x=LON_LABEL, y=LAT_LABEL, col="start dates", cmap="BrBG", robust=True
    )

def process_data(
    ensemble_length: int,
    field: str,
    weighting_data_file: str = None,
    ensemble_start: int = None,
) -> None:
    if weighting_data_file is not None:
        weighting_data = read_noaa_data_file(
            weighting_data_file, ds.coords[TIME_LABEL], TIME_LABEL
        )
    else:
        weighting_data = np.ones(ds.coords[TIME_LABEL].shape)

    print(weighting_data)
    
    get_weighting_roc_improvement(
        ds[field],
        "2012-07-31",
        ["2012-04-30", "2012-05-31", "2012-06-30", "2012-07-16", "2012-07-31"],
        period,
        wf=wfs.weight_value_builder(weighting_data),
    )
    get_weighting_roc_improvement(
        ds[field],
        "2012-07-31",
        ["2012-04-30", "2012-05-31", "2012-06-30", "2012-07-16", "2012-07-31"],
        period,
        wf=wfs.weight_time_builder(period),
    )
    plot_predictions(
        ds[field],
        "2001-03-01",
        ["2000-12-01", "2001-01-01", "2001-02-01", "2001-02-15", "2001-03-01"],
        period,
        wf=wfs.weight_time_builder(period),
        mean_kwargs={"vmin": 0},
    )
    plot_ppmcc(
        ds[field],
        "2012-07-31",
        ["2012-04-30", "2012-05-31", "2012-06-30", "2012-07-16", "2012-07-31"],
        period,
    )
    
    plot_roc_auc(
        ds[field][:, :, :],
        "2012-07-31",
        ["2012-04-30", "2012-05-31", "2012-06-30", "2012-07-16", "2012-07-31"],
        period,
    )

    plot_roc_auc(
        ds[field][:, :, :],
        "2012-07-31",
        ["2011-07-31"],
        period,
        wf=wfs.weight_time_builder(period),
    )
    plt.show()


In [33]:
process_data(6, 'ndvi', '../data/oni.data')

<xarray.DataArray (time: 962)>
array([ 4.19354839e-03,  6.61290323e-02,  1.30000000e-01,  1.85714286e-01,
        3.16451613e-01,  4.60967742e-01,  5.58666667e-01,  6.53666667e-01,
        6.87096774e-01,  7.18064516e-01,  7.52666667e-01,  7.87666667e-01,
        9.16451613e-01,  1.06096774e+00,  1.30032258e+00,  1.56354839e+00,
        1.76200000e+00,  1.95700000e+00,  2.06483871e+00,  2.17322581e+00,
        2.20333333e+00,  2.22833333e+00,  2.20741935e+00,  2.18161290e+00,
        2.06258065e+00,  1.92838710e+00,  1.73000000e+00,  1.55357143e+00,
        1.42709677e+00,  1.29806452e+00,  1.18266667e+00,  1.06766667e+00,
        9.06451613e-01,  7.30967742e-01,  5.28666667e-01,  3.23666667e-01,
        1.33870968e-01, -6.74193548e-02, -2.51612903e-01, -4.47741935e-01,
       -6.23333333e-01, -7.98333333e-01, -8.95806452e-01, -9.93870968e-01,
       -9.58000000e-01, -9.13000000e-01, -7.70000000e-01, -6.10000000e-01,
       -5.18709677e-01, -4.25806452e-01, -3.81379310e-01, -3.42758621

NameError: name 'get_mean_data' is not defined