In [1]:
import xarray as xr

import numpy as np
import matplotlib.pyplot as plt
from glob import glob
from dask.diagnostics import ProgressBar

import speccy
import nonstat_itides_jax as nsjax
import nonstat_itides as nsit
import optax

In [2]:
def load_scenario(scenario):
    ncfiles = {'wp5':
             [
             '../DATA/ROMS/wp5/t1/file_inst_00010101-00010219.nc',
             '../DATA/ROMS/wp5/t1/file_inst_00010220-00010410.nc',
            '../DATA/ROMS/wp5/t2/file_inst_00010101-00010219.nc',
            '../DATA/ROMS/wp5/t2/file_inst_00010220-00010410.nc',
              #'../DATA/ROMS/wp5/t3/file_inst_00010101-00010219.nc',
              #'../DATA/ROMS/wp5/t3/file_inst_00010220-00010410.nc'
             ],
               'wp6':
             [
             '../DATA/ROMS/wp6/t1/inst_00010101-00010219.nc',
             '../DATA/ROMS/wp6/t1/inst_00010220-00010410.nc',
            '../DATA/ROMS/wp6/t2/inst_00010101-00010219.nc',
            '../DATA/ROMS/wp6/t2/inst_00010220-00010410.nc',
             ],
               'wp75':
             [
             '../DATA/ROMS/wp75/t1/inst_00010101-00010219.nc',
             '../DATA/ROMS/wp75/t1/inst_00010220-00010410.nc',
            '../DATA/ROMS/wp75/t2/inst_00010101-00010219.nc',
            '../DATA/ROMS/wp75/t2/inst_00010220-00010410.nc',
             ],
                'wp8':
             [
             '../DATA/ROMS/wp8/t1/file_inst_00010101-00010219.nc',
             '../DATA/ROMS/wp8/t1/file_inst_00010220-00010410.nc',
            '../DATA/ROMS/wp8/t2/file_inst_00010101-00010219.nc',
            '../DATA/ROMS/wp8/t2/file_inst_00010220-00010410.nc',
              #'../DATA/ROMS/wp8/t3/file_inst_00010101-00010219.nc',
              #'../DATA/ROMS/wp8/t3/file_inst_00010220-00010410.nc'
             ],
               'wp9':
             [
             '../DATA/ROMS/wp9/t1/file_inst_00010101-00010219.nc',
             '../DATA/ROMS/wp9/t1/file_inst_00010220-00010410.nc',
            '../DATA/ROMS/wp9/t2/file_inst_00010101-00010219.nc',
            '../DATA/ROMS/wp9/t2/file_inst_00010220-00010410.nc',
              #'../DATA/ROMS/wp9/t3/file_inst_00010101-00010219.nc',
              #'../DATA/ROMS/wp9/t3/file_inst_00010220-00010410.nc'
             ],
              }
    
    ds0 = xr.open_mfdataset(ncfiles[scenario][0:2])
    ds1 = xr.open_mfdataset(ncfiles[scenario][2:4])
    
    dt1 = ds0['time_instant'][-1]-ds0['time_instant'][0]
    ds1['time_instant'] = ds1['time_instant']+dt1
    ds1['time_counter'] = ds1['time_counter']+dt1

    if len(ncfiles[scenario])>4:
        ds2 = xr.open_mfdataset(ncfiles[scenario][4:6])
        dt2 = ds1['time_instant'][-1]-ds1['time_instant'][0]
        
        ds2['time_instant'] = ds2['time_instant']+dt1+dt2
        ds2['time_counter'] = ds2['time_counter']+dt1+dt2

        return xr.concat([ds0, ds1, ds2], dim='time_counter')

    else:
        return xr.concat([ds0, ds1], dim='time_counter')
        

In [3]:
def estimate_spectral_params_jax(y,  X=None, covfunc=None, covparams_ic=None, fmin=None, fmax=None,
                                transformer=nsjax.LogTransformer,
                                opt=optax.adabelief(learning_rate=1e-1)):
    
    params, loss_val = nsjax.estimate_jax(y, X, covfunc, covparams_ic, fmin, fmax,
                        maxiter=5000,
                         opt= opt,
                         verbose=False,
                        transformer=transformer)

    return params
    #return np.concatenate([params, np.array([f_cor_cpd]) ])


def estimate_params(scenario, covfunc, paramnames, fmin, fmax, window=None, 
                    transformer=nsjax.LogTransformer, varname='v_y'):
    ds = load_scenario(scenario)
    
    # Load a subset of the data
    ds_out = ds.sel(y_vy=slice(250,500,1))
    
    # Calculate the coherent portion of the signal
    y = ds_out[varname].chunk({'y_vy':1,'time_counter':-1}).squeeze()

    t_ = y['time_instant'] - y['time_instant'][0]
    X = t_.values.astype(float)/1e9/86400
    
    y_coherent,_,_,_,_ = nsit.calc_coherent(y, X)
    
    inputs = dict(X=X, covfunc=covfunc, covparams_ic=covparams_ic, fmin=fmin, fmax=fmax, transformer=transformer)
    
    print('\tEstimating parameters...')
    params = xr.apply_ufunc(estimate_spectral_params_jax,
                         y-y_coherent,
                          dask='parallelized',
                          kwargs=inputs,
                          output_dtypes=[y.dtype],
                          input_core_dims=(['time_counter'],),
                          output_core_dims=(['params'],),
                          #output_sizes=,
                          dask_gufunc_kwargs={'output_sizes':{'params':len(covparams_ic)}},
                          vectorize=True,
                         )
    
    with ProgressBar():
        params = params.compute()
    
    params.name='data'
    params.attrs={'parameter names':paramnames,
                              'covariance function':covfunc.__name__}
    
    ## Export the data 
    nt = X.shape[0]
    ny,_ = params.shape
    da_vy_coh = xr.DataArray(ds_out['v_y'])
    da_vy_coh[:] = y_coherent.reshape((nt,1,ny,1))
    da_vy_coh.attrs['long_name'] = 'tidally-coherent v-momentum component'
    da_vy_coh.name = 'v_y_coherent'
    
    ds_paramsout = ds_out.drop_vars(['ssh_y','T_y','u_y','w_y'])
    ds_paramsout['data'] = params
    ds_paramsout['v_y_coherent'] = da_vy_coh
    
    return ds_paramsout

In [4]:
# Use a logit transform

##
covfunc = nsjax.itide_D2_meso_gammaexp
paramnames = ('η_m','τ_m','γ_m', 'η_D2','τ_D2','γ_D2')
covparams_ic = (0.1, 10, 1.5, 0.1, 10, 1.5)
fmin, fmax = 5e-3, 3.

varname = 'u_y'

for scenario in ['wp5','wp6','wp75','wp8','wp9']:
#for scenario in ['wp9']:

    ds_paramsout = estimate_params(scenario, covfunc, paramnames, fmin, fmax, transformer=nsjax.CustomTransformer)
    output_nc = '../DATA/ROMS/{}_params_{}_{}_v2.nc'.format(scenario,covfunc.__name__, varname)
    ds_paramsout.to_netcdf(output_nc)
    print(output_nc)



	Estimating parameters...
[########################################] | 100% Completed | 19m 54ss
../DATA/ROMS/wp5_params_itide_D2_meso_gammaexp_u_y_v2.nc




	Estimating parameters...
[########################################] | 100% Completed | 249.57 s
../DATA/ROMS/wp6_params_itide_D2_meso_gammaexp_u_y_v2.nc




	Estimating parameters...
[########################################] | 100% Completed | 466.32 s
../DATA/ROMS/wp75_params_itide_D2_meso_gammaexp_u_y_v2.nc




	Estimating parameters...
[########################################] | 100% Completed | 10m 11ss
../DATA/ROMS/wp8_params_itide_D2_meso_gammaexp_u_y_v2.nc




	Estimating parameters...
[########################################] | 100% Completed | 10m 47ss
../DATA/ROMS/wp9_params_itide_D2_meso_gammaexp_u_y_v2.nc


In [None]:
covfunc = nsjax.oscillate_1d_gammaexp
paramnames = ('η_D2','τ_D2','γ_D2','T_D2')
covparams_ic = (0.1, 10, 1.5, 0.5)
fmin, fmax = 1.5, 2.5

for scenario in ['wp5','wp6','wp75','wp8','wp9']:
    ds_paramsout = estimate_params(scenario, covfunc, paramnames, fmin, fmax)
    output_nc = '../DATA/ROMS/{}_params_{}.nc'.format(scenario,covfunc.__name__)
    ds_paramsout.to_netcdf(output_nc)
    print(output_nc)

In [None]:
covfunc = nsjax.itide_D2_meso_gammaexp_fixed
paramnames = ('η_m','τ_m','γ_m', 'η_D2','τ_D2')
covparams_ic = (0.1, 10, 1.5, 0.1, 10)
fmin, fmax = 5e-3, 3.

for scenario in ['wp5','wp6','wp75','wp8','wp9']:
# for scenario in ['wp6','wp75','wp8','wp9']:

    ds_paramsout = estimate_params(scenario, covfunc, paramnames, fmin, fmax)
    output_nc = '../DATA/ROMS/{}_params_{}.nc'.format(scenario,covfunc.__name__)
    ds_paramsout.to_netcdf(output_nc)
    print(output_nc)

In [9]:
class CustomTransformer2:
    def __init__(self,params):
        self.params = params

    def __call__(self):
        params_t = nsjax.np.log(self.params)
        params_t = params_t.at[2].set(nsjax.invlogit(self.params[2],scale=2))
        return params_t
        
    def out(self, tparams):
        params = nsjax.np.exp(tparams)
        params = params.at[2].set(nsjax.logit(tparams[2],scale=2))
        return params
        
covfunc = nsjax.itide_D2_meso_gammaexp_fixed
paramnames = ('η_m','τ_m','γ_m', 'η_D2','τ_D2')
covparams_ic = (0.1, 10, 1.5, 0.1, 10)
fmin, fmax = 5e-3, 3.

for scenario in ['wp5','wp6','wp75','wp8','wp9']:
# for scenario in ['wp6','wp75','wp8','wp9']:

    ds_paramsout = estimate_params(scenario, covfunc, paramnames, fmin, fmax, transformer=CustomTransformer2)
    output_nc = '../DATA/ROMS/{}_params_{}_v2.nc'.format(scenario,covfunc.__name__)
    ds_paramsout.to_netcdf(output_nc)
    print(output_nc)



	Estimating parameters...
[########################################] | 100% Completed | 77m 41ss
../DATA/ROMS/wp5_params_itide_D2_meso_gammaexp_fixed_v2.nc




	Estimating parameters...
[########################################] | 100% Completed | 42m 20ss
../DATA/ROMS/wp6_params_itide_D2_meso_gammaexp_fixed_v2.nc




	Estimating parameters...
[########################################] | 100% Completed | 277.41 s
../DATA/ROMS/wp75_params_itide_D2_meso_gammaexp_fixed_v2.nc




	Estimating parameters...
[########################################] | 100% Completed | 323.29 s
../DATA/ROMS/wp8_params_itide_D2_meso_gammaexp_fixed_v2.nc




	Estimating parameters...
[########################################] | 100% Completed | 338.96 s
../DATA/ROMS/wp9_params_itide_D2_meso_gammaexp_fixed_v2.nc


# Testing below here...

In [None]:
ds.time_instant.values

In [None]:
plt.figure()
ds.rot_y.plot()
# plt.savefig('../FIGURES/PK2015_ROMS_vorticity_{}.png'.format(scenario), dpi=150)

In [None]:
plt.figure()
ds.v_y.plot()
# plt.savefig('../FIGURES/PK2015_ROMS_v_{}.png'.format(scenario), dpi=150)

In [None]:
ypt = 400
#y = ds['ssh_y'].sel(y_rhoy=ypt)
y = ds['v_y'].sel(y_vy=ypt)
dt = y['time_instant'][1] - y['time_instant'][0]
dtout = dt.values.astype(float)/1e9/86400
ypr = y.values.ravel()-y.values.mean()
t_ = y['time_instant'] - y['time_instant'][0]
t = t_.values.astype(float)/1e9/86400
plt.figure()
y.plot()

# plt.xlim(y['time_instant'][0], y['time_instant'][500])
# plt.savefig('../FIGURES/PK2015_ROMS_v_timeseries_{}.png'.format(scenario), dpi=150)

In [None]:

dt = y['time_instant'][1] - y['time_instant'][0]
dtout = dt.values.astype(float)/1e9/86400

y_coherent = nsit.calc_coherent(ypr, t)

f, I =  speccy.periodogram(ypr-y_coherent[:,0], delta=dtout)

plt.figure()
ax1=plt.subplot(111)
plt.loglog(f, I)

In [None]:
import optax

In [None]:
%%time
# covfunc = meso_itide_gamma
# covfunc = meso_itide_matern
covfunc = nsjax.itide_D2_meso_gammaexp
paramnames = ('η_m','τ_m','γ_m', 'η_D2','τ_D2','γ_D2')

covparams_ic = (0.1, 10, 1.5, 0.1, 10, 1.5)
fmin, fmax = 5e-3, 3.
priors=None


acf1 = covfunc(t, t[0], covparams_ic)
f_S1, S1 = speccy.bochner(acf1, delta=dtout)

soln,loss_val = nsjax.estimate_jax(ypr-y_coherent[:,0], t, covfunc, covparams_ic, fmin, fmax,
                window=None,
                verbose=True,
                maxiter=5000,
                ftol=1e-2,
                opt=optax.adabelief(learning_rate=1e-1),
                #opt= optax.sgd(learning_rate=3e-4),
                #transformer=LogTransformer)
            )
# soln = params_loss[:-1]
# loss_val = params_loss[-1]

print(soln)
acf = covfunc(t, t[0], soln)

f_S, S = speccy.bochner(acf, delta=dtout)
plt.figure()
plt.loglog(f, I, lw=0.5)
# plt.loglog(f_S1,S1)
plt.loglog(f_S,S,'--')

# plt.xlim(fmin, fmax)
plt.ylim(1e-9,1e1)
plt.vlines(fmax, 1e-9,1e1,colors='k',ls=':')

plt.xlabel('f [cpd]')

# # plt.savefig('../FIGURES/PK2015_ROMS_psd_maternitgamma_{}.png'.format(scenario), dpi=150)

In [None]:
ds_out = ds.sel(y_vy=slice(250,500,10))
ds_out


In [None]:
%%time
y = ds_out['v_y'].chunk({'y_vy':1,'time_counter':-1}).squeeze()
y_coherent = nsit.calc_coherent(y, t)


In [None]:
from dask.diagnostics import ProgressBar

In [None]:
window=None
X = t

inputs = dict(X=X, covfunc=covfunc, covparams_ic=covparams_ic, fmin=fmin, fmax=fmax)

def estimate_spectral_params_jax(y,  X=X, covfunc=covfunc, covparams_ic=covparams_ic, fmin=fmin, fmax=fmax):
    
    params, loss_val = nsjax.estimate_jax(y, X, covfunc, covparams_ic, fmin, fmax,
                        maxiter=5000,
                         opt= optax.adabelief(learning_rate=1e-1),
                         verbose=False)

    return params
    #return np.concatenate([params, np.array([f_cor_cpd]) ])

print('\tBuilding the dask graph...')
params = xr.apply_ufunc(estimate_spectral_params_jax,
                     y-y_coherent,
                      dask='parallelized',
                      kwargs=inputs,
                      output_dtypes=[y.dtype],
                      input_core_dims=(['time_counter'],),
                      output_core_dims=(['params'],),
                      #output_sizes=,
                      dask_gufunc_kwargs={'output_sizes':{'params':len(covparams_ic)}},
                      vectorize=True,
                     )

with ProgressBar():
    params = params.compute()


In [None]:
params.name='data'
params.attrs={'parameter names':paramnames,
                          'covariance function':covfunc.__name__}


In [None]:
## Export the data 
nt = t.shape[0]
ny,_ = params.shape
da_vy_coh = xr.DataArray(ds_out['v_y'])
da_vy_coh[:] = y_coherent.reshape((nt,1,ny,1))
da_vy_coh.attrs['long_name'] = 'tidally-coherent v-momentum component'
da_vy_coh.name = 'v_y_coherent'

ds_paramsout = ds_out.drop_vars(['ssh_y','T_y','u_y','w_y'])
ds_paramsout['data'] = params
ds_paramsout['v_y_coherent'] = da_vy_coh

ds_paramsout

### Old stuff below here

In [None]:
from tqdm import tqdm


In [None]:
ds_out = ds.sel(y_vy=slice(250,500,2))

In [None]:
nparams = len(covparams)
paramsout = np.zeros((ds_out.dims['y_vy'], nparams))
for jj in tqdm(range(ds_out.dims['y_vy'])):
    #t, y, ylow = calc_raw(ds_nonstat, ii, jj, dtout, 'ssh')
    y = ds_out['v_y'].isel(y_vy=jj)
    ypr = y.values.ravel()-y.values.mean()
    y_coherent = nsit.calc_coherent(ypr, t)
    soln = nsit.estimate_spectral_params_whittle_ufunc(ypr-y_coherent, **inputs)
    paramsout[jj,:] = soln

In [None]:
labels=['η_m','ℓ_m', 'ν_m','η_i','ℓ_i', 'γ_i']
x = ds_out['nav_lat_vy'].values/1000

plt.figure(figsize=(9,6))
ax=plt.subplot(611)
plt.plot(x,paramsout[:,0])
plt.ylabel(labels[0])
ax.set_xticklabels([])

ax=plt.subplot(612)
plt.plot(x,paramsout[:,1])
plt.ylabel(labels[1])
ax.set_xticklabels([])
plt.ylim(bounds[1])
        
ax=plt.subplot(613)
plt.plot(x,paramsout[:,2])
plt.ylabel(labels[2])
ax.set_xticklabels([])
plt.ylim(bounds[2])

ax=plt.subplot(614)
plt.plot(x,paramsout[:,3])
plt.ylabel(labels[3])
ax.set_xticklabels([])

ax=plt.subplot(615)
plt.plot(x,paramsout[:,4])
plt.ylabel(labels[4])
ax.set_xticklabels([])
plt.ylim(bounds[4])


plt.subplot(616)
plt.plot(x,paramsout[:,5])
plt.ylabel(labels[5])
plt.ylim(bounds[5])

plt.xlabel('y [km]')
plt.savefig('../FIGURES/PK2015_ROMS_params_maternitgamma_{}.png'.format(scenario), dpi=150)

In [None]:
import seaborn as sns

In [None]:
import pandas as pd

In [None]:
sns.pairplot(pd.DataFrame(paramsout, columns=labels))