<a href="https://colab.research.google.com/github/li-positive-one/Numerical_torch/blob/main/ERK.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# 显式Runge-Kutta方程

import numpy as np

# u_t=f(t,u)

# Butcher tableau
# c|A
# ---
#  |b


class RK:
    def __init__(self, c, b, A) -> None:
        self.Stage = len(c)
        if len(b) != self.Stage or A.shape[0] != A.shape[1] or A.shape[0] != self.Stage:
            raise ValueError("Invalid Butcher tableau")
        self.c = c
        self.b = b
        self.A = A

    def __call__(self, f, t0, u0, h, f0=None):
        k_list = []
        for i in range(self.Stage):
            if i==0 and f0 is not None and self.A[0, 0]==0:
                k_list.append(f0)
            else:
                k_list.append(
                    f(
                        t0 + self.c[i] * h,
                        u0 + h * sum([self.A[i, j] * k_list[j] for j in range(i)]),
                    )
                )
        u = u0 + h * sum([k_list[j] * self.b[j] for j in range(self.Stage)])
        return u

class RK1(RK):
    def __init__(self) -> None:
        c = np.array([0])
        b = np.array([1])
        A = np.array([[0]])
        super().__init__(c, b, A)


class RK2(RK):
    def __init__(self, alpha=1 / 2) -> None:
        c = np.array([0, alpha])
        b = np.array([1 - 1 / (2 * alpha), 1 / (2 * alpha)])
        A = np.array([[0, 0], [alpha, 0]])
        super().__init__(c, b, A)

class RK3_SSP(RK):
    def __init__(self) -> None:
        c = np.array([0, 1, 0.5])
        b = np.array([1/6, 1/6, 2/3])
        A = np.array([[0, 0, 0], [1, 0, 0],[1/4, 1/4, 0]])
        super().__init__(c, b, A)

class RK4(RK):
    def __init__(self) -> None:
        c = np.array([0, 0.5, 0.5, 1])
        b = np.array([1 / 6, 1 / 3, 1 / 3, 1 / 6])
        A = np.array([[0, 0, 0, 0], [1 / 2, 0, 0, 0], [0, 1 / 2, 0, 0], [0, 0, 1, 0]])
        super().__init__(c, b, A)


class RK4_2(RK):
    def __init__(self) -> None:
        c = np.array([0, 1 / 3, 2 / 3, 1])
        b = np.array([1 / 8, 3 / 8, 3 / 8, 1 / 8])
        A = np.array([[0, 0, 0, 0], [1 / 3, 0, 0, 0], [-1 / 3, 1, 0, 0], [1, -1, 1, 0]])
        super().__init__(c, b, A)


if __name__ == "__main__":
    a = RK4_2()
    # print(a.c, a.b, a.A)
    f = lambda t, x: x
    u = 1
    for i in range(10):
        u = a(f, 0, u, 1 / 10)
    print("error is:", u - np.exp(1))

error is: -2.0843238792700447e-06
