# Ex. Solving an ODE with Symbolic Math

In [None]:
import numpy as np
import sympy
from scipy import integrate
import matplotlib.pyplot as plt
import ipywidgets

In [None]:
sympy.init_printing()

## Harmonic Oscillator

### Symbolic Manipulation (with SymPy)

In [None]:
t = sympy.Symbol("t")
omega0 = sympy.Symbol("omega0")
x = sympy.Function('x')

In [None]:
ode = x(t).diff(t, 2) + omega0**2 * x(t)

In [None]:
ode

In [None]:
ode_sol = sympy.dsolve(ode)

In [None]:
ode_sol

In [None]:
ode_sol.rhs

In [None]:
# initial conditions
ics = {x(0): 2, x(t).diff(t).subs(t, 0): 3}

In [None]:
ics

In [None]:
ode_sol = sympy.dsolve(ode,ics=ics)

In [None]:
ode_sol

In [None]:
ode_sol.rewrite(sympy.cos).simplify()

In [None]:
ics

In [None]:
ode_sol = sympy.dsolve(ode)

In [None]:
ode_sol

In [None]:
ode_sol.free_symbols

In [None]:
ode_sol.free_symbols - {omega0}

In [None]:
(ode_sol.lhs.diff(t,0) - ode_sol.rhs.diff(t,0)).subs(t,0)

In [None]:
(ode_sol.lhs.diff(t,0) - ode_sol.rhs.diff(t,0)).subs(t,0).subs(ics)

In [None]:
(ode_sol.lhs.diff(t,1) - ode_sol.rhs.diff(t,1)).subs(t,0)

In [None]:
(ode_sol.lhs.diff(t,1) - ode_sol.rhs.diff(t,1)).subs(t,0).subs(ics)

In [None]:
eqs = [(ode_sol.lhs.diff(t, n) - ode_sol.rhs.diff(t, n)).subs(t, 0).subs(ics)
       for n in range(len(ics))]

In [None]:
eqs

In [None]:
ode_sol.free_symbols - {omega0}

In [None]:
sympy.solve(eqs, ode_sol.free_symbols - set([omega0]))

In [None]:
sol_params = sympy.solve(eqs, ode_sol.free_symbols - set([omega0]))

In [None]:
x_t_sol = ode_sol.subs(sol_params)

In [None]:
x_t_sol

In [None]:
x_t_sol.rewrite(sympy.cos).simplify()

Let's use lambdify to plot: "The primary purpose of this function [lambdify] is to provide a bridge from SymPy expressions to numerical libraries such as NumPy, SciPy, NumExpr, mpmath, and tensorflow."

In [None]:
square = sympy.lambdify(t, t**2)

In [None]:
square(6)

In [None]:
np.linspace(0,1,10)

In [None]:
square(np.linspace(0,1,10))

In [None]:
sympy.lambdify?

In [None]:
expr_func = sympy.lambdify(t, x_t_sol.rhs.subs(omega0,1), 'numpy')

In [None]:
xvalues = np.linspace(0, 10, 30)

In [None]:
expr_func(xvalues)

In [None]:
plt.plot(xvalues,expr_func(xvalues));