<a href="https://colab.research.google.com/github/chetools/STEMUnleashed2025/blob/main/RotationMatrices.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import jax
import jax.numpy as jnp
from plotly.subplots import make_subplots

In [50]:
def rotmat(theta):
    return jnp.array([[jnp.cos(theta), -jnp.sin(theta)],
                      [jnp.sin(theta), jnp.cos(theta)]])

rotmat = jnp.vectorize(rotmat, signature='()->(2,2)')

def rotmat_x(theta):
    return jnp.array([[1, 0., 0.],
                        [0, jnp.cos(theta), -jnp.sin(theta)],
                      [0, jnp.sin(theta), jnp.cos(theta)]])

rotmat_x = jnp.vectorize(rotmat_x, signature='()->(3,3)')

In [51]:
rotmat_x(jnp.pi/6)

Array([[ 1.       ,  0.       ,  0.       ],
       [ 0.       ,  0.8660254, -0.5      ],
       [ 0.       ,  0.5      ,  0.8660254]], dtype=float32)

In [11]:
rotmat(jnp.pi/6)

Array([[ 0.8660254, -0.5      ],
       [ 0.5      ,  0.8660254]], dtype=float32)

In [5]:
thetas=jnp.linspace(0,2*jnp.pi, 24)

In [13]:
rotmats = rotmat(thetas)

In [14]:
rotmats.shape

(24, 2, 2)

In [19]:
v0 = jnp.array([1.5,0])
v0.shape

(2,)

In [36]:
vs = jnp.einsum('nij, j -> in',  rotmats, v0)
vs=vs.at[0,:].set(vs[0,:]*2)

In [38]:
fig=make_subplots()
fig.add_scatter(x=vs[0],y=vs[1], mode='markers')
fig.update_xaxes(range=[-4,4])
fig.update_yaxes(range=[-4,4])
fig.update_layout(width=400,height=400,template='plotly_dark')

In [40]:
rotmat(jnp.pi/6)

Array([[ 0.8660254, -0.5      ],
       [ 0.5      ,  0.8660254]], dtype=float32)

In [42]:
vs[:,:3]

Array([[3.        , 2.8887517 , 2.5632582 ],
       [0.        , 0.40469518, 0.7793759 ]], dtype=float32)

In [44]:
vs30=jnp.einsum('ij, jn', rotmat(jnp.pi/6), vs)

In [72]:
fig=make_subplots()
fig.add_scatter(x=vs[0],y=vs[1], mode='markers')
fig.add_scatter(x=vs30xyz[0],y=vs30xyz[1], mode='markers')
fig.update_xaxes(range=[-4,4])
fig.update_yaxes(range=[-4,4])
fig.update_layout(width=400,height=400,template='plotly_dark')

In [73]:
vs30xyz=jnp.c_[vs30[0],vs30[1],jnp.zeros_like(vs30[0])].T

In [62]:
rotmat_x(jnp.pi/6)

Array([[ 1.       ,  0.       ,  0.       ],
       [ 0.       ,  0.8660254, -0.5      ],
       [ 0.       ,  0.5      ,  0.8660254]], dtype=float32)

In [75]:
vs30_30xyz=jnp.einsum('ij, jn', rotmat_x(jnp.pi/6), vs30xyz)

In [76]:
import plotly.graph_objects as go

fig = go.Figure(data=[go.Scatter3d(x=vs30_30xyz[0], y=vs30_30xyz[1], z=vs30_30xyz[2], mode='markers')])

fig.update_layout(height=600,width=400,margin=dict(l=0, r=0, b=0, t=0), template='plotly_dark')