In [None]:
#%reset -f
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import sympy as sp
import h5py

In [None]:
import plot_settings as ps

In [None]:
from PySTFP import PySTFP

In [None]:
# Load analytical results, which have been calculated using the notebook
# analytical solution.ipynb

analytical_solution = {}

with h5py.File('analytical_solution.h5','r') as f:
    #
    for key in f.keys():
        analytical_solution[key] = f[key][()]

#analytical_dictionary.keys()

In [None]:
def compare_perturbative_to_analytical_results(
                    order_for_perturbative_propagator=2, # this is K in Eq. (XX) of the manuscript,
                        # if this is set to None, then the Gaussian propagator
                        # is used
                        positivity_preserving=False,
                    plot_times=[],
                    plot_error=True,
                ):
        # get pre-evaluated drift and diffusivity, evaluate at the midpoints
        global analytical_solution
        #
        x0 = analytical_solution['x0']
        D_x0 = analytical_solution['D_x0']
        a_x0 = analytical_solution['a_x0']
        #
        a_midpoints = analytical_solution['a_midpoints']
        D_midpoints = analytical_solution['D_midpoints']
        a_midpoints = analytical_solution['a_midpoints']
        #
        xP = analytical_solution['xP']
        P = analytical_solution['P'][0]
        dt = analytical_solution['dt']
        dx = xP - x0
        #
        if order_for_perturbative_propagator is not None:
                # create an instance of PySTFP
                p = PySTFP()
                #
                # get lambda function with perturbative probability density
                P_lambda = p.get_probability_density_from_derivatives(
                        D_derivatives=D_x0,
                        a_derivatives=a_x0,
                        x0=x0,
                        order=order_for_perturbative_propagator,
                        positivity_preserving=positivity_preserving,
                        dimensionless_time=False,
                        dimensionless_position=False)
        else: # use K = 2 propagator with midpoint evaluation
                #
                def P_lambda(x,dt):
                        #
                        dx_DL = dx/np.sqrt(2*D_midpoints[0]*dt)
                        prefactor = 1./np.sqrt(4*np.pi*D_midpoints[0]*dt)
                        #
                        S0 = 1./(4*D_midpoints[0]) \
                                        * ( dx/dt \
                                - a_midpoints[0] \
                                + D_midpoints[1] )**2
                        #
                        S1 = 0.5 * a_midpoints[1] \
                                + (-dx_DL**4 + 3*dx_DL**2 - 6) \
                                        * D_midpoints[2]/24. \
                                + (dx_DL**4 - 2*dx_DL**2 - 1) \
                                        * D_midpoints[1]**2/(16*D_midpoints[0])
                        #
                        return prefactor*np.exp(-dt*(S0 + S1))
                #
        #
        # evaluate perturbative probability density function at various times 
        P_perturbative = np.zeros([len(dt),len(xP)],dtype=float)
        for i,lagtime in enumerate(dt):
                P_perturbative[i] = P_lambda(xP,lagtime)
        #
        # compare perturbative probability density to analytical result
        diff_integrals = np.zeros(len(dt),dtype=float)
        for i,t_value in enumerate(dt):
                diff_integrals[i] =np.trapz(
                        np.fabs(P_perturbative[i] - P[i]), 
                        xP
                                )
        #
        # plot distributions at given plot_times
        stride = 10
        for i,t_value in enumerate(plot_times):
                #
                index = np.argmin(np.fabs(dt - t_value))
                fig,ax = plt.subplots(1,1,figsize=(6,4))
                ax.set_title(r'$t/T = $ ' + '{0:3.3f}'.format(dt[index]))
                ax.plot(xP[::stride],
                        yP_numerical[index][::stride],
                        label='perturbative')
                #
                ax.plot(xP[::stride],
                                P_analytical[index][::stride],
                                ls='--',
                                label='exact')
                #
                ax.legend(loc='best')
                ax.set_xlabel(r'$x/L$')
                ax.set_ylabel(r'$P \cdot L$')
                plt.show()
                plt.close(fig)
        #
        # plot error
        if plot_error:
                #
                fig,axes = plt.subplots(1,2,figsize=(10,4))
                fig.subplots_adjust(wspace=0.3)
                #
                # log-log plots of error
                ax = axes[0]
                #
                ax.plot(dt,diff_integrals,
                        color='black',
                        label='Error')
                #
                ax.set_xscale('log')
                ax.set_yscale('log')
                #
                ax.set_xlabel(r'$t/T$')
                ax.set_ylabel(r'Error')
                #
                ax.set_xlim(1e-3,1e0)
                #
                # running exponent of error
                ax = axes[1]
                #
                local_exponent = np.log(diff_integrals[2:]) \
                                        - np.log(diff_integrals[:-2])
                local_exponent /= np.log(dt[2:]) - np.log(dt[:-2])
                ax.plot(dt[1:-1],local_exponent,
                        color='black',
                        label='Running exponent')
                ax.set_xscale('log')
                ax.set_xlabel(r'$t/T$')
                ax.set_ylabel(r'Running exponent')
                #
                # add some power laws for comparison
                dt_ref = 1e-2
                index = np.argmin(np.fabs(dt-dt_ref))
                closest_exponent = np.round(2*local_exponent[index])/2.
                y_ref = diff_integrals[index]
                dt_scale = np.logspace(-3,-1)
                shifts = [-0.5,0,0.5]
                #
                ax = axes[0]
                for shift in shifts:
                        power = closest_exponent + shift
                        #
                        y = (dt_scale/dt_ref)**power * y_ref
                        ax.plot(dt_scale,y,ls='--',
                                label=r'$\sim t^{{{0}}}$'.format(power))
                ax.legend(loc='best')
                #
                ax = axes[1]
                for shift in shifts:
                        line, = ax.plot([], [])
                        ax.axhline(closest_exponent + shift,
                                        ls='--',
                                        label='{0:3.1f}'.format(
                                                closest_exponent + shift
                                        ),
                                        color=line.get_color())
                ax.set_ylim(closest_exponent + min(shifts)- 0.5,
                              closest_exponent + max(shifts)+0.5  )
                ax.set_xlim(1e-3,1e0)
                        
                ax.legend(loc='best',ncols=1)    
                plt.show()
                plt.close(fig)

        #
        output_dictionary = {'P_perturbative':P_perturbative,
                                'xP':xP,
                                'diff_integrals':diff_integrals,
                                'dt':dt,
                                }
        #
        return output_dictionary


In [None]:
orders_for_perturbative_propagators = [2,2,None]
positivity_preserving = [False,True,None]

errors_perturbative = []

for i,order in enumerate(orders_for_perturbative_propagators):
    if order is None:
        print("Gaussian propagator:")
    else:
        print('P_K with K = {0}:'.format(order))
    errors_perturbative.append( compare_perturbative_to_analytical_results(
                positivity_preserving=positivity_preserving[i],
                order_for_perturbative_propagator=order,
                plot_error=True,
                )
            )


In [None]:
# Grid of 2x3 subplots:
# (a), (b), (c)
# (d), (e), (f)
# 
# corresponding to axes[i,j]
# with row
#   i = 0, 1
# and column
#   j = 0, 1, 2
# so that e.g. axes[1,2] corresponds to subplot (f)
#
# left column: Propagators at lagtime dt = 0.05
# (a): plots of P
# (d): pointwise error of P
# middle column: Propagators at lagtime dt = 0.20
# (b): plots of P
# (e): pointwise error of P
# right column: Instantaneous error
# (c): log-log plot of instantaneous error
# (f): running exponent of instantaneous error


x0 = analytical_solution['x0']
D_x0 = analytical_solution['D_x0']
a_x0 = analytical_solution['a_x0']
#
xP = analytical_solution['xP']
P = analytical_solution['P'][0]
dt = analytical_solution['dt']

# Plotting parameters
dL = 1.5 # plotting interval = 1.5 free-diffusion standard deviations around x0

perturbative_labels = {0:r'NPP, $K=2$',
                       1:r'PPP, $K=2$',
                     2:r'PPP, $K=2$,' + '\nmidpoint',
                      }

# formatting
labelspacing=0.3
borderaxespad=0.3
handletextpad=0.3

fig,axes = plt.subplots(2,3,figsize=(16,7))
fig.subplots_adjust(hspace=0.00001,wspace=0.6)

##########################
# Left and middle column #
##########################

# plot vertical line for x0
for i in range(2):
    for j in range(2):
        axes[i,j].axvline(x0,
                        ls=ps.x0_vertical['ls'],
                        color=ps.x0_vertical['color'],
                )



for i,dt_plot in enumerate(ps.dt_plot['times']):
    # i = 0: left column
    # i = 1: middle column
    #
    #####################################
    # upper plot: probability densities #
    #####################################
    ax = axes[0,i]
    #
    index = np.argmin(np.fabs(dt-dt_plot))
    dL = 4*np.sqrt(2 * D_x0[0] * dt_plot)
    #
    # plot analytical density
    ax.plot(xP,
            P[index],
            color=ps.analytical['color'],
            lw=ps.analytical['lw'],
            label=ps.analytical['label'],
            )
    #
    # plot perturbative densities
    for j,dictionary in enumerate(errors_perturbative):
        #
        P_perturbative = dictionary['P_perturbative']
        #
        ax.plot(xP,
                P_perturbative[index],
                color=ps.perturbative['color'][j],
                lw=ps.perturbative['lw'][j],
                label=perturbative_labels[j],
                dashes=ps.perturbative['dashes'][j],
            )
    #
    # set limits, ticks, label, legend
    ax.set_xlim(x0-dL,x0+dL)
    ydiff = np.max(P[index]) - np.min(P[index])
    ax.set_ylim(-0.1*ydiff,
                np.max(P[index]) + 0.1*ydiff)
    ax.set_xticks([])
    ax.set_ylabel(r'$P\cdot L$')
    if i == 0:
        ax.legend(loc='lower right',labelspacing=labelspacing,
            borderaxespad=borderaxespad,
            handletextpad=handletextpad,
            bbox_to_anchor=(1.33,0.35),
            framealpha=0.95,
            )
    #
    ####################################
    # lower plot: pointwise difference #
    ####################################
    #
    ax = axes[1,i]
    #
    for j,dictionary in enumerate(errors_perturbative):
        #
        P_perturbative = dictionary['P_perturbative']
        #
        ax.plot(xP,
                np.fabs( P_perturbative[index] - P[index]),
                color=ps.perturbative['color'][j],
                lw=ps.perturbative['lw'][j],
                label=perturbative_labels[j],
                dashes=ps.perturbative['dashes'][j],
            )
    #
    # set limits, ticks, label
    ax.set_xlim(x0-dL,x0+dL)
    ax.set_xlabel(r'$x/L$')
    ax.set_ylabel(r'$|P - P^e| \cdot L$')
    #


#######################
# Instantaneous error #
#######################

# For both instantaneous error and running exponent, add vertical lines 
# for plot times and breakdown time.
for i in range(2):
    ax = axes[i,2]
    #
    for j,dt_plot in enumerate(ps.dt_plot['times']):
        ax.axvline(dt_plot,
                color=ps.dt_plot['color'][j],
                dashes=ps.dt_plot['dashes'][j],
                label=ps.dt_plot['label'][j])
    #
    ps.add_vertical_breakdown_line(ax=ax,
                    t_breakdown=analytical_solution['t_breakdown'])
    

ax = axes[0,2]

ps.add_horizontal_line(ax=ax,
                        label='')
ax.text(x=2e-3,y=1.2e-2,
        s=ps.horizontal['label'].format(ps.horizontal['value']))

# plot instantaneous error
for i,dictionary in enumerate(errors_perturbative):
    #
    diff_integrals = dictionary['diff_integrals']
    #
    ax.plot(dt,
            diff_integrals,
            color=ps.perturbative['color'][i],
            lw=ps.perturbative['lw'][i],
            label=perturbative_labels[i],
           dashes=ps.perturbative['dashes'][i],
            )

# print some values for instantaneous error
dt_get_result = 1e-1
for i,dictionary in enumerate(errors_perturbative):
    #
    diff_integrals = dictionary['diff_integrals']
    #
    print('{0}:'.format(perturbative_labels[i].replace('\n',' ')))
    index = np.argmin(np.fabs(dt - dt_get_result))
    print('t = {0:3.3f},\tE = {1:3.3f}'.format(dt_get_result,
                                        diff_integrals[index]))
    #
    for j,dt_plot in enumerate(ps.dt_plot['times']):
        index = np.argmin(np.fabs(dt - dt_plot))
        print('t = {0:3.3f},\tE = {1:3.3f}'.format(dt_plot,
                                        diff_integrals[index]))
    #
    index = np.argmin(np.fabs(diff_integrals - ps.horizontal['value']))
    print('E = {0:3.5f} at dt = {1:3.5f}'.format(ps.horizontal['value'],
                                    dt[index]))
    #
    
# set limits, ticks, label, legend
ax.set_xscale('log')
ax.set_yscale('log')
ax.set_ylim(2e-6,2e0)
ax.set_xlim(1e-3,1e0)
ax.set_xticks([])
ax.set_ylabel(r'$E \cdot L$')
ax.set_zorder(1) # bring to front, see https://stackoverflow.com/a/33150705
ax.legend(loc='lower right',labelspacing=labelspacing,
          borderaxespad=borderaxespad,
          handletextpad=handletextpad,
          ncols=1,
          framealpha=.95,
         bbox_to_anchor=(1.,-.4))
ax.set_yticks([
            1e-5,1e-4,1e-3,
            1e-2,1e-1,1e0,
            ])


###########################################
# Running exponent of instantaneous error #
###########################################
ax = axes[1,2]


# plot instantaneous error
for i,dictionary in enumerate(errors_perturbative):
    #
    diff_integrals = dictionary['diff_integrals']
    #
    local_exponent = np.log(diff_integrals[2:]) - np.log(diff_integrals[:-2])
    local_exponent /= np.log(dt[2:]) - np.log(dt[:-2])
    #
    ax.plot(dt[1:-1],
            local_exponent,
            color=ps.perturbative['color'][i],
            lw=ps.perturbative['lw'][i],
            label=perturbative_labels[i],
           dashes=ps.perturbative['dashes'][i],
            )
            
# set limits, ticks, labels
ax.set_xscale('log')
ax.set_ylim(0,4.95)
ax.set_xlim(1e-3,1e0)
ax.set_yticks([0,0.5,1,1.5,2,2.5,3,3.5,4,4.5])
ax.set_xlabel(r'$\Delta t/T$')
ax.set_ylabel(r'$\kappa$')



###################################
# Add enumeration to the subplots #
###################################
enumeration = [ [ '(a)','(b)','(c)'],
               [ '(d)','(e)','(f)']]
for j,f in enumerate(enumeration):
    for i,e in enumerate(f):
        #
        ax = axes[j,i]
        #
        xlims = ax.get_xlim()
        dx = xlims[1] - xlims[0]
        ylims = ax.get_ylim()
        dy = ylims[1] - ylims[0]
        #
        ax.text(-0.3,
               .95,
               e,
     horizontalalignment='center',
     verticalalignment='center',
     transform = ax.transAxes,
                fontsize=25)
        
##############################
# Add titles to the subplots #
##############################
titles = [
             r'$\Delta t/T = {0:3.2f}$'.format(ps.dt_plot['times'][0]),
            r'$\Delta t/T = {0:3.2f}$'.format(ps.dt_plot['times'][1]),
            'Instantaneous error',
            ]

for i,e in enumerate(titles):
    ax = axes[0,i]
    ax.set_title(e,y=1.03)



plt.show()
fig.savefig('Fig5_exact_vs_second_order_propagators.pdf',bbox_inches='tight')
plt.close(fig)