# Figures - Data

In [None]:
import sys, os

import ml_collections
from pyprojroot import here

# spyder up to find the root
root = here(project_files=[".root"])


# append to path
sys.path.append(str(root))

In [None]:
from pathlib import Path
import numpy as np
from pathlib import Path
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt

import seaborn as sns

sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)
import scienceplots

# plt.style.use("science")

import hvplot.xarray
import hvplot.pandas

from inr4ssh._src.operators.finite_diff import calculate_gradient, calculate_laplacian
from inr4ssh._src.preprocess.subset import temporal_subset, spatial_subset
from inr4ssh._src.preprocess.coords import (
    correct_coordinate_labels,
    correct_longitude_domain,
)
from inr4ssh._src.data.ssh_obs import load_ssh_altimetry_data_train

from inr4ssh._src.preprocess.coords import correct_longitude_domain
from inr4ssh._src.preprocess.obs import bin_observations_xr, bin_observations_coords
from inr4ssh._src.preprocess.grid import create_spatiotemporal_grid
from inr4ssh._src.viz.movie import create_movie
from inr4ssh._src.metrics.psd import psd_isotropic
from inr4ssh._src.viz.psd.isotropic import plot_psd_isotropic
from inr4ssh._src.viz.obs import plot_obs_demo
from inr4ssh._src.metrics.psd import psd_spacetime, psd_spacetime_dask
from inr4ssh._src.viz.psd.spacetime import (
    plot_psd_spacetime_wavelength,
    plot_psd_spacetime_wavenumber,
)
from inr4ssh._src.viz.psd.spacetime import (
    plot_psd_spacetime_score_wavelength,
    plot_psd_spacetime_score_wavenumber,
)
from loguru import logger

%matplotlib inline
%load_ext autoreload
%autoreload 2

## Data

### Evaluation Field

In [None]:
from inr4ssh._src.preprocess.spatial import convert_lon_360_180, convert_lon_180_360


def post_process(ds, variable):

    # correct coordinate labels
    logger.info("Fixing coordinate labels...")
    ds = correct_coordinate_labels(ds)

    # correct labels
    logger.info("Fixing labels")
    ds = ds.rename({variable: "ssh"})

    # correct longitude domain
    logger.info("Fixing longitude domain")
    from inr4ssh._src.preprocess.spatial import convert_lon_360_180

    # ds["longitude"] = convert_lon_180_360(ds.longitude)
    ds["longitude"] = convert_lon_360_180(ds.longitude)

    # # subset temporal space
    # ds = ds.sel(time=slice(np.datetime64("2017-02-01"), np.datetime64("2017-03-01")))

    # # subset spatial space
    # ds = ds.sel(
    #     longitude=slice(-75.0, -45.0),
    #     latitude=slice(33.0, 53.0)
    # )

    # # subset spatial space
    # ds = ds.where(
    #     (ds["longitude"] >= -75.0)
    #     & (ds["longitude"] <= -45.0)
    #     & (ds["latitude"] >= 33.0)
    #     & (ds["latitude"] <= 53.0),
    #     drop=True,
    # )

    # # subset spatial space (evaluation)
    # ds = ds.where(
    #     (ds["longitude"] >= -65.0)
    #     & (ds["longitude"] <= -55.0)
    #     & (ds["latitude"] >= 33.0)
    #     & (ds["latitude"] <= 43.0),
    #     drop=True,
    # )

    ds = ds.transpose("time", "latitude", "longitude")

    # regrid data
    return ds

In [None]:
fig_path = Path(root).joinpath("figures/dc21a")

## Reference Grid

In [None]:
from inr4ssh._src.preprocess.regrid import oi_regrid

logger.info("Dataset I - DUACS")
url = "/Volumes/EMANS_HDD/data/dc21b_ose/test_2/results/OSE_ssh_mapping_DUACS.nc"
ds_field = xr.open_dataset(url)

ds_field = post_process(ds_field, "ssh")
ds_field

In [None]:
vmin = np.min([ds_field.ssh.values])
vmax = np.max([ds_field.ssh.values])

# with plt.style.context('science'):
fig, ax = plt.subplots()

ds_field.ssh.sel(time="2017-02-01").plot(
    ax=ax, cmap="viridis", robust=True, cbar_kwargs={"label": ""}, vmin=vmin, vmax=vmax
)
ax.set(xlabel="Longitude", ylabel="Latitude", title="")
plt.show()

### Regridding

There are a few opertions that we are interested in doing:
* Bounds <---> Coordinates
* Coordinates <---> Grid

In [None]:
from ml_collections import config_dict

# create configuration
def get_lowres_config():
    config = config_dict.ConfigDict()

    config.lon_min = -65  # -75.0
    config.lon_max = -55.0  # -45.0
    config.dlon = 0.1
    config.lat_min = 33.0
    config.lat_max = 43.0  # 53.0
    config.dlat = 0.1
    config.time_min = np.datetime64("2017-02-01")
    config.time_max = np.datetime64("2017-03-01")
    config.dt_freq = 1
    config.dt_unit = "D"
    config.dtime = "1_D"  # np.timedelta64(1, "D")
    config.time_buffer = np.timedelta64(1, "D")
    return config


def get_hires_config():
    config = get_lowres_config()
    config.dlon = 0.05
    config.dlat = 0.05
    config.dtime = "12_h"
    return config


def get_superres_config():
    config = get_lowres_config()
    config.dlon = 0.01
    config.dlat = 0.01
    config.dtime = "6_h"
    return config

In [None]:
from inr4ssh._src.preprocess.coords import Bounds2DT, Coordinates2DT, Grid2DT

# create coordinates class from config
config = get_lowres_config()
bounds = Bounds2DT.init_from_config(config)

# create coordinates from bounds
coords = bounds.create_coordinates()
# create bounds from coordinates
bounds_ = coords.create_bounds()
np.testing.assert_almost_equal(bounds.lon_min, bounds_.lon_min, decimal=1)
np.testing.assert_almost_equal(bounds.lon_max, bounds_.lon_max, decimal=1)
np.testing.assert_almost_equal(bounds.dlon, bounds_.dlon, decimal=1)

# create grid from coordinates
grid = coords.create_grid()

# create coordinates from grids
coords_ = grid.create_coords()
np.testing.assert_array_equal(coords.lon_coords, coords_.lon_coords)
np.testing.assert_array_equal(coords.lat_coords, coords_.lat_coords)
np.testing.assert_array_equal(coords.time_coords, coords_.time_coords)

In [None]:
# init config
config = get_hires_config()

# create target grid
grid_target = Bounds2DT.init_from_config(config).create_coordinates().create_grid()

In [None]:
import pyinterp
from einops import rearrange
from inr4ssh._src.preprocess.regrid import (
    create_pyinterp_grid_2dt,
    regrid_2dt_from_grid,
    regrid_2dt_from_da,
)
from inr4ssh._src.interp import interp_2dt

In [None]:
# regrid
ds_field_hires = regrid_2dt_from_grid(
    ds_field.ssh,
    grid_target,
    is_circle=False,
)

# fill gaps (around edges)
ds_field_hires = interp_2dt(ds_field_hires, is_circle=False, method="loess")

ds_field.ssh.shape, ds_field_hires.shape

### Figure

In [None]:
vmin = np.min([ds_field.ssh.values])
vmax = np.max([ds_field.ssh.values])

# with plt.style.context('science'):
fig, ax = plt.subplots()

ds_field_hires.sel(time="2017-02-01")[0].plot(
    ax=ax, cmap="viridis", robust=True, cbar_kwargs={"label": ""}, vmin=vmin, vmax=vmax
)
ax.set(xlabel="Longitude", ylabel="Latitude", title="")
plt.tight_layout()
plt.show()

### Movie (GIF)

In [None]:
vmin = np.min([ds_field.ssh.values])
vmax = np.max([ds_field.ssh.values])

create_movie(
    ds_field_hires.sel(time=slice("2017-02-01", "2017-03-01")),
    name="dc21a_ssh_duacs",
    file_path=fig_path,
    framedim="time",
    cmap="viridis",
    cbar_kwargs={"label": ""},
    robust=True,
)

## Study Grid

In [None]:
logger.info("Dataset II - BASELINE")
url = "/Volumes/EMANS_HDD/data/dc21b_ose/test_2/results/OSE_ssh_mapping_BASELINE.nc"
ds_predict = xr.open_dataset(url)

ds_predict = post_process(ds_predict, "ssh")
ds_predict
# ds_field["ssh_oi"] = oi_regrid(ds_predict["ssh"], ds_field["ssh"])
#
# ds_field

### Regridding

In [None]:
# regrid
ds_predict_hires = regrid_2dt_from_da(
    ds_predict.ssh,
    ds_field_hires,
    is_circle=False,
)

# fill gaps (around edges)
ds_predict_hires = interp_2dt(ds_predict_hires, is_circle=False, method="loess")

ds_predict.ssh.shape, ds_predict_hires.shape

### Image (Static)

In [None]:
vmin = np.min([ds_predict_hires.values.min(), ds_field_hires.values.min()])
vmax = np.max([ds_predict_hires.values.max(), ds_field_hires.values.max()])

# with plt.style.context('science'):
fig, ax = plt.subplots()

ds_predict_hires.sel(time="2017-02-01")[0].plot(
    ax=ax, cmap="viridis", robust=True, cbar_kwargs={"label": ""}, vmin=vmin, vmax=vmax
)
ax.set(xlabel="Longitude", ylabel="Latitude", title="")
plt.show()

### Movie (GIF)

In [None]:
vmin = np.min([ds_predict_hires.values.min(), ds_field_hires.values.min()])
vmax = np.max([ds_predict_hires.values.max(), ds_field_hires.values.max()])


create_movie(
    ds_predict_hires.sel(time=slice("2017-02-01", "2017-03-01")),
    name="dc21a_ssh_oi",
    file_path=fig_path,
    framedim="time",
    cmap="viridis",
    cbar_kwargs={"label": ""},
    robust=True,
)

## AlongTrack Observations

In [None]:
filename = "/Volumes/EMANS_HDD/data/dc21b_ose/test_2/ml_ready/train.nc"
ds_alongtrack = xr.open_dataset(filename, engine="netcdf4")
ds_alongtrack = correct_coordinate_labels(ds_alongtrack)
ds_alongtrack["longitude"] = convert_lon_360_180(ds_alongtrack.longitude)
# ds_alongtrack["longitude"] = convert_lon_360_180(ds_alongtrack.longitude)

ds_alongtrack["ssh"] = (
    ds_alongtrack["sla_unfiltered"] + ds_alongtrack["mdt"] - ds_alongtrack["lwe"]
)
ds_alongtrack

In [None]:
ds_alongtrack.longitude.min(), ds_alongtrack.longitude.max()

### Binning (with Reference)

In [None]:
from inr4ssh._src.preprocess.alongtrack import (
    alongtrack_bin_from_da,
    alongtrack_bin_from_coords,
)

# with a reference
ds_alongtrack_hires = alongtrack_bin_from_da(
    ds_alongtrack,
    variable="ssh",
    ds_ref=ds_field_hires,
    time_buffer=np.timedelta64(12, "h"),
)

In [None]:
config = get_hires_config()
coords = Bounds2DT.init_from_config(config).create_coordinates()
ds_alongtrack_hires = alongtrack_bin_from_coords(
    ds_alongtrack, variable="ssh", coords=coords, time_buffer=np.timedelta64(12, "h")
)

#### Figure

In [None]:
# with plt.style.context('science'):
fig, ax = plt.subplots()

ds_alongtrack_hires.ssh.sel(time="2017-02-15")[0].plot.pcolormesh(
    ax=ax, cmap="viridis", robust=True, cbar_kwargs={"label": ""}
)

ax.set(xlabel="Longitude", ylabel="Latitude", title="")
plt.show()

In [None]:
import cartopy.crs as ccrs

fig = plt.figure()

ax = plt.subplot(
    # projection=ccrs.Orthographic(-80,35)
    projection=ccrs.PlateCarree()
)


ds_alongtrack_hires.ssh.sel(time="2017-02-15")[0].plot(
    ax=ax,
    cmap="viridis",
    robust=True,
    # transform=ccrs.PlateCarree(),
    # infer_intervals=True,
    cbar_kwargs={"label": ""},
)

ds_alongtrack_hires.ssh.sel(time="2017-02-15")[0].plot.scatter(
    x="latitude", y="longitude", transform=ccrs.PlateCarree()
)
ax.coastlines()
ax.set_global()
ax.gridlines(draw_labels=True)
plt.tight_layout()
plt.show()

In [None]:
# with plt.style.context('science'):
fig, ax = plt.subplots()

ds_alongtrack_hires.ssh.sel(time="2017-02-15")[0].plot.pcolormesh(
    ax=ax,
    cmap="viridis",
    robust=True,
    subplot_kws={"projection": ccrs.Orthographic(-80, 35)},
    cbar_kwargs={"label": ""},
)

ax.set(xlabel="Longitude", ylabel="Latitude", title="")
plt.show()

### Movie (GIF)

In [None]:
vmin = np.min(
    [
        ds_predict_hires.values.min(),
        ds_field_hires.values.min(),
        ds_alongtrack_hires.ssh.values.min(),
    ]
)
vmax = np.max(
    [
        ds_predict_hires.values.max(),
        ds_field_hires.values.max(),
        ds_alongtrack_hires.ssh.values.max(),
    ]
)


create_movie(
    ds_alongtrack_hires.ssh.sel(time=slice("2017-02-01", "2017-03-01")),
    name="dc21a_ssh_obs",
    file_path=fig_path,
    framedim="time",
    cmap="viridis",
    cbar_kwargs={"label": ""},
    robust=True,
)

### Binning (Without Reference)

In [None]:
from ml_collections import config_dict
from inr4ssh._src.preprocess.grid import create_spatiotemporal_coords
from inr4ssh._src.preprocess.obs import bin_observations_coords, bin_observations_xr

# create configuration
config = config_dict.ConfigDict()

config.lon_min = 285.0  # -75.0
config.lon_max = 315.0  # -45.0
config.dlon = 0.1
config.lat_min = 23.0
config.lat_max = 53.0
config.dlat = 0.1
config.time_min = np.datetime64("2017-02-01")
config.time_max = np.datetime64("2017-03-01")
config.dt_freq = 1
config.dt_unit = "D"
config.dtime = np.timedelta64(1, "D")
config.time_buffer = np.timedelta64(12, "h")

# create spatiotemporal grid
lon_coords, lat_coords, time_coords = create_spatiotemporal_coords(
    lon_min=config.lon_min,
    lon_max=config.lon_max,
    lon_dx=config.dlon,
    lat_min=config.lat_min,
    lat_max=config.lat_max,
    lat_dy=config.dlat,
    time_min=config.time_min,
    time_max=config.time_max,
    time_dt=config.dtime,
)

# binning with coordinates
ds_alongtrack_hires = bin_observations_coords(
    ds_alongtrack,
    variable="ssh",
    lon_coords=lon_coords,
    lat_coords=lat_coords,
    time_coords=time_coords,
    time_buffer=config.time_buffer,
)

In [None]:
ds_alongtrack_hires

In [None]:
# ds_field["obs"]\
_ = bin_observations_xr(
    ds_obs=ds_alongtrack,
    ds_ref=ds_field,
    variable="ssh",
    time_buffer=np.timedelta64(12, "h"),
)["ssh"]

### Figure

In [None]:
# with plt.style.context('science'):
fig, ax = plt.subplots()

ds_field.obs.sel(time="2017-02-01").plot(
    ax=ax, cmap="viridis", robust=True, cbar_kwargs={"label": ""}
)
ax.set(xlabel="Longitude", ylabel="Latitude", title="")
plt.show()

### Video

In [None]:
def count_num_obs(ds, central_date, delta_t):
    tmin = central_date - delta_t
    tmax = central_date + delta_t

    ds = ds.sel(time=slice(tmin, tmax))

    ds = ds.drop_duplicates(dim="time")
    return len(ds.values.flatten())

In [None]:
central_date = np.datetime64("2012-10-22")
num_days = 1_0000
delta_t = np.timedelta64(num_days, "D")

count_num_obs(ds_field.ssh, central_date, delta_t)

In [None]:
ds_field.isel(time=0).ssh.plot.imshow()

### Density

In [None]:
# fig, ax = plt.subplots()
# sns.kdeplot(
#     # data=ds_field.ssh.values.flatten(),
#     # data=np.log(ds_field.ssh_grad.values.flatten()),
#     data=np.log(ds_field.ssh_lap.values.flatten()),
#     cumulative=True, common_norm=False, common_grid=True,
#     ax=ax
# )
# # ax.set_xlabel("SSH [m]")
# # ax.set_xlabel(r"Log Kinetic Energy [m$^2$s$^{-2}$]")
# ax.set_xlabel(r"Log Enstropy [s$^{-1}$]")
# ax.set_ylabel("Cumulative Density")
# plt.show()

In [None]:
# fig, ax = plt.subplots()
# sns.kdeplot(
#     # data=ds_field.ssh.values.flatten(),
#     # data=np.log(ds_field.ssh_grad.values.flatten()),
#     data=np.log(ds_field.ssh_lap.values.flatten()),
#     cumulative=False, common_norm=False, common_grid=True,
#     ax=ax
# )
# # ax.set_xlabel("SSH [m]")
# # ax.set_xlabel(r"Log Kinetic Energy [m$^2$s$^{-2}$]")
# ax.set_xlabel(r"Log Enstropy [s$^{-1}$]")
# ax.set_ylabel("Density")
# plt.show()

#### Movie (GIF)

In [None]:
# create_movie(ds_field.ssh, "ssh_field", framedim="time", cmap="viridis")

#### Gradients/Laplacian

In [None]:
from inr4ssh._src.operators.finite_diff import calculate_gradient, calculate_laplacian

ds_field["ssh_grad"] = calculate_gradient(ds_field["ssh"], "longitude", "latitude")
ds_field["ssh_lap"] = calculate_gradient(ds_field["ssh_grad"], "longitude", "latitude")


# create_movie(ds_field.ssh_grad, "ssh_field_grad", framedim="time", cmap="Spectral_r")
create_movie(np.log(ds_field.ssh_lap), "ssh_field_lap", framedim="time", cmap="RdBu_r")

### PSD

In [None]:
ds_field_psd = correct_coordinate_labels(ds_field)

# grab ssh
ds_field_psd = ds_field_psd.ssh

# correct units, degrees -> meters
ds_field_psd["longitude"] = ds_field_psd.longitude * 111e3
ds_field_psd["latitude"] = ds_field_psd.latitude * 111e3

# calculate
ds_field_psd = psd_isotropic(ds_field_psd)

In [None]:
fig, ax = plot_psd_isotropic(ds_field_psd.freq_r.values * 1e3, ds_field_psd.values)
plt.tight_layout()
plt.show()

### PSD - Spatial-Temporal

In [None]:
ds_field_psd = correct_coordinate_labels(ds_field)

# grab ssh
ds_field_psd = ds_field_psd.ssh_grad

# grab ssh
ds_field_psd = ds_field_psd.compute()

# correct units, degrees -> meters
ds_field_psd["longitude"] = ds_field_psd.longitude * 111e3
ds_field_psd["latitude"] = ds_field_psd.latitude * 111e3

time_norm = np.timedelta64(1, "D")
# mean psd of signal
ds_field_psd["time"] = (ds_field_psd.time - ds_field_psd.time[0]) / time_norm

# calculate
ds_field_psd = psd_spacetime_dask(ds_field_psd)

In [None]:
fig, ax, cbar = plot_psd_spacetime_wavelength(
    ds_field_psd.freq_longitude * 1e3,
    ds_field_psd.freq_time,
    ds_field_psd,
)

plt.tight_layout()
plt.show()

In [None]:
# fig, ax, cbar = plot_psd_spacetime_wavenumber(
#     ds_field_psd.freq_longitude * 1e3,
#     ds_field_psd.freq_time,
#     ds_field_psd,
# )

# plt.tight_layout()
# plt.show()

## Observations

In [None]:
# # grab ssh
# ds_field_psd = ds_field.ssh_grad

# # correct units, degrees -> meters
# ds_field_psd["longitude"] = ds_field_psd.longitude * 111e3
# ds_field_psd["latitude"] = ds_field_psd.latitude * 111e3

# # calculate
# ds_field_psd = psd_isotropic(ds_field_psd)

# fig, ax = plot_isotropic_psd(ds_field_psd, freq_scale=1e3)
# ax.set_ylabel(r"PSD [m$^2$s$^{-2}$/cyles/m")
# plt.tight_layout()
# plt.show()

### Missing Time

In [None]:
ds_obs = xr.open_dataset(
    "/Users/eman/code_projects/data/osse_2022b/dc_qg_obs_fullfields/ssh_obs_fullfields.nc"
)


ds_obs = correct_coordinate_labels(ds_obs)

ds_obs = ds_obs.rename({"ssh": "ssh_obs"})

ds_obs

In [None]:
central_date = np.datetime64("2012-10-22")
num_days = 100
delta_t = np.timedelta64(num_days, "D")

count_num_obs(ds_obs.ssh_obs, central_date, delta_t)

In [None]:
ds_obs = xr.merge([ds_field, ds_obs])

#### Movie (GIF)

In [None]:
# create_movie(ds_obs.ssh_obs, "ssh_missing_time", framedim="time", cmap="viridis")

In [None]:
# !ls /Users/eman/code_projects/data/osse_2022b/dc_qg_obs_nadirlike/

### Jason-Like

In [None]:
ds_obs = xr.open_dataset(
    "/Users/eman/code_projects/data/osse_2022b/dc_qg_obs_jasonlike/ssh_obs_jasonlike.nc"
)

ds_obs = ds_obs.sortby("time")

ds_obs = correct_coordinate_labels(ds_obs)

# ds_obs = ds_obs.rename({"ssh": "ssh_obs"})
ds_obs

In [None]:
central_date = np.datetime64("2012-10-22")
num_days = 100
delta_t = np.timedelta64(num_days, "D")

count_num_obs(ds_obs.ssh_obs, central_date, delta_t)

In [None]:
central_date = np.datetime64("2012-10-22")
num_days = 1
delta_t = np.timedelta64(num_days, "D")
variable = "ssh_obs"

plot_obs_demo(ds_obs, central_date, delta_t, variable, verbose=True)

#### Gridded Dataset

In [None]:
ds_obs_binned = bin_observations(ds_obs, ds_field, "ssh_obs", np.timedelta64(12, "h"))

In [None]:
ds_obs_binned.isel(time=10).ssh_obs.plot()

#### Movie (GIF)

In [None]:
# create_movie(ds_obs_binned.ssh_obs, "ssh_jasonlike", framedim="time", cmap="viridis")

### NADIR-Like

In [None]:
ds_obs = xr.open_mfdataset(
    "/Users/eman/code_projects/data/osse_2022b/dc_qg_obs_nadirlike/ssh_obs*.nc",
    combine="nested",
    concat_dim="time",
    parallel=True,
    preprocess=None,
    engine="netcdf4",
)


ds_obs = ds_obs.sortby("time")

ds_obs = correct_coordinate_labels(ds_obs)

# ds_obs = ds_obs.rename({"ssh": "ssh_obs"})
ds_obs

In [None]:
central_date = np.datetime64("2012-10-22")
num_days = 1

delta_t = np.timedelta64(num_days, "D")

count_num_obs(ds_obs.ssh_obs, central_date, delta_t)

#### Demo

In [None]:
central_date = np.datetime64("2012-10-22")
num_days = 10
delta_t = np.timedelta64(num_days, "D")
variable = "ssh_obs"

plot_obs_demo(ds_obs, central_date, delta_t, variable, verbose=True)

#### Gridded Dataset

In [None]:
ds_obs_binned = bin_observations(ds_obs, ds_field, "ssh_obs", np.timedelta64(12, "h"))

In [None]:
ds_obs_binned.isel(time=10).ssh_obs.plot()

#### Movie (GIF)

In [None]:
# create_movie(ds_obs_binned.ssh_obs, "ssh_nadirlike", framedim="time", cmap="viridis")