In [None]:
#%reset -f
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import sympy as sp
import h5py

In [None]:
import plot_settings as ps

In [None]:
from PySTFP import PySTFP

In [None]:
# Load analytical results, which have been calculated using the notebook
# analytical solution.ipynb

analytical_solution = {}

with h5py.File('analytical_solution.h5','r') as f:
    #
    for key in f.keys():
        analytical_solution[key] = f[key][()]

#analytical_dictionary.keys()

In [None]:
x0 = analytical_solution['x0']
xP = analytical_solution['xP']
P_analytical = analytical_solution['P'][0]
dt = analytical_solution['dt']
t_breakdown = analytical_solution['t_breakdown']

In [None]:
# calculate analytical moments, by numerically evaluating the moment
# integral using the exact propagator

moments_to_consider = [1,2]

# calculate the moments as a function of time
moments_analytical = {}
for n in moments_to_consider:
    moments_analytical[n] = np.trapz( 
                (xP[np.newaxis,:]-x0)**n * P_analytical , 
                                xP, 
                                axis=1,
                )
    

In [None]:
def get_perturbative_moment(n,
                        perturbation_order=2):
    global analytical_solution
    D_x0 = analytical_solution['D_x0']
    a_x0 = analytical_solution['a_x0']
    dt = analytical_solution['dt']
    #
    if perturbation_order is not None:
        # create an instance of PySTFP
        p = PySTFP()
        #
        # get lambda function with perturbative moment
        moment_lambda = p.get_moment_from_derivatives(
                D_derivatives=D_x0,
                a_derivatives=a_x0,
                n=n,
                order=perturbation_order,
                dimensionless_time=False)
    else:
        # this is the prediction of the Gaussian propagator
        if n == 1:
            moment_lambda = lambda t: a_x0[0] * t
        elif n == 2:
            moment_lambda = lambda t: 2*D_x0[0] * t \
                                        + a_x0[1]**2 * t**2
        else:
            raise RuntimeError("Only first two moments implemented for "\
                            + "Gaussian propagator")
    #
    return moment_lambda(dt)

In [None]:
orders_for_perturbative_moments = [None,2,8]

moments_perturbative = {}

for n in moments_to_consider:
    moments_perturbative[n] = {}
    for order in orders_for_perturbative_moments:
        moments_perturbative[n][order] = get_perturbative_moment(
                                n=n,
                                perturbation_order=order
                                    )

In [None]:
fig,axes = plt.subplots(2,3,figsize=(16,7))
fig.subplots_adjust(hspace=0.00001,wspace=0.6)

# indices are [row,column]:
#
# [0,0]: <x>/dt
# [1,0]: <x^2>/dt
#
# [0,1]: relative Error <x>
# [1,1]: local exponent relative Error <x>
#
# [0,2]: relative Error <x^2>
# [1,2]: local exponent relative Error <x^2>

relative= False # set to True, relative instantaneous error is plotted

# limits for plots
xlims = [1e-3, 
        1e0]
ylims = {1:[-0.29,0],
        2:[2,2.6]}

ylims_0_relative = {1:[1e-13,1e3],
        2:[1e-13,1e3]}

ylims_0_absolute = {1:[1e-11,1e0],
        2:[1e-11,1e0]}

if relative:
    ylims_0 = ylims_0_relative
else:
    ylims_0 = ylims_0_absolute



# plot vertical line for the breakdown time in all subplots
for i in range(2):
    for j in range(3):
        #
        ps.add_vertical_breakdown_line(ax = axes[i,j],
                                    t_breakdown=t_breakdown,
                                    )

#######################################################
# left column: finite-time Kramers-Moyal coefficients #
#######################################################
for i, moment in enumerate(moments_to_consider):
    #
    ax = axes[i,0]
    #
    # analytical
    ax.plot(dt,
            moments_analytical[moment]/dt,
            label=ps.analytical['label'],
            lw=ps.analytical['lw'],
            color=ps.analytical['color'],
            )
    #
    # perturbative
    for j,order in enumerate(orders_for_perturbative_moments):
        y = moments_perturbative[moment][order]
        ax.plot(dt,
                y/dt,
                color=ps.perturbative['color'][j],
                dashes=ps.perturbative['dashes'][j],
                lw=ps.perturbative['lw'][j],
               label=ps.perturbative['label'][j],
               )
    #
    # set scales, limits, labels, legend
    ax.set_xscale('log')
    ax.set_ylim(*ylims[moment])
    ax.set_xlim(*xlims)
    #
    if i == 0:
        ax.set_xticks([])
    if i == 1:
        ax.legend(loc='best',
        framealpha=ps.framealpha)
    if i == 0:
        ax.set_ylabel(r'$\alpha_1/L$')
    else:
        ax.set_xlabel(r'$\Delta t/T$')
        ax.set_ylabel(r'$\alpha_2/L^2$')



for i, moment in enumerate(moments_to_consider):
    # i = 0: first moment
    # i = 1: second moment

    ###################################
    # Instantaneous error (upper row) #
    ###################################
    ax = axes[0,i+1]
    #
    # plot horizontal line
    if i == 0:
        horizontal_label = ps.horizontal['label']
    else:
        horizontal_label = r'$E/L^2 = {0:3.2f}$'
    ax.axhline(ps.horizontal['value'],
               color=ps.horizontal['color'],
              label=horizontal_label.format(ps.horizontal['value']),
              lw=ps.horizontal['lw'])
    # 
    # plot relative error
    m_ana = moments_analytical[moment]
    for j, order in enumerate(orders_for_perturbative_moments):
        #
        y = moments_perturbative[moment][order]
        diff = np.fabs(y-m_ana)
        if relative:
            diff /= np.fabs(m_ana)
        else:
            diff /= dt
        #
        ax.plot(dt,
                diff,
                color=ps.perturbative['color'][j],
                dashes=ps.perturbative['dashes'][j],
                lw=ps.perturbative['lw'][j],
               label=ps.perturbative['label'][j],
               )
    #
    # set scales, limits, labels, legend
    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.set_xlim(*xlims)
    ax.set_ylim(*ylims_0[moment])
    ax.set_yticks([
                1e-10,1e-8,1e-6,
                1e-4,1e-2,1e0,1e2,
                ])
    ax.set_zorder(1)
    ax.legend(loc='lower right',
                framealpha=ps.framealpha,
                bbox_to_anchor=(1.1,-.25))
    ax.set_xticks([])
    if i == 0:
        ax.set_ylabel(r'$E_1/L$')
    else:
        ax.set_ylabel(r'$E_2/L^2$')

    ################################
    # running exponent (lower row) #
    ################################
    ax = axes[1,i+1]
    #
    for j, order in enumerate(orders_for_perturbative_moments):
        #
        y = moments_perturbative[moment][order]
        diff = np.fabs(y-m_ana)
        if relative:
            diff /= np.fabs(m_ana)
        else:
            diff /= dt
        #
        local_exponent = np.log(diff[2:]) - np.log(diff[:-2])
        local_exponent /= np.log(dt[2:]) - np.log(dt[:-2])
        #
        ax.plot(dt[1:-1],
                local_exponent,
                color=ps.perturbative['color'][j],
                dashes=ps.perturbative['dashes'][j],
                lw=ps.perturbative['lw'][j],
               label=ps.perturbative['label'][j],
            )
    #
    ax.set_xscale('log')
    ax.set_xlim(*xlims)
    ax.set_ylim(0,6.5)
    ax.set_yticks([0,1,2,3,4,5,6])
    ax.set_xlabel(r'$\Delta t/T$')
    ax.set_ylabel(r'$\kappa$')


###################################
# Add enumeration to the subplots #
###################################
enumeration = [ [ '(a)','(b)','(c)'],
               [ '(d)','(e)','(f)']]
for j,f in enumerate(enumeration):
    for i,e in enumerate(f):
        #
        ax = axes[j,i]
        #
        xlims = ax.get_xlim()
        dx = xlims[1] - xlims[0]
        ylims = ax.get_ylim()
        dy = ylims[1] - ylims[0]
        #
        ax.text(-0.3,
               1.,
               e,
     horizontalalignment='center',
     verticalalignment='center',
     transform = ax.transAxes,
                fontsize=25)

##############################
# Add titles to the subplots #
##############################
titles = ['Kramers-Moyal coefficients',
             #r'Relative error, $\alpha_1$',
             #r'Relative error, $\alpha_2$',
             r'Instantaneous error, $\alpha_1$',
             r'Instantaneous error, $\alpha_2$',
            ]

for i,e in enumerate(titles):
    ax = axes[0,i]
    ax.set_title(e,y=1.03)

    

plt.show()
fig.savefig('Fig3_exact_vs_approximate_Kramers-Moyal_coefficients.pdf',
            bbox_inches='tight')
plt.close(fig)