In [None]:
import glob
from datetime import datetime, timedelta

import numpy as np
import xarray as xr
import xesmf as xe
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style("darkgrid")

In [None]:
input_grid_file = '/cluster/projects/nn9490k/ROHO800/Grid/ROHO800_grid_fix5.nc'
input_data_file =  '/cluster/projects/nn9297k/shmiak/roho160_data/1_2017-01-15_to_2019-07-16/roho160_his_0045.nc'
roho800_forward = "/cluster/projects/nn9297k/shmiak/roho800_data/output_data/4dvar/4st_spring2018/*fwd_outer0*.nc"
roho800_da = "/cluster/projects/nn9297k/shmiak/roho800_data/output_data/4dvar/4st_spring2018/*fwd_outer1*.nc"

In [None]:
roho800_forward = sorted(glob.glob(roho800_forward))
roho800_da = sorted(glob.glob(roho800_da))

In [None]:
time_slice = slice('2018-04-01', '2018-04-07')

In [None]:
ds_grid = xr.open_dataset(input_grid_file)
ds_data = xr.open_dataset(input_data_file).sel(ocean_time=time_slice)
ds_roho800_before = xr.open_dataset(roho800_forward[0]).sel(ocean_time=time_slice)
ds_roho800_after = xr.open_dataset(roho800_da[0]).sel(ocean_time=time_slice)

In [None]:
ds_grid = ds_grid.rename({"lon_rho": "lon", "lat_rho": "lat"})
ds_data = ds_data.rename({"lon_rho": "lon", "lat_rho": "lat"})

In [None]:
regridder = xe.Regridder(ds_data, ds_grid, "bilinear", unmapped_to_nan=True)
da_temp_nature = regridder(ds_data['temp'])

In [None]:
da_temp_nature = da_temp_nature / ds_grid.mask_rho

In [None]:
np.count_nonzero(np.isfinite(da_temp_nature))

In [None]:
mask_nature = np.isfinite(da_temp_nature).astype(int)

In [None]:
mask_nature.isel(ocean_time=-1, s_rho=-1).plot()

In [None]:
da_temp_before = ds_roho800_before["temp"]
da_temp_after = ds_roho800_after["temp"]

In [None]:
da_temp_before = da_temp_before / mask_nature
da_temp_after = da_temp_after / mask_nature

In [None]:
np.count_nonzero(np.isfinite(da_temp_before))

In [None]:
np.count_nonzero(np.isfinite(da_temp_after))

In [None]:
ocean_time = -1
vmin, vmax = 0, 10

In [None]:
da_temp_nature.isel(ocean_time=ocean_time, s_rho=-1).plot(vmin=vmin, vmax=vmax, figsize=(14, 7))

In [None]:
da_temp_before.isel(ocean_time=ocean_time, s_rho=-1).plot(vmin=vmin, vmax=vmax, figsize=(14, 7))

In [None]:
da_temp_after.isel(ocean_time=ocean_time, s_rho=-1).plot(vmin=vmin, vmax=vmax, figsize=(14, 7))

In [None]:
da_temp_nature.values.shape

In [None]:
da_temp_before.values.shape

In [None]:
da_temp_after.values.shape

In [None]:
def rmsd(first, second):
    first = first[np.isfinite(first)].flatten()
    second = second[np.isfinite(second)].flatten()
    assert first.size == second.size
    return np.sqrt(
        np.sum(np.square(first - second)) / first.size
        )

In [None]:
xi_slice = slice(150, 300)  # 150, 300
eta_slice = slice(50, 150)  # 50, 150

In [None]:
rmsd_spatial_before = np.zeros(da_temp_nature.shape[0])
rmsd_spatial_after = np.zeros(da_temp_nature.shape[0])

In [None]:
for i in range(da_temp_nature.shape[0]):
    rmsd_spatial_before[i] = rmsd(da_temp_nature[i, :, eta_slice, xi_slice].values, da_temp_before[i, :, eta_slice, xi_slice].values)
    rmsd_spatial_after[i] = rmsd(da_temp_nature[i, :, eta_slice, xi_slice].values, da_temp_after[i, :, eta_slice, xi_slice].values)

In [None]:
na_time = np.arange(datetime(2018, 4, 1), datetime(2018, 4, 8), timedelta(hours=12)).astype(datetime)

In [None]:
fig, ax = plt.subplots(figsize=(10, 2))
ax.plot(na_time, rmsd_spatial_before, label="before da")
ax.plot(na_time, rmsd_spatial_after, label="after da")
ax.set_title('Spatially Integrated RMSD')
ax.legend()
fig.tight_layout()