In [None]:
import os
from pathlib import Path

import jax
import matplotlib.pyplot as plt
import orbax.checkpoint as ocp
import pandas as pd
from jax import numpy as jnp

from blooms_ml.configs import classification
from blooms_ml.learning import create_train_state
from blooms_ml.utils import (
    labeling_binary_incremented,
)

### Get & Prepare data

In [None]:
datadir = os.path.join(Path.home(), "data_ROHO")
df = pd.read_parquet(os.path.join(datadir, "roho800_weekly_average_stacked.parquet"))
df = df.groupby(["station", "s_rho"]).apply(labeling_binary_incremented, include_groups=False)
df = df.reset_index().drop(columns="level_2")

In [None]:
df

In [None]:
df = df[df['ocean_time'] > '2013-01-01']  # keep test data only
df = df[df['label'].notna()]

In [None]:
df['s_rho'].unique()

In [None]:
df['station'].unique()

In [None]:
station_number = 0

In [None]:
df_station = df[df['station'] == station_number][df['s_rho'] == -0.02]
df_station.set_index('ocean_time', inplace=True)

In [None]:
df_station

In [None]:
df_obs = df_station.drop(columns=['station', 's_rho'])

In [None]:
df_obs

In [None]:
data = {
    'label': df_obs['label'].values,
    'observations': jnp.float32(df_obs.drop(columns=['label']).values),
}

### Inference

In [None]:
checkpointdir = os.path.join(Path.home(), "blooms-ml_results/b96ucher/chkpt_epoch_030")
orbax_checkpointer = ocp.StandardCheckpointer()
rng = jax.random.PRNGKey(0)  # jax.random.key(0)
rng, init_rng = jax.random.split(rng)

In [None]:
state = create_train_state(init_rng, classification(), data['observations'].shape)
abstract_my_tree = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, state)

In [None]:
state = orbax_checkpointer.restore(checkpointdir, args=ocp.args.StandardRestore(abstract_my_tree))

In [None]:
logits = state.apply_fn({"params": state.params}, data['observations'])

### Visualization

In [None]:
df_station['prediction'] = jax.nn.softmax(logits)[:, 1]
df_label = df_station[df_station['label'] == 1]

In [None]:
_, ax1 = plt.subplots(figsize=(20, 5))
ax1.plot(df_station.index, df_station['P1_c'], 'b-')
ax1.plot(df_label.index, df_label['P1_c'], 'r.')
ax1.set_ylabel('Carbon', color='b')

ax2 = ax1.twinx()
ax2.plot(df_station.index, df_station['prediction'], 'c-')
ax2.set_ylabel('Prediction probability', color='c')

# savepath = os.path.join(Path.home(), f"tmp/blooms-ml_results/station_{station_number:04}.png")
# plt.savefig(savepath, dpi=300)