# Viz + Data Challenge 2021a

In [None]:
import sys, os
from pyprojroot import here

# spyder up to find the root
root = here(project_files=[".root"])


# append to path
sys.path.append(str(root))

In [None]:
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns

sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)

import hvplot.xarray
import hvplot.pandas

from inr4ssh._src.operators.finite_diff import calculate_gradient, calculate_laplacian
from inr4ssh._src.preprocess.subset import temporal_subset, spatial_subset
from inr4ssh._src.preprocess.coords import (
    correct_coordinate_labels,
    correct_longitude_domain,
)
from inr4ssh._src.data.ssh_obs import load_ssh_altimetry_data_train
from inr4ssh._src.viz.movie import create_movie

%matplotlib inline
%load_ext autoreload
%autoreload 2

## Observations

In [None]:
train_data_dir = f"/Volumes/EMANS_HDD/data/dc21b/train"
# train_data_dir =

ds_obs = load_ssh_altimetry_data_train(train_data_dir)

In [None]:
variable = "sla_unfiltered"

In [None]:
# temporal subset
ds_obs = temporal_subset(
    ds_obs,
    time_min=np.datetime64("2017-01-01"),
    time_max=np.datetime64("2018-01-01"),
    # time_min=np.datetime64("2017-01-01"),
    # time_max=np.datetime64("2017-02-01"),
    time_buffer=0.0,
    time_buffer_order="D",
)

In [None]:
ds_obs = ds_obs[["latitude", "longitude", variable]].reset_coords().astype("f4").load()

In [None]:
ds_obs

In [None]:
df_obs = ds_obs.to_dataframe()  # .reset_index()

In [None]:
df_obs

In [None]:
df_sla_mean = (
    df_obs.groupby(["latitude", "longitude", pd.Grouper(freq="D", level="time")])[
        variable
    ]
    .mean()
    .reset_index()
)

In [None]:
# df_sla_mean.hvplot.scatter(
#     x='longitude', y='latitude', groupby='time',
#     datashade=True, #coastline=True
#     # tiles=True
# )

In [None]:
import pyinterp

lon_min = 285.0
lon_max = 315.0
lon_buffer = 1.0
lat_min = 23.0
lat_max = 53.0
bin_lon_step = 0.1
bin_lat_step = 0.1

In [None]:
binning = pyinterp.Histogram2D(
    pyinterp.Axis(np.arange(lon_min, lon_max, bin_lon_step), is_circle=True),
    pyinterp.Axis(np.arange(lat_min, lat_max + bin_lat_step, bin_lat_step)),
)

In [None]:
binning

In [None]:
binning.x[:].shape, binning.y[:].shape, binning.variable("mean").shape,

In [None]:
def create_xarray(grid, lon_coord, lat_coord, time_coord):
    return xr.Dataset(
        {
            "ssh": (("time", "latitude", "longitude"), grid),
            "time": ("time", time_coord),
            "latitude": ("latitude", lat_coord),
            "longitude": ("longitude", lon_coord),
        },
    )

In [None]:
ds_xr = []
from tqdm.notebook import tqdm

for itime in tqdm(df_sla_mean.groupby("time")):

    # do binning
    binning.push(
        itime[1].longitude,
        itime[1].latitude,
        itime[1][variable],
        # simple=True
    )

    # create temp df
    ids = create_xarray(
        binning.variable("mean").T[None, :], binning.x[:], binning.y[:], [itime[0]]
    )

    binning.clear()

    # add to dataframe
    ds_xr.append(ids)

In [None]:
ds_xr = xr.concat(ds_xr, dim="time")

In [None]:
# correct longitude domain
ds_xr = correct_longitude_domain(ds_xr)

In [None]:
ds_xr

In [None]:
# save_path = "./"
# create_movie(ds_xr.ssh, f"obs", "time", cmap="viridis", file_path=save_path)

In [None]:
ds_xr.ssh.hvplot.image(
    x="longitude",
    y="latitude",
    # groupby='time',
    # rasterize=True,
    width=500,
    height=400,
    cmap="viridis",
)

In [None]:
import powerspec as ps

## Previous Work

In [None]:
model = "DYMOST"  # "MIOST" # "DUACS" # "BASELINE" # "BFN" # "4DVARNET" #

In [None]:
!ls /Volumes/EMANS_HDD/data/dc21b/results/

In [None]:
data_dir = f"/Volumes/EMANS_HDD/data/dc21b/results/OSE_ssh_mapping_{model}.nc"

In [None]:
ds = xr.open_dataset(data_dir)

#### Corrections

In [None]:
# correct labels
ds = correct_coordinate_labels(ds)

# correct longitude domain
ds = correct_longitude_domain(ds)

#### Time Period

**Daily Mean**

In [None]:
# resample to daily mean
ds = ds.resample(time="1D").mean()

#### Spatio-Temporal Subset

In [None]:
# temporal subset
ds = temporal_subset(
    ds,
    time_min=np.datetime64("2017-01-01"),
    time_max=np.datetime64("2018-01-01"),
    time_buffer=7.0,
    time_buffer_order="D",
)

In [None]:
# spatial subset
if model != "4DVARNET":
    ds = spatial_subset(
        ds,
        lon_min=285.0,
        lon_max=315.0,
        lon_buffer=1.0,
        lat_min=23.0,
        lat_max=53.0,
        lat_buffer=1.0,
    )

In [None]:
ds

In [None]:
# save_path = "./"
# create_movie(ds.ssh, f"field_{model.lower()}", "time", cmap="viridis", file_path=save_path)

#### SSH

In [None]:
ds.ssh.hvplot.image(
    x="longitude",
    y="latitude",
    # groupby='time',
    # rasterize=True,
    width=500,
    height=400,
    cmap="viridis",
)

#### Gradient

In [None]:
ds["ssh_grad"] = calculate_gradient(ds["ssh"], "longitude", "latitude")

In [None]:
# create_movie(ds.ssh_grad, f"field_{model.lower()}_grad", "time", cmap="Spectral_r", file_path=save_path)
#

In [None]:
ds.ssh_grad.hvplot.image(
    x="longitude",
    y="latitude",
    # groupby='time',
    # rasterize=True,
    width=500,
    height=400,
    cmap="Spectral_r",
)

#### (Norm) Laplacian

In [None]:
ds["ssh_lap"] = calculate_laplacian(ds["ssh"], "longitude", "latitude")

In [None]:
# create_movie(ds.ssh_lap, f"field_{model.lower()}_lap", "time", cmap="RdBu_r", file_path=save_path)

In [None]:
ds.ssh_lap.hvplot.image(
    x="longitude",
    y="latitude",
    # groupby='time',
    # rasterize=True,
    width=500,
    height=400,
    cmap="RdBu_r",
)