<a href="https://colab.research.google.com/github/chetools/STEMUnleashed2025/blob/main/RotationMatrices_FresnetSerret.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import jax
import jax.numpy as jnp
from plotly.subplots import make_subplots

In [2]:
def rotmat(theta):
    return jnp.array([[jnp.cos(theta), -jnp.sin(theta)],
                      [jnp.sin(theta), jnp.cos(theta)]])

rotmat = jnp.vectorize(rotmat, signature='()->(2,2)')

def rotmat_x(theta):
    return jnp.array([[1, 0., 0.],
                        [0, jnp.cos(theta), -jnp.sin(theta)],
                      [0, jnp.sin(theta), jnp.cos(theta)]])

rotmat_x = jnp.vectorize(rotmat_x, signature='()->(3,3)')

In [3]:
thetas=jnp.linspace(0,2*jnp.pi, 24)
rotmats = rotmat(thetas)

In [4]:
v0 = jnp.array([1.5,0])

In [5]:
vs = jnp.einsum('nij, j -> in',  rotmats, v0)
vs=vs.at[0,:].set(vs[0,:]*2)

In [6]:
fig=make_subplots()
fig.add_scatter(x=vs[0],y=vs[1], mode='markers')
fig.update_xaxes(range=[-4,4])
fig.update_yaxes(range=[-4,4])
fig.update_layout(width=400,height=400,template='plotly_dark')

In [7]:
rotmat(jnp.pi/6)

Array([[ 0.8660254, -0.5      ],
       [ 0.5      ,  0.8660254]], dtype=float32)

In [8]:
vs[:,:3]

Array([[3.        , 2.8887517 , 2.5632582 ],
       [0.        , 0.40469518, 0.7793759 ]], dtype=float32)

In [9]:
vs30=jnp.einsum('ij, jn -> in', rotmat(jnp.pi/6), vs)

In [10]:
fig=make_subplots()
fig.add_scatter(x=vs[0],y=vs[1], mode='markers')
fig.add_scatter(x=vs30[0],y=vs30[1], mode='markers')
fig.update_xaxes(range=[-4,4])
fig.update_yaxes(range=[-4,4])
fig.update_layout(width=400,height=400,template='plotly_dark')

In [11]:
vs30xyz=jnp.c_[vs30[0],vs30[1],jnp.zeros_like(vs30[0])].T

In [12]:
rotmat_x(jnp.pi/6)

Array([[ 1.       ,  0.       ,  0.       ],
       [ 0.       ,  0.8660254, -0.5      ],
       [ 0.       ,  0.5      ,  0.8660254]], dtype=float32)

In [13]:
vs30_30xyz=jnp.einsum('ij, jn', rotmat_x(jnp.pi/6), vs30xyz)

In [14]:
import plotly.graph_objects as go

fig = go.Figure(data=[go.Scatter3d(x=vs30_30xyz[0], y=vs30_30xyz[1], z=vs30_30xyz[2], mode='markers')])

fig.update_layout(height=600,width=400,margin=dict(l=0, r=0, b=0, t=0), template='plotly_dark')

In [15]:
def spiral(t):
    x = (1+0.5*t)*jnp.cos(t) + t
    y = (1+0.5*t)*jnp.sin(t)
    z = t
    return jnp.array([x,y,z])

spiral = jnp.vectorize(spiral, signature='()->(3)')

In [16]:
def FresnetSerret(fun):

    drdt_fun=jax.jacobian(fun)

    def T(t):
        drdt = drdt_fun(t)
        return drdt/jnp.linalg.norm(drdt)

    dTdt_fun = jax.jacobian(T)

    def N(t):
        dTdt = dTdt_fun(t)
        drdt = drdt_fun(t)
        return dTdt/jnp.linalg.norm(drdt)

    return jnp.vectorize(T, signature='()->(3)'), jnp.vectorize(N, signature='()->(3)')


In [17]:
ts = jnp.linspace(0,6*jnp.pi, 120)
xyzs = spiral(ts)
T_fun, N_fun = FresnetSerret(spiral)

Ts = T_fun(ts)
curvature_Ns = N_fun(ts)
curvature = jnp.linalg.norm(curvature_Ns, axis=1)
Ns= curvature_Ns/curvature[:,None]
Rs = 1/curvature

In [18]:
idx = 100
center = xyzs[idx]+Rs[idx]*Ns[idx]
basis_matrix = np.c_[-Ns[idx], Ts[idx]]

In [19]:
circle_points= Rs[idx]*np.c_[np.cos(thetas), np.sin(thetas)].T
tangent_circle=basis_matrix @ circle_points+center[:,None]

In [21]:

fig = go.Figure(data=[go.Scatter3d(x=xyzs[:,0], y=xyzs[:,1], z=xyzs[:,3], mode='lines')])
# fig.add_trace(go.Cone(x=xyzs[:,0], y=xyzs[:,1], z=xyzs[:,3], u=Ts[:,0], v=Ts[:,1], w=Ts[:,2],
#                       sizemode='absolute',sizeref=5, colorscale='Blues', cmin = 0.9, cmax=1.1,
#                       showscale=False))

# fig.add_trace(go.Cone(x=xyzs[:,0], y=xyzs[:,1], z=xyzs[:,3], u=Ns[:,0], v=Ns[:,1], w=Ns[:,2],
#                       sizemode='absolute',sizeref=5, colorscale='Greens', cmin = 0.9, cmax=1.1,
#                       showscale=False))
fig.add_trace(go.Scatter3d(x=tangent_circle[0], y=tangent_circle[1], z=tangent_circle[2], mode='lines'))

fig.update_layout(height=600,width=600,margin=dict(l=0, r=0, b=0, t=0), template='plotly_dark')

In [77]:
R = 1.2
def tube_orig(t, theta):
    T = T_fun(t)
    curvature_Ns = N_fun(t)
    curvature = jnp.linalg.norm(curvature_Ns)
    N= curvature_Ns/curvature
    B = jnp.cross(T,N)
    xy= R*jnp.c_[jnp.cos(theta), jnp.sin(theta)].T
    tube_point=jnp.c_[N, B] @ xy
    return tube_point.squeeze() + spiral(t)

tube = jnp.vectorize(tube_orig, signature='(),()->(3)')

In [74]:
tube_ts=jnp.linspace(0,6*jnp.pi, 100)
tube_thetas = jnp.linspace(0,2*jnp.pi, 24)
pipexyzs = tube(tube_ts[:,None], tube_thetas[None,:])

In [76]:
fig = go.Figure(data=[go.Scatter3d(x=xyzs[:,0], y=xyzs[:,1], z=xyzs[:,3], mode='lines')])

for i in range(pipexyzs.shape[0]):
    fig.add_trace(go.Scatter3d(x=pipexyzs[i,:,0], y=pipexyzs[i,:,1], z=pipexyzs[i,:,2], mode='lines',line_color='green'))

for i in range(pipexyzs.shape[1]):
    fig.add_trace(go.Scatter3d(x=pipexyzs[:,i,0], y=pipexyzs[:,i,1], z=pipexyzs[:,i,2], mode='lines',line_color='green'))

fig.update_layout(height=600,width=600,margin=dict(l=0, r=0, b=0, t=0), template='plotly_dark',
                  showlegend=False)

In [80]:
jac = jax.jacobian(tube_orig, (0,1))

In [81]:
jac(6*jnp.pi, 0.)

(Array([1.3307328, 9.244813 , 1.0015277], dtype=float32, weak_type=True),
 Array([-0.01084279, -0.1130337 ,  1.1946154 ], dtype=float32, weak_type=True))