In [None]:
import os
import json
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

In [None]:
import holoviews as hv
import hvplot.xarray
import numpy as np
import pandas as pd
import tensorflow as tf
import xarray as xr

hv.extension('bokeh', logo=False)

In [None]:
from re_nobm_pcc.kit import DATA_DIR, TAXA
from re_nobm_pcc.learn import Truncated_MAE

In [None]:
def hexbin(ds):
    hv.output(size=120)
    plots = {}
    for item in TAXA:
        plots[item] = (
            hv.HexTiles(
                data=(ds[item], ds[f'{item}_hat']),
                kdims=['y', 'y_hat'],
            ).options(
                logz=True,
                cmap='greens',
                bgcolor='lightskyblue',
                tools=['hover'],
                padding=0.001,
                aspect='square',
            )
            *hv.Slope(1, 0).options(
                color='darkorange',
                line_width=1,
            )
        )
    return (
        hv.HoloMap(plots, kdims='group').layout().cols(2).options(shared_axes=False)
    )

## Model

In [None]:
model = tf.keras.models.load_model(
    DATA_DIR/'model',
    compile=False,
)

In [None]:
model.summary()

## Loss by Epoch

In [None]:
(
    xr.Dataset({
        k: ('epoch', v) for k, v in np.load(DATA_DIR/'fit.npz').items()
    })
    .hvplot.line(
        x='epoch',
        y=['loss', 'val_loss'],
        logy=True,
    )
)

## Test: Metrics

In [None]:
with (DATA_DIR/'metrics.json').open() as stream:
    metrics = json.load(stream)
loss = metrics.pop('loss')

In [None]:
print('Test loss: {}'.format(loss[0]))

In [None]:
table = (
    pd.DataFrame.from_dict(
        {tuple(k.split('_')): [v] for k, v in metrics.items()},
        orient='columns',
    )
    .stack(level=0)
)
table.index = table.index.droplevel(0)
table.loc[TAXA, ['ME', 'MAE', 'RMSE', 'R2']]

## Test: True vs. Predicted

In [None]:
test = tf.data.Dataset.load(str(DATA_DIR/'test'))
test = test.batch(test.cardinality())
_, y = next(test.as_numpy_iterator())

In [None]:
y_hat = {
    model.outputs[i].node.layer.name: item
    for i, item in enumerate(model.predict(test, verbose=0))
}

In [None]:
test = xr.merge(
    (
        xr.Dataset(
            {f'{i}_presence_hat': (('pxl',), y_hat[f'presence_{i}'].flatten()) for i in TAXA}
        ),
        xr.Dataset(
            {f'{i}_hat': (('pxl',), (
                (y_hat[f'abundance_{i}'] * (y_hat[f'presence_{i}'] > 0)).flatten()
            )) for i in TAXA}
        ),
        xr.Dataset(
            {i: (('pxl',), y[f'abundance_{i}']) for i in TAXA}
        )
    )
)
# test.max()

In [None]:
x =  np.array([ 1, -4, 32, -2,  5])
y =  np.array([ 0,  0,  1,  1,  1])

order = x.argsort()
fn = np.insert(y[order].cumsum(), 0, 0) # false negative count
p = fn[-1]
n = len(x) - p

tn = np.arange(len(fn)) - fn
tpr = 1 - fn / p
fpr = 1 - tn / n
#hv.Curve((fpr, tpr)).opts(padding=0.05)

In [None]:
n = test.sizes['pxl']
order = test['dia_presence_hat'].argsort()
result = (test['dia'] > 0)[order]
rate = result.cumsum() / n

In [None]:
hexbin(test)

In [None]:
hexbin(np.log10(test))