# Viz + Data Challenge 2021a

In [None]:
import sys, os
from pyprojroot import here

# spyder up to find the root
root = here(project_files=[".root"])


# append to path
sys.path.append(str(root))

In [None]:
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns

sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)

import hvplot.xarray
import hvplot.pandas

from inr4ssh._src.operators.finite_diff import calculate_gradient, calculate_laplacian
from inr4ssh._src.preprocess.subset import temporal_subset, spatial_subset
from inr4ssh._src.preprocess.coords import (
    correct_coordinate_labels,
    correct_longitude_domain,
)
from inr4ssh._src.data.ssh_obs import load_ssh_altimetry_data_train
from inr4ssh._src.viz.movie import create_movie

%matplotlib inline
%load_ext autoreload
%autoreload 2

## Observations

In [None]:
!ls $train_data_dir

In [None]:
train_data_dir = f"/Volumes/EMANS_HDD/data/dc21b/train"
# train_data_dir =

ds_obs = load_ssh_altimetry_data_train(train_data_dir)

In [None]:
variable = "sla_unfiltered"

In [None]:
# temporal subset
ds_obs = temporal_subset(
    ds_obs,
    time_min=np.datetime64("2017-01-01"),
    time_max=np.datetime64("2018-01-01"),
    # time_min=np.datetime64("2017-01-01"),
    # time_max=np.datetime64("2017-02-01"),
    time_buffer=0.0,
    time_buffer_order="D",
)

In [None]:
ds_obs = ds_obs[["latitude", "longitude", variable]].reset_coords().astype("f4").load()

In [None]:
ds_obs

In [None]:
df_obs = ds_obs.to_dataframe()  # .reset_index()

In [None]:
df_obs

In [None]:
df_sla_mean = (
    df_obs.groupby(["latitude", "longitude", pd.Grouper(freq="D", level="time")])[
        variable
    ]
    .mean()
    .reset_index()
)

In [None]:
# df_sla_mean.hvplot.scatter(
#     x='longitude', y='latitude', groupby='time',
#     datashade=True, #coastline=True
#     # tiles=True
# )

In [None]:
import pyinterp

lon_min = 285.0
lon_max = 315.0
lon_buffer = 1.0
lat_min = 23.0
lat_max = 53.0
bin_lon_step = 0.1
bin_lat_step = 0.1

In [None]:
binning = pyinterp.Histogram2D(
    pyinterp.Axis(np.arange(lon_min, lon_max, bin_lon_step), is_circle=True),
    pyinterp.Axis(np.arange(lat_min, lat_max + bin_lat_step, bin_lat_step)),
)

In [None]:
binning

In [None]:
binning.x[:].shape, binning.y[:].shape, binning.variable("mean").shape,

In [None]:
def create_xarray(grid, lon_coord, lat_coord, time_coord):
    return xr.Dataset(
        {
            "ssh": (("time", "latitude", "longitude"), grid),
            "time": ("time", time_coord),
            "latitude": ("latitude", lat_coord),
            "longitude": ("longitude", lon_coord),
        },
    )

In [None]:
ds_xr = []
from tqdm.notebook import tqdm

for itime in tqdm(df_sla_mean.groupby("time")):

    # do binning
    binning.push(
        itime[1].longitude,
        itime[1].latitude,
        itime[1][variable],
        # simple=True
    )

    # create temp df
    ids = create_xarray(
        binning.variable("mean").T[None, :], binning.x[:], binning.y[:], [itime[0]]
    )

    binning.clear()

    # add to dataframe
    ds_xr.append(ids)

In [None]:
ds_xr = xr.concat(ds_xr, dim="time")

In [None]:
# correct longitude domain
ds_xr = correct_longitude_domain(ds_xr)

In [None]:
ds_xr

In [None]:
# save_path = "./"
# create_movie(ds_xr.ssh, f"obs", "time", cmap="viridis", file_path=save_path)

In [None]:
ds_xr.ssh.hvplot.image(
    x="longitude",
    y="latitude",
    # groupby='time',
    # rasterize=True,
    width=500,
    height=400,
    cmap="viridis",
)

In [None]:
ds_xr.sel(time="2017-01-01").ssh.plot()

In [None]:
# import powerspec as ps

## Previous Work

In [None]:
model = "DUACS"  # "DYMOST"  # "MIOST" #  "BASELINE" # "BFN" # "4DVARNET" #

In [None]:
!ls /Volumes/EMANS_HDD/data/dc21b/results/

In [None]:
data_dir = f"/Volumes/EMANS_HDD/data/dc21b/results/OSE_ssh_mapping_{model}.nc"

In [None]:
ds = xr.open_dataset(data_dir)

#### Corrections

In [None]:
# correct labels
ds = correct_coordinate_labels(ds)

# correct longitude domain
ds = correct_longitude_domain(ds)

#### Time Period

**Daily Mean**

In [None]:
from inr4ssh._src.preprocess.coords import correct_longitude_domain

ds_baseline = xr.open_dataset(
    f"/Volumes/EMANS_HDD/data/dc21b/results/OSE_ssh_mapping_BASELINE.nc"
)
ds_duacs = xr.open_dataset(
    f"/Volumes/EMANS_HDD/data/dc21b/results/OSE_ssh_mapping_DUACS.nc"
)
ds_miost = xr.open_dataset(
    "/Volumes/EMANS_HDD/data/dc21b/results/OSE_ssh_mapping_MIOST.nc"
)
ds_siren = xr.open_dataset(f"/Volumes/EMANS_HDD/data/dc21b/results/siren_136.nc")

# correct labels
ds_baseline = correct_coordinate_labels(ds_baseline)
ds_duacs = correct_coordinate_labels(ds_duacs)
ds_siren = correct_coordinate_labels(ds_siren)
ds_miost = correct_coordinate_labels(ds_miost)

# correct longitude domain
ds_baseline = correct_longitude_domain(ds_baseline).resample(time="1D").mean()
ds_duacs = correct_longitude_domain(ds_duacs).resample(time="1D").mean()
ds_siren = correct_longitude_domain(ds_siren).resample(time="1D").mean()
ds_miost = correct_longitude_domain(ds_miost).resample(time="1D").mean()
# spatial temporal subset
fn = lambda x: temporal_subset(
    x,
    time_min=np.datetime64("2017-01-01"),
    time_max=np.datetime64("2017-02-01"),
    time_buffer=7.0,
    time_buffer_order="D",
)
ds_baseline = fn(ds_baseline)
ds_duacs = fn(ds_duacs)
ds_siren = fn(ds_siren)

ds_baseline = correct_longitude_domain(ds_baseline)
ds_duacs = correct_longitude_domain(ds_duacs)
ds_siren = correct_longitude_domain(ds_siren)
# calculate gradients and laplacian
ds_baseline["ssh_grad"] = calculate_gradient(
    ds_baseline["ssh"], "longitude", "latitude"
)
ds_baseline["ssh_lap"] = calculate_laplacian(
    ds_baseline["ssh"], "longitude", "latitude"
)
ds_duacs["ssh_grad"] = calculate_gradient(ds_duacs["ssh"], "longitude", "latitude")
ds_duacs["ssh_lap"] = calculate_laplacian(ds_duacs["ssh"], "longitude", "latitude")
ds_siren["ssh_grad"] = calculate_gradient(ds_siren["ssh"], "longitude", "latitude")
ds_siren["ssh_lap"] = calculate_laplacian(ds_siren["ssh"], "longitude", "latitude")

In [None]:
ds_duacs

In [None]:
import xarray as xr
import numpy
import pyinterp
import pyinterp.fill
import logging


def oi_regrid(ds_source, ds_target):

    logging.info("     Regridding...")

    # Define source grid
    x_source_axis = pyinterp.Axis(ds_source["longitude"][:].values, is_circle=False)
    y_source_axis = pyinterp.Axis(ds_source["latitude"][:].values)
    z_source_axis = pyinterp.TemporalAxis(ds_source["time"][:].values)
    ssh_source = ds_source["ssh"][:].T
    grid_source = pyinterp.Grid3D(
        x_source_axis, y_source_axis, z_source_axis, ssh_source.data
    )

    # Define target grid
    mx_target, my_target, mz_target = numpy.meshgrid(
        ds_target["longitude"].values,
        ds_target["latitude"].values,
        z_source_axis.safe_cast(ds_target["time"].values),
        indexing="ij",
    )
    # Spatio-temporal Interpolation
    ssh_interp = (
        pyinterp.trivariate(
            grid_source,
            mx_target.flatten(),
            my_target.flatten(),
            mz_target.flatten(),
            bounds_error=False,
        )
        .reshape(mx_target.shape)
        .T
    )

    # MB add extrapolation in NaN values if needed
    if numpy.isnan(ssh_interp).any():
        logging.info("     NaN found in ssh_interp, starting extrapolation...")
        x_source_axis = pyinterp.Axis(ds_target["longitude"].values, is_circle=False)
        y_source_axis = pyinterp.Axis(ds_target["latitude"].values)
        z_source_axis = pyinterp.TemporalAxis(ds_target["time"][:].values)
        grid = pyinterp.Grid3D(
            x_source_axis, y_source_axis, z_source_axis, ssh_interp.T
        )
        has_converged, filled = pyinterp.fill.gauss_seidel(grid)
    else:
        filled = ssh_interp.T

    # Save to dataset
    ds_ssh_interp = xr.Dataset(
        {"ssh": (("time", "latitude", "longitude"), filled.T)},
        coords={
            "time": ds_target["time"].values,
            "longitude": ds_target["longitude"].values,
            "latitude": ds_target["latitude"].values,
        },
    )

    return ds_ssh_interp

In [None]:
from inr4ssh._src.preprocess.coords import correct_longitude_domain

ds_baseline = xr.open_dataset(
    f"/Volumes/EMANS_HDD/data/dc21b/results/OSE_ssh_mapping_BASELINE.nc"
)
ds_duacs = xr.open_dataset(
    f"/Volumes/EMANS_HDD/data/dc21b/results/OSE_ssh_mapping_DUACS.nc"
)
ds_miost = xr.open_dataset(
    "/Volumes/EMANS_HDD/data/dc21b/results/OSE_ssh_mapping_MIOST.nc"
)
ds_siren = xr.open_dataset(f"/Volumes/EMANS_HDD/data/dc21b/results/siren_136.nc")

# correct labels
ds_baseline = correct_coordinate_labels(ds_baseline)
ds_duacs = correct_coordinate_labels(ds_duacs)
ds_siren = correct_coordinate_labels(ds_siren)
ds_miost = correct_coordinate_labels(ds_miost)

fn = lambda x: temporal_subset(
    x,
    time_min=np.datetime64("2017-01-01"),
    time_max=np.datetime64("2017-02-01"),
    time_buffer=7.0,
    time_buffer_order="D",
)
ds_baseline = fn(ds_baseline)
ds_duacs = fn(ds_duacs)
ds_miost = fn(ds_miost)
ds_siren = fn(ds_siren)

# ds_siren_interp = oi_regrid(ds_siren, ds_duacs)
# ds_baseline_interp = oi_regrid(ds_baseline, ds_duacs)
# ds_miost_interp = oi_regrid(ds_miost, ds_duacs)

In [None]:
ds_miost

In [None]:
ds_siren

In [None]:
ds_duacs.ssh

In [None]:
ds_siren_interp = oi_regrid(ds_miost, ds_duacs)

In [None]:
from inr4ssh._src.metrics.field.stats import nrmse_spacetime, rmse_space, nrmse_time

In [None]:
nrmse_xyt = nrmse_spacetime(ds_siren_interp["ssh"], ds_duacs["ssh"]).values
print(f"Leaderboard SSH RMSE score =  {nrmse_xyt:.2f}")

In [None]:
rmse_t = nrmse_time(ds_siren_interp["ssh"], ds_duacs["ssh"])

err_var_time = rmse_t.std().values
print(f"Error Variability =  {err_var_time:.2f}")

In [None]:
from inr4ssh._src.metrics.psd import (
    psd_isotropic_score,
    psd_spacetime_score,
    wavelength_resolved_spacetime,
    wavelength_resolved_isotropic,
)

In [None]:
time_norm = np.timedelta64(1, "D")
# mean psd of signal
ds_siren_interp["time"] = (ds_siren_interp.time - ds_siren_interp.time[0]) / time_norm
ds_duacs["time"] = (ds_duacs.time - ds_duacs.time[0]) / time_norm

In [None]:
# Time-Longitude (Lat avg) PSD Score
psd_score = psd_spacetime_score(ds_siren_interp["ssh"], ds_duacs["ssh"])

In [None]:
psd_score

In [None]:
spatial_resolved, time_resolved = wavelength_resolved_spacetime(psd_score.T)
print(f"Shortest Spatial Wavelength Resolved = {spatial_resolved:.2f} (degree lon)")
print(f"Shortest Temporal Wavelength Resolved = {time_resolved:.2f} (days)")

In [None]:
ds_duacs

In [None]:
# resample to daily mean
ds = ds.resample(time="1D").mean()

#### Spatio-Temporal Subset

In [None]:
# temporal subset
ds = temporal_subset(
    ds,
    time_min=np.datetime64("2017-01-01"),
    time_max=np.datetime64("2018-01-01"),
    time_buffer=7.0,
    time_buffer_order="D",
)

In [None]:
# spatial subset
if model != "4DVARNET":
    ds = spatial_subset(
        ds,
        lon_min=285.0,
        lon_max=315.0,
        lon_buffer=1.0,
        lat_min=23.0,
        lat_max=53.0,
        lat_buffer=1.0,
    )

In [None]:
ds

In [None]:
# save_path = "./"
# create_movie(ds.ssh, f"field_{model.lower()}", "time", cmap="viridis", file_path=save_path)

#### SSH

In [None]:
ds.ssh.hvplot.image(
    x="longitude",
    y="latitude",
    # groupby='time',
    # rasterize=True,
    width=500,
    height=400,
    cmap="viridis",
)

#### Gradient

In [None]:
ds["ssh_grad"] = calculate_gradient(ds["ssh"], "longitude", "latitude")

In [None]:
# create_movie(ds.ssh_grad, f"field_{model.lower()}_grad", "time", cmap="Spectral_r", file_path=save_path)
#

In [None]:
ds.ssh_grad.hvplot.image(
    x="longitude",
    y="latitude",
    # groupby='time',
    # rasterize=True,
    width=500,
    height=400,
    cmap="Spectral_r",
)

#### (Norm) Laplacian

In [None]:
ds["ssh_lap"] = calculate_laplacian(ds["ssh"], "longitude", "latitude")

In [None]:
# create_movie(ds.ssh_lap, f"field_{model.lower()}_lap", "time", cmap="RdBu_r", file_path=save_path)

In [None]:
ds.ssh_lap.hvplot.image(
    x="longitude",
    y="latitude",
    # groupby='time',
    # rasterize=True,
    width=500,
    height=400,
    cmap="RdBu_r",
)

In [None]:
# BASELINE
fig, ax = plt.subplots(figsize=(8, 7))
ds_baseline.sel(time="2017-01-20").ssh.plot(
    cmap="viridis", ax=ax, label="", vmin=-1.2, vmax=1.2, add_colorbar=False
)
ax.set(xlabel=r"", ylabel="", title="")
fig.savefig("baseline_ssh.png")
plt.show()

# DUACS
fig, ax = plt.subplots(figsize=(8, 7))
ds_duacs.sel(time="2017-01-20").ssh.plot(
    cmap="viridis", ax=ax, label="", vmin=-1.2, vmax=1.2, add_colorbar=False
)
ax.set(xlabel=r"Longitude ($^\circ$)", ylabel="Latitude ($^\circ$)", title="")
fig.savefig("duacs_ssh.png")
plt.show()


# SIREN
fig, ax = plt.subplots(figsize=(10, 7))
ds_siren.sel(time="2017-01-20").ssh.plot(
    cmap="viridis", ax=ax, label="", vmin=-1.2, vmax=1.2
)
ax.set(xlabel=r"Longitude ($^\circ$)", ylabel="Latitude ($^\circ$)", title="")
fig.savefig("siren_ssh.png")
plt.show()

In [None]:
# BASELINE
fig, ax = plt.subplots(figsize=(8, 7))
ds_baseline.sel(time="2017-01-20").ssh_grad.plot(
    cmap="Spectral_r", ax=ax, label="", vmin=0.0, vmax=2.2, add_colorbar=False
)
ax.set(xlabel=r"Longitude ($^\circ$)", ylabel="Latitude ($^\circ$)", title="")
fig.savefig("baseline_grad.png")
plt.show()

# DUACS
fig, ax = plt.subplots(figsize=(8, 7))
ds_duacs.sel(time="2017-01-20").ssh_grad.plot(
    cmap="Spectral_r", ax=ax, label="", vmin=0.0, vmax=2.2, add_colorbar=False
)
ax.set(xlabel=r"Longitude ($^\circ$)", ylabel="Latitude ($^\circ$)", title="")
fig.savefig("duacs_grad.png")
plt.show()


# SIREN
fig, ax = plt.subplots(figsize=(10, 7))
ds_siren.sel(time="2017-01-20").ssh_grad.plot(
    cmap="Spectral_r", ax=ax, label="", vmin=-0.0, vmax=2.2, cbar_kwargs={"label": ""}
)
ax.set(xlabel=r"Longitude ($^\circ$)", ylabel="Latitude ($^\circ$)", title="")
fig.savefig("siren_grad.png")
plt.show()

In [None]:
# BASELINE
fig, ax = plt.subplots(figsize=(8, 7))
ds_baseline.sel(time="2017-01-20").ssh_lap.plot(
    cmap="RdBu_r", ax=ax, label="", vmin=0.0, vmax=5.0, add_colorbar=False
)
ax.set(xlabel="", ylabel="", title="")
fig.savefig("baseline_lap.png")
plt.show()

# DUACS
fig, ax = plt.subplots(figsize=(8, 7))
ds_duacs.sel(time="2017-01-20").ssh_lap.plot(
    cmap="RdBu_r", ax=ax, label="", vmin=0.0, vmax=5.0, add_colorbar=False
)
ax.set(xlabel="", ylabel="", title="")
fig.savefig("duacs_lap.png")
plt.show()


# SIREN
fig, ax = plt.subplots(figsize=(10, 7))
ds_siren.sel(time="2017-01-20").ssh_lap.plot(
    cmap="RdBu_r", ax=ax, label="", vmin=-0.0, vmax=5.0, cbar_kwargs={"label": ""}
)
ax.set(xlabel="", ylabel="", title="")
fig.savefig("siren_lap.png")
plt.show()