# Demo - Inference (+Metrics) Pipeline

In [None]:
import os
import sys

from pyprojroot import here

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

# spyder up to find the root
root = here(project_files=[".root"])
exp = here(
    relative_project_path=root.joinpath("experiments/dc21a"), project_files=[".local"]
)


# append to path
sys.path.append(str(root))
sys.path.append(str(exp))

In [None]:
import time

import pytorch_lightning as pl
import torch
import torch.nn as nn
import wandb
from inr4ssh._src.datamodules.ssh_obs import SSHAltimetry
from inr4ssh._src.io import get_wandb_config, get_wandb_model
from inr4ssh._src.metrics.psd import compute_psd_scores
from loguru import logger
from ml_collections import config_dict
from models import model_factory
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from torch.utils.data import DataLoader, TensorDataset
from utils import (
    get_alongtrack_prediction_ds,
    get_alongtrack_stats,
    get_grid_stats,
    get_interpolation_alongtrack_prediction_ds,
    plot_psd_figs,
    postprocess_predictions,
)

%matplotlib inline
%load_ext autoreload
%autoreload 2

## Load Config

* `ige/inr4ssh/34behz0w` | `epoch=839-step=330960.ckpt` | `elated-galaxy`
* `ige/inr4ssh/25d69t9z` | `epoch=148-step=58706.ckpt` | `firm-salad`
* `ige/inr4ssh/2z8tsrfn` | `epoch=836-step=329778.ckpt` | `ruby-dew-62` | `siren`
* `ige/inr4ssh/1t0yk7rc` | `epoch=710-step=280134.ckpt` | `fearless-glade` | `fouriernet`
* `ige/inr4ssh/25d69t9z` | `epoch=959-step=378240.ckpt` | `firm-salad-58` | `gabornet`
* `ige/inr4ssh/11h89im3` | `epoch=502-step=198182.ckpt` | `ethereal-aardvark-73`
* `ige/inr4ssh/14s2md8s` | `epoch=739-step=291560.ckpt` | `eager-surf-74`

In [None]:
wandb_path = "ige/inr4ssh/2z8tsrfn"
checkpoint_name = "checkpoints/epoch=836-step=329778.ckpt"
# checkpoint_name = "checkpoints/last.ckpt"

In [None]:
# download wandb config
config = get_wandb_config(wandb_path)

# download model checkpoint
best_model = get_wandb_model(wandb_path, checkpoint_name)
best_model.download(replace=True)

# convert to configdict
args = config_dict.ConfigDict(config)

In [None]:
best_model

In [None]:
args.data.train_data_dir = "/Users/eman/.CMVolumes/cal1_workdir/data/dc_2021/raw/train"
args.data.ref_data_dir = "/Users/eman/.CMVolumes/cal1_workdir/data/dc_2021/raw/ref"
args.data.test_data_dir = "/Users/eman/.CMVolumes/cal1_workdir/data/dc_2021/raw/test"
# modify args (PERSONAL)
# args.data.train_data_dir = "/Volumes/EMANS_HDD/data/dc21b/train"
# args.data.ref_data_dir = "/Volumes/EMANS_HDD/data/dc21b/ref"
# args.data.test_data_dir = "/Volumes/EMANS_HDD/data/dc21b/test"
args.siren.use_bias = True

## Load Data

In [None]:
args.dataloader.batchsize_predict = 1_000

In [None]:
# DATA MODULE
logger.info("Initializing data module...")
dm = SSHAltimetry(
    data=args.data,
    preprocess=args.preprocess,
    traintest=args.traintest,
    features=args.features,
    dataloader=args.dataloader,
    eval=args.eval,
)

In [None]:
dm.setup()

## Init Model

In [None]:
logger.info("extracting train/test dims")
dim_in = dm.dim_in
dim_out = dm.dim_out

logger.info(f"Creating {args.model.model} neural network...")
net = model_factory(
    model=args.model.model, dim_in=dim_in, dim_out=dim_out, config=args.modsiren
)

In [None]:
logger.info("Initializing trainer class...")


class CoordinatesLearner(pl.LightningModule):
    def __init__(self, model: nn.Module):
        super().__init__()
        self.model = model
        self.loss = nn.MSELoss(reduction=args.losses.reduction)

    def forward(self, x):
        return self.model(x)

    def predict_step(self, batch, batch_idx, dataloader_idx=0):

        (x,) = batch

        pred = self.forward(x)

        return pred

    def training_step(self, batch, batch_idx):
        x, y = batch
        # loss function
        pred = self.forward(x)
        loss = self.loss(pred, y)

        self.log("train_loss", loss)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        # loss function
        pred = self.forward(x)
        loss = self.loss(pred, y)

        self.log("valid_loss", loss)

        return loss

    def configure_optimizers(self):

        optimizer = torch.optim.Adam(
            self.model.parameters(), lr=args.optimizer.learning_rate
        )
        scheduler = ReduceLROnPlateau(
            optimizer,
            patience=args.lr_scheduler.patience,
            factor=args.lr_scheduler.factor,
            mode="min",
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": scheduler,
            "monitor": "valid_loss",
        }

## Load Model State

In [None]:
learn = CoordinatesLearner.load_from_checkpoint(
    checkpoint_path=best_model.name, model=net
)

## Initialize Trainer

In [None]:
logger.info("Initializing trainer...")
trainer = Trainer(
    min_epochs=args.optimizer.min_epochs,
    max_epochs=args.optimizer.num_epochs,
    accelerator="mps" if args.optimizer.device == "mps" else None,
    devices=1 if args.optimizer.device == "mps" else None,
    gpus=args.optimizer.gpus if torch.cuda.is_available() else 0,
    enable_progress_bar=True,
)

## Predictions (Grid)

In [None]:
logger.info("GRID STATS...")

# TESTING
logger.info("Making predictions (grid)...")
t0 = time.time()
with torch.inference_mode():
    predictions = trainer.predict(learn, datamodule=dm, return_predictions=True)
    predictions = torch.cat(predictions)
    predictions = predictions.numpy()
t1 = time.time() - t0

In [None]:
logger.info(f"Time Taken for {dm.ds_predict[:][0].shape[0]} points: {t1:.4f} secs")

### Post Processing

In [None]:
args.eval

In [None]:
logger.info("Getting RMSE Metrics (GRID)...")

ds_oi = postprocess_predictions(predictions, dm, args.data.ref_data_dir, logger)

alongtracks, tracks = get_interpolation_alongtrack_prediction_ds(
    ds_oi, args.data.test_data_dir, args.eval, logger
)

### RMSE Metrics

In [None]:
logger.info("Getting RMSE Metrics (GRID)...")
rmse_metrics = get_grid_stats(
    alongtracks, args.metrics, None, None  # wandb_logger.log_metrics
)


logger.info(f"Grid Stats: {rmse_metrics}")

### PSD Metrics

In [None]:
# compute scores
logger.info("Computing PSD Scores (Grid)...")
psd_metrics = compute_psd_scores(
    ssh_true=tracks.ssh_alongtrack,
    ssh_pred=tracks.ssh_map,
    delta_x=args.metrics.velocity * args.metrics.delta_t,
    npt=tracks.npt,
    scaling="density",
    noverlap=0,
)

logger.info(f"Grid PSD: {psd_metrics}")

### Figures

In [None]:
#
logger.info(f"Plotting PSD Score and Spectrum (Grid)...")
plot_psd_figs(psd_metrics, logger, None, method="grid")  # wandb_logger.experiment.log,
logger.info("Finished GRID Script...!")

## AlongTrack Predictions

In [None]:
logger.info("ALONGTRACK STATS...")

X_test, y_test = get_alongtrack_prediction_ds(dm, args, logger)

# initialize dataset
ds_test = TensorDataset(
    torch.FloatTensor(X_test)
    # torch.Tensor(y_test)
)
# initialize dataloader
dl_test = DataLoader(
    ds_test,
    batch_size=args.dataloader.batch_size_eval,
    shuffle=False,
    num_workers=args.dataloader.num_workers,
    pin_memory=args.dataloader.pin_memory,
)

In [None]:
logger.info(f"Predicting alongtrack data...")
t0 = time.time()
with torch.inference_mode():
    predictions = trainer.predict(learn, dataloaders=dl_test, return_predictions=True)
    predictions = torch.cat(predictions)
    predictions = predictions.numpy()
t1 = time.time() - t0

### RMSE Stats

In [None]:
logger.info("Calculating stats (alongtrack)...")
get_alongtrack_stats(
    y_test,
    predictions,
    logger,
    None,  # wandb_logger.log_metrics
)

### PSD Stats

In [None]:
# PSD
logger.info(f"Getting PSD Scores (alongtrack)...")
psd_metrics = compute_psd_scores(
    ssh_true=y_test.squeeze(),
    ssh_pred=predictions.squeeze(),
    delta_x=args.metrics.velocity * args.metrics.delta_t,
    npt=None,
    scaling="density",
    noverlap=0,
)

logger.info(f"Grid PSD: {psd_metrics}")

### Figures

In [None]:
logger.info(f"Plotting PSD Score and Spectrum (AlongTrack)...")
plot_psd_figs(
    psd_metrics, logger, None, method="alongtrack"  # wandb_logger.experiment.log,
)

In [None]:
ds_oi.to_netcdf("/Volumes/EMANS_HDD/data/dc21b/results/siren_136.nc")

#### SSH Field

In [None]:
import hvplot.xarray

In [None]:
ds_oi.ssh.hvplot.image(
    x="longitude",
    y="latitude",
    # groupby='time',
    # rasterize=True,
    width=500,
    height=400,
    cmap="viridis",
)

In [None]:
ds_oi_sub = ds_oi.sel(time=slice("2017-01-01", "2017-02-01"))
ds_oi_sub

In [None]:
from inr4ssh._src.viz.movie import create_movie

save_path = "./"

In [None]:
ds_oi.sel(time="2017-01-20").ssh.plot(cmap="viridis")

In [None]:
create_movie(ds_oi_sub.ssh, f"pred", "time", cmap="viridis", file_path=save_path)

#### Gradient (Norm)

In [None]:
from inr4ssh._src.operators.finite_diff import calculate_gradient, calculate_laplacian

ds_oi["ssh_grad"] = calculate_gradient(ds_oi["ssh"], "longitude", "latitude")

In [None]:
# create_movie(
#     ds_oi.ssh_grad, f"pred_grad", "time", cmap="Spectral_r", file_path=save_path
# )

In [None]:
ds_oi.ssh_grad.hvplot.image(
    x="longitude",
    y="latitude",
    # groupby='time',
    # rasterize=True,
    width=500,
    height=400,
    cmap="Spectral_r",
)

In [None]:
ds_oi.sel(time="2017-01-20").ssh_grad.plot(cmap="Spectral_r")

#### Laplacian (Norm)

In [None]:
ds_oi["ssh_lap"] = calculate_laplacian(ds_oi["ssh"], "longitude", "latitude")

In [None]:
ds_oi.ssh_lap.hvplot.image(
    x="longitude",
    y="latitude",
    # groupby='time',
    # rasterize=True,
    width=500,
    height=400,
    cmap="RdBu_r",
)

In [None]:
ds_oi.sel(time="2017-01-20").ssh_lap.plot(cmap="RdBu_r")

In [None]:
# create_movie(ds_oi.ssh_lap, f"pred_lap", "time", cmap="RdBu_r", file_path=save_path)