# AlongTrack Data - SWOT

In [1]:
import autoroot
import typing as tp
from dataclasses import dataclass
import functools as ft
import numpy as np
import pandas as pd
import xarray as xr
import einops
from metpy.units import units
import pint_xarray
import xarray_dataclasses as xrdataclass
from oceanbench._src.datasets.base import XRDABatcher
from oceanbench._src.geoprocessing.spatial import transform_360_to_180
from oceanbench._src.geoprocessing.subset import where_slice
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import matplotlib.ticker as ticker
import seaborn as sns

sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)

%load_ext autoreload
%autoreload 2


## Data

In [2]:
!ls "/gpfswork/rech/yrf/commun/data_challenges/dc20a_osse/raw/dc_obs/"

2020a_SSH_mapping_NATL60_envisat.nc
2020a_SSH_mapping_NATL60_geosat2.nc
2020a_SSH_mapping_NATL60_jason1.nc
2020a_SSH_mapping_NATL60_karin_swot.nc
2020a_SSH_mapping_NATL60_nadir_swot.nc
2020a_SSH_mapping_NATL60_topex-poseidon_interleaved.nc


In [48]:
files_nadir_dc20a = [
    "/gpfswork/rech/yrf/commun/data_challenges/dc20a_osse/raw/dc_obs/2020a_SSH_mapping_NATL60_jason1.nc",
    "/gpfswork/rech/yrf/commun/data_challenges/dc20a_osse/raw/dc_obs/2020a_SSH_mapping_NATL60_envisat.nc",
    "/gpfswork/rech/yrf/commun/data_challenges/dc20a_osse/raw/dc_obs/2020a_SSH_mapping_NATL60_geosat2.nc",
    "/gpfswork/rech/yrf/commun/data_challenges/dc20a_osse/raw/dc_obs/2020a_SSH_mapping_NATL60_topex-poseidon_interleaved.nc",
    "/gpfswork/rech/yrf/commun/data_challenges/dc20a_osse/raw/dc_obs/2020a_SSH_mapping_NATL60_nadir_swot.nc",
]

files_swot_dc20a = [
    "/gpfswork/rech/yrf/commun/data_challenges/dc20a_osse/raw/dc_obs/2020a_SSH_mapping_NATL60_karin_swot.nc",
    
]

ds_swot = xr.open_dataset(files_swot_dc20a[0])

In [49]:
ds_swot

In [50]:
def remove_swath_dimension(ds, name: str="nC"):
    
    return ds.rename({"time": "z"}).stack(time=(name, "z")).set_index({"time": "z"}).reset_coords([name]).sortby("time")

In [51]:
ds_swot_ = remove_swath_dimension(ds_swot, "nC")

In [55]:
def preprocess_nadir_dc20a(da, variable="ssh_mod"):
        
    da = da.rename({variable: "ssh"})
    
    da = da.sel(
        time=slice("2012-10-22", "2012-12-03"),
        drop=True
    ).compute()
    
    da["lon"] = transform_360_to_180(da["lon"])
    
    da = where_slice(da, "lon", -64.975, -55.007)
    da = where_slice(da, "lat", 33.025, 42.9917)
    
    da = da.drop_dims("cycle")
        
    return da


def preprocess_swot_dc20a(da, variable="ssh_mod"):
    
    da = remove_swath_dimension(da, "nC")
        
    da = da.rename({variable: "ssh"})
    
    da = da.sel(
        time=slice("2012-10-22", "2012-12-03"),
        drop=True
    ).compute()
    
    da["lon"] = transform_360_to_180(da["lon"])
    
    da = where_slice(da, "lon", -64.975, -55.007)
    da = where_slice(da, "lat", 33.025, 42.9917)
    
    da = da.sortby("time")
        
    return da

In [56]:
preprocess_fn = ft.partial(preprocess_nadir_dc20a, variable="ssh_model")

ds_nadir = xr.open_mfdataset(
    files_nadir_dc20a, 
    preprocess=preprocess_fn,
    combine="nested",
    engine="netcdf4",
    concat_dim="time"
)

ds_nadir = ds_nadir.sortby("time")

ds_nadir

In [57]:
preprocess_fn = ft.partial(preprocess_swot_dc20a, variable="ssh_model")

ds_swot = xr.open_mfdataset(
    files_swot_dc20a, 
    preprocess=preprocess_fn,
    combine="nested",
    engine="netcdf4",
    concat_dim="time"
)

ds_swot = ds_swot.sortby("time")

ds_swot

In [69]:
ds_swotnadir = xr.concat([ds_nadir, ds_swot], dim="time")
ds_swotnadir = ds_swotnadir.sortby("time")

In [70]:
ds_swotnadir

In [2]:
# %matplotlib inline

# fig, ax = plt.subplots()

# sub_ds = ds_nadir.isel(time=slice(0,None))
# pts = ax.scatter(sub_ds.lon, sub_ds.lat, c=sub_ds.ssh, s=0.1)
# ax.set(
#     xlabel="Longitude",
#     ylabel="Latitude",
# )

# plt.colorbar(pts, label="Sea Surface Height [m]")
# plt.tight_layout()
# plt.show()

### Gridding

In [58]:
!ls /gpfswork/rech/yrf/commun/data_challenges/dc20a_osse/staging/natl60/

NATL60-CJM165_GULFSTREAM_ssh_y2013.1y.decoded.nc
NATL60-CJM165_GULFSTREAM_ssh_y2013.1y.nc
NATL60-CJM165_GULFSTREAM_sss_y2013.1y.nc
NATL60-CJM165_GULFSTREAM_sst_y2013.1y.nc


In [60]:
file_natl60 = "/gpfswork/rech/yrf/commun/data_challenges/dc20a_osse/staging/natl60/NATL60-CJM165_GULFSTREAM_ssh_y2013.1y.nc"


In [61]:
def open_natl60_reference(file, variable="gssh"):
    da = xr.open_dataset(file, decode_times=False)
    da["time"] = pd.to_datetime(da.time)
    da = da.sortby("time")
    da["lon"] = transform_360_to_180(da["lon"])
    da = da.sel(
        time=slice("2012-10-22", "2012-12-03"),
        lon=slice(-64.975, -55.007),
        lat=slice(33.025, 42.9917),
        drop=True
    )
    da = da.rename({variable: "ssh"})
    return da


In [63]:
ds_natl60 = open_natl60_reference(file_natl60, "ssh")

## Data Structure

In [64]:
from oceanbench._src.geoprocessing.gridding import coord_based_to_grid

In [65]:
ds_nadir_gridded = coord_based_to_grid(
    ds_nadir, 
    ds_natl60,
    data_vars=["ssh"], 
    t_res=pd.to_timedelta(12, unit="hour")
)

In [66]:
ds_swot_gridded = coord_based_to_grid(
    ds_swot, 
    ds_natl60,
    data_vars=["ssh"], 
    t_res=pd.to_timedelta(12, unit="hour")
)

In [71]:
ds_swotnadir_gridded = coord_based_to_grid(
    ds_swotnadir, 
    ds_natl60,
    data_vars=["ssh"], 
    t_res=pd.to_timedelta(12, unit="hour")
)

In [None]:
import holoviews as hv
hv.extension("matplotlib")

In [73]:
variable = "ssh" # "vort_r" # "ke" #  
cmap = "viridis" # "RdBu_r" # "YlGnBu_r" #
field_name = "NATL60"

ssh_ds = xr.Dataset({
    field_name: ds_natl60[variable],
    "NADIR": np.isfinite(ds_nadir_gridded[variable]),
    "SWOT": np.isfinite(ds_swot_gridded[variable]),
    "SWOTNADIR": np.isfinite(ds_swotnadir_gridded[variable]),
})


to_plot_ds = ssh_ds.transpose("time", "lat", "lon")#.isel(time=slice(25, 55, 1))

clim = (
    to_plot_ds[
        [field_name, "NADIR", "SWOT", "SWOTNADIR"]
    ].to_array().pipe(lambda da: (da.quantile(0.005).item(), da.quantile(0.995).item()))
)

images = hv.Layout([
    hv.Dataset(to_plot_ds)
    .to(hv.QuadMesh, ["lon", "lat"], v).relabel(v)
    .options(cmap=cmap, clim=clim)
    for v in to_plot_ds]
).cols(2).opts(sublabel_format="")

hv.output(images, holomap="gif", fps=2, dpi=125)