## In this notebook, I tested an implementation of the IRK4 Method in python

Included pieces:
* THe implementation itself
* Some simple tests of utilities
* Simple examples of applications

In [None]:
import numpy as np
from scipy.optimize import root
import matplotlib.pyplot as plt

class IRK4Solver:
    
    def __init__(self):
        # IRK4 coefficients (Gauss-Legendre quadrature)
        self.A = np.array([
            [0.25, 0.25 - np.sqrt(3)/6],
            [0.25 + np.sqrt(3)/6, 0.25]
        ])
        self.b = np.array([0.5, 0.5])
        self.c = np.array([0.5 - np.sqrt(3)/6, 0.5 + np.sqrt(3)/6])

    def solve(self, f, y0, t_span, h):
        t0, t_end = t_span
        t_values = [t0]
        y_values = [y0]
        
        y = y0
        t = t0
        
        while t < t_end:
            if t + h > t_end:  # Adjust the last step size
                h = t_end - t
            
            # Solve for stages using nonlinear solver
            def residual(Y):
                # Y contains the stages, reshape for multiple dimensions
                Y = np.reshape(Y, (-1, len(y)))
                res = np.zeros_like(Y)
                for i in range(len(self.b)):
                    res[i] = Y[i] - y - h * sum(self.A[i, j] * f(t + self.c[j] * h, Y[j]) for j in range(len(self.b)))
                return res.ravel()
            
            # Initial guess for stages
            Y0 = np.tile(y, len(self.b))
            
            # Solve the nonlinear system
            sol = root(residual, Y0)
            if not sol.success:
                raise RuntimeError("Nonlinear solver failed to converge")
            
            # Extract stages
            Y = np.reshape(sol.x, (-1, len(y)))
            
            # Update solution
            y = y + h * sum(self.b[i] * f(t + self.c[i] * h, Y[i]) for i in range(len(self.b)))
            t += h
            
            t_values.append(t)
            y_values.append(y)
        
        return np.array(t_values), np.array(y_values)

# Example usage for a system of ODEs
if __name__ == "__main__":
    # Define the system of ODEs
    def f(t, y):
        y1, y2 = y
        return np.array([y2, -y1])

    # Parameters
    y0 = np.array([1, 0])  # Initial condition (y1=1, y2=0)
    t_span = (0, 2 * np.pi)  # One full period
    h = 0.1  # Time step

    # Create an IRK4 solver instance
    solver = IRK4Solver()

    # Solve the system
    t_values, y_values = solver.solve(f, y0, t_span, h)

    # Analytical solution
    y1_analytical = np.cos(t_values)
    y2_analytical = -np.sin(t_values)

    # Extract numerical solutions
    y1_numerical = y_values[:, 0]
    y2_numerical = y_values[:, 1]

    # Plot results
    plt.figure(figsize=(10, 6))
    plt.plot(t_values, y1_analytical, label="y1 Analytical", linestyle="dashed", color="black")
    plt.plot(t_values, y2_analytical, label="y2 Analytical", linestyle="dashed", color="gray")
    plt.plot(t_values, y1_numerical, 'o-', label="y1 Numerical", color="blue")
    plt.plot(t_values, y2_numerical, 'o-', label="y2 Numerical", color="red")
    plt.xlabel("Time")
    plt.ylabel("y")
    plt.title("System of ODEs Solved with IRK4")
    plt.legend()
    plt.grid()
    plt.show()


In [2]:
run_test = True

def test_IRK4Solver():
    '''A test with a simple f = - 2 * y
    '''
    # Define the ODE function
    def f(t, y):
        return -2 * y

    # Parameters
    y0 = np.array([1])  # Initial condition
    t_span = (0, 5)  # Time span
    h = 0.1  # Time step

    # Create an IRK4 solver instance
    solver = IRK4Solver()

    # Solve the ODE
    t_values, y_values = solver.solve(f, y0, t_span, h)

    # Analytical solution
    y_analytical = np.exp(-2 * t_values)

    # Verify that numerical and analytical solutions are close
    assert np.allclose(y_values.flatten(), y_analytical, atol=1e-4), "Numerical solution does not match analytical solution"


def test_IRK4Solver_system_of_odes():
    '''A test with a system of odes
        y1' = y2,
        y2' = -y1
    '''
    # Define the system of ODEs
    def f(t, y):
        y1, y2 = y
        return np.array([y2, -y1])

    # Parameters
    y0 = np.array([1, 0])  # Initial condition (y1=1, y2=0)
    t_span = (0, 2 * np.pi)  # One full period
    h = 0.1  # Time step

    # Create an IRK4 solver instance
    solver = IRK4Solver()

    # Solve the system
    t_values, y_values = solver.solve(f, y0, t_span, h)

    # Analytical solution
    y1_analytical = np.cos(t_values)
    y2_analytical = -np.sin(t_values)

    # Extract numerical solutions
    y1_numerical = y_values[:, 0]
    y2_numerical = y_values[:, 1]

    # Verify that numerical and analytical solutions are close
    assert np.allclose(y1_numerical, y1_analytical, atol=1e-4), "y1 numerical solution does not match analytical solution"
    assert np.allclose(y2_numerical, y2_analytical, atol=1e-4), "y2 numerical solution does not match analytical solution"


if run_test:
    test_IRK4Solver()
    test_IRK4Solver_system_of_odes()

### **1. Seismology: Modeling Wave Propagation**
- **Problem:** Solving the wave equation for seismic wave propagation in heterogeneous media.
- **Example Application:** 
  - Use IRK4 to solve the second-order wave equation:
    $$\frac{\partial^2 u}{\partial t^2} = c^2 \nabla^2 u$$
    where \( u \) is the displacement and \( c \) is the wave speed.
  - Convert to a system of first-order ODEs and apply IRK4 to simulate wave behavior in complex geological structures.

Great! Let’s dive into modeling wave propagation in seismology using **IRK4**. We’ll work with the **1D wave equation** as a starting point:

- **Conversion to a System of First-Order ODEs:**
  Because the derivative to t is a higher order, we need to introduce new variables and convert
  from one higher-order ODE to a system of first-order ODEs.
  Let:
  $$
  v = \frac{\partial u}{\partial t}
  $$
  Then:
  $$
  \frac{\partial u}{\partial t} = v
  $$
  $$
  \frac{\partial v}{\partial t} = c^2 \frac{\partial^2 u}{\partial x^2}
  $$

- **Discretization in Space:**
  Using finite differences for \( \frac{\partial^2 u}{\partial x^2} \):
  $$
  \frac{\partial^2 u}{\partial x^2} \approx \frac{u_{i-1} - 2u_i + u_{i+1}}{\Delta x^2}
  $$

- **Additional notes:**
  
* In case we need to create a new variable v, the solution vector y would look like [u, v],
  which merges u and v.
* The boundary are set to 0 through an inital 0 value, and a zero value for the dvdt in the function of "wave_equation"

In [None]:
import time

# Problem setup
def wave_equation(t, y, c, dx):
    N = len(y) // 2
    u = y[:N]
    v = y[N:]
    
    # Spatial derivatives using finite differences
    dudt = v
    dvdt = np.zeros_like(u)
    dvdt[1:-1] = c**2 * (u[:-2] - 2 * u[1:-1] + u[2:]) / dx**2
    
    # Combine into a single vector
    return np.concatenate([dudt, dvdt])

# Parameters
L = 10.0  # Length of the domain
N = 100  # Number of spatial points
dx = L / (N - 1)  # Spatial step size
c = 1.0  # Wave speed
t_span = (0, 5)  # Time span
h = 0.01  # Time step

# Initial conditions
x = np.linspace(0, L, N)
u0 = np.exp(-0.5 * (x - L/2)**2)  # Gaussian initial displacement
v0 = np.zeros_like(u0)  # Initial velocity
y0 = np.concatenate([u0, v0])  # Combine into a single vector

# Solve the wave equation
start = time.time()
solver = IRK4Solver()
t_values, y_values = solver.solve(lambda t, y: wave_equation(t, y, c, dx), y0, t_span, h)
end = time.time()
print("Solving the IRK4 system takes %.2f s" % (end - start))

# Extract and plot results
u_values = y_values[:, :N]
plt.figure(figsize=(10, 6))
for i in range(0, len(t_values), len(t_values)//10):  # Plot at 10 time steps
    plt.plot(x, u_values[i], label=f"t={t_values[i]:.2f}")
plt.xlabel("x")
plt.ylabel("u(x, t)")
plt.title("Wave Propagation")
plt.legend()
plt.grid()
plt.show()

In [None]:
u_values[450, 10]

### **2. Geodynamics: Thermal Convection in the Mantle**
- **Problem:** Time integration of mantle convection equations or heat transport in the lithosphere.
- **Example Application:**
  - Use IRK4 to solve the heat equation:
    $$\frac{\partial T}{\partial t} = \nabla \cdot (\kappa \nabla T) + H$$
    where:
    - \( T \): Temperature
    - \( \kappa \): Thermal diffusivity
    - \( H \): Heat production
  - This method helps handle stiff systems when including phase changes or high-temperature gradients.
- **Additional Notes:**
  - Without the heating term, the static state is a lenear curve.

In [None]:
import time

# Problem setup
def heat_equation(t, T, kappa, dx, H):
    N = len(T)
    dTdt = np.zeros_like(T)
    
    # Diffusion term (finite differences) for interior points
    dTdt[1:-1] = kappa * (T[:-2] - 2 * T[1:-1] + T[2:]) / dx**2
    
    # Add heat source for interior points
    dTdt[1:-1] += H
    
    # Boundary conditions: ensure no change at boundaries
    dTdt[0] = 0  # No change at x = 0
    dTdt[-1] = 0  # No change at x = L
    
    return dTdt

# Parameters
L = 100e3  # Length of the domain in meters (100 km)
N = 100  # Number of spatial points
dx = L / (N - 1)  # Spatial step size
kappa = 1e-6  # Thermal diffusivity in m^2/s
H = 0.1e-12  # Heat source term in W/m^3
t_span = (0, 100e6 * 365.25 * 24 * 3600)  # Time span: 100 million years in seconds
h = 1e6 * 365.25 * 24 * 3600  # Time step: 1 million years in seconds

# Initial condition: constant mantle temperature
T0 = np.full(N, 1673.15) 

# Boundary conditions (fixed temperatures at boundaries)
def apply_boundary_conditions(T):
    T[0] = 273.15  # Fixed temperature at x = 0 (273.15 K)
    T[-1] = 1673.15  # Fixed temperature at x = L (1673.15 K)
    return T

T0 = apply_boundary_conditions(T0)

# Solve the heat equation
start = time.time()
solver = IRK4Solver()
t_values, T_values = solver.solve(lambda t, T: heat_equation(t, T, kappa, dx, H), T0, t_span, h)
end = time.time()
print("Solving the IRK4 system takes %.2f s" % (end - start))

# Extract and plot results
plt.figure(figsize=(10, 6))
for i in range(0, len(t_values), max(1, len(t_values)//10)):  # Plot at 10 time steps
    plt.plot(np.linspace(0, L/1e3, N), T_values[i], label=f"t={t_values[i]/(1e6 * 365.25 * 24 * 3600):.2f} Myr")
# i = 0  # debug
# plt.plot(np.linspace(0, L/1e3, N), T_values[i], label=f"t={t_values[i]/(1e6 * 365.25 * 24 * 3600):.2f} Myr")
plt.xlabel("Depth (km)")
plt.ylabel("Temperature (K)")
plt.title("1D Thermal Conduction in Oceanic Lithosphere")
plt.legend()
plt.grid()
plt.show()

### **3. Stock Market: Option Pricing with Stochastic Models**
- **Problem:** Solving the Black-Scholes equation for options pricing.
- **Example Application:**
  - Apply IRK4 to solve the time evolution of the Black-Scholes equation:
    $$\frac{\partial V}{\partial t} + \frac{1}{2} \sigma^2 S^2 \frac{\partial^2 V}{\partial S^2} + r S \frac{\partial V}{\partial S} - rV = 0$$
    where:
    - \( V \): Option value
    - \( S \): Stock price
    - \( \sigma \): Volatility
    - \( r \): Risk-free rate
  - Convert the partial differential equation (PDE) to an ODE system using finite difference methods, then integrate in time using IRK4.

### **4. Macroeconomics: Dynamic Systems in Economic Models**
- **Problem:** Solving dynamic economic models with differential equations.
- **Example Application:**
  - Use IRK4 to solve equations in a Solow growth model or DSGE model:
    $$\frac{dk}{dt} = s f(k) - \delta k$$
    where:
    - \( k \): Capital
    - \( s \): Savings rate
    - \( f(k) \): Production function
    - \( \delta \): Depreciation
  - IRK4 provides stability when solving for long-term trends in capital accumulation.


### **5. Applied Physics: Coupled Oscillations**
- **Problem:** Modeling coupled pendulums or masses in mechanical systems.
- **Example Application:**
  - Solve a system of second-order ODEs for coupled harmonic oscillators:
    $$m_1 \ddot{x}_1 = -k_1 x_1 + k_2 (x_2 - x_1)$$
    $$m_2 \ddot{x}_2 = -k_2 (x_2 - x_1)$$
    where:
    - \( m_1, m_2 \): Masses
    - \( k_1, k_2 \): Spring constants
  - Convert to a first-order system and use IRK4 for high-accuracy time integration.