# Solutions to pre-class assignment on solving boundary-value ODEs

In this pre-class assignment, we numerically solve $y'' = 4y$ over the interval of $[0,2]$, using the boundary conditions $y(0) =
5$ and $y'(2) = 218.282706$.

The analytic solution is $y(x) = 2 e^{2x} + 3 e^{-2x}$


In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

In [None]:
def rk4(y1_0,y2_0,rhs,xl=0.0,xr=1.0,n=100):
    '''
    does 4th order Runge-Kutta for given y1(xl), y2(xl), rhs (right hand side 
    of equation to be integrated) integrates to xy over n steps.  Assume that
    we only have 2 equations to make the code clearer, though note that this whole
    set of solutions can in principle be written for a vector of equations rather than
    two if we wanted to do so.
    '''

    # set up our arrays and step size
    y1 = np.zeros(n)
    y2 = np.zeros(n)
    x = np.linspace(xl, xr, n)
    h = x[1] - x[0]  # stepsize

    # set up left boundary
    y1[0] = y1_0
    y2[0] = y2_0

    # integrate from left to right boundary using given rhs function.
    for m in range(n-1):
        dy1dx_1, dy2dx_1 = rhs(y1[m], y2[m])
        dy1dx_2, dy2dx_2 = rhs(y1[m] + 0.5*h*dy1dx_1, y2[m] + 0.5*h*dy2dx_1)
        dy1dx_3, dy2dx_3 = rhs(y1[m] + 0.5*h*dy1dx_2, y2[m] + 0.5*h*dy2dx_2)
        dy1dx_4, dy2dx_4 = rhs(y1[m] + h*dy1dx_3, y2[m] + h*dy2dx_3)

        y1[m+1] = y1[m] + (h/6.0)*(dy1dx_1 + 2.0*dy1dx_2 + 2.0*dy1dx_3 + dy1dx_4)
        y2[m+1] = y2[m] + (h/6.0)*(dy2dx_1 + 2.0*dy2dx_2 + 2.0*dy2dx_3 + dy2dx_4)

    return y1, y2  # 2 arrays


def rhs(y1, y2):
    '''
    right hand side (i.e., derivatives) from pre-class assignment.
    Original equation is y'' = 4y
    Linearized equations are:
    dy1/dx = y2
    dy2/dx = 4y1
    
    Note that y1 = y
    '''
    dy1dx = y2
    dy2dx = 4.0*y1

    return dy1dx, dy2dx


def analytic_f(x):
    # analytic function we're trying to match
    return 2.0*np.exp(2.0*x) + 3.0*np.exp(-2.0*x)

def analytic_fprime(x):
    # analytic derivative we're trying to match (just for fun)
    return 4.0*np.exp(2.0*x) - 6.0*np.exp(-2.0*x)


In [None]:
# boundaries for equation
x_left = 0.0
x_right = 2.0

# these are the values we actually know are true
y1_left_true = 5.0
y2_right_true = 218.282706

# this many steps in interval (doesn't really matter, just for plotting)
points=20

# max number of iterations in our secant method
maxiters=20

# desired accuracy
eps = 1.0e-6

In [None]:
# set this to be large
dy = 1.0e+4*eps

y1_0 = y1_left_true  # correct answer
y2_0 = 0.0  # will adjust

# make a first guess to get some test values for y1, y2 
# (as arrays for plotting convenience - we really only 
#  care about the last points, though!)
y1_old, y2_old = rk4(y1_0,y2_0,rhs,xl=x_left,xr=x_right,n=points)

y2_tm1 = y2_0  # y2[0] at t-1
y2_0 = 0.1    # new guess for y2[0] - will get modified

iter = 0


'''
Iterate using secant method: we want to get the correct value of 
eta, which is y2(0) - the parameter we're trying to find that gives 
us the correct value of y2(2).  So, we want to zero out the difference 
between our estimate of y2(2) and our known value of y2(2), or more
accurately zero:

f(eta) = y2^eta(2) - y2_true(2)

df/deta is the change in y2(2) as we change our estimates for y2(0).
'''
while dy > eps and iter <= maxiters: 
    
    # get new values of y1 and y2 arrays
    y1, y2 = rk4(y1_0,y2_0,rhs,xl=x_left,xr=x_right,n=points)
    
    # technically this is ((y2[p-1]-y2_r_true)-(y2_old[p-1]-y2_r_true))/(y2_0-y2_tml)
    # but I simplified it for ease of reading.
    dfdeta = (y2[points-1] - y2_old[points-1])/(y2_0-y2_tm1)

    deta = -(y2[points-1]-y2_right_true)/dfdeta

    
    y2_tm1 = y2_0
    y2_0 += deta
    
    dy = abs(deta)
    iter += 1
    
    y1_old = y1
    y2_old = y2
    
    print(y1[0],y2[0],iter,dy)
    

x = np.linspace(x_left,x_right,points)
y_an = analytic_f(x)

plt.plot(x,y1,'r-',x,y_an,'b.')

## Now solve using SciPy!

It turns out that [integrate.solve_bvp](https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.solve_bvp.html) is irritatingly complicated to use, but hopefully the heavily annotated example below will help to make sense of it.

Note that the command `help(solve_bvp)` actually provides some pretty useful information; probably more useful once you've read the comments below.  Some relatively handy examples can be found at https://gist.github.com/nmayorov/f8af5ca956c6a7f75ecdb578a2655894 .


In [None]:
def rhs_new(x, y):
    '''
    right hand side (i.e., derivatives) from pre-class assignment.
    Original equation is y'' = 4y
    Linearized equations are:
    dy1/dx = y2    (or y1')
    dy2/dx = 4y1   (or y2')
    
    (Note that y1 = y)
    
    In the function call, x is a 1D array of mesh points of length M; y is a 2D array that is 2 by M points;
    solve_bvp() can accept an arbitrary number of solved equations, so for N equations Y would be N x M points.
    Note that solve_bvp() will also accept a function of the form f(x,y,k) where k is an array of *unknown*
    parameters (so the 4 in the y2' equation above could in principle be unknown)
    
    We use np.vstack() to return a "stacked" set of solutions to the RHS.  For M equations, this would be
    np.vstack( (eqn1, eqn2, ... , eqnM) ).  In our case, we're returning np.vstack(y1', y2').  Since numpy
    users arrays that are zero-ordered, this is as follows:
    '''
    return np.vstack( (y[1],4.0*y[0]) )

def bc(ya,yb):
    '''
    This function returns the boundary condition residual - in other words, the difference between the guess and
    the known boundary conditions.  For inputs, ya and yb are the left and right boundary conditions that come 
    from solve_bvp, respectively; these are numpy arrays of length N (since we're solving N equations 
    simultaneously).  In our case, that's N=2.  As with the function above, solve_bvp() will accept a boundary 
    condition function of the form bc(ya,yb,k), where k is an array of *unknown* parameters.
    
    We know that y(0) = 5 and y'(2) = 218.282706, which means that we know the correct answers are ya[0] = 5 
    and yb[1] = 218.282706.  We're returning an array of residuals for the boundary conditions, so it's in the 
    form np.array( [ya_guess-ya_true, yb_guess-yb_true] ).  If we had unknown parameters, we'd have to add additional
    residuals for those (1 extra residual per unknown parameters).  Since hwe have two, though, it looks like this:
    '''
    return np.array( [ya[0]-5, yb[1]-218.282706] )
    

In [None]:
# actually the solver we need
from scipy.integrate import solve_bvp

'''
1D mesh that we want to feed into solve_bvp.  Number of points doesn't
really matter too much, but xmesh[0] must be the left edge of the 
interval we're solving over and xmesh[M-1] (which is xmesh[-1] in numpy 
nomenclature, assuming M mesh points) must be the right edge of the interval.
'''

xmesh = np.linspace(x_left,x_right,num=101,endpoint=True)

'''
This is our initial guess for the function values at the mesh nodes.
for N equations and M mesh points, this is of shape N x M.  To make sure
that we don't get out of shape compared to xmesh, we set this to 2 x xmesh.shape[0].
The initial guess doesn't really matter, so we just set it to zeros.
'''

yvals = np.zeros((2, xmesh.shape[0]))

'''
Now call solve_bvp, which returns an object that I'm calling 'sols'.  The important 
things that it returns are sols.x, which is 1D array of length M containing the mesh
points, and then sols.y, which is an array of size N x M (for N equations and M mesh 
points).  In our case, given the way we've set things up sols.y[0] are our y-values 
and sols.y[1] are our y' values.
''' 
sols=solve_bvp(rhs_new,bc,xmesh,yvals)

# Note: Type "sols?" (without the quotes) to get information about what's in this object.  


In [None]:
'''
And now we plot our analytic solution f(x) and our 
calculated solution (sols.y[0]).
'''
plt.plot(xmesh,analytic_f(xmesh),'r-',linewidth=4,alpha=0.5)
plt.plot(xmesh,sols.y[0],'b:',linewidth=3)
plt.xlabel('x')
plt.ylabel('f(x)')
plt.title('f(x) vs. x')

In [None]:
'''
For fun, plot our analytic derivative f'(x) and our 
calculated derivative (sols.y[1]).
'''
plt.plot(xmesh,analytic_fprime(xmesh),'g-',linewidth=4,alpha=0.5)
plt.plot(xmesh,sols.y[1],'k:',linewidth=2)
plt.xlabel('x')
plt.ylabel('f\'(x)')
plt.title('f\'(x) vs. x')

In [None]:
'''
plot relative error of numerical solution for numerical f(x) and f'(x)
compared to analytic.  Not too bad!
'''
plt.plot(xmesh,sols.y[0]/analytic_f(xmesh)-1.0,'b-',linewidth=3)
plt.plot(xmesh,sols.y[1]/analytic_fprime(xmesh)-1.0,'r-',linewidth=3)
plt.xlabel('x')
plt.ylabel('relative error')
plt.title('relative error for numerical f(x), f\'(x)\ncompared to analytic solution')