In [1]:
import torch
from madspace.functional.kinematics import rotxxx, rotxxx_inv, rotate_zy, pmag, atan2

In [3]:
torch.set_printoptions(precision=6, sci_mode=False)
dtype = torch.float64

# --- Test inputs -------------------------------------------------------------

# Case 1: q along +z  → identity on spatial part
q1 = torch.tensor([10.0, 0.0, 0.0,  +5.0], dtype=dtype)
p1 = torch.tensor([ 3.0, 1.0, 2.0,  +3.0], dtype=dtype)  # arbitrary

# Case 2: q along -z  → spatial sign flip
q2 = torch.tensor([10.0, 0.0, 0.0,  -7.0], dtype=dtype)
p2 = torch.tensor([ 3.0, 1.5, -2.0, 4.0], dtype=dtype)  # arbitrary

# Case 3: generic q (qx,qy,qz) = (6,8,10)  (|q_T|=10, |q|=sqrt(200))
# Choose p aligned with +z in the q-aligned frame → result points along q̂ with |p_z|
q3 = torch.tensor([20.0, 6.0, 8.0, 10.0], dtype=dtype)
pz = 5.0
p3 = torch.tensor([ 4.0, 0.0, 0.0, pz], dtype=dtype)

# Batch them
q = torch.stack([q1, q2, q3], dim=0)        # (3,4)
p = torch.stack([p1, p2, p3], dim=0)        # (3,4)

# --- Run forward and inverse -------------------------------------------------

prot = rotxxx(p, q)          # forward rotation
p_back = rotxxx_inv(prot, q) # inverse should recover p

print("prot =\n", prot)
print("p_back =\n", p_back)

prot =
 tensor([[ 3.000000,  1.000000,  2.000000,  3.000000],
        [ 3.000000, -1.500000,  2.000000, -4.000000],
        [ 4.000000,  2.121320,  2.828427,  3.535534]], dtype=torch.float64)
p_back =
 tensor([[     3.000000,      1.000000,      2.000000,      3.000000],
        [     3.000000,      1.500000,     -2.000000,      4.000000],
        [     4.000000,      0.000000,     -0.000000,      5.000000]],
       dtype=torch.float64)


In [13]:
EPS = 1e-12
q3 = torch.tensor([20.0, 6.0, 8.0, 10.0], dtype=dtype)
p3 = torch.tensor([ 4.0, 0.0, 0.0, 5.0], dtype=dtype)

q3mag = pmag(q3)
phi3 = atan2(q3[2], q3[1])
costheta3 = q3[3] / q3mag.clip(min=EPS)
p3rot1 = rotate_zy(p3, phi3, costheta3)
p3rot2 = rotxxx(p3, q3)
print("p3rot1 =\n", p3rot1)
print("p3rot2 =\n", p3rot2)



p3rot1 =
 tensor([4.000000, 2.121320, 2.828427, 3.535534], dtype=torch.float64)
p3rot2 =
 tensor([4.000000, 2.121320, 2.828427, 3.535534], dtype=torch.float64)


In [32]:
# --- Assertions / expected checks -------------------------------------------

EPS = 1e-12

# Case 1: identity (q_T==0, qz>0)
assert torch.allclose(prot[0], p[0], atol=EPS)

# Case 2: sign flip (q_T==0, qz<0)  → spatial parts multiply by -1
assert torch.allclose(prot[1, 0], p[1, 0], atol=EPS)       # energy unchanged
assert torch.allclose(prot[1, 1:], -p[1, 1:], atol=EPS)    # spatial flipped

# Case 3: p = (E,0,0,pz) in q-aligned frame should map to pz * q̂ in lab
q_vec = q[2, 1:]
q_norm = torch.linalg.vector_norm(q_vec)
expected_spatial = (pz / q_norm) * q_vec   # = pz * q / |q|
assert torch.allclose(prot[2, 1:], expected_spatial, atol=EPS)

# Inverse round-trip for all cases
assert torch.allclose(p_back, p, atol=EPS)
print("All rotxxx / rotxxx_inv tests passed ✔️")

All rotxxx / rotxxx_inv tests passed ✔️


In [33]:
# --- Edge-case Assertions / expected checks -------------------------------------------
q4 = torch.tensor([12.0, 1e-16, -1e-16, 3.0], dtype=dtype)  # q_T ~ 0, qz>0
p4 = torch.tensor([ 2.0, 0.3, -0.4, 0.5], dtype=dtype)
prot4 = rotxxx(p4.unsqueeze(0), q4.unsqueeze(0))[0]
assert torch.allclose(prot4, p4, atol=EPS)
print("Edge-case rotxxx / rotxxx_inv tests passed ✔️")

Edge-case rotxxx / rotxxx_inv tests passed ✔️
