# Demo - Siren

In [None]:
import os
import sys

from pyprojroot import here

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

# spyder up to find the root
root = here(project_files=[".root"])
exp = here(
    relative_project_path=root.joinpath("experiments/expv2"), project_files=[".local"]
)


# append to path
sys.path.append(str(root))
sys.path.append(str(exp))

In [None]:
import imageio
import numpy as np
import pytorch_lightning as pl
import torch
import wandb
from inr4ssh._src.data.ssh_obs import (
    load_ssh_altimetry_data_test,
    load_ssh_altimetry_data_train,
    load_ssh_correction,
)
from inr4ssh._src.datamodules.ssh_obs import SSHAltimetry
from inr4ssh._src.features.data_struct import df_2_xr
from inr4ssh._src.interp import interp_on_alongtrack
from inr4ssh._src.metrics.psd import compute_psd_scores, select_track_segments
from inr4ssh._src.metrics.stats import (
    calculate_nrmse,
    calculate_nrmse_elementwise,
    calculate_rmse_elementwise,
)
from inr4ssh._src.models.activations import get_activation
from inr4ssh._src.models.siren import ModulatedSirenNet, Modulator, Siren, SirenNet
from inr4ssh._src.postprocess.ssh_obs import postprocess
from inr4ssh._src.preprocess.coords import (
    correct_coordinate_labels,
    correct_longitude_domain,
)
from inr4ssh._src.preprocess.subset import spatial_subset, temporal_subset
from inr4ssh._src.viz.psd import plot_psd_score, plot_psd_spectrum
from loguru import logger
from skorch import NeuralNetRegressor
from skorch.callbacks import EarlyStopping, LRScheduler, WandbLogger
from skorch.dataset import ValidSplit
from torch import nn
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm.notebook import tqdm as tqdm

pl.seed_everything(123)

import matplotlib.pyplot as plt
import seaborn as sns
from inr4ssh._src.viz.movie import create_movie

sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import argparse

import config

In [None]:
# initialize argparse
parser = argparse.ArgumentParser()

# add all experiment arguments
parser = config.add_logging_args(parser)
parser = config.add_data_dir_args(parser)
parser = config.add_data_preprocess_args(parser)
parser = config.add_feature_transform_args(parser)
parser = config.add_train_split_args(parser)
parser = config.add_dataloader_args(parser)
parser = config.add_model_args(parser)
parser = config.add_loss_args(parser)
parser = config.add_optimizer_args(parser)
parser = config.add_eval_data_args(parser)
parser = config.add_eval_metrics_args(parser)
parser = config.add_viz_data_args(parser)

# parse args
args = parser.parse_args([])

# modify args (PERSONAL)
args.train_data_dir = "/Users/eman/.CMVolumes/cal1_workdir/data/dc_2021/raw/train"
args.ref_data_dir = "/Users/eman/.CMVolumes/cal1_workdir/data/dc_2021/raw/ref"
args.test_data_dir = "/Users/eman/.CMVolumes/cal1_workdir/data/dc_2021/raw/test"

# # modify args (MEOMCAL1)
# args.train_data_dir = "/home/johnsonj/data/dc_2021/raw/train"
# args.ref_data_dir = "/home/johnsonj/data/dc_2021/raw/ref"
# args.test_data_dir = "/home/johnsonj/data/dc_2021/raw/test"
#
# args.time_min = "2017-01-01"
# args.time_max = "2017-02-01"
# args.eval_time_min = "2017-01-01"
# args.eval_time_max = "2017-02-01"
# args.eval_dtime = "12_h"

# logger args
# ige/inr4ssh/2avm7u7m
# ige/inr4ssh/3rzy4mbv - genial-deluge-8
# ige/inr4ssh/nlbgt9aq - dazzling tree 12
# ige/inr4ssh/ymwqgoj7 - leafy moon
args.wandb_resume = True
args.wandb_mode = "online"
args.wandb_project = "inr4ssh"
args.wandb_entity = "ige"
args.wandb_log_dir = "/Users/eman/code_projects/logs"
args.wandb_id = "ymwqgoj7"

In [None]:
wandb_run = wandb.init(
    # config=args,
    mode=args.wandb_mode,
    project=args.wandb_project,
    entity=args.wandb_entity,
    dir=args.wandb_log_dir,
    resume=args.wandb_resume,
    id=args.wandb_id,
)

In [None]:
args = wandb_run.config
args

In [None]:
# modify args (PERSONAL)
args.train_data_dir = "/Users/eman/.CMVolumes/cal1_workdir/data/dc_2021/raw/train"
args.ref_data_dir = "/Users/eman/.CMVolumes/cal1_workdir/data/dc_2021/raw/ref"
args.test_data_dir = "/Users/eman/.CMVolumes/cal1_workdir/data/dc_2021/raw/test"

to extract a default argument if it is in the name space
`lr = getattr(args, "lr", 1e-4)`

In [None]:
logger.info("Initializing data module...")
dm = SSHAltimetry(args)

In [None]:
dm.setup()

In [None]:
dl_train = dm.train_dataloader()
dl_valid = dm.val_dataloader()
dl_predict = dm.predict_dataloader()

In [None]:
len(dm.ds_train), len(dm.ds_valid), len(dm.ds_predict)

## Skorch Training

In [None]:
X_train, y_train = dm.ds_train[:]
X_valid, y_valid = dm.ds_valid[:]
(X_test,) = dm.ds_predict[:]
X_train = torch.cat([X_train, X_valid])
y_train = torch.cat([y_train, y_valid])

In [None]:
X_train.shape, X_test.shape

In [None]:
logger.info("Creating neural network...")
dim_in = X_train.shape[1]
dim_hidden = args.hidden_dim
dim_out = y_train.shape[1]
num_layers = args.n_hidden
w0 = args.siren_w0
w0_initial = args.siren_w0_initial
c = args.siren_c
final_activation = get_activation(args.final_activation)

In [None]:
siren_net = SirenNet(
    dim_in=dim_in,
    dim_hidden=dim_hidden,
    dim_out=dim_out,
    num_layers=num_layers,
    w0=w0,
    w0_initial=w0_initial,
    final_activation=final_activation,
)

In [None]:
# if torch.has_mps:
#     device = "mps"
# elif torch.has_cuda:
#     device = "cuda"
# else:
#     device = "cpu"
device = "mps"

In [None]:
# args.num_epochs = 500
# args.learning_rate = 1e-3

In [None]:
# learning rate scheduler
lr_scheduler = LRScheduler(
    policy="ReduceLROnPlateau",
    monitor="valid_loss",
    mode="min",
    factor=0.1,
    patience=10,
)
# learning rate scheduler

# lr_scheduler = LRScheduler(
#     policy=CosineAnnealingLR,
#     T_max=args.num_epochs
# )

# early stopping
estop_callback = EarlyStopping(
    monitor="valid_loss",
    patience=50,
)

wandb_callback = WandbLogger(wandb_run, save_model=True)

callbacks = [
    ("earlystopping", estop_callback),
    ("lrscheduler", lr_scheduler),
    ("wandb_logger", wandb_callback),
]

In [None]:
# train split percentage
train_split = ValidSplit(0.1, stratified=False)

In [None]:
skorch_net = NeuralNetRegressor(
    module=siren_net,
    max_epochs=args.num_epochs,
    lr=args.learning_rate,
    batch_size=args.batch_size,
    device=device,
    optimizer=torch.optim.Adam,
    train_split=train_split,
    callbacks=callbacks,
)

In [None]:
# if args.server == "jz":
# get id (from this run or a run you can set)
run_id = wandb_run.id

# initialize api
api = wandb.Api()

# get run
run = api.run(f"{args.wandb_entity}/{args.wandb_project}/{run_id}")


# download the files
files = [
    # "scaler.pickle",
    f"checkpoints/params.ckpt"
    # "checkpoints/last.ckpt"
]

for ifile in files:

    run.file(ifile).download(replace=True)

In [None]:
from inr4ssh._src.io import load_object

In [None]:
scaler = load_object("./scaler.pickle")

In [None]:
skorch_net.initialize()  # This is important!
# skorch_net.load_params(f_params='./best_model.pth')
skorch_net.load_params(f_params="checkpoints/params.ckpt")

In [None]:
# skorch_net.fit(X_train, y_train)

In [None]:
# fig, ax = plt.subplots()

# ax.plot(skorch_net.history[:, "train_loss"], label="Train Loss")
# ax.plot(skorch_net.history[:, "valid_loss"], label="Validation Loss")

# ax.set(yscale="log", xlabel="Epochs", ylabel="Mean Squared Error")

# plt.legend()
# plt.show()

## Predictions

### SSH Along Track

In [None]:
%%time

# open along track dataset
ds_alongtrack = load_ssh_altimetry_data_test(args.test_data_dir)

# correct labels
ds_alongtrack = correct_coordinate_labels(ds_alongtrack)

# correct longitude domain
ds_alongtrack = correct_longitude_domain(ds_alongtrack)

# temporal subset
ds_alongtrack = temporal_subset(
    ds_alongtrack,
    time_min=np.datetime64(args.time_min),
    time_max=np.datetime64(args.time_max),
    time_buffer=args.time_buffer,
)

# spatial subset
ds_alongtrack = spatial_subset(
    ds_alongtrack,
    lon_min=args.eval_lon_min,
    lon_max=args.eval_lon_max,
    lon_buffer=args.eval_lon_buffer,
    lat_min=args.eval_lat_min,
    lat_max=args.eval_lat_max,
    lat_buffer=args.eval_lat_buffer,
)

# convert to dataframe
ds_alongtrack = ds_alongtrack.to_dataframe().reset_index().dropna()

# convert to df using the transform
X_test = dm.scaler.transform(ds_alongtrack)
y_test = ds_alongtrack["sla_unfiltered"]

In [None]:
%%time
import time

t0 = time.time()
predictions = skorch_net.predict(torch.Tensor(X_test))
t1 = time.time() - t0

wandb_run.log(
    {
        "time_predict_alongtrack": t1,
    }
)

#### Stats

In [None]:
rmse_mean, rmse_std = calculate_rmse_elementwise(y_test, predictions)

wandb_run.log(
    {
        f"rmse_mean_alongtrack": rmse_mean,
        f"rmse_std_alongtrack": rmse_std,
    }
)

In [None]:
print(f"RMSE: {rmse_mean}\nRMSE (stddev): {rmse_std}")

In [None]:
metrics = ["custom", "std", "mean", "minmax", "iqr"]

for imetric in metrics:

    nrmse_mean, nrmse_std = calculate_nrmse_elementwise(y_test, predictions, imetric)

    print(f"RMSE ({imetric}): mean - {nrmse_mean:.4f}, stddev - {nrmse_std:.4f}")

    wandb_run.log(
        {
            f"nrmse_mean_alongtrack_{imetric}": nrmse_mean,
            f"nrmse_std_alongtrack_{imetric}": nrmse_std,
        }
    )

#### PSD Score

In [None]:
psd_metrics = compute_psd_scores(
    ssh_true=y_test.squeeze(),
    ssh_pred=predictions.squeeze(),
    delta_x=args.eval_psd_velocity * args.eval_psd_delta_t,
    npt=None,
    scaling="density",
    noverlap=0,
)

In [None]:
print(psd_metrics)

In [None]:
wandb_run.log(
    {
        "resolved_scale_alongtrack": psd_metrics.resolved_scale,
    }
)

#### Viz - PSD Score

In [None]:
fig, ax = plot_psd_score(
    psd_diff=psd_metrics.psd_diff,
    psd_ref=psd_metrics.psd_ref,
    wavenumber=psd_metrics.wavenumber,
    resolved_scale=psd_metrics.resolved_scale,
)

wandb_run.log(
    {
        "psd_score_alongtrack": wandb.Image(fig),
    }
)

#### Viz - PSD Spectrum

In [None]:
fig, ax = plot_psd_spectrum(
    psd_study=psd_metrics.psd_study,
    psd_ref=psd_metrics.psd_ref,
    wavenumber=psd_metrics.wavenumber,
)

wandb_run.log(
    {
        "psd_spectrum_alongtrack": wandb.Image(fig),
    }
)

### SSH Grid

In [None]:
%%time
# extract grid variables
(X_test,) = dm.ds_predict[:]

# TESTING
logger.info("Making predictions...")
t0 = time.time()
predictions = skorch_net.predict(torch.FloatTensor(X_test))
t1 = time.time() - t0

In [None]:
wandb_run.log(
    {
        "time_predict_grid": t1,
    }
)

In [None]:
# convert to da
logger.info("Convert data to xarray ds...")
ds_oi = dm.X_pred_index
ds_oi["ssh"] = predictions
ds_oi = df_2_xr(ds_oi)

In [None]:
# open correction dataset
logger.info("Loading SSH corrections...")
ds_correct = load_ssh_correction(args.ref_data_dir)

In [None]:
# correct predictions
logger.info("Correcting SSH predictions...")
ds_oi = postprocess(ds_oi, ds_correct)

In [None]:
# open along track dataset
logger.info("Loading test dataset...")
ds_alongtrack = load_ssh_altimetry_data_test(args.test_data_dir)

# interpolate along track
logger.info("Interpolating alongtrack obs...")
alongtracks = interp_on_alongtrack(
    gridded_dataset=ds_oi,
    ds_alongtrack=ds_alongtrack,
    lon_min=args.eval_lon_min,
    lon_max=args.eval_lon_max,
    lat_min=args.eval_lat_min,
    lat_max=args.eval_lat_max,
    time_min=args.eval_time_min,
    time_max=args.eval_time_max,
)

#### Stats

In [None]:
# RMSE
logger.info("Getting RMSE Metrics...")


rmse_metrics = calculate_nrmse(
    true=alongtracks.ssh_alongtrack,
    pred=alongtracks.ssh_map,
    time_vector=alongtracks.time,
    dt_freq=args.eval_bin_time_step,
    min_obs=args.eval_min_obs,
)

print(rmse_metrics)
wandb_run.log(
    {
        "rmse_mean_grid": rmse_metrics.rmse_mean,
        "rmse_std_grid": rmse_metrics.rmse_std,
        "nrmse_mean_grid": rmse_metrics.nrmse_mean,
        "nrmse_std_grid": rmse_metrics.nrmse_std,
    }
)

#### PSD

In [None]:
logger.info("Selecting track segments...")
tracks = select_track_segments(
    time_alongtrack=alongtracks.time,
    lat_alongtrack=alongtracks.lat,
    lon_alongtrack=alongtracks.lon,
    ssh_alongtrack=alongtracks.ssh_alongtrack,
    ssh_map_interp=alongtracks.ssh_map,
)

delta_x = args.eval_psd_velocity * args.eval_psd_delta_t

In [None]:
# compute scores
logger.info("Computing PSD Scores...")
psd_metrics = compute_psd_scores(
    ssh_true=tracks.ssh_alongtrack,
    ssh_pred=tracks.ssh_map,
    delta_x=delta_x,
    npt=tracks.npt,
    scaling="density",
    noverlap=0,
)

In [None]:
print(psd_metrics)

In [None]:
wandb_run.log(
    {
        "resolved_scale_grid": psd_metrics.resolved_scale,
    }
)

#### Viz - PSD Spectraum

In [None]:
logger.info("Plotting PSD Score...")
fig, ax = plot_psd_spectrum(
    psd_study=psd_metrics.psd_study,
    psd_ref=psd_metrics.psd_ref,
    wavenumber=psd_metrics.wavenumber,
)

wandb_run.log(
    {
        "psd_spectrum_grid": wandb.Image(fig),
    }
)

#### Viz - PSD Score

In [None]:
logger.info("Plotting PSD Score...")
fig, ax = plot_psd_score(
    psd_diff=psd_metrics.psd_diff,
    psd_ref=psd_metrics.psd_ref,
    wavenumber=psd_metrics.wavenumber,
    resolved_scale=psd_metrics.resolved_scale,
)

wandb_run.log(
    {
        "psd_score_grid": wandb.Image(fig),
    }
)

logger.info("Finished Script...!")

In [None]:
wandb_run.finish()

## Visualization

#### SSH Field

In [None]:
import hvplot.xarray

In [None]:
ds_oi.ssh.hvplot.image(
    x="longitude",
    y="latitude",
    # groupby='time',
    # rasterize=True,
    width=500,
    height=400,
    cmap="viridis",
)

In [None]:
save_path = wandb.run.dir

In [None]:
create_movie(ds_oi.ssh, f"pred", "time", cmap="viridis", file_path=save_path)

#### Gradient (Norm)

In [None]:
from inr4ssh._src.operators.finite_diff import calculate_gradient, calculate_laplacian

ds_oi["ssh_grad"] = calculate_gradient(ds_oi["ssh"], "longitude", "latitude")

In [None]:
# create_movie(
#     ds_oi.ssh_grad, f"pred_grad", "time", cmap="Spectral_r", file_path=save_path
# )

In [None]:
ds_oi.ssh_grad.hvplot.image(
    x="longitude",
    y="latitude",
    # groupby='time',
    # rasterize=True,
    width=500,
    height=400,
    cmap="Spectral_r",
)

#### Laplacian (Norm)

In [None]:
ds_oi["ssh_lap"] = calculate_laplacian(ds_oi["ssh"], "longitude", "latitude")

In [None]:
ds_oi.ssh_lap.hvplot.image(
    x="longitude",
    y="latitude",
    # groupby='time',
    # rasterize=True,
    width=500,
    height=400,
    cmap="RdBu_r",
)

In [None]:
# create_movie(ds_oi.ssh_lap, f"pred_lap", "time", cmap="RdBu_r", file_path=save_path)

In [None]:
# def custom_plotfunc(ds, fig, tt, *args, **kwargs):

#     # find indices
#     indx = np.where(
#         (np.abs(ds_test_obs_summer.time.values - ds.time.values[tt]) <= dt)
#     )[0]

#     # subset data
#     lat = ds_test_obs_summer.latitude.values[indx]
#     lon = ds_test_obs_summer.longitude.values[indx]
#     data = ds_test_obs_summer.sla_unfiltered.values[indx]

#     vmin = ds_test_obs_summer.sla_unfiltered.min()
#     vmax = ds_test_obs_summer.sla_unfiltered.max()

#     # do scatter plot
#     ax = fig.add_subplot(111, aspect="equal")

#     pts = ax.scatter(
#         lon,
#         lat,
#         c=data,
#         cmap="RdBu_r",
#         vmin=ds_test_obs_summer.sla_unfiltered.min(),
#         vmax=ds_test_obs_summer.sla_unfiltered.max(),
#     )
#     ax.set_title("")
#     ax.set_facecolor("0.5")
#     ax.set_aspect(0.75)
#     ax.set(
#         xlim=[ds.longitude.values.min() - 0.5, ds.longitude.values.max() + 0.5],
#         ylim=[
#             ds.latitude.values.min() - 0.5,
#             ds.latitude.values.max() + 0.5,
#         ],
#     )
#     plt.colorbar(pts)
#     # plt.tight_layout()

#     return None, None

In [None]:
# mov_custom = Movie(ds_summer.sla, custom_plotfunc)

In [None]:
# mov_custom.preview(9)

In [None]:
# mov_custom.save(
#     "plots/movie_ssh_gulf_jja_test.gif",
#     remove_movie=True,
#     progress=True,
#     framerate=5,
#     gif_framerate=5,
#     overwrite_existing=True,
# )

In [None]:
# def make_customplotfunc(ds_obs):
#     def f(ds, fig, tt, *args, **kwargs):

#         # find indices
#         indx = np.where((np.abs(ds_obs.time.values - ds.time.values[tt]) <= dt))[0]

#         # subset data
#         lat = ds_obs.latitude.values[indx]
#         lon = ds_obs.longitude.values[indx]
#         data = ds_obs.sla_unfiltered.values[indx]

#         vmin = ds_obs.sla_unfiltered.min()
#         vmax = ds_obs.sla_unfiltered.max()

#         # do scatter plot
#         ax = fig.add_subplot(111, aspect="equal")

#         pts = ax.scatter(
#             lon,
#             lat,
#             c=data,
#             cmap="RdBu_r",
#             vmin=ds_obs.sla_unfiltered.min(),
#             vmax=ds_obs.sla_unfiltered.max(),
#         )
#         ax.set_title(f"{ds.time.values[tt]:.10}")
#         ax.set_facecolor("0.5")
#         ax.set_aspect(0.75)
#         ax.set(
#             xlim=[ds.longitude.values.min() - 0.5, ds.longitude.values.max() + 0.5],
#             ylim=[
#                 ds.latitude.values.min() - 0.5,
#                 ds.latitude.values.max() + 0.5,
#             ],
#             xlabel="Longitudes [degrees_east]",
#             ylabel="Latitudes [degrees_north]",
#         )
#         plt.colorbar(pts)
#         plt.tight_layout()

#         return None, None

#     return f

In [None]:
# f_cust = make_customplotfunc(ds_test_obs_winter)

# mov_custom = Movie(ds_winter.sla, f_cust)

In [None]:
# mov_custom.save(
#     "plots/movie_ssh_gulf_djf_test.gif",
#     remove_movie=True,
#     progress=True,
#     framerate=5,
#     gif_framerate=5,
#     overwrite_existing=True,
# )