## Imports

In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'

In [None]:
import json
from importlib import reload

import holoviews as hv
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_datasets as tfds
import xarray as xr

from re_nobm_pcc import DATA_DIR, TAXA
from re_nobm_pcc import learn, viz

hv.extension('bokeh')

## Model

In [None]:
DATA_DIR = DATA_DIR / '../.dvc/tmp/exps/standalone/tmpq_rtdiex/data'
DATA_DIR

In [None]:
learn = reload(learn)
network = learn.make_network(6)
network.load_weights(DATA_DIR / 'fit' / 'epoch-210')

In [None]:
DATA_DIR = DATA_DIR/'../.dvc/tmp/exps/standalone/tmpq_rtdiex/data'
network = tf.keras.models.load_model(DATA_DIR/'network', compile=False)
network.summary()

## Loss by Epoch

In [None]:
fit = xr.Dataset({
    k: ('epoch', v) for k, v in np.load(DATA_DIR / 'fit.npz').items()
})
offset = 1 - min(*tuple(v.item() for k, v in fit.min().items()))
viz.loss(fit, offset)

In [None]:
# network.load_weights(DATA_DIR / 'fit' / 'epoch-90')

## Test: Metrics

In [None]:
fix_index = pd.Series(['mean'] + [str(i+1) for i in range(5)], index=TAXA)
with (DATA_DIR/'metrics.json').open() as stream:
    metrics = json.load(stream)

In [None]:
table = (
    pd.DataFrame.from_dict(
        {tuple(k.split('_'))[-2:]: [v] for k, v in metrics.items()},
        orient='columns',
    )
    .stack(level=0).droplevel(0)
)
columns = ['ME', 'MAE', 'RMSE', 'R2']
table = pd.concat((pd.DataFrame(columns=columns), table))
table = table.loc[fix_index]
table.index = fix_index.index
table[columns]

## Test: True vs. Predicted

In [None]:
dataset = tfds.builder('rrs_day_tfds', data_dir=DATA_DIR)
test = dataset.as_dataset(split='split[9%:10%]', as_supervised=True)
test = test.batch(2 ** 16)
y_true = []
y_mu = []
y_sd = []
y_sample = []
for item in test:
    y_true.append(item[1].numpy())
    y_pred = network(item[0])
    y_mu.append(y_pred.mean().numpy())
    y_sd.append(y_pred.stddev().numpy())
    y_sample.append(y_pred.sample().numpy())
y_true = np.concatenate(y_true)
y_mu = np.concatenate(y_mu)
y_sd = np.concatenate(y_sd)
y_sample = np.concatenate(y_sample)

In [None]:
MAXCV = 1
ds = xr.Dataset({
    'y_true': (('pxl', 'phy'), y_true),
    'y_pred': (('pxl', 'phy'), y_sample),
    'mask_high_cv': (('pxl', 'phy'), (y_sd / (y_mu + 1e-32)) < MAXCV),
    'phy': ('phy', list(TAXA)),
})
ds['mask_high_cv'].sum('pxl') / ds.sizes['pxl']

In [None]:
i = 4
phy = ds.sel({'phy': TAXA[i]})
phy = phy[['y_true', 'y_pred']].where(phy['mask_high_cv'])
(
    hv.HexTiles(data=(phy['y_pred'], phy['y_true']), kdims=['prediction', 'truth'])
    .options(
        title=TAXA[i],
        logz=True,
        tools=['hover'],
        # padding=0.001,
        aspect=1,
        fontscale=1.4,
        colorbar=True,
    )
    * hv.Slope(1, 0).options(color='red', line_width=1.5)
)

In [None]:
viz.hexbin(np.log10(ds))