## Prepare ROMS I4DVar observation file from a ROMS output file

In [None]:
from dataclasses import dataclass

import numpy as np
import xarray as xr
import xesmf as xe
import matplotlib.pyplot as plt

There are no nc stations file for this time period, get from the next period for coordinates.

In [None]:
@dataclass
class Arguments:
    start_time: str = '2018-04-01'
    end_time: str = '2018-04-07'
    input_grid_file: str = '/cluster/projects/nn9490k/ROHO800/Grid/ROHO800_grid_fix5.nc'
    input_data_files: tuple = (
        '/cluster/projects/nn9297k/ROHO160+/OutputData/s_layers_25/1_dec2017-dec2018/roho160_his_0011.nc',
        '/cluster/projects/nn9297k/ROHO160+/OutputData/s_layers_25/1_dec2017-dec2018/roho160_his_0012.nc',
    )
    stations: str = '/cluster/projects/nn9297k/ROHO160+/OutputData/s_layers_25/2_dec2018-sep2019/roho160_sta.nc'
    wc13_obs_file: str = '/cluster/home/shmiak/src/roms-applications/WC13/Data/wc13_obs.nc'
    output_obs_file: str = '/cluster/projects/nn9297k/shmiak/roho800_data/input_data/obs_i4dvar_4st_temp_2018-04-01_to_2018-04-07.nc'

args = Arguments()

Load necessary nc files to datasets.

In [None]:
time_slice = slice(args.start_time, args.end_time)
ds_data = xr.open_mfdataset(list(args.input_data_files)).sel(ocean_time=time_slice)
ds_stations = xr.open_dataset(args.stations).sel(ocean_time=time_slice)
ds_grid = xr.open_dataset(args.input_grid_file)
ds_wc13_obs = xr.open_dataset(args.wc13_obs_file)

### Interpolate (regrid) temperature from 160 m to 800 m (checkup)

Prepare variables and resample from 160 to 800 meters domain.
Description of xesmf is unclear for me, so renaming is to follow the examples
at `https://xesmf.readthedocs.io/en/latest/notebooks/Curvilinear_grid.html`.

In [None]:
ds_grid = ds_grid.rename({"lon_rho": "lon", "lat_rho": "lat"})
ds_data = ds_data.rename({"lon_rho": "lon", "lat_rho": "lat"})

In [None]:
regridder = xe.Regridder(ds_data, ds_grid, "bilinear", unmapped_to_nan=True)
da_sst = regridder(ds_data['temp']).isel(s_rho=-1)

Visual check

In [None]:
da_sst = da_sst / ds_grid.mask_rho  # type: ignore ; exclude values outside the sea

In [None]:
ds_data['temp'].isel(ocean_time=-1, s_rho=-1).plot(figsize=(14, 7))  # type: ignore

In [None]:
da_sst.isel(ocean_time=-1).plot(figsize=(14, 7))  # type: ignore

### Check the coordinates of the observation stations

In [None]:
st_lons, st_lats = np.zeros(ds_stations.station.shape), np.zeros(ds_stations.station.shape)
for i, station in enumerate(ds_stations.station):
    st_lons[i] = station.lon_rho.values
    st_lats[i] = station.lat_rho.values

In [None]:
st_labels = ['VT53', 'VT70', 'VT74', 'VT69']
st_n_point = [1, 3, 4, 9]
p = ds_grid.mask_rho.plot(
    x="lon", y="lat", figsize=(14, 7), cmap='GnBu'
    )
p.axes.scatter(x=st_lons[st_n_point], y=st_lats[st_n_point], color='red')
for i, label in enumerate(st_labels):
    p.axes.annotate(label, (st_lons[st_n_point][i], st_lats[st_n_point][i]), color='red')

### Interpolate 160 m output to the data assimilation stations

In [None]:
ds_st_coords = xr.Dataset(
    {
        "lat": (["lat"], st_lats[st_n_point], {"units": "degrees_north"}),
        "lon": (["lon"], st_lons[st_n_point], {"units": "degrees_east"}),
    }
)

In [None]:
regridder = xe.Regridder(ds_data, ds_st_coords, "bilinear", unmapped_to_nan=True)
da_stations_temp = regridder(ds_data['temp'])

Xesmf interpolates to the grid that is a combination of station coordinates

In [None]:
da_stations_temp

In [None]:
da_stations_temp.isel(lat=0, lon=0).plot(x="ocean_time", y="s_rho")  # VT53

Extract the diagonal elements along the lat and lon dimensions, these are stations

In [None]:
na_stations = da_stations_temp.values.diagonal(axis1=2, axis2=3)

Checkup

In [None]:
na_stations.shape

In [None]:
na_stations[0, 3, :]

In [None]:
da_stations_temp.isel(ocean_time=0, s_rho=3).values

### Prepare data arrays for the ROMS data assimilation observation file

In [None]:
points_per_time = len(st_n_point) * da_stations_temp.s_rho.shape[0]
time_points = da_stations_temp.ocean_time.shape[0]
print(f"There are {points_per_time} observations per time and {time_points} time points.")

In [None]:
rho_levels = da_stations_temp.s_rho.shape[0]

Obs file structure: 
All variables are similar size 1d arrays, where size is a number of total observations,
except survey_time and nobs per survey time.
Survey can contain several observations.

obs_value

In [None]:
na_stations[0, 0, :]

In [None]:
na_stations[0, 1, :]

In [None]:
na_stations.flatten()[:8]

The order of values: 'VT53', 'VT70', 'VT74', 'VT69' from the bottom upwards, from the first time step

In [None]:
# Flattens from the last dimenstion: (x, y, z) so z -> y -> x
na_st_temp = na_stations.flatten()
na_st_temp.shape

In [None]:
print(f"Is data finite? {np.all(np.isfinite(na_st_temp))}")

In [None]:
fig, ax = plt.subplots(figsize=(14, 7))
ax.plot(na_st_temp, 'c+')

obs_type

In [None]:
na_type = np.full_like(na_st_temp, 6, dtype=np.int32)
na_type.shape

obs_provenance

In [None]:
na_provenance = np.full_like(na_st_temp, 1, dtype=np.int32)
na_provenance.shape

obs_time

In [None]:
na_time = np.repeat(da_stations_temp.ocean_time.values[..., np.newaxis], points_per_time, axis=1).flatten()
na_time.shape

In [None]:
fig, ax = plt.subplots(figsize=(14, 7))
ax.plot(na_time)

obs_depth

If positive, should be a ROMS grid level, for example, 25 is a top layer if there are 25 layers
If negative, meters, not tested

Here should be like 1 1 1 1 2 2 2 2... 25 25 25 25 1 1 1 1 ...

In [None]:
depth_time_step = np.repeat(np.arange(1, 26, 1)[..., np.newaxis], len(st_n_point), axis=1).flatten()

In [None]:
na_depth = np.tile(depth_time_step, time_points)
na_depth.shape

In [None]:
fig, ax = plt.subplots(figsize=(14, 7))
ax.plot(na_depth)

obs_Xgrid and obs_Ygrid

We need to find the closest to the observation coordinates grid points.
They can be fractional.

from https://www.myroms.org/Workshops/4DVAR2019/Tutorials/Tutorial_10_2019.pdf:

The obs_lon and obs_lat values are only necessary to compute the fractional grid
locations (obs_Xgrid, obs_Ygrid) during pre-processing using obs_ijpos.m.

The obs_lon and obs_lat are not used directly in ROMS when running the 4D-Var
algorithms for efficiency and because of the complexity of curvilinear grids. The
fractional grid locations obs_Xgrid and obs_Ygrid are used instead. Their values
range are:
obs_Xgrid: 0.5 to L – 0.5
obs_Ygrid: 0.5 to M – 0.5

In [None]:
ds_st_coords

In [None]:
ds_st_coords['lat'].values

In [None]:
ds_st_coords['lon'].values

In [None]:
ds_grid

In [None]:
len_eta_rho = ds_grid.eta_rho.shape[0]
len_xi_rho = ds_grid.xi_rho.shape[0]

In [None]:
na_xi = np.tile(ds_grid.xi_rho.values, (len_eta_rho, 1))
na_eta = np.tile(ds_grid.eta_rho.values, (len_xi_rho, 1))

In [None]:
na_xi.shape

In [None]:
na_xi

In [None]:
na_eta.T.shape

In [None]:
na_eta.T

In [None]:
ds_grid = ds_grid.assign(xis=(['eta_rho', 'xi_rho'], na_xi.astype(np.float64)))
ds_grid = ds_grid.assign(etas=(['eta_rho', 'xi_rho'], na_eta.T.astype(np.float64)))

In [None]:
ds_grid

In [None]:
regridder = xe.Regridder(ds_grid, ds_st_coords, "bilinear", unmapped_to_nan=True)
da_xis = regridder(ds_grid['xis'])
da_etas = regridder(ds_grid['etas'])

In [None]:
na_xis = da_xis.values.diagonal()
na_etas = da_etas.values.diagonal()

In [None]:
na_xis

In [None]:
na_etas

Plot the fractional xi, eta station coordinates

In [None]:
p = ds_grid.mask_rho.plot(
    x="xi_rho", y="eta_rho", figsize=(14, 7), cmap='GnBu'
    )
p.axes.scatter(x=na_xis, y=na_etas, color='red')
for i, label in enumerate(st_labels):
    p.axes.annotate(label, (na_xis[i], na_etas[i]), color='red')

In [None]:
na_xgrid = np.tile(np.tile(na_xis, rho_levels), time_points)
na_ygrid = np.tile(np.tile(na_etas, rho_levels), time_points)

In [None]:
fig, ax = plt.subplots(figsize=(14, 3))
ax.plot(na_xgrid, 'c+')

In [None]:
na_xgrid.shape

In [None]:
na_ygrid.shape

obs_Zgrid

There is no description, but according to `d_sst_obs.m` zgrid on surface should be zero

In [None]:
na_zgrid_time_step = np.repeat(np.arange(25, 0, -1)[..., np.newaxis], len(st_n_point), axis=1).flatten() - 0.5

In [None]:
na_zgrid = np.tile(na_zgrid_time_step, time_points)
na_zgrid.shape

In [None]:
fig, ax = plt.subplots(figsize=(14, 7))
ax.plot(na_zgrid)

obs_Error

In [None]:
# from `d_sst_obs.m`
na_error = np.full_like(na_st_temp, 0.4**2)

In [None]:
na_error.shape

survey_time

In [None]:
na_survey_time = da_stations_temp.ocean_time.values

In [None]:
na_survey_time

np_nobs

In [None]:
na_nobs = np.repeat(points_per_time, time_points)

In [None]:
na_nobs

np_lon and np_lat

In [None]:
na_lon = np.tile(np.tile(ds_st_coords['lon'].values, rho_levels), time_points)
na_lat = np.tile(np.tile(ds_st_coords['lat'].values, rho_levels), time_points)

In [None]:
na_lon.shape

In [None]:
na_lat.shape

### Make a dataset

In [None]:
ds = xr.Dataset(
    {
        "spherical": 1,
        "Nobs": ("survey", na_nobs),
        "survey_time": ("survey", na_survey_time),
        "obs_variance": ("state_variable", ds_wc13_obs.obs_variance.data),
        "obs_value": ("datum", na_st_temp),
        "obs_type": ("datum", na_type),
        "obs_provenance": ("datum", na_provenance),
        "obs_time": ("datum", na_time),
        "obs_depth": ("datum", na_depth),
        "obs_Xgrid": ("datum", na_xgrid),
        "obs_Ygrid": ("datum", na_ygrid),
        "obs_Zgrid": ("datum", na_zgrid),
        "obs_error": ("datum", na_error),
        "obs_lon": ("datum", na_lon),
        "obs_lat": ("datum", na_lat),
    },
)
ds

In [None]:
ds

In [None]:
ds.to_netcdf(
    args.output_obs_file,
    encoding={
        'survey_time':{'units': "days since 1948-01-01"},
        'obs_time':{'units': "days since 1948-01-01"},
        }
    )