In [None]:
import os
from pathlib import Path

import jax
import matplotlib.pyplot as plt
import numpy as np
import orbax.checkpoint as ocp
import pandas as pd
from jax import numpy as jnp

from blooms_ml.configs import classification_model
from blooms_ml.learning import create_train_state
from blooms_ml.utils import normalize_columns
from blooms_ml.utils_ferrybox import add_previous, get_dataframe_ferrybox2002to2018, get_ferrytracks, get_rivers

### Get & Prepare data

In [None]:
datadir = os.path.join(Path.home(), "blooms-ml_data")
dfs = get_ferrytracks(datadir)
df = get_dataframe_ferrybox2002to2018(dfs, normalize=False)
df_rivers = get_rivers(datadir)
df_merged = pd.merge_asof(df, df_rivers, on="timestamps", direction="forward")
df_merged = df_merged.dropna().reset_index(drop=True)
df_merged = normalize_columns(df_merged, slice(3, None))
df_stacked = add_previous(df_merged)
# split
df_train = df_stacked[df_stacked["timestamps"] < "2015-01-01"]
df_test = df_stacked[df_stacked["timestamps"] > "2015-01-01"]

In [30]:
df_year = df_test[df_test["timestamps"].dt.year == 2016]

In [31]:
data = {
    'label': df_year['labels'].values,
    'observations': jnp.float32(df_year.drop(columns=["timestamps", "fluorescence", "labels"]).values),
}

### Inference

In [32]:
checkpointdir = os.path.join(Path.home(), "blooms-ml_results/rhd0vg41/chkpt_epoch_010")
orbax_checkpointer = ocp.StandardCheckpointer()
rng = jax.random.PRNGKey(0)  # jax.random.key(0)
rng, init_rng = jax.random.split(rng)

In [33]:
state = create_train_state(init_rng, classification_model(), data['observations'].shape)
abstract_my_tree = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, state)

In [34]:
state = orbax_checkpointer.restore(checkpointdir, args=ocp.args.StandardRestore(abstract_my_tree))

In [35]:
logits = state.apply_fn({"params": state.params}, data['observations'])

In [36]:
prediction = jax.nn.softmax(logits)[:, 1]

### Visualization

In [None]:
df_year["prediction"] = np.where(prediction > 0.5, 1, 0)
df_year.set_index('timestamps', inplace=True)
df_label = df_year[df_year["labels"] == 1]
df_prediction = df_year[df_year["prediction"] == 1]

In [None]:
_, ax1 = plt.subplots(figsize=(20, 5))
ax1.plot(df_year.index, df_year['fluorescence'], 'b-')
ax1.plot(df_label.index, df_label['fluorescence'], 'r^')
ax1.set_ylabel('Fluorescence', color='b')

ax2 = ax1.twinx()
ax2.plot(df_prediction.index, df_prediction['prediction'], 'cv')
ax2.set_ylabel('Prediction', color='c')

ax3 = ax1.twinx()
ax3.spines['right'].set_position(('outward', 60))
ax3.plot(df_year.index, df_year['Solbergfoss'], 'g-')
ax3.set_ylabel('Glomma discharge', color='g')
