## Prepare a ROMS I4DVar observation file from a ROMS output file

In [None]:
import os
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import xarray as xr

from nautilos_osse.arguments import ArgumentsSpring2017

In [None]:
variable = "salt"
args = ArgumentsSpring2017()
args.output_obs_file = os.path.join(
    Path.home(), f"data_ROHO/obs_i4dvar_min_{variable}_2017-02-01_to_2017-03-15.nc"
)
data_type = {
    "temp": 6,
    "salt": 7,
}

Load necessary nc files to datasets.

In [None]:
time_slice = slice(args.start_time, args.end_time)
ds_data = xr.open_mfdataset(list(args.nature_files_2017)).sel(ocean_time=time_slice)
ds_grid = xr.open_dataset(args.input_grid_file)
ds_wc13_obs = xr.open_dataset(args.wc13_obs_file)

In [None]:
ds_grid.mask_rho.plot()

In [None]:
ds_data

### Get the observation station

In [None]:
xi_rho = 150
eta_rho = 150

In [None]:
def keep_unique_times(data: xr.DataArray):
    """
    ROMS changes the initial values on startup and saves them.
    -> There are several similar data sets with slightly different data 
    on successive runs.
    This function keeps only the first date occurrence and removes the second. 
    """
    time_index = data['ocean_time'].to_index()
    is_unique = ~time_index.duplicated(keep='first')
    indices_unique = np.where(is_unique)[0]
    return data.isel(ocean_time=indices_unique)

In [None]:
da_station = keep_unique_times(ds_data.sel(xi_rho = xi_rho, eta_rho = eta_rho).isel(s_rho=-1)[variable].persist())

In [None]:
da_station

In [None]:
st_lon = da_station.lon_rho.values
st_lat = da_station.lat_rho.values
na_stations = da_station.values  # keep the name similar to 4st obs

### Prepare data arrays for the ROMS data assimilation observation file

In [None]:
points_per_time = 1
time_points = da_station.ocean_time.shape[0]
print(f"There are {points_per_time} observations per time and {time_points} time points.")

In [None]:
rho_levels = da_station.s_rho.shape
rho_levels

In [None]:
da_station.s_rho

Obs file structure: 
All variables are similar size 1d arrays, where size is a number of total observations,
except survey_time and nobs per survey time.
Survey can contain several observations.

obs_value

In [None]:
na_st_flat = na_stations
na_st_flat.shape

In [None]:
print(f"Is data finite? {np.all(np.isfinite(na_st_flat))}")

In [None]:
fig, ax = plt.subplots(figsize=(14, 7))
ax.plot(na_st_flat, "c+")

obs_type

In [None]:
na_type = np.full_like(na_st_flat, data_type[variable], dtype=np.int32)
na_type.shape

obs_provenance

In [None]:
na_provenance = np.full_like(na_st_flat, 1, dtype=np.int32)
na_provenance.shape

obs_time

In [None]:
na_time = np.repeat(da_station.ocean_time.values[..., np.newaxis], points_per_time, axis=1).flatten()
na_time.shape

In [None]:
fig, ax = plt.subplots(figsize=(14, 7))
ax.plot(na_time)

obs_depth

If positive, should be a ROMS grid level, for example, 25 is a top layer if there are 25 layers
If negative, meters, not tested

Here should be like 1 1 1 1 2 2 2 2... 25 25 25 25 1 1 1 1 ...

In [None]:
depth_time_step = 25

In [None]:
na_depth = np.tile(depth_time_step, time_points)
na_depth.shape

In [None]:
fig, ax = plt.subplots(figsize=(14, 7))
ax.plot(na_depth)

obs_Xgrid and obs_Ygrid

We need to find the closest to the observation coordinates grid points.
They can be fractional.

from https://www.myroms.org/Workshops/4DVAR2019/Tutorials/Tutorial_10_2019.pdf:

The obs_lon and obs_lat values are only necessary to compute the fractional grid
locations (obs_Xgrid, obs_Ygrid) during pre-processing using obs_ijpos.m.

The obs_lon and obs_lat are not used directly in ROMS when running the 4D-Var
algorithms for efficiency and because of the complexity of curvilinear grids. The
fractional grid locations obs_Xgrid and obs_Ygrid are used instead. Their values
range are:
obs_Xgrid: 0.5 to L – 0.5
obs_Ygrid: 0.5 to M – 0.5

In [None]:
na_xis = np.array([xi_rho, ])

In [None]:
na_etas = np.array([eta_rho, ])

Plot the fractional xi, eta station coordinates

In [None]:
p = ds_grid.mask_rho.plot(x="xi_rho", y="eta_rho", figsize=(14, 7), cmap="GnBu")
p.axes.scatter(x=na_xis, y=na_etas, color="red")
p.axes.annotate("St 1", (na_xis[0], na_etas[0]), color="red")

In [None]:
na_xgrid = np.tile(np.tile(na_xis, rho_levels), time_points)
na_ygrid = np.tile(np.tile(na_etas, rho_levels), time_points)

In [None]:
fig, ax = plt.subplots(figsize=(14, 3))
ax.plot(na_xgrid, "c+")

In [None]:
na_xgrid.shape

In [None]:
na_ygrid.shape

obs_Zgrid

There is no description, but according to `d_sst_obs.m` zgrid on surface should be zero

In [None]:
na_zgrid_time_step = np.array([0, ])

In [None]:
na_zgrid = np.tile(na_zgrid_time_step, time_points)
na_zgrid.shape

In [None]:
fig, ax = plt.subplots(figsize=(14, 7))
ax.plot(na_zgrid)

obs_Error

In [None]:
# from `d_sst_obs.m`
na_error = np.full_like(na_st_flat, 0.4**2)

In [None]:
na_error.shape

survey_time

In [None]:
na_survey_time = da_station.ocean_time.values

In [None]:
na_survey_time

In [None]:
na_survey_time.shape

np_nobs

In [None]:
na_nobs = np.repeat(points_per_time, time_points)

In [None]:
na_nobs

np_lon and np_lat

In [None]:
na_lon = np.tile(np.tile(st_lon, rho_levels), time_points)
na_lat = np.tile(np.tile(st_lat, rho_levels), time_points)

In [None]:
na_lon.shape

In [None]:
na_lat.shape

### Make a dataset

In [None]:
ds = xr.Dataset(
    {
        "spherical": 1,
        "Nobs": ("survey", na_nobs),
        "survey_time": ("survey", na_survey_time),
        "obs_variance": ("state_variable", ds_wc13_obs.obs_variance.data),
        "obs_value": ("datum", na_st_flat),
        "obs_type": ("datum", na_type),
        "obs_provenance": ("datum", na_provenance),
        "obs_time": ("datum", na_time),
        "obs_depth": ("datum", na_depth),
        "obs_Xgrid": ("datum", na_xgrid),
        "obs_Ygrid": ("datum", na_ygrid),
        "obs_Zgrid": ("datum", na_zgrid),
        "obs_error": ("datum", na_error),
        "obs_lon": ("datum", na_lon),
        "obs_lat": ("datum", na_lat),
    },
)

In [None]:
ds

According to my experiments it is crucial to put time units with a time reference 'days since ...'

In [None]:
ds.to_netcdf(
    args.output_obs_file,
    encoding={
        "survey_time": {"units": "days since 1948-01-01"},
        "obs_time": {"units": "days since 1948-01-01"},
    },
)