# OSSE 2022b - QG Simulations

In [None]:
import sys, os
from pyprojroot import here

# spyder up to find the root
root = here(project_files=[".root"])


# append to path
sys.path.append(str(root))

In [None]:
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns

sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)

import hvplot.xarray
import hvplot.pandas

from inr4ssh._src.operators.finite_diff import calculate_gradient, calculate_laplacian
from inr4ssh._src.preprocess.subset import temporal_subset, spatial_subset
from inr4ssh._src.preprocess.coords import (
    correct_coordinate_labels,
    correct_longitude_domain,
)
from inr4ssh._src.data.ssh_obs import load_ssh_altimetry_data_train

from inr4ssh._src.preprocess.coords import correct_coordinate_labels
from inr4ssh._src.preprocess.obs import bin_observations
from inr4ssh._src.viz.movie import create_movie
from inr4ssh._src.metrics.psd import psd_isotropic
from inr4ssh._src.viz.psd.isotropic import plot_psd_isotropic
from inr4ssh._src.viz.obs import plot_obs_demo
from inr4ssh._src.metrics.psd import psd_spacetime, psd_spacetime_dask
from inr4ssh._src.viz.psd.spacetime import (
    plot_psd_spacetime_wavelength,
    plot_psd_spacetime_wavenumber,
)

%matplotlib inline
%load_ext autoreload
%autoreload 2

## Data

### Evaluation Field

In [None]:
url = "/Users/eman/code_projects/data/osse_2022b/dc_qg_eval/dc_qg_eval_*.nc"
url = "/Volumes/EMANS_HDD/data/dc22b_osse/raw/dc_qg_eval/dc_qg_eval_*.nc"
ds_field = xr.open_mfdataset(url)

ds_field = ds_field.rename({"nav_lon": "lon"}).rename({"nav_lat": "lat"})

ds_field = correct_coordinate_labels(ds_field)

ds_field = correct_longitude_domain(ds_field)

ds_field

In [None]:
def count_num_obs(ds, central_date, delta_t, drop_duplicates=False):
    tmin = central_date - delta_t
    tmax = central_date + delta_t

    ds = ds.sel(time=slice(tmin, tmax))

    if drop_duplicates:
        ds = ds.drop_duplicates(dim="time")
    return len(ds.values.flatten())

In [None]:
central_date = np.datetime64("2012-10-22")
num_days = 100
delta_t = np.timedelta64(num_days, "D")

count_num_obs(ds_field.ssh, central_date, delta_t)

In [None]:
ds_field.isel(time=0).ssh.plot.imshow()

#### Movie (GIF)

In [None]:
# create_movie(ds_field.ssh_lap, "ssh_field_lap", framedim="time", cmap="RdBu_r")

### Density

In [None]:
fig, ax = plt.subplots()
sns.kdeplot(
    data=np.log(ds_field.ssh_grad.values.flatten()),
    cumulative=True,
    common_norm=False,
    common_grid=True,
    ax=ax,
)
# ax.set_xlabel("SSH [m]")
ax.set_xlabel(r"Log Kinetic Energy [m$^2$s$^{-2}$]")
# ax.set_xlabel(r"Log Enstropy [s$^{-1}$]")
ax.set_ylabel("Cumulative Density")
plt.show()

In [None]:
fig, ax = plt.subplots()
sns.kdeplot(
    data=np.log(ds_field.ssh_grad.values.flatten()),
    cumulative=False,
    common_norm=False,
    common_grid=True,
    ax=ax,
)
# ax.set_xlabel("SSH [m]")
ax.set_xlabel(r"Log Kinetic Energy [m$^2$s$^{-2}$]")
# ax.set_xlabel(r"Log Enstropy [s$^{-1}$]")
ax.set_ylabel("Density")
plt.show()

#### Gradients/Laplacian

In [None]:
from inr4ssh._src.operators.finite_diff import calculate_gradient, calculate_laplacian

ds_field["ssh_grad"] = calculate_gradient(ds_field["ssh"], "longitude", "latitude")
ds_field["ssh_lap"] = calculate_gradient(ds_field["ssh_grad"], "longitude", "latitude")

# create_movie(ds_field.ssh_grad, "ssh_field_grad", framedim="time", cmap="Spectral_r")
# create_movie(ds_field.ssh_lap, "ssh_field_lap", framedim="time", cmap="RdBu_r")
create_movie(
    np.log(ds_field.ssh_lap), "ssh_field_loglap", framedim="time", cmap="RdBu_r"
)

### PSD

In [None]:
ds_field_psd = correct_coordinate_labels(ds_field)

# grab ssh
ds_field_psd = ds_field_psd.ssh_lap

# correct units, degrees -> meters
ds_field_psd["longitude"] = ds_field_psd.longitude * 111e3
ds_field_psd["latitude"] = ds_field_psd.latitude * 111e3

# calculate
ds_field_psd = psd_isotropic(ds_field_psd)

In [None]:
fig, ax = plot_psd_isotropic(ds_field_psd.freq_r.values * 1e3, ds_field_psd.values)
# ax.set_ylabel(r"PSD [m$^2$s$^{-2}$/cyles/m")
# ax.set_ylabel(r"PSD [s$^{-1}$/cyles/m")
plt.tight_layout()
plt.show()

### PSD - Spatial-Temporal

In [None]:
ds_field_psd = correct_coordinate_labels(ds_field)

# grab ssh
ds_field_psd = ds_field_psd.ssh

# grab ssh
ds_field_psd = ds_field_psd.compute()

# correct units, degrees -> meters
ds_field_psd["longitude"] = ds_field_psd.longitude * 111e3
ds_field_psd["latitude"] = ds_field_psd.latitude * 111e3

time_norm = np.timedelta64(1, "D")
# mean psd of signal
ds_field_psd["time"] = (ds_field_psd.time - ds_field_psd.time[0]) / time_norm

# calculate
ds_field_psd = psd_spacetime_dask(ds_field_psd)

In [None]:
fig, ax, cbar = plot_psd_spacetime_wavelength(
    ds_field_psd.freq_longitude * 1e3,
    ds_field_psd.freq_time,
    ds_field_psd,
)
# ax.set_xlim((1000, 10))
# cbar.ax.set_ylabel(r"PSD [m$^2$s$^{-2}$/cyles/m")
# cbar.ax.set_ylabel(r"PSD [s$^{-1}$/cyles/m")

plt.tight_layout()
plt.show()

In [None]:
fig, ax, cbar = plot_psd_spacetime_wavenumber(
    ds_field_psd.freq_longitude * 1e3,
    ds_field_psd.freq_time,
    ds_field_psd,
)

plt.tight_layout()
plt.show()

## Observations

In [None]:
# # grab ssh
# ds_field_psd = ds_field.ssh_grad

# # correct units, degrees -> meters
# ds_field_psd["longitude"] = ds_field_psd.longitude * 111e3
# ds_field_psd["latitude"] = ds_field_psd.latitude * 111e3

# # calculate
# ds_field_psd = psd_isotropic(ds_field_psd)

# fig, ax = plot_isotropic_psd(ds_field_psd, freq_scale=1e3)
# ax.set_ylabel(r"PSD [m$^2$s$^{-2}$/cyles/m")
# plt.tight_layout()
# plt.show()

### Jason-Like

In [None]:
ds_obs = xr.open_dataset(
    "/Volumes/EMANS_HDD/data/dc20a_osse/test/preprocess/osse_2020a_natl60/2020a_SSH_mapping_NATL60_jason1.nc"
)


ds_obs = ds_obs.sortby("time")

ds_obs = correct_coordinate_labels(ds_obs)

# ds_obs = ds_obs.rename({"ssh": "ssh_obs"})
ds_obs

In [None]:
central_date = np.datetime64("2012-10-22")
num_days = 100
delta_t = np.timedelta64(num_days, "D")

count_num_obs(ds_obs.ssh_obs, central_date, delta_t)

In [None]:
central_date = np.datetime64("2012-10-22")
num_days = 1
delta_t = np.timedelta64(num_days, "D")
variable = "ssh_obs"

plot_obs_demo(ds_obs, central_date, delta_t, variable, verbose=True)

#### Gridded Dataset

In [None]:
ds_obs_binned = bin_observations(ds_obs, ds_field, "ssh_obs", np.timedelta64(12, "h"))

In [None]:
ds_obs_binned.isel(time=10).ssh_obs.plot()

#### Movie (GIF)

In [None]:
create_movie(ds_obs_binned.ssh_obs, "ssh_jason1", framedim="time", cmap="viridis")

### NADIR-Like

In [None]:
ds_files = [
    "/Volumes/EMANS_HDD/data/dc20a_osse/test/preprocess/osse_2020a_natl60/2020a_SSH_mapping_NATL60_jason1.nc",
    "/Volumes/EMANS_HDD/data/dc20a_osse/test/preprocess/osse_2020a_natl60/2020a_SSH_mapping_NATL60_envisat.nc",
    "/Volumes/EMANS_HDD/data/dc20a_osse/test/preprocess/osse_2020a_natl60/2020a_SSH_mapping_NATL60_geosat2.nc",
    "/Volumes/EMANS_HDD/data/dc20a_osse/test/preprocess/osse_2020a_natl60/2020a_SSH_mapping_NATL60_topex-poseidon_interleaved.nc",
]

ds_obs = xr.open_mfdataset(
    ds_files,
    combine="nested",
    concat_dim="time",
    parallel=True,
    preprocess=None,
    engine="netcdf4",
)


ds_obs = ds_obs.sortby("time")

ds_obs = correct_coordinate_labels(ds_obs)

# ds_obs = ds_obs.rename({"ssh": "ssh_obs"})
ds_obs

In [None]:
central_date = np.datetime64("2012-10-22")
num_days = 1

delta_t = np.timedelta64(num_days, "D")

count_num_obs(ds_obs.ssh_obs, central_date, delta_t)

#### Demo

In [None]:
central_date = np.datetime64("2012-10-22")
num_days = 10
delta_t = np.timedelta64(num_days, "D")
variable = "ssh_obs"

plot_obs_demo(ds_obs, central_date, delta_t, variable, verbose=True)

#### Gridded Dataset

In [None]:
ds_obs_binned = bin_observations(ds_obs, ds_field, "ssh_obs", np.timedelta64(12, "h"))

In [None]:
ds_obs_binned.isel(time=10).ssh_obs.plot()

#### Movie (GIF)

In [None]:
create_movie(ds_obs_binned.ssh_obs, "ssh_nadir4", framedim="time", cmap="viridis")

### SWOT 1 + NADIR 1

In [None]:
ds_files = [
    # "/Volumes/EMANS_HDD/data/dc20a_osse/test/ml/nadir1.nc",
    # "/Volumes/EMANS_HDD/data/dc20a_osse/test/ml/nadir4.nc",
    # "/Volumes/EMANS_HDD/data/dc20a_osse/test/ml/swot1nadir1.nc",
    "/Volumes/EMANS_HDD/data/dc20a_osse/test/ml/swot1nadir5.nc",
]

ds_obs = xr.open_mfdataset(
    ds_files,
    combine="nested",
    concat_dim="time",
    parallel=True,
    preprocess=None,
    engine="netcdf4",
)


ds_obs = ds_obs.sortby("time")

ds_obs = correct_coordinate_labels(ds_obs)

# ds_obs = ds_obs.rename({"ssh": "ssh_obs"})
ds_obs

In [None]:
central_date = np.datetime64("2012-10-22")
num_days = 1

delta_t = np.timedelta64(num_days, "D")

count_num_obs(ds_obs.ssh_model, central_date, delta_t)

#### Demo

In [None]:
central_date = np.datetime64("2012-10-22")
num_days = 5
delta_t = np.timedelta64(num_days, "D")
variable = "ssh_model"

plot_obs_demo(ds_obs, central_date, delta_t, variable, verbose=True)

#### Gridded Dataset

In [None]:
import pyinterp
from tqdm.notebook import tqdm


def bin_observations_swot(
    ds_obs: xr.Dataset, ds_ref: xr.Dataset, variable: str, time_buffer: np.timedelta64
) -> xr.Dataset:

    # create binning object
    binning = pyinterp.Binning2D(
        pyinterp.Axis(ds_ref.longitude.values), pyinterp.Axis(ds_ref.latitude.values)
    )

    # initialize datasets
    ds_obs_binned = []

    for t in tqdm(ds_ref.time):
        binning.clear()

        # get all indices within timestamp + buffer
        # ids = np.where((np.abs(ds_obs.time.values - t.values) < 2.0 * time_buffer))[0]
        tds = ds_obs.isel(
            time=pd.to_datetime(ds_obs.time.values).date
            == pd.to_datetime(t.values).date()
        )

        # extract lat,lon,values
        values = np.ravel(ds_obs[variable].values)
        lons = np.ravel(ds_obs.longitude.values) - 360
        lats = np.ravel(ds_obs.latitude.values)

        # mask all nans
        msk = np.isfinite(values)

        binning.push(lons[msk], lats[msk], values[msk])

        gridded = (
            ("time", "latitude", "longitude"),
            binning.variable("mean").T[None, ...],
        )

        # create gridded dataset
        ds_obs_binned.append(
            xr.Dataset(
                {variable: gridded},
                {
                    "time": [t.values],
                    "latitude": np.array(binning.y),
                    "longitude": np.array(binning.x),
                },
            ).astype("float32", casting="same_kind")
        )

    # concatenate final dataset
    ds_obs_binned = xr.concat(ds_obs_binned, dim="time")
    return ds_obs_binned

In [None]:
ds_obs_binned = bin_observations_swot(
    ds_obs, ds_field, "ssh_model", np.timedelta64(12, "h")
)

In [None]:
# ds_obs_binned.isel(time=0).values

In [None]:
ds_obs_binned.isel(time=0).ssh_model.plot()

#### Movie (GIF)

In [None]:
create_movie(
    ds_obs_binned.ssh_model, "ssh_swot1nadir5", framedim="time", cmap="viridis"
)