# Solving Ordinary Differential Equations (ODEs) numerically in Python
### -MC Physics [Updated 02/2024]

For the most part, we can (and do) express the basic laws of physics in terms of differential equations. Much of your undergraduate physics education deals with analytical methods for dealing with these equations, but in many cases we have to resort to numerics to make headway. Even in cases where analytical solutions exist, we can often gain some insight by looking at visual representations of numerical solutions.

This notebook will introduce you to the basics of solving ordinary differential equations numerically with python. There are other differential equation solvers (e.g. https://pythonnumericalmethods.berkeley.edu/notebooks/chapter22.06-Python-ODE-Solvers.html)

In [None]:
pip install matplotlib #For plotting if not already installed

In [None]:
# Imports for this tutorial
import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import odeint

%matplotlib inline

# Simplest use cases

## First order scalar equations

Courses in differential equations usually start with first order equations with constant coefficients, like $$\frac{dy}{dx} + \frac{1}{2}y = \frac{3}{2},$$ which has the solution $$y = 3+ce^{-x/2}.$$  There is one constant of integration, so we can specify one boundary (or initial) condition.

The procedure for solving this (or any, really) ODE has the following steps:
1. Rewrite the equation isolating the derivative on the left hand side.
2. Write a python function for the right hand side.
3. Set up the values for the independent variable.
4. Set initial conditions.
5. Integrate.

Step 1 looks like this:
$$\frac{dy}{dx} = -\frac{1}{2}y + \frac{3}{2}.$$

All of the others are done in python.

In [None]:
# write a python function for the RHS
def rhs(y, x):
    """Right hand side for our differential equation."""
    return -0.5*y + 1.5


# values for independent variable (x in this case)
x = np.linspace(0, 10, 1000)

# initial conditions
y0 = 1.0

# integrate
result = odeint(rhs, y0, x)

The easiest way to look at the results is with a plot.

In [None]:
plt.plot(x, result)
plt.xlabel("x")
plt.ylabel("y")
plt.title("solution to $dy/dx = -1/2 y + 3/2$ with $y(0) = 1$");

All numerical solutions to ODEs essentially follow the same pattern, so you should be able to copy and paste (with modifications) the above cells in your own work.

# Check #1:
#### Use odeint to solve the differential equation $y^\prime = 3 x^2 y + y$ (from the homework) for initial conditions y=2 when x=0, and plot the answer. 
#### Feel free to copy and paste/modify the code above and ask your instructor or neighbors for help. 
#### Warning: you may need to change the maximum range of x in the $\texttt{np.linspace()}$ command. Use the help option in Jupyter and/or Google for help. 

In [None]:
# write a python function for the RHS
def rhs(y, x):
    """Right hand side for our differential equation."""
    return y/(y+1)


# values for independent variable (x in this case)
x = np.linspace(1, 10, 1000)

# initial conditions
y0 = 1.0

# integrate
result = odeint(rhs, y0, x)
plt.plot(x, result)
plt.xlabel("x")
plt.ylabel("y")
plt.title("solution to $dy/dx = -1/2 y + 3/2$ with $y(0) = 1$");

## Second order scalar equations
There are two important adjustments we need to make for this to be really useful in a physics context.

As described above, `odeint` only works with first order equations. However, the ODEs we encounter in physics are, generally speaking, second order and above. So, what do we do?

Your differential equations class should have taught you that you can write an nth order ODE as a set of n first order ODEs. For example, Newton's second law in one dimension, $$\frac{d^2x}{dt^2} = \frac{F}{m}$$ is second order, but can be written as a pair of first order equations, namely, $$\frac{dx}{dt} = v$$ and $$\frac{dv}{dt} = \frac{F}{m}.$$

We can make use of this idea by making one small change to the pattern laid out above: let $y$ be a container for the two scalar variables $x$ and $v$.  To get specific, let's look at a mass oscillating on a spring. Then $F = -kx$. We will assume $m = k = 1$, but I'll show you where they would go into our python function.

In [None]:
# write a python function for the RHS
def rhs(y, t):
    """Right hand side for our differential equation."""
    # mass and spring constant
    m = 1
    k = 1
    
    # unpack y
    x,v = y
    
    # calculate the force
    F = -k*x
    
    # return dx/dt and dv/dt
    return v, F/m


# values for independent variable (x in this case)
t = np.linspace(0, 10, 1000)

# initial conditions
y0 = (1.0, 0)

# integrate
result = odeint(rhs, y0, t)

Again, we can plot to see the results. Note that now we have results for both $x(t)$ and $v(t)$; they are returned as elements of an array (hence the indexing in the plot statements). Since we're plotting two quantities with different units, I have chosen not to label the $y$-axis of this plot, but instead use a legend.

In [None]:
plt.plot(t, result[:,0], label="x (m)")
plt.plot(t, result[:,1], label="v (m/s)")
plt.xlabel("t")
leg = plt.legend(loc="upper left")
leg.set_bbox_to_anchor((1,1))
plt.title("Harmonic oscillator, m=k=1");

Although I've illustrated this here with a second order equation, this technique will work with any order ODE.

## Check #2: 
#### Use the technique above to solve and graph the position and velocity of an object in free fall (F = -m g), where $g = 9.8~m/s^2$

## Multiple initial conditions

Often in studying differential equations, we want to see the behavior of *families* of solutions, that is, collections of solutions with different initial conditions or different values for some key parameter. It is (relatively) easy to do either of these things, but it does require a small modification to how we set up the problem and plot the results.

Let's revisit the first equation we solved in this tutorial: $$\frac{dy}{dx} = -\frac{1}{2}y + \frac{3}{2},$$ and solve it with 10 different initial conditions, equally spaced between 0 and 6.

In [None]:
# write a python function for the RHS
def rhs(y, x):
    """Right hand side for our differential equation."""
    return -0.5*y + 1.5


# values for independent variable (x in this case)
x = np.linspace(0, 10, 1000)

# initial conditions
y0 = np.linspace(0,6,10)

# integrate
result = odeint(rhs, y0, x)

In [None]:
result.shape

Our `result` comes back with 10 series, each one corresponding to a different initial condition.

In [None]:
for idx, ic in enumerate(y0):
    plt.plot(x, result[:,idx], label="$y_0 =$" + "{:4.2f}".format(ic))
plt.xlabel("x")
plt.ylabel("y")
leg = plt.legend(loc="upper left")
leg.set_bbox_to_anchor((1,1))
plt.title("solutions to $dy/dx = -1/2 y + 3/2$");

# Check #3: Carefully explain the graph above:

# Check #4: Using what you learned above, use odeint to solve another ODE we've solved in this class

# If you finished all check points above before the end of class, please continue to work through the following examples. If you did not finish the examples before the end of class you may submit the assignment after finishing Check #4. 

For the mass on a spring, each initial condition had two values: $x_0$ and $v_0$. Unfortunately, this means we can't use the approach given above. Instead, we can make a list of initial conditions and loop through, integrating (and then plotting) them one at a time.

In [None]:
# write a python function for the RHS
def rhs(y, t):
    """Right hand side for our differential equation."""
    # mass and spring constant
    m = 1
    k = 1
    
    # unpack y
    x,v = y
    
    # calculate the force
    F = -k*x
    
    # return dx/dt and dv/dt
    return v, F/m


# values for independent variable (x in this case)
t = np.linspace(0, 10, 1000)

# initial conditions
x0 = np.linspace(-1,1,10)
v0 = np.ones_like(x0)
y0 = [*zip(x0,v0)] # make an association of pairs of initial position and velocity

# integrate and plot
for ic in y0:
    result = odeint(rhs, ic, t)
    # just plot x; plotting v also would be too busy
    plt.plot(t, result[:,0], label="$x_0 = $" + "{:4.2f}".format(ic[0]))

plt.xlabel("t")
leg = plt.legend(loc="upper left")
leg.set_bbox_to_anchor((1,1))
plt.title("Harmonic oscillator, m=k=1");

In [None]:
y0

## Additional arguments to the `rhs()` function

In the mass on a spring case, we may want to compare various values of the spring constant $k$ (or the mass, but it's really the ratio of the two that matters, so we can hold one fixed and just vary the other).  In the function we defined above, $k$ is constant. We could write a separate function for each value of $k$ we want to examine, but that would violate the principle of having the computer do as much of the work (and especially tedious work) as possible.

Instead, let's rewrite our `rhs` function to take parameters.

In [None]:
def rhs(y, t, k=1, m=1):
    """Right hand side for our differential equation.
    
    :param y: iterable with x, xdot
    :param t: independent variable. Not actually used here, but necessary for odeint
    :param k: spring constant
    :param m: mass
    
    :returns: tuple containing dx/dt, d2x/dt2
    """
    
    # unpack y
    x,v = y
    
    # calculate the force
    F = -k*x
    
    # return dx/dt and dv/dt
    return v, F/m

Now we're faced with a problem. How do we get the extra parameters into `rhs` when it's being called by `odeint`? The answer: we supply another argument to `odeint` containing the extra parameters.

In [None]:
# values for independent variable (x in this case)
t = np.linspace(0, 10, 1000)

# initial conditions
y0 = (1.0, 0)

# the trailing comma is necessary to make this a tuple
parameters = (3.0,)

# integrate
result = odeint(rhs, y0, t, args=parameters)

# plot results
plt.plot(t, result[:,0], label="x (m)")
plt.plot(t, result[:,1], label="v (m/s)")
plt.xlabel("t")
leg = plt.legend(loc="upper left")
leg.set_bbox_to_anchor((1,1))
plt.title("Harmonic oscillator, m=1, k={:4.2f}".format(parameters[0]));

To compare multiple values, we would again create a list and loop through it.

In [None]:
# values for independent variable (x in this case)
t = np.linspace(0, 10, 1000)

# initial conditions
y0 = (1.0, 0)

# the trailing comma is necessary to make this a tuple
paramslist = [(kval,) for kval in np.linspace(1.0, 5.0, 5)]

# integrate
for parameters in paramslist:
    result = odeint(rhs, y0, t, args=parameters)
    plt.plot(t, result[:,0], label="k = {:4.2f}".format(parameters[0]))


# plot results
plt.ylabel("x")
plt.xlabel("t")
leg = plt.legend(loc="upper left")
leg.set_bbox_to_anchor((1,1))
plt.title("Harmonic oscillator, m=1");

# Reference/Complete documentation

In [None]:
# full docstring for odeint
help(odeint)