In [None]:
import sys, os
from pyprojroot import here

# spyder up to find the root
root = here(project_files=[".root"])


# append to path
sys.path.append(str(root))

In [None]:
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns

sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)

import hvplot.xarray
import hvplot.pandas

from inr4ssh._src.operators.finite_diff import calculate_gradient, calculate_laplacian
from inr4ssh._src.preprocess.subset import temporal_subset, spatial_subset
from inr4ssh._src.preprocess.coords import (
    correct_coordinate_labels,
    correct_longitude_domain,
)
from inr4ssh._src.data.ssh_obs import load_ssh_altimetry_data_train

from inr4ssh._src.preprocess.coords import correct_coordinate_labels
from inr4ssh._src.viz.movie import create_movie
from inr4ssh._src.metrics.psd import psd_isotropic
from inr4ssh._src.viz.psd import plot_isotropic_psd

%matplotlib inline
%load_ext autoreload
%autoreload 2

## Data

### Evaluation Field

In [None]:
ds_field = xr.open_mfdataset(
    "/Users/eman/code_projects/data/osse_2022b/dc_qg_eval/dc_qg_eval_*.nc"
)


ds_field = correct_coordinate_labels(ds_field)

In [None]:
200 * 200 * 43

In [None]:
ds_field.isel(time=0).ssh.plot.imshow()

In [None]:
from inr4ssh._src.operators.finite_diff import calculate_gradient, calculate_laplacian

ds_field["ssh_grad"] = calculate_gradient(ds_field["ssh"], "longitude", "latitude")
ds_field["ssh_lap"] = calculate_gradient(ds_field["ssh_grad"], "longitude", "latitude")

In [None]:
# create_movie(ds_field.ssh, "ssh_field", framedim="time", cmap="viridis")
# create_movie(ds_field.ssh_grad, "ssh_field_grad", framedim="time", cmap="Spectral_r")
create_movie(ds_field.ssh_lap, "ssh_field_lap", framedim="time", cmap="RdBu_r")

In [None]:
# grab ssh
ds_field_psd = ds_field.ssh

# correct units, degrees -> meters
ds_field_psd["longitude"] = ds_field_psd.longitude * 111e3
ds_field_psd["latitude"] = ds_field_psd.latitude * 111e3

# calculate
ds_field_psd = psd_isotropic(ds_field_psd)

fig, ax = plot_isotropic_psd(ds_field_psd, freq_scale=1e3)
plt.tight_layout()
plt.show()

In [None]:
# # grab ssh
# ds_field_psd = ds_field.ssh_grad

# # correct units, degrees -> meters
# ds_field_psd["longitude"] = ds_field_psd.longitude * 111e3
# ds_field_psd["latitude"] = ds_field_psd.latitude * 111e3

# # calculate
# ds_field_psd = psd_isotropic(ds_field_psd)

# fig, ax = plot_isotropic_psd(ds_field_psd, freq_scale=1e3)
# ax.set_ylabel(r"PSD [m$^2$s$^{-2}$/cyles/m")
# plt.tight_layout()
# plt.show()

### Obs: Missing Time

In [None]:
ds_field = xr.open_dataset(
    "/Users/eman/code_projects/data/osse_2022b/dc_qg_obs_fullfields/ssh_obs_fullfields.nc"
)


ds_field = correct_coordinate_labels(ds_field)

ds_field

In [None]:
create_movie(ds_field.ssh, "ssh_field", framedim="time", cmap="viridis")

### NADIR-Like

In [None]:
ds_obs = xr.open_dataset(
    "/Users/eman/code_projects/data/osse_2022b/dc_qg_obs_jasonlike/ssh_obs_jasonlike.nc"
)

# ds_obs = xr.open_mfdataset(
#     "/Users/eman/code_projects/data/osse_2022b/dc_qg_obs_nadirlike/ssh_obs*.nc",
#     combine="nested",
#     concat_dim="time",
#     parallel=True,
#     preprocess=None,
#     engine="netcdf4",
# )

# ds_obs = xr.open_dataset("/Users/eman/code_projects/data/osse_2022b/dc_qg_obs_fullfields/ssh_obs_fullfields.nc")

ds_obs = ds_obs.sortby("time")

ds_obs = correct_coordinate_labels(ds_obs)

# ds_obs = ds_obs.rename({"ssh": "ssh_obs"})
ds_obs

In [None]:
create_movie(ds_obs.ssh_obs, "ssh_missing_space", framedim="time", cmap="viridis")

In [None]:
ds_obs = xr.merge([ds_field, ds_obs])

In [None]:
ds_obs

In [None]:
from tqdm.notebook import tqdm

variable = "ssh_obs"
# create binning
binning = pyinterp.Binning2D(
    pyinterp.Axis(ds_field.nav_lon.values), pyinterp.Axis(ds_field.nav_lat.values)
)
grid_day_dses = []

for t in tqdm(ds_field.time):
    binning.clear()
    # TODO: add some buffers
    tds = ds_obs.isel(
        time=pd.to_datetime(ds_obs.time.values).date == pd.to_datetime(t.values).date()
    )

    values = np.ravel(tds[variable].values)
    lons = np.ravel(tds.longitude.values)
    lats = np.ravel(tds.latitude.values)
    msk = np.isfinite(values)
    binning.push(lons[msk], lats[msk], values[msk])
    gridded = (("time", "latitude", "longitude"), binning.variable("mean").T[None, ...])
    grid_day_dses.append(
        xr.Dataset(
            {"gridded": gridded},
            {
                "time": [t.values],
                "latitude": np.array(binning.y),
                "longitude": np.array(binning.x),
            },
        )  # .astype('float32', casting='same_kind')
    )
tgt_ds = xr.concat(grid_day_dses, dim="time")

In [None]:
create_movie(tgt_ds.gridded, "ssh_missing_space", framedim="time", cmap="viridis")

In [None]:
def custom_plotfunc(ds, fig, tt, *args, **kwargs):
    # Define station location for timeseries

    fig, ax = plt.subplots()

    # subset date
    ids = ds.isel(time=tt)

    # convert to dataframe
    ids = ids.to_dataframe().reset_index()

    # drop nans
    ids = ids.dropna()

    # plot remaining
    pts = ax.scatter(
        ids["longitude"],
        ids["latitude"],
        c=ids["gridded"],
        vmin=ds.min().values,
        vmax=ds.max().values,
    )

    # Map axis
    # # Colorlimits need to be fixed or your video is going to cause seizures.
    # # This is the only modification from the code above!
    # ds.isel(time=tt).plot(ax=ax1, vmin=ds.min(), vmax=ds.max(), cmap='RdBu_r')
    # ax1.plot(station['x'], station['y'], marker='*', color='k' ,markersize=15)
    # ax1.text(station['x']+4, station['y']+4, 'Station', color='k' )
    # ax1.set_aspect(1)
    # ax1.set_facecolor('0.5')
    # ax1.set_title('');

    # # Time series
    # ds_station.isel(time=slice(0,tt+1)).plot.line(ax=ax2, x='time')
    # ax2.set_xlim(ds.time.min().data, ds.time.max().data)
    # ax2.set_ylim(ds_station.min(), ds_station.max())
    # ax2.set_title('Data at station');

    # fig.subplots_adjust(wspace=0.6)

    return None, None

In [None]:
from xmovie import Movie

mov_custom = Movie(tgt_ds.gridded, custom_plotfunc)
mov_custom.preview(3)

In [None]:
tgt_ds = tgt_ds.fillna(np.nan)

In [None]:
tgt_ds.isel(time=1).gridded

In [None]:
tgt_ds.isel(time=1).gridded.astype(np.float64).plot(cmap="viridis")

In [None]:
create_movie(tgt_ds.gridded, "ssh_nadir", framedim="time", cmap="viridis")

In [None]:
values = np.ravel(tds.pred.values)
lons = np.ravel(tds.lon.values) - 360
lats = np.ravel(tds.lat.values)
msk = np.isfinite(values)
binning.push(lons[msk], lats[msk], values[msk])
gridded = (("time", "lat", "lon"), binning.variable("mean").T[None, ...])
grid_day_dses.append(
    xr.Dataset(
        {"gridded": gridded},
        {"time": [t.values], "lat": np.array(binning.y), "lon": np.array(binning.x)},
    ).astype("float32", casting="same_kind")
)
tgt_ds = xr.concat(grid_day_dses, dim="time")

In [None]:
import pyinterp

ds = swath_data[["pred", "lat", "lon", "time"]]
binning = pyinterp.Binning2D(
    pyinterp.Axis(tgt_grid.lon.values), pyinterp.Axis(tgt_grid.lat.values)
)
grid_day_dses = []

for t in tgt_grid.time:
    binning.clear()
    tds = ds.isel(
        time=pd.to_datetime(ds.time.values).date == pd.to_datetime(t.values).date()
    )

    values = np.ravel(tds.pred.values)
    lons = np.ravel(tds.lon.values) - 360
    lats = np.ravel(tds.lat.values)
    msk = np.isfinite(values)
    binning.push(lons[msk], lats[msk], values[msk])
    gridded = (("time", "lat", "lon"), binning.variable("mean").T[None, ...])
    grid_day_dses.append(
        xr.Dataset(
            {"gridded": gridded},
            {
                "time": [t.values],
                "lat": np.array(binning.y),
                "lon": np.array(binning.x),
            },
        ).astype("float32", casting="same_kind")
    )
tgt_ds = xr.concat(grid_day_dses, dim="time")

In [None]:
variable = "ssh_obs"

df_jason = ds_jason.to_dataframe()

df_jason_1D = (
    df_jason.groupby(["latitude", "longitude", pd.Grouper(freq="D", level="time")])[
        variable
    ]
    .mean()
    .reset_index()
)

In [None]:
df_jason_1D.head()

In [None]:
ds_xr = []
from tqdm.notebook import tqdm

for itime in tqdm(df_jason_1D.groupby("time")):

    # do binning
    binning.push(
        itime[1].longitude,
        itime[1].latitude,
        itime[1][variable],
        # simple=True
    )

    # create temp df
    ids = create_xarray(
        binning.variable("mean").T[None, :], binning.x[:], binning.y[:], [itime[0]]
    )
    break

    binning.clear()

    # add to dataframe
    ds_xr.append(ids)

In [None]:
binning.x[:].shape, binning.y[:].shape, binning.z[:].shape

In [None]:
ids

In [None]:
itime[1].longitude

### NADIR-Like

In [None]:
ds_xr = xr.concat(ds_xr, dim="time")
ds_xr

In [None]:
ds_xr.ssh.min(), ds_xr.ssh.max()

In [None]:
ds_xr.isel(time=20).ssh

In [None]:
create_movie(ds_xr.ssh, "ssh_missing_space", framedim="time", cmap="viridis")

In [None]:
from xhistogram.xarray import histogram

In [None]:
lon_bins = np.linspace(
    ds_field.nav_lon.min().values, ds_field.nav_lon.max().values, 200
)
lat_bins = np.linspace(
    ds_field.nav_lat.min().values, ds_field.nav_lat.max().values, 200
)
time_bins = np.arange(
    ds_field.time.min().values,
    ds_field.time.max().values + np.timedelta64(1, "D"),
    np.timedelta64(1, "D"),
)

assert time_bins.shape == ds_field.time.shape
assert ds_field.nav_lat.shape == lat_bins.shape
assert ds_field.nav_lon.shape == lon_bins.shape

In [None]:
ds_jason.time

In [None]:
h = histogram(
    ds_jason.longitude,
    ds_jason.latitude,
    ds_jason.time,
    bins=[lon_bins, lat_bins, time_bins],
)

In [None]:
time_bins.shape, ds_field.time.shape

In [None]:
lon_bins = np.arange(0, 361, 2)
lat_bins = np.arange(-70, 71, 2)

# helps with memory management
ds_ll_chunked = ds_ll.chunk({"time": "5MB"})

sla_variance = histogram(
    ds_ll_chunked.longitude,
    ds_ll_chunked.latitude,
    bins=[lon_bins, lat_bins],
    weights=ds_ll_chunked.sla_filtered.fillna(0.0) ** 2,
)

norm = histogram(
    ds_ll_chunked.longitude, ds_ll_chunked.latitude, bins=[lon_bins, lat_bins]
)