In [None]:
import os
import json
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

In [None]:
import holoviews as hv
import hvplot.xarray
import numpy as np
import pandas as pd
import tensorflow as tf
import xarray as xr

hv.extension('bokeh', logo=False)

In [None]:
from re_nobm_pcc.kit import DATA_DIR, TAXA
# from pathlib import Path
# DATA_DIR = Path('../data')

In [None]:
def hexbin(ds):
    hv.output(size=120)
    plots = {}
    for item in TAXA:
        plots[item] = (
            hv.HexTiles(
                data=(ds[item], ds[f'{item}_hat']),
                kdims=['y', 'y_hat'],
            ).options(
                logz=True,
                cmap='greens',
                bgcolor='lightskyblue',
                tools=['hover'],
                padding=0.001,
                aspect='square',
            )
            *hv.Slope(1, 0).options(
                color='darkorange',
                line_width=1,
            )
        )
    return (
        hv.HoloMap(plots, kdims='group').layout().cols(2).options(shared_axes=False)
    )


def roc(ds):
    n = ds.sizes['pxl']
    plots = {}
    for item in TAXA:
        
        order = ds[f'{item}_hat'].argsort()
        false_neg = np.insert(ds[item].cumsum(), 0, 0)
        pos = false_neg[-1]
        neg = n - pos
        true_neg = np.arange(n + 1) - false_neg
        true_pos_rate = 1 - false_neg / pos
        false_pos_rate = 1 - true_neg / neg
        plots[item] = hv.Curve(
            (false_pos_rate, true_pos_rate), 'False Positive Rate', 'True Positive Rate'
        )
    return hv.HoloMap(plots, kdims='group').overlay().opts(legend_position='bottom_right')

## Model

In [None]:
model = tf.keras.models.load_model(
    DATA_DIR/'model',
    compile=False,
)

In [None]:
model.summary()

## Loss by Epoch

In [None]:
fit = xr.Dataset({
    k: ('epoch', v) for k, v in np.load(DATA_DIR/'fit.npz').items()
})
plt = (
    fit.hvplot.line(x='epoch', y=['loss', 'val_loss'], logy=True).options('Curve', color='black')
    + hv.Overlay(tuple(
        fit.hvplot.line(x='epoch', y=[f'abundance_{i}_loss', f'val_abundance_{i}_loss'], logy=True)
        for i in TAXA
    )).options('Curve', color=hv.Cycle(
        ['blue', 'blue', 'orange', 'orange', 'red', 'red', 'green', 'green']
    ))
)
(
    plt
    .options(shared_axes=False)
    .options('Curve', line_dash=hv.Cycle(['solid', 'dashed']))
    .cols(1)
)

## Test: Metrics

In [None]:
with (DATA_DIR/'metrics.json').open() as stream:
    metrics = json.load(stream)

In [None]:
f"Test loss: {metrics['loss']}"

In [None]:
table = (
    pd.DataFrame.from_dict(
        {tuple(k.split('_'))[-2:]: [v] for k, v in metrics.items()},
        orient='columns',
    )
    .stack(level=0).droplevel(0)
)
columns = ['loss', 'AUC', 'ME', 'MAE', 'RMSE', 'R2']
table = pd.concat((pd.DataFrame(columns=columns), table))
table[columns]

## Test: True vs. Predicted

In [None]:
test = tf.data.Dataset.load(str(DATA_DIR/'test'))
test = test.batch(test.cardinality())
_, y = next(test.as_numpy_iterator())

In [None]:
y_hat = {
    model.outputs[i].node.layer.name: item
    for i, item in enumerate(model.predict(test, verbose=0))
}

In [None]:
test = xr.Dataset(
    {i: (('pxl',), y[f'abundance_{i}']) for i in TAXA}
)
if table['AUC'].any():
    test = xr.merge(
        (
            xr.Dataset(
                {f'{i}_presence_hat': (('pxl',), y_hat[f'presence_{i}'].flatten()) for i in TAXA}
            ),
            xr.Dataset(
                {f'{i}_hat': (('pxl',), (
                    (y_hat[f'abundance_{i}'][:, :1] * (y_hat[f'presence_{i}'] > 0)).flatten()
                )) for i in TAXA}
            ),
            test
        )
    )
    roc(test)
else:
    test = xr.merge(
        (
            xr.Dataset(
                {f'{i}_hat': (('pxl',), y_hat[f'abundance_{i}'][:, :1].flatten()) for i in TAXA}
            ),
            test
        )
    )

In [None]:
hexbin(test)

In [None]:
hexbin(np.log10(test))