In [1]:
from math import sqrt
import numpy as np
from scipy.optimize import root
import time

def solve_ode(step_func, f, y0, t0, t_end, n):
    t_vals = [t0]
    y_vals = [np.array(y0)]

    t = t0
    y = np.array(y0)

    h = (t_end - t0) / n
    for i in range(n):
        y = step_func(f, y, h)
        t += h
        t_vals.append(t)
        y_vals.append(y.copy())
    
    return np.array(t_vals), np.array(y_vals)

In [2]:
def implicit_midpoint_step(f, y_n, h, tol=1e-10, max_iter=100):
    def G(y_next):
        midpoint = (y_n + y_next) / 2
        return y_next - y_n - h * f(midpoint)
    
    sol = root(G, y_n, method='hybr', tol=tol, options={'maxfev': max_iter})
    
    if not sol.success:
        raise RuntimeError(f"Root solver failed: {sol.message}")
    
    return sol.x

In [3]:
half = 1/2
quarter = 1/4

def EES25_1_step(f, y_n, h, tol=1e-10, max_iter=100):
    k1 = f(y_n)
    k2 = f(y_n + h * half * k1)
    k3 = f(y_n + h * k2)
    return y_n + h * (quarter * (k1 + k3) + half * (k2))

In [4]:
third = 1/3
EES25_a31 = -5/48
EES25_a32 = 15/16
tenth = 1 / 10
twofifths = 2/5

def EES25_2_step(f, y_n, h, tol=1e-10, max_iter=100):
    k1 = f(y_n)
    k2 = f(y_n + h * third * k1)
    k3 = f(y_n + h * (EES25_a31 * k1 + EES25_a32 * k2))
    return y_n + h * (tenth * k1 + half * k2 + twofifths * k3)

In [5]:
EES27_1_a21 = (2 - sqrt(2)) / 2
EES27_1_a32 = sqrt(2) / 2
EES27_1_b1 = (2 - sqrt(2)) / 4
EES27_1_b2 = sqrt(2) / 4

def EES27_1_step(f, y_n, h, tol=1e-10, max_iter=100):
    k1 = f(y_n)
    k2 = f(y_n + h * EES27_1_a21 * k1)
    k3 = f(y_n + h * EES27_1_a32 * k2)
    k4 = f(y_n + h * (EES27_1_a21 * k1 + EES27_1_a32 * k3))
    return y_n + h * (EES27_1_b1 * (k1 + k4) + EES27_1_b2 * (k2 + k3))

In [6]:
EES27_2_a21 = (2 - sqrt(2)) / 3
EES27_2_a31 = (-4+sqrt(2)) / 24
EES27_2_a32 = (4 + sqrt(2)) / 8
EES27_2_a41 = (-176 + 145 * sqrt(2)) / 168
EES27_2_a42 = (8 - 5*sqrt(2)) * (3/56)
EES27_2_a43 = (3 - sqrt(2))*(3/7)


EES27_2_b1 = (5 - 3 * sqrt(2)) / 14
EES27_2_b2 = (3 + sqrt(2))/14
EES27_2_b3 = (-1+2*sqrt(2)) * (3/14)
EES27_2_b4 = (9 - 4*sqrt(2)) / 14

def EES27_2_step(f, y_n, h, tol=1e-10, max_iter=100):
    k1 = f(y_n)
    k2 = f(y_n + h * EES27_2_a21 * k1)
    k3 = f(y_n + h * (EES27_2_a31 * k1 + EES27_2_a32 * k2))
    k4 = f(y_n + h * (EES27_2_a41 * k1 + EES27_2_a42 * k2 + EES27_2_a43 * k3))
    return y_n + h * (EES27_2_b1 * k1 + EES27_2_b2 * k2 + EES27_2_b3 * k3 + EES27_2_b4 * k4)

In [7]:
def run_inverse_square_attraction(name, step_func):

    def inverse_square_attraction_func(y):
        return np.array([y[2], y[3], -y[0] / ((y[0]**2 + y[1]**2)**(3/2)), -y[1] / ((y[0]**2 + y[1]**2)**(3/2))])
    
    y0 = np.array([1.0, 0.0, 0.0, 1.0])
    t0, t_end, n = 0.0, 10.0, 100
    
    true_end = np.array([np.cos(t_end), np.sin(t_end), -np.sin(t_end), np.cos(t_end)])

    start = time.time()
    t_vals, y_vals = solve_ode(step_func, inverse_square_attraction_func, y0, t0, t_end, n)
    t_vals, ic_vals = solve_ode(step_func, inverse_square_attraction_func, y_vals[-1,:], t_vals[-1], t0, n)
    end = time.time()
    time_ = end - start
    end_err_ = y_vals[-1,:] - true_end
    start_err_ = ic_vals[-1, :] - y0
    
    return time_, end_err_, start_err_

In [8]:
import time

times = []
end_err = []
start_err = []

solvers = [
    ("Implicit Midpoint", implicit_midpoint_step),
    ("EES(1,5;1/4)", EES25_1_step),
    ("EES(2,5;1/10)", EES25_2_step),
    ("EES(2,7;(2-sqrt(2))/4)", EES27_1_step),
    ("EES(2,7;(5-3*sqrt(2))/14)", EES27_2_step),
]

num_runs = 100

for name, step_func in solvers:
    print("Starting " + name)
    time_ = 0
    for i in range(num_runs):
        t, end_err_, start_err_ = run_inverse_square_attraction(name, step_func)
        time_ += t
    time_ /= num_runs
    times.append(time_)
    end_err.append(end_err_)
    start_err.append(start_err_)
    print("Done " + name + "\n")

Starting Implicit Midpoint
Done Implicit Midpoint

Starting EES(1,5;1/4)
Done EES(1,5;1/4)

Starting EES(2,5;1/10)
Done EES(2,5;1/10)

Starting EES(2,7;(2-sqrt(2))/4)
Done EES(2,7;(2-sqrt(2))/4)

Starting EES(2,7;(5-3*sqrt(2))/14)
Done EES(2,7;(5-3*sqrt(2))/14)



In [9]:
end_norm = [np.linalg.norm(xx) for xx in end_err]
start_norm = [np.linalg.norm(xx) for xx in start_err]

In [10]:
for i in range(len(times)):
    print(solvers[i][0] + " & " + str(times[i]) + " & " + str(end_norm[i]) + " & " + str(start_norm[i]))

Implicit Midpoint & 0.06353169202804565 & 0.10062566068818622 & 1.6938588504653234e-13
EES(1,5;1/4) & 0.009762468338012696 & 0.04923197537788734 & 3.214276575635452e-05
EES(2,5;1/10) & 0.008165018558502197 & 0.030920947707525504 & 7.863932427188999e-07
EES(2,7;(2-sqrt(2))/4) & 0.013017115592956542 & 0.023966974179207456 & 2.1529930071239566e-10
EES(2,7;(5-3*sqrt(2))/14) & 0.016941308975219727 & 0.015040645748651533 & 4.954544710658016e-10
