# Splat Math

### Camera Projection
$$ \left[\begin{matrix} u \\ v \end{matrix}\right] = \left[\begin{matrix} f_x & 0 & c_x \\ 0 & f_y & c_y \end{matrix}\right] \left[\begin{matrix} \frac{x}{z} \\ \frac{y}{z} \\ 1 \end{matrix}\right] $$

In [16]:
import torch
from torch.autograd import Function, gradcheck

class CameraProject(Function):
    @staticmethod
    def forward(ctx, x, y, z, fx, fy, cx, cy):
        u = fx * x / z + cx
        v = fy * y / z + cy
        ctx.save_for_backward(x, y, z, fx, fy)
        return u, v

    @staticmethod
    def backward(ctx, grad_u, grad_v):
        x, y, z, fx, fy = ctx.saved_tensors
        grad_x = grad_u * fx / z
        grad_y = grad_v * fy / z
        grad_z = -grad_u * fx * x / z**2 - grad_v * fy * y / z**2
        return grad_x, grad_y, grad_z, None, None, None, None
    
x = torch.tensor(10.0, dtype=torch.float64, requires_grad=True)
y = torch.tensor(-5.0, dtype=torch.float64, requires_grad=True)
z = torch.tensor(10.0, dtype=torch.float64, requires_grad=True)
fx = torch.tensor(1300.0, dtype=torch.float64, requires_grad=False)
fy = torch.tensor(1200.0, dtype=torch.float64, requires_grad=False)
cx = torch.tensor(320.0, dtype=torch.float64, requires_grad=False)
cy = torch.tensor(240.0, dtype=torch.float64, requires_grad=False)

test = gradcheck(CameraProject.apply, (x, y, z, fx, fy, cx, cy))
print(test)


True


### Matrix Multiplication

The reverse mode differentiation for matrix operations is documented in: [An extended collection of matrix derivative results for forward and reverse mode algorithmic differentiation](https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf) by Mike Giles

Matrix Multiplication is documented in Section 2.2.2

In [17]:
import torch
from torch.autograd import Function, gradcheck

class MatrixMultiplication(Function):
    @staticmethod
    def forward(ctx, A, B):
        C = A @ B
        ctx.save_for_backward(A, B)
        return C

    @staticmethod
    def backward(ctx, grad_C):
        A, B, = ctx.saved_tensors
        grad_A = grad_C @ B.T
        grad_B = A.T @ grad_C
        return grad_A, grad_B

    
R = torch.rand(3, 3, dtype=torch.float64, requires_grad=True)
S = torch.rand(3, 3, dtype=torch.float64, requires_grad=True)

test = gradcheck(MatrixMultiplication.apply, (R, S))
print(test)

True


Computing `RSSR` is just two matrix multiplications:
1. $RS = R * S$
2. $RSSR = RS * (RS)^T$

In [18]:
import torch
from torch.autograd import Function, gradcheck

class RSSR(Function):
    @staticmethod
    def forward(ctx, R, S):
        RS = R @ S
        RSSR = RS @ RS.T
        ctx.save_for_backward(R, S)
        return RSSR
    
    @staticmethod
    def backward(ctx, grad_RSSR):
        R, S = ctx.saved_tensors
        RS = R @ S
        grad_RS = grad_RSSR @ RS
        grad_SR = RS.T @ grad_RSSR

        grad_R = grad_RS @ S.T
        grad_S = R.T @ grad_RS

        grad_R_t = S @ grad_SR
        grad_S_t = grad_SR @ R


        return grad_R + grad_R_t.T, grad_S + grad_S_t.T
    

R = torch.rand(3, 3, dtype=torch.float64, requires_grad=True)
S = torch.rand(3, 3, dtype=torch.float64, requires_grad=True)

test = gradcheck(RSSR.apply, (R, S))
print(test)


True


Computing Sigma Image matches the "First Quadratic Form" 2.3.2

$ C = B^T A B $ 

$ \Sigma_{image} = JW \Sigma_{world} (JW)^T $

Where $ B = (JW)^T $ and $ A = \Sigma_{world} $ 


In [19]:
import torch
from torch.autograd import Function, gradcheck

class ComputeSigmaImage(Function):
    @staticmethod
    def forward(ctx, sigma_world, W, J):
        JW = J @ W
        sigma_image = JW @ sigma_world @ JW.T
        ctx.save_for_backward(sigma_world, W, J)
        return sigma_image
    
    @staticmethod
    def backward(ctx, grad_sigma_image):
        sigma_world, W, J = ctx.saved_tensors
        JW = J @ W 

        # using First Quadratic Form 2.3.2 from: https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
        # for C = B_t @ A @ B 
        # grad_A = B @ grad_C @ B_t
        # grad_B = A @ B @ grad_C_t + A_t @ B @ grad_C
        
        # applying to our variables
        # sigma_image = JW @ sigma_world @ JW.T
        # C = sigma_image
        # A = sigma_world
        # B = JW_t

        grad_sigma_world = JW.T @ grad_sigma_image @ JW
        grad_JW_t = sigma_world @ JW.T @ grad_sigma_image.T + sigma_world.T @ JW.T @ grad_sigma_image

        # compute gradient of JW_t using multiplication rules in 2.2.2 
        grad_W_t =  grad_JW_t @ J
        grad_J_t = W @ grad_JW_t

        grad_W = grad_W_t.T
        grad_J = grad_J_t.T

        return grad_sigma_world, grad_W, grad_J


sigma_world = torch.rand(3, 3, dtype=torch.float64, requires_grad=True)
W = torch.rand(3, 3, dtype=torch.float64, requires_grad=True)
J = torch.rand(2, 3, dtype=torch.float64, requires_grad=True)
test = gradcheck(ComputeSigmaImage.apply, (sigma_world, W, J))
print(test)




True


In [20]:
import sympy as sp
from sympy import print_latex

def quaternion_to_rotation_Symbolic(q):
    # norm = sp.sqrt(q[0]**2 + x**2 + y**2 + z**2)
    norm = 1
    w = w / norm
    x = x / norm
    y = y / norm
    z = z / norm
    # Compute the rotation matrix
    rotation_matrix = sp.Matrix([[1 - 2*y**2 - 2*z**2, 2*x*y - 2*w*z, 2*x*z + 2*w*y],
                               [2*x*y + 2*w*z, 1 - 2*x**2 - 2*z**2, 2*y*z - 2*w*x],
                               [2*x*z - 2*w*y, 2*y*z + 2*w*x, 1 - 2*x**2 - 2*y**2]])
    
    return rotation_matrix


w, x, y, z = sp.Symbols('w x y z')
q = [w, x, y, z]

rotation_matrix = quaternion_to_rotation_Symbolic(q)
rotation_derivative = sp.diff(rotation_matrix, z)
print_latex(rotation_derivative)


AttributeError: module 'sympy' has no attribute 'Symbols'


Jacobian of Quaternion to Rotation Matrix without normalization
$$  \frac{\partial q}{\partial w} = \left[\begin{matrix}0 & - 2 z & 2 y\\2 z & 0 & - 2 x\\- 2 y & 2 x & 0\end{matrix}\right] $$

$$  \frac{\partial q}{\partial x} =  \left[\begin{matrix}0 & 2 y & 2 z\\2 y & - 4 x & - 2 w\\2 z & 2 w & - 4 x\end{matrix}\right] $$ 

$$  \frac{\partial q}{\partial y} =  \left[\begin{matrix}- 4 y & 2 x & 2 w\\2 x & 0 & 2 z\\- 2 w & 2 z & - 4 y\end{matrix}\right] $$ 

$$  \frac{\partial q}{\partial z} =  \left[\begin{matrix}- 4 z & - 2 w & 2 x\\2 w & - 4 z & 2 y\\2 x & 2 y & 0\end{matrix}\right] $$ 

In [None]:
import torch
from torch.autograd import Function, gradcheck

class QuaternionToRotation(Function):
    @staticmethod
    def forward(ctx, q):
        rot = [
            1 - 2 * q[:, 2] ** 2 - 2 * q[:, 3] ** 2,
            2 * q[:, 1] * q[:, 2] - 2 * q[:, 0] * q[:, 3],
            2 * q[:, 3] * q[:, 1] + 2 * q[:, 0] * q[:, 2],
            2 * q[:, 1] * q[:, 2] + 2 * q[:, 0] * q[:, 3],
            1 - 2 * q[:, 1] ** 2 - 2 * q[:, 3] ** 2,
            2 * q[:, 2] * q[:, 3] - 2 * q[:, 0] * q[:, 1],
            2 * q[:, 3] * q[:, 1] - 2 * q[:, 0] * q[:, 2],
            2 * q[:, 2] * q[:, 3] + 2 * q[:, 0] * q[:, 1],
            1 - 2 * q[:, 1] ** 2 - 2 * q[:, 2] ** 2,
        ]
        rot = torch.stack(rot, dim=1).reshape(-1, 3, 3)
        ctx.save_for_backward(q)
        return rot

    @staticmethod
    def backward(ctx, grad_rot):
        q = ctx.saved_tensors[0]

        w = q[:, 0]
        x = q[:, 1]
        y = q[:, 2]
        z = q[:, 3]

        grad_qw = -2 * z *grad_rot[:,0 , 1] + 2 * y *grad_rot[:,0, 2] + 2 * z *grad_rot[:,1, 0] - 2 * x *grad_rot[:,1, 2] - 2 * y *grad_rot[:,2, 0] + 2 * x *grad_rot[:,2, 1]
        grad_qx = 2 * y *grad_rot[:,0, 1] + 2 * z *grad_rot[:,0, 2] + 2 * y *grad_rot[:,1, 0] - 4 * x *grad_rot[:,1, 1] - 2 * w *grad_rot[:,1, 2] + 2 * z *grad_rot[:,2, 0] + 2 * w *grad_rot[:,2, 1] - 4 * x *grad_rot[:,2, 2]
        grad_qy = -4 * y *grad_rot[:,0, 0] + 2 * x *grad_rot[:,0, 1] + 2 * w *grad_rot[:,0, 2] + 2 * x *grad_rot[:,1, 0] + 2 * z *grad_rot[:,1, 2] - 2 * w *grad_rot[:,2, 0] + 2 * z *grad_rot[:,2, 1] - 4 * y *grad_rot[:,2, 2]
        grad_qz = -4 * z *grad_rot[:,0, 0] - 2 * w *grad_rot[:,0, 1] + 2 * x *grad_rot[:,0, 2] + 2 * w *grad_rot[:,1, 0] - 4 * z *grad_rot[:,1, 1] + 2 * y *grad_rot[:,1, 2] + 2 * x *grad_rot[:,2, 0] + 2 * y *grad_rot[:,2, 1]
        grad_q = torch.stack([grad_qw, grad_qx, grad_qy, grad_qz], dim=1)

        return grad_q
        

q = torch.rand(10, 4, dtype=torch.float64, requires_grad=True)
norm_q = torch.norm(q, dim=1, keepdim=True)
q = q / norm_q

test = gradcheck(QuaternionToRotation.apply, (q))
print(test)

True


In [None]:
import sympy as sp
from sympy import print_latex

w, x, y, z = sp.symbols('w x y z')
q = sp.Matrix([w, x, y, z])
norm = sp.sqrt(w**2 + x**2 + y**2 + z**2)

q_norm = q / norm

dw = sp.diff(q_norm, w)
dx = sp.diff(q_norm, x)
dy = sp.diff(q_norm, y)
dz = sp.diff(q_norm, z)

print_latex(dw)
print_latex(dx)
print_latex(dy)
print_latex(dz)

Partial Derivatives of Quaternion Normalization

$$ \frac{\partial q}{\partial w} = \left[\begin{matrix}- \frac{w^{2}}{\left(w^{2} + x^{2} + y^{2} + z^{2}\right)^{\frac{3}{2}}} + \frac{1}{\sqrt{w^{2} + x^{2} + y^{2} + z^{2}}}\\- \frac{w x}{\left(w^{2} + x^{2} + y^{2} + z^{2}\right)^{\frac{3}{2}}}\\- \frac{w y}{\left(w^{2} + x^{2} + y^{2} + z^{2}\right)^{\frac{3}{2}}}\\- \frac{w z}{\left(w^{2} + x^{2} + y^{2} + z^{2}\right)^{\frac{3}{2}}}\end{matrix}\right] $$

$$ \frac{\partial q}{\partial x} = \left[\begin{matrix}- \frac{w x}{\left(w^{2} + x^{2} + y^{2} + z^{2}\right)^{\frac{3}{2}}}\\- \frac{x^{2}}{\left(w^{2} + x^{2} + y^{2} + z^{2}\right)^{\frac{3}{2}}} + \frac{1}{\sqrt{w^{2} + x^{2} + y^{2} + z^{2}}}\\- \frac{x y}{\left(w^{2} + x^{2} + y^{2} + z^{2}\right)^{\frac{3}{2}}}\\- \frac{x z}{\left(w^{2} + x^{2} + y^{2} + z^{2}\right)^{\frac{3}{2}}}\end{matrix}\right] $$

$$ \frac{\partial q}{\partial y} =\left[\begin{matrix}- \frac{w y}{\left(w^{2} + x^{2} + y^{2} + z^{2}\right)^{\frac{3}{2}}}\\- \frac{x y}{\left(w^{2} + x^{2} + y^{2} + z^{2}\right)^{\frac{3}{2}}}\\- \frac{y^{2}}{\left(w^{2} + x^{2} + y^{2} + z^{2}\right)^{\frac{3}{2}}} + \frac{1}{\sqrt{w^{2} + x^{2} + y^{2} + z^{2}}}\\- \frac{y z}{\left(w^{2} + x^{2} + y^{2} + z^{2}\right)^{\frac{3}{2}}}\end{matrix}\right] $$ 

$$ \frac{\partial q}{\partial z} = \left[\begin{matrix}- \frac{w z}{\left(w^{2} + x^{2} + y^{2} + z^{2}\right)^{\frac{3}{2}}}\\- \frac{x z}{\left(w^{2} + x^{2} + y^{2} + z^{2}\right)^{\frac{3}{2}}}\\- \frac{y z}{\left(w^{2} + x^{2} + y^{2} + z^{2}\right)^{\frac{3}{2}}}\\- \frac{z^{2}}{\left(w^{2} + x^{2} + y^{2} + z^{2}\right)^{\frac{3}{2}}} + \frac{1}{\sqrt{w^{2} + x^{2} + y^{2} + z^{2}}}\end{matrix}\right] $$

In [None]:
import torch
from torch.autograd import Function, gradcheck

class QuaternionNormalization(Function):
    @staticmethod
    def forward(ctx, q):
        q_norm = q / torch.norm(q, dim=1, keepdim=True)
        ctx.save_for_backward(q)
        return q_norm

    @staticmethod
    def backward(ctx, grad_q_norm):
        q = ctx.saved_tensors[0]
        w = q[:, 0]
        x = q[:, 1]
        y = q[:, 2]
        z = q[:, 3]
        
        norm_sq = w * w + x * x + y * y + z * z
        grad_qw = (-1 * w * w / norm_sq**1.5 + 1/norm_sq**0.5) * grad_q_norm[:, 0] - w * x / norm_sq**1.5 * grad_q_norm[:, 1] - w * y / norm_sq**1.5 * grad_q_norm[:, 2] - w * z / norm_sq**1.5 * grad_q_norm[:, 3]
        grad_qx = -w * x / norm_sq**1.5 * grad_q_norm[:, 0] + (-1 * x * x / norm_sq**1.5 + 1/norm_sq**0.5) * grad_q_norm[:, 1] - x * y / norm_sq**1.5 * grad_q_norm[:, 2] - x * z / norm_sq**1.5 * grad_q_norm[:, 3]
        grad_qy = -w * y / norm_sq**1.5 * grad_q_norm[:, 0] - x * y / norm_sq**1.5 * grad_q_norm[:, 1] + (-1 * y * y / norm_sq**1.5 + 1/norm_sq**0.5) * grad_q_norm[:, 2] - y * z / norm_sq**1.5 * grad_q_norm[:, 3]
        grad_qz = -w * z / norm_sq**1.5 * grad_q_norm[:, 0] - x * z / norm_sq**1.5 * grad_q_norm[:, 1] - y * z / norm_sq**1.5 * grad_q_norm[:, 2] + (-1 * z * z / norm_sq**1.5 + 1/norm_sq**0.5) * grad_q_norm[:, 3]
        grad_q = torch.stack([grad_qw, grad_qx, grad_qy, grad_qz], dim=1)

        return grad_q
        

q = torch.rand(2, 4, dtype=torch.float64, requires_grad=True)

test = gradcheck(QuaternionNormalization.apply, (q))
print(test)

True


In [None]:
import torch
from torch.autograd import Function, gradcheck

class ComputeSigmaWorld(Function):
    @staticmethod
    def forward(ctx, q, scale):
        S = torch.diag_embed(torch.exp(scale))
        norm_q = torch.norm(q, dim=1, keepdim=True)
        q_norm = q / norm_q
        R = [
            1 - 2 * q_norm[:, 2] ** 2 - 2 * q_norm[:, 3] ** 2,
            2 * q_norm[:, 1] * q_norm[:, 2] - 2 * q_norm[:, 0] * q_norm[:, 3],
            2 * q_norm[:, 3] * q_norm[:, 1] + 2 * q_norm[:, 0] * q_norm[:, 2],
            2 * q_norm[:, 1] * q_norm[:, 2] + 2 * q_norm[:, 0] * q_norm[:, 3],
            1 - 2 * q_norm[:, 1] ** 2 - 2 * q_norm[:, 3] ** 2,
            2 * q_norm[:, 2] * q_norm[:, 3] - 2 * q_norm[:, 0] * q_norm[:, 1],
            2 * q_norm[:, 3] * q_norm[:, 1] - 2 * q_norm[:, 0] * q_norm[:, 2],
            2 * q_norm[:, 2] * q_norm[:, 3] + 2 * q_norm[:, 0] * q_norm[:, 1],
            1 - 2 * q_norm[:, 1] ** 2 - 2 * q_norm[:, 2] ** 2,
        ]
        R = torch.stack(R, dim=1).reshape(-1, 3, 3)

        RS = torch.bmm(R, S)
        RS_t = RS.permute(0, 2, 1)

        RSSR = torch.bmm(RS, RS_t)
        ctx.save_for_backward(RS, R, S, scale, q, q_norm)
        return RSSR
    
    @staticmethod
    def backward(ctx, grad_RSSR):
        # compute double matmul gradient        
        RS, R, S, scale, q, q_norm = ctx.saved_tensors
        grad_RS = torch.bmm(grad_RSSR, RS)
        
        RS_t = RS.permute(0, 2, 1)
        grad_SR = RS_t @ grad_RSSR

        grad_R = grad_RS @ S.permute(0, 2, 1) + (S @ grad_SR).permute(0, 2, 1)
        grad_S = R.permute(0, 2, 1) @ grad_RS + (grad_SR @ R).permute(0, 2, 1)

        # compute quaternion gradient
        w = q_norm[:, 0]
        x = q_norm[:, 1]
        y = q_norm[:, 2]
        z = q_norm[:, 3]
        grad_qw_norm = -2 * z *grad_R[:,0 , 1] + 2 * y *grad_R[:,0, 2] + 2 * z *grad_R[:,1, 0] - \
            2 * x *grad_R[:,1, 2] - 2 * y *grad_R[:,2, 0] + 2 * x *grad_R[:,2, 1]
        grad_qx_norm = 2 * y *grad_R[:,0, 1] + 2 * z *grad_R[:,0, 2] + 2 * y *grad_R[:,1, 0] - \
            4 * x *grad_R[:,1, 1] - 2 * w *grad_R[:,1, 2] + 2 * z *grad_R[:,2, 0] + 2 * w *grad_R[:,2, 1] - 4 * x *grad_R[:,2, 2]
        grad_qy_norm = -4 * y *grad_R[:,0, 0] + 2 * x *grad_R[:,0, 1] + 2 * w *grad_R[:,0, 2] + \
            2 * x *grad_R[:,1, 0] + 2 * z *grad_R[:,1, 2] - 2 * w *grad_R[:,2, 0] + 2 * z *grad_R[:,2, 1] - 4 * y *grad_R[:,2, 2]
        grad_qz_norm = -4 * z *grad_R[:,0, 0] - 2 * w *grad_R[:,0, 1] + 2 * x *grad_R[:,0, 2] + \
            2 * w *grad_R[:,1, 0] - 4 * z *grad_R[:,1, 1] + 2 * y *grad_R[:,1, 2] + 2 * x *grad_R[:,2, 0] + 2 * y *grad_R[:,2, 1]
        grad_q_norm = torch.stack([grad_qw_norm, grad_qx_norm, grad_qy_norm, grad_qz_norm], dim=1)

        # compute gradient for unnormalized quaternion
        w = q[:, 0]
        x = q[:, 1]
        y = q[:, 2]
        z = q[:, 3]
        norm_sq = w * w + x * x + y * y + z * z
        grad_qw = (-1 * w * w / norm_sq**1.5 + 1/norm_sq**0.5) * grad_q_norm[:, 0] - w * x / norm_sq**1.5 * grad_q_norm[:, 1] - \
            w * y / norm_sq**1.5 * grad_q_norm[:, 2] - w * z / norm_sq**1.5 * grad_q_norm[:, 3]
        grad_qx = -w * x / norm_sq**1.5 * grad_q_norm[:, 0] + (-1 * x * x / norm_sq**1.5 + 1/norm_sq**0.5) * grad_q_norm[:, 1] - \
            x * y / norm_sq**1.5 * grad_q_norm[:, 2] - x * z / norm_sq**1.5 * grad_q_norm[:, 3]
        grad_qy = -w * y / norm_sq**1.5 * grad_q_norm[:, 0] - x * y / norm_sq**1.5 * grad_q_norm[:, 1] + (-1 * y * y / norm_sq**1.5 + \
            1/norm_sq**0.5) * grad_q_norm[:, 2] - y * z / norm_sq**1.5 * grad_q_norm[:, 3]
        grad_qz = -w * z / norm_sq**1.5 * grad_q_norm[:, 0] - x * z / norm_sq**1.5 * grad_q_norm[:, 1] - y * z / norm_sq**1.5 * grad_q_norm[:, 2] + \
            (-1 * z * z / norm_sq**1.5 + 1/norm_sq**0.5) * grad_q_norm[:, 3]
        grad_q = torch.stack([grad_qw, grad_qx, grad_qy, grad_qz], dim=1)

        grad_scale_no_activation = grad_S.diagonal(dim1=1, dim2=2)
        grad_scale = grad_scale_no_activation * torch.exp(scale)

        return grad_q, grad_scale
    
N = 2
q = torch.rand(N, 4, dtype=torch.float64, requires_grad=True)

s = torch.rand(N, 3, dtype=torch.float64, requires_grad=True)
test = gradcheck(ComputeSigmaWorld.apply, (q, s))
print(test)

True


In [None]:
import sympy as sp
from sympy import print_latex, exp


a, b, c, d = sp.symbols('a b c d')
sigma_image = sp.Matrix([[a, b], [c, d]])

u, v = sp.symbols('u v')

mh_dist_sq = (d * u ** 2 - b * u * v - c * u * v + a * v ** 2) / (a * d - b * c)

print_latex(sp.diff(mh_dist_sq, v))





\frac{2 a v - b u - c u}{a d - b c}


$$ d_m^2 = \frac{a v^{2} - b u v - c u v + d u^{2}}{a d - b c} $$ 

$$ \frac{\partial d_m^2}{\partial a} = - \frac{d \left(a v^{2} - b u v - c u v + d u^{2}\right)}{\left(a d - b c\right)^{2}} + \frac{v^{2}}{a d - b c} $$ 

$$ \frac{\partial d_m^2}{\partial b} = \frac{c \left(a v^{2} - b u v - c u v + d u^{2}\right)}{\left(a d - b c\right)^{2}} - \frac{u v}{a d - b c} $$ 

$$ \frac{\partial d_m^2}{\partial c} = \frac{b \left(a v^{2} - b u v - c u v + d u^{2}\right)}{\left(a d - b c\right)^{2}} - \frac{u v}{a d - b c} $$ 

$$ \frac{\partial d_m^2}{\partial d} = - \frac{a \left(a v^{2} - b u v - c u v + d u^{2}\right)}{\left(a d - b c\right)^{2}} + \frac{u^{2}}{a d - b c} $$


$$ \frac{\partial d_m^2}{\partial u} = \frac{- b v - c v + 2 d u}{a d - b c} $$ 

$$ \frac{\partial d_m^2}{\partial v} = \frac{2 a v - b u - c u}{a d - b c} $$ 


In [None]:
import torch
from torch.autograd import Function, gradcheck

class ComputeAlpha(Function):
    @staticmethod
    def forward(ctx, sigma_image, opa, uv_splat, uv_pixel):
        uv_diff = uv_pixel - uv_splat 
        a = sigma_image[0, 0]
        b = sigma_image[0, 1]
        c = sigma_image[1, 0]
        d = sigma_image[1, 1]
        mh_dist = (d * uv_diff[0] ** 2 - b * uv_diff[0] * uv_diff[1] - c * uv_diff[0] * uv_diff[1] + a * uv_diff[1] ** 2) / (a * d - b * c)

        prob = torch.exp(-0.5 * mh_dist)
        alpha = prob * opa
        ctx.save_for_backward(prob, sigma_image, uv_diff, opa)
        return alpha
        
    @staticmethod
    def backward(ctx, grad_alpha):
        prob, sigma_image, uv_diff, opa = ctx.saved_tensors
        grad_opa = prob * grad_alpha

        ## compute sigma world and uv_diff gradients        
        grad_prob = opa * grad_alpha
        grad_mh = -0.5 * prob * grad_prob

        a = sigma_image[0, 0]
        b = sigma_image[0, 1]
        c = sigma_image[1, 0]
        d = sigma_image[1, 1]

        u_diff = uv_diff[0]
        v_diff = uv_diff[1]

        grad_u = -(-b * v_diff - c * v_diff + 2 * d * u_diff) / (a * d - b * c) * grad_mh
        grad_v = -(2 * a * v_diff - b * u_diff - c * u_diff) / (a * d - b * c) * grad_mh

        grad_a = (-d * (a * v_diff ** 2 - b * u_diff * v_diff - c * u_diff * v_diff + d * u_diff ** 2) / (a * d - b * c)**2  + (v_diff ** 2) / (a * d - b * c)) * grad_mh
        grad_b = (c * (a * v_diff ** 2 - b * u_diff * v_diff - c * u_diff * v_diff + d * u_diff ** 2) / (a * d - b * c)**2  - (u_diff * v_diff) / (a * d - b * c)) * grad_mh
        grad_c = (b * (a * v_diff ** 2 - b * u_diff * v_diff - c * u_diff * v_diff + d * u_diff ** 2) / (a * d - b * c)**2  - (u_diff * v_diff) / (a * d - b * c)) * grad_mh
        grad_d = (-a * (a * v_diff ** 2 - b * u_diff * v_diff - c * u_diff * v_diff + d * u_diff ** 2) / (a * d - b * c)**2  + (u_diff ** 2) / (a * d - b * c)) * grad_mh

        grad_sigma_image = torch.Tensor([[grad_a, grad_b], [grad_c, grad_d]])
        grad_uv_splat = torch.Tensor([grad_u, grad_v])

        return grad_sigma_image, grad_opa, grad_uv_splat, None


uv_splat = torch.rand(2, dtype=torch.float64, requires_grad=True)
uv_pixel = torch.rand(2, dtype=torch.float64, requires_grad=False)

sigma_image = torch.rand(2, 2, dtype=torch.float64, requires_grad=True)
opa = torch.rand(1, dtype=torch.float64, requires_grad=True)
test = gradcheck(ComputeAlpha.apply, (sigma_image, opa, uv_splat, uv_pixel))
print(test)

True


Alpha Compositing

First (front-to-back) Gaussian Splatted to Pixel: 
$$ \alpha_a = 0.0 $$ 
$$ \alpha_c = \alpha_0(1.0 - \alpha_a) = \alpha_0 $$ 

Second splat:
$$ \alpha_a = \alpha_0 $$
$$ \alpha_c = \alpha_1(1.0 - \alpha_a) = \alpha_1(1.0 - \alpha_0) $$ 

Third splat:
$$ \alpha_a = \alpha_0 + \alpha_1(1.0 - \alpha_0) $$
$$ \alpha_c = \alpha_2(1.0 - \alpha_a) = \alpha_2(1.0 - (\alpha_0 + \alpha_1(1.0 - \alpha_0))) $$ 



In [None]:
import torch
from torch.autograd import Function, gradcheck

class AlphaComposite(Function):
    @staticmethod
    def forward(ctx, colors, alphas):
        alpha_accum = 0.0
        color_accum = torch.zeros_like(colors[0])
        for i in range(alphas.shape[0]):
            alpha_weight = (1 - alpha_accum)
            alpha_current = alphas[i] * (1 - alpha_accum)
            color_accum += alpha_current * colors[i, :]
            alpha_accum += alpha_current

        ctx.save_for_backward(alpha_weight, alphas, colors)
        return color_accum

    @staticmethod
    def backward(ctx, grad_color_accum):
        weight_final, alphas, colors = ctx.saved_tensors
        grad_alphas = torch.zeros_like(alphas)
        grad_colors = torch.zeros_like(colors)

        colors_accum = torch.zeros_like(colors[0])
        weight = weight_final
        for i in reversed(range(alphas.shape[0])):
            grad_colors[i] = alphas[i] * weight * grad_color_accum
            grad_alphas[i] = torch.dot((colors[i, :] * weight - colors_accum/(1.0 - alphas[i])), grad_color_accum)

            colors_accum += alphas[i] * colors[i, :] * weight
            weight /= (1 - alphas[i - 1])
            
        return grad_colors, grad_alphas

alphas = torch.rand(10, dtype=torch.float64, requires_grad=True) / 10.0
colors = torch.rand(10, 3, dtype=torch.float64, requires_grad=True)

test = gradcheck(AlphaComposite.apply, (colors, alphas))
print(test)

True


## Spherical Harmonics


Zero Order - not dependent on viewing direction

$ Y_0 = \frac{1}{2} \sqrt{\frac{1}{\pi}}$

First order

$ Y_1^{-1} = \frac{1}{2} \sqrt{\frac{3}{2\pi}} \frac{y}{r}$

$ Y_1^{0} = \frac{1}{2} \sqrt{\frac{3}{2\pi}} \frac{z}{r}$

$ Y_1^{1} = \frac{1}{2} \sqrt{\frac{3}{2\pi}} \frac{x}{r}$


Second Order

$ Y_2^{-2} = \frac{1}{2} \sqrt{\frac{15}{\pi}} \frac{xy}{r^2}$

$ Y_2^{-1} = \frac{1}{2} \sqrt{\frac{15}{\pi}} \frac{yz}{r^2}$

$ Y_2^{0} = \frac{1}{4} \sqrt{\frac{5}{\pi}} \frac{3z^2 - r^2}{r^2}$

$ Y_2^{1} = \frac{1}{2} \sqrt{\frac{15}{\pi}} \frac{xz}{r^2}$

$ Y_2^{2} = \frac{1}{4} \sqrt{\frac{15}{\pi}} \frac{x^2 - y^2}{r^2}$


Third Order

$ Y_3^{-3} = \frac{1}{4} \sqrt{\frac{35}{2\pi}} \frac{y(3x^2-y^2)}{r^3}$

$ Y_3^{-2} = \frac{1}{2} \sqrt{\frac{105}{\pi}} \frac{xyz}{r^3}$

$ Y_3^{-1} = \frac{1}{4} \sqrt{\frac{21}{2\pi}} \frac{y(5z^2-r^2)}{r^3}$

$ Y_3^{0} = \frac{1}{4} \sqrt{\frac{7}{\pi}} \frac{z(5z^2-3r^2)}{r^3}$

$ Y_3^{1} = \frac{1}{4} \sqrt{\frac{21}{2\pi}} \frac{x(5z^2-r^2)}{r^3}$

$ Y_3^{2} = \frac{1}{4} \sqrt{\frac{105}{\pi}} \frac{z(x^2-y^2)}{r^3}$

$ Y_3^{3} = \frac{1}{4} \sqrt{\frac{35}{2\pi}} \frac{x(x^2-3y^2)}{r^3}$





In [57]:
# compute the constants for the SH bands
import math

C_SH_0 = math.sqrt(1/math.pi)/2
C_SH_1 = math.sqrt(3/math.pi)/2

C_SH_2_N2 = math.sqrt(15/math.pi)/2
C_SH_2_N1 = -1.0 * math.sqrt(15/math.pi)/2
C_SH_2_0 = math.sqrt(5/math.pi)/4
C_SH_2_P1 = -1.0 * math.sqrt(15/math.pi)/2
C_SH_2_P2 = math.sqrt(15/math.pi)/4

C_SH_3_N3 = math.sqrt(35/(2*math.pi))/4
C_SH_3_N2 = math.sqrt(105/(math.pi))/2
C_SH_3_N1 = math.sqrt(21/(2*math.pi))/4
C_SH_3_0 = math.sqrt(7/(2*math.pi))/4
C_SH_3_P1 = math.sqrt(21/(2*math.pi))/4
C_SH_3_P2 = math.sqrt(105/(math.pi))/4
C_SH_3_P3 = math.sqrt(35/(2*math.pi))/4


print("Band 0", C_SH_0)
print("Band 1", C_SH_1)
print("Band 2", C_SH_2_N2, C_SH_2_N1, C_SH_2_0, C_SH_2_P1, C_SH_2_P2)
print("Band 3", C_SH_3_N3, C_SH_3_N2, C_SH_3_N1, C_SH_3_0, C_SH_3_P1, C_SH_3_P2, C_SH_3_P3)

Band 0 0.28209479177387814
Band 1 0.4886025119029199
Band 2 1.0925484305920792 -1.0925484305920792 0.31539156525252005 -1.0925484305920792 0.5462742152960396
Band 3 0.5900435899266435 2.890611442640554 0.4570457994644658 0.263875515352797 0.4570457994644658 1.445305721320277 0.5900435899266435


In [66]:
import torch
from torch.autograd import Function, gradcheck

class SphericalHarmonicsToRGB(Function):
    @staticmethod
    def forward(ctx, sh_coeff, view_dir):
        view_dir = view_dir / torch.norm(view_dir, dim=1, keepdim=True)
        x = view_dir[:, 0]
        y = view_dir[:, 1]
        z = view_dir[:, 2]

        N = sh_coeff.shape[0]
        N_sh = sh_coeff.shape[2]

        sh_r = sh_coeff[:, 0, :]
        sh_g = sh_coeff[:, 1, :]
        sh_b = sh_coeff[:, 2, :]

        rgb = torch.zeros(N, 3, dtype=sh_coeff.dtype, device=sh_coeff.device)
        rgb[:, 0] = C_SH_0 * sh_r[:, 0]
        rgb[:, 1] = C_SH_0 * sh_g[:, 0]
        rgb[:, 2] = C_SH_0 * sh_b[:, 0]

        if N_sh >= 4:
            rgb[:, 0] += -1 * C_SH_1 * sh_r[:, 1] * x + C_SH_1 * sh_r[:, 2] * y - C_SH_1 * sh_r[:, 3] * z
            rgb[:, 1] += -1 * C_SH_1 * sh_g[:, 1] * x + C_SH_1 * sh_g[:, 2] * y - C_SH_1 * sh_g[:, 3] * z
            rgb[:, 2] += -1 * C_SH_1 * sh_b[:, 1] * x + C_SH_1 * sh_b[:, 2] * y - C_SH_1 * sh_b[:, 3] * z
        if N_sh >= 9:
            xx = torch.square(x)
            yy = torch.square(y)
            zz = torch.square(z)
            rgb[:, 0] += C_SH_2_N2 * sh_r[:, 4] * x * y + \
                         C_SH_2_N1 * sh_r[:, 5] * y * z + \
                         C_SH_2_0 * sh_r[:, 6] * (3*zz - torch.ones(N, dtype=sh_coeff.dtype, device=sh_coeff.device)) + \
                         C_SH_2_P1 * sh_r[:, 7] * x * z + \
                         C_SH_2_P2 * sh_r[:, 8] * (xx - yy)
            
            rgb[:, 1] += C_SH_2_N2 * sh_g[:, 4] * x * y + \
                         C_SH_2_N1 * sh_g[:, 5] * y * z + \
                         C_SH_2_0 * sh_g[:, 6] * (3*zz - torch.ones(N, dtype=sh_coeff.dtype, device=sh_coeff.device)) + \
                         C_SH_2_P1 * sh_g[:, 7] * x * z + \
                         C_SH_2_P2 * sh_g[:, 8] * (xx - yy)
            
            rgb[:, 2] += C_SH_2_N2 * sh_b[:, 4] * x * y + \
                         C_SH_2_N1 * sh_b[:, 5] * y * z + \
                         C_SH_2_0 * sh_b[:, 6] * (3*zz - torch.ones(N, dtype=sh_coeff.dtype, device=sh_coeff.device)) + \
                         C_SH_2_P1 * sh_b[:, 7] * x * z + \
                         C_SH_2_P2 * sh_b[:, 8] * (xx - yy)
            
        if N_sh >= 16:
            rgb[:, 0] += C_SH_3_N3 * sh_r[:, 9] * (y * (3*xx - yy)) + \
                         C_SH_3_N2 * sh_r[:, 10] * (x * y * z) + \
                         C_SH_3_N1 * sh_r[:, 11] * (y * (5 * zz - 1)) + \
                         C_SH_3_0 * sh_r[:, 12] * (z * (5 * zz - 1)) + \
                         C_SH_3_P1 * sh_r[:, 13] * (x * (5 * zz - 1)) + \
                         C_SH_3_P2 * sh_r[:, 14] * (z * (xx - yy)) + \
                         C_SH_3_P3 * sh_r[:, 15] * (x * (xx - 3 * yy))
            
            rgb[:, 1] += C_SH_3_N3 * sh_g[:, 9] * (y * (3*xx - yy)) + \
                         C_SH_3_N2 * sh_g[:, 10] * (x * y * z) + \
                         C_SH_3_N1 * sh_g[:, 11] * (y * (5 * zz - 1)) + \
                         C_SH_3_0 * sh_g[:, 12] * (z * (5 * zz - 1)) + \
                         C_SH_3_P1 * sh_g[:, 13] * (x * (5 * zz - 1)) + \
                         C_SH_3_P2 * sh_g[:, 14] * (z * (xx - yy)) + \
                         C_SH_3_P3 * sh_g[:, 15] * (x * (xx - 3 * yy))
            
            rgb[:, 2] += C_SH_3_N3 * sh_b[:, 9] * (y * (3*xx - yy)) + \
                         C_SH_3_N2 * sh_b[:, 10] * (x * y * z) + \
                         C_SH_3_N1 * sh_b[:, 11] * (y * (5 * zz - 1)) + \
                         C_SH_3_0 * sh_b[:, 12] * (z * (5 * zz - 1)) + \
                         C_SH_3_P1 * sh_b[:, 13] * (x * (5 * zz - 1)) + \
                         C_SH_3_P2 * sh_b[:, 14] * (z * (xx - yy)) + \
                         C_SH_3_P3 * sh_b[:, 15] * (x * (xx - 3 * yy))
            
        # apply sigmoid activation to constrain values between 0 and 1
        rgb_sigmoid = torch.sigmoid(rgb)

        ctx.save_for_backward(sh_coeff, view_dir, rgb_sigmoid)
        return rgb_sigmoid
    
    @staticmethod
    def backward(ctx, grad_rgb_sigmoid):
        sh_coeff, view_dir, rgb_sigmoid = ctx.saved_tensors
        x = view_dir[:, 0]
        y = view_dir[:, 1]
        z = view_dir[:, 2]

        N = sh_coeff.shape[0]
        N_sh = sh_coeff.shape[2]

        grad_sh_coeff = torch.zeros_like(sh_coeff)

        # backwards of sigmoid
        grad_rgb = rgb_sigmoid * (1 - rgb_sigmoid) * grad_rgb_sigmoid

        # zero order gradients
        for i in range(3):
            grad_sh_coeff[:, i, 0] = C_SH_0 * grad_rgb[:, i]

        if N_sh >= 4:
            sh_1_n1_grad_mult =  -1 * C_SH_1 * x
            sh_1_0_grad_mult = C_SH_1 * y
            sh_1_p1_grad_mult = -1 * C_SH_1 * z

            for i in range(3):
                grad_sh_coeff[:, i, 1] = sh_1_n1_grad_mult * grad_rgb[:, i]
                grad_sh_coeff[:, i, 2] = sh_1_0_grad_mult * grad_rgb[:, i]
                grad_sh_coeff[:, i, 3] = sh_1_p1_grad_mult * grad_rgb[:, i]

        if N_sh >= 9:
            xx = torch.square(x)
            yy = torch.square(y)
            zz = torch.square(z)

            sh_2_n2_grad_mult = C_SH_2_N2 * x * y
            sh_2_n1_grad_mult = C_SH_2_N1 * y * z
            sh_2_0_grad_mult = C_SH_2_0 * (3 * zz - torch.ones(z.shape[0], dtype=z.dtype, device=z.device))
            sh_2_p1_grad_mult = C_SH_2_P1 * x * z
            sh_2_p2_grad_mult = C_SH_2_P2 * (xx - yy)

            for i in range(3):
                grad_sh_coeff[:, i, 4] = sh_2_n2_grad_mult * grad_rgb[:, i]
                grad_sh_coeff[:, i, 5] = sh_2_n1_grad_mult * grad_rgb[:, i]
                grad_sh_coeff[:, i, 6] = sh_2_0_grad_mult * grad_rgb[:, i]
                grad_sh_coeff[:, i, 7] = sh_2_p1_grad_mult * grad_rgb[:, i]
                grad_sh_coeff[:, i, 8] = sh_2_p2_grad_mult * grad_rgb[:, i]
        
        if N_sh >= 16:
            for i in range(3):
                grad_sh_coeff[:, i, 9] =   C_SH_3_N3 * (y * (3*xx - yy)) * grad_rgb[:, i]
                grad_sh_coeff[:, i, 10] =    C_SH_3_N2 * (x * y * z) * grad_rgb[:, i]
                grad_sh_coeff[:, i, 11] =    C_SH_3_N1 * (y * (5 * zz - 1)) * grad_rgb[:, i]
                grad_sh_coeff[:, i, 12] =    C_SH_3_0 * (z * (5 * zz - 1)) * grad_rgb[:, i]
                grad_sh_coeff[:, i, 13] =    C_SH_3_P1 * (x * (5 * zz - 1)) * grad_rgb[:, i]
                grad_sh_coeff[:, i, 14] =    C_SH_3_P2 * (z * (xx - yy)) * grad_rgb[:, i]
                grad_sh_coeff[:, i, 15] =    C_SH_3_P3 * (x * (xx - 3 * yy)) * grad_rgb[:, i]

        return grad_sh_coeff, None

sh_coeff = torch.rand(10, 3, 16, dtype=torch.float64, requires_grad=True)
view_dir = torch.rand(10, 3, dtype=torch.float64, requires_grad=False)
rgbs = SphericalHarmonicsToRGB.apply(sh_coeff, view_dir)
# print(rgbs)

test = gradcheck(SphericalHarmonicsToRGB.apply, (sh_coeff, view_dir))
print(test)



True
