# Joint Differential Equations

[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/brainpy/brainpy/blob/master/docs_version2/tutorial_toolbox/joint_equations.ipynb)
[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/brainpy/brainpy/blob/master/docs_version2/tutorial_toolbox/joint_equations.ipynb)

@[Xiaoyu Chen](mailto:c-xy17@tsinghua.org.cn)

In a [dynamical system](../tutorial_building/dynamical_systems.ipynb), there may be multiple variables that change dynamically over time. Sometimes these variables are interconnected, and updating one variable requires others as the input. For example, in the widely known Hodgkinâ€“Huxley model, the variables $V$, $m$, $h$, and $n$ are updated synchronously and interdependently (please refer to [Building Neuron Models](../tutorial_building/neuron_models.ipynb)for details). To achieve higher integral accuracy, it is recommended to use ``brainpy.JointEq`` to jointly solving interconnected differential equations.

In [None]:
import brainpy as bp

## ``brainpy.JointEq``

``brainpy.JointEq`` is used to merge individual but interconnected differential equations into a single joint equation. For example, below are the two differential equations of the Izhikevich model:

In [2]:
a, b = 0.02, 0.20
dV = lambda V, t, u, Iext: 0.04 * V * V + 5 * V + 140 - u + Iext
du = lambda u, t, V: a * (b * V - u)

Where updating $V$ requires $u$ as the input, and updating $u$ requires $V$ as the input. The joint equation can be defined as:

In [3]:
joint_eq = bp.JointEq(dV, du)

``brainpy.JointEq`` receives only one argument named `eqs`, which can be a list or tuple containing multiple differential equations. Then it can be packed into a numarical integrator that solves the equation with a specified method, just as what can be done to any individual differential equation.

In [4]:
itg = bp.odeint(joint_eq, method='rk2')

There are several requirements for defining a joint equation:
1. Every individual differential equation should follow the format of defining a [ODE](ode_numerical_solvers.ipynb) or [SDE](sde_numerical_solvers.ipynb) funtion in BrainPy. For example, the arguments before `t` denote the dynamical variables and arguments after `t` denote the parameters.
2. The same variable in different equations should have the same name. Different variables should named differently.

Note that `brainpy.JointEq` supports make nested ``JointEq``, which means the instance of ``JointEq`` can be an element to compose a new ``JointEq``.

## Why use `brainpy.JointEq`?

Users may be confused with the function of `brainpy.JointEq`, because multiple differential equations can be written in a single function:

In [5]:
def diff(V, u, t, Iext):
    dV = 0.04 * V * V + 5 * V + 140 - u + Iext
    du = a * (b * V - u)
    return dV, du

itg_V_u = bp.odeint(diff, method='rk2')

or simply packed into interators separately:

In [6]:
int_V = bp.odeint(dV, method='rk2')
int_u = bp.odeint(du, method='rk2')

To illusrate the difference between joint and separate differential equations, let's dive into the differential codes of these two types of equations. 

If we make numerical solver for each derivative function, they will be solved independently:

In [7]:
bp.odeint(dV, method='rk2', show_code=True)

def brainpy_itg_of_ode4(V, t, u, Iext, dt=0.1):
  dV_k1 = f(V, t, u, Iext)
  k2_V_arg = V + dt * dV_k1 * 0.6666666666666666
  k2_t_arg = t + dt * 0.6666666666666666
  dV_k2 = f(k2_V_arg, k2_t_arg, u, Iext)
  V_new = V + dV_k1 * dt * 0.25 + dV_k2 * dt * 0.75
  return V_new

{'f': <function <lambda> at 0x12ef725c0>}



<brainpy.integrators.ode.explicit_rk.RK2 at 0x12ef5e630>

As is shown in the output code, the variable $V$ is integrated twice by the RK2 method. For the second differential value `dV_k2`, the updated value of $V$ (`k2_V_arg`) and original $u$ are used to calculate the differential value. This will generate a tiny error, since the values of $V$ and $u$ are taken at different times.

To eliminate this error, the differential equation of $V$ and $u$ should be solved jointly through `brainpy.JointEq`:

In [8]:
eq = bp.JointEq(dV, du)
bp.odeint(eq, method='rk2', show_code=True)

def brainpy_itg_of_ode5_joint_eq(V, u, t, Iext, dt=0.1):
  dV_k1, du_k1 = f(V, u, t, Iext)
  k2_V_arg = V + dt * dV_k1 * 0.6666666666666666
  k2_u_arg = u + dt * du_k1 * 0.6666666666666666
  k2_t_arg = t + dt * 0.6666666666666666
  dV_k2, du_k2 = f(k2_V_arg, k2_u_arg, k2_t_arg, Iext)
  V_new = V + dV_k1 * dt * 0.25 + dV_k2 * dt * 0.75
  u_new = u + du_k1 * dt * 0.25 + du_k2 * dt * 0.75
  return V_new, u_new

{'f': <brainpy.integrators.joint_eq.JointEq object at 0x12ef5e7b0>}



<brainpy.integrators.ode.explicit_rk.RK2 at 0x12ef5e3f0>

It is shown in this output code that second differential values of $v$ and $u$ are calculated by using the updated values (`k2_V_arg` and `k2_u_arg`) at the same time. This will result in a more accurate integral.

## Second-Order ODEs with `brainpy.JointEq`

A common use case for `JointEq` is solving second-order ordinary differential equations (ODEs). Second-order ODEs appear in many physical systems, such as the harmonic oscillator, pendulum, or neural mass models like the Jansen-Rit model.

When using `JointEq` for second-order ODEs, it's important to follow the correct function signature pattern.

### Example: Harmonic Oscillator

Consider a damped harmonic oscillator described by:

$$\frac{d^2x}{dt^2} = -kx - c\frac{dx}{dt}$$

To solve this with `JointEq`, we split it into two first-order ODEs:

$$\frac{dx}{dt} = v$$
$$\frac{dv}{dt} = -kx - cv$$

Where $x$ is position and $v$ is velocity.

In [9]:
import brainpy as bp
import brainpy.math as bm

# Parameters
k = 1.0  # spring constant
c = 0.1  # damping coefficient

# Define derivative functions
# IMPORTANT: Each state variable appears as the FIRST parameter before 't'
# Other state variables appear AFTER 't' as dependencies
def dx(x, t, v):
    """dx/dt = v"""
    return v

def dv(v, t, x):
    """dv/dt = -k*x - c*v"""
    return -k * x - c * v

# Create joint equation
joint_eq = bp.JointEq(dx, dv)
print(f"Joint equation signature: {joint_eq.__signature__}")

Joint equation signature: (x, v, t)


### Important: Function Signature Pattern

When defining derivative functions for `JointEq`, follow this pattern:

**Correct:**
```python
def dx(x, t, v):  # x is the state variable, v is a dependency
    return v

def dv(v, t, x):  # v is the state variable, x is a dependency
    return -k * x - c * v
```

**Incorrect:**
```python
def dx(x, v, t):  # WRONG: Both x and v before t
    return v
```

**Rule:** Each state variable should appear as the **first parameter before** `t` in exactly one derivative function. If a variable is needed as a dependency in another function, it should be placed **after** `t`.

This ensures that `JointEq` knows which variable each function is differentiating and which variables are dependencies.

### Example: Jansen-Rit Model

The Jansen-Rit model is a neural mass model with three coupled second-order ODEs. Here's how to implement it correctly with `JointEq`:

In [10]:
class JansenRitModel(bp.dyn.NeuDyn):
    def __init__(self, size=1, A=3.25, te=10, B=22, ti=20, C=135, 
                 e0=2.5, r=0.56, v0=6, method='rk4', **kwargs):
        super().__init__(size=size, **kwargs)
        self.A, self.te = A, te
        self.B, self.ti = B, ti
        self.C = C
        self.e0, self.r, self.v0 = e0, r, v0
        
        # State variables: positions (y0, y1, y2) and velocities (y3, y4, y5)
        self.y0 = bm.Variable(bm.zeros(self.num))
        self.y1 = bm.Variable(bm.zeros(self.num))
        self.y2 = bm.Variable(bm.zeros(self.num))
        self.y3 = bm.Variable(bm.zeros(self.num))  # velocity for y0
        self.y4 = bm.Variable(bm.zeros(self.num))  # velocity for y1
        self.y5 = bm.Variable(bm.zeros(self.num))  # velocity for y2
        
        self.integral = bp.odeint(f=self.derivative, method=method)
    
    # Position derivatives: dx/dt = v
    def dy0(self, y0, t, y3):  # y0 is state, y3 is dependency
        return y3 / 1000
    
    def dy1(self, y1, t, y4):  # y1 is state, y4 is dependency
        return y4 / 1000
    
    def dy2(self, y2, t, y5):  # y2 is state, y5 is dependency
        return y5 / 1000
    
    # Velocity derivatives: dv/dt = ...
    def dy3(self, y3, t, y0, y1, y2):  # y3 is state, others are dependencies
        Sp = 2 * self.e0 / (1 + bm.exp(self.r * (self.v0 - y1 + y2)))
        return (self.A * Sp - 2 * y3 - y0 / self.te * 1000) / self.te
    
    def dy4(self, y4, t, y0, y1, inp=0.):  # y4 is state, others are dependencies
        Se = 2 * self.e0 / (1 + bm.exp(self.r * (self.v0 - self.C * y0)))
        return (self.A * (inp + 0.8 * self.C * Se) - 2 * y4 - y1 / self.te * 1000) / self.te
    
    def dy5(self, y5, t, y0, y2):  # y5 is state, others are dependencies
        Si = 2 * self.e0 / (1 + bm.exp(self.r * (self.v0 - 0.25 * self.C * y0)))
        return (self.B * 0.25 * self.C * Si - 2 * y5 - y2 / self.ti * 1000) / self.ti
    
    @property
    def derivative(self):
        # Join all derivatives - order matches the state variables
        return bp.JointEq([self.dy0, self.dy1, self.dy2, self.dy3, self.dy4, self.dy5])
    
    def update(self, inp=0.):
        y0, y1, y2, y3, y4, y5 = self.integral(
            self.y0, self.y1, self.y2, self.y3, self.y4, self.y5,
            bp.share['t'], inp, bp.share['dt']
        )
        self.y0.value = y0
        self.y1.value = y1
        self.y2.value = y2
        self.y3.value = y3
        self.y4.value = y4
        self.y5.value = y5

# Create and test the model
model = JansenRitModel(size=1)
print("Jansen-Rit model created successfully!")

Jansen-Rit model created successfully!


The figure below compares the simulation results of the Izhikevich model using joint and separate differential equations ($dt = 0.2 ms$). It is shown that as the simulation time increases, the integral error becomes greater.

<img src="../_static/joint_and_separate_equations.png" width="900 px">