In [None]:
#%reset -f
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import sympy as sp
import h5py

In [None]:
import plot_settings as ps

In [None]:
from PySTFP import PySTFP

In [None]:
# Load analytical results, which have been calculated using the notebook
# analytical solution.ipynb

analytical_solution = {}

with h5py.File('analytical_solution.h5','r') as f:
    #
    for key in f.keys():
        analytical_solution[key] = f[key][()]

#analytical_dictionary.keys()

In [None]:
def compare_perturbative_to_analytical_results(analytical_solution,
                    order_for_NPP=2, # this is K in Eq. (36) of the manuscript,
                        # if this is set to None, then the Gaussian propagator
                        # is used
                    plot_times=[],
                    plot_difference=True,
                ):
        #
        # get the information we need from the analytical_dictionary
        D_x0 = analytical_solution['D_x0']
        a_x0 = analytical_solution['a_x0']
        x0 = analytical_solution['x0']
        xP = analytical_solution['xP']
        P_analytical = analytical_solution['P'][0]
        dt = analytical_solution['dt']
        #
        # define lambda function for propagator
        if order_for_NPP is not None:
                # create an instance of PySTFP
                p = PySTFP()
                #
                # get lambda function with perturbative probability density
                P_lambda = p.get_probability_density_from_derivatives(
                        D_derivatives=D_x0,
                        a_derivatives=a_x0,
                        x0=x0,
                        order=order_for_NPP,
                        dimensionless_time=False)
        else: # use Gaussian propagator
                def P_lambda(x,dt):
                        prefactor = 1./np.sqrt(4*np.pi*D_x0[0]*dt)
                        exponent = 1./(4*D_x0[0]) * ((x - x0)/dt \
                                - a_x0[0] )**2
                        return prefactor*np.exp(-dt*exponent)
        #
        # evaluate perturbative probability density function at various times 
        P_perturbative = np.zeros([len(dt),len(xP)],dtype=float)
        for i,lagtime in enumerate(dt):
                P_perturbative[i] = P_lambda(xP,lagtime)
        #
        # compare perturbative probability density to analytical result
        diff_integrals = np.zeros(len(dt),dtype=float)
        for i,t_value in enumerate(dt):
                diff_integrals[i] =np.trapz(
                        np.fabs(P_perturbative[i] - P_analytical[i]), 
                        xP
                                )
        #
        # plot distributions for times from list plot_times
        stride = 10 # spatial stride for plot
        for i,t_value in enumerate(plot_times):
                #
                index = np.argmin(np.fabs(dt - t_value))
                fig,ax = plt.subplots(1,1,figsize=(6,4))
                ax.set_title(r'$t/T = $ ' + '{0:3.3f}'.format(dt[index]))
                ax.plot(xP[::stride],
                        P_numerical[index][::stride],
                        label='perturbative')
                #
                ax.plot(xP[::stride],
                                P_analytical[index][::stride],
                                ls='--',
                                label='exact')
                #
                ax.legend(loc='best')
                ax.set_xlabel(r'$x/L$')
                ax.set_ylabel(r'$P \cdot L$')
                plt.show()
                plt.close(fig)
        #
        # 
        if plot_difference:
                # plot absolute value of difference between perturbative and
                # numerical data
                fig,axes = plt.subplots(1,2,figsize=(10,4))
                fig.subplots_adjust(wspace=0.3)
                #
                ax = axes[0]
                #
                ax.plot(dt,diff_integrals,
                        color='black',
                        label='Error')
                #
                ax.set_xscale('log')
                ax.set_yscale('log')
                ax.set_xlabel(r'$t/T$')
                ax.set_ylabel(r'Error')
                #ax.set_ylim(1e-15,1e2)
                ax.set_xlim(np.min(dt),np.max(dt))
                # 
                #
                ax = axes[1]
                #
                local_exponent = np.log(diff_integrals[2:]) \
                                    - np.log(diff_integrals[:-2])
                local_exponent /= np.log(dt[2:]) - np.log(dt[:-2])
                ax.plot(dt[1:-1],local_exponent,
                        color='black',
                        label='Local exponent')
                ax.set_xscale('log')
                ax.set_xlabel(r'$t/T$')
                ax.set_ylabel(r'Local exponent')
                #
                # add some power laws for comparison
                dt_ref = 1e-2
                index = np.argmin(np.fabs(dt-dt_ref))
                closest_exponent = np.round(2*local_exponent[index])/2.
                y_ref = diff_integrals[index]
                dt_scale = np.logspace(-3,-1)
                shifts = [-0.5,0,0.5]
                #
                ax = axes[0]
                for shift in shifts:
                        power = closest_exponent + shift
                        #
                        y = (dt_scale/dt_ref)**power * y_ref
                        ax.plot(dt_scale,y,ls='--',
                                label=r'$\sim t^{{{0}}}$'.format(power))
                ax.legend(loc='best')
                #
                ax = axes[1]
                for shift in shifts:
                        line, = ax.plot([], [])
                        ax.axhline(closest_exponent + shift,
                                        ls='--',
                                        label='{0:3.1f}'.format(
                                                closest_exponent + shift
                                        ),
                                        color=line.get_color())
                ax.set_ylim(closest_exponent + min(shifts)- 0.5,
                              closest_exponent + max(shifts)+0.5  )
                ax.set_xlim(np.min(dt),np.max(dt))
                ax.legend(loc='best',ncols=2)    
                plt.show()
                #fig.savefig('error.pdf',bbox_inches='tight')
                plt.close(fig)

        #
        #
        output_dictionary = {'P_perturbative':P_perturbative,
                                'xP':xP,
                                'diff_integrals':diff_integrals,
                                'dt':dt,
                                }
        #
        return output_dictionary

In [None]:
orders_for_NPP = [None,4,8]

comparison_results = {}

for order in orders_for_NPP:
    if order is None:
        print("Gaussian propagator:")
    else:
        print('P_K with K = {0}:'.format(order))
    comparison_results[order] = compare_perturbative_to_analytical_results(
                analytical_solution=analytical_solution,
                    order_for_NPP=order, # this is K in Eq. (XX) of the manuscript,
                        # if this is set to None, then the Gaussian propagator
                        # is used
                    #plot_times=[0.01,0.1],
                    plot_difference=True)


In [None]:
# Grid of 2x3 subplots:
# (a), (b), (c)
# (d), (e), (f)
#
# left column:
# (a): error
# (d): local exponent of error
# middle column:
# (b): probability densities at t = 0.05 T
# (e): pointwise errors at t = 0.05 T
# right column:
# (c): probability densities at t = 0.2 T
# (f): pointwise errors at t = 0.2 T


# Plotting parameters
dL = 1.5 # plotting interval = 1.5 free-diffusion standard deviations around x0
d0 = analytical_solution['D_x0'][0]
dt = analytical_solution['dt']
xP = analytical_solution['xP']
x0 = analytical_solution['x0']
P_analytical = analytical_solution['P'][0]

t_breakdown = 1/(8. * d0)
print('t_breakdown =',t_breakdown)
horizontal_value = ps.horizontal['value']

# formatting
labelspacing=0.3
borderaxespad=0.3
handletextpad=0.3

fig,axes = plt.subplots(2,3,figsize=(16,7))
fig.subplots_adjust(hspace=0.00001,wspace=0.6)

# middle column
# subplot (b), probability densities at time t = 0.05 T

def plot_column(column_index,
                dt_plot):
    #
    index = np.argmin(np.fabs(dt-dt_plot))
    dL = 4*np.sqrt(2*d0*dt_plot)
    #
    # upper plot
    ax = axes[0,column_index]
    # plot analytical density
    ax.plot(xP,
        P_analytical[index],
        color=ps.analytical['color'],
        label=ps.analytical['label'],
        lw=ps.analytical['lw'],
        )
    #
    # plot perturbative results
    for i,dictionary in enumerate(comparison_results.values()):
        #
        P_perturbative = dictionary['P_perturbative']
        #
        ax.plot(xP,P_perturbative[index],
            color=ps.perturbative['color'][i],
            label=ps.perturbative['label'][i],
            dashes=ps.perturbative['dashes'][i],
            lw=ps.perturbative['lw'][i],
            )
    #
    ax.axvline(x0,
                ls=ps.x0_vertical['ls'],
                color=ps.x0_vertical['color'],
                )
    ax.set_xlim(x0-dL,x0+dL)
    ax.set_xticks([])
    ax.set_ylabel(r'$P\cdot L$')
    if column_index == 0:
        ax.legend(loc='lower right',labelspacing=labelspacing,
                borderaxespad=borderaxespad,
                handletextpad=handletextpad,
                bbox_to_anchor=(1.33,0.35)
                )
    #
    # lower plot
    ax = axes[1,column_index]
    #
    # plot absolute value of difference for current lagtime
    for i,dictionary in enumerate(comparison_results.values()):
        #
        P_perturbative = dictionary['P_perturbative']
        #
        ax.plot(xP,np.fabs( P_perturbative[index] - P_analytical[index]),
            color=ps.perturbative['color'][i],
            label=ps.perturbative['label'][i],
            dashes=ps.perturbative['dashes'][i],
            lw=ps.perturbative['lw'][i],
                        )
    #
    ax.axvline(x0,
                ls=ps.x0_vertical['ls'],
                color=ps.x0_vertical['color'],
                )
    ax.set_xlim(x0-dL,x0+dL)
    ax.set_xlabel(r'$x/L$')
    ax.set_ylabel(r'$|P - P^{\mathrm{e}}| \cdot L$')
    #

for i,dt_plot in enumerate(ps.dt_plot['times']):
    plot_column(column_index=i,
                dt_plot=dt_plot)

#


# right column:
# upper plot is instantaneous error as a function of time
# lower plot is running exponent of the instantaneous error

# in both plots we plot the breakdown time and the times
# corresponding to the left and middle column:
for j in range(2):
    ax = axes[j,2]
    #
    ps.add_vertical_breakdown_line(ax=ax,t_breakdown=t_breakdown)
    #
    for i,dt_plot in enumerate(ps.dt_plot['times']):
        #
        ax.axvline(dt_plot,
                    color=ps.dt_plot['color'][i],
                    dashes=ps.dt_plot['dashes'][i],
                    label=ps.dt_plot['label'][i],
                    )


ax = axes[0,2] # instantaneous error

ps.add_horizontal_line(ax=ax)

dt_get_result = [0.05,0.1,0.2]

for i,dictionary in enumerate(comparison_results.values()):
    #
    diff_integrals = dictionary['diff_integrals']
    # plot
    ax.plot(dt,diff_integrals,
            color=ps.perturbative['color'][i],
            label=ps.perturbative['label'][i],
           dashes=ps.perturbative['dashes'][i],
           lw=ps.perturbative['lw'][i],
           )
    #
    # print some results
    print(ps.perturbative['label'][i])
    for dt_ in dt_get_result:
        index = np.argmin(np.fabs(dt - dt_))
        print('t = {0:3.3f},\tE = {1:3.3f}'.format(dt_,
                                    diff_integrals[index])
                    )
        #
    index = np.argmin(np.fabs(diff_integrals - horizontal_value))
    print('E = {0:3.3f} at dt = {1:3.3f}'.format(horizontal_value,dt[index]))
    #
ax.set_xscale('log')
ax.set_yscale('log')
ax.set_ylim(1e-11,2e0)
ax.set_xlim(1e-3,1e0)
ax.set_xticks([])
ax.set_ylabel(r'$E_p \cdot L$')
ax.legend(loc='lower right',labelspacing=labelspacing,
          borderaxespad=borderaxespad,
          handletextpad=handletextpad,
         bbox_to_anchor=(1.27,0.))
ax.set_yticks([1e-10,1e-8,1e-6,1e-4,1e-2,1e0])


# Running exponent
ax = axes[1,2]

for i,dictionary in enumerate(comparison_results.values()):
    #
    diff_integrals = dictionary['diff_integrals']
    #
    local_exponent = np.log(diff_integrals[2:]) - np.log(diff_integrals[:-2])
    local_exponent /= np.log(dt[2:]) - np.log(dt[:-2])
    #
    ax.plot(dt[1:-1],local_exponent,
        color=ps.perturbative['color'][i],
        label=ps.perturbative['label'][i],
        dashes=ps.perturbative['dashes'][i],
        lw=ps.perturbative['lw'][i],
           )
    
ax.set_xscale('log')
ax.set_ylim(0,4.95)
ax.set_xlim(1e-3,1e0)
ax.set_yticks([0,0.5,1,1.5,2,2.5,3,3.5,4,4.5])
ax.set_xlabel(r'$\Delta t/T$')
ax.set_ylabel(r'$\kappa$')


enumeration = [ [ '(a)','(b)','(c)'],
               [ '(d)','(e)','(f)']]
for j,f in enumerate(enumeration):
    for i,e in enumerate(f):
        #
        ax = axes[j,i]
        #
        xlims = ax.get_xlim()
        dx = xlims[1] - xlims[0]
        ylims = ax.get_ylim()
        dy = ylims[1] - ylims[0]
        #
        ax.text(-0.3, 
               .95,
               e,
     horizontalalignment='center',
     verticalalignment='center',
     transform = ax.transAxes,
                fontsize=25)
        
suptitles = [
             r'$\Delta t/T = {0:3.2f}$'.format(ps.dt_plot['times'][0]),
            r'$\Delta t/T = {0:3.2f}$'.format(ps.dt_plot['times'][1]),
            'Instantaneous error',
            ]

for i,e in enumerate(suptitles):
    ax = axes[0,i]
    ax.set_title(e,y=1.03)


plt.show()
fig.savefig('Fig2_exact_vs_approximate_propagator.pdf',bbox_inches='tight')
plt.close(fig)