<a href="https://colab.research.google.com/github/profteachkids/StemUnleashed/blob/main/NumericalSymbolicAutomaticDerivatives.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [91]:
import numpy as np
import jax.numpy as jnp
import jax
from jax.config import config
config.update("jax_enable_x64", True)
from plotly.subplots import make_subplots
from functools import partial

In [92]:
def f(x):
    return jnp.cos(x)

In [93]:
def get_d2f(f):

    def d2f(x,h):
        return (f(x+h)-2*f(x)+f(x-h))/h**2

    return d2f

In [94]:
d2f=get_d2f(f)
jax_d2f=jax.grad(jax.grad(f))

In [95]:
for h in range(0,-16,-1):
    print(f'{h:3d}, {d2f(np.pi/6,10**h):25.16e} {-np.cos(np.pi/6):25.16e}  {jax_d2f(np.pi/6):25.16e}')

  0,   -7.9621976235863934e-01   -8.6602540378443871e-01    -8.6602540378443871e-01
 -1,   -8.6530395646761105e-01   -8.6602540378443871e-01    -8.6602540378443871e-01
 -2,   -8.6601818693132770e-01   -8.6602540378443871e-01    -8.6602540378443871e-01
 -3,   -8.6602533166946216e-01   -8.6602540378443871e-01    -8.6602540378443871e-01
 -4,   -8.6602541804481348e-01   -8.6602540378443871e-01    -8.6602540378443871e-01
 -5,   -8.6602613968977937e-01   -8.6602540378443871e-01    -8.6602540378443871e-01
 -6,   -8.6608498151008462e-01   -8.6602540378443871e-01    -8.6602540378443871e-01
 -7,   -8.7707618945387378e-01   -8.6602540378443871e-01    -8.6602540378443871e-01
 -8,   -1.1102230246251563e+00   -8.6602540378443871e-01    -8.6602540378443871e-01
 -9,   -1.1102230246251564e+02   -8.6602540378443871e-01    -8.6602540378443871e-01
-10,    0.0000000000000000e+00   -8.6602540378443871e-01    -8.6602540378443871e-01
-11,    0.0000000000000000e+00   -8.6602540378443871e-01    -8.6602540378443

In [96]:
def klein(u,v):
    x = -2/15 * jnp.cos(u) *(3*jnp.cos(v)-30*jnp.sin(u)+90*jnp.cos(u)**4 * jnp.sin(u) -
                             60*jnp.cos(u)**6*jnp.sin(u)+5*jnp.cos(u)*jnp.cos(v)*jnp.sin(u))
    y= -1/15 * jnp.sin(u)*(3*jnp.cos(v)-3*jnp.cos(u)**2*jnp.cos(v)-48*jnp.cos(u)**4*jnp.cos(v)+48*jnp.cos(u)**6*jnp.cos(v)-
                           60*jnp.sin(u)+5*jnp.cos(u)*jnp.cos(v)*jnp.sin(u)-5*jnp.cos(u)**3*jnp.cos(v)*jnp.sin(u)-
                           80*jnp.cos(u)**5*jnp.cos(v)*jnp.sin(u)+80*jnp.cos(u)**7*jnp.cos(v)*jnp.sin(u))
    z=2/15*(3+5*jnp.cos(u)*jnp.sin(u))*jnp.sin(v)
    return jnp.array([x,y,z])

In [97]:
klein_vec=jnp.vectorize(klein,signature='(),()->(3)')

In [98]:
u=jnp.linspace(0,jnp.pi,50,endpoint=True)
v=jnp.linspace(0,2*jnp.pi,50,endpoint=False)
s=klein_vec(u[:,None], v[None,:])

In [99]:
x,y,z=s.T

In [100]:
fig=make_subplots(rows=1,cols=1,specs=[[{'type':'surface'}]])

for i in range(u.size):
    fig.add_scatter3d(x=x[i,:],y=y[i,:],z=z[i,:],mode='lines', line_color='white',row=1,col=1)

for i in range(v.size):
    fig.add_scatter3d(x=x[:,i],y=y[:,i],z=z[:,i],mode='lines', line_color='white',row=1,col=1)

fig.update_layout(width=600,height=600,template='plotly_dark', showlegend=False)

In [101]:
klein_jac=jax.jacobian(klein, (0,1))

In [102]:
klein_jac(0.1,0.2)

(DeviceArray([-0.58570369,  0.89772215,  0.12980611], dtype=float64, weak_type=True),
 DeviceArray([ 0.0921615 , -0.00067659,  0.45692969], dtype=float64, weak_type=True))

In [103]:
def klein_norm(u,v):
    v1,v2=klein_jac(u,v)
    n=jnp.cross(v2,v1)
    return n/jnp.linalg.norm(n)

In [104]:
klein_norm_vec=jnp.vectorize(klein_norm,signature='(),()->(3)')

In [105]:
dx,dy,dz=klein_norm_vec(u[:,None],v[None,:]).T

In [106]:
color_normal=[
        [0, "rgb(0, 200, 0)"],
        [1., "rgb(0, 200, 0)"]]
for i in range(u.size):

    fig.add_cone(x=x[i,:],y=y[i,:],z=z[i,:],u=dx[i,:],v=dy[i,:],w=dz[i,:], anchor='tail', colorscale=color_normal,
                 sizeref=1,showscale=False)
fig.update_layout(width=800,height=800,template='plotly_dark', showlegend=False)