In [10]:
import torch


def build_rotation(r: torch.Tensor) -> torch.Tensor:
    # norm = torch.sqrt(
    #     r[:, 0] * r[:, 0] + r[:, 1] * r[:, 1] + r[:, 2] * r[:, 2] + r[:, 3] * r[:, 3]
    # )

    # q = r / norm[:, None]
    q = r

    R = torch.zeros((q.size(0), 3, 3), device=r.device, dtype=r.dtype)

    r = q[:, 0]
    x = q[:, 1]
    y = q[:, 2]
    z = q[:, 3]

    R[:, 0, 0] = 1 - 2 * (y * y + z * z)
    R[:, 0, 1] = 2 * (x * y - r * z)
    R[:, 0, 2] = 2 * (x * z + r * y)
    R[:, 1, 0] = 2 * (x * y + r * z)
    R[:, 1, 1] = 1 - 2 * (x * x + z * z)
    R[:, 1, 2] = 2 * (y * z - r * x)
    R[:, 2, 0] = 2 * (x * z - r * y)
    R[:, 2, 1] = 2 * (y * z + r * x)
    R[:, 2, 2] = 1 - 2 * (x * x + y * y)
    return R

r = torch.Tensor([[1.0, 2.0, 3.0, 4.0]]).requires_grad_(True)
scales = torch.Tensor([[1.0, 2.0, 3.0]])
S = torch.eye(3).unsqueeze(0).detach()
S[:, 0, 0] = 1.0
S[:, 1, 1] = 2.0
S[:, 2, 2] = 3.0
S_noised = S.clone()
torch.manual_seed(0)
S_noised[:, 0, 0] = S[:, 0, 0] + torch.randn(1) * 0.001
S_noised[:, 1, 1] = S[:, 1, 1] + torch.randn(1) * 0.001
S_noised[:, 2, 2] = S[:, 2, 2] + torch.randn(1) * 0.001
R = build_rotation(r)

target = torch.eye(3).unsqueeze(0)
target[:, 0, 0] = 1.0
target[:, 1, 1] = 2.0
target[:, 2, 2] = 3.0

M = R @ S
print(M)
print(target)
loss = (M - target).pow(2).sum()
print(loss)
loss.backward()
print(r.grad)

tensor([[[-49.,   8.,  66.],
         [ 20., -78.,  60.],
         [ 10.,  56., -75.]]], grad_fn=<UnsafeViewBackward0>)
tensor([[[1., 0., 0.],
         [0., 2., 0.],
         [0., 0., 3.]]])
tensor(26640., grad_fn=<SumBackward0>)
tensor([[ 1776.,  9792., 12528., 11904.]])


In [10]:
import torch
from torch.autograd.functional import jacobian

def build_rotation(r: torch.Tensor) -> torch.Tensor:
    # r has shape [1,4], but let's ignore batch and just do the single quaternion
    # for clarity. (Here, r has shape [4].)
    w, x, y, z = r  # unpack the 4 components
    # We'll return a 3x3 matrix (no normalization for demonstration):
    R = torch.zeros(3, 3, dtype=r.dtype, device=r.device)
    
    R[0, 0] = 1 - 2*(y**2 + z**2)
    R[0, 1] = 2*(x*y - w*z)
    R[0, 2] = 2*(x*z + w*y)
    R[1, 0] = 2*(x*y + w*z)
    R[1, 1] = 1 - 2*(x**2 + z**2)
    R[1, 2] = 2*(y*z - w*x)
    R[2, 0] = 2*(x*z - w*y)
    R[2, 1] = 2*(y*z + w*x)
    R[2, 2] = 1 - 2*(x**2 + y**2)

    return R

def f(r: torch.Tensor, scale_matrix: torch.Tensor) -> torch.Tensor:
    # r is shape [4]
    # We'll just return the 3x3 rotation from the function above:
    return build_rotation(r) @ scale_matrix

# r i j k
r = torch.Tensor([1.0, 2.0, 3.0, 4.0])
scale_matrix = torch.Tensor([[1.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 3.0]])
J = jacobian(f, (r, scale_matrix))
print(J[0])

tensor([[[  0.,   0., -12., -16.],
         [-16.,  12.,   8.,  -4.],
         [ 18.,  24.,   6.,  12.]],

        [[  8.,   6.,   4.,   2.],
         [  0., -16.,   0., -32.],
         [-12.,  -6.,  24.,  18.]],

        [[ -6.,   8.,  -2.,   4.],
         [  8.,   4.,  16.,  12.],
         [  0., -24., -36.,   0.]]])


In [8]:
import torch 

def build_rotation(r: torch.Tensor) -> torch.Tensor:
    q = r

    R = torch.zeros((q.size(0), 3, 3), device=r.device, dtype=r.dtype)

    r = q[:, 0]
    x = q[:, 1]
    y = q[:, 2]
    z = q[:, 3]

    R[:, 0, 0] = 1 - 2 * (y * y + z * z)
    R[:, 0, 1] = 2 * (x * y - r * z)
    R[:, 0, 2] = 2 * (x * z + r * y)
    R[:, 1, 0] = 2 * (x * y + r * z)
    R[:, 1, 1] = 1 - 2 * (x * x + z * z)
    R[:, 1, 2] = 2 * (y * z - r * x)
    R[:, 2, 0] = 2 * (x * z - r * y)
    R[:, 2, 1] = 2 * (y * z + r * x)
    R[:, 2, 2] = 1 - 2 * (x * x + y * y)
    return R


def d_m_wrt_qr(quats: torch.Tensor, scales: torch.Tensor, n: int) -> torch.Tensor:
    """
    Compute the derivative of m wrt quats
    quats is nx4 tensor
    shape is nx3 tensor
    """
    qr = quats[:, 0]
    qi = quats[:, 1]
    qj = quats[:, 2]
    qk = quats[:, 3]
    
    sx = scales[:, 0]
    sy = scales[:, 1]
    sz = scales[:, 2]

    derivative = torch.zeros((n, 3, 3))
    derivative[:, 0, 1] = -sy*qk
    derivative[:, 0, 2] = sz*qj
    derivative[:, 1, 0] = sx*qk
    derivative[:, 1, 2] = -sz*qi
    derivative[:, 2, 0] = -sx*qj
    derivative[:, 2, 1] = sy*qi
    
    return 2 * derivative


def d_m_wrt_qi(quats: torch.Tensor, scales: torch.Tensor, n: int) -> torch.Tensor:
    """
    Compute the derivative of m wrt quats
    quats is nx4 tensor
    shape is nx3 tensor
    """
    qr = quats[:, 0]
    qi = quats[:, 1]
    qj = quats[:, 2]
    qk = quats[:, 3]
    
    sx = scales[:, 0]
    sy = scales[:, 1]
    sz = scales[:, 2]

    derivative = torch.zeros((n, 3, 3))
    derivative[:, 0, 1] = sy*qj
    derivative[:, 0, 2] = sz*qk
    derivative[:, 1, 0] = sx*qj
    derivative[:, 1, 1] = -2*sy*qi
    derivative[:, 1, 2] = -sz*qr
    derivative[:, 2, 0] = sx*qk
    derivative[:, 2, 1] = sy*qr
    derivative[:, 2, 2] = -2*sz*qi
    return 2 * derivative

def d_m_wrt_qj(quats: torch.Tensor, scales: torch.Tensor, n: int) -> torch.Tensor:
    """
    Compute the derivative of m wrt quats
    quats is nx4 tensor
    shape is nx3 tensor
    """
    qr = quats[:, 0]
    qi = quats[:, 1]
    qj = quats[:, 2]
    qk = quats[:, 3]
    
    sx = scales[:, 0]
    sy = scales[:, 1]
    sz = scales[:, 2]

    derivative = torch.zeros((n, 3, 3))
    derivative[:, 0, 0] = -2*sx*qj
    derivative[:, 0, 1] = sy*qi
    derivative[:, 0, 2] = sz*qr
    derivative[:, 1, 0] = sx*qi
    derivative[:, 1, 1] = 0
    derivative[:, 1, 2] = sz*qk
    derivative[:, 2, 0] = -sx*qr
    derivative[:, 2, 1] = sy*qk
    derivative[:, 2, 2] = -2*sz*qj
    return 2 * derivative

def d_m_wrt_qk(quats: torch.Tensor, scales: torch.Tensor, n: int) -> torch.Tensor:
    """
    Compute the derivative of m wrt quats
    quats is nx4 tensor
    shape is nx3 tensor
    """
    qr = quats[:, 0]
    qi = quats[:, 1]
    qj = quats[:, 2]
    qk = quats[:, 3]
    
    sx = scales[:, 0]
    sy = scales[:, 1]
    sz = scales[:, 2]

    derivative = torch.zeros((n, 3, 3))
    derivative[:, 0, 0] = -2*sx*qk
    derivative[:, 0, 1] = -sy*qr
    derivative[:, 0, 2] = sz*qi
    derivative[:, 1, 0] = sx*qr
    derivative[:, 1, 1] = -2*sy*qk
    derivative[:, 1, 2] = sz*qj
    derivative[:, 2, 0] = sx*qi
    derivative[:, 2, 1] = sy*qj
    derivative[:, 2, 2] = 0
    return 2 * derivative

def d_m_wrt_q(quats: torch.Tensor, scales: torch.Tensor) -> torch.Tensor:
    """
    Compute the derivative of m wrt quats and scales
    quats is nx4 tensor
    scales is nx3 tensor
    """
    n = quats.shape[0]
    qr = d_m_wrt_qr(quats, scales, n)
    qi = d_m_wrt_qi(quats, scales, n)
    qj = d_m_wrt_qj(quats, scales, n)
    qk = d_m_wrt_qk(quats, scales, n)
    
    return qr, qi, qj, qk

class quatsScalesToM(torch.autograd.Function):
    @staticmethod
    def forward(ctx, quats: torch.Tensor, scales: torch.Tensor) -> torch.Tensor:
        ctx.save_for_backward(quats, scales)
        R = build_rotation(quats)
        S = torch.eye(3).unsqueeze(0)
        S[:, 0, 0] = scales[:, 0]
        S[:, 1, 1] = scales[:, 1]
        S[:, 2, 2] = scales[:, 2]
        return R @ S, R
    
    @staticmethod
    def backward(ctx, grad_output1: torch.Tensor, grad_output2: torch.Tensor) -> torch.Tensor:
        quats, scales = ctx.saved_tensors
        deriv = d_m_wrt_q(quats, scales)
        wrt_r = (grad_output1 * deriv[0]).sum(dim=(1, 2), keepdim=True).squeeze(2)
        wrt_i = (grad_output1 * deriv[1]).sum(dim=(1, 2), keepdim=True).squeeze(2)
        wrt_j = (grad_output1 * deriv[2]).sum(dim=(1, 2), keepdim=True).squeeze(2)
        wrt_k = (grad_output1 * deriv[3]).sum(dim=(1, 2), keepdim=True).squeeze(2)
        deriv_wrt_quats = torch.cat([wrt_r, wrt_i, wrt_j, wrt_k], dim=1)
        print(deriv_wrt_quats)
        return deriv_wrt_quats, None
    
r = torch.Tensor([[1.0, 2.0, 3.0, 4.0]]).requires_grad_(True)

scales = torch.Tensor([[1.0, 2.0, 3.0]])
S = torch.eye(3).unsqueeze(0)
S[:, 0, 0] = 1.0
S[:, 1, 1] = 2.0
S[:, 2, 2] = 3.0
S_noised = S.clone()
torch.manual_seed(0)
S_noised[:, 0, 0] = S[:, 0, 0] + torch.randn(1) * 0.001
S_noised[:, 1, 1] = S[:, 1, 1] + torch.randn(1) * 0.001
S_noised[:, 2, 2] = S[:, 2, 2] + torch.randn(1) * 0.001
M, R = quatsScalesToM.apply(r, scales)
print(M.requires_grad)
target = torch.eye(3).unsqueeze(0)
target[:, 0, 0] = 1.0
target[:, 1, 1] = 2.0
target[:, 2, 2] = 3.0

print(M)
print(target)
loss = (M - target).pow(2).sum()
print(loss)
loss.backward()
print(r.grad)

True
tensor([[[-49.,   8.,  66.],
         [ 20., -78.,  60.],
         [ 10.,  56., -75.]]], grad_fn=<quatsScalesToMBackward>)
tensor([[[1., 0., 0.],
         [0., 2., 0.],
         [0., 0., 3.]]])
tensor(26640., grad_fn=<SumBackward0>)
tensor([[ 1776.,  9792., 12528., 11904.]])
tensor([[ 1776.,  9792., 12528., 11904.]])


In [2]:
import torch

def build_rotation(r: torch.Tensor) -> torch.Tensor:
    q = r

    R = torch.zeros((q.size(0), 3, 3), device=r.device, dtype=r.dtype)

    r = q[:, 0]
    x = q[:, 1]
    y = q[:, 2]
    z = q[:, 3]

    R[:, 0, 0] = 1 - 2 * (y * y + z * z)
    R[:, 0, 1] = 2 * (x * y - r * z)
    R[:, 0, 2] = 2 * (x * z + r * y)
    R[:, 1, 0] = 2 * (x * y + r * z)
    R[:, 1, 1] = 1 - 2 * (x * x + z * z)
    R[:, 1, 2] = 2 * (y * z - r * x)
    R[:, 2, 0] = 2 * (x * z - r * y)
    R[:, 2, 1] = 2 * (y * z + r * x)
    R[:, 2, 2] = 1 - 2 * (x * x + y * y)
    return R

def d_r_wrt_qr(quats: torch.Tensor, n: int) -> torch.Tensor:
    """
    Compute the derivative of m wrt quats
    quats is nx4 tensor
    shape is nx3 tensor
    """
    qr = quats[:, 0]
    qi = quats[:, 1]
    qj = quats[:, 2]
    qk = quats[:, 3]

    derivative = torch.zeros((n, 3, 3))
    derivative[:, 0, 1] = -qk
    derivative[:, 0, 2] = qj
    derivative[:, 1, 0] = qk
    derivative[:, 1, 2] = -qi
    derivative[:, 2, 0] = -qj
    derivative[:, 2, 1] = qi
    
    return 2 * derivative

def d_r_wrt_qi(quats: torch.Tensor, n: int) -> torch.Tensor:
    """
    Compute the derivative of m wrt quats
    quats is nx4 tensor
    shape is nx3 tensor
    """
    qr = quats[:, 0]
    qi = quats[:, 1]
    qj = quats[:, 2]
    qk = quats[:, 3]

    derivative = torch.zeros((n, 3, 3))
    derivative[:, 0, 1] = qj
    derivative[:, 0, 2] = qk
    derivative[:, 1, 0] = qj
    derivative[:, 1, 1] = -2*qi
    derivative[:, 1, 2] = -qr
    derivative[:, 2, 0] = qk
    derivative[:, 2, 1] = qr
    derivative[:, 2, 2] = -2*qi
    return 2 * derivative

def d_r_wrt_qj(quats: torch.Tensor, n: int) -> torch.Tensor:
    """
    Compute the derivative of m wrt quats
    quats is nx4 tensor
    shape is nx3 tensor
    """
    qr = quats[:, 0]
    qi = quats[:, 1]
    qj = quats[:, 2]
    qk = quats[:, 3]

    derivative = torch.zeros((n, 3, 3))
    derivative[:, 0, 0] = -2*qj
    derivative[:, 0, 1] = qi
    derivative[:, 0, 2] = qr
    derivative[:, 1, 0] = qi
    derivative[:, 1, 1] = 0
    derivative[:, 1, 2] = qk
    derivative[:, 2, 0] = -qr
    derivative[:, 2, 1] = qk
    derivative[:, 2, 2] = -2*qj
    return 2 * derivative

class quatsToR(torch.autograd.Function):
    @staticmethod
    def forward(ctx, quats: torch.Tensor) -> torch.Tensor:
        ctx.save_for_backward(quats)
        R = build_rotation(quats)
        return R
    
    @staticmethod
    def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
        quats = ctx.saved_tensors[0]
        deriv_wrt_qr = d_r_wrt_qr(quats, quats.shape[0])
        deriv_wrt_qi = d_r_wrt_qi(quats, quats.shape[0])
        deriv_wrt_qj = d_r_wrt_qj(quats, quats.shape[0])
        
        
        deriv_wrt_qr = (grad_output * deriv_wrt_qr).sum(dim=(1, 2), keepdim=True).squeeze(2)
        deriv_wrt_qi = (grad_output * deriv_wrt_qi).sum(dim=(1, 2), keepdim=True).squeeze(2)
        deriv_wrt_qj = (grad_output * deriv_wrt_qj).sum(dim=(1, 2), keepdim=True).squeeze(2)
        return torch.cat([deriv_wrt_qr, deriv_wrt_qi, deriv_wrt_qj, torch.zeros_like(deriv_wrt_qr)], dim=1)
    
r = torch.Tensor([[1.0, 2.0, 3.0, 4.0]]).requires_grad_(True)
# r = r / torch.norm(r, dim=1, keepdim=True)
# r.retain_grad()
print(r.requires_grad)

R = quatsToR.apply(r)

print(R)
target = torch.eye(3).unsqueeze(0)
loss = (R - target).pow(2).sum()
print(loss)
loss.backward()
print(r.grad)

True
tensor([[[-49.,   4.,  22.],
         [ 20., -39.,  20.],
         [ 10.,  28., -25.]]], grad_fn=<quatsToRBackward>)
tensor(6960., grad_fn=<SumBackward0>)
tensor([[ 464., 1888., 2832.,    0.]])


In [3]:
r = torch.Tensor([[1.0, 2.0, 3.0, 4.0]]).requires_grad_(True)
# r = r / torch.norm(r, dim=1, keepdim=True)
# r.retain_grad()

R = build_rotation(r)
print(R)
loss = (R - target).pow(2).sum()
print(loss)
loss.backward()
print(r.grad)


tensor([[[-49.,   4.,  22.],
         [ 20., -39.,  20.],
         [ 10.,  28., -25.]]], grad_fn=<CopySlices>)
tensor(6960., grad_fn=<SumBackward0>)
tensor([[ 464., 1888., 2832., 3776.]])
