# Steric Sea Level Contribution from ORAS5 Reanalysis to Global Sea Level Observed by Satellite Altimetry

## Import libraries

In [None]:
import cacholote
import gsw_xarray as gsw
import matplotlib.pyplot as plt
import pooch
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download, utils

plt.style.use("seaborn-v0_8-notebook")

## Define parameters

In [None]:
# Time
year_start = 2004
year_stop = 2023

# Space
lat_slice = slice(-60, 60)
lon_slice = slice(-180, 180)

## Define Requests

In [None]:
time_request = {
    "year": [str(year) for year in range(year_start, year_stop + 1)],
    "month": [f"{month:02d}" for month in range(1, 13)],
}

request_reanalysis = (
    "reanalysis-oras5",
    {
        "product_type": ["consolidated", "operational"],
        "vertical_resolution": "all_levels",
        "variable": ["potential_temperature", "salinity"],
    }
    | time_request,
)
request_satellite = (
    "satellite-sea-level-global",
    {"variable": ["monthly_mean"], "version": "vdt2024"} | time_request,
)

download_kwargs = {
    "chunks": {"year": 1, "month": 1},
    "drop_variables": ["time_counter_bnds"],
}

## Define functions to cache

In [None]:
def compute_gsw_ds(ds):
    p = gsw.p_from_z(-ds["deptht"], ds["latitude"])
    SA = gsw.SA_from_SP(ds["vosaline"], p, ds["longitude"], ds["latitude"])
    CT = gsw.CT_from_pt(SA, ds["votemper"])
    rho = gsw.rho(SA, CT, p)
    ds = xr.merge([p, SA, CT, rho])
    for da in ds.data_vars.values():
        chunks = []
        for dim in da.dims:
            if dim == "time":
                chunks.append(1)
            elif dim in ["x", "y"]:
                chunks.append(200)
            elif dim == "deptht":
                chunks.append(15)
            else:
                raise ValueError(f"{dim = }")
        da.encoding["chunksizes"] = tuple(chunks)
    return ds


def compute_ssl_from_rho(rho, rho0=1025, prefix=""):
    grouped_rho = rho.groupby("time.month")
    delta_rho = grouped_rho - grouped_rho.mean()
    ssl = -(delta_rho.fillna(0) / rho0).integrate("deptht")
    ssl.attrs = {"long_name": f"{prefix}steric sea level".title(), "units": "m"}
    return ssl


def compute_ssl(
    collection_id, request, prefix, lon_slice, lat_slice, **download_kwargs
):
    ds = download.download_and_transform(
        collection_id, request, transform_func=compute_gsw_ds, **download_kwargs
    )
    ds = utils.regionalise(ds, lon_slice=lon_slice, lat_slice=lat_slice)
    if prefix == "":
        rho = ds["rho"]
    elif prefix == "thermo":
        rho = gsw.rho(ds["SA"].mean("time"), ds["CT"], ds["p"])
    elif prefix == "halo":
        rho = gsw.rho(ds["SA"], ds["CT"].mean("time"), ds["p"])
    else:
        raise NotImplementedError(f"{prefix=}")
    return compute_ssl_from_rho(rho, prefix=prefix).rename(f"{prefix}ssl")


@cacholote.cacheable
def compute_ssl_timeseries(
    collection_id,
    request,
    prefix,
    lon_slice,
    lat_slice,
    **download_kwargs,
):
    ssl = compute_ssl(
        collection_id,
        request,
        prefix,
        lon_slice=lon_slice,
        lat_slice=lat_slice,
        **download_kwargs,
    )
    return diagnostics.spatial_weighted_mean(ssl)

## Download and transform reanalysis

In [None]:
dataarrays = []
for prefix in ["", "thermo", "halo"]:
    name = "_".join(([prefix] if prefix else []) + ["ssl"])
    print(f"{name = }")
    da = compute_ssl_timeseries(
        *request_reanalysis,
        prefix=prefix,
        lon_slice=lon_slice,
        lat_slice=lat_slice,
        **download_kwargs,
    )
    dataarrays.append(da.rename(name))
ds_reanalysis = xr.merge(dataarrays)

del dataarrays

## Quick and dirty plot

In [None]:
da = ds_reanalysis.to_dataarray()
da.plot(hue="variable")
plt.grid()

## Download and transform ARGO

In [None]:
def preprocess(ds):
    # Naming
    ds = ds.rename({var: var.lower() for var in ds.variables})
    # Time
    ds["time"].attrs["calendar"] = "360_day"
    ds = xr.decode_cf(ds)
    # Depth
    ds["depth"] = -gsw.z_from_p(ds["pressure"], ds["latitude"]).mean(
        "latitude", keep_attrs=True
    )
    ds["depth"].attrs.update({"positive": "down", "long_name": "Depth from pressure"})
    return ds.swap_dims(pressure="depth")


# First dataset
filenames = []
for var in ["Temperature", "Salinity"]:
    url = f"https://sio-argo.ucsd.edu/RG/RG_ArgoClim_{var}_2019.nc.gz"
    filename = pooch.retrieve(url=url, known_hash=None, processor=pooch.Decompress())
    filenames.append(filename)
with xr.set_options(use_new_combine_kwarg_defaults=True):
    ds_argo_1 = xr.open_mfdataset(filenames, preprocess=preprocess, decode_times=False)

# Second dataset
filenames = []
for year in range(2019, year_stop + 1):
    for month in range(1, 13):
        url = f"https://sio-argo.ucsd.edu/RG/RG_ArgoClim_{year}{month:02d}_2019.nc.gz"
        filename = pooch.retrieve(
            url=url, known_hash=None, processor=pooch.Decompress()
        )
        filenames.append(filename)
with xr.set_options(use_new_combine_kwarg_defaults=True):
    ds_argo_2 = xr.open_mfdataset(filename, preprocess=preprocess, decode_times=False)

# Combine
dataarrays = []
for var in ["salinity", "temperature"]:
    da = ds_argo_1[f"argo_{var}_mean"]
    units = da.units
    da = xr.combine_by_coords(
        [
            da + ds_argo_1[f"argo_{var}_anomaly"],
            da + ds_argo_2[f"argo_{var}_anomaly"],
        ]
    )
    da.attrs["units"] = units
    dataarrays.append(da.rename(var))
with xr.set_options(use_new_combine_kwarg_defaults=True):
    ds_argo = xr.merge(dataarrays)

# Selection
ds_argo = ds_argo.sel(time=slice(str(year_start), str(year_stop)))
ds_argo = utils.regionalise(ds_argo, lon_slice=lon_slice, lat_slice=lat_slice)

## Download and transform satellite

In [None]:
ds_satellite = download.download_and_transform(*request_satellite, **download_kwargs)
ds_satellite = utils.regionalise(ds_satellite, lon_slice=lon_slice, lat_slice=lat_slice)