## Compare RBF and DMD NIROM results using pre-computed online solutions for Shallow Water models

In [None]:
## Load modules
import numpy as np
import scipy
from importlib import reload

import os
import gc
from scipy import interpolate
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.ticker import LinearLocator, ScalarFormatter, FormatStrFormatter

from matplotlib import animation
matplotlib.rc('animation', html='html5')
from IPython.display import display
import matplotlib.ticker as ticker
from matplotlib import rcParams
from matplotlib.offsetbox import AnchoredText


# Plot parameters
plt.rc('font', family='serif')
plt.rcParams.update({'font.size': 20,
                     'lines.linewidth': 2,
                     'axes.labelsize': 16, 
                     'axes.titlesize': 20,
                     'xtick.labelsize': 16,
                     'ytick.labelsize': 16,
                     'legend.fontsize': 16,
                     'axes.linewidth': 2})

import itertools
colors = itertools.cycle(['r','g','b','m','y','c'])
markers = itertools.cycle(['p','d','o','^','s','x','D','H','v','*'])


base_dir = os.getcwd()
src_dir = os.path.join(base_dir,'../src')
work_dir = os.path.join(base_dir,'../notebooks')
data_dir = os.path.join(base_dir,'../data/')
nirom_data_dir = os.path.join(base_dir,'../data/')
node_data_dir = os.path.join(base_dir,'../best_models/')
fig_dir = os.path.join(base_dir,'../figures')


os.chdir(work_dir)

In [None]:
## Load snapshot data

# # ### San Diego problem
# model = 'SD'
# data = np.load(os.path.join(data_dir,'san_diego_tide_snapshots_T4.32e5_nn6311_dt25.npz'))
# mesh = np.load(os.path.join(data_dir,'san_diego_mesh.npz'))

### Red River problem
model ='Red'
data = np.load(os.path.join(data_dir,'red_river_inset_snapshots_T7.0e4_nn12291_dt10.npz'))
mesh = np.load(os.path.join(data_dir,'red_river_mesh.npz'))

print("Solution component keys are : " + str(list(data.keys())))
print("Mesh element keys are : " + str(list(mesh.keys())))

In [None]:
## Prepare training snapshots
soln_names = ['S_dep', 'S_vx', 'S_vy']
comp_names={0:'S_dep',1:'S_vx',2:'S_vy'}
Nc=3 


nodes = mesh['nodes']; triangles = mesh['triangles']
Nn = nodes.shape[0]; Ne = triangles.shape[0]

snap_start = 100
if model == 'SD':
    T_end = 50*3600   ### 50 hours in seconds
elif model == 'Red':
    T_end = 3.24e4
snap_end = np.count_nonzero(data['T'][data['T'] <= T_end])

snap_data = {}
for key in soln_names:
    snap_data[key] = data[key][:,snap_start:]

times_offline = data['T'][snap_start:]
Nt = times_offline.size
print('Loaded {0} snapshots of dimension {1} for h,u and v, spanning times [{2}, {3}]'.format(
                    snap_data[soln_names[0]].shape[1],snap_data[soln_names[0]].shape[0], 
                    times_offline[0], times_offline[-1]))

DT = (times_offline[1:] - times_offline[:-1]).mean()

## number of steps to skip in selecting training snapshots for SVD basis
if model == 'SD':
    snap_incr=4
elif model == 'Red':
    snap_incr=3
    
## Normalize the time axis. Required for DMD fitting
tscale = DT*snap_incr            ### Scaling for DMD ()
times_offline_dmd = times_offline/tscale   ## Snapshots DT = 1

## Subsample snapshots for building POD basis
snap_train = {};
for key in soln_names:
    snap_train[key] = snap_data[key][:,:snap_end+1:snap_incr]


times_train=times_offline[:snap_end+1:snap_incr]
Nt_b = times_train.size
print('Using {0} training snapshots for time interval [{1},{2}]'.format(times_train.shape[0], 
                                        times_train[0], times_train[-1]))


del data
del mesh
gc.collect()

In [None]:
## Set the time steps for online prediction

t0 = times_train[0]
if model == 'Red':
    Tonline_end = 3.24e4  ### 9 hours in seconds
elif model == 'SD':
    Tonline_end = 50*3600   ### 50 hours in seconds
trainT0 = np.searchsorted(times_offline, t0); 
trainT = np.searchsorted(times_offline, times_train[-1]); 
trainP = np.searchsorted(times_offline, Tonline_end); 

finer_steps = True
long_term = True

if finer_steps and not long_term:
    onl_incr = snap_incr-1
    times_online = times_offline[trainT0:trainT+1:onl_incr]
    N_online = trainT+1
elif long_term and not finer_steps:
    onl_incr = snap_incr
    times_online = times_offline[trainT0:trainP+1:onl_incr]
    N_online = trainP+1
elif long_term and finer_steps:
    onl_incr = snap_incr-2
    times_online = times_offline[trainT0:trainP+1:onl_incr]
    N_online = trainP+1
Nt_online = times_online.size
print('Trying to simulate interval [{0},{1}] days with {2} steps'.format(t0,
                                                times_online[-1], Nt_online))

times_online_dmd = times_online/tscale

In [None]:
## LOAD saved NIROM solutions

if model == 'Red':
    #     DMD = np.load(nirom_data_dir+'%s_online_dmd_r30.npz'%model);rdmd=30
    DMD = np.load(nirom_data_dir+'%s_online_dmd_r315.npz'%model); rdmd=315
elif model == 'SD':
    DMD = np.load(nirom_data_dir+'%s_online_dmd_r115.npz'%model)
Xdmd = DMD['dmd']; X_true = DMD['true']; 

RBF = np.load(nirom_data_dir+'%s_online_rbf.npz'%model)
urbf={}
for key in soln_names:
    urbf[key] = RBF[key]


if model == 'SD':
    node_data_dir = node_data_dir+'/SD/'
elif model == 'Red':
    node_data_dir = node_data_dir+'/RED/'
NODE = np.load(node_data_dir+'%s_online_node.npz'%model)

unode = {}
for key in soln_names:
    unode[key] = NODE[key]

assert np.allclose(times_online,NODE['time'])

del DMD
del RBF
del NODE
gc.collect()


In [None]:
def var_string(ky):
    if ky == 'S_dep':
        md = 'h'
    elif ky == 'S_vx':
        md = 'u_x'
    elif ky == 'S_vy':
        md = 'u_y'
    return md

In [None]:
### Compute spatial RMS errors

fig = plt.figure(figsize=(16,4))
start_trunc = 10+0*np.searchsorted(times_online,times_train[-1])//10
end_trunc = 10*np.searchsorted(times_online,times_train[-1])//10
end_trunc = end_trunc + (Nt_online - end_trunc)//1
x_inx = times_online_dmd*tscale/3600
time_ind = np.searchsorted(times_offline, times_online)
ky1 = 'S_dep'; ky2 = 'S_vx'; ky3 = 'S_vy'
md1 = var_string(ky1); md2 = var_string(ky2); md3 = var_string(ky3)


dmd_rel_err = {}
rbf_rel_err = {}
node_rel_err = {}

for ivar,key in enumerate(soln_names):
    dmd_rel_err[key] = np.linalg.norm(X_true[ivar::Nc,:] - Xdmd[ivar::Nc,:], axis = 0)/np.sqrt(Nn) 
    rbf_rel_err[key] = np.linalg.norm(snap_data[key][:,time_ind]- urbf[key][:,:], axis=0)/np.sqrt(Nn)  
    node_rel_err[key] = np.linalg.norm(snap_data[key][:,time_ind]- unode[key][:,:], axis=0)/np.sqrt(Nn)  

ax1 = fig.add_subplot(1, 2, 1)
ax1.plot(x_inx[start_trunc:end_trunc], dmd_rel_err[ky1][start_trunc:end_trunc], 'r-s', markersize=8,
                label='DMD:$\mathbf{%s}$'%(md1),lw=2,markevery=400)
ax1.plot(x_inx[start_trunc:end_trunc], rbf_rel_err[ky1][start_trunc:end_trunc], 'k-p', markersize=8,
                label='RBF:$\mathbf{%s}$'%(md1),lw=2,markevery=500)
ax1.plot(x_inx[start_trunc:end_trunc], node_rel_err[ky1][start_trunc:end_trunc], 'b-o', markersize=8,
                label='NODE:$\mathbf{%s}$'%(md1),lw=2,markevery=600)
ymax_ax1 = dmd_rel_err[ky1][start_trunc:end_trunc].max()
ax1.set_xlabel('Time (hrs)');lg=plt.legend(ncol=3,fancybox=True,loc='upper center')

ax2 = fig.add_subplot(1, 2, 2)
ax2.plot(x_inx[start_trunc:end_trunc], dmd_rel_err[ky2][start_trunc:end_trunc], 'r-o', markersize=8,
                label='DMD:$\mathbf{%s}$'%(md2), lw=2,markevery=400)
ax2.plot(x_inx[start_trunc:end_trunc], rbf_rel_err[ky2][start_trunc:end_trunc], 'k-D', markersize=8,
                label='RBF:$\mathbf{%s}$'%(md2), lw=2,markevery=500)
ax2.plot(x_inx[start_trunc:end_trunc], node_rel_err[ky2][start_trunc:end_trunc], 'b-H', markersize=8,
                label='NODE:$\mathbf{%s}$'%(md2), lw=2,markevery=530)


ymax_ax2 = np.maximum(dmd_rel_err[ky2][start_trunc:end_trunc].max(), dmd_rel_err[ky3][start_trunc:end_trunc].max())
ax2.set_xlabel('Time (hrs)');lg=plt.legend(ncol=3,fancybox=True,loc='upper center')
 
# os.chdir(fig_dir)
# plt.savefig('SW_%s_nirom_comp_rms_tskip%d_oskip%d.pdf'%(model,snap_incr,onl_incr),bbox_extra_artists=(lg,), bbox_inches='tight')

In [None]:
## Visualize NIROM solution
def viz_sol(urom,iplot,times_online,nodes,triangles,method,key):
    print("NIROM solution at t = {0:.2f} hrs".format(times_online[iplot]/3600))
    tn = times_online[iplot]
    fig = plt.figure(figsize=(8,6))
    ax1 = plt.subplot(1,1,1); ax1.axis('off')
    surf1 = ax1.tripcolor(nodes[:,0], nodes[:,1],triangles, urom, cmap=plt.cm.jet)
    ax1.set_title("%s solution at t=$%.2f$ hrs\n $%1.5f<\mathbf{%s}<%1.5f$"%(method,tn/3600,np.amin(urom),
                                                var_string(key),np.amax(urom)),fontsize=16)
    plt.axis('off')
    plt.colorbar(surf1, shrink=0.8,aspect=20, pad = 0.03)



In [None]:
## Visualize NIROM error
def viz_err(urom,utrue,iplot,times_online,nodes,triangles,method,key):
    print("comparing NIROM solution at t = {1:.2f} hrs and fine-grid solution at t = {0:.2f} hrs".format(
                                            times_offline[iplot_true]/3600, times_online[iplot]/3600))

    fig = plt.figure(figsize=(8,6))
    ax3 = plt.subplot(1,1,1); ax1.axis('off')
    surf3 = ax3.tripcolor(nodes[:,0], nodes[:,1],triangles, utrue-urom, cmap=plt.cm.jet)
    boundaries_err = [np.amin(utrue-urom), np.amax(utrue-urom)]
    ax3.set_title("$%1.6f$ <%s $\mathbf{%s}$ Error< $%1.6f$\n Rel. Error 2-norm : $%2.6f$"%(boundaries_err[0],
                                    method,var_string(key),boundaries_err[1],
                                    np.linalg.norm(urom-utrue)/np.linalg.norm(utrue)),fontsize=16)
    plt.axis('off')
    plt.colorbar(surf3, shrink=0.8,aspect=20, pad = 0.03)



In [None]:
key = 'S_vx'
iplot = 1200
ivar = list(comp_names.values()).index(key)

###  Uncomment to select one of the three NIROM solutions
# urom = urbf[key][:,iplot]; method='RBF'   ## RBF
# urom = unode[key][:,iplot]; method='NODE'  ## NODE
urom = Xdmd[ivar::Nc,iplot]; method='DMD'     ## DMD

viz_sol(urom,iplot,times_online,nodes,triangles,method,key)

# os.chdir(fig_dir)
# plt.savefig('%s_%s_%s_t%.3f_tskip%d_oskip%d.pdf'%(model,method,var_string(key),times_online[iplot]/3600,snap_incr,onl_incr),bbox_inches='tight')

In [None]:
iplot_true = np.searchsorted(times_offline, times_online[iplot])
utrue = snap_data[key][:,iplot_true]
viz_err(urom,utrue,iplot,times_online,nodes,triangles,method,key)

# os.chdir(fig_dir)
# plt.savefig('%s_%s_relerr_%s_t%.3f_tskip%d_oskip%d.pdf'%(model,method,var_string(key),times_online[iplot]/3600,snap_incr,onl_incr), bbox_inches='tight')

In [None]:
def plot_nirom_soln(Xtrue, Xdmd, Xrbf, Xnode, Nc, Nt_plot, nodes, elems, trainT0, times_online, comp_names, seed =100, flag = True): 
    
    np.random.seed(seed)
    itime = np.searchsorted(times_online,3.61*3600) #np.random.randint(0,Nt_plot)
    ivar  = 1 #np.random.randint(1,Nc)
    ky = comp_names[ivar]
    tn   = times_online[itime]

    if flag:     ### for interleaved snapshots
        tmp_dmd      = Xdmd[ivar::Nc,itime]
        tmp_true = Xtrue[ivar::Nc,itime]
    else:
        tmp_dmd      = Xdmd[ivar*Nn:(ivar+1)*Nn,itime]
        tmp_true = Xtrue[ivar*Nn:(ivar+1)*Nn,itime]
    
    tmp_rbf = Xrbf[ky][:,itime]
    tmp_node = Xnode[ky][:,itime]

    
    fig  = plt.figure(figsize=(15,28));
    ax1   = fig.add_subplot(4, 2, 1)
    surf1 = ax1.tripcolor(nodes[:,0], nodes[:,1],elems, tmp_dmd, cmap=plt.cm.jet)
    ax1.set_title('DMD solution: {0} at t={1:1.2f} hrs, \n{0} range = [{2:5.3g},{3:4.2g}]'.format(ky,tn/3600,
                                                                        tmp_dmd.min(),tmp_dmd.max()),fontsize=16)
    plt.axis('off')
    plt.colorbar(surf1, orientation='horizontal',shrink=0.6,aspect=40, pad = 0.03)
    
    ax2   = fig.add_subplot(4, 2, 2)
    surf2 = ax2.tripcolor(nodes[:,0], nodes[:,1],elems, tmp_rbf, cmap=plt.cm.jet)
    ax2.set_title('RBF solution: {0} at t={1:1.2f} hrs, \n{0} range = [{2:5.3g},{3:4.2g}]'.format(ky,tn/3600,
                                                                        tmp_rbf.min(),tmp_rbf.max()),fontsize=16)
    plt.axis('off')
    plt.colorbar(surf2, orientation='horizontal',shrink=0.6,aspect=40, pad = 0.03)
    
    ax3   = fig.add_subplot(4, 2, 3)
    surf3 = ax3.tripcolor(nodes[:,0], nodes[:,1],elems, tmp_node, cmap=plt.cm.jet)
    ax3.set_title('NODE solution: {0} at t={1:1.2f} hrs, \n{0} range = [{2:5.3g},{3:4.2g}]'.format(ky,tn/3600,
                                                                    tmp_node.min(),tmp_node.max()),fontsize=16)
    plt.axis('off')
    plt.colorbar(surf3, orientation='horizontal',shrink=0.6,aspect=40, pad = 0.03)
    
    ax4   = fig.add_subplot(4, 2, 4)
    surf4 = ax4.tripcolor(nodes[:,0], nodes[:,1],elems, tmp_true, cmap=plt.cm.jet)
    ax4.set_title('HFM solution: {0} at t={1:1.2f} hrs, \n{0} range = [{2:5.3g},{3:4.2g}]'.format(ky,tn/3600,
                                                                    tmp_true.min(),tmp_true.max()),fontsize=16)
    plt.axis('off')
    plt.colorbar(surf4, orientation='horizontal',shrink=0.6,aspect=40, pad = 0.03)

    err_dmd = tmp_dmd-tmp_true
    ax5   = fig.add_subplot(4, 2, 5)
    surf5 = ax5.tripcolor(nodes[:,0], nodes[:,1],elems, err_dmd, cmap=plt.cm.Spectral)
    ax5.set_title('DMD error: {0} at t={1:1.2f} hrs, \nerror range = [{2:5.3g},{3:4.2g}]'.format(ky,tn/3600,
                                                                    err_dmd.min(),err_dmd.max()),fontsize=16)
    plt.axis('off')
    plt.colorbar(surf5,orientation='horizontal',shrink=0.6,aspect=40, pad = 0.03)
    
    err_rbf = tmp_rbf-tmp_true
    ax6   = fig.add_subplot(4, 2, 6)
    surf6 = ax6.tripcolor(nodes[:,0], nodes[:,1],elems, err_rbf, cmap=plt.cm.Spectral)
    ax6.set_title('RBF error: {0} at t={1:1.2f} hrs, \nerror range = [{2:5.3g},{3:4.2g}]'.format(ky,tn/3600,
                                                                    err_rbf.min(),err_rbf.max()),fontsize=16)
    plt.axis('off')
    plt.colorbar(surf6,orientation='horizontal',shrink=0.6,aspect=40, pad = 0.03)
    
    err_node = tmp_node-tmp_true
    ax7   = fig.add_subplot(4, 2, 7)
    surf7 = ax7.tripcolor(nodes[:,0], nodes[:,1],elems, err_node, cmap=plt.cm.Spectral)
    ax7.set_title('NODE error: {0} at t={1:1.2f} hrs, \nerror range = [{2:5.3g},{3:4.2g}]'.format(ky,tn/3600,
                                                                    err_node.min(),err_node.max()),fontsize=16)
    plt.axis('off')
    plt.colorbar(surf7,orientation='horizontal',shrink=0.6,aspect=40, pad = 0.03)
    
    return tn

In [None]:
Nt_plot = np.searchsorted(times_online, times_train[-1])
itime = plot_nirom_soln(X_true, Xdmd, urbf, unode, Nc, Nt_plot, nodes, triangles, trainT0, times_online, 
                        comp_names, seed=1990,flag = True)

# os.chdir(fig_dir)
# plt.savefig('OF_nirom_t%.3f_tskip%d_oskip%d.pdf'%(itime,snap_incr,onl_incr),bbox_extra_artists=(lg,), bbox_inches='tight')