<a href="https://colab.research.google.com/github/profteachkids/StemUnleashed/blob/main/NumericalSymbolicAutomaticDerivatives.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import jax.numpy as jnp
import jax
from jax.config import config
config.update("jax_enable_x64", True)
from plotly.subplots import make_subplots
from functools import partial

In [4]:
def f(x):
    return jnp.cos(x)

In [15]:
def get_df(f):

    def df(x,h):
        return (f(x+h)-f(x))/h

    return df

In [18]:
df=get_df(f)
jax_df=jax.grad(f)

In [19]:
for h in range(0,-16,-1):
    print(f'{h:3d}, {df(np.pi/6,10**h):25.16e} {-np.sin(np.pi/6):25.16e}  {jax_df(np.pi/6):25.16e}')

  0,   -8.1884537358326781e-01   -4.9999999999999994e-01    -4.9999999999999994e-01
 -1,   -5.4243228105752106e-01   -4.9999999999999994e-01    -4.9999999999999994e-01
 -2,   -5.0432175764298925e-01   -4.9999999999999994e-01    -4.9999999999999994e-01
 -3,   -5.0043292933255046e-01   -4.9999999999999994e-01    -4.9999999999999994e-01
 -4,   -5.0004330043762479e-01   -4.9999999999999994e-01    -4.9999999999999994e-01
 -5,   -5.0000433011732071e-01   -4.9999999999999994e-01    -4.9999999999999994e-01
 -6,   -5.0000043305686859e-01   -4.9999999999999994e-01    -4.9999999999999994e-01
 -7,   -5.0000004359063155e-01   -4.9999999999999994e-01    -4.9999999999999994e-01
 -8,   -5.0000000806349476e-01   -4.9999999999999994e-01    -4.9999999999999994e-01
 -9,   -5.0000004137018550e-01   -4.9999999999999994e-01    -4.9999999999999994e-01
-10,   -5.0000004137018550e-01   -4.9999999999999994e-01    -4.9999999999999994e-01
-11,   -5.0000004137018550e-01   -4.9999999999999994e-01    -4.9999999999999

In [None]:
def klein(u,v):
    x = -2/15 * jnp.cos(u) *(3*jnp.cos(v)-30*jnp.sin(u)+90*jnp.cos(u)**4 * jnp.sin(u) -
                             60*jnp.cos(u)**6*jnp.sin(u)+5*jnp.cos(u)*jnp.cos(v)*jnp.sin(u))
    y= -1/15 * jnp.sin(u)*(3*jnp.cos(v)-3*jnp.cos(u)**2*jnp.cos(v)-48*jnp.cos(u)**4*jnp.cos(v)+48*jnp.cos(u)**6*jnp.cos(v)-
                           60*jnp.sin(u)+5*jnp.cos(u)*jnp.cos(v)*jnp.sin(u)-5*jnp.cos(u)**3*jnp.cos(v)*jnp.sin(u)-
                           80*jnp.cos(u)**5*jnp.cos(v)*jnp.sin(u)+80*jnp.cos(u)**7*jnp.cos(v)*jnp.sin(u))
    z=2/15*(3+5*jnp.cos(u)*jnp.sin(u))*jnp.sin(v)
    return jnp.array([x,y,z])

In [None]:
klein_vec=jnp.vectorize(klein,signature='(),()->(3)')

In [None]:
u=jnp.linspace(0,jnp.pi,50)
v=jnp.linspace(0,2*jnp.pi,50)
s=klein_vec(u[:,None], v[None,:])

In [None]:
x,y,z=s.T

In [None]:
fig=make_subplots(rows=1,cols=1,specs=[[{'type':'surface'}]])

for i in range(u.size):
    fig.add_scatter3d(x=x[i,:],y=y[i,:],z=z[i,:],mode='lines', line_color='white',row=1,col=1)

for i in range(v.size):
    fig.add_scatter3d(x=x[:,i],y=y[:,i],z=z[:,i],mode='lines', line_color='white',row=1,col=1)

fig.update_layout(width=600,height=600,template='plotly_dark', showlegend=False)

In [None]:
klein_jac=jax.jacobian(klein, (0,1))

In [None]:
klein_jac(0.1,0.2)

(DeviceArray([-0.58570369,  0.89772215,  0.12980611], dtype=float64, weak_type=True),
 DeviceArray([ 0.0921615 , -0.00067659,  0.45692969], dtype=float64, weak_type=True))

In [None]:
def klein_norm(u,v):
    v1,v2=klein_jac(u,v)
    n=jnp.cross(v2,v1)
    return n/jnp.linalg.norm(n)

In [None]:
klein_norm_vec=jnp.vectorize(klein_norm,signature='(),()->(3)')

In [None]:
dx,dy,dz=klein_norm_vec(u[:,None],v[None,:]).T

In [None]:
color_normal=[
        [0, "rgb(0, 200, 0)"],
        [1., "rgb(0, 200, 0)"]]
for i in range(u.size):

    fig.add_cone(x=x[i,:],y=y[i,:],z=z[i,:],u=dx[i,:],v=dy[i,:],w=dz[i,:], anchor='tail', colorscale=color_normal,
                 sizeref=1,showscale=False)
fig.update_layout(width=800,height=800,template='plotly_dark', showlegend=False)