# Introduction

The goal of this library is to extend OpenMDAO with the ability to do time integration, so that we are able to do instationary MDAO.
This can be achieved by using the class `RungeKuttaIntegrator`, which allows time stepping via Runge-Kutta schemes.
Currently, both explicit and DIRK (**D**iagonally **I**mplicit **R**unge-**K**utta) schemes are supported.

Runge-Kutta schemes are single-step multi-stage methods for ODEs, i.e. the computation of the state at the next time step is done using only the information from one previous step, but multiple intermediate stages are introduced additionally.
A common form of representation for RK-methods uses the so-called Butcher tableaux:
$$
\begin{array} {r|r } c_1 & a_{11} & a_{12} & \cdots & a_{1S} \\  c_2 & a_{21} & a_{22} & \cdots & a_{2S} \\  \vdots & \vdots  & \vdots & \ddots & \vdots \\  c_S & a_{S1} & a_{S2} & \cdots & a_{SS} \\ \hline  & b_1 & b_2 & \cdots & b_S \\   \end{array}
$$
Given the ODE
$$
x'(t) = f(t, x(t)),
$$
the formula to compute the time steps is
$$
x_{n+1} = x_n + \Delta t\sum_{i=1}^S b_i \cdot k_i, \quad n=0,\ldots, N-1,
$$
where $x_{n+1}$ and $x_n$ are the states at the new and old step respectively, $\Delta t$ is the step size, the $b_i$ are the coefficients form the last row of the Butcher tableau, and the $k_i$ are stage variables approximating the slope of the function $x$.
These $k_i$ are given by 
$$
k_i = f\left(t_n + \Delta t\cdot c_i, x_n + \Delta t\cdot\sum_{j=1}^S a_{ij}k_j \right), \quad i =1,\ldots, S
$$
where the $c_i$ are from the first column of the Butcher tableau, and the $a_{ij}$ from the Butcher matrix in the tableau.
In general, this leads to a $s\cdot d$-dimensional nonlinear system of equations, where $d$ is the dimension of $x$.
However, for explicit schemes we have $a_{ij} = 0$ for $i<=j$, and for DIRK schemes we have $a_{ij} = 0$ for $i<j$.
This allows the stages to be solved for sequentially.

Define $s_i:=\sum_{j=0}^{i-1}a_{ij} k_j$. Then we can rewrite the formula for the stage of explicit or DIRK schemes as
$$
k_i = f\left(t_n + \Delta t\cdot c_i, x_n + \Delta t\cdot (s_i + a_{ii} k_i) \right), \quad i =1,\ldots, S.
$$
In this form, the formula always has the same number of variables and the same structure, independent of which stage is currently computed.
This is what allows us to take the following approach:
We assume that we have an openMDAO-problem that models one time stage of the instationary multidisciplinary problem. 
In this problem, we have inputs $x_n$ and $s_i$ as well as outputs $k_i$ per discipline.
These are marked both with a tag for the discipline and the role of the variable in the time integration.
To do the time integration then, we repeatedly run this model with differing inputs and outputs:

1. At the start of each time step, we write the old states into the inputs of the $x_n$s of the disciplines.
2. At the start of each time stage, we compute the new $s_i$ and write it into the respective inputs.
3. We run the model, and afterwards copy the outputs $k_i$ out of the problem.

# A Simple Example

To start with a simple example, we keep it monodisciplinary for now and try the possibly simplest ODE: 
$$x'(t) = x(t), \quad x(0) = 1$$
The formula for the stage then is
$$ k_i = x_n + \Delta t \cdot (s_i + a_{ii}\cdot k_i), $$
or after some rearranging
$$ k_i = \frac{x_n + \Delta t\cdot s_i}{1-\Delta t \cdot a_{ii}}. $$
This can easily be implemented via an explicit component.

In [None]:
import openmdao.api as om
from  rkopenmdao.integration_control import IntegrationControl

class ODEComponent(om.ExplicitComponent):
    def initialize(self):
        self.options.declare("integration_control", types=IntegrationControl)

    def setup(self):
        self.add_input("x_old", val=1.0, tags=["x", "step_input_var"])
        self.add_input("s_i", val=0.0, tags=["x", "accumulated_stage_var"])
        self.add_output("k_i", val=1.0, tags=["x", "stage_output_var"])

    def setup_partials(self):
        self.declare_partials("*", "*")

    def compute(self, inputs, outputs):
        delta_t = self.options["integration_control"].delta_t
        butcher_diagonal_element = self.options["integration_control"].butcher_diagonal_element
        outputs["k_i"] = (inputs["x_old"] + delta_t * inputs["s_i"]) / (1 - delta_t * butcher_diagonal_element)

    def compute_partials(self, inputs, partials):
        delta_t = self.options["integration_control"].delta_t
        butcher_diagonal_element = self.options["integration_control"].butcher_diagonal_element
        partials["k_i", "x_old"] = 1 / (1 - delta_t * butcher_diagonal_element)
        partials["k_i", "s_i"] = delta_t / (1 - delta_t * butcher_diagonal_element)




You should notice a few things here. First, we gave the component an option for an `IntegrationControl` object.
This class contains data and metadata for the current time step and time stage.
An instance of this class should be passed to the outer integration component and most components in the time stage problem.
More details on that later.
Second, notice that the three declared variables have certain tags.
All three share the tag `"x"`, which is done so that they are grouped as belonging to the same quantity.
Furthermore, these have a second tag that describes their role in the time integration. 
The tag `"step_input_var"` is used for the variable containing the old state of the quantity, and the tag `"accumulated_stage_var"` is used for the $s_i$ of the quantity.
These two tags need to be applied to input variables, and per quantity it is necessary to either have both or none of these tags.
(The second case might be useful for problems where the time derivative of a quantity does not depend on the state of the same quantity but of other quantities.)
Then there is the tag `"stage_output_var"`. for the $k_i$ of the quantity.
For each quantity you want to apply time integration on, an output with this tag needs to be present.

Now lets see how this component is used in our time integrator, the `RungeKuttaIntegrator`.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import h5py
from rkopenmdao.runge_kutta_integrator import RungeKuttaIntegrator
from rkopenmdao.butcher_tableaux import implicit_euler


integration_control = IntegrationControl(initial_time=0.0, num_steps=100, delta_t=0.01)

time_stage_problem = om.Problem()
time_stage_problem.model.add_subsystem("ODE_Comp", ODEComponent(integration_control=integration_control))

runge_kutta_problem = om.Problem()
runge_kutta_problem.model.add_subsystem(
    "rk_comp", 
    RungeKuttaIntegrator(
        time_stage_problem=time_stage_problem, 
        butcher_tableau=implicit_euler,
        integration_control=integration_control,
        time_integration_quantities=["x"],
        write_file="simple_ODE_example.h5",
        write_out_distance = 10,
    )
)

runge_kutta_problem.setup()
runge_kutta_problem.run_model()

fig, ax = plt.subplots(1,1)

t = np.linspace(0.0,1.0, 11)
ax.plot(t, np.exp(t), label="Analytic solution.")

numeric_data = np.zeros(11)
with h5py.File("simple_ODE_example.h5", mode="r") as f:
    for i in range(11):
        numeric_data[i] = f["x"][str(10*i)][0]
ax.plot(t, numeric_data, "o", label="Numeric Solution")

ax.set_xlabel("t")
ax.set_xlim(0.0,1.0)
ax.set_ylabel("x(t)")
ax.set_ylim(0.0,3.0)
ax.legend()

plt.show()


We start by creating our instance of `IntegrationControl`.
This requires the initial time of the differential equation, the number of time steps we want to simulate, as well as the step size.
At the next step, we create the problem for one time stage, which is then used in the creation of our instance of the `RungeKuttaIntegrator`. 
This class is derived from an explicit component, which is why we add it to a second, outer OpenMDAO problem here.

The arguments to create an instance of the `RungeKuttaIntegrator` are:
* The time stage problem,
* the Butcher tableau for the Runge-Kutta scheme we want to apply (here implicit Euler),
* the `IntegratrionControl` object,
* and a list of quantities tagged in the inner time stage problem we want to apply time integration on.

Furthermore, we used two optional arguments here:
* The `write_out_distance` is the distance between the time steps for which quantity data is written to disk (in the form of an HDF5 file), starting with the initial values. 
The default is `0`, for which no data is written out.
Note that when writing out data, the data of the final step is always written out, even if the step number isn't a multiple of `write_out_distance`.
* The `write_file` (default: `"data.h5"`) is the name of the HDF5-file the quantity data is stored in (if `write_out_distance` is non-zero). The file structure of the file is of the form quantity/time_step.

There are still other optional arguments which will be covered later.

Afterward, we set up the outer problem, and let it run.
At the end follows a plot to visually compare the numerical with the analytical result.


# A Multidisciplinary Example

We are not only interested in monodisciplinary examples of course. Which is why we continue with a simple multidisciplinary example: A 2-dimensional system of ODEs:
$$
x'(t) = -y(t), \quad x(0) = 1,\\
y'(t) = x(t), \quad y(0) = 1.
$$
The analytic solution to this system is
$$
x(t) = \cos(t) - \sin(t),\\
y(t) = \cos(t) + \sin(t).
$$
For OpenMDAO, the stage problem can be represented by two explicit components.

In [None]:
class CoupledODEComponent1(om.ExplicitComponent):
    def initialize(self):
        self.options.declare("integration_control", types=IntegrationControl)

    def setup(self):
        self.add_input("y_old", val=1.0, tags=["y", "step_input_var"])
        self.add_input("sy_i", val=0.0, tags=["y", "accumulated_stage_var"])
        self.add_input("ky_i", val=1.0)
        self.add_output("kx_i", val=1.0, tags=["x", "stage_output_var"])

    def setup_partials(self):
        self.declare_partials("*", "*")

    def compute(self, inputs, outputs):
        delta_t = self.options["integration_control"].delta_t
        butcher_diagonal_element = self.options["integration_control"].butcher_diagonal_element
        outputs["kx_i"] = -(inputs["y_old"] + delta_t * (inputs["sy_i"] + butcher_diagonal_element * inputs["ky_i"]))

    def compute_partials(self, inputs, partials):
        delta_t = self.options["integration_control"].delta_t
        butcher_diagonal_element = self.options["integration_control"].butcher_diagonal_element
        partials["kx_i", "y_old"] = -1.0
        partials["kx_i", "sy_i"] = -delta_t
        partials["kx_i", "ky_i"] = -delta_t * butcher_diagonal_element

class CoupledODEComponent2(om.ExplicitComponent):
    def initialize(self):
        self.options.declare("integration_control", types=IntegrationControl)

    def setup(self):
        self.add_input("x_old", val=1.0, tags=["x", "step_input_var"])
        self.add_input("sx_i", val=0.0, tags=["x", "accumulated_stage_var"])
        self.add_input("kx_i", val=1.0)
        self.add_output("ky_i", val=1.0, tags=["y", "stage_output_var"])

    def setup_partials(self):
        self.declare_partials("*", "*")

    def compute(self, inputs, outputs):
        delta_t = self.options["integration_control"].delta_t
        butcher_diagonal_element = self.options["integration_control"].butcher_diagonal_element
        outputs["ky_i"] = inputs["x_old"] + delta_t * (inputs["sx_i"] + butcher_diagonal_element * inputs["kx_i"])

    def compute_partials(self, inputs, partials):
        delta_t = self.options["integration_control"].delta_t
        butcher_diagonal_element = self.options["integration_control"].butcher_diagonal_element
        partials["ky_i", "x_old"] = 1.0
        partials["ky_i", "sx_i"] = delta_t 
        partials["ky_i", "kx_i"] = delta_t * butcher_diagonal_element

Note that the different time integration variables of the quantities are distributed across the components.
Lets apply the `RungeKuttaIntegrator` again and compare with the analytical solution.

In [None]:
integration_control = IntegrationControl(initial_time=0.0, num_steps=1000, delta_t=0.01)

time_stage_problem = om.Problem()
time_stage_problem.model.add_subsystem(
    "Coupled_ODE_Comp1",
    CoupledODEComponent1(integration_control=integration_control),
    promotes=["*"]
)
time_stage_problem.model.add_subsystem(
    "Coupled_ODE_Comp2",
    CoupledODEComponent2(integration_control=integration_control),
    promotes=["*"]
)

time_stage_problem.model.nonlinear_solver=om.NewtonSolver(solve_subsystems=True, iprint=-2)
time_stage_problem.model.linear_solver=om.DirectSolver(iprint=-2)

runge_kutta_problem = om.Problem()
runge_kutta_problem.model.add_subsystem(
    "rk_comp", 
    RungeKuttaIntegrator(
        time_stage_problem=time_stage_problem, 
        butcher_tableau=implicit_euler,
        integration_control=integration_control,
        time_integration_quantities=["x", "y"],
        write_file="coupled_ODE_example.h5",
        write_out_distance=25,
    )
)

runge_kutta_problem.setup()
runge_kutta_problem.run_model()

fig, (ax1, ax2) = plt.subplots(1,2)

t = np.linspace(0.0, 10.0, 41)
ax1.plot(t, np.cos(t) - np.sin(t), label="Analytic solution of x")
ax2.plot(t, np.cos(t) + np.sin(t), label="Analytic solution of y")

numeric_data = np.zeros((41, 2))
with h5py.File("coupled_ODE_example.h5", mode="r") as f:
    for i in range(41):
        numeric_data[i][0] = f["x"][str(25*i)][0]
        numeric_data[i][1] = f["y"][str(25*i)][0]
ax1.plot(t, numeric_data[:,0], "o", label="Numeric solution of x")
ax2.plot(t, numeric_data[:,1], "o", label="Numeric solution of y")

ax1.set_xlabel("t")
ax1.set_xlim(0.0,10.0)
ax1.set_ylabel("x(t)")
ax1.set_ylim(-1.5, 1.5)
ax1.legend()

ax2.set_xlabel("t")
ax2.set_xlim(0.0,10.0)
ax2.set_ylabel("y(t)")
ax2.set_ylim(-1.5, 1.5)
ax2.legend()

plt.show()

# Further Options of the `RungeKuttaIntegrator`
## Postprocessing
The `RungeKuttaIntegrator` doesn't just have the ability to do time integration, it also can do time-step postprocessing.
An example use case could be the computation of a value derived of the quantities of the time integration, which isn't necessarily dependent on time (or at least not in the form of time derivatives).
While it is possible to "hack" the derived quantity into the time integration in most cases, it is also useful to have such an ability to postprocess, be it because it's less computationally intensive, or simply because it's easier to understand.

For postprocessing, we give a second OpenMDAO problem to the `RungeKuttaIntegrator`, which takes a subset of the time integration quantities as inputs, and from which we can get some new additional quantities. We extend the first example:

In [None]:
class LogComponent(om.ExplicitComponent):
    def setup(self):
        self.add_input("x", val = 1.0, tags = ["x", "postproc_input_var"])
        self.add_output("log_x", val = 0.0, tags = ["log_x" ,"postproc_output_var"])

    def setup_partials(self):
        self.declare_partials("*", "*")
        

    def compute(self, inputs, outputs):
        outputs["log_x"] = np.log(inputs["x"])

    def compute_partials(self, inputs, partials):
        partials["log_x", "x"] = 1 / inputs["x"]

integration_control = IntegrationControl(initial_time=0.0, num_steps=100, delta_t=0.01)

time_stage_problem = om.Problem()
time_stage_problem.model.add_subsystem("ODE_Comp", ODEComponent(integration_control=integration_control))

postproc_problem = om.Problem()
postproc_problem.model.add_subsystem("log_comp", LogComponent())

runge_kutta_problem = om.Problem()
runge_kutta_problem.model.add_subsystem(
    "rk_comp", 
    RungeKuttaIntegrator(
        time_stage_problem=time_stage_problem,
        postprocessing_problem=postproc_problem,
        butcher_tableau=implicit_euler,
        integration_control=integration_control,
        time_integration_quantities=["x"],
        postprocessing_quantities=["log_x"],
        write_file="postproc_example.h5",
        write_out_distance=10,
    )
)

runge_kutta_problem.setup()
runge_kutta_problem.run_model()

fig, ax = plt.subplots(1,1)

t = np.linspace(0.0,1.0, 11)
ax.plot(t, np.exp(t), label="Analytic solution.")
ax.plot(t, t, label="Postprocessed analytic solution")

numeric_data = np.zeros(11)
postproc_data = np.zeros(11)
with h5py.File("postproc_example.h5", mode = "r") as f:
    for i in range(11):
        numeric_data[i] = f["x"][str(10*i)][0]
        postproc_data[i] = f["log_x"][str(10*i)][0]
ax.plot(t, numeric_data, "o", label="Numeric Solution")
ax.plot(t, postproc_data, "o", label="Postprocessed numeric solution")
ax.set_xlabel("t")
ax.set_xlim(0.0,1.0)
ax.set_ylabel("x(t)")
ax.set_ylim(0.0,3.0)
ax.legend()

plt.show()

print(runge_kutta_problem["rk_comp.log_x_final"])

We first defined the component that we put in our postprocessing problem.
Note that while here, this problem only has one component, it can be as complex as you want it to be.

## Functionals
Another feature of the `RungeKuttaIntegrator` is the computation of functionals (or more precisely linear combinations) based on the quantities (both from time integration and postprocessing).
Note that one of these linear combinations only takes data from *one* quantity.
If you want to compute a functional involving two quantities, combine them first via the postprocessing.

To use this feature, you have to give a `FunctionalCoefficients` object to the `RungeKuttaIntegrator`.
Such an object needs implementations for two methods:
1. `list_quantities()`, which needs to return a list of strings.
These strings are supposed to be the quantities on which the instance of the `RungeKuttaIntegrator` is working on.
2. `get_coefficient()`, which gets an `int` (for a time step) and a `str` (for a quantity), and returns a `float` (for the coefficient of the functional for the given quantity at the time step).

As example, we extend the multidisciplinary example from above.
First we define the class for functional coefficients:

In [None]:
from typing import List
from rkopenmdao.functional_coefficients import FunctionalCoefficients

class MultidisciplinaryFunctionalCoefficients(FunctionalCoefficients):
    def __init__(self, num_steps: int, delta_t: float) -> None:
        self.num_steps = num_steps
        self.delta_t = delta_t

    def list_quantities(self) -> List[str]:
        return ["x", "y"]
    
    def get_coefficient(self, time_step: int, quantity: str) -> float:
        if quantity == "x":
            if time_step == 0 or time_step == self.num_steps:
                return 0.5 * self.delta_t
            return self.delta_t
        elif quantity == "y":
            return self.num_steps**-1

This is used to compute two functionals:
1. Over $x$, we want to compute an integral, and use a composite trapezoidal rule to approximate it.
2. Over $y$, we want to compute its average value over the domain (i.e. $[0,10]$)

In [None]:
delta_t = 0.01
num_steps = 1000

integration_control = IntegrationControl(0.0, num_steps, delta_t)

functional_coefficients = MultidisciplinaryFunctionalCoefficients(num_steps, delta_t)

time_stage_problem = om.Problem()
time_stage_problem.model.add_subsystem(
    "Coupled_ODE_Comp1",
    CoupledODEComponent1(integration_control=integration_control),
    promotes=["*"]
)
time_stage_problem.model.add_subsystem(
    "Coupled_ODE_Comp2",
    CoupledODEComponent2(integration_control=integration_control),
    promotes=["*"]
)

time_stage_problem.model.nonlinear_solver=om.NewtonSolver(solve_subsystems=True, iprint=-2)
time_stage_problem.model.linear_solver=om.DirectSolver(iprint=-2)

runge_kutta_problem = om.Problem()
runge_kutta_problem.model.add_subsystem(
    "rk_comp", 
    RungeKuttaIntegrator(
        time_stage_problem=time_stage_problem, 
        butcher_tableau=implicit_euler,
        integration_control=integration_control,
        time_integration_quantities=["x", "y"],
        functional_coefficients=functional_coefficients
    )
)

runge_kutta_problem.setup()
runge_kutta_problem.run_model()


print("numerical integral of numerical solution of x: ", runge_kutta_problem["rk_comp.x_functional"][0])
print("analytical integral of analytical solution of x: ", np.sin(num_steps * delta_t) + np.cos(num_steps * delta_t) - 1)
print("average of numerical solution of y: ", runge_kutta_problem["rk_comp.y_functional"][0])
print("average of analytical solution of y: ", np.sum([np.sin(delta_t*i) + np.cos(delta_t*i) for i in range(num_steps+1)])/(num_steps+1))

The computation of the functional is enabled by giving the instance of `FunctionalCoefficients` to the `functional_coefficient` argument of the `RungeKuttaIntegrator`.
This adds additional outputs to the `RungeKuttaIntegrator`, one per quantity over which a functional is computed.
The variable name of these outputs are of the form `<quantity>_functional`, and their values can be extracted from the OpenMDAO-problem as usual.

To illustrate that it is possible to compute functionals only over a subset of the quantities in an integrator, and also to use the quantities from postprocessing, let's look at a second example (which is reusing the postprocessing example):

In [None]:
class NinetyNinthStepFunctionalCoefficients(FunctionalCoefficients):
    def list_quantities(self) -> List[str]:
        return ["log_x"]
    
    def get_coefficient(self, time_step: int, quantity: str) -> float:
        if quantity == "log_x" and time_step == 99:
            return 1.0
        else:
            return 0.0

integration_control = IntegrationControl(initial_time=0.0, num_steps=100, delta_t=0.01)

functional_coefficients = NinetyNinthStepFunctionalCoefficients()

time_stage_problem = om.Problem()
time_stage_problem.model.add_subsystem("ODE_Comp", ODEComponent(integration_control=integration_control))

postproc_problem = om.Problem()
postproc_problem.model.add_subsystem("log_comp", LogComponent())

runge_kutta_problem = om.Problem()
runge_kutta_problem.model.add_subsystem(
    "rk_comp", 
    RungeKuttaIntegrator(
        time_stage_problem=time_stage_problem,
        postprocessing_problem=postproc_problem,
        butcher_tableau=implicit_euler,
        integration_control=integration_control,
        time_integration_quantities=["x"],
        postprocessing_quantities=["log_x"],
        functional_coefficients=functional_coefficients,
    )
)

runge_kutta_problem.setup()
runge_kutta_problem.run_model()

print(runge_kutta_problem["rk_comp.log_x_functional"][0])


The used "functional" coefficients take the value of the 99th time step of the quantity `"log_x"` and make it an output of the `RKIntegrator`.
The quantity `"x"` doesn't get an additional functional output, as `NinetyNinthStepFunctionalCoefficients` doesn't list it in its `list_quantities`	

## Partial Derivatives and Checkpointing

Now to a feature which already was active, but which we didn't make use of yet:
The `RungeKuttaIntegrator` supports calculation of derivatives, both in forward and reverse mode (under the condition that both the *time_stage_problem* and the optional *postprocessing_problem* support the given derivative mode wrt. to the in/outputs that the `RungeKuttaIntegrator` interacts with)

We reuse the multidisciplinary example again, but with a twist:
this time, we won't just simply do the time integration, but we want to find $x_0$ and $y_0$ such that $(x_N-1)^2 + y_N^2$ is minimized.

In [None]:
from rkopenmdao.checkpoint_interface.pyrevolve_checkpointer import PyrevolveCheckpointer

integration_control = IntegrationControl(initial_time=0.0, num_steps=1000, delta_t=0.01)

time_stage_problem = om.Problem()
time_stage_problem.model.add_subsystem(
    "Coupled_ODE_Comp1",
    CoupledODEComponent1(integration_control=integration_control),
    promotes=["*"]
)
time_stage_problem.model.add_subsystem(
    "Coupled_ODE_Comp2",
    CoupledODEComponent2(integration_control=integration_control),
    promotes=["*"]
)

time_stage_problem.model.nonlinear_solver=om.NewtonSolver(solve_subsystems=True, iprint = -2)
time_stage_problem.model.linear_solver=om.DirectSolver(iprint = -2)

runge_kutta_problem = om.Problem()
runge_kutta_problem.model.add_subsystem(
    "rk_comp", 
    RungeKuttaIntegrator(
        time_stage_problem=time_stage_problem, 
        butcher_tableau=implicit_euler,
        integration_control=integration_control,
        time_integration_quantities=["x", "y"],
        write_file="coupled_ODE_solving_example.h5",
        write_out_distance=25,
        checkpointing_type = PyrevolveCheckpointer
    )
)
class ObjComp(om.ExplicitComponent):
    def setup(self):
        self.add_input("x", val=1.0)
        self.add_input("y", val=1.0)
        self.add_output("sq", val=1.0)

    def compute(self, inputs, outputs):
        outputs["sq"] = (inputs["x"]-1)**2 + inputs["y"]**2

    def compute_jacvec_product(self, inputs, d_inputs, d_outputs, mode, discrete_inputs=None):
        if mode == "fwd":
            d_outputs["sq"] += 2 * (inputs["x"] - 1) * d_inputs["x"]
            d_outputs["sq"] += 2 * inputs["y"] * d_inputs["y"]
        elif mode == "rev":
            d_inputs["x"] += 2 * (inputs["x"] - 1) * d_outputs["sq"]
            d_inputs["y"] += 2 * inputs["y"] * d_outputs["sq"]


runge_kutta_problem.model.add_subsystem("square", ObjComp())

runge_kutta_problem.model.connect("rk_comp.x_final", "square.x")
runge_kutta_problem.model.connect("rk_comp.y_final", "square.y")

runge_kutta_problem.model.add_design_var("rk_comp.x_initial", lower=-1.0, upper=2.0)
runge_kutta_problem.model.add_design_var("rk_comp.y_initial", lower=-1.0, upper=2.0)
runge_kutta_problem.model.add_objective("square.sq")

runge_kutta_problem.driver = om.ScipyOptimizeDriver()
runge_kutta_problem.driver.options["optimizer"] = "SLSQP"
res_0 = np.zeros((2,2))
res_N = np.zeros((2,2))
res_obj = np.zeros(2)
for i, mode in enumerate(["fwd", "rev"]):
    runge_kutta_problem.setup(mode=mode)

    runge_kutta_problem["rk_comp.x_initial"] = 0.5
    runge_kutta_problem["rk_comp.y_initial"] = 0.5

    runge_kutta_problem.run_driver()

    res_0[i,0] = runge_kutta_problem["rk_comp.x_initial"]
    res_0[i,1] = runge_kutta_problem["rk_comp.y_initial"]
    res_N[i,0] = runge_kutta_problem["rk_comp.x_final"]
    res_N[i,1] = runge_kutta_problem["rk_comp.y_final"]
    res_obj[i] = runge_kutta_problem["square.sq"]
for i, mode in enumerate(["fwd", "rev"]):
    print(mode + "_results:")
    print("x0 = ", res_0[i,0], " y_0 = ", res_0[i,1])
    print("xN = ", res_N[i,0], " y_N = ", res_N[i,1])
    print("objective = ", res_obj[i])


Both forward and reverse mode yield the same result (within tolerance).

There are two arguments in the `RungeKuttaIntegrator` to configure checkpointing:
1. `checkpointing_type`, which needs to be implementations of the class `CheckpointInterface`. There are three implementations currently: `NoCheckpointer` (the default), which doesn't set checkpoints (and doesn't work with reverse-mode in the process). `AllCheckpointer`, which sets checkpoints at all time steps, and `PyrevolveCheckpointer`, which uses [PyRevolve](https://github.com/devitocodes/PyRevolve) for checkpointing.
2. `Checkpoint_options`, which is a dict for implementation specific options that are passed to the checkpointer.

To illustrate, we use the reverse mode example, but this time use the `AllCheckpointer`:

In [None]:
from rkopenmdao.checkpoint_interface.all_checkpointer import AllCheckpointer

integration_control = IntegrationControl(initial_time=0.0, num_steps=1000, delta_t=0.01)

time_stage_problem = om.Problem()
time_stage_problem.model.add_subsystem(
    "Coupled_ODE_Comp1",
    CoupledODEComponent1(integration_control=integration_control),
    promotes=["*"]
)
time_stage_problem.model.add_subsystem(
    "Coupled_ODE_Comp2",
    CoupledODEComponent2(integration_control=integration_control),
    promotes=["*"]
)

time_stage_problem.model.nonlinear_solver=om.NewtonSolver(solve_subsystems=True, iprint=-2)
time_stage_problem.model.linear_solver=om.DirectSolver(iprint=-2)

runge_kutta_problem = om.Problem()
runge_kutta_problem.model.add_subsystem(
    "rk_comp", 
    RungeKuttaIntegrator(
        time_stage_problem=time_stage_problem, 
        butcher_tableau=implicit_euler,
        integration_control=integration_control,
        time_integration_quantities=["x", "y"],
        checkpointing_type= AllCheckpointer
    )
)


runge_kutta_problem.model.add_subsystem("square", ObjComp())

runge_kutta_problem.model.connect("rk_comp.x_final", "square.x")
runge_kutta_problem.model.connect("rk_comp.y_final", "square.y")

runge_kutta_problem.model.add_design_var("rk_comp.x_initial",lower=-1.0, upper=2.0)
runge_kutta_problem.model.add_design_var("rk_comp.y_initial", lower=-1.0, upper=2.0)
runge_kutta_problem.model.add_objective("square.sq")

runge_kutta_problem.driver = om.ScipyOptimizeDriver()
runge_kutta_problem.driver.options["optimizer"] ="SLSQP"

runge_kutta_problem.setup()

runge_kutta_problem["rk_comp.x_initial"] = 0.5
runge_kutta_problem["rk_comp.y_initial"] = 0.5

runge_kutta_problem.run_driver()

print("results:")
print("x0 = ", runge_kutta_problem["rk_comp.x_initial"], " y_0 = ", runge_kutta_problem["rk_comp.y_initial"])
print("xN = ", runge_kutta_problem["rk_comp.x_final"], " y_N = ", runge_kutta_problem["rk_comp.y_final"])
print("objective = ", runge_kutta_problem["square.sq"])



# Further Examples

To see some further examples, look into `/examples`.
In there, there are examples for solving the heat equation via the finite-difference-method and the `RungeKuttaIntegrator`, both in a monodisciplinary way and one multidisciplinary with a split domain.
Also, there are examples on how to use the `RungeKuttaIntegrator` with MPI support, as explained in the next section.

## MPI Support

The `RungeKuttaIntegrator` can also work with MPI parallel problems (both for time integration and postprocessing).
In that case, the in- and outputs of the integration component will mimic the structure of the variables of the quantities in the inner problems.
Both variants of MPI support from OpenMDAO are supported, meaning distributed variables and parallel groups.
However, some conditions must be fulfilled on the inner problems for it to work:

In general:
- Make sure that the communicators used by the `RungeKuttaIntegrator`-instance, the time stage problem, and the postprocessing problem are compatible, as in that the ranks used match for all of them (this is easiest achieved by just using the same communicator for all of them).
Also, the different variables for the same quantity have to reside on the same rank(s), even across the different problems.
The `RungeKuttaIntegrator` will create its in- and outputs accordingly. That means in particular

For  distributed variables:
- The distribution structure of the different variables for the same quantity must be the same, i.e. the variables with the `"step_input_var"`, `"accumulated_stage_var"` and `"stage_output_var"` for a quantity must be distributed in the same way, over the same ranks.
- Should you additionally use that quantity in a postprocessing problem, the variable with the `"postproc_input_var"` tag must share the same distribution structure.
- The in- and outputs of the `RungeKuttaIntegrator` instance will be created such that they also share the same distribution structure.

For parallel groups:
- Make sure that the in- and output variables really share the same ranks.
That can e.g. achieved by manually creating IndependentVariableComponents for the inputs, and grouping them together with the respective component before adding them into a parallel group.
- While quantities will only reside on certain ranks in the inner problem, there will be in/outputs on every rank in the`RungeKuttaIntegrator`.
But the size of the variables on all ranks but the ones where the quantity is found in the inner problem will be zero.
