<a href="https://colab.research.google.com/github/chetools/STEMUnleashed2025/blob/main/RotationMatrices_FresnetSerret.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import jax
import jax.numpy as jnp
from plotly.subplots import make_subplots

In [None]:
def rotmat(theta):
    return jnp.array([[jnp.cos(theta), -jnp.sin(theta)],
                      [jnp.sin(theta), jnp.cos(theta)]])

rotmat = jnp.vectorize(rotmat, signature='()->(2,2)')

def rotmat_x(theta):
    return jnp.array([[1, 0., 0.],
                        [0, jnp.cos(theta), -jnp.sin(theta)],
                      [0, jnp.sin(theta), jnp.cos(theta)]])

rotmat_x = jnp.vectorize(rotmat_x, signature='()->(3,3)')

In [None]:
thetas=jnp.linspace(0,2*jnp.pi, 24)
rotmats = rotmat(thetas)

In [None]:
v0 = jnp.array([1.5,0])

In [None]:
vs = jnp.einsum('nij, j -> in',  rotmats, v0)
vs=vs.at[0,:].set(vs[0,:]*2)

In [None]:
fig=make_subplots()
fig.add_scatter(x=vs[0],y=vs[1], mode='markers')
fig.update_xaxes(range=[-4,4])
fig.update_yaxes(range=[-4,4])
fig.update_layout(width=400,height=400,template='plotly_dark')

In [None]:
rotmat(jnp.pi/6)

Array([[ 0.8660254, -0.5      ],
       [ 0.5      ,  0.8660254]], dtype=float32)

In [None]:
vs[:,:3]

Array([[3.        , 2.8887517 , 2.5632582 ],
       [0.        , 0.40469518, 0.7793759 ]], dtype=float32)

In [None]:
vs30=jnp.einsum('ij, jn -> in', rotmat(jnp.pi/6), vs)

In [None]:
fig=make_subplots()
fig.add_scatter(x=vs[0],y=vs[1], mode='markers')
fig.add_scatter(x=vs30[0],y=vs30[1], mode='markers')
fig.update_xaxes(range=[-4,4])
fig.update_yaxes(range=[-4,4])
fig.update_layout(width=400,height=400,template='plotly_dark')

In [None]:
vs30xyz=jnp.c_[vs30[0],vs30[1],jnp.zeros_like(vs30[0])].T

In [None]:
rotmat_x(jnp.pi/6)

Array([[ 1.       ,  0.       ,  0.       ],
       [ 0.       ,  0.8660254, -0.5      ],
       [ 0.       ,  0.5      ,  0.8660254]], dtype=float32)

In [None]:
vs30_30xyz=jnp.einsum('ij, jn', rotmat_x(jnp.pi/6), vs30xyz)

In [None]:
import plotly.graph_objects as go

fig = go.Figure(data=[go.Scatter3d(x=vs30_30xyz[0], y=vs30_30xyz[1], z=vs30_30xyz[2], mode='markers')])

fig.update_layout(height=600,width=400,margin=dict(l=0, r=0, b=0, t=0), template='plotly_dark')

In [None]:
def spiral(t):
    x = (1+0.5*t)*jnp.cos(t) + t
    y = (1+0.5*t)*jnp.sin(t)
    z = t
    return jnp.array([x,y,z])

spiral = jnp.vectorize(spiral, signature='()->(3)')

In [None]:
def FresnetSerret(fun):

    drdt_fun=jax.jacobian(fun)

    def T(t):
        drdt = drdt_fun(t)
        return drdt/jnp.linalg.norm(drdt)

    dTdt_fun = jax.jacobian(T)

    def N(t):
        dTdt = dTdt_fun(t)
        drdt = drdt_fun(t)
        return dTdt/jnp.linalg.norm(drdt)

    return jnp.vectorize(T, signature='()->(3)'), jnp.vectorize(N, signature='()->(3)')


In [None]:
ts = jnp.linspace(0,6*jnp.pi, 120)
xyzs = spiral(ts)
T_fun, N_fun = FresnetSerret(spiral)

Ts = T_fun(ts)
curvature_Ns = N_fun(ts)
curvature = jnp.linalg.norm(curvature_Ns, axis=1)
Ns= curvature_Ns/curvature[:,None]
Rs = 1/curvature

In [None]:
idx = 100
center = xyzs[idx]+Rs[idx]*Ns[idx]
basis_matrix = np.c_[-Ns[idx], Ts[idx]]

In [None]:
circle_points= Rs[idx]*np.c_[np.cos(thetas), np.sin(thetas)].T

In [None]:
tangent_circle=basis_matrix @ circle_points+center[:,None]

In [None]:

fig = go.Figure(data=[go.Scatter3d(x=xyzs[:,0], y=xyzs[:,1], z=xyzs[:,3], mode='lines')])
# fig.add_trace(go.Cone(x=xyzs[:,0], y=xyzs[:,1], z=xyzs[:,3], u=Ts[:,0], v=Ts[:,1], w=Ts[:,2],
#                       sizemode='absolute',sizeref=5, colorscale='Blues', cmin = 0.9, cmax=1.1,
#                       showscale=False))

# fig.add_trace(go.Cone(x=xyzs[:,0], y=xyzs[:,1], z=xyzs[:,3], u=Ns[:,0], v=Ns[:,1], w=Ns[:,2],
#                       sizemode='absolute',sizeref=5, colorscale='Greens', cmin = 0.9, cmax=1.1,
#                       showscale=False))
fig.add_trace(go.Scatter3d(x=tangent_circle[0], y=tangent_circle[1], z=tangent_circle[2], mode='lines'))

fig.update_layout(height=600,width=600,margin=dict(l=0, r=0, b=0, t=0), template='plotly_dark')