# Steric Sea Level Contribution from ORAS5 Reanalysis to Global Sea Level Observed by Satellite Altimetry

## Import libraries

In [None]:
import cacholote
import gsw_xarray as gsw
import matplotlib.pyplot as plt
import pooch
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download, plot, utils

plt.style.use("seaborn-v0_8-notebook")

## Define parameters

In [None]:
# Time
year_start = 2004
year_stop = 2023

# Space
lat_slice = slice(-60, 60)
lon_slice = slice(-180, 180)

## Define Requests

In [None]:
time_request = {
    "year": [str(year) for year in range(year_start, year_stop + 1)],
    "month": [f"{month:02d}" for month in range(1, 13)],
}

request_reanalysis = (
    "reanalysis-oras5",
    {
        "product_type": ["consolidated", "operational"],
        "vertical_resolution": "all_levels",
        "variable": ["potential_temperature", "salinity"],
    }
    | time_request,
)
request_satellite = (
    "satellite-sea-level-global",
    {"variable": ["monthly_mean"], "version": "vdt2024"} | time_request,
)

download_kwargs = {
    "chunks": {"year": 1, "month": 1},
    "drop_variables": ["time_counter_bnds"],
}

## Define functions to cache

In [None]:
def compute_gsw_ds(ds):
    p = gsw.p_from_z(-ds["deptht"], ds["latitude"])
    SA = gsw.SA_from_SP(ds["vosaline"], p, ds["longitude"], ds["latitude"])
    if "votemper" in ds.data_vars:
        CT = gsw.CT_from_pt(SA, ds["votemper"])
    else:
        CT = gsw.CT_from_t(SA, ds["temperature"], p)
    rho = gsw.rho(SA, CT, p)
    ds = xr.merge([p, SA, CT, rho])
    for da in ds.data_vars.values():
        chunks = []
        for dim in da.dims:
            if dim == "time":
                chunks.append(1)
            elif dim in ["x", "y", "latitude", "longitude"]:
                chunks.append(200)
            elif dim == "deptht":
                chunks.append(15)
            else:
                raise ValueError(f"{dim = }")
        da.encoding["chunksizes"] = tuple(chunks)
    return ds


def compute_ssl_from_rho(rho, rho0=1025, prefix=""):
    grouped_rho = rho.groupby("time.month")
    delta_rho = grouped_rho - grouped_rho.mean()
    ssl = -(delta_rho.fillna(0) / rho0).integrate("deptht")
    ssl.attrs = {"long_name": f"{prefix}steric sea level".title(), "units": "m"}
    return ssl


def compute_ssl_from_ds(ds, prefix, lon_slice, lat_slice):
    ds = utils.regionalise(ds, lon_slice=lon_slice, lat_slice=lat_slice)
    if prefix == "":
        rho = ds["rho"]
    elif prefix == "thermo":
        rho = gsw.rho(ds["SA"].mean("time"), ds["CT"], ds["p"])
    elif prefix == "halo":
        rho = gsw.rho(ds["SA"], ds["CT"].mean("time"), ds["p"])
    else:
        raise NotImplementedError(f"{prefix=}")
    return compute_ssl_from_rho(rho, prefix=prefix).rename(f"{prefix}ssl")


def compute_ssl_timeseries_from_ds(ds, prefix, lon_slice, lat_slice):
    ssl = compute_ssl_from_ds(ds, prefix, lon_slice, lat_slice)
    return diagnostics.spatial_weighted_mean(ssl)


def compute_ssl_trend_from_ds(ds, prefix, lon_slice, lat_slice):
    ssl = compute_ssl_from_ds(ds, prefix, lon_slice, lat_slice)
    coords = ssl.to_dataset().drop_dims("time").coords
    ssl = diagnostics.time_weighted_linear_trend(ssl)
    return ssl.assign_coords(coords)


@cacholote.cacheable
def compute_ssl_timeseries(
    collection_id,
    request,
    prefix,
    lon_slice,
    lat_slice,
    **download_kwargs,
):
    ds = download.download_and_transform(
        collection_id, request, transform_func=compute_gsw_ds, **download_kwargs
    )
    return compute_ssl_timeseries_from_ds(
        ds, prefix, lon_slice=lon_slice, lat_slice=lat_slice
    )


@cacholote.cacheable
def compute_ssl_trend(
    collection_id,
    request,
    prefix,
    lon_slice,
    lat_slice,
    **download_kwargs,
):
    ds = download.download_and_transform(
        collection_id, request, transform_func=compute_gsw_ds, **download_kwargs
    )
    return compute_ssl_trend_from_ds(
        ds, prefix, lon_slice=lon_slice, lat_slice=lat_slice
    )


def preprocess(ds):
    # Naming
    ds = ds.rename({var: var.lower() for var in ds.variables})
    # Time
    ds["time"].attrs["calendar"] = "360_day"
    ds = xr.decode_cf(ds)
    ds["time"].attrs = {"standard_name": "time"}
    ds["time"].encoding = {}
    # Depth
    ds["depth"] = -gsw.z_from_p(ds["pressure"], ds["latitude"]).mean(
        "latitude", keep_attrs=True
    )
    ds["depth"].attrs.update({"positive": "down", "long_name": "Depth from pressure"})
    return ds.swap_dims(pressure="depth")


@cacholote.cacheable
def get_argo(year, month):
    # Get climatology
    filenames = []
    for var in ["Temperature", "Salinity"]:
        url = f"https://sio-argo.ucsd.edu/RG/RG_ArgoClim_{var}_2019.nc.gz"
        filename = pooch.retrieve(
            url=url, known_hash=None, processor=pooch.Decompress()
        )
        filenames.append(filename)
    ds = xr.open_mfdataset(filenames, preprocess=preprocess, decode_times=False)
    ds_clima = ds.drop_dims("time")

    # Get anomalies
    ds = ds.sel(time=slice(f"{year}-{month:02d}", f"{year}-{month:02d}"))
    if not ds.sizes["time"]:
        url = f"https://sio-argo.ucsd.edu/RG/RG_ArgoClim_{year}{month:02d}_2019.nc.gz"
        filename = pooch.retrieve(
            url=url, known_hash=None, processor=pooch.Decompress()
        )
        ds = xr.open_mfdataset(filename, preprocess=preprocess, decode_times=False)
        ds = ds.sel(time=slice(f"{year}-{month:02d}", f"{year}-{month:02d}"))

    # Compute values
    dataarrays = []
    for var in ["salinity", "temperature"]:
        da = ds_clima[f"argo_{var}_mean"] + ds[f"argo_{var}_anomaly"]
        dataarrays.append(da.rename(var))
    ds = xr.merge(dataarrays)

    # Compute gsw dataset
    ds = ds.rename(depth="deptht", salinity="vosaline")
    ds = compute_gsw_ds(ds)
    return ds


def get_all_argo(year_start, year_stop):
    datasets = []
    for year in range(year_start, year_stop + 1):
        for month in range(1, 13):
            datasets.append(get_argo(year, month))
    ds = xr.concat(datasets, "time")
    return ds


@cacholote.cacheable
def compute_ssl_timeseries_argo(prefix, lon_slice, lat_slice, year_start, year_stop):
    ds = get_all_argo(year_start, year_stop)
    return compute_ssl_timeseries_from_ds(ds, prefix, lon_slice, lat_slice)


@cacholote.cacheable
def compute_ssl_trend_argo(prefix, lon_slice, lat_slice, year_start, year_stop):
    ds = get_all_argo(year_start, year_stop)
    return compute_ssl_trend_from_ds(ds, prefix, lon_slice, lat_slice)

## Download and transform reanalysis

In [None]:
prefixes = ["", "thermo", "halo"]

timeseries = []
trends = []
for prefix in prefixes:
    name = "_".join(([prefix] if prefix else []) + ["ssl"])
    print(f"{name = }")
    da = compute_ssl_timeseries(
        *request_reanalysis,
        prefix=prefix,
        lon_slice=lon_slice,
        lat_slice=lat_slice,
        **download_kwargs,
    )
    timeseries.append(da.rename(name))

    da = compute_ssl_trend(
        *request_reanalysis,
        prefix=prefix,
        lon_slice=lon_slice,
        lat_slice=lat_slice,
        **download_kwargs,
    )
    if "latitude" not in da.coords:
        print("Delete!")
        cacholote.delete(
            compute_ssl_trend,
            *request_reanalysis,
            prefix=prefix,
            lon_slice=lon_slice,
            lat_slice=lat_slice,
            **download_kwargs,
        )
        da = compute_ssl_trend(
            *request_reanalysis,
            prefix=prefix,
            lon_slice=lon_slice,
            lat_slice=lat_slice,
            **download_kwargs,
        )
    trends.append(da.rename(name))
ds_reanalysis_timeseries = xr.merge(timeseries)
ds_reanalysis_trend = xr.merge(trends)

## Download and transform ARGO

In [None]:
prefixes = ["", "thermo", "halo"]

timeseries = []
trends = []
for prefix in prefixes:
    name = "_".join(([prefix] if prefix else []) + ["ssl"])
    print(f"{name = }")
    da = compute_ssl_timeseries_argo(
        prefix=prefix,
        lon_slice=lon_slice,
        lat_slice=lat_slice,
        year_start=year_start,
        year_stop=year_stop,
    )
    timeseries.append(da.rename(name))
    da = compute_ssl_trend_argo(
        prefix=prefix,
        lon_slice=lon_slice,
        lat_slice=lat_slice,
        year_start=year_start,
        year_stop=year_stop,
    )
    trends.append(da.rename(name))
ds_argo_timeseries = xr.merge(timeseries)
ds_argo_trend = xr.merge(trends)

# Align
ds_argo_timeseries["time"] = ds_argo_timeseries["time"].convert_calendar(
    "proleptic_gregorian", align_on="date"
)
if (ds_argo_timeseries["time"] == ds_reanalysis_timeseries["time"]).all():
    ds_argo_timeseries["time"] = ds_reanalysis_timeseries["time"]

## Plot timeseries

In [None]:
ds_timeseries = xr.concat(
    [
        ds_reanalysis_timeseries.expand_dims(product=["ORAS5"]),
        ds_argo_timeseries.expand_dims(product=["ARGO"]),
    ],
    "product",
)
da = ds_timeseries.to_dataarray()
facet = da.plot(hue="variable", col="product")
for ax in facet.axs.flatten():
    ax.grid()

## Plot maps

In [None]:
for label, ds in zip(["ORAS5", "ARGO"], [ds_reanalysis_trend, ds_argo_trend]):
    da = ds.to_dataarray()
    facet = plot.projected_map(da, col="variable", robust=True)
    for ax in facet.axs.flatten():
        ax.set_extent(
            [lon_slice.start, lon_slice.stop, lat_slice.start, lat_slice.stop]
        )
    facet.fig.suptitle(label)
    plt.show()

## Download and transform satellite

In [None]:
ds_satellite = download.download_and_transform(*request_satellite, **download_kwargs)
ds_satellite = utils.regionalise(ds_satellite, lon_slice=lon_slice, lat_slice=lat_slice)