# Script to compare nn simu with reference simu
This script run the model with an NN and a reference simu

In [None]:
## Import package
from neuralsw.model.shalw import SWmodel
from neuralsw.model.datatools import TimeSeq 
from neuralsw.model.shalwnet import SWparnnim
import neuralsw
import numpy as np
import xarray as xr
import os
from os.path import join
import glob
from tqdm import tqdm
import matplotlib.pyplot as plt
from scipy.stats import ttest_rel
%matplotlib notebook


In [None]:
## Specify the output
PLOT = True #if plot is wanted

#rootdir
rootdir = os.path.realpath(\
	os.path.join(os.getcwd(),'../..'))

#directory to store the data
datadir = os.path.realpath(os.path.join(rootdir,'data'))

#Duration of the integration
endtime = 48*30*12*10 #10 years (can be override by option in allsets)

#save frequency
freq = 48*30 #each month (can be overrieded by option in allsets)

#number of successive time steps to save at a given time
nseq = 1

#param to save
param = {'hphy','uphy','vphy','taux','tauy','uparam','vparam'}
print('data directory:',datadir)


## Option in allsets
### Mandatory:
- `'suf'`: suffix of the simu filename
- `'type'`: prefix of the simul filename (defined the type of set train/test/valid
- `'rst'`: restart file to initialize the model (generate by `restart` script)

### Optional:
-  `'saverst'`: save a restart file at the end of the simu
-  `'endtime'`: change the default value for the end of the simu
-  `'freq'`: change the default frequency of save
- `'warg'`: specify forcing args in the SW model 

In [None]:
## Specifiy types of dataset to run

warg0 = {'sigx':0.1} #a example of forcing arg
warg1 = {'sigx':0.1,'arx':0.6,'convx':5}
warg2 = {'taux0':0.1}

#restartfile for app
rstfile = os.path.join(datadir,'restart_10years.nc')

#restartfile for test
rstfile_test = os.path.join(datadir,'restart_test.nc')

#restarfile for low wind
rstfile_low = os.path.join(datadir,'restart_10years_low.nc')

##############################################################
# SPECIFY THE TYPE OF DATASET TO BE GENEREATED (SEE ABOVE)   #
##############################################################
allsets = {0:{'suf':'std','rst':rstfile,},
           1:{'suf':'windvar','endtime':48*30*12*10,'warg':warg0,'rst':rstfile_test},
           2:{'suf':'warsmooth','endtime':48*30*12*10,'warg':warg1,'rst':rstfile_test},
           3:{'suf':'std','endtime':48*30*12*10,'rst':rstfile_test},
           4:{'suf':'windlow','endtime':48*30*12*10,'warg':warg2,'rst':rstfile_low}}

selected = 4 #to be change to create more sets
###############################################################
           
dset = allsets[selected]



In [None]:
## Define the neural net

#For u
nnid = '0'
oparam = 'uparam'
#noise = 'nonoise'
noise = 'noise01'
wtype = 'std'
uname = '_'.join(['nn'+nnid,oparam,noise,wtype])

#For v (change only parameters that are different)
oparam = 'vparam'
vname = '_'.join(['nn'+nnid,oparam,noise,wtype])

nnufile = glob.glob(join(datadir,uname,'*.pkl'))[0]
nnvfile = glob.glob(join(datadir,vname,'*.pkl'))[0]

print('uparam network:',nnufile)
print('vparam network:',nnvfile)

In [None]:
## redefine some options
if 'endtime' in dset:
    endtime = dset['endtime']
if 'freq' in dset:
    freq = dset['freq']
if 'nseq' in dset:
    nseq = dset['nseq']
if 'warg' in dset:
    warg = dset['warg']
else:
    warg = dict()    

In [None]:
#define output filenames
#filename
fnameroot = os.path.join(datadir,'run_'+ noise + '_' +dset['suf'])


print('filename:',fnameroot)
fname_00 = fnameroot+'_00.nc'
fname_nn = fnameroot+'_nn.nc'


In [None]:
## Init model for simul
SW0 = SWmodel(nx=80,ny=80,warg=warg)
SW0.inistate_rst(dset['rst'])
SW0.set_time(0)

SW = SWparnnim(nnupar=nnufile,nnvpar=nnvfile,warg=warg,nx=80,ny=80)
SW.inistate_rst(dset['rst'])
SW.set_time(0)

#time
time = TimeSeq(endtime=endtime,freq=freq,start=0,nseq=nseq)

#Save every freq
if os.path.isfile(fname_00):
    os.remove(fname_00)
if os.path.isfile(fname_nn):
    os.remove(fname_nn)
SW0.save(time=time,name=fname_00,para=param)
SW.save(time=time,name=fname_nn,para=param)


In [None]:
# run the model for training set
for i in tqdm(range(endtime)):
    SW0.next()
    SW.next()

In [None]:
# Save the restart for test set
if 'saverst' in dset:
    SW.save_rst(dset['saverst'])


In [None]:
## Plots conservative quantities for training set
if PLOT:
    import neuralsw.model.modeltools as model

    ds0 = xr.open_dataset(fname_00)
    dsn = xr.open_dataset(fname_nn)

    fig,ax = plt.subplots(nrows=3,sharex=True)
    Ec0 = model.cinetic_ener(ds=ds0)
    Ep0 = model.potential_ener(ds=ds0)
    Pv0 = model.potential_vor(ds=ds0)
    Ecn = model.cinetic_ener(ds=dsn)
    Epn = model.potential_ener(ds=dsn)
    Pvn = model.potential_vor(ds=dsn)
    Ec0.plot(ax=ax[0],label='physical based')
    Ep0.plot(ax=ax[1],label='physical based')
    Pv0.plot(ax=ax[2],label='physical based')
    Ecn.plot(ax=ax[0],label='nn based')
    Epn.plot(ax=ax[1],label='nn based')
    Pvn.plot(ax=ax[2],label='nn based')
    ax[0].set_title('mean kinetic energy')
    ax[0].set_ylabel('Ec')
    ax[0].set_xlabel('')
    ax[0].legend()
    ax[1].set_title('mean potential energy')
    ax[1].set_ylabel('Ep')
    ax[1].set_xlabel('')
    ax[1].legend()
    ax[2].set_title('mean potential vorticity')
    ax[2].set_ylabel('Pv')
    ax[2].legend()
    plt.show()
    _,pval_Ec = ttest_rel(Ec0,Ecn)
    _,pval_Ep = ttest_rel(Ep0,Epn)
    _,pval_PV = ttest_rel(Pv0,Pvn)

    print ('Mean value of Ec (phys/nn): {0:3.2e}/{1:3.2e} (pval={2:3.2e})'.format(float(Ec0.mean()),
                                                                                  float(Ecn.mean()),
                                                                                 pval_Ec))
    print ('Mean value of Ep (phys/nn): {0:3.2e}/{1:3.2e} (pval={2:3.2e})'.format(float(Ep0.mean()),
                                                                                  float(Epn.mean()),
                                                                                 pval_Ep))
    print ('Mean value of Pv (phys/nn): {0:3.2e}/{1:3.2e} (pval={2:3.2e})'.format(float(Pv0.mean()),
                                                                                  float(Pvn.mean()),
                                                                                 pval_PV))

In [None]:
## Plots mean states for training
if PLOT:
    #fname_00  ='/net/argos/data/parvati/jbrlod/jbrlod/these/postdoc/collaborations/bigdata/sw/data/run_std_00.nc'
    #fname_nn  = '/net/argos/data/parvati/jbrlod/jbrlod/these/postdoc/collaborations/bigdata/sw/data/run_warsmooth_00.nc'
    
    ds0 = xr.open_dataset(fname_00)
    dsn = xr.open_dataset(fname_nn)
    
    fig,ax = plt.subplots(nrows=3,sharex=True)
    ds0['hphy'].mean(dim='time').plot(ax=ax[0])
    ds0['uphy'].mean(dim='time').plot(ax=ax[1])
    ds0['vphy'].mean(dim='time').plot(ax=ax[2])

    plt.show()
    
    fig,ax = plt.subplots(nrows=3,sharex=True)
    ds0['hphy'].std(dim='time').plot(ax=ax[0])
    ds0['uphy'].std(dim='time').plot(ax=ax[1])
    ds0['vphy'].std(dim='time').plot(ax=ax[2])

    plt.show()
    
    

In [None]:
umax0 = model.aumax(ds0)
umaxn = model.aumax(dsn)
fig,ax = plt.subplots()
umax0.plot(ax=ax,label='physical based')
umaxn.plot(ax=ax,label='nn based')
plt.show()
_,pval = ttest_rel(umax0,umaxn)
pval

In [None]:
dsn.taux[2].plot()