# For the comparison of different methods

In [None]:
import os
os.environ["JAX_ENABLE_X64"] = "true"
import sys
sys.path.append('..')
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from ipywidgets import interact
from ml_collections import ConfigDict
from models.ETD_KT_CM_JAX_Vectorised import *
from filters import resamplers
from filters.filter import ParticleFilter
from jax import config
from jax import vmap
import time

config.update("jax_enable_x64", True)
import scienceplots
plt.style.use(['science','ieee'])

Next, we update the method to have 256 spatial resolution, as to mimimise spatial errors.
We choose $dt = (10^{-5}) i$, for $i=1,...,1/16$. $T_{max} = 0.1$. 

In [None]:
key1 = jax.random.PRNGKey(0)
i=8
nmax = int(4000 * 2**i) #10000
tmax = 4.0 # 0.1 timestep
dt = 0.001*(1/2)**i # tmax/nmax
E=1; P=3
dW = jax.random.normal(key1, shape=(nmax, E, P))
dW_refine = []; W_refine = []; time_refine = []
for l in range(0, i+1):
    nmax = int(4000 * 2**(l))# the initial one is at lowest resolution 
    time_refine.append( jnp.linspace(0, tmax, nmax) )
    dW_refine.append(dW)
    W_refine.append(jnp.cumsum(dW, axis=0))
    dW = (dW + jnp.roll(dW, -1, axis=0))[::2, :, :]
dW_refine.reverse()
W_refine.reverse()
#time_refine.reverse()
print(nmax,len(time_refine[0]),len(W_refine[0]))
print(len(time_refine[-1]),len(W_refine[-1]))

final_save = []

@jax.jit
def relative_error_final(signal_final, analytic_final):
    return jnp.linalg.norm(signal_final-analytic_final)/jnp.linalg.norm(analytic_final)

array = jnp.zeros(8)
key = jax.random.PRNGKey(0)
for i in range(8):
    print(i)
    signal_params = ConfigDict(KDV_params_2_SALT)
    signal_params.update(E=1,method='Dealiased_SETDRK4',dt = 0.001*(1/2)**i, nt = int(4000 * 2**i) ,tmax=4.0,nx=256,P=3,S=0,noise_magnitude=0.01,stochastic_advection_basis='sin')
    signal_model = ETD_KT_CM_JAX_Vectorised(signal_params)
    initial_signal = initial_condition(signal_model.x, signal_params.E, signal_params.initial_condition)
    dW =(dW_refine[i]/jnp.sqrt(2**(8-i))) # renormalise. from the refined dW created.
    signal_final = signal_model.final_time_run(initial_signal, signal_model.params.nt, dW, key) #the final input is scan length? 
    final_save.append(signal_final)
for i in range(8):
    array = array.at[i].set(relative_error_final( final_save[i],final_save[-1]))
print(array)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import scienceplots
plt.style.use(['science','ieee'])
dt = 0.001*np.asarray([1.,1/2,1/4,1/8,1/16,1/32,1/64,1/128])
values = array
print(values,values.shape)
a = 1e-5
v=0.02
plt.figure()
plt.loglog(dt,values, marker='o', linestyle='-', color='b',label='CS-ETDRK4')

plt.loglog(dt,(dt**0.5)/a**0.5*v, linestyle=':', color='k', label=r'1/2 Order')
plt.loglog(dt,dt/a*v, linestyle='--', color='k', label='1st Order')
plt.loglog(dt,(dt**1.5)/a**1.5*v, linestyle='-.', color='k', label='3/2 Order')
plt.loglog(dt,dt**2/a**2*v, linestyle='-', color='k', label='2 Order')

plt.xlabel(r'timestep $\Delta t$')
plt.ylabel('Relative error')
plt.title('Log-Log Relative error \n Pathwise Temporal Convergence')
plt.grid(True, which="both", ls="--")
plt.legend()
#plt.savefig('/Users/jmw/Documents/GitHub/Particle_Filter/Saving/Temporal_convergence_2.png',bbox_inches='tight',dpi=300)
plt.show()

In [None]:
methods = ['Dealiased_IFSRK4','Dealiased_SETDRK4','Dealiased_SRK4']
max_number = 8
array = jnp.zeros([len(methods),max_number])
total_array = jnp.zeros([9,max_number])
key = jax.random.PRNGKey(0)
nmax = int(400 * 2**max_number)# highest resolution is 100 2^{16} timestes
tmax = 0.4
dt = 0.001*(1/2)**max_number
assert dt*nmax == tmax, "dt*nmax must equal tmax for the setup to be correct."
E=1; P=3
dW = jax.random.normal(key, shape=(nmax, E, P))

@jax.jit
def relative_error_final(signal_final, analytic_final):
    return jnp.linalg.norm(signal_final-analytic_final)/jnp.linalg.norm(analytic_final)

dW_refine = []; W_refine = []; time_refine = []
for l in range(0,max_number+1):
    nmax = int(400 * 2**(l))# the initial one is at lowest resolution 
    time_refine.append( jnp.linspace(0, tmax, nmax) )
    dW_refine.append(dW)
    W_refine.append(jnp.cumsum(dW, axis=0))
    dW = (dW + jnp.roll(dW, -1, axis=0))[::2, :, :]
dW_refine.reverse()
W_refine.reverse()

signal_params = ConfigDict(KDV_params_2_SALT)
signal_params.update(E=1,method='Dealiased_SETDRK4',dt= 0.001*(1/2)**14, nt = int(400 * 2**14),tmax=0.4,nx=256,P=3,S=0,noise_magnitude=0.01,stochastic_advection_basis='sin')
signal_model = ETD_KT_CM_JAX_Vectorised(signal_params)
initial_signal = initial_condition(signal_model.x, signal_params.E, signal_params.initial_condition)
analytic_final = signal_model.final_time_run(initial_signal, signal_model.params.nt, dW, key) #the final input is scan length? 
del signal_model
print("analytic done")

for j, method in enumerate(methods):
    for i in range(max_number):
        print(j,i,method)
        signal_params = ConfigDict(KDV_params_2_SALT)
        signal_params.update(E=1,method=method,dt= 0.001*(1/2)**i, nt = int(400 * 2**i),tmax=0.4,nx=256,P=3,S=0,noise_magnitude=0.01,stochastic_advection_basis='sin')
        signal_model = ETD_KT_CM_JAX_Vectorised(signal_params)
        initial_signal = initial_condition(signal_model.x, signal_params.E, signal_params.initial_condition)
        dW =(dW_refine[i]/jnp.sqrt(2**(max_number-i))) # renormalise. from the refined dW created.
        signal_final = signal_model.final_time_run(initial_signal, signal_model.params.nt, dW, key) #the final input is scan length? 
        array = array.at[j,i].set(relative_error_final(signal_final[-1, :], analytic_final[-1, :]))
        del signal_model
        del signal_params
        print("done",j,i,method)
for j in range(len(methods)):
    total_array = total_array.at[j,:].set(array[j,:])



In [None]:
import numpy as np
import matplotlib.pyplot as plt
import scienceplots
plt.style.use(['science','ieee'])
dt = 0.001**np.asarray([2**-i for i in range(0,max_number)])
values = array
print(values)
print(values.shape)
values = array
print(values,values.shape)
a = 1e-3
v= 0.02
plt.figure()
#plt.scatter(a,0.1, color='r', label='point')
# Mask values greater than 1 before plotting
values = np.where(values > 2, np.nan, values)
plt.loglog(dt,(dt**0.5)/a**0.5*v, linestyle=':', color='k', label=r'1/2 Order')
plt.loglog(dt,dt/a*v, linestyle='--', color='k', label='1st Order')
plt.loglog(dt,(dt**1.5)/a**1.5*v, linestyle='-.', color='k', label='3/2 Order')
plt.loglog(dt,dt**2/a**2*v, linestyle='-', color='k', label='2 Order')
marker_list = ['o', 's', '^', 'D', 'v', 'x', '+', '*', 'p', 'H', 'd', 'X']
color_list = ['b', 'g', 'r', 'c', 'm', 'y']
# In matplotlib, the marker for an upward-pointing triangle is '^'
for j, method in enumerate(methods):
    marker = marker_list[j % len(marker_list)]
    color = color_list[j % len(color_list)]
    plt.loglog(dt, values[j, :], marker=marker,color=color, linestyle='-', label=f'{method.replace("Dealiased_", "")}', markerfacecolor='none')

# for j,method in enumerate(methods):
#     plt.loglog(dt,values[j,:], marker='o', linestyle='-',label=f'{method}')
# Highlight 4th order convergence for SETDRK4 near first non-nan values
# setdrk_idx = methods.index('Dealiased_SETDRK4')
# setdrk_values = values[setdrk_idx]
# # Find the first non-nan index
# first_valid = np.where(~np.isnan(setdrk_values))[0][0]
# # Select a window of points (e.g., 4 points) for the fit
# window = 2
# idxs = np.arange(first_valid, min(first_valid + window, len(setdrk_values)))
# x_fit = dt[idxs]
# y_fit = setdrk_values[idxs]

# # Fit a line in log-log space (slope should be close to 4 for 4th order)
# coeffs = np.polyfit(np.log(x_fit), np.log(y_fit), 1)
# print(f"Fitted coefficients for SETDRK4: {coeffs}")
# y_fit_line = np.exp(np.polyval(coeffs, np.log(x_fit)))
# plt.loglog(x_fit, y_fit_line, 'k--', linewidth=2, label=f'O({coeffs[0]:.3g}) Order Fit (SETDRK4)')

plt.xlabel(r'timestep $\Delta t$')
plt.ylabel('Relative error')
plt.title('Log-Log Relative error \n Pathwise Temporal Convergence')
#plt.grid(True, which="both", ls="--")
plt.legend()
plt.legend(fontsize='small', loc='best', ncol=2)
#plt.grid(True, which="both", ls="--")
#plt.savefig('/Users/jmw/Documents/GitHub/Particle_Filter/Saving/EX5_G_Temporal_convergence_RK4_IFRK4_ETDRK4.png',bbox_inches='tight',dpi=300)
plt.show()

In [None]:
methods = ['Dealiased_eSSPIFSRK_P_33','Dealiased_SETDRK33','Dealiased_SSP33']
max_number = 14
array = jnp.zeros([len(methods),max_number])
key = jax.random.PRNGKey(0)
nmax = int(1e2 * 2**max_number)
tmax = 0.1 
dt = 1e-3*(1/2)**max_number
assert dt*nmax == tmax, "dt*nmax must equal tmax for the setup to be correct."
E=1; P=3
dW = jax.random.normal(key, shape=(nmax, E, P))

@jax.jit
def relative_error_final(signal_final, analytic_final):
    return jnp.linalg.norm(signal_final-analytic_final)/jnp.linalg.norm(analytic_final)

dW_refine = []; W_refine = []; time_refine = []
for l in range(0,max_number+1):
    nmax = int(1e2 * 2**(l))# the initial one is at lowest resolution 
    time_refine.append( jnp.linspace(0, tmax, nmax) )
    dW_refine.append(dW)
    W_refine.append(jnp.cumsum(dW, axis=0))
    dW = (dW + jnp.roll(dW, -1, axis=0))[::2, :, :]
dW_refine.reverse()
W_refine.reverse()



for j, method in enumerate(methods):
    for i in range(max_number):
        print(i,method)
        signal_params = ConfigDict(KDV_params_2_SALT)
        signal_params.update(E=1,method=method,dt= 1e-3*(1/2)**i, nt = int(1e2 * 2**i),tmax=0.1,nx=256,P=3,S=0,noise_magnitude=1.0, initial_condition = 'very_steep_traveling_wave')
        signal_model = ETD_KT_CM_JAX_Vectorised(signal_params)
        initial_signal = initial_condition(signal_model.x, signal_params.E, signal_params.initial_condition)
        dW =(dW_refine[i]/jnp.sqrt(2**(max_number-i))) # renormalise. from the refined dW created.
        signal_final = signal_model.final_time_run(initial_signal, signal_model.params.nt, dW, key) #the final input is scan length? 
        W = jnp.cumsum(dW, axis=0)
        W = jnp.sqrt(signal_model.params.dt) * signal_model.params.noise_magnitude * W
        W_new = jnp.zeros([signal_model.params.nt+1, signal_model.params.E, signal_model.params.P])
        W_new = W_new.at[1:,:,:].set(W)

        E = signal_model.params.E
        nmax = signal_model.params.nt
        nx = signal_model.params.nx
        x = signal_model.x
        xmax = signal_model.params.xmax
        xmin = signal_model.params.xmin
        dt = signal_model.params.dt

        initial_condition_jitted = jax.jit(initial_condition, static_argnums=(1,2))# jitted with last two arguments frozen
        def compute_ans(n, x, dt, W_new, xmax, E, initial_condition):
            return initial_condition_jitted((x - 9 * 9 * (dt * n) - W_new[n, :, :] + xmax) % (xmax * 2) - xmax, E, signal_params.initial_condition)

        analytic_final = compute_ans(nmax, x, dt, W_new, xmax, E, signal_params.initial_condition)
    
        array = array.at[j,i].set(relative_error_final(signal_final[-1, :], analytic_final[-1, :]))
        del signal_model

for j in range(len(methods)):
    total_array = total_array.at[j+3,:].set(array[j,:])

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import scienceplots
plt.style.use(['science','ieee'])
dt = 1e-3*np.asarray([2**-i for i in range(0,max_number)])
values = array
print(values)
print(values.shape)
values = array
print(values,values.shape)
a = 1e-3
v= 2
plt.figure()
#plt.scatter(a,0.1, color='r', label='point')
# Mask values greater than 1 before plotting
values = np.where(values > 1, np.nan, values)
plt.loglog(dt,(dt**0.5)/a**0.5*v, linestyle=':', color='k', label=r'1/2 Order')
plt.loglog(dt,dt/a*v, linestyle='--', color='k', label='1st Order')
plt.loglog(dt,(dt**1.5)/a**1.5*v, linestyle='-.', color='k', label='3/2 Order')
plt.loglog(dt,dt**2/a**2*v, linestyle='-', color='k', label='2 Order')

marker_list = ['o', 's', '^', 'D', 'v', 'x', '+', '*', 'p', 'H', 'd', 'X']
color_list = ['b', 'g', 'r', 'c', 'm', 'y']
# In matplotlib, the marker for an upward-pointing triangle is '^'
for j, method in enumerate(methods):
    marker = marker_list[j % len(marker_list)]
    color = color_list[j % len(color_list)]
    plt.loglog(dt, values[j, :], marker=marker,color=color, linestyle='-', label=f'{method.replace("Dealiased_", "")}', markerfacecolor='none')

plt.xlabel(r'timestep $\Delta t$')
plt.ylabel('Relative error')
plt.title('Log-Log Relative error \n Pathwise Temporal Convergence')
#plt.grid(True, which="both", ls="--")
plt.legend()
plt.legend(fontsize='small', loc='best', ncol=2)

#plt.savefig('/Users/jmw/Documents/GitHub/Particle_Filter/Saving/EX5_G_Temporal_convergence_SSP33_ETDRK3_eSSPIFSRK_P_33.png',bbox_inches='tight',dpi=300)
plt.show()

In [None]:
methods = ['Dealiased_eSSPIFSRK_P_22','Dealiased_SETDRK22','Dealiased_SSP22']
max_number = 14
array = jnp.zeros([len(methods),max_number])
key = jax.random.PRNGKey(0)
nmax = int(1e2 * 2**max_number)
tmax = 0.1 
dt = 1e-3*(1/2)**max_number
assert dt*nmax == tmax, "dt*nmax must equal tmax for the setup to be correct."
E=1; P=3
dW = jax.random.normal(key, shape=(nmax, E, P))

@jax.jit
def relative_error_final(signal_final, analytic_final):
    return jnp.linalg.norm(signal_final-analytic_final)/jnp.linalg.norm(analytic_final)

dW_refine = []; W_refine = []; time_refine = []
for l in range(0,max_number+1):
    nmax = int(1e2 * 2**(l))# the initial one is at lowest resolution 
    time_refine.append( jnp.linspace(0, tmax, nmax) )
    dW_refine.append(dW)
    W_refine.append(jnp.cumsum(dW, axis=0))
    dW = (dW + jnp.roll(dW, -1, axis=0))[::2, :, :]
dW_refine.reverse()
W_refine.reverse()



for j, method in enumerate(methods):
    for i in range(max_number):
        print(i,method)
        signal_params = ConfigDict(KDV_params_2_SALT)
        signal_params.update(E=1,method=method,dt= 1e-3*(1/2)**i, nt = int(1e2 * 2**i),tmax=0.1,nx=256,P=3,S=0,noise_magnitude=1.0, initial_condition = 'very_steep_traveling_wave')
        signal_model = ETD_KT_CM_JAX_Vectorised(signal_params)
        initial_signal = initial_condition(signal_model.x, signal_params.E, signal_params.initial_condition)
        dW =(dW_refine[i]/jnp.sqrt(2**(max_number-i))) # renormalise. from the refined dW created.
        signal_final = signal_model.final_time_run(initial_signal, signal_model.params.nt, dW, key) #the final input is scan length? 
        W = jnp.cumsum(dW, axis=0)
        W = jnp.sqrt(signal_model.params.dt) * signal_model.params.noise_magnitude * W
        W_new = jnp.zeros([signal_model.params.nt+1, signal_model.params.E, signal_model.params.P])
        W_new = W_new.at[1:,:,:].set(W)

        E = signal_model.params.E
        nmax = signal_model.params.nt
        nx = signal_model.params.nx
        x = signal_model.x
        xmax = signal_model.params.xmax
        xmin = signal_model.params.xmin
        dt = signal_model.params.dt

        initial_condition_jitted = jax.jit(initial_condition, static_argnums=(1,2))# jitted with last two arguments frozen
        def compute_ans(n, x, dt, W_new, xmax, E, initial_condition):
            return initial_condition_jitted((x - 9 * 9 * (dt * n) - W_new[n, :, :] + xmax) % (xmax * 2) - xmax, E, signal_params.initial_condition)

        analytic_final = compute_ans(nmax, x, dt, W_new, xmax, E, signal_params.initial_condition)
    
        array = array.at[j,i].set(relative_error_final(signal_final[-1, :], analytic_final[-1, :]))
        del signal_model

for j in range(len(methods)):
    total_array = total_array.at[j+6,:].set(array[j,:])

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import scienceplots
plt.style.use(['science','ieee'])
dt = 1e-3*np.asarray([2**-i for i in range(0,max_number)])
values = array
print(values)
print(values.shape)
values = array
print(values,values.shape)
a = 1e-3#
v= 2#
plt.figure()
#plt.scatter(a,0.1, color='r', label='point')
# Mask values greater than 1 before plotting
values = np.where(values > 1, np.nan, values)
plt.loglog(dt,(dt**0.5)/a**0.5*v, linestyle=':', color='k', label=r'1/2 Order')
plt.loglog(dt,dt/a*v, linestyle='--', color='k', label='1st Order')
plt.loglog(dt,(dt**1.5)/a**1.5*v, linestyle='-.', color='k', label='3/2 Order')
plt.loglog(dt,dt**2/a**2*v, linestyle='-', color='k', label='2 Order')

marker_list = ['o', 's', '^', 'D', 'v', 'x', '+', '*', 'p', 'H', 'd', 'X']
color_list = ['b', 'g', 'r', 'c', 'm', 'y']
# In matplotlib, the marker for an upward-pointing triangle is '^'
for j, method in enumerate(methods):
    marker = marker_list[j % len(marker_list)]
    color = color_list[j % len(color_list)]
    plt.loglog(dt, values[j, :], marker=marker,color=color, linestyle='-', label=f'{method.replace("Dealiased_", "")}', markerfacecolor='none')

plt.xlabel(r'timestep $\Delta t$')
plt.ylabel('Relative error')
plt.title('Log-Log Relative error \n Pathwise Temporal Convergence')
#plt.grid(True, which="both", ls="--")
plt.legend()
plt.legend(fontsize='small', loc='best', ncol=2)
#plt.grid(True, which="both", ls="--")

#plt.savefig('/Users/jmw/Documents/GitHub/Particle_Filter/Saving/EX5_G_Temporal_convergence_SSP22_ETDRK22_eSSPIFSRK_P_22.png',bbox_inches='tight',dpi=300)
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import scienceplots
plt.style.use(['science','ieee'])
methods = ['Dealiased_IFSRK4','Dealiased_SETDRK4','Dealiased_SRK4',
           'Dealiased_eSSPIFSRK_P_33','Dealiased_SETDRK33','Dealiased_SSP33',
           'Dealiased_eSSPIFSRK_P_22','Dealiased_SETDRK22','Dealiased_SSP22']
max_number = 14
dt = 1e-3*np.asarray([2**-i for i in range(0,max_number)])
values = array
print(values)
print(values.shape)
values = array
print(values,values.shape)
a = 1e-3#
v= 2#
plt.figure()
#plt.scatter(a,0.1, color='r', label='point')
# Mask values greater than 1 before plotting
values = np.where(values > 1, np.nan, values)
plt.loglog(dt,(dt**0.5)/a**0.5*v, linestyle=':', color='k', label=r'1/2 Order')
plt.loglog(dt,dt/a*v, linestyle='--', color='k', label='1st Order')
plt.loglog(dt,(dt**1.5)/a**1.5*v, linestyle='-.', color='k', label='3/2 Order')
plt.loglog(dt,dt**2/a**2*v, linestyle='-', color='k', label='2 Order')

marker_list = ['o', 's', '^', 'D', 'v', 'x', '+', '*', 'p']
# Use a colorblind-friendly palette for better accessibility
color_list = [
    "#0072B2",  # Blue
    "#D55E00",  # Vermillion
    "#009E73",  # Green
    "#F0E442",  # Yellow
    "#CC79A7",  # Purple
    "#56B4E9",  # Sky Blue
    "#E69F00",  # Orange
    "#000000",  # Black
    "#999999"   # Grey
]
#color_list = ["#0d23e8", "#32e80d", "#e81c0d", "#0da6e8", "#c4e80d", "#e85a0d",  "#24a3a5", "#98a524", "#a0792c"]
linestyle_list = ['-', '--', '-.', ':', (0, (3, 1, 1, 1)), (0, (5, 5)), (0, (1, 10)), (0, (3, 5, 1, 5)), (0, (5, 1))]
# In matplotlib, the marker for an upward-pointing triangle is '^'
for j, method in enumerate(methods):
    marker = marker_list[j % len(marker_list)]
    color = color_list[j % len(color_list)]
    linestyle = linestyle_list[j % len(linestyle_list)]
    plt.loglog(dt, total_array[j, :], marker=marker, color=color, linestyle=linestyle, label=f'{method.replace("Dealiased_", "")}',  markersize=4)
    #plt.loglog(dt, total_array[j, :], marker=marker,color=color, linestyle=linestyle, label=f'{method.replace("Dealiased_", "")}')

plt.xlabel(r'timestep $\Delta t$')
plt.ylabel('Relative error')
plt.title('Log-Log Relative error \n Pathwise Temporal Convergence')
#plt.grid(True, which="both", ls="--")
plt.legend()
plt.legend(fontsize='small', loc='best', ncol=2)
#plt.grid(True, which="both", ls="--")

#plt.savefig('/Users/jmw/Documents/GitHub/Particle_Filter/Saving/EX5_G_Temporal_convergence_Allatty.png',bbox_inches='tight',dpi=300)
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import scienceplots
plt.style.use(['science', 'ieee'])

# --- Inputs ---
methods = [
    'Dealiased_IFSRK4', 'Dealiased_SETDRK4', 'Dealiased_SRK4',
    'Dealiased_eSSPIFSRK_P_33', 'Dealiased_SETDRK33', 'Dealiased_SSP33',
    'Dealiased_eSSPIFSRK_P_22', 'Dealiased_SETDRK22', 'Dealiased_SSP22'
]

max_number = 14
dt = 1e-3 * np.asarray([2**-i for i in range(max_number)])

# Replace this with your actual data (assumed shape [9, max_number])
# Example placeholder (remove once array is loaded):
# total_array = np.random.rand(len(methods), max_number) * (dt**2)[None, :]
# Ensure values and total_array are defined
values = total_array  # If 'array' is the same as 'total_array'

# --- Plot Setup ---
fig, ax = plt.subplots(figsize=(6.5, 5))  # Slightly wider

# Reference lines for convergence order
a = 1e-3
v = 2
order_styles = {
    r'1/2 Order': (dt**0.5)/(a**0.5)*v,
    r'1st Order': dt/a*v,
    r'3/2 Order': (dt**1.5)/(a**1.5)*v,
    r'2nd Order': (dt**2)/(a**2)*v
}
order_linestyles = [':', '--', '-.', '-']
order_colors = ['k']*4
for (label, values), ls, color in zip(order_styles.items(), order_linestyles, order_colors):
    ax.loglog(dt, values, linestyle=ls, color=color, linewidth=1, alpha=0.6, label=label)

# Markers, colors, and linestyles
marker_list = ['o', 's', '^', 'D', 'v', 'x', '+', '*', 'p']
color_list = [
    "#0072B2", "#D55E00", "#009E73", "#F0E442",
    "#CC79A7", "#56B4E9", "#E69F00", "#000000", "#999999"
]
linestyle_list = ['-', '-', '-', '-.','-.','-.', '--', '--', '--']

# Plot each method
for j, method in enumerate(methods):
    label = method.replace("Dealiased_", "")
    marker = marker_list[j % len(marker_list)]
    color = color_list[j % len(color_list)]
    linestyle = linestyle_list[j % len(linestyle_list)]
    ax.loglog(dt, total_array[j, :], marker=marker, linestyle=linestyle,
              color=color, label=label, markersize=5, linewidth=1.5)

# --- Labels, Legend, Title ---
ax.set_xlabel(r'Timestep $\Delta t$')
ax.set_ylabel('Relative Error')
ax.set_title('Pathwise Temporal Convergence\n(Log–Log Relative Error)', fontsize=12)

# Combine and organize legend
ax.legend(fontsize='medium', loc='best', ncol=3, frameon=False)

# Tight layout and save
plt.tight_layout()
#plt.savefig('EX5_G_Temporal_convergence_Allatty.pdf', bbox_inches='tight')  # Vector version
#plt.savefig('EX5_G_Temporal_convergence_Allatty.png', bbox_inches='tight', dpi=300)

plt.show()