<a href="https://colab.research.google.com/github/profteachkids/CHE5136_Fall2022/blob/main/Coil3D_FrenetSerret.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import jax
import jax.numpy as jnp
from jax.config import config
config.update("jax_enable_x64", True)
from plotly.subplots import make_subplots
from functools import partial

In [2]:
theta=np.linspace(0,2*np.pi,36)
xy=np.array([np.cos(theta),np.sin(theta)])

In [3]:
xy_stretch=np.array([[2.,0.],[0.,1.]]) @ xy

rot_angle = np.pi/6
xy_stretch_rotate = np.array([[np.cos(rot_angle), -np.sin(rot_angle)],[np.sin(rot_angle), np.cos(rot_angle)]]) @ xy_stretch

In [4]:
fig=make_subplots()
fig.add_scatter(x=xy[0],y=xy[1],mode='markers')
fig.add_scatter(x=xy_stretch[0],y=xy_stretch[1],mode='markers')
fig.add_scatter(x=xy_stretch_rotate[0],y=xy_stretch_rotate[1],mode='markers')
fig.update_xaxes(range=[-3,3])
fig.update_yaxes(range=[-3,3])
fig.update_layout(width=600,height=600,template='plotly_dark')

In [5]:
t=jnp.linspace(0,4*np.pi,36)
theta = jnp.linspace(0,2*np.pi, 18)
R=3.
mz=0.4
r=0.5



In [109]:
def curve(t, R, mz):
    xyz = jnp.array([R*jnp.cos(t), R*jnp.sin(t), mz*t])
    return xyz

# test for line without using Frenet-Serret
# def curve(t, R, mz):
#     xyz = jnp.array([t, 0.3*t, 0.5*t+0.1])
#     return xyz

In [110]:
def frenet(f):

    dfdt = jax.jacobian(f)

    def dsdt(t):
        return jnp.linalg.norm(dfdt(t))

    def tangent(t):   #df_ds = dfdt / dsdt
        return dfdt(t)/dsdt(t)

    dtangent_dt = jax.jacobian(tangent)

    def normal(t):
        v= dtangent_dt(t) / dsdt(t)
        return v/jnp.linalg.norm(v)


    return tangent, normal

In [111]:
xyz= curve(t,R,mz)

In [112]:
fig=make_subplots(rows=1,cols=1,specs=[[{'type':'surface'}]])
fig.add_scatter3d(x=xyz[0,:],y=xyz[1,:],z=xyz[2,:],mode='lines',row=1,col=1)
fig.update_layout(width=600,height=600,template='plotly_dark')

In [113]:
tangent, normal = frenet(lambda t: curve(t, R=R,mz=mz))

def surf(t,theta):
    T = tangent(t)
    N = normal(t)
    B = jnp.cross(T,N)
    return curve(t, R, mz) + jnp.stack([N, B], axis=1) @ jnp.array([r*jnp.cos(theta), r*jnp.sin(theta)])


In [114]:
def get_tangent_local_basis_function(f):

    dfdt = jax.jacobian(f)

    def dsdt(t):
        return jnp.linalg.norm(dfdt(t))

    def tangent(t):   #df_ds = dfdt / dsdt
        return dfdt(t)/dsdt(t)

    def local_basis(t):
        arb_v = jnp.array([0.,0.,1.])
        v1 = jnp.cross(arb_v, tangent(t))
        v2 = jnp.cross(v1, tangent(t))
        return v1/jnp.linalg.norm(v1), v2/jnp.linalg.norm(v2)

    return tangent, local_basis



In [115]:
tangent, local_basis=get_tangent_local_basis_function(lambda t: curve(t,R, mz))

def surf2(t,theta):
    v1,v2 = local_basis(t)
    return curve(t, R, mz) + jnp.stack([v1, v2], axis=1) @ jnp.array([r*jnp.cos(theta), r*jnp.sin(theta)])

In [116]:
# surf_vec = jnp.vectorize(surf, signature='(),()->(3)')
surf2_vec = jnp.vectorize(surf2, signature='(),()->(3)')

In [117]:
# s1=surf_vec(t[:,None], theta[None,:])
s2=surf2_vec(t[:,None], theta[None,:])

In [118]:
x,y,z = s2.T

In [119]:
fig=make_subplots(rows=1,cols=1,specs=[[{'type':'surface'}]])

for i in range(theta.size):
    fig.add_scatter3d(x=x[i,:],y=y[i,:],z=z[i,:],mode='lines', line_color='white',row=1,col=1)

for i in range(t.size):
    fig.add_scatter3d(x=x[:,i],y=y[:,i],z=z[:,i],mode='lines', line_color='white',row=1,col=1)

fig.update_layout(width=600,height=600,template='plotly_dark', showlegend=False)

In [17]:
def get_grad(f):

    def grad(x,h):
        res=[]
        xplush=np.copy(x)
        for i in range(x.size):
            xplush[i]=x[i]+h
            res.append((f(xplush) - f(x))/h)
        return np.array(res)
    return grad
        



In [18]:
def f(x):
    return np.sin(x[0])+ np.cos(2*x[1])

In [19]:
grad=get_grad(f)

In [20]:
grad(np.array([np.pi/6, np.pi/3]),h=1e-8)

array([ 0.8660254 , -0.86602539])

In [104]:
def get_jac(f):
    def jac(x0, dx=1e-8):
        x0=np.atleast_1d(x0)
        f0 = np.atleast_1d(f(x0))
        J=np.zeros((f0.size, x0.size))
        for i in range(x0.size):
            x=np.copy(x0)
            x[i]=x0[i]+dx
            J[:,i]=((f(x)-f0)/dx)[:,0]
        return J

    return jac

In [105]:
jac=get_jac(lambda t: curve(t,R,mz))

In [106]:
jac(0.)

array([[1. ],
       [0.3],
       [0.5]])

In [107]:
jax_jac=jax.jacobian(lambda t: curve(t,R,mz))

In [108]:
jax_jac(0.)

DeviceArray([1. , 0.3, 0.5], dtype=float64, weak_type=True)