In [None]:
import os

import holoviews as hv
import hvplot.xarray
import numpy as np
import tensorflow as tf
import xarray as xr

from re_nobm_pcc.kit import DATA_DIR

os.chdir(os.environ['PWD'])

## Summary

In [None]:
model = tf.keras.models.load_model(DATA_DIR/'model')

In [None]:
model.summary()

## Training

In [None]:
(
    xr.Dataset({
        k: ('epoch', v) for k, v in np.load(DATA_DIR/'fit.npz').items()
    })
    .hvplot.line(
        logy=True,
        ylabel=model.loss.name,
    )
)

## Testing

In [None]:
test = tf.data.experimental.load(str(DATA_DIR/'test')).batch(256)

In [None]:
y_hat = model.predict(test)

In [None]:
(
    xr.Dataset({
        k: ('epoch', v) for k, v in np.load(DATA_DIR/'fit.npz').items()
    })
    .hvplot.line(
        logy=True,
        ylabel=model.loss.name,
    )
)

In [None]:
predict = model(test['x'].values, training=False)
full = xr.Dataset({
    'estimate': (('pxl', 'component'), predict),
    'target': test['y'],
    'loss': (tuple(), model.loss(test['y'].values, predict).numpy()),
})

In [None]:
list(full['component'].data)

In [None]:
hv.extension('bokeh')
elements = []
for item in full['component'].data:
    elements.append(
        full
        .sel(component=item)
        .hvplot
        .hexbin(
            x='target',
            y='estimate',
            aspect='equal',
            frame_height=300,
            frame_width=400,
            label=item,
        )
    )
hv.Layout(elements).opts(hv.opts.Layout(shared_axes=False)).cols(2)