## Prepare ROMS I4DVar observation file from a ROMS output file

In [None]:
from dataclasses import dataclass

import cftime
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
import xesmf as xe

In [None]:
@dataclass
class Arguments:
    start_time: str = '2018-09-24'
    end_time: str = '2018-09-30'
    input_grid_file: str = '/cluster/projects/nn9490k/ROHO800/Grid/ROHO800_grid_fix5.nc'
    input_data_files: tuple = (
        '/cluster/projects/nn9297k/ROHO160+/OutputData/s_layers_25/1_dec2017-dec2018/roho160_his_0029.nc',
        '/cluster/projects/nn9297k/ROHO160+/OutputData/s_layers_25/1_dec2017-dec2018/roho160_his_0030.nc',
    )
    wc13_obs_file: str = '/cluster/home/shmiak/src/roms-applications/WC13/Data/wc13_obs.nc'
    output_obs_file: str = '/cluster/projects/nn9297k/shmiak/roho800_data/input_data/roho800_obs_sst_roho160_2018-09-24_to_2018-09-30.nc'

args = Arguments()

In [None]:
ds_grid = xr.open_dataset(args.input_grid_file)
ds_data = xr.open_mfdataset(list(args.input_data_files))
wc13_obs = xr.open_dataset(args.wc13_obs_file)

In [None]:
ds_data = ds_data.sel(ocean_time=slice(args.start_time, args.end_time))

In [None]:
ds_out = ds_grid.rename({"lon_rho": "lon", "lat_rho": "lat"})
ds = ds_data.rename({"lon_rho": "lon", "lat_rho": "lat"})
da_temp = ds['temp']

In [None]:
regridder = xe.Regridder(ds, ds_out, "bilinear", unmapped_to_nan=True)

In [None]:
da_out = regridder(da_temp)

In [None]:
da_temp.isel(ocean_time=-1, s_rho=-1).plot()  # type: ignore

In [None]:
da_out.isel(ocean_time=-1, s_rho=-1).plot()  # type: ignore

In [None]:
da_sst = da_out.isel(s_rho=-1) / ds_grid.mask_rho  # type: ignore ; exclude values outside the sea

In [None]:
da_sst = da_sst.rename({"ocean_time": "time"})

In [None]:
mask = np.isfinite(da_sst.isel(time=0).values)  # mask of grid points without data
points_per_time = mask.flatten()[mask.flatten()==True].shape[0]
time_points = da_sst.time.shape[0]
print(f"The number of points per time: {points_per_time}")

#### obs_value

In [None]:
# Flattens from the last dimenstion: (x, y, z) so z -> y -> x
np_sst = da_sst.values.flatten(order='C')
np_sst = np_sst[np.isfinite(np_sst)]
np_sst.shape

#### obs_type

In [None]:
np_type = np.full_like(np_sst, 6, dtype=np.int32)
np_type.shape

#### obs_provenance

In [None]:
np_provenance = np.full_like(np_sst, 1, dtype=np.int32)
np_provenance.shape

#### obs_time

In [None]:
da_sst.time.data

In [None]:
# change time of the first obs to correspond to the init conditions file
# roms doesn't read `days since ...` from this file, it uses this date from 
# another file
# time_cftime = cftime.datetime(2018, 9, 24)
# da_sst.time.data[0] = time_cftime

In [None]:
da_sst.time.data

In [None]:
# np datetime64 to python datetime to cftime num

# from datetime import datetime

# da_sst.time.data.dtype
# np.datetime64('1970-01-01T00:00:00', "ns").dtype
# timestamp = ((da_sst.time.data - np.datetime64('1970-01-01T00:00:00')) / np.timedelta64(1, 's'))
# dt_sst = datetime.utcfromtimestamp(int(timestamp))
# cftime.date2num(dt_sst, "days since 2007-01-15")

In [None]:
np_time = np.repeat(da_sst.time.values[..., np.newaxis], points_per_time, axis=1).flatten()
np_time.shape

In [None]:
fig, ax = plt.subplots(figsize=(20, 7))
ax.plot(np_time)

#### obs_depth
If positive, should be a ROMS grid level, for example, 25 is a top layer if there are 25 layers
If negative, meters, not tested

In [None]:
np_depth = np.full_like(np_sst, 25)
np_depth.shape

#### obs_Xgrid and obs_Ygrid

In [None]:
x_idx, y_idx = np.where(mask)

In [None]:
np_xgrid = np.tile(x_idx.astype(dtype=np.float64) + 1, time_points)
np_ygrid = np.tile(y_idx.astype(dtype=np.float64) + 1, time_points)

In [None]:
fig, ax = plt.subplots(figsize=(20, 7))
ax.plot(np_xgrid)

In [None]:
np_xgrid.shape

In [None]:
np_ygrid.shape

#### obs_Zgrid

In [None]:
np_zgrid = np.full_like(np_sst, 0)
np_zgrid.shape

#### obs_Error

In [None]:
# np_sst_std = da_sst_std.values.flatten(order='C')
# np_sst_std = np_sst_std[~np.isnan(np_sst_std)]
# np_sst_var = np_sst_std ** 2
# np_error = np.repeat(np_sst_var, time_points)
np_error = np.full_like(np_sst, 0.4**2)

In [None]:
np_error.shape

#### survey_time

In [None]:
np_survey_time = da_sst.time.values

In [None]:
np_survey_time

#### np_nobs

In [None]:
np_nobs = np.repeat(points_per_time, time_points)

In [None]:
np_nobs

#### np_lon and np_lat

In [None]:
np_lon = np.tile(ds_grid.lon_rho.values.flatten()[mask.flatten()==True], time_points)

In [None]:
np_lon.shape

In [None]:
np_lat = np.tile(ds_grid.lat_rho.values.flatten()[mask.flatten()==True], time_points)

In [None]:
np_lat.shape

#### Make a dataset

In [None]:
ds = xr.Dataset(
    {
        "spherical": 1,
        "Nobs": ("survey", np_nobs),
        "survey_time": ("survey", np_survey_time),
        "obs_variance": ("state_variable", wc13_obs.obs_variance.data),
        "obs_value": ("datum", np_sst),
        "obs_type": ("datum", np_type),
        "obs_provenance": ("datum", np_provenance),
        "obs_time": ("datum", np_time),
        "obs_depth": ("datum", np_depth),
        "obs_Xgrid": ("datum", np_xgrid),
        "obs_Ygrid": ("datum", np_ygrid),
        "obs_Zgrid": ("datum", np_zgrid),
        "obs_error": ("datum", np_error),
        "obs_lon": ("datum", np_lon),
        "obs_lat": ("datum", np_lat),
    },
)
ds

In [None]:
ds.to_netcdf(args.output_obs_file)