In [None]:
import os
import json
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import holoviews as hv
import hvplot.xarray
import numpy as np
import pandas as pd
import tensorflow as tf
import xarray as xr

from re_nobm_pcc.kit import DATA_DIR, TAXA

hv.extension('bokeh', logo=False)

## Model

In [None]:
model = tf.keras.models.load_model(DATA_DIR/'model')

In [None]:
model.summary()

## Loss by Epoch

In [None]:
(
    xr.Dataset({
        k: ('epoch', v) for k, v in np.load(DATA_DIR/'fit.npz').items()
    })
    .hvplot.line(
        x='epoch',
        y=['loss', 'val_loss'],
        logy=True,
        ylabel=model.loss.name,
    )
)

## Test: Metrics

In [None]:
with (DATA_DIR/'metrics.json').open() as stream:
    metrics = json.load(stream)
del metrics['loss']
table = (
    pd.DataFrame.from_dict(
        {tuple(k.split('_')): [v] for k, v in metrics.items()},
        orient='columns',
    )
    .stack(level=0)
)
table.index = table.index.droplevel(0)
table.loc[TAXA, ['ME', 'MAE', 'RMSE', 'R2']]

## Test: Correlations

In [None]:
test = tf.data.Dataset.load(str(DATA_DIR/'test'))
test = test.batch(test.cardinality())
_, y = next(test.as_numpy_iterator())

In [None]:
y_hat = model.predict(test, verbose=0)

In [None]:
test = xr.merge(
    (
        xr.Dataset({f'{item}_hat': (('pxl',), y_hat[i].flatten()) for i, item in enumerate(TAXA)}),
        xr.Dataset({f'{item}': (('pxl',), y[i]) for i, item in enumerate(TAXA)})
    )
)
test.min()

### True vs. Pred

In [None]:
logtest = np.log10(test)
plots = {}
for item in TAXA:
    plots[item] = (
        hv.HexTiles(
            data=(logtest[item], logtest[f'{item}_hat']),
            kdims=['y', 'y_hat'],
        ).opts(shared_axes=False, cmap='greens', bgcolor='lightskyblue', tools=['hover'], padding=0.001)
        *hv.Slope(1, 0).opts(shared_axes=False, color='darkorange', line_width=1)
    )
hv.output(size=150)
(
    hv.HoloMap(plots, kdims='group').opts(shared_axes=False).layout().cols(2)
)

### True > 10e-5  vs Pred

In [None]:
zoomtest = np.log10(test.where(test > 10e-5))
plots = {}
for item in TAXA:
    plots[item] = (
        hv.HexTiles(
            data=(zoomtest[item], zoomtest[f'{item}_hat']),
            kdims=['y', 'y_hat'],
        ).opts(shared_axes=False, cmap='greens', bgcolor='lightskyblue', tools=['hover'], padding=0.001)
        *hv.Slope(1, 0).opts(shared_axes=False, color='darkorange', line_width=1)
    )
hv.output(size=150)
(
    hv.HoloMap(plots, kdims='group').opts(shared_axes=False).layout().cols(2)
)