In [None]:
import os
from pathlib import Path

import jax
import orbax.checkpoint as ocp
from jax import numpy as jnp
import seaborn as sns

from blooms_ml.configs import default
from blooms_ml.learning import create_train_state
from blooms_ml.utils import (
    get_dataframe,
    get_stats,
)

In [None]:
sns.set_style("whitegrid")

### Get & Prepare data

In [None]:
datadir = os.path.join(Path.home(), "data_ROHO")
(p1_c_mean, n1_p_mean, n3_n_mean, n5_s_mean,
 p1_c_std, n1_p_std, n3_n_std, n5_s_std) = get_stats(os.path.join(datadir, "cnps_mean_std.csv"))

In [None]:
df = get_dataframe(datadir)
df = df[df['ocean_time'] > '2013-01-01']  # keep test data only
df = df[df['y'].notna()]

In [None]:
df_station = df[df['station'] == 0][df['s_rho'] == -0.02]
df_station.set_index('ocean_time', inplace=True)

In [None]:
df_station

In [None]:
df_obs = df_station.drop(columns=['station', 's_rho', 'P1_c', 'rho', 'y'])
# "normalize"
df_obs['N1_p'] = ((df_obs['N1_p'] - float(n1_p_mean)) / float(n1_p_std)).round(2).astype('float32')
df_obs['N3_n'] = ((df_obs['N3_n'] - float(n3_n_mean)) / float(n3_n_std)).round(2).astype('float32')
df_obs['N5_s'] = ((df_obs['N5_s'] - float(n5_s_mean)) / float(n5_s_std)).round(2).astype('float32')

In [None]:
df_obs

In [None]:
data = {
    'label': df_obs['label'].values,
    'observations': jnp.float32(df_obs.drop(columns=['label']).values),
}

### Inference

In [None]:
checkpointdir = os.path.join(Path.home(), "blooms-ml_results/q2nm_vgs/checkpoint")
orbax_checkpointer = ocp.StandardCheckpointer()
rng = jax.random.key(0)
rng, init_rng = jax.random.split(rng)
config = default.get_config()

In [None]:
state = create_train_state(init_rng, config, data['observations'].shape)
abstract_my_tree = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, state)

In [None]:
state = orbax_checkpointer.restore(checkpointdir, args=ocp.args.StandardRestore(abstract_my_tree))

In [None]:
logits = state.apply_fn({"params": state.params}, data['observations'])

### Visualization

In [None]:
df_station['prediction'] = jax.nn.softmax(logits)[:, 1]

In [None]:
df_station['prediction'].plot(figsize=(14, 5))

In [None]:
df_station['P1_c'].plot(figsize=(14, 5))

In [None]:
df_station['label'].plot(figsize=(14, 5))