# COURSE: Master math by coding in Python
# SECTION: Calculus
# VIDEO: Finding critical points


### https://www.udemy.com/course/math-with-python/?couponCode=202312
#### INSTRUCTOR: Mike X Cohen (http://sincxpress.com)

This code roughly matches the code shown in the live recording: variable names, order of lines, and parameter settings may be slightly different.

<a target="_blank" href="https://colab.research.google.com/github/mikexcohen/MathWithPython/blob/main/calculus/mathWithPython_calc_critPoints.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

In [None]:
# import libraries
import sympy as sym
import matplotlib.pyplot as plt
import numpy as np
from scipy.signal import find_peaks
from IPython.display import display,Math

In [None]:
# The empirical method (useful for df=0; won't work for non-differentiable points)

# create a function
x = np.linspace(-5,5,1000) # vs 1001
fx = x**2 * np.exp(-x**2)

# extrema
localmax = find_peaks(fx)[0]
localmin = find_peaks(-fx)[0]
print('The critical points are ' + str(x[localmax]) + ' ' + str(x[localmin]))

# compute its derivative
dfx = np.diff(fx)/np.mean(np.diff(x)) # scale by dx!

# plot everything
plt.plot(x,fx,label='y')
plt.plot(x[0:-1],dfx,label='dy/dx')
plt.plot(x[localmax],fx[localmax],'ro',label='local max.')
plt.plot(x[localmin],fx[localmin],'gs',label='local min.')
plt.plot(x[[0,-1]],[0,0],'--',c=[.7,.7,.7])

plt.legend()
plt.xlim(x[[0,-1]])
plt.show()

In [None]:
# The analytic (symbolic) method

x = sym.symbols('x')
fx = x**2 * sym.exp(-x**2)

# derivative in sympy, solve
dfx = sym.diff(fx,x)
critpoints = sym.solve(dfx)
print('The critical points are: ' + str(critpoints))


# some sympy plotting
p = sym.plot(fx,(x,-5,5),show=False)
p.extend( sym.plot(dfx,(x,-5,5),show=False,line_color='r') )

p[0].label = 'y'
p[1].label = 'dy/dx'
p.legend = True
p.show()

# Exercise

In [None]:
# what values of 'a' give this function a critical point at x=1 or x=2?
a,x = sym.symbols('a,x')

baseexpr = x**2 * sym.exp(-a*x**2)
arange = np.arange(0,2.25,.25)
xrange = np.linspace(-3,3,100)

# setup plots
fig,ax = plt.subplots(1,2)

for ai in arange:

  fx = baseexpr.subs(a,ai)
  dfx = sym.diff(fx)
  critpnts = sym.solve( dfx )

  # also plot the function in subplot1 and its derivative in subplot2
  ax[0].plot(xrange,sym.lambdify(x,fx)(xrange))
  ax[1].plot(xrange,sym.lambdify(x,dfx)(xrange))


  if 1 in critpnts:
    display(Math('\\Rightarrow %s\\text{ has a critical point at x=1! Woohoo!!}' %sym.latex(fx)))
  elif 2 in critpnts:
    display(Math('\\Rightarrow %s\\text{ has a critical point at x=2! Woohoo!!}' %sym.latex(fx)))
  else:
    display(Math('\\quad %s\\text{ has NO critical point at x=2. :(}' %sym.latex(fx)))



# some adjustments to the function plot
ax[0].set_ylim([0,2])
ax[0].set_title('Function')
ax[0].plot([1,1],[0,2],'--',c='gray')
ax[0].plot([2,2],[0,2],'--',c='gray')

# adjustments to the derivative plot
ax[1].set_ylim([-1.5,1.5])
ax[1].plot(xrange[[0,-1]],[0,0],'--',c='gray')
ax[1].plot([1,1],[-1.5,1.5],'--',c='gray')
ax[1].plot([2,2],[-1.5,1.5],'--',c='gray')
# ax[1].set_xlim([.5,2.5])
ax[1].set_title('Its derivative')
fig.set_size_inches(8,4)

plt.show()