In [None]:
import sys, os
from pyprojroot import here


# spyder up to find the root

root = here(project_files=[".root"])
local = here(project_files=[".local"])

# append to path
sys.path.append(str(root))
sys.path.append(str(local))

In [None]:
from pathlib import Path
import argparse
import wandb
import tqdm
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)
# import hvplot.xarray

import tensorflow as tf
import tensorflow_datasets as tfd
import jax
import jax.random as jrandom
import jax.numpy as jnp
import equinox as eqx
from ml4ssh._src.io import load_object, save_object
from ml4ssh._src.viz import create_movie, plot_psd_spectrum, plot_psd_score
from ml4ssh._src.utils import get_meshgrid, calculate_gradient, calculate_laplacian

# import parsers
from data import get_data_args, load_data
from preprocess import add_preprocess_args, preprocess_data
from features import add_feature_args, feature_transform
from split import add_split_args, split_data
from models.gp_tf import add_model_args, get_likelihood, get_kernel, get_inducing_points
from loss import add_loss_args, get_loss_fn
from logger import add_logger_args
from optimizer import add_optimizer_args, get_optimizer
from postprocess import add_postprocess_args, postprocess_data, generate_eval_data
from evaluation import add_eval_args, get_rmse_metrics, get_psd_metrics
from smoke_test import add_winter_smoke_test_args, add_january_smoke_test_args

%matplotlib inline
%load_ext autoreload
%autoreload 2

### Arguments

In [None]:
parser = argparse.ArgumentParser()

# logger
parser = add_logger_args(parser)

# data
parser = get_data_args(parser)

# preprocessing, feature transform, split
parser = add_preprocess_args(parser)
parser = add_feature_args(parser)
parser = add_split_args(parser)

# model, optimizer, loss
parser = add_model_args(parser)
parser = add_optimizer_args(parser)
parser = add_loss_args(parser)

# postprocessing, metrics
parser = add_postprocess_args(parser)
parser = add_eval_args(parser)

# parse args
args = parser.parse_args([])

# # jeanzay specific
# args.train_data_dir = "/gpfsdswork/projects/rech/cli/uvo53rl/data/data_challenges/ssh_mapping_2021/train/"
# args.ref_data_dir = "/gpfsdswork/projects/rech/cli/uvo53rl/data/data_challenges/ssh_mapping_2021/ref/"
# args.test_data_dir = "/gpfsdswork/projects/rech/cli/uvo53rl/data/data_challenges/ssh_mapping_2021/test/"
# args.log_dir = "/gpfswork/rech/cli/uvo53rl/logs"

# subset
args.feature_scaler = "standard"
args.model = "svgp"
args.smoke_test = False
args.wandb_mode = "disabled"
args.wandb_resume = True
args.id = None  # "2uuq7tks"
args.batch_size = 2048
args.n_epochs = 10

args = add_january_smoke_test_args(args)
# args = add_winter_smoke_test_args(args)

### Logger


In [None]:
# init wandb logger
wandb.init(
    id=args.id,
    config=args,
    mode=args.wandb_mode,
    project=args.project,
    entity=args.entity,
    dir=args.log_dir,
    resume=args.wandb_resume,
)

### Load Data

In [None]:
%%time

# load data
data = load_data(args)

# preprocess data
data = preprocess_data(data, args)

# feature transformation
data, scaler = feature_transform(data, args)

In [None]:
data[data.attrs["input_cols"]].head()

In [None]:
data[data.attrs["input_cols"]].describe()

In [None]:
%%time


# split data
xtrain, ytrain, xvalid, yvalid = split_data(data, args)

args.in_dim = xtrain.shape[-1]
args.n_train = xtrain.shape[0]
args.n_valid = xvalid.shape[0]

print(args.n_train)
# if args.smoke_test:

rng = np.random.RandomState(args.split_seed)
idx = rng.choice(np.arange(args.n_train), size=2_000)
xtrain = xtrain[idx]
ytrain = ytrain[idx]

# rng = np.random.RandomState(args.split_seed+10)
# idx = rng.choice(np.arange(args.n_valid), size=1_000)
# xvalid = xvalid[idx]
# yvalid = yvalid[idx]


wandb.config.update(
    {
        "in_dim": args.in_dim,
        "n_train": args.n_train,
        "n_valid": args.n_valid,
    }
)

In [None]:
xtrain.shape

## Model - Stochastic Variational GP

In [None]:
import gpflow
import gpflux
import numpy as np
from gpflow.utilities import print_summary

# # Ensure TF does not see GPU and grab all GPU memory.
import tensorflow as tf

tf.config.set_visible_devices([], device_type="GPU")

In [None]:
from scipy.cluster.vq import kmeans2

In [None]:
data_train = (xtrain.astype(np.float64), ytrain.astype(np.float64))

In [None]:
# get kernel
kernel = get_kernel(args)
# get likelihood
likelihood = get_likelihood(args)
# get inducing points
Z = get_inducing_points(xtrain, args)

In [None]:
def make_svgp_model(n_inducing: int = 100):
    # kernel function
    lengthscales = [1.0, 1.0, 7.0]
    kernel = gpflow.kernels.SquaredExponential(
        lengthscales=lengthscales,
    )

    # likelihood
    noise = 0.01
    likelihood = gpflow.likelihoods.Gaussian(variance=noise)

    # inducing points
    Z = kmeans2(xtrain, n_inducing, minit="points")[0]

    num_data = xtrain.shape[0]

    # create gp model
    model = gpflow.models.SVGP(
        kernel, likelihood, Z.astype(np.float64), num_data=num_data
    )

    return model

In [None]:
# kernel function
lengthscales = [1.0, 1.0, 7.0]
kernel = gpflow.kernels.SquaredExponential(
    lengthscales=lengthscales,
)

# likelihood
noise = 0.01
likelihood = gpflow.likelihoods.Gaussian(variance=noise)

In [None]:
%%time

# inducing points
n_inducing = 100
Z = kmeans2(xtrain, n_inducing, minit="points")[0]

In [None]:
num_data = xtrain.shape[0]

# create gp model
model = gpflow.models.SVGP(kernel, likelihood, Z.astype(np.float64), num_data=num_data)

In [None]:
# don't train the inducing inputs
gpflow.set_trainable(model.inducing_variable, False)

In [None]:
# compile
elbo = tf.function(model.elbo)

In [None]:
# tensor_data = tuple(map(tf.convert_to_tensor, data_train))
# elbo(tensor_data)  # run it once to trace & compile

In [None]:
# %%timeit
# elbo(tensor_data)

In [None]:
# Make dataloader, set batch size and prefetch buffer:
prefetch_buffer = 5
batch_size = 100

In [None]:
def make_ds(batch_size: int = 100, shuffle: bool = True):
    ds = tf.data.Dataset.from_tensor_slices(
        (xtrain.astype(np.float64), ytrain.astype(np.float64))
    )

    ds = ds.prefetch(prefetch_buffer)
    ds = ds.repeat()
    if shuffle:
        ds = ds.shuffle(buffer_size=10 * batch_size)
    ds = ds.batch(batch_size)

    return iter(ds)


ds_train = make_ds()

In [None]:
elbo(next(ds_train))

In [None]:
%%timeit
elbo(next(ds_train))

In [None]:
from tqdm.notebook import trange

In [None]:
losses = {}

In [None]:
# make dataset
ds_train = make_ds()

# make gp model
n_inducing = 100
model = make_svgp_model(n_inducing)
# don't train the inducing inputs
gpflow.set_trainable(model.inducing_variable, False)


# Create an Adam Optimizer
n_iterations = 10_000
losses["standard"] = []
optimizer = tf.optimizers.Adam()

# make training loss
training_loss = model.training_loss_closure(ds_train, compile=True)


@tf.function
def optimization_step():
    optimizer.minimize(training_loss, model.trainable_variables)


with trange(n_iterations) as pbar:
    for step in pbar:
        optimization_step()
        elbo = -training_loss().numpy()
        losses["standard"].append(elbo)

        if step % 10 == 0:
            pbar.set_description(f"Loss (ELBO): {elbo:.4e}")

In [None]:
fig, ax = plt.subplots()

ax.plot(losses["standard"], label="Training Loss (Standard)")
ax.set(xlabel="Iterations", ylabel="ELBO")
plt.legend()
plt.show()

In [None]:
print_summary(model)

## Natural Gradients

In [None]:
from gpflow import set_trainable
from gpflow.optimizers import NaturalGradient

# make dataset
ds_train = make_ds()

# make gp model
n_inducing = 100
model = make_svgp_model(n_inducing)
# don't train the inducing inputs
gpflow.set_trainable(model.inducing_variable, False)


# Create an Adam Optimizer
n_iterations = 10_000
learning_rate = 1e-3

ordinary_adam_opt = tf.optimizers.Adam(learning_rate)


# NatGrads and Adam for SVGP
# Stop Adam from optimizing the variational parameters
set_trainable(model.q_mu, False)
set_trainable(model.q_sqrt, False)

# Create the optimize_tensors for SVGP
natgrad_adam_opt = tf.optimizers.Adam(learning_rate)

natgrad_opt = NaturalGradient(gamma=0.1)
variational_params = [(model.q_mu, model.q_sqrt)]


# make training loss
training_loss = model.training_loss_closure(ds_train, compile=True)

In [None]:
losses["natgrad"] = []


@tf.function
def optimization_step():
    natgrad_adam_opt.minimize(training_loss, var_list=model.trainable_variables)


@tf.function
def natgrad_optimization_step():
    natgrad_opt.minimize(training_loss, var_list=variational_params)


with trange(n_iterations) as pbar:
    for step in pbar:
        optimization_step()
        natgrad_optimization_step()
        elbo = -training_loss().numpy()
        losses["natgrad"].append(elbo)

        if step % 10 == 0:
            pbar.set_description(f"Loss (ELBO): {elbo:.4e}")

In [None]:
fig, ax = plt.subplots()

ax.plot(losses["standard"], label="Training Loss (Standard)")
ax.plot(losses["natgrad"], label="Training Loss (NatGrad)")
ax.set(xlabel="Iterations", ylabel="ELBO", xscale="log")
plt.legend()
plt.show()

In [None]:
print_summary(model)

In [None]:
import tensorflow_datasets as tfd


def predict_grid(gp_model, n_batches: int = 5_000):
    # generate grid
    df_grid = generate_eval_data(args)

    # set input columns
    df_pred = df_grid[df_grid.attrs["input_cols"]].values

    # create dataloader
    ds_test = tf.data.Dataset.from_tensor_slices(df_pred).batch(n_batches)
    n_iters = len(ds_test)
    means, variances = [], []
    ds_test = iter(ds_test)
    with trange(n_iters) as pbar:
        for i in pbar:
            ix = next(ds_test)
            # predict using GP
            imean, ivar = gp_model.predict_f(ix)

            # add stuff
            means.append(imean)
            variances.append(ivar)

    mean = np.vstack(means)
    variance = np.vstack(variances)

    df_grid["pred"] = mean
    df_grid["variance"] = variance

    return df_grid

In [None]:
# make predictions
df_grid = predict_grid(model)

# create OI
ds_oi = postprocess_data(df_grid, args)

### Metrics

In [None]:
%%time

rmse_metrics = get_rmse_metrics(ds_oi, args)
print(rmse_metrics)

In [None]:
psd_metrics = get_psd_metrics(ds_oi, args)
print(psd_metrics)

In [None]:
fig, ax = plot_psd_spectrum(
    psd_metrics.psd_study, psd_metrics.psd_ref, psd_metrics.wavenumber
)
fig, ax = plot_psd_score(
    psd_metrics.psd_diff,
    psd_metrics.psd_ref,
    psd_metrics.wavenumber,
    psd_metrics.resolved_scale,
)

### Viz

In [None]:
# ds_oi.ssh.hvplot.image(
#     x="longitude",
#     y="latitude",
#     groupby='time',
#     # rasterize=True,
#     width=500, height=400, cmap="viridis")

In [None]:
# ds_oi.variance.hvplot.image(
#     x="longitude",
#     y="latitude",
#     # groupby='time',
#     # rasterize=True,
#     width=500, height=400, cmap="viridis")

In [None]:
ds_oi["ssh_grad"] = calculate_gradient(ds_oi["ssh"], "longitude", "latitude")

In [None]:
# ds_oi.ssh_grad.hvplot.image(
#     x="longitude",
#     y="latitude",
#     # groupby='time',
#     # rasterize=True,
#     width=500, height=400, cmap="Spectral_r")

In [None]:
ds_oi["ssh_lap"] = calculate_laplacian(ds_oi["ssh"], "longitude", "latitude")

In [None]:
# ds_oi.ssh_lap.hvplot.image(
#     x="longitude",
#     y="latitude",
#     # groupby='time',
#     # rasterize=True,
#     width=500, height=400, cmap="RdBu_r")

In [None]:
from gpflux.helpers import construct_basic_inducing_variables, construct_basic_kernel

In [None]:
num_data = xtrain.shape[0]
input_dim = xtrain.shape[1]
num_inducing = 100
output_dim = ytrain.shape[1]


# kernel
kernel = construct_basic_kernel(gpflow.kernels.RBF(), output_dim)

# inducing points layer
inducing_vars = construct_basic_inducing_variables(num_inducing, input_dim, output_dim)

# gp layer
gp_layer = gpflux.layers.GPLayer(
    kernel=kernel, inducing_variable=inducing_vars, num_data=num_data
)

# likelihood layer
likelihood_layer = gpflux.layers.LikelihoodLayer(gpflow.likelihoods.Gaussian(0.1))


# create DEEPGP
gp_model = gpflux.models.DeepGP([gp_layer], likelihood_layer, input_dim=input_dim)

In [None]:
gpflux.__version__, tf.__version__, gpflow.__version__

In [None]:
optimizer = tf.optimizers.Adam()


@tf.function(autograph=False)
def objective_closure():
    return -gp_model.elbo((xtrain, ytrain))


@tf.function
def step():
    optimizer.minimize(objective_closure, gp_model.trainable_variables)


tq = tqdm.tqdm(range(args.n_epochs))
for i in tq:
    step()

In [None]:
gp_model.compile(tf.optimizers.Adam(0.01))

In [None]:
history = gp_model.fit({"inputs": xtrain, "targets": ytrain}, epochs=int(10), verbose=1)

In [None]:
model = single_layer_dgp.as_training_model()
model.compile(tf.optimizers.Adam(0.01))

In [None]:
num_data = xtrain.shape[0]
num_inducing = 100
output_dim = ytrain.shape[1]

kernel = gpflow.kernels.SquaredExponential()
z = np.linspace(xtrain.min(), xtrain.max(), num_inducing).reshape(-1, 1)
inducing_variable = gpflow.inducing_variables.InducingPoints(z)

# init gp layer
gp_layer = gpflux.layers.GPLayer(
    kernel, inducing_variable, num_data=num_data, num_latent_gps=output_dim
)

# likelihood layer
likelihood_layer = gpflux.layers.LikelihoodLayer(gpflow.likelihoods.Gaussian(0.1))


# create dgp
single_layer_dgp = gpflux.models.DeepGP([gp_layer], likelihood_layer)

In [None]:
history = model.fit({"inputs": xtrain, "targets": ytrain}, epochs=int(10), verbose=1)

In [None]:
plt.plot(history.history["loss"])

In [None]:
%%time


# model
model = get_model(args)

# optimizer
optimizer = get_optimizer(args)

# loss
make_step, val_step = get_loss_fn(args)

# init model
opt_state = optimizer.init(model)

In [None]:
n_steps_per_epoch = args.n_train / args.batch_size
steps = int(n_steps_per_epoch * args.n_epochs) if not args.smoke_test else 500


wandb.config.update(
    {
        "steps": steps,
        "n_steps_per_epoch": n_steps_per_epoch,
    }
)

### Training

In [None]:
train_ds = make_mini_batcher(xtrain, ytrain, args.batch_size, 5, shuffle=True)
valid_ds = make_mini_batcher(xvalid, yvalid, args.batch_size, 5, shuffle=False)


losses = {}
losses["train"] = []
losses["valid"] = []


with tqdm.trange(steps) as pbar:
    for step in pbar:

        ix, iy = next(train_ds)
        loss, grads = make_step(model, jnp.asarray(ix), jnp.asarray(iy))

        updates, opt_state = optimizer.update(grads, opt_state)
        model = eqx.apply_updates(model, updates)

        losses["train"].append(loss)
        wandb.log({"train_loss": loss}, step=step)
        ix, iy = next(valid_ds)
        # validation step
        vloss = val_step(model, jnp.asarray(ix), jnp.asarray(iy))
        losses["valid"].append(vloss)

        wandb.log({"val_loss": vloss}, step=step)

        if step % 10 == 0:
            pbar.set_description(
                f"Step: {step:_} | Train Loss: {loss:.3e} | Valid Loss: {vloss:.3e}"
            )

### Save models

In [None]:
# objects
path_model = Path(wandb.run.dir).joinpath("model.pickle")
path_scaler = Path(wandb.run.dir).joinpath("scaler.pickle")

# models to save
save_object(model, path_model)
save_object(scaler, path_scaler)

# save with wandb
wandb.save(str(path_model), policy="now")
wandb.save(str(path_scaler), policy="now")

### Load Models (Optional)

In [None]:
# # if args.server == "jz":
# # get id (from this run or a run you can set)
# run_id = wandb.run.id

# # initialize api
# api = wandb.Api()

# # get run
# run = api.run(f"{args.entity}/{args.project}/{run_id}")

# # download the files
# files = ["scaler.pickle", "model.pickle"]

# for ifile in files:

#     run.file(ifile).download(replace=True)

In [None]:
# model = load_object("./model.pickle")
# scaler = load_object("./scaler.pickle")

### PostProcessing

In [None]:
df_grid = generate_eval_data(args)
df_pred = feature_transform(df_grid.copy(), args, scaler=scaler)

df_grid.describe()

In [None]:
df_pred = feature_transform(df_grid.copy(), args, scaler=scaler)
df_pred.describe(), df_grid.describe()

In [None]:
wandb.config.update(
    {
        "n_test": df_pred.shape[0],
    }
)

### Predictions


In [None]:
@jax.jit
def pred_step(model, data):
    return jax.vmap(model)(data)

In [None]:
from ml4ssh._src.model_utils import batch_predict
from functools import partial
import time

In [None]:
df_pred = jnp.asarray(df_pred[df_pred.columns.difference(["time"])].values)

fn = partial(pred_step, model)

t0 = time.time()
df_grid["pred"] = batch_predict(df_pred, fn, args.eval_batch_size)
t1 = time.time() - t0

In [None]:
wandb.config.update(
    {
        "time_predict_batches": t1,
    }
)

In [None]:
ds_oi = postprocess_data(df_grid, args)

In [None]:
ds_oi

In [None]:
%%time

rmse_metrics = get_rmse_metrics(ds_oi, args)

wandb.log(
    {
        "model_rmse_mean": rmse_metrics[0],
        "model_rmse_std": rmse_metrics[1],
        "model_nrmse_mean": rmse_metrics[2],
        "model_nrmse_std": rmse_metrics[3],
    }
)

In [None]:
psd_metrics = get_psd_metrics(ds_oi, args)

### Figures

In [None]:
fig, ax = plot_psd_spectrum(
    psd_metrics.psd_study, psd_metrics.psd_ref, psd_metrics.wavenumber
)


wandb.log(
    {
        "model_psd_spectrum": wandb.Image(fig),
    }
)

In [None]:
fig, ax = plot_psd_score(
    psd_metrics.psd_diff,
    psd_metrics.psd_ref,
    psd_metrics.wavenumber,
    psd_metrics.resolved_scale,
)

wandb.log(
    {
        "model_psd_score": wandb.Image(fig),
    }
)

### Movies

In [None]:
save_path = wandb.run.dir  # Path(root).joinpath("experiments/dc_2021b")

In [None]:
# import hvplot.xarray


# ds_oi.ssh.hvplot.image(
#     x="longitude",
#     y="latitude",
#     # groupby='time',
#     # rasterize=True,
#     width=500, height=400, cmap="viridis")

In [None]:
# if args.smoke_test:
#     create_movie(ds_oi.ssh.isel(time=slice(50,60)), f"pred", "time", cmap="viridis", file_path=save_path)
# else:
#     create_movie(ds_oi.ssh, f"pred", "time", cmap="viridis", file_path=save_path)

In [None]:
# wandb.log(
#     {
#         "predictions_gif": wandb.Image(f"{save_path}/movie_pred.gif"),
#     }
# )

### Gradients

In [None]:
ds_oi["ssh_grad"] = calculate_gradient(ds_oi["ssh"], "longitude", "latitude")

In [None]:
ds_oi["ssh_grad"] = calculate_gradient(ds_oi["ssh"], "longitude", "latitude")
# ds_oi.ssh_grad.hvplot.image(
#     x="longitude",
#     y="latitude",
#     # groupby='time',
#     # rasterize=True,
#     width=500, height=400, cmap="Spectral_r")

In [None]:
# if args.smoke_test:
#     create_movie(ds_oi.ssh_grad.isel(time=slice(50,60)), f"pred_grad", "time", cmap="Spectral_r", file_path=save_path)
# else:
#     create_movie(ds_oi.ssh_grad, f"pred_grad", "time", cmap="Spectral_r", file_path=save_path)

In [None]:
# wandb.log(
#     {
#         "predictions_grad_gif": wandb.Image(f"{save_path}/movie_pred_grad.gif"),
#     }
# )

### Laplacian

In [None]:
ds_oi["ssh_lap"] = calculate_laplacian(ds_oi["ssh"], "longitude", "latitude")

In [None]:
ds_oi["ssh_lap"] = calculate_laplacian(ds_oi["ssh"], "longitude", "latitude")
# ds_oi.ssh_lap.hvplot.image(
#     x="longitude",
#     y="latitude",
#     # groupby='time',
#     # rasterize=True,
#     width=500, height=400, cmap="RdBu_r")

In [None]:
# if args.smoke_test:
#     create_movie(ds_oi.ssh_lap.isel(time=slice(50,60)), f"pred_lap", "time", cmap="RdBu_r", file_path=save_path)
# else:
#     create_movie(ds_oi.ssh_lap, f"pred_lap", "time", cmap="RdBu_r", file_path=save_path)

In [None]:
# wandb.log(
#     {
#         "predictions_laplacian_gif": wandb.Image(f"{save_path}/movie_pred_lap.gif"),
#     }
# )