## Imports

In [None]:
import os
import json
from importlib import reload

import holoviews as hv
import numpy as np
import xarray as xr

from re_nobm_pcc import DATA_DIR, TAXA
from re_nobm_pcc import viz

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
hv.extension('bokeh')

In [None]:
import tensorflow as tf

In [None]:
viz = reload(viz)

## Model

In [None]:
network = tf.keras.models.load_model(
    DATA_DIR/'network',
    compile=False,
)
network.summary()

## Loss by Epoch

In [None]:
fit = xr.Dataset({
    k: ('epoch', v) for k, v in np.load(DATA_DIR / 'fit.npz').items()
})
viz.loss(fit)

In [None]:
fit

## Test: Metrics

In [None]:
with (DATA_DIR/'metrics.json').open() as stream:
    metrics = json.load(stream)

In [None]:
f"Test loss: {metrics['loss']}"

In [None]:
table = (
    pd.DataFrame.from_dict(
        {tuple(k.split('_'))[-2:]: [v] for k, v in metrics.items()},
        orient='columns',
    )
    .stack(level=0).droplevel(0)
)
columns = ['loss', 'AUC', 'ME', 'MAE', 'RMSE', 'R2']
table = pd.concat((pd.DataFrame(columns=columns), table))
table[columns]

## Test: True vs. Predicted

In [None]:
test = tf.data.Dataset.load(str(DATA_DIR/'test'))
test = test.batch(test.cardinality())
_, y = next(test.as_numpy_iterator())

In [None]:
y_hat = {
    model.outputs[i].node.layer.name: item
    for i, item in enumerate(model.predict(test, verbose=0))
}

In [None]:
test = xr.Dataset(
    {i: (('pxl',), y[f'abundance_{i}']) for i in TAXA}
)
if table['AUC'].any():
    test = xr.merge(
        (
            xr.Dataset(
                {f'{i}_presence_hat': (('pxl',), y_hat[f'presence_{i}'].flatten()) for i in TAXA}
            ),
            xr.Dataset(
                {f'{i}_hat': (('pxl',), (
                    (y_hat[f'abundance_{i}'][:, :1] * (y_hat[f'presence_{i}'] > 0)).flatten()
                )) for i in TAXA}
            ),
            test
        )
    )
    roc(test)
else:
    test = xr.merge(
        (
            xr.Dataset(
                {f'{i}_hat': (('pxl',), y_hat[f'abundance_{i}'][:, :1].flatten()) for i in TAXA}
            ),
            test
        )
    )

In [None]:
hexbin(test)

In [None]:
hexbin(np.log10(test))