In [1]:
import numpy as np
import xarray as xr
import sys

sys.path.insert(0, "/home/mtissot/SpinUp/jumper/lib")
import matplotlib.pyplot as plt

# **LOAD RESTART FILES**

In [None]:
dataset1 = xr.open_dataset(
    "/data/mtissot/infos4restart/data_restart/OCE_CM65v420-LR-CdL-pi-01_19141231_restart.nc",
    decode_times=False,
)
mask = xr.open_dataset(
    "/data/mtissot/infos4restart/eORCA1.4.2_mesh_mask_modJD.nc", decode_times=False
)
dataset2 = xr.open_dataset(
    "/data/mtissot/infos4restart/data_restart/NEW_OCE_CM65v420-LR-CdL-pi-01_19141231_restart.nc",
    decode_times=False,
)

In [None]:
print("restart features : \n", list(dataset2.keys()))
print("\nmask features : \n", list(mask.keys()))

# **ANALYSE TRUTH VS PREDICTIONS**

rhop,u,v,e3t,ssh,T,S.

### IN SITU DENSITY (rhop)

In [None]:
new = dataset2.rhop.where(mask.tmask.values)
old = dataset1.rhop.where(mask.tmask.values)

diff_new = np.diff(new.isel(time_counter=0), axis=0)
diff_old = np.diff(old.isel(time_counter=0), axis=0)

val = [old[0], new[0]]
diff = [diff_old, diff_new]

In [None]:
fig = plt.figure(figsize=(4, 5))
ax = plt.gca()

ax.plot(
    np.nanmean(val[0], axis=(1, 2)),
    dataset1.nav_lev,
    linestyle="dashed",
    color="black",
    alpha=0.7,
    linewidth=3,
    label="truth",
)
ax.plot(
    np.nanmean(val[1], axis=(1, 2)),
    dataset2.nav_lev,
    color="purple",
    alpha=0.8,
    linewidth=2,
    label="predictions",
)
ax.invert_yaxis()
ax.invert_xaxis()

# ax.set_xlim(left=1)

ax.yaxis.set_label_position("right")
ax.yaxis.tick_right()
ax.legend()
ax.set_title("Average")
plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np

fig, axes = plt.subplots(1, 3, figsize=(9, 4))


ax = axes[0]
ax.plot(
    np.nanmean(diff[0], axis=(1, 2)),
    dataset1.nav_lev[:-1],
    linestyle="dashed",
    color="black",
    alpha=0.7,
    linewidth=3,
    label="truth",
)
ax.plot(
    np.nanmean(diff[1], axis=(1, 2)),
    dataset2.nav_lev[:-1],
    color="purple",
    alpha=0.8,
    linewidth=2,
    label="predictions",
)
ax.invert_yaxis()
ax.invert_xaxis()
ax.yaxis.tick_right()
ax.legend()


ax = axes[1]
for i in range(0, 360, 40):
    for j in range(0, 331, 28):
        if not np.isnan(diff[0][:30, j, i]).any():
            l = len(diff[0][:, j, i])
            ax.plot(diff[0][:, j, i], dataset2.nav_lev[:l])
ax.invert_yaxis()
ax.invert_xaxis()
ax.yaxis.tick_right()


ax = axes[2]
for i in range(0, 360, 30):
    for j in range(0, 331, 28):
        if not np.isnan(diff[1][:30, j, i]).any():
            l = len(diff[1][:, j, i])
            ax.plot(diff[1][:, j, i], dataset1.nav_lev[:l])
ax.invert_yaxis()
ax.invert_xaxis()
ax.yaxis.tick_right()

plt.tight_layout()
plt.show()

### U VELOCITIES

In [None]:
new = dataset2.un.where(mask.umask.values)
old = dataset1.un.where(mask.umask.values)

diff_new = np.diff(new.isel(time_counter=0), axis=0)
diff_old = np.diff(old.isel(time_counter=0), axis=0)

val = [old[0], new[0]]
diff = [diff_old, diff_new]

In [None]:
fig = plt.figure(figsize=(4, 5))
ax = plt.gca()

ax.plot(
    np.nanmean(val[0], axis=(1, 2)),
    dataset1.nav_lev,
    linestyle="dashed",
    color="black",
    alpha=0.7,
    linewidth=3,
    label="truth",
)
ax.plot(
    np.nanmean(val[1], axis=(1, 2)),
    dataset2.nav_lev,
    color="purple",
    alpha=0.8,
    linewidth=2,
    label="predictions",
)
ax.invert_yaxis()
ax.invert_xaxis()

# ax.set_xlim(left=1)

ax.yaxis.set_label_position("right")
ax.yaxis.tick_right()
ax.legend()
plt.show()

In [None]:
fig = plt.figure(figsize=(4, 5))
ax = plt.gca()

# l = len(rhop_new[:,j,i])
ax.plot(
    np.nanmean(diff[0], axis=(1, 2)),
    dataset1.nav_lev[:-1],
    linestyle="dashed",
    color="black",
    alpha=0.7,
    linewidth=3,
    label="truth",
)
ax.plot(
    np.nanmean(diff[1], axis=(1, 2)),
    dataset2.nav_lev[:-1],
    color="purple",
    alpha=0.8,
    linewidth=2,
    label="predictions",
)
ax.invert_yaxis()
ax.invert_xaxis()

# ax.set_xlim(left=1)

ax.yaxis.set_label_position("right")
ax.yaxis.tick_right()
ax.legend()
plt.show()

### V VELOCITIES

In [None]:
new = dataset2.vn.where(mask.vmask.values)
old = dataset1.vn.where(mask.vmask.values)

diff_new = np.diff(new.isel(time_counter=0), axis=0)
diff_old = np.diff(old.isel(time_counter=0), axis=0)

val = [old[0], new[0]]
diff = [diff_old, diff_new]

In [None]:
fig = plt.figure(figsize=(4, 5))
ax = plt.gca()

ax.plot(
    np.nanmean(val[0], axis=(1, 2)),
    dataset1.nav_lev,
    linestyle="dashed",
    color="black",
    alpha=0.7,
    linewidth=3,
    label="truth",
)
ax.plot(
    np.nanmean(val[1], axis=(1, 2)),
    dataset2.nav_lev,
    color="purple",
    alpha=0.8,
    linewidth=2,
    label="predictions",
)
ax.invert_yaxis()
ax.invert_xaxis()

# ax.set_xlim(left=1)

ax.yaxis.set_label_position("right")
ax.yaxis.tick_right()
ax.legend()
plt.show()

In [None]:
fig = plt.figure(figsize=(4, 5))
ax = plt.gca()

# l = len(rhop_new[:,j,i])
ax.plot(
    np.nanmean(diff[0], axis=(1, 2)),
    dataset1.nav_lev[:-1],
    linestyle="dashed",
    color="black",
    alpha=0.7,
    linewidth=3,
    label="truth",
)
ax.plot(
    np.nanmean(diff[1], axis=(1, 2)),
    dataset2.nav_lev[:-1],
    color="purple",
    alpha=0.8,
    linewidth=2,
    label="predictions",
)
ax.invert_yaxis()
ax.invert_xaxis()

# ax.set_xlim(left=1)

ax.yaxis.set_label_position("right")
ax.yaxis.tick_right()
ax.legend()
plt.show()

----