<a href="https://colab.research.google.com/github/profteachkids/StemUnleashed/blob/main/Coil3D_Cross_Transform.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import numpy as np
import jax.numpy as jnp
import jax
from plotly.subplots import make_subplots

In [95]:


def curve(t):
    return jnp.array([jnp.cos(t), jnp.sin(t), t/(2*np.pi)])

curve_jac=jax.jacobian(curve)


In [96]:
def surf(t,theta,r):

    vl = curve(t)
    v_jac = curve_jac(t)

    n1=jnp.cross(v_jac,jnp.array([0.,0.,1.]))
    n1=n1/jnp.linalg.norm(n1)
    n2=jnp.cross(n1,v_jac)
    n2=n2/jnp.linalg.norm(n2)

    rot = jnp.array([[jnp.cos(theta), -jnp.sin(theta)],[jnp.sin(theta),jnp.cos(theta)]])
    vx = jnp.array([r,0])
    vrot = jnp.matmul(rot, vx)
    vsurf = jnp.matmul(jnp.stack([n1,n2], axis=1),vrot)

    return vl+vsurf

In [97]:
surf_v = jnp.vectorize(surf, signature='(),(),()->(3)')

In [98]:
t=jnp.linspace(0,4*np.pi,120)
theta=jnp.linspace(0,2*np.pi,16)
r=0.1
s=surf_v(t[:,None],theta[None,:],r)

In [99]:
fig = make_subplots(rows=1,cols=1,specs=[[{'is_3d': True}]])


for i in range(t.size):
    x,y,z = s[i,:].T
    fig.add_scatter3d(x=x,y=y,z=z,mode='lines', line_color='white')

for i in range(theta.size):
    x,y,z = s[:,i].T
    fig.add_scatter3d(x=x,y=y,z=z,mode='lines', line_color='white')

fig.update_layout(width=800,template='plotly_dark',showlegend=False, scene=dict(zaxis=dict(range=(-1,1))))

In [100]:
surf_jac=jax.jacobian(surf,(0,1))

In [101]:
surf(0.,0.,0.1)

DeviceArray([1.1, 0. , 0. ], dtype=float32)

In [102]:
surf_jac(0.,0.,0.1)

(DeviceArray([0.        , 1.1       , 0.15915494], dtype=float32, weak_type=True),
 DeviceArray([ 0.        , -0.01571767,  0.09875705], dtype=float32, weak_type=True))

In [103]:
nt=120
ntheta=64
t=jnp.linspace(0,4*np.pi,nt,endpoint=False)
theta=jnp.linspace(0,2*np.pi,ntheta,endpoint=False)

dt=2*np.pi/nt
dtheta=2*np.pi/ntheta

def surf_jacdet(t,theta,r):
    dxyz_dt, dxyz_dtheta = surf_jac(t,theta,r)
    return jnp.linalg.norm(jnp.cross(dxyz_dt*dt, dxyz_dtheta*dtheta))

surf_jacdet_vec=jnp.vectorize(surf_jacdet,signature='(),(),()->()')

jnp.sum(surf_jacdet_vec(t[:,None],theta[None,:],r))

DeviceArray(3.9975314, dtype=float32)

In [104]:
#arc length of curve
curve_jac_vec=jnp.vectorize(curve_jac,signature='()->(3)')
arclength=jnp.sum(jnp.sqrt(jnp.sum(curve_jac_vec(t)**2,axis=1)))*dt
arclength*2*np.pi*r


DeviceArray(3.9975297, dtype=float32)