In [None]:
import sys, os
from pyprojroot import here

# spyder up to find the root
root = here(project_files=[".root"])


# append to path
sys.path.append(str(root))

In [None]:
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns

sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)

import hvplot.xarray
import hvplot.pandas

from inr4ssh._src.operators.finite_diff import calculate_gradient, calculate_laplacian
from inr4ssh._src.preprocess.subset import temporal_subset, spatial_subset
from inr4ssh._src.preprocess.coords import (
    correct_coordinate_labels,
    correct_longitude_domain,
)
from inr4ssh._src.data.ssh_obs import load_ssh_altimetry_data_train

from inr4ssh._src.preprocess.coords import correct_coordinate_labels
from inr4ssh._src.preprocess.obs import bin_observations
from inr4ssh._src.viz.movie import create_movie
from inr4ssh._src.metrics.psd import psd_isotropic
from inr4ssh._src.viz.psd.isotropic import plot_psd_isotropic
from inr4ssh._src.viz.obs import plot_obs_demo
from inr4ssh._src.metrics.psd import psd_spacetime, psd_spacetime_dask
from inr4ssh._src.viz.psd.spacetime import (
    plot_psd_spacetime_wavelength,
    plot_psd_spacetime_wavenumber,
)

%matplotlib inline
%load_ext autoreload
%autoreload 2

## Data

### Evaluation Field

In [None]:
!ls /Volumes/EMANS_HDD/data/dc21b/results/

In [None]:
url = "/Volumes/EMANS_HDD/data/dc21b/results/OSE_ssh_mapping_MIOST.nc"


def preprocess(ds):

    # subset time
    ds = ds.sel(time=slice(np.datetime64("2012-10-22"), np.datetime64("2012-12-02")))

    return ds


ds_field = xr.open_dataset(url)

# ds_field = ds_field.sel(
#     time=slice(np.datetime64("2012-10-22"), np.datetime64("2012-12-02"))
# )
#
# ds_field = (
#     ds_field.rename({"lon": "longitude"})
#     .rename({"lat": "latitude"})
#     .rename({"sossheig": "ssh"})
# )
#
# ds_field = ds_field.resample(time="1D").mean()

ds_field = correct_coordinate_labels(ds_field)

ds_field

In [None]:
# create_movie(ds_field.ssh, "ssh_dc21b_siren", framedim="time", cmap="viridis")

## Alongtrack Observations

In [None]:
url = "/Volumes/EMANS_HDD/data/dc21b/train/dt_gulfstream_*.nc"


def preprocess(ds):

    # subset time
    ds = ds.sel(time=slice(np.datetime64("2012-10-22"), np.datetime64("2012-12-02")))

    return ds


ds_obs = xr.open_mfdataset(url, preprocess=None, combine="nested", concat_dim="time")

# ds_field = ds_field.sel(
#     time=slice(np.datetime64("2012-10-22"), np.datetime64("2012-12-02"))
# )
#
# ds_field = (
#     ds_field.rename({"lon": "longitude"})
#     .rename({"lat": "latitude"})
#     .rename({"sossheig": "ssh"})
# )
#
# ds_field = ds_field.resample(time="1D").mean()

ds_obs = correct_coordinate_labels(ds_obs)

ds_obs = ds_obs.sortby("time")

ds_obs

In [None]:
ds_obs

In [None]:
from inr4ssh._src.viz.obs import plot_obs_demo

In [None]:
ds_obs.time.min().values, ds_obs.time.max().values

In [None]:
central_date = np.datetime64("2017-01-20")
num_days = 1
delta_t = np.timedelta64(num_days, "D")
variable = "sla_unfiltered"

ds_obs["time"] = pd.to_datetime(ds_obs["time"].values)

plot_obs_demo(ds_obs, central_date, delta_t, variable, verbose=True)

In [None]:
ds_obs.sel(time="2017-01-20")

In [None]:
ds_obs.sel(time="2017-1-20")

In [None]:
ds_obs.time.min().values, ds_obs.time.max().values

In [None]:
ds_obs_binned = bin_observations(
    ds_obs, ds_field, "sla_filtered", np.timedelta64(12, "h")
)

In [None]:
ds_obs_binned

In [None]:
ds_obs_binned = ds_obs_binned.rename({"sla_filtered": "ssh"})

In [None]:
ds_obs_binned

In [None]:
fig, ax = plt.subplots()
ds_obs_binned.ssh.sel(time="2017-01-20").plot(cmap="viridis", vmin=-1.3, vmax=1.3)
ax.set(xlabel="", ylabel="", title="")
plt.tight_layout()
plt.show()

In [None]:
# create_movie(ds_obs_binned.ssh, "ssh_dc21b_obs", framedim="time", cmap="viridis")

In [None]:
def count_num_obs(ds, central_date, delta_t):
    tmin = central_date - delta_t
    tmax = central_date + delta_t

    ds = ds.sel(time=slice(tmin, tmax))

    ds = ds.drop_duplicates(dim="time")
    return len(ds.values.flatten())

In [None]:
central_date = np.datetime64("2012-10-22")
num_days = 1_0000
delta_t = np.timedelta64(num_days, "D")

count_num_obs(ds_field.ssh, central_date, delta_t)

In [None]:
ds_field.isel(time=0).ssh.plot.imshow()

### Density

In [None]:
# fig, ax = plt.subplots()
# sns.kdeplot(
#     # data=ds_field.ssh.values.flatten(),
#     # data=np.log(ds_field.ssh_grad.values.flatten()),
#     data=np.log(ds_field.ssh_lap.values.flatten()),
#     cumulative=True, common_norm=False, common_grid=True,
#     ax=ax
# )
# # ax.set_xlabel("SSH [m]")
# # ax.set_xlabel(r"Log Kinetic Energy [m$^2$s$^{-2}$]")
# ax.set_xlabel(r"Log Enstropy [s$^{-1}$]")
# ax.set_ylabel("Cumulative Density")
# plt.show()

In [None]:
# fig, ax = plt.subplots()
# sns.kdeplot(
#     # data=ds_field.ssh.values.flatten(),
#     # data=np.log(ds_field.ssh_grad.values.flatten()),
#     data=np.log(ds_field.ssh_lap.values.flatten()),
#     cumulative=False, common_norm=False, common_grid=True,
#     ax=ax
# )
# # ax.set_xlabel("SSH [m]")
# # ax.set_xlabel(r"Log Kinetic Energy [m$^2$s$^{-2}$]")
# ax.set_xlabel(r"Log Enstropy [s$^{-1}$]")
# ax.set_ylabel("Density")
# plt.show()

#### Movie (GIF)

In [None]:
# create_movie(ds_field.ssh, "ssh_field", framedim="time", cmap="viridis")

#### Gradients/Laplacian

In [None]:
from inr4ssh._src.operators.finite_diff import calculate_gradient, calculate_laplacian

ds_field["ssh_grad"] = calculate_gradient(ds_field["ssh"], "longitude", "latitude")
ds_field["ssh_lap"] = calculate_gradient(ds_field["ssh_grad"], "longitude", "latitude")


# create_movie(ds_field.ssh_grad, "ssh_field_grad", framedim="time", cmap="Spectral_r")
create_movie(np.log(ds_field.ssh_lap), "ssh_field_lap", framedim="time", cmap="RdBu_r")

### PSD

In [None]:
ds_field_psd = correct_coordinate_labels(ds_field)

# grab ssh
ds_field_psd = ds_field_psd.ssh

# correct units, degrees -> meters
ds_field_psd["longitude"] = ds_field_psd.longitude * 111e3
ds_field_psd["latitude"] = ds_field_psd.latitude * 111e3

# calculate
ds_field_psd = psd_isotropic(ds_field_psd)

In [None]:
fig, ax = plot_psd_isotropic(ds_field_psd.freq_r.values * 1e3, ds_field_psd.values)
plt.tight_layout()
plt.show()

### PSD - Spatial-Temporal

In [None]:
ds_field_psd = correct_coordinate_labels(ds_field)

# grab ssh
ds_field_psd = ds_field_psd.ssh_grad

# grab ssh
ds_field_psd = ds_field_psd.compute()

# correct units, degrees -> meters
ds_field_psd["longitude"] = ds_field_psd.longitude * 111e3
ds_field_psd["latitude"] = ds_field_psd.latitude * 111e3

time_norm = np.timedelta64(1, "D")
# mean psd of signal
ds_field_psd["time"] = (ds_field_psd.time - ds_field_psd.time[0]) / time_norm

# calculate
ds_field_psd = psd_spacetime_dask(ds_field_psd)

In [None]:
fig, ax, cbar = plot_psd_spacetime_wavelength(
    ds_field_psd.freq_longitude * 1e3,
    ds_field_psd.freq_time,
    ds_field_psd,
)

plt.tight_layout()
plt.show()

In [None]:
# fig, ax, cbar = plot_psd_spacetime_wavenumber(
#     ds_field_psd.freq_longitude * 1e3,
#     ds_field_psd.freq_time,
#     ds_field_psd,
# )

# plt.tight_layout()
# plt.show()

## Observations

In [None]:
# # grab ssh
# ds_field_psd = ds_field.ssh_grad

# # correct units, degrees -> meters
# ds_field_psd["longitude"] = ds_field_psd.longitude * 111e3
# ds_field_psd["latitude"] = ds_field_psd.latitude * 111e3

# # calculate
# ds_field_psd = psd_isotropic(ds_field_psd)

# fig, ax = plot_isotropic_psd(ds_field_psd, freq_scale=1e3)
# ax.set_ylabel(r"PSD [m$^2$s$^{-2}$/cyles/m")
# plt.tight_layout()
# plt.show()

### Missing Time

In [None]:
ds_obs = xr.open_dataset(
    "/Users/eman/code_projects/data/osse_2022b/dc_qg_obs_fullfields/ssh_obs_fullfields.nc"
)


ds_obs = correct_coordinate_labels(ds_obs)

ds_obs = ds_obs.rename({"ssh": "ssh_obs"})

ds_obs

In [None]:
central_date = np.datetime64("2012-10-22")
num_days = 100
delta_t = np.timedelta64(num_days, "D")

count_num_obs(ds_obs.ssh_obs, central_date, delta_t)

In [None]:
ds_obs = xr.merge([ds_field, ds_obs])

#### Movie (GIF)

In [None]:
# create_movie(ds_obs.ssh_obs, "ssh_missing_time", framedim="time", cmap="viridis")

In [None]:
# !ls /Users/eman/code_projects/data/osse_2022b/dc_qg_obs_nadirlike/

### Jason-Like

In [None]:
ds_obs = xr.open_dataset(
    "/Users/eman/code_projects/data/osse_2022b/dc_qg_obs_jasonlike/ssh_obs_jasonlike.nc"
)

ds_obs = ds_obs.sortby("time")

ds_obs = correct_coordinate_labels(ds_obs)

# ds_obs = ds_obs.rename({"ssh": "ssh_obs"})
ds_obs

In [None]:
central_date = np.datetime64("2012-10-22")
num_days = 100
delta_t = np.timedelta64(num_days, "D")

count_num_obs(ds_obs.ssh_obs, central_date, delta_t)

In [None]:
central_date = np.datetime64("2012-10-22")
num_days = 1
delta_t = np.timedelta64(num_days, "D")
variable = "ssh_obs"

plot_obs_demo(ds_obs, central_date, delta_t, variable, verbose=True)

#### Gridded Dataset

In [None]:
ds_obs_binned = bin_observations(ds_obs, ds_field, "ssh_obs", np.timedelta64(12, "h"))

In [None]:
ds_obs_binned.isel(time=10).ssh_obs.plot()

#### Movie (GIF)

In [None]:
# create_movie(ds_obs_binned.ssh_obs, "ssh_jasonlike", framedim="time", cmap="viridis")

### NADIR-Like

In [None]:
ds_obs = xr.open_mfdataset(
    "/Users/eman/code_projects/data/osse_2022b/dc_qg_obs_nadirlike/ssh_obs*.nc",
    combine="nested",
    concat_dim="time",
    parallel=True,
    preprocess=None,
    engine="netcdf4",
)


ds_obs = ds_obs.sortby("time")

ds_obs = correct_coordinate_labels(ds_obs)

# ds_obs = ds_obs.rename({"ssh": "ssh_obs"})
ds_obs

In [None]:
central_date = np.datetime64("2012-10-22")
num_days = 1

delta_t = np.timedelta64(num_days, "D")

count_num_obs(ds_obs.ssh_obs, central_date, delta_t)

#### Demo

In [None]:
central_date = np.datetime64("2012-10-22")
num_days = 10
delta_t = np.timedelta64(num_days, "D")
variable = "ssh_obs"

plot_obs_demo(ds_obs, central_date, delta_t, variable, verbose=True)

#### Gridded Dataset

In [None]:
ds_obs_binned = bin_observations(ds_obs, ds_field, "ssh_obs", np.timedelta64(12, "h"))

In [None]:
ds_obs_binned.isel(time=10).ssh_obs.plot()

#### Movie (GIF)

In [None]:
# create_movie(ds_obs_binned.ssh_obs, "ssh_nadirlike", framedim="time", cmap="viridis")