<a href="https://colab.research.google.com/github/profteachkids/StemUnleashed/blob/main/Coil3D_Cross_Transform.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import jax.numpy as jnp
import jax
from plotly.subplots import make_subplots

In [14]:
def curve(t,R, mz):
    return jnp.array([R*jnp.cos(t), R*jnp.sin(t), t*mz/(2*np.pi)])

#Research on what is the Jacobian of a function.
#How does this differ from the gradient of a function?  
curve_jac=jax.jacobian(curve, 0)
print(curve(np.pi/6.,1., 2*np.pi))
print(curve_jac(np.pi/6.,1., 2*np.pi))

[0.8660254 0.5       0.5235988]
[-0.5        0.8660254  1.       ]


In [3]:
#I woke up at 4 AM a couple of weeks ago with this idea ...
def surf(t,theta,r,R,mz):

    vl = curve(t,R,mz)
    v_jac = curve_jac(t,R,mz)

    #What are n1 and n2?
    n1=jnp.cross(v_jac,jnp.array([0.,0.,1.]))  #the choice of unit vector in z-direction is arbitrary, but ok
    n1=n1/jnp.linalg.norm(n1)
    n2=jnp.cross(n1,v_jac)
    n2=n2/jnp.linalg.norm(n2)

    #rotation matrix
    rot = jnp.array([[jnp.cos(theta), -jnp.sin(theta)],[jnp.sin(theta),jnp.cos(theta)]])

    vx = jnp.array([r,0])   #r along the x-coordinate
    vrot = jnp.matmul(rot, vx)  #rotate by theta
    vsurf = jnp.matmul(jnp.stack([n1,n2], axis=1),vrot)  #what is this???

    return vl+vsurf  #why add?

In [15]:
#research on what JAX vectorize does (AMAZING huh!)
surf_v = jnp.vectorize(surf, signature='(),(),(),(),()->(3)')

In [5]:
t=jnp.linspace(0,4*np.pi,120)
theta=jnp.linspace(0,2*np.pi,16)
R=1.
r=0.2
mz=0.8
s=surf_v(t[:,None],theta[None,:],r, R,mz)

In [6]:
fig = make_subplots(rows=1,cols=1,specs=[[{'type': 'scene'}]])

for i in range(t.size):
    x,y,z = s[i,:].T  #why tranpose?  take a look at unpacking for small matrices
    fig.add_scatter3d(x=x,y=y,z=z,mode='lines', line_color='white')

for i in range(theta.size):
    x,y,z = s[:,i].T
    fig.add_scatter3d(x=x,y=y,z=z,mode='lines', line_color='white')

fig.update_layout(width=800,template='plotly_dark',showlegend=False)

In [7]:
#what is the purpose of (0,1)
surf_jac=jax.jacobian(surf,(0,1))

In [8]:
#examine different t, theta (multiples of pi/2) values and for convenient r, R and also mz=0 
surf(0.,0.,r, R, 0.8)

DeviceArray([1.2, 0. , 0. ], dtype=float32)

In [9]:
#In the Jacobian, why does 1.2 being equal to R+r and 0.1975 being close to r=0.2 make sense?
surf_jac(0.,0.,r,R,1)

(DeviceArray([0.        , 1.2       , 0.15915494], dtype=float32, weak_type=True),
 DeviceArray([ 0.        , -0.03143534,  0.1975141 ], dtype=float32, weak_type=True))

In [10]:
nt=120
ntheta=64
t=jnp.linspace(0,4*np.pi,nt,endpoint=False)
theta=jnp.linspace(0,2*np.pi,ntheta,endpoint=False)

dt=4*np.pi/nt
dtheta=2*np.pi/ntheta

def surf_jacdet(t,theta,r):

    #why name these dxyz_dt and dxyz_dtheta?
    dxyz_dt, dxyz_dtheta = surf_jac(t,theta,r,R,mz)

    #why do we multiply by dt and dtheta?
    #recall the geometric meaning of the cross product
    return jnp.linalg.norm(jnp.cross(dxyz_dt*dt, dxyz_dtheta*dtheta))

surf_jacdet_vec=jnp.vectorize(surf_jacdet,signature='(),(),()->()')

jnp.sum(surf_jacdet_vec(t[:,None],theta[None,:],r))

DeviceArray(15.918863, dtype=float32)

In [11]:
#arc length of curve, what approximation are we making here for the area?
curve_jac_vec=jnp.vectorize(curve_jac,signature='(),(),()->(3)')

#why do we take the sqrt and sum of squares?
arclength=jnp.sum(jnp.sqrt(jnp.sum(curve_jac_vec(t,R,mz)**2,axis=1)))*dt
arclength*2*np.pi*r


DeviceArray(15.918854, dtype=float32)