# Optimize parameters for the reference case for VARIATIONAL method
- VARIATIONAL : 

In [None]:
import os
from glob import glob

import xarray as xr
import pandas as pd
import numpy as np

import hvplot.xarray
import hvplot.pandas

import matplotlib.pyplot as plt
from cycler import cycler

from sstats import signals as sg
from sstats import sigp as sigp
from sstats import tseries as ts
from sstats import get_cmap_colors

import hvplot.xarray
import hvplot.pandas
import holoviews as hv

import pynsitu as pyn
import lib as lib
from lib import KEYS, raw_dir, images_dir
import os
from glob import glob

import synthetic_traj as st

from synthetic_traj import synthetic_traj, noise_irregular_sampling, ref_case, typical_case

___________
# Reference case

In [None]:
N=20
acc_cut = 1
position_noise=20
ntype = 'white_noise'
offset_type = 'svp_scripps_10'
dt_smooth = '30min'
true_key = 'True_'+dt_smooth
spectral_diff = False
ref_case['spectral_diff'] = spectral_diff
print(ref_case)

In [None]:
DF = dict()

# TRUE
t = (50, '1min') # use it instead of (50, 1/24/60 because otherwise not regularly sampled
dst = synthetic_traj(t, N , **ref_case)# u,v, ax, ay computed
dft = st.dataset2dataframe(dst).rename(columns={'draw':'id'})
DF['True_1min'] = dft

#OBSERVED

dso = noise_irregular_sampling(dst, t, position_noise, ntype=ntype, offset_type=offset_type, istart=6097)
dfo = st.dataset2dataframe(dso).rename(columns={'draw':'id'})
dfo = dfo.groupby('id').apply(pyn.geo.compute_dt, time='index')

DF['Observed'] = dfo

# Interpolated True
dsti = dst.sel(time = pd.date_range(dfo.index.min(), dfo.index.max(), freq=dt_smooth))
dsti['dt']=dsti.time.diff('time')/pd.Timedelta('1s')
dfti = st.dataset2dataframe(dsti).rename(columns={'draw':'id'})
dfti = dfti.groupby('id').apply(pyn.geo.compute_velocities,time='index', distance='xy', names=('u', 'v', 'U'), fill_startend=True, centered=True, keep_dt=True)
dfti = dfti.groupby('id').apply(pyn.geo.compute_accelerations,from_ =('xy', 'x', 'y'), names=('ax', 'ay', 'Axy'), keep_dt=True)
dsti = dfti.reset_index().set_index(['time', 'id']).to_xarray()


DF['True_'+dt_smooth] = dfti

___________
# VARIATIONAL
Parameters : 
- acc_cut
- position_error = 20m
- acceleration_amplitude = 1e-5 > 5e-5
- acceleration_T in seconds = 5.5 ~T du synthetic traj ?
- time_chunk



In [None]:
np.sqrt((dst.ax**2).mean('draw').mean())

In [None]:
np.sqrt((dst.ay**2).mean('draw').mean())

In [None]:
acc_cut= 1
position_error=position_noise
acceleration_amplitude = 1e-5
acceleration_T = 5.5*86400
time_chunk=2

In [None]:
DF.keys()

_________
# 3D
## Compute $ \frac{\langle(\alpha-\alpha_t)^2\rangle}{\langle \alpha_t^2 \rangle} $ 


In [None]:
acc_A = [0.1e-5,0.2e-5, 0.3e-5, 0.4e-5] + list(np.arange(0.5,10)*1e-5)
pos_e = [5, 10, 20,40, 60, 80, 100]
acc_T = [0.05, 0.1, 0.25] + list(np.arange(0.5,10.5))

def f_3D() : 
    DSV = []
    for acca in acc_A:
        DSV_ = []
        for p in pos_e:
            DSV__ =[]
            for acct in acc_T : 
                dfv = pyn.drifters.smooth_all(dfo, 'variational',
                                          dt_smooth,
                                          parameters = dict(acc_cut=acc_cut, position_error=p, acceleration_amplitude = acca, acceleration_T = acct*86400, acc_cut_key = ('ax', 'ay', 'Axy')),
                                          spectral_diff = spectral_diff,
                                          geo = False)
                dfv = dfv.reset_index().set_index(['time', 'id'])
                dsv = dfv.to_xarray().assign_coords(dict(acceleration_T=acct,position_error =p, acceleration_A = acca)).expand_dims(['acceleration_T', 'position_error', 'acceleration_A'])
                DSV__.append(dsv)
            ds = xr.concat(DSV__, dim = 'acceleration_T')
            ds.to_netcdf(f'/Users/mdemol/code/PhD/insitu_drifters_trajectories/diagnostic_synth_traj/3D_variational/3D_variational_{acca}_{p}_{acct}.nc')
            print('ok')

In [None]:
f_3D()

_________
# 3D
## concat and image


In [None]:
dir_ = '/Users/mdemol/DATA_DRIFTERS/3D_variational/3D_variational_*.nc'
files = glob(dir_)
files

In [None]:
D = []
for f in files :
    D.append(xr.open_dataset(f))

In [None]:
len(files)

In [None]:
dst_ = (dsti**2).mean('time')
Dms = []
for d in D :
    Dms.append((((d-dsti)**2).mean('time')/dst_).mean('id'))

In [None]:
DMS =xr.combine_by_coords(Dms)

_________
# 3D
## Plotting x, y


In [None]:
np.log(DMS.x).plot(col="position_error", col_wrap=4)

In [None]:
DMS.x.plot.contour?

In [None]:
DMS.x.plot.contour(col="position_error", col_wrap=4, levels=1000, add_colorbar=True)

In [None]:
DMS.u.plot.contour(col="acceleration_T", col_wrap=4, levels=500, add_colorbar=True, vmax=0.01)

In [None]:
DMS.ax.plot.contour(col="position_error", col_wrap=4, levels=500, add_colorbar=True)

In [None]:
np.log(DMS.x).plot(col="position_error", col_wrap=4)

In [None]:
np.log(DMS.x).plot(col="acceleration_A", col_wrap=4)

In [None]:
np.log(DMS.x).plot(col="acceleration_T", col_wrap=4)

## Plotting u,v

In [None]:
np.log(DMS.u).plot(col="position_error", col_wrap=4)

In [None]:
np.log(DMS.u).plot(col="acceleration_A", col_wrap=4)

In [None]:
np.log(DMS.u).plot(col="acceleration_T", col_wrap=4)

## Plotting ax,ay

In [None]:
np.log(DMS.ax).plot(col="position_error", col_wrap=4)

In [None]:
np.log(DMS.ax).plot(col="acceleration_A", col_wrap=4)

In [None]:
np.log(DMS.ax).plot(col="acceleration_T", col_wrap=4,vmax=1)

# 3D plot

In [None]:
i=0
for v in ['x', 'y','u','v','ax','ay']:
    DMS.plot.scatter(z="acceleration_T", y="position_error", x="acceleration_A", hue=v, vmax=DMS[v].quantile(0.9), vmin=DMS[v].quantile(0.1))
    DMS.where(DMS[v] ==DMS[v].min(), drop=True)[v].plot.scatter(z="acceleration_T", y="position_error", x="acceleration_A", color='r', s=50)
    plt.show()
    print(DMS.where(DMS[v] ==DMS[v].min(), drop=True)[v])

In [None]:
mina

In [None]:
DMS['acceleration_T'] = DMS['acceleration_T'].assign_attrs(units='days', long_name =r'$\tau_A$')
DMS['acceleration_A'] = DMS['acceleration_A'].assign_attrs(units=r'$m.s^{-2}$', long_name =r'$\sigma_A$')
DMS['position_error'] = DMS['position_error'].assign_attrs(units='m', long_name =r'$\epsilon_X$')
DMS['ax'] = DMS['ax'].assign_attrs(units='', long_name =fr'$\langle (ax-ax_t)^2 \rangle / \langle ax_t^2 \rangle$')
DMS['ay'] = DMS['ay'].assign_attrs(units='', long_name =fr'$\langle (ay-ay_t)^2 \rangle / \langle ay_t^2 \rangle$')

# acceleration minimum section

In [None]:
v ='ax'
fig, axs = plt.subplots(1,3, figsize=(14,3.5))
axs =axs.flatten()
mina = DMS.where(DMS[v] ==DMS[v].min(), drop=True)[v]
ax=axs[0]
DMS[v].sel(acceleration_A = mina.acceleration_A).plot(ax=ax)
ax.plot(mina.position_error, mina.acceleration_T, marker='x', color='r', markersize=10)
ax.set_title(DMS.acceleration_A.attrs['long_name'] +' = '+ str(float(mina.acceleration_A)) +' ['+DMS.acceleration_A.attrs['units']+']')
ax=axs[1]
DMS[v].sel(acceleration_T = mina.acceleration_T).plot(ax=ax)
ax.plot(mina.acceleration_A, mina.position_error, marker='x', color='r', markersize=10)
ax.set_title(DMS.acceleration_T.attrs['long_name'] +' = '+ str(float(mina.acceleration_T)) +' ['+DMS.acceleration_T.attrs['units']+']')
ax=axs[2]
DMS[v].sel(position_error = mina.position_error).plot(ax=ax)
ax.plot(mina.acceleration_A, mina.acceleration_T, marker='x', color='r', markersize=10)
ax.set_title(DMS.position_error.attrs['long_name'] +' = '+ str(float(mina.position_error)) +' ['+DMS.position_error.attrs['units']+']')
fig.tight_layout(w_pad = 0)
fig.savefig(os.path.join(images_dir, 'VARIATIONAL_opt_ax.png'), dpi =200, bbox_inches="tight")

In [None]:
v ='ay'
fig, axs = plt.subplots(1,3, figsize=(14,3.5))
axs =axs.flatten()
mina = DMS.where(DMS[v] ==DMS[v].min(), drop=True)[v]
ax=axs[0]
DMS[v].sel(acceleration_A = mina.acceleration_A).plot(ax=ax)
ax.plot(mina.position_error, mina.acceleration_T, marker='x', color='r', markersize=10)
ax.set_title(DMS.acceleration_A.attrs['long_name'] +' = '+ str(float(mina.acceleration_A)) +' ['+DMS.acceleration_A.attrs['units']+']')
ax=axs[1]
DMS[v].sel(acceleration_T = mina.acceleration_T).plot(ax=ax)
ax.plot(mina.acceleration_A, mina.position_error, marker='x', color='r', markersize=10)
ax.set_title(DMS.acceleration_T.attrs['long_name'] +' = '+ str(float(mina.acceleration_T)) +' ['+DMS.acceleration_T.attrs['units']+']')
ax=axs[2]
DMS[v].sel(position_error = mina.position_error).plot(ax=ax)
ax.plot(mina.acceleration_A, mina.acceleration_T, marker='x', color='r', markersize=10)
ax.set_title(DMS.position_error.attrs['long_name'] +' = '+ str(float(mina.position_error)) +' ['+DMS.position_error.attrs['units']+']')
fig.tight_layout(w_pad=0)
fig.savefig(os.path.join(images_dir, 'VARIATIONAL_opt_ay.png'), dpi =200, bbox_inches="tight")