In [None]:
import os
from pathlib import Path

import jax
import matplotlib.pyplot as plt
import orbax.checkpoint as ocp
import pandas as pd
import seaborn as sns
from jax import numpy as jnp

from blooms_ml.configs import regression
from blooms_ml.learning import apply_regression_model, create_train_state
from blooms_ml.utils import get_stats, labeling

sns.set_style("whitegrid")

### Get & Prepare data

In [None]:
datadir = os.path.join(Path.home(), "data_ROHO")
(p1_c_mean, n1_p_mean, n3_n_mean, n5_s_mean,
 p1_c_std, n1_p_std, n3_n_std, n5_s_std) = get_stats(os.path.join(datadir, "cnps_mean_std.csv"))

In [None]:
df = pd.read_parquet(os.path.join(datadir, "roho800_weekly_average_stacked.parquet"))
df= df[df['ocean_time'] < '2013-01-01']

In [None]:
df

In [None]:
df = df.groupby(['station', 's_rho']).apply(labeling, include_groups=False)
df = df.reset_index().drop(columns='level_2')
df.rename(columns={'label': 'y'}, inplace=True)
df = df[df['y'].notna()]

In [None]:
df['s_rho'].unique()

In [None]:
df['station'].unique()

In [None]:
station_number = 0

In [None]:
df_station = df[df['station'] == station_number][df['s_rho'] == -0.02]
df_station.set_index('ocean_time', inplace=True)

In [None]:
df_obs = df_station.drop(columns=['station', 's_rho'])

In [None]:
data = {
    'y': df_obs['y'].values,
    'observations': jnp.float32(df_obs.drop(columns=['y']).values),
}

### Inference

In [None]:
checkpointdir = os.path.join(Path.home(), "blooms-ml_results/2tvek1q3/checkpoint")
orbax_checkpointer = ocp.StandardCheckpointer()
rng = jax.random.key(0)
rng, init_rng = jax.random.split(rng)

In [None]:
state = create_train_state(init_rng, regression(), data['observations'].shape)
abstract_my_tree = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, state)

In [None]:
state = orbax_checkpointer.restore(checkpointdir, args=ocp.args.StandardRestore(abstract_my_tree))

In [None]:
predictions = state.apply_fn({"params": state.params}, data['observations'])

### Visualization

In [None]:
df_station['prediction'] = predictions[:, 1]

In [None]:
grads, loss = apply_regression_model(
    state, data['observations'], data['y']
    )

In [None]:
loss

In [None]:
_, ax1 = plt.subplots(figsize=(20, 5))
ax1.plot(df_station.index, df_station['y'].clip(lower=-1, upper=1), 'b-')
ax1.plot(df_station.index, df_station['prediction'], 'r.')
ax1.set_ylabel('Carbon', color='b')

# savepath = os.path.join(Path.home(), f"tmp/blooms-ml_results/station_{station_number:04}.png")
# plt.savefig(savepath, dpi=300)