## Prepare ROMS I4DVar observation file from cortadv5_FilledSST

In [None]:
from dataclasses import dataclass

import numpy as np
import xarray as xr

In [None]:
@dataclass
class Arguments:
    start_time: str = '2007-01-01'
    end_time: str = '2007-01-31'
    input_grid_file: str = '/cluster/projects/nn9490k/ROHO800/Grid/ROHO800_grid_fix5.nc'
    input_data_file: str = '/cluster/projects/nn9297k/ROHO800+/InputData/4dvar/cortadv5_FilledSST.nc'
    wc13_obs_file: str = '/cluster/home/shmiak/src/roms-applications/WC13/Data/wc13_obs.nc'
    output_obs_file: str = '/cluster/projects/nn9297k/shmiak/roho800_data/input_data/roho800_obs_sst.nc'

args = Arguments()

In [None]:
ds_grid = xr.open_dataset(args.input_grid_file)
ds_data = xr.open_dataset(args.input_data_file)
wc13_obs = xr.open_dataset(args.wc13_obs_file)

In [None]:
ds_grid

In [None]:
ds_data

In [None]:
# slice a time period and interpolate to the ROMS grid
ds_data = ds_data.sel(time=slice(args.start_time, args.end_time))
coords = {
    'lon': ds_grid.lon_rho,
    'lat': ds_grid.lat_rho,
}
da_sst = ds_data['FilledSST'].interp(coords)
da_sst = da_sst / ds_grid.mask_rho  # exclude values outside the sea
da_sst -= 273.15

In [None]:
da_sst

In [None]:
da_sst_std = ds_data['FilledSSTstandardDeviation']
# add a time dimension and a coordinate to make xr.interp work
da_sst_std = da_sst_std.expand_dims({'time': (ds_data.coords['time'].data[0], )})
da_sst_std = da_sst_std.interp(coords)

In [None]:
da_sst_std

In [None]:
mask = ~np.isnan(da_sst.isel(time=0).values)  # mask of grid points without data
points_per_time = mask.flatten()[mask.flatten()==True].shape[0]
time_points = da_sst.time.shape[0]
print(f"The number of points per time: {points_per_time}")

#### obs_value

In [None]:
# Flattens from the last dimenstion: (x, y, z) so z -> y -> x
np_sst = da_sst.values.flatten(order='C')
np_sst = np_sst[~np.isnan(np_sst)]
np_sst.shape

#### obs_type

In [None]:
np_type = np.full_like(np_sst, 6, dtype=np.int32)
np_type.shape

#### obs_provenance

In [None]:
np_provenance = np.full_like(np_sst, 1, dtype=np.int32)
np_provenance.shape

#### obs_time

In [None]:
np_time = np.repeat(da_sst.time.values[..., np.newaxis], points_per_time, axis=1).flatten()
np_time.shape

#### obs_depth

In [None]:
np_depth = np.full_like(np_sst, 1)
np_depth.shape

#### obs_Xgrid and obs_Ygrid

In [None]:
x_idx, y_idx = np.where(mask)

In [None]:
np_xgrid = np.repeat(x_idx.astype(dtype=np.float64) + 1, time_points)
np_ygrid = np.repeat(y_idx.astype(dtype=np.float64) + 1, time_points)

In [None]:
np_xgrid.shape

In [None]:
np_ygrid.shape

#### obs_Zgrid

In [None]:
np_zgrid = np.full_like(np_sst, 0)
np_zgrid.shape

#### obs_Error

In [None]:
# np_sst_std = da_sst_std.values.flatten(order='C')
# np_sst_std = np_sst_std[~np.isnan(np_sst_std)]
# np_sst_var = np_sst_std ** 2
# np_error = np.repeat(np_sst_var, time_points)
np_error = np.full_like(np_sst, 0.4**2)

In [None]:
np_error.shape

#### survey_time

In [None]:
np_survey_time = da_sst.time.values

In [None]:
np_survey_time.shape

#### np_nobs

In [None]:
np_nobs = np.repeat(points_per_time, time_points)

In [None]:
np_nobs.shape

#### Make a dataset

In [None]:
ds = xr.Dataset(
    {
        "spherical": 1,
        "Nobs": ("survey", np_nobs),
        "survey_time": ("survey", np_survey_time),
        "obs_variance": ("state_variable", wc13_obs.obs_variance.data),
        "obs_value": ("datum", np_sst),
        "obs_type": ("datum", np_type),
        "obs_provenance": ("datum", np_provenance),
        "obs_time": ("datum", np_time),
        "obs_depth": ("datum", np_depth),
        "obs_Xgrid": ("datum", np_xgrid),
        "obs_Ygrid": ("datum", np_ygrid),
        "obs_Zgrid": ("datum", np_zgrid),
        "obs_error": ("datum", np_error),
    },
)
ds

In [None]:
ds.to_netcdf(args.output_obs_file)