# Create 3D Initial Conditions for the 79NG Fjord GETM Setup

This notebook creates 3D initial conditions for the 79NG fjord setup in GETM with data from a global ocean model.
The data have been provided by Claudia Wekerle (AWI) and come from a global FESOM2.1 setup with increased resolution in the 79NG fjord, see Wekerle et al. (2024, https://doi.org/10.1038/s41467-024-45650-z).

Notebook by Markus Reinert (IOW, 2023–2024, https://orcid.org/0000-0002-3761-8029).

In [None]:
from datetime import datetime

import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
from scipy.interpolate import griddata
from dask.diagnostics import ProgressBar
from pyproj import CRS, Transformer

from tools.configuration import Configuration

In [None]:
config = Configuration()

## Prepare the coordinate transformation

The interpolation of the data must be computed on a Cartesian coordinate system, not a latitude–longitude system, in order to have the correct distances between grid points.
Here we define
* the Coordinate Reference System (CRS) in which the model grids are given,
* a projected CRS that is suitable for the interpolation,
* a `pyproj.Transformer` object from the former to the latter.

### CRS of model grids

In [None]:
crs_latlon = CRS.from_epsg(4326)
crs_latlon

### CRS for interpolation

In [None]:
crs_cartesian = CRS.from_epsg(3413)
crs_cartesian

### CRS transformer

In [None]:
transformer = Transformer.from_crs(crs_latlon, crs_cartesian)

## Load the GETM topography

In [None]:
filename = config.get_file_path("getm/domain/bathymetry")
print(f"Loading topography from {filename!r}.")
getm = xr.open_dataset(filename)
getm

## Load the FESOM data

In [None]:
print("Loading FESOM data")
fesom = xr.open_mfdataset("data/FESOM/*.fesom.*.sub.nc", drop_variables="faces")
# Add Cartesian coordinates to the dataset
fesom.coords["x"], fesom.coords["y"] = (
    ("nod2", coord, {"long_name": axis_info.name, "units": "m", "CRS": str(crs_cartesian)})
    for coord, axis_info in zip(transformer.transform(fesom.lat, fesom.lon), crs_cartesian.axis_info)
)
variables = [var for var in fesom.variables if fesom[var].dims == ("time", "nod2", "nz1")]
print("Contains", *variables)
fesom

### Average in time

The FESOM dataset contains monthly mean values, one value for each calendar month.
Since months have different numbers of days, we need to take a weighted average of the monthly means.
The weight of each month depends on the number of days in this month.
With these weights, every day (and not each month) counts with the same weight into the final average.

Luckily, the timestamp of each monthly mean in the FESOM dataset is on the last day of that month, so we can use the day-attribute of the timestamp to get the number of days in this month.

In [None]:
# Compute weights for the monthly means
fesom_datetime = fesom.time.data.astype("datetime64[s]").tolist()
days_per_month = np.array([dt.day for dt in fesom_datetime])
weights = xr.DataArray(days_per_month / days_per_month.mean(), {"time": fesom.time})

# Check that the weights are as expected
assert len(days_per_month) == 10 * 12, "number of months does not match 10 years"
assert sum(days_per_month) == 10 * 365 + 3, "number of days does not match 10 years incl. 3 leap years"
assert np.allclose(weights.mean(), 1), "average weight is not 1"

# Show the weights
weights.plot(figsize=(12, 3), marker=".", linestyle="--")
for d in range(28, 32):
    w = d * len(days_per_month) / sum(days_per_month)
    plt.axhline(w, c="k", lw=1)
    plt.text(fesom_datetime[-1], w, f"months with {d} days", ha="right", va="top" if d == 30 else "bottom")
plt.title("Weights of monthly mean values")
plt.grid(axis="x")

In [None]:
print("Computing weighted time-averages")
with xr.set_options(keep_attrs=True):
    fesom = (fesom * weights).mean("time")
# Convert from Dask to xarray
with ProgressBar():
    fesom = fesom.compute()
fesom

### Mask cells with zero salinity

In [None]:
# Determine and apply the mask
fesom["mask"] = fesom.salt > 0
for var in variables:
    fesom[var] = fesom[var].where(fesom.mask)

# Remove layers where all values are masked out
print(f"Original FESOM dataset:    {fesom.nz1.size} vertical layers with max. depth {fesom.nz.item(-1):6} m")
for k in range(fesom.nz1.size):
    if not np.any(fesom.mask.isel(nz1=k)):
        assert np.all(fesom.mask.isel(nz1=slice(k, None)) == False), "not all deeper levels are masked"
        assert fesom.nz1.size == fesom.nz.size - 1, "dimension nz is not 1 larger than nz1"
        fesom = fesom.isel(nz1=slice(k), nz=slice(k+1))
        break
print(f"With empty layers removed: {fesom.nz1.size} vertical layers with max. depth {fesom.nz.item(-1):6} m")

fesom

## Compare the grids

In [None]:
fig, axs = plt.subplots(ncols=2, constrained_layout=True, figsize=(12, 6), dpi=200)
fig.suptitle("Comparison between the grids of FESOM (blue and yellow dots) and GETM (red)", weight="bold")
for ax, x, y, crs in zip(axs, ["lon", "x"], ["lat", "y"], [crs_latlon, crs_cartesian]):
    ax.scatter(fesom[x], fesom[y], 1/4, fesom.mask.isel(nz1=0))
    if x == "lon" and y == "lat":
        for lon in getm.lon:
            ax.plot([lon, lon], [getm.lat[0] - getm.dlat/2, getm.lat[-1] + getm.dlat/2], "r", lw=0.2)
        for lat in getm.lat:
            ax.plot([getm.lon[0] - getm.dlon/2, getm.lon[-1] + getm.dlon/2], [lat, lat], "r", lw=0.2)
    elif x == "x" and y == "y":
        ax.plot(getm.x, getm.y, "r.", ms=0.8, markeredgewidth=0)
    else:
        raise ValueError("unexpected combination of x and y")
    ax.set_title(f"{crs.name}\n{crs}")
    ax.set_xlabel(f"{fesom[x].long_name} [{fesom[x].units}]")
    ax.set_ylabel(f"{fesom[y].long_name} [{fesom[y].units}]")
ax.set_aspect("equal")

In [None]:
fig, axs = plt.subplots(ncols=2, constrained_layout=True, figsize=(12, 6), dpi=200)
fig.suptitle("Comparison between the masks of FESOM (dots) and GETM (grey area)", weight="bold")

for ax, x, y, crs in zip(axs, ["lon", "x"], ["lat", "y"], [crs_latlon, crs_cartesian]):
    getm.mask.plot(ax=ax, x=x, y=y, cmap="Greys", vmax=2, add_colorbar=False)
    ax.scatter(fesom[x], fesom[y], 1/4, fesom.mask.isel(nz1=0), cmap="tab10_r", vmax=11/7)
    ax.set_title(f"{crs.name}\n{crs}")
    ax.set_xlabel(f"{fesom[x].long_name} [{fesom[x].units}]")
    ax.set_ylabel(f"{fesom[y].long_name} [{fesom[y].units}]")

ax = axs[0]
ax.set_xlim(None, -14.5)
ax.set_ylim(79.1, 80.4)
ax.plot([], [], ".", color="tab:cyan", label="ice tongue")
ax.plot([], [], ".", color="tab:red", label="ocean")
ax.legend(title="FESOM mask")

ax = axs[1]
ax.set_aspect("equal")

## Create initial conditions

### Create the dataset

In [None]:
time_string = config.get_text("getm/time/start")
datetime_start = datetime.strptime(time_string, "%Y-%m-%d %H:%M:%S")
print(f"Model runs from {datetime_start}.")

In [None]:
init = xr.Dataset(
    {
        "salt": (
            ["time", "zax", "lat", "lon"],
            np.full((1, fesom.nz1.size, *getm.mask.shape), np.nan),
            {"long_name": "salinity", "units": "g/kg"},
        ),
        "temp": (
            ["time", "zax", "lat", "lon"],
            np.full((1, fesom.nz1.size, *getm.mask.shape), np.nan),
            {"long_name": "temperature", "units": "degC"},
        ),
    },
    coords={
        "time": (["time"], [datetime_start]),
        "zax": (["zax"], -fesom.nz1.data, {"long_name": "z-axis", "units": "m", "positive": "up"}),
        "lat": getm.lat,
        "lon": getm.lon,
    },
    attrs={
        "title": "Initial conditions (3D) for the 79NG fjord GETM setup",
        "author": "Markus Reinert (ORCID: 0000-0002-3761-8029)",
        "institution": "Leibniz Institute for Baltic Sea Research Warnemuende (IOW), Germany",
        "source": "FESOM2.1 setup with focus on 79NG (Wekerle et al. 2024)",
    },
)
init

### Interpolate the data

In [None]:
for var in variables:
    progress = f"Interpolating {var}: {{:2}}/{fesom.nz1.size}"
    for i, z in enumerate(fesom.nz1):
        print(progress.format(i), end="\r")
        init[var][0, i] = griddata((fesom.x, fesom.y), fesom[var].sel(nz1=z), (getm.x, getm.y))
        # Fill missing values with nearest neighbour interpolation
        # Given that there are only 266 grid points that require this step,
        # there might be a better way.
        init[var][0, i] = init[var][0, i].where(
            init[var][0, i].notnull(),
            griddata((fesom.x, fesom.y), fesom[var].sel(nz1=z), (getm.x, getm.y), method="nearest"),
        )
    print(progress.format(i + 1))
    # Apply the mask
    init[var] = init[var].where(getm.mask)
# Remove the Cartesian coordinates that were added to the dataset when applying the mask
del init["x"], init["y"]
init

### Extrapolate vertically with constant values

To ensure that the vertical interpolation by GETM does not create NaNs, we eliminate missing values in the water columns.
Values are missing where grid cells are in the ice tongue or in the seabed, i.e., above the uppermost or below the lowermost valid grid cell.
To obtain the position of the uppermost valid grid cell, we take the maximum of the Boolean array that marks valid grid cells.
For the lowermost valid grid cell, we do the same operation on that array with flipped vertical axis.
Given the positions of the uppermost valid grid cells, we keep all values below and fill the grid cells above with the uppermost valid value.
We proceed analogously for the lower part.

In [None]:
for var in variables:
    is_valid = init[var].notnull()
    z_upper = is_valid.idxmax("zax")
    z_lower = is_valid.isel(zax=slice(None, None, -1)).idxmax("zax")
    init[var] = init[var].where(init.zax < z_upper, init[var].sel(zax=z_upper))
    init[var] = init[var].where(init.zax > z_lower, init[var].sel(zax=z_lower))

### Simplify the sponge zone

In the sponge zone (open boundary points and 3 adjacent grid points), there are no gradients in the topography, so for consistency, we also impose no gradients in the initial conditions.

In [None]:
# Check that there are really no topography gradients in the sponge zone
assert np.all(getm.bathymetry.where(getm.mask, 0).isel(lat=0) == getm.bathymetry.where(getm.mask, 0).isel(lat=slice(4)))
assert np.all(getm.bathymetry.where(getm.mask, 0).isel(lat=-1) == getm.bathymetry.where(getm.mask, 0).isel(lat=slice(-5, -1)))
assert np.all(getm.bathymetry.where(getm.mask, 0).isel(lon=-1) == getm.bathymetry.where(getm.mask, 0).isel(lon=slice(-5, -1)))

# Impose no gradients in the inital conditions in the sponge zone
for var in variables:
    init[var][0, :, :4] = init[var][0, :, 4]
    init[var][0, :, -5:] = init[var][0, :, -5]
    init[var][0, :, :, -5:] = init[var][0, :, :, -5]

### Invert the vertical axis

GETM actually expects that `zax` is depth with positive values below sea level.

In [None]:
assert np.all(init.zax < 0), "vertical axis already inverted"
init["zax"] = -init.zax
init.zax.attrs["long_name"] = "depth"
init.zax.attrs["positive"] = "down"
init

### Save the dataset

In [None]:
filename = config.get_file_path("getm/temp/temp_file")
assert filename.endswith(".nc"), "file for initial temperature distribution is not NetCDF"
assert filename == config.get_file_path("getm/salt/salt_file"), "filenames of temperature and salinity initial conditions are different"
init.to_netcdf(filename, unlimited_dims=["time"])
print(f"Saved the initial conditions as {filename!r}.")