In [1]:
"""We implement and then we check each gradient so when we daisy chain we are good"""

'We implement and then we check each gradient so when we daisy chain we are good'

In [2]:
import torch
from torch.autograd.gradcheck import gradcheck

from splat.utils import build_rotation


class final_color(torch.autograd.Function):
    @staticmethod
    def forward(ctx, color: torch.Tensor, current_T: torch.Tensor, alpha: torch.Tensor):
        """Color is a nx3 tensor, weight is a nx1 tensor, alpha is a nx1 tensor"""
        ctx.save_for_backward(color, current_T, alpha)
        return color * current_T * alpha
    
    @staticmethod
    def backward(ctx, grad_output: torch.Tensor):
        """Output of forward is a nx3 tensor so the grad_output is a nx3 tensor"""
        color, current_T, alpha = ctx.saved_tensors
        grad_color = grad_output * current_T * alpha
        grad_alpha = (grad_output * color * current_T).sum(dim=1, keepdim=True)
        return grad_color, None, grad_alpha
    
def autograd_test(color: torch.Tensor, current_T: torch.Tensor, alpha: torch.Tensor):
    return color * current_T * alpha

# first we check with gradcheck
color = torch.tensor([[1.0, 2.0, 3.0]], requires_grad=True, dtype=torch.float64)
current_T = torch.tensor([[4.0]], dtype=torch.float64)
alpha = torch.tensor([[5.0]], requires_grad=True, dtype=torch.float64)

print("gradcheck passing: ", gradcheck(final_color.apply, (color, current_T, alpha)))

# then we check with autograd
color = torch.tensor([[1.0, 2.0, 3.0]], requires_grad=True, dtype=torch.float64)
current_T = torch.tensor([[4.0]], dtype=torch.float64)
alpha = torch.tensor([[5.0]], requires_grad=True, dtype=torch.float64)

output = final_color.apply(color, current_T, alpha)
loss = output.sum()
loss.backward()
print("My grads: ", color.grad, current_T.grad, alpha.grad)

# then we check with autograd
color = torch.tensor([[1.0, 2.0, 3.0]], requires_grad=True, dtype=torch.float64)
current_T = torch.tensor([[4.0]], dtype=torch.float64)
alpha = torch.tensor([[5.0]], requires_grad=True, dtype=torch.float64)

output = autograd_test(color, current_T, alpha)
loss = output.sum()
loss.backward()
print("Autograd grads: ", color.grad, current_T.grad, alpha.grad)


# finally we test with multiple dimensions
color = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], requires_grad=True, dtype=torch.float64)
current_T = torch.tensor([[4.0], [5.0]], dtype=torch.float64)
alpha = torch.tensor([[5.0], [6.0]], requires_grad=True, dtype=torch.float64)

output = final_color.apply(color, current_T, alpha)
loss = output.sum()
loss.backward()
print("My grads: ", color.grad, current_T.grad, alpha.grad)

color = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], requires_grad=True, dtype=torch.float64)
current_T = torch.tensor([[4.0], [5.0]], dtype=torch.float64)
alpha = torch.tensor([[5.0], [6.0]], requires_grad=True, dtype=torch.float64)

output = autograd_test(color, current_T, alpha)
loss = output.sum()
loss.backward()
print("Autograd grads: ", color.grad, current_T.grad, alpha.grad)

gradcheck passing:  True
My grads:  tensor([[20., 20., 20.]], dtype=torch.float64) None tensor([[24.]], dtype=torch.float64)
Autograd grads:  tensor([[20., 20., 20.]], dtype=torch.float64) None tensor([[24.]], dtype=torch.float64)
My grads:  tensor([[20., 20., 20.],
        [30., 30., 30.]], dtype=torch.float64) None tensor([[24.],
        [75.]], dtype=torch.float64)
Autograd grads:  tensor([[20., 20., 20.],
        [30., 30., 30.]], dtype=torch.float64) None tensor([[24.],
        [75.]], dtype=torch.float64)


In [3]:
class get_alpha(torch.autograd.Function):
    @staticmethod
    def forward(ctx, gaussian_strength: torch.Tensor, unactivated_opacity: torch.Tensor):
        """Gaussian strength is a nx1 tensor, unactivated opacity is a nx1 tensor"""
        ctx.save_for_backward(gaussian_strength, unactivated_opacity)
        activated_opacity = torch.sigmoid(unactivated_opacity)
        return gaussian_strength * activated_opacity
    
    @staticmethod
    def backward(ctx, grad_output: torch.Tensor):
        """Output of forward is a nx1 tensor so the grad_output is a nx1 tensor"""
        gaussian_strength, unactivated_opacity = ctx.saved_tensors
        derivative_sigmoid = torch.sigmoid(unactivated_opacity) * (1 - torch.sigmoid(unactivated_opacity))
        grad_gaussian_strength = grad_output * torch.sigmoid(unactivated_opacity)
        grad_unactivated_opacity = grad_output * gaussian_strength * derivative_sigmoid
        return grad_gaussian_strength, grad_unactivated_opacity
    
# Define a test function using torch's autograd
def autograd_test(gaussian_strength: torch.Tensor, unactivated_opacity: torch.Tensor):
    activated_opacity = torch.sigmoid(unactivated_opacity)
    return gaussian_strength * activated_opacity

# 1. Gradcheck for numerical gradient correctness
gaussian_strength = torch.tensor([[2.0]], requires_grad=True, dtype=torch.float64)
unactivated_opacity = torch.tensor([[0.5]], requires_grad=True, dtype=torch.float64)

print("gradcheck passing: ", gradcheck(get_alpha.apply, (gaussian_strength, unactivated_opacity)))

# 2. Verify backward computation with autograd
gaussian_strength = torch.tensor([[2.0]], requires_grad=True, dtype=torch.float64)
unactivated_opacity = torch.tensor([[0.5]], requires_grad=True, dtype=torch.float64)

output = get_alpha.apply(gaussian_strength, unactivated_opacity)
loss = output.sum()
loss.backward()
print("My grads: ", gaussian_strength.grad, unactivated_opacity.grad)

# Compare with autograd
gaussian_strength = torch.tensor([[2.0]], requires_grad=True, dtype=torch.float64)
unactivated_opacity = torch.tensor([[0.5]], requires_grad=True, dtype=torch.float64)

output = autograd_test(gaussian_strength, unactivated_opacity)
loss = output.sum()
loss.backward()
print("Autograd grads: ", gaussian_strength.grad, unactivated_opacity.grad)

# 3. Test with multiple dimensions
gaussian_strength = torch.tensor([[2.0], [3.0]], requires_grad=True, dtype=torch.float64)
unactivated_opacity = torch.tensor([[0.5], [1.5]], requires_grad=True, dtype=torch.float64)

output = get_alpha.apply(gaussian_strength, unactivated_opacity)
loss = output.sum()
loss.backward()
print("My grads (multiple dimensions): ", gaussian_strength.grad, unactivated_opacity.grad)

# Compare with autograd
gaussian_strength = torch.tensor([[2.0], [3.0]], requires_grad=True, dtype=torch.float64)
unactivated_opacity = torch.tensor([[0.5], [1.5]], requires_grad=True, dtype=torch.float64)

output = autograd_test(gaussian_strength, unactivated_opacity)
loss = output.sum()
loss.backward()
print("Autograd grads (multiple dimensions): ", gaussian_strength.grad, unactivated_opacity.grad)

# 4. Edge case: Very large and very small values
gaussian_strength = torch.tensor([[1e-5], [1e5]], requires_grad=True, dtype=torch.float64)
unactivated_opacity = torch.tensor([[10.0], [-10.0]], requires_grad=True, dtype=torch.float64)

output = get_alpha.apply(gaussian_strength, unactivated_opacity)
loss = output.sum()
loss.backward()
print("My grads (edge cases): ", gaussian_strength.grad, unactivated_opacity.grad)

gaussian_strength = torch.tensor([[1e-5], [1e5]], requires_grad=True, dtype=torch.float64)
unactivated_opacity = torch.tensor([[10.0], [-10.0]], requires_grad=True, dtype=torch.float64)

output = autograd_test(gaussian_strength, unactivated_opacity)
loss = output.sum()
loss.backward()
print("Autograd grads (edge cases): ", gaussian_strength.grad, unactivated_opacity.grad)

gradcheck passing:  True
My grads:  tensor([[0.6225]], dtype=torch.float64) tensor([[0.4700]], dtype=torch.float64)
Autograd grads:  tensor([[0.6225]], dtype=torch.float64) tensor([[0.4700]], dtype=torch.float64)
My grads (multiple dimensions):  tensor([[0.6225],
        [0.8176]], dtype=torch.float64) tensor([[0.4700],
        [0.4474]], dtype=torch.float64)
Autograd grads (multiple dimensions):  tensor([[0.6225],
        [0.8176]], dtype=torch.float64) tensor([[0.4700],
        [0.4474]], dtype=torch.float64)
My grads (edge cases):  tensor([[9.9995e-01],
        [4.5398e-05]], dtype=torch.float64) tensor([[4.5396e-10],
        [4.5396e+00]], dtype=torch.float64)
Autograd grads (edge cases):  tensor([[9.9995e-01],
        [4.5398e-05]], dtype=torch.float64) tensor([[4.5396e-10],
        [4.5396e+00]], dtype=torch.float64)


In [4]:
import torch
from torch.autograd.gradcheck import gradcheck


class gaussian_exp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, gaussian_weight: torch.Tensor):
        ctx.save_for_backward(gaussian_weight)
        return torch.exp(gaussian_weight)
    
    @staticmethod
    def backward(ctx, grad_output: torch.Tensor):
        gaussian_weight, = ctx.saved_tensors
        grad_gaussian_weight = grad_output * torch.exp(gaussian_weight)
        return grad_gaussian_weight

# Define a test function using torch's autograd
def autograd_test(gaussian_weight: torch.Tensor):
    return torch.exp(gaussian_weight)

# Helper function to compare gradients
def compare_grads(my_grads, autograd_grads, tolerance=1e-6):
    return torch.allclose(my_grads, autograd_grads, atol=tolerance, rtol=tolerance)

# Test cases
def run_tests():
    tolerance = 1e-8

    # 1. Gradcheck for numerical gradient correctness
    gaussian_weight = torch.tensor([[1.0]], requires_grad=True, dtype=torch.float64)
    gradcheck_passed = gradcheck(gaussian_exp.apply, (gaussian_weight,))
    assert gradcheck_passed, "Gradcheck failed!"
    print("Gradcheck passed.")

    # 2. Verify backward computation with autograd (single value)
    gaussian_weight = torch.tensor([[1.0]], requires_grad=True, dtype=torch.float64)

    # My grads
    output = gaussian_exp.apply(gaussian_weight)
    loss = output.sum()
    loss.backward()
    my_grads = gaussian_weight.grad.clone()

    # Autograd grads
    gaussian_weight.grad.zero_()
    output = autograd_test(gaussian_weight)
    loss = output.sum()
    loss.backward()
    autograd_grads = gaussian_weight.grad.clone()

    assert compare_grads(my_grads, autograd_grads, tolerance), "My grads do not match autograd grads (single value)!"
    print("My grads match autograd grads (single value).")

    # 3. Test with multiple dimensions
    gaussian_weight = torch.tensor([[1.0], [2.0], [3.0]], requires_grad=True, dtype=torch.float64)

    # My grads
    output = gaussian_exp.apply(gaussian_weight)
    loss = output.sum()
    loss.backward()
    my_grads = gaussian_weight.grad.clone()

    # Autograd grads
    gaussian_weight.grad.zero_()
    output = autograd_test(gaussian_weight)
    loss = output.sum()
    loss.backward()
    autograd_grads = gaussian_weight.grad.clone()

    assert compare_grads(my_grads, autograd_grads, tolerance), "My grads do not match autograd grads (multiple dimensions)!"
    print("My grads match autograd grads (multiple dimensions).")

    # 4. Edge case: Very large and very small values
    gaussian_weight = torch.tensor([[1e-5], [1e5]], requires_grad=True, dtype=torch.float64)

    # My grads
    output = gaussian_exp.apply(gaussian_weight)
    loss = output.sum()
    loss.backward()
    my_grads = gaussian_weight.grad.clone()

    # Autograd grads
    gaussian_weight.grad.zero_()
    output = autograd_test(gaussian_weight)
    loss = output.sum()
    loss.backward()
    autograd_grads = gaussian_weight.grad.clone()

    assert compare_grads(my_grads, autograd_grads, tolerance), "My grads do not match autograd grads (edge cases)!"
    print("My grads match autograd grads (edge cases).")

# Run the tests
run_tests()

Gradcheck passed.
My grads match autograd grads (single value).
My grads match autograd grads (multiple dimensions).
My grads match autograd grads (edge cases).


In [5]:
import torch
from torch.autograd.gradcheck import gradcheck


class gaussian_weight(torch.autograd.Function):
    @staticmethod
    def forward(ctx, gaussian_mean: torch.Tensor, inverted_covariance: torch.Tensor, pixel: torch.Tensor):
        """
        Pixel means are a nx2 tensor, inverted covariance is a 2x2 tensor, pixel is a nx2 tensor
        Outputs a nx1 tensor
        """
        ctx.save_for_backward(gaussian_mean, inverted_covariance, pixel)
        diff = (pixel - gaussian_mean).unsqueeze(1)
        # 2x2 * 2x1 = 2x1
        inv_cov_mult = torch.einsum('bij,bjk->bik', inverted_covariance, diff.transpose(1, 2))
        return -0.5 * torch.einsum('bij,bjk->bik', diff, inv_cov_mult).squeeze(-1)
    
    @staticmethod
    def backward(ctx, grad_output: torch.Tensor):
        """Output of forward is a nx1 tensor so the grad_output is a nx1 tensor"""
        gaussian_mean, inverted_covariance, pixel = ctx.saved_tensors
        diff = (pixel - gaussian_mean).unsqueeze(1)  # nx2x1

        deriv_wrt_inv_cov = -0.5 * torch.einsum("bij,bjk->bik", diff.transpose(1, 2), diff)
        grad_inv_cov = grad_output * deriv_wrt_inv_cov  # output is nx2x2
        print("Grad inv cov: ", grad_inv_cov.shape)
        print("Grad output: ", grad_output.shape)
        print("Diff: ", diff.shape)

        # deriv_wrt_diff = -0.5 * 2 * torch.einsum("bij,bjk->bik", diff, inverted_covariance)
        # deriv_wrt_gaussian_mean = -1
        # grad_gaussian_mean = torch.einsum("bi,bij->bj", grad_output, deriv_wrt_diff) * deriv_wrt_gaussian_mean]
        deriv_output_wrt_diff1 = torch.einsum("bij,bjk->bik", inverted_covariance, diff.transpose(1, 2))
        deriv_output_wrt_diff2 = torch.einsum("bij,bjk->bik", inverted_covariance.transpose(1, 2), diff.transpose(1, 2))
        deriv_output_wrt_diff = -0.5 * torch.einsum("bi,bji->bj", grad_output, deriv_output_wrt_diff1 + deriv_output_wrt_diff2)
        grad_gaussian_mean = deriv_output_wrt_diff * -1
        return grad_gaussian_mean, grad_inv_cov, None
    

# Define a test function using PyTorch's autograd
def autograd_test(gaussian_mean, inverted_covariance, pixel):
    diff = (pixel - gaussian_mean).unsqueeze(1)
    # 2x2 * 2x1 = 2x1
    inv_cov_mult = torch.einsum('bij,bjk->bik', inverted_covariance, diff.transpose(1, 2))
    return -0.5 * torch.einsum('bij,bjk->bik', diff, inv_cov_mult).squeeze(-1)

# Helper function to compare gradients
def compare_grads(my_grads, autograd_grads, tolerance=1e-6):
    return torch.allclose(my_grads, autograd_grads, atol=tolerance, rtol=tolerance)

# Test cases
def run_tests():
    tolerance = 1e-6

    # 1. Gradcheck for numerical gradient correctness
    gaussian_mean = torch.tensor([[1.0, 2.0]], requires_grad=True, dtype=torch.float64)
    inverted_covariance = torch.tensor([[[2.0, 0.0], [0.0, 1.5]]], requires_grad=True, dtype=torch.float64)
    pixel = torch.tensor([[1.5, 2.5]], requires_grad=False, dtype=torch.float64)

    assert gradcheck(gaussian_weight.apply, (gaussian_mean, inverted_covariance, pixel)), "Gradcheck failed!"
    print("Gradcheck passed.")

    # 2. Verify backward computation with autograd (single pixel)
    gaussian_mean = torch.tensor([[1.0, 2.0]], requires_grad=True, dtype=torch.float64)
    inverted_covariance = torch.tensor([[[2.0, 0.0], [0.0, 1.5]]], requires_grad=True, dtype=torch.float64)
    pixel = torch.tensor([[1.5, 2.5]], requires_grad=False, dtype=torch.float64)

    # My grads
    output = gaussian_weight.apply(gaussian_mean, inverted_covariance, pixel)
    loss = output.sum()
    loss.backward()
    my_grad_gaussian_mean = gaussian_mean.grad.clone()
    my_grad_inv_cov = inverted_covariance.grad.clone()

    # Autograd grads
    gaussian_mean.grad.zero_()
    inverted_covariance.grad.zero_()
    output = autograd_test(gaussian_mean, inverted_covariance, pixel)
    loss = output.sum()
    loss.backward()
    autograd_grad_gaussian_mean = gaussian_mean.grad.clone()
    autograd_grad_inv_cov = inverted_covariance.grad.clone()

    assert compare_grads(my_grad_gaussian_mean, autograd_grad_gaussian_mean, tolerance), "Mismatch in grads (gaussian_mean, single pixel)!"
    assert compare_grads(my_grad_inv_cov, autograd_grad_inv_cov, tolerance), "Mismatch in grads (inverted_covariance, single pixel)!"
    print("Gradients match for single pixel.")

    # 3. Test with multiple pixels
    gaussian_mean = torch.tensor([[1.0, 2.0], [2.0, 3.0]], requires_grad=True, dtype=torch.float64)
    inverted_covariance = torch.tensor([[[2.0, 0.0], [0.0, 1.5]]], requires_grad=True, dtype=torch.float64)
    pixel = torch.tensor([[3.0, 4.0], [1.5, 2.5]], requires_grad=False, dtype=torch.float64)

    # My grads
    output = gaussian_weight.apply(gaussian_mean, inverted_covariance, pixel)
    loss = output.sum()
    loss.backward()
    my_grad_gaussian_mean = gaussian_mean.grad.clone()
    my_grad_inv_cov = inverted_covariance.grad.clone()

    # Autograd grads
    gaussian_mean.grad.zero_()
    inverted_covariance.grad.zero_()
    output = autograd_test(gaussian_mean, inverted_covariance, pixel)
    print("Output: ", output)
    loss = output.sum()
    loss.backward()
    autograd_grad_gaussian_mean = gaussian_mean.grad.clone()
    autograd_grad_inv_cov = inverted_covariance.grad.clone()

    print("My grads: ", my_grad_gaussian_mean)
    print("Autograd grads: ", autograd_grad_gaussian_mean)
    assert compare_grads(my_grad_gaussian_mean, autograd_grad_gaussian_mean, tolerance), "Mismatch in grads (gaussian_mean, multiple pixels)!"
    assert compare_grads(my_grad_inv_cov, autograd_grad_inv_cov, tolerance), "Mismatch in grads (inverted_covariance, multiple pixels)!"
    print("Gradients match for multiple pixels.")

# Run the tests
run_tests()

Grad inv cov:  torch.Size([1, 2, 2])
Grad output:  torch.Size([1, 1])
Diff:  torch.Size([1, 1, 2])
Grad inv cov:  torch.Size([1, 2, 2])
Grad output:  torch.Size([1, 1])
Diff:  torch.Size([1, 1, 2])
Grad inv cov:  torch.Size([1, 2, 2])
Grad output:  torch.Size([1, 1])
Diff:  torch.Size([1, 1, 2])
Grad inv cov:  torch.Size([1, 2, 2])
Grad output:  torch.Size([1, 1])
Diff:  torch.Size([1, 1, 2])
Gradcheck passed.
Grad inv cov:  torch.Size([1, 2, 2])
Grad output:  torch.Size([1, 1])
Diff:  torch.Size([1, 1, 2])
Gradients match for single pixel.
Grad inv cov:  torch.Size([2, 2, 2])
Grad output:  torch.Size([2, 1])
Diff:  torch.Size([2, 1, 2])
Output:  tensor([[-7.0000],
        [-0.4375]], dtype=torch.float64, grad_fn=<MulBackward0>)
My grads:  tensor([[ 4.0000,  3.0000],
        [-1.0000, -0.7500]], dtype=torch.float64)
Autograd grads:  tensor([[ 4.0000,  3.0000],
        [-1.0000, -0.7500]], dtype=torch.float64)
Gradients match for multiple pixels.


In [7]:
def d_inv_wrt_a(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, d: torch.Tensor, grad_output: torch.Tensor):
    """
    All tensors are nx1 tensors - returns a nx1 tensor by summing over the last dimension
    """
    det = a * d - b * c
    deriv = -1 * (d**2) / (det**2) * grad_output[:, 0, 0]
    deriv += (b*d) / (det**2) * grad_output[:, 0, 1]
    deriv += (c*d) / (det**2) * grad_output[:, 1, 0]
    deriv += -1 * (b*c) / (det**2) * grad_output[:, 1, 1]
    return deriv

def d_inv_wrt_b(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, d: torch.Tensor, grad_output: torch.Tensor):                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        
    det = a * d - b * c
    deriv = (c * d) / (det**2) * grad_output[:, 0, 0]
    deriv += -1 * (a * d) / (det**2) * grad_output[:, 0, 1]
    deriv += -1 * (c*c) / (det**2) * grad_output[:, 1, 0]
    deriv += (a*c) / (det**2) * grad_output[:, 1, 1]
    return deriv

def d_inv_wrt_c(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, d: torch.Tensor, grad_output: torch.Tensor):
    det = a * d - b * c
    deriv = (b*d) / (det**2) * grad_output[:, 0, 0]
    deriv += -1 * (b*b) / (det**2) * grad_output[:, 0, 1]
    deriv += -1 * (a*d) / (det**2) * grad_output[:, 1, 0]
    deriv += (a*b) / (det**2) * grad_output[:, 1, 1]
    return deriv

def d_inv_wrt_d(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, d: torch.Tensor, grad_output: torch.Tensor):
    det = a * d - b * c
    deriv = -1 * (b*c) / (det**2) * grad_output[:, 0, 0]
    deriv += (a*b) / (det**2) * grad_output[:, 0, 1]
    deriv += (a*c) / (det**2) * grad_output[:, 1, 0]
    deriv += -1 * (a*a) / (det**2) * grad_output[:, 1, 1]
    return deriv

class invert_2x2_matrix(torch.autograd.Function):
    @staticmethod
    def forward(ctx, matrix: torch.Tensor):
        """input is a nx2x2 tensor, returns the inverse of each 2x2 matrix"""
        ctx.save_for_backward(matrix)
        det = matrix[:, 0, 0] * matrix[:, 1, 1] - matrix[:, 0, 1] * matrix[:, 1, 0]
        # Create empty nx2x2 tensor for the inverted matrices
        final_matrices = torch.zeros_like(matrix)
        
        # Fill in the inverted matrix elements using the 2x2 matrix inverse formula
        final_matrices[:, 0, 0] = matrix[:, 1, 1] / det  
        final_matrices[:, 0, 1] = -matrix[:, 0, 1] / det
        final_matrices[:, 1, 0] = -matrix[:, 1, 0] / det
        final_matrices[:, 1, 1] = matrix[:, 0, 0] / det

        if torch.isinf(final_matrices).any():
            raise RuntimeError("Infinite values in final matrices")
        return final_matrices
    
    @staticmethod
    def backward(ctx, grad_output: torch.Tensor):
        """grad_output is a nx2x2 tensor, returns the gradient of the inverse of each 2x2 matrix"""
        matrix = ctx.saved_tensors[0]
        a = matrix[:, 0, 0]
        b = matrix[:, 0, 1]
        c = matrix[:, 1, 0]
        d = matrix[:, 1, 1]
        grad_a = d_inv_wrt_a(a, b, c, d, grad_output)
        grad_b = d_inv_wrt_b(a, b, c, d, grad_output)
        grad_c = d_inv_wrt_c(a, b, c, d, grad_output)
        grad_d = d_inv_wrt_d(a, b, c, d, grad_output)
        return torch.stack([grad_a, grad_b, grad_c, grad_d], dim=-1).view(matrix.shape)
    
def invert_2x2_matrix_test(matrix: torch.Tensor):
    det = matrix[:, 0, 0] * matrix[:, 1, 1] - matrix[:, 0, 1] * matrix[:, 1, 0]
    # Create empty nx2x2 tensor for the inverted matrices
    final_matrices = torch.zeros_like(matrix)
    
    # Fill in the inverted matrix elements using the 2x2 matrix inverse formula
    final_matrices[:, 0, 0] = matrix[:, 1, 1] / det  
    final_matrices[:, 0, 1] = -matrix[:, 0, 1] / det
    final_matrices[:, 1, 0] = -matrix[:, 1, 0] / det
    final_matrices[:, 1, 1] = matrix[:, 0, 0] / det
    return final_matrices


# Define helper functions to compare gradients and compute reference gradients
def compare_grads(my_grads, autograd_grads, tolerance=1e-6):
    return torch.allclose(my_grads, autograd_grads, atol=tolerance, rtol=tolerance)

# Test cases for invert_2x2_matrix
def run_tests():
    tolerance = 1e-6

    # 1. Gradcheck for numerical gradient correctness
    matrix = torch.tensor(
        [[[4.0, 7.0], [2.0, 6.0]]], requires_grad=True, dtype=torch.float64
    )  # Single 2x2 matrix
    assert gradcheck(invert_2x2_matrix.apply, (matrix,)), "Gradcheck failed!"
    print("Gradcheck passed.")

    # 2. Verify backward computation with autograd (single matrix)
    matrix = torch.tensor(
        [[[4.0, 7.0], [2.0, 6.0]]], requires_grad=True, dtype=torch.float64
    )
    # My grads
    output = invert_2x2_matrix.apply(matrix)
    loss = output.sum()
    loss.backward()
    my_grads = matrix.grad.clone()

    # Autograd grads
    matrix.grad.zero_()
    output = invert_2x2_matrix_test(matrix)
    loss = output.sum()
    loss.backward()
    autograd_grads = matrix.grad.clone()

    assert compare_grads(my_grads, autograd_grads, tolerance), "Mismatch in gradients (single matrix)!"
    print("Gradients match for single matrix.")

    # 3. Test with multiple matrices
    matrix = torch.tensor(
        [[[4.0, 7.0], [2.0, 6.0]], [[1.0, 2.0], [3.0, 4.0]]], requires_grad=True, dtype=torch.float64
    )
    # My grads
    output = invert_2x2_matrix.apply(matrix)
    loss = output.sum()
    loss.backward()
    my_grads = matrix.grad.clone()

    # Autograd grads
    matrix.grad.zero_()
    output = invert_2x2_matrix_test(matrix)
    loss = output.sum()
    loss.backward()
    autograd_grads = matrix.grad.clone()

    assert compare_grads(my_grads, autograd_grads, tolerance), "Mismatch in gradients (multiple matrices)!"
    print("Gradients match for multiple matrices.")

    # 4. Edge cases: non-invertible matrix
    try:
        matrix = torch.tensor(
            [[[1.0, 2.0], [2.0, 4.0]]], requires_grad=True, dtype=torch.float64
        )
        output = invert_2x2_matrix.apply(matrix)
        print("Output: ", output)
        print("Non-invertible matrix did not raise an error as expected.")
    except RuntimeError as e:
        print(f"Correctly raised error for non-invertible matrix: {e}")

    # 5. Test with larger batch sizes
    batch_size = 10
    matrix = torch.rand((batch_size, 2, 2), requires_grad=True, dtype=torch.float64)
    # My grads
    output = invert_2x2_matrix.apply(matrix)
    loss = output.sum()
    loss.backward()
    my_grads = matrix.grad.clone()

    # Autograd grads
    matrix.grad.zero_()
    output = invert_2x2_matrix_test(matrix)
    loss = output.sum()
    loss.backward()
    autograd_grads = matrix.grad.clone()

    assert compare_grads(my_grads, autograd_grads, tolerance), "Mismatch in gradients (batch size 10)!"
    print("Gradients match for larger batch size.")

# Run the tests
run_tests()


Gradcheck passed.
Gradients match for single matrix.
Gradients match for multiple matrices.
Correctly raised error for non-invertible matrix: Infinite values in final matrices
Gradients match for larger batch size.


In [20]:
class covariance_3d_to_covariance_2d(torch.autograd.Function):
    @staticmethod
    def forward(ctx, covariance_3d: torch.Tensor, U: torch.Tensor):
        """
        Covariance 3d is the nx3x3 covariance matrix.
        U is the J@W.T matrix. this is a 3x3 matrix
        To get the covariance 2d we do U.T @ covariance_3d @ U
        """
        ctx.save_for_backward(U, covariance_3d)
        outcome = torch.bmm(U.transpose(1, 2), torch.bmm(covariance_3d, U))
        return outcome

    @staticmethod
    def backward(ctx, grad_output: torch.Tensor):
        U, covariance_3d = ctx.saved_tensors

        # Derivative of (U^T * C * U) w.r.t. C = U * grad_output * U^T
        # grad_cov3d = torch.einsum("nij,njk->nik", U, grad_output)
        # grad_cov3d = torch.einsum("nij,njk->nik", grad_cov3d, U.transpose(1, 2))
        grad_cov3d = U @ grad_output @ U.transpose(1, 2)
        # Derivative of (U^T * C * U) w.r.t. U
        # Z = (U^T * (C * U)) Y= C * U
        # the contribution from Y is covariance_3d.T @ grad_output
        y = torch.einsum("nij,njk->nik", covariance_3d, U)
        deriv_U_first_part = torch.einsum("nij,njk->nik", grad_output, y.transpose(1, 2)).transpose(1, 2)
        dz_dy = torch.einsum("nij,njk->nik", U, grad_output)
        dy_du = torch.einsum("nij,njk->nik", covariance_3d.transpose(1, 2), dz_dy)
        return grad_cov3d, deriv_U_first_part + dy_du

def covariance_3d_to_covariance_2d_test(E: torch.Tensor, U: torch.Tensor):
    """E is a nx3x3 matrix, U is a n3x3 matrix"""
    return torch.bmm(U.transpose(1, 2), torch.bmm(covariance_3d, U))

# Define helper functions for comparing gradients
def compare_grads(my_grads, autograd_grads, tolerance=1e-6):
    return torch.allclose(my_grads, autograd_grads, atol=tolerance, rtol=tolerance)


torch.random.manual_seed(0)
covariance_3d = torch.randn(1, 3, 3, requires_grad=True, dtype=torch.float64)
U = torch.randn(1, 3, 3, requires_grad=True, dtype=torch.float64)
# print("U", U)

output = covariance_3d_to_covariance_2d.apply(covariance_3d, U)
loss = output.sum()
loss.backward()
# print("output", output)
# print("cov grad", covariance_3d.grad)
print("U grad", U.grad)

# auto test
torch.random.manual_seed(0)
covariance_3d = torch.randn(1, 3, 3, requires_grad=True, dtype=torch.float64)
U = torch.randn(1, 3, 3, requires_grad=True, dtype=torch.float64)
# print("U", U)

output = covariance_3d_to_covariance_2d_test(covariance_3d, U)
print("output", output)
loss = output.sum()
loss.backward()
# print("cov grad", covariance_3d.grad)
print("U grad", U.grad)

gradcheck(covariance_3d_to_covariance_2d.apply, (covariance_3d, U))


U grad tensor([[[-2.6232, -2.6232, -2.6232],
         [ 1.6089,  1.6089,  1.6089],
         [ 2.0174,  2.0174,  2.0174]]], dtype=torch.float64)
output tensor([[[-0.3142,  0.7716, -0.6829],
         [ 1.5104, -1.4269,  1.5671],
         [-1.0433,  1.1982, -1.2437]]], dtype=torch.float64,
       grad_fn=<BmmBackward0>)
U grad tensor([[[-2.6232, -2.6232, -2.6232],
         [ 1.6089,  1.6089,  1.6089],
         [ 2.0174,  2.0174,  2.0174]]], dtype=torch.float64)


True

In [9]:
class R_S_To_M(torch.autograd.Function):
    @staticmethod
    def forward(ctx, R: torch.Tensor, S: torch.Tensor):
        """R is a nx3x3 rotation matrix, S is a nx3x3 scale matrix"""
        ctx.save_for_backward(R, S)
        return torch.einsum("nij,njk->nik", R, S)
    
    @staticmethod
    def backward(ctx, grad_output: torch.Tensor):
        R, S = ctx.saved_tensors
        grad_R = torch.einsum("nij,njk->nik", grad_output, S.transpose(1, 2))
        grad_S = torch.einsum("nij,njk->nik", R.transpose(1, 2), grad_output)
        return grad_R, grad_S

# Define helper functions for comparing gradients
def compare_grads(my_grads, autograd_grads, tolerance=1e-6):
    return torch.allclose(my_grads, autograd_grads, atol=tolerance, rtol=tolerance)

# Test cases for R_S_To_M
def run_R_S_To_M_tests():
    tolerance = 1e-6

    # 1. Gradcheck for numerical gradient correctness3, 3, requires_grad=True, dtype=torch.float64)
    R = torch.randn((1, 3, 3), requires_grad=True, dtype=torch.float64)
    S = torch.randn((1, 3, 3), requires_grad=True, dtype=torch.float64)
    assert gradcheck(R_S_To_M.apply, (R, S)), "Gradcheck failed!"
    print("Gradcheck passed.")

    # 2. Verify backward computation with autograd (single matrix)
    c = torch.eye(3, requires_grad=False, dtype=torch.float64).unsqueeze(0)
    R = torch.tensor(c, requires_grad=True, dtype=torch.float64)
    S = torch.tensor(c, requires_grad=True, dtype=torch.float64)

    # My grads
    output = R_S_To_M.apply(R, S)
    loss = output.sum()
    loss.backward()
    print(f"R.grad: {R.grad}, S.grad: {S.grad}")
    my_grads_R = R.grad.clone()
    my_grads_S = S.grad.clone()

    # Autograd grads
    R.grad.zero_()
    S.grad.zero_()
    output = torch.einsum("nij,njk->nik", R, S)
    loss = output.sum()
    loss.backward()
    autograd_grads_R = R.grad.clone()
    autograd_grads_S = S.grad.clone()

    assert compare_grads(my_grads_R, autograd_grads_R, tolerance), "Mismatch in gradients for R (single matrix)!"
    assert compare_grads(my_grads_S, autograd_grads_S, tolerance), "Mismatch in gradients for S (single matrix)!"
    print("Gradients match for single matrix.")

    # 3. Test with multiple matrices
    R = torch.stack([torch.randn((3, 3), dtype=torch.float64) for _ in range(5)], dim=0).requires_grad_()
    S = torch.stack([torch.randn((3, 3), dtype=torch.float64) for _ in range(5)], dim=0).requires_grad_()

    # My grads
    output = R_S_To_M.apply(R, S)
    loss = output.sum()
    loss.backward()
    my_grads_R = R.grad.clone()
    my_grads_S = S.grad.clone()

    # Autograd grads
    R.grad.zero_()
    S.grad.zero_()
    output = torch.einsum("nij,njk->nik", R, S)
    loss = output.sum()
    loss.backward()
    autograd_grads_R = R.grad.clone()
    autograd_grads_S = S.grad.clone()

    assert compare_grads(my_grads_R, autograd_grads_R, tolerance), "Mismatch in gradients for R (multiple matrices)!"
    assert compare_grads(my_grads_S, autograd_grads_S, tolerance), "Mismatch in gradients for S (multiple matrices)!"
    print("Gradients match for multiple matrices.")

    # 4. Edge cases: non-orthogonal R or S
    try:
        R = torch.rand((1, 3, 3), requires_grad=True, dtype=torch.float64)
        S = torch.rand((1, 3, 3), requires_grad=True, dtype=torch.float64)
        output = R_S_To_M.apply(R, S)
        print("Output: ", output)
        print("Non-orthogonal R or S did not raise an error as expected.")
    except RuntimeError as e:
        print(f"Handled non-orthogonal case correctly: {e}")

    # 5. Test with larger batch sizes
    batch_size = 10
    R = torch.stack([torch.eye(3, dtype=torch.float64) for _ in range(batch_size)], dim=0).requires_grad_()
    S = torch.stack([torch.eye(3, dtype=torch.float64) for _ in range(batch_size)], dim=0).requires_grad_()

    # My grads
    output = R_S_To_M.apply(R, S)
    loss = output.sum()
    loss.backward()
    my_grads_R = R.grad.clone()
    my_grads_S = S.grad.clone()

    # Autograd grads
    R.grad.zero_()
    S.grad.zero_()
    output = torch.einsum("nij,njk->nik", R, S)
    loss = output.sum()
    loss.backward()
    autograd_grads_R = R.grad.clone()
    autograd_grads_S = S.grad.clone()

    assert compare_grads(my_grads_R, autograd_grads_R, tolerance), "Mismatch in gradients for R (batch size 10)!"
    assert compare_grads(my_grads_S, autograd_grads_S, tolerance), "Mismatch in gradients for S (batch size 10)!"
    print("Gradients match for larger batch size.")

# Run the tests
run_R_S_To_M_tests()

Gradcheck passed.
R.grad: tensor([[[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]]], dtype=torch.float64), S.grad: tensor([[[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]]], dtype=torch.float64)
Gradients match for single matrix.
Gradients match for multiple matrices.
Output:  tensor([[[0.9408, 0.7610, 0.4686],
         [0.9541, 0.6338, 0.1669],
         [0.8051, 0.8766, 0.8510]]], dtype=torch.float64,
       grad_fn=<R_S_To_MBackward>)
Non-orthogonal R or S did not raise an error as expected.
Gradients match for larger batch size.


  R = torch.tensor(c, requires_grad=True, dtype=torch.float64)
  S = torch.tensor(c, requires_grad=True, dtype=torch.float64)


In [10]:
import torch
from torch.autograd import gradcheck

class M_to_covariance(torch.autograd.Function):
    @staticmethod
    def forward(ctx, M: torch.Tensor):
        ctx.save_for_backward(M)
        return M.pow(2)
    
    @staticmethod
    def backward(ctx, grad_output: torch.Tensor):
        M = ctx.saved_tensors[0]
        return 2 * grad_output * M
    
def test_M_To_Covariance(M: torch.Tensor):
    return M.pow(2)

# Helper functions for testing
def compare_grads(my_grads, autograd_grads, tolerance=1e-6):
    return torch.allclose(my_grads, autograd_grads, atol=tolerance, rtol=tolerance)

# Test cases for M_To_Covariance
def run_M_To_Covariance_tests():
    tolerance = 1e-6

    # 1. Gradcheck for numerical gradient correctness
    M = torch.randn((2, 3, 3), requires_grad=True, dtype=torch.float64)
    assert gradcheck(M_to_covariance.apply, (M,)), "Gradcheck failed!"
    print("Gradcheck passed.")

    # 2. Verify backward computation with autograd (single matrix)
    M = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]],
                     requires_grad=True, dtype=torch.float64)
    M = torch.tensor(M.unsqueeze(0), requires_grad=True, dtype=torch.float64)

    # My grads
    output = M_to_covariance.apply(M)
    loss = output.sum()
    loss.backward()
    print(f"M.grad: {M.grad}")
    my_grads_M = M.grad.clone()

    # Autograd grads
    M.grad.zero_()
    output = test_M_To_Covariance(M)
    loss = output.sum()
    loss.backward()
    autograd_grads_M = M.grad.clone()

    assert compare_grads(my_grads_M, autograd_grads_M, tolerance), "Mismatch in gradients for M (single matrix)!"
    print("Gradients match for single matrix.")

    # 3. Test with multiple matrices
    M = torch.stack([torch.randn((3, 3), dtype=torch.float64) for _ in range(5)], dim=0).requires_grad_()

    # My grads
    output = M_to_covariance.apply(M)
    loss = output.sum()
    loss.backward()
    my_grads_M = M.grad.clone()

    # Autograd grads
    M.grad.zero_()
    output = M.pow(2)
    loss = output.sum()
    loss.backward()
    autograd_grads_M = M.grad.clone()

    assert compare_grads(my_grads_M, autograd_grads_M, tolerance), "Mismatch in gradients for M (multiple matrices)!"
    print("Gradients match for multiple matrices.")

    # 4. Test edge cases: matrix with zeros
    M = torch.zeros((1, 3, 3), requires_grad=True, dtype=torch.float64)

    # My grads
    output = M_to_covariance.apply(M)
    loss = output.sum()
    loss.backward()
    my_grads_M = M.grad.clone()

    # Autograd grads
    M.grad.zero_()
    output = test_M_To_Covariance(M)
    loss = output.sum()
    loss.backward()
    autograd_grads_M = M.grad.clone()

    assert compare_grads(my_grads_M, autograd_grads_M, tolerance), "Mismatch in gradients for M (zero matrix)!"
    print("Gradients match for zero matrix.")

    # 5. Test with larger batch sizes
    batch_size = 10
    M = torch.randn((batch_size, 3, 3), requires_grad=True, dtype=torch.float64)

    # My grads
    output = M_to_covariance.apply(M)
    loss = output.sum()
    loss.backward()
    my_grads_M = M.grad.clone()

    # Autograd grads
    M.grad.zero_()
    output = test_M_To_Covariance(M)
    loss = output.sum()
    loss.backward()
    autograd_grads_M = M.grad.clone()

    assert compare_grads(my_grads_M, autograd_grads_M, tolerance), "Mismatch in gradients for M (batch size 10)!"
    print("Gradients match for larger batch size.")

# Run the tests
run_M_To_Covariance_tests()

Gradcheck passed.
M.grad: tensor([[[ 2.,  4.,  6.],
         [ 8., 10., 12.],
         [14., 16., 18.]]], dtype=torch.float64)
Gradients match for single matrix.
Gradients match for multiple matrices.
Gradients match for zero matrix.
Gradients match for larger batch size.


  M = torch.tensor(M.unsqueeze(0), requires_grad=True, dtype=torch.float64)


In [11]:
from splat.utils import build_rotation

def d_r_wrt_qr(quats: torch.Tensor, n: int) -> torch.Tensor:
    """
    Compute the derivative of m wrt quats
    quats is nx4 tensor
    shape is nx3 tensor
    """
    qr = quats[:, 0]
    qi = quats[:, 1]
    qj = quats[:, 2]
    qk = quats[:, 3]

    derivative = torch.zeros((n, 3, 3))
    derivative[:, 0, 1] = -qk
    derivative[:, 0, 2] = qj
    derivative[:, 1, 0] = qk
    derivative[:, 1, 2] = -qi
    derivative[:, 2, 0] = -qj
    derivative[:, 2, 1] = qi

    return 2 * derivative


def d_r_wrt_qi(quats: torch.Tensor, n: int) -> torch.Tensor:
    """
    Compute the derivative of m wrt quats
    quats is nx4 tensor
    shape is nx3 tensor
    """
    qr = quats[:, 0]
    qi = quats[:, 1]
    qj = quats[:, 2]
    qk = quats[:, 3]

    derivative = torch.zeros((n, 3, 3))
    derivative[:, 0, 1] = qj
    derivative[:, 0, 2] = qk
    derivative[:, 1, 0] = qj
    derivative[:, 1, 1] = -2 * qi
    derivative[:, 1, 2] = -qr
    derivative[:, 2, 0] = qk
    derivative[:, 2, 1] = qr
    derivative[:, 2, 2] = -2 * qi
    return 2 * derivative


def d_r_wrt_qj(quats: torch.Tensor, n: int) -> torch.Tensor:
    """
    Compute the derivative of m wrt quats
    quats is nx4 tensor
    shape is nx3 tensor
    """
    qr = quats[:, 0]
    qi = quats[:, 1]
    qj = quats[:, 2]
    qk = quats[:, 3]

    derivative = torch.zeros((n, 3, 3))
    derivative[:, 0, 0] = -2 * qj
    derivative[:, 0, 1] = qi
    derivative[:, 0, 2] = qr
    derivative[:, 1, 0] = qi
    derivative[:, 1, 1] = 0
    derivative[:, 1, 2] = qk
    derivative[:, 2, 0] = -qr
    derivative[:, 2, 1] = qk
    derivative[:, 2, 2] = -2 * qj
    return 2 * derivative


def d_r_wrt_qk(quats: torch.Tensor, n: int) -> torch.Tensor:
    """
    Compute the derivative of m wrt quats
    quats is nx4 tensor
    shape is nx3 tensor
    """
    qr = quats[:, 0]
    qi = quats[:, 1]
    qj = quats[:, 2]
    qk = quats[:, 3]

    derivative = torch.zeros((n, 3, 3))
    derivative[:, 0, 0] = -2 * qk
    derivative[:, 0, 1] = -qr
    derivative[:, 0, 2] = qi
    derivative[:, 1, 0] = qr
    derivative[:, 1, 1] = -2*qk
    derivative[:, 1, 2] = qj
    derivative[:, 2, 0] = qi
    derivative[:, 2, 1] = qj
    derivative[:, 2, 2] = 0
    return 2 * derivative

class quats_to_R(torch.autograd.Function):
    @staticmethod
    def forward(ctx, quats: torch.Tensor) -> torch.Tensor:
        ctx.save_for_backward(quats)
        R = build_rotation(quats, normalize=False)
        return R
    
    @staticmethod
    def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
        """
        grad output is nx3x3 and the jacobian nx3x3x4 so deriv_wrt_qr 
        is nx3x3 where each entry is the derivative wrt one element in the final wrt r
        """
        quats = ctx.saved_tensors[0]
        deriv_wrt_qr = d_r_wrt_qr(quats, quats.shape[0])
        deriv_wrt_qi = d_r_wrt_qi(quats, quats.shape[0])
        deriv_wrt_qj = d_r_wrt_qj(quats, quats.shape[0])
        deriv_wrt_qk = d_r_wrt_qk(quats, quats.shape[0])
        
        deriv_wrt_qr = (grad_output * deriv_wrt_qr).sum(dim=(1, 2), keepdim=True).squeeze(2)
        deriv_wrt_qi = (grad_output * deriv_wrt_qi).sum(dim=(1, 2), keepdim=True).squeeze(2)
        deriv_wrt_qj = (grad_output * deriv_wrt_qj).sum(dim=(1, 2), keepdim=True).squeeze(2)
        deriv_wrt_qk = (grad_output * deriv_wrt_qk).sum(dim=(1, 2), keepdim=True).squeeze(2)
        return torch.cat([deriv_wrt_qr, deriv_wrt_qi, deriv_wrt_qj, deriv_wrt_qk], dim=1)
    
def test_quats_to_R(quats: torch.Tensor):
    return build_rotation(quats, normalize=False)


# Helper function to compare gradients
def compare_grads(my_grads, autograd_grads, tolerance=1e-6):
    return torch.allclose(my_grads, autograd_grads, atol=tolerance, rtol=tolerance)

# Test suite for quats_to_R
def run_quats_to_R_tests():
    tolerance = 1e-6

    # 1. Gradcheck for numerical gradient correctness
    quats = torch.randn((5, 4), requires_grad=True, dtype=torch.float64)
    assert gradcheck(quats_to_R.apply, (quats,)), "Gradcheck failed!"
    print("Gradcheck passed.")

    # 2. Verify backward computation with autograd (single quaternion)
    quats = torch.tensor([[1.0, 0.0, 0.0, 0.0]], requires_grad=True, dtype=torch.float64)

    # My grads
    output = quats_to_R.apply(quats)
    loss = output.sum()
    loss.backward()
    print(f"Quats.grad: {quats.grad}")
    my_grads = quats.grad.clone()

    # Autograd grads
    quats.grad.zero_()
    output = build_rotation(quats, normalize=False)
    loss = output.sum()
    loss.backward()
    autograd_grads = quats.grad.clone()

    assert compare_grads(my_grads, autograd_grads, tolerance), "Mismatch in gradients for single quaternion!"
    print("Gradients match for single quaternion.")

    # 3. Test with multiple quaternions
    quats = torch.randn((5, 4), dtype=torch.float64).requires_grad_()

    # My grads
    output = quats_to_R.apply(quats)
    loss = output.sum()
    loss.backward()
    my_grads = quats.grad.clone()

    # Autograd grads
    quats.grad.zero_()
    output = build_rotation(quats, normalize=False)
    loss = output.sum()
    loss.backward()
    autograd_grads = quats.grad.clone()

    assert compare_grads(my_grads, autograd_grads, tolerance), "Mismatch in gradients for multiple quaternions!"
    print("Gradients match for multiple quaternions.")

    # 4. Test edge cases: zero quaternion
    quats = torch.zeros((1, 4), requires_grad=True, dtype=torch.float64)

    # My grads
    output = quats_to_R.apply(quats)
    loss = output.sum()
    loss.backward()
    my_grads = quats.grad.clone()

    # Autograd grads
    quats.grad.zero_()
    output = build_rotation(quats, normalize=False)
    loss = output.sum()
    loss.backward()
    autograd_grads = quats.grad.clone()

    assert compare_grads(my_grads, autograd_grads, tolerance), "Mismatch in gradients for zero quaternion!"
    print("Gradients match for zero quaternion.")

    # 5. Test with larger batch sizes
    batch_size = 10
    quats = torch.randn((batch_size, 4), requires_grad=True, dtype=torch.float64)

    # My grads
    output = quats_to_R.apply(quats)
    loss = output.sum()
    loss.backward()
    my_grads = quats.grad.clone()

    # Autograd grads
    quats.grad.zero_()
    output = build_rotation(quats, normalize=False)
    loss = output.sum()
    loss.backward()
    autograd_grads = quats.grad.clone()

    assert compare_grads(my_grads, autograd_grads, tolerance), "Mismatch in gradients for larger batch size!"
    print("Gradients match for larger batch size.")
    print("Normalization test passed.")

# Run the tests
run_quats_to_R_tests()

Gradcheck passed.
Quats.grad: tensor([[0., 0., 0., 0.]], dtype=torch.float64)
Gradients match for single quaternion.
Gradients match for multiple quaternions.
Gradients match for zero quaternion.
Gradients match for larger batch size.
Normalization test passed.


In [12]:
class normalize_quats(torch.autograd.Function):
    @staticmethod
    def forward(ctx, quats: torch.Tensor):
        """Quats are a nx4 tensor"""
        ctx.save_for_backward(quats)
        return quats / quats.norm(dim=1, keepdim=True)
    
    @staticmethod
    def backward(ctx, grad_output: torch.Tensor):
        quats = ctx.saved_tensors[0]
        norm = quats.norm(dim=1, keepdim=True)
        norm_cube = norm ** 3

        quats_outer = torch.einsum('ni,nj->nij', quats, quats)
        eye = torch.eye(4, dtype=quats.dtype, device=quats.device).unsqueeze(0)

        jacobian = (eye / norm.unsqueeze(2)) - (quats_outer / norm_cube.unsqueeze(2))
        grad_input = torch.einsum('nij,nj->ni', jacobian, grad_output)
        return grad_input

def test_normalize_quats(quats: torch.Tensor):
    return quats / quats.norm(dim=1, keepdim=True)

# Helper function to compare gradients
def compare_grads(my_grads, autograd_grads, tolerance=1e-6):
    return torch.allclose(my_grads, autograd_grads, atol=tolerance, rtol=tolerance)

# Test cases
def run_tests():
    tolerance = 1e-6

    # 1. Gradcheck for numerical gradient correctness
    quats = torch.tensor([[1.0, 0.0, 0.0, 0.0]], requires_grad=True, dtype=torch.float64)
    gradcheck_passed = gradcheck(normalize_quats.apply, (quats,))
    assert gradcheck_passed, "Gradcheck failed!"
    print("Gradcheck passed.")

    # 2. Verify backward computation with autograd (single quaternion)
    quats = torch.tensor([[1.0, 0.0, 0.0, 0.0]], requires_grad=True, dtype=torch.float64)

    # My grads
    output = normalize_quats.apply(quats)
    loss = output.sum()
    loss.backward()
    my_grads = quats.grad.clone()

    # Autograd grads
    quats.grad.zero_()
    output = test_normalize_quats(quats)
    loss = output.sum()
    loss.backward()
    autograd_grads = quats.grad.clone()

    assert compare_grads(my_grads, autograd_grads, tolerance), "My grads do not match autograd grads (single quaternion)!"
    print("My grads match autograd grads (single quaternion).")

    # 3. Test with multiple quaternions
    quats = torch.tensor([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.5, 0.5, 0.5, 0.5]], requires_grad=True, dtype=torch.float64)

    # My grads
    output = normalize_quats.apply(quats)
    loss = output.sum()
    loss.backward()
    my_grads = quats.grad.clone()

    # Autograd grads
    quats.grad.zero_()
    output = test_normalize_quats(quats)
    loss = output.sum()
    loss.backward()
    autograd_grads = quats.grad.clone()

    assert compare_grads(my_grads, autograd_grads, tolerance), "My grads do not match autograd grads (multiple quaternions)!"
    print("My grads match autograd grads (multiple quaternions).")

    # 4. Edge case: Small and large values
    quats = torch.tensor([[1e-5, 1e-5, 1e-5, 1e-5], [1e5, 0.0, 0.0, 0.0]], requires_grad=True, dtype=torch.float64)

    # My grads
    output = normalize_quats.apply(quats)
    loss = output.sum()
    loss.backward()
    my_grads = quats.grad.clone()

    # Autograd grads
    quats.grad.zero_()
    output = test_normalize_quats(quats)
    loss = output.sum()
    loss.backward()
    autograd_grads = quats.grad.clone()

    assert compare_grads(my_grads, autograd_grads, tolerance), "My grads do not match autograd grads (edge cases)!"
    print("My grads match autograd grads (edge cases).")

# Run the tests
run_tests()

Gradcheck passed.
My grads match autograd grads (single quaternion).
My grads match autograd grads (multiple quaternions).
My grads match autograd grads (edge cases).


In [14]:
import torch
from torch.autograd.gradcheck import gradcheck

class scale_to_s_matrix(torch.autograd.Function):
    @staticmethod
    def forward(ctx, s: torch.Tensor):
        """Takes the nx3 tensor and returns the nx3x3 diagonal matrix"""
        ctx.save_for_backward(s)
        return torch.diag_embed(s)
    
    @staticmethod
    def backward(ctx, grad_output: torch.Tensor):
        """Grad output is a nx3x3 tensor"""
        s = ctx.saved_tensors[0]
        deriv_wr_s1 = grad_output[:, 0, 0].view(-1, 1)
        deriv_wr_s2 = grad_output[:, 1, 1].view(-1, 1)
        deriv_wr_s3 = grad_output[:, 2, 2].view(-1, 1)
        return torch.cat([deriv_wr_s1, deriv_wr_s2, deriv_wr_s3], dim=1)

def test_scale_to_s_matrix(s: torch.Tensor):
    return torch.diag_embed(s)

# Helper function to compare gradients
def compare_grads(my_grads, autograd_grads, tolerance=1e-6):
    return torch.allclose(my_grads, autograd_grads, atol=tolerance, rtol=tolerance)

# Test cases
def run_tests():
    tolerance = 1e-6

    # 1. Gradcheck for numerical gradient correctness
    s = torch.tensor([[1.0, 2.0, 3.0]], requires_grad=True, dtype=torch.float64)
    gradcheck_passed = gradcheck(scale_to_s_matrix.apply, (s,))
    assert gradcheck_passed, "Gradcheck failed!"
    print("Gradcheck passed.")

    # 2. Verify backward computation with autograd (single input)
    s = torch.tensor([[1.0, 2.0, 3.0]], requires_grad=True, dtype=torch.float64)

    # My grads
    output = scale_to_s_matrix.apply(s)
    loss = output.sum()
    loss.backward()
    my_grads = s.grad.clone()

    # Autograd grads
    s.grad.zero_()
    output = test_scale_to_s_matrix(s)
    loss = output.sum()
    loss.backward()
    autograd_grads = s.grad.clone()

    assert compare_grads(my_grads, autograd_grads, tolerance), "My grads do not match autograd grads (single input)!"
    print("My grads match autograd grads (single input).")

    # 3. Test with multiple inputs
    s = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], requires_grad=True, dtype=torch.float64)

    # My grads
    output = scale_to_s_matrix.apply(s)
    loss = output.sum()
    loss.backward()
    my_grads = s.grad.clone()

    # Autograd grads
    s.grad.zero_()
    output = test_scale_to_s_matrix(s)
    loss = output.sum()
    loss.backward()
    autograd_grads = s.grad.clone()

    assert compare_grads(my_grads, autograd_grads, tolerance), "My grads do not match autograd grads (multiple inputs)!"
    print("My grads match autograd grads (multiple inputs).")

    # 4. Edge case: Small and large values
    s = torch.tensor([[1e-5, 1e-5, 1e-5], [1e5, 1e5, 1e5]], requires_grad=True, dtype=torch.float64)

    # My grads
    output = scale_to_s_matrix.apply(s)
    loss = output.sum()
    loss.backward()
    my_grads = s.grad.clone()

    # Autograd grads
    s.grad.zero_()
    output = test_scale_to_s_matrix(s)
    loss = output.sum()
    loss.backward()
    autograd_grads = s.grad.clone()

    assert compare_grads(my_grads, autograd_grads, tolerance), "My grads do not match autograd grads (edge cases)!"
    print("My grads match autograd grads (edge cases).")

# Run the tests
run_tests()

Gradcheck passed.
My grads match autograd grads (single input).
My grads match autograd grads (multiple inputs).
My grads match autograd grads (edge cases).


In [15]:
import torch
from torch.autograd.gradcheck import gradcheck

class scaling_exp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, s: torch.Tensor):
        ctx.save_for_backward(s)
        return torch.exp(s)
    
    @staticmethod
    def backward(ctx, grad_output: torch.Tensor):
        s = ctx.saved_tensors[0]
        return grad_output * torch.exp(s)

# Define a test function using PyTorch's autograd
def test_scaling_exp(s: torch.Tensor):
    return torch.exp(s)

# Helper function to compare gradients
def compare_grads(my_grads, autograd_grads, tolerance=1e-6):
    return torch.allclose(my_grads, autograd_grads, atol=tolerance, rtol=tolerance)

# Test cases
def run_tests():
    tolerance = 1e-6

    # 1. Gradcheck for numerical gradient correctness
    s = torch.tensor([1.0], requires_grad=True, dtype=torch.float64)
    gradcheck_passed = gradcheck(scaling_exp.apply, (s,))
    assert gradcheck_passed, "Gradcheck failed!"
    print("Gradcheck passed.")

    # 2. Verify backward computation with autograd (single value)
    s = torch.tensor([1.0], requires_grad=True, dtype=torch.float64)

    # My grads
    output = scaling_exp.apply(s)
    loss = output.sum()
    loss.backward()
    my_grads = s.grad.clone()

    # Autograd grads
    s.grad.zero_()
    output = test_scaling_exp(s)
    loss = output.sum()
    loss.backward()
    autograd_grads = s.grad.clone()

    assert compare_grads(my_grads, autograd_grads, tolerance), "My grads do not match autograd grads (single value)!"
    print("My grads match autograd grads (single value).")

    # 3. Test with multiple values
    s = torch.tensor([1.0, 2.0, 3.0], requires_grad=True, dtype=torch.float64)

    # My grads
    output = scaling_exp.apply(s)
    loss = output.sum()
    loss.backward()
    my_grads = s.grad.clone()

    # Autograd grads
    s.grad.zero_()
    output = test_scaling_exp(s)
    loss = output.sum()
    loss.backward()
    autograd_grads = s.grad.clone()

    assert compare_grads(my_grads, autograd_grads, tolerance), "My grads do not match autograd grads (multiple values)!"
    print("My grads match autograd grads (multiple values).")

    # 4. Edge case: Very small and very large values
    s = torch.tensor([1e-5, 1e5], requires_grad=True, dtype=torch.float64)

    # My grads
    output = scaling_exp.apply(s)
    loss = output.sum()
    loss.backward()
    my_grads = s.grad.clone()

    # Autograd grads
    s.grad.zero_()
    output = test_scaling_exp(s)
    loss = output.sum()
    loss.backward()
    autograd_grads = s.grad.clone()

    assert compare_grads(my_grads, autograd_grads, tolerance), "My grads do not match autograd grads (edge cases)!"
    print("My grads match autograd grads (edge cases).")

    # 5. Test with a tensor containing negative values
    s = torch.tensor([-1.0, -2.0, -3.0], requires_grad=True, dtype=torch.float64)

    # My grads
    output = scaling_exp.apply(s)
    loss = output.sum()
    loss.backward()
    my_grads = s.grad.clone()

    # Autograd grads
    s.grad.zero_()
    output = test_scaling_exp(s)
    loss = output.sum()
    loss.backward()
    autograd_grads = s.grad.clone()

    assert compare_grads(my_grads, autograd_grads, tolerance), "My grads do not match autograd grads (negative values)!"
    print("My grads match autograd grads (negative values).")

# Run the tests
run_tests()

Gradcheck passed.
My grads match autograd grads (single value).
My grads match autograd grads (multiple values).
My grads match autograd grads (edge cases).
My grads match autograd grads (negative values).


In [14]:
import torch
from torch.autograd.gradcheck import gradcheck

class mean_3d_to_camera_space(torch.autograd.Function):
    @staticmethod
    def forward(ctx, mean_3d: torch.Tensor, extrinsic_matrix: torch.Tensor):
        ctx.save_for_backward(extrinsic_matrix)
        return torch.einsum("nk, kh->nh", mean_3d, extrinsic_matrix)
    
    @staticmethod
    def backward(ctx, grad_output: torch.Tensor):
        extrinsic_matrix = ctx.saved_tensors[0]
        mean_3d_grad = torch.einsum("nh,hj->nj", grad_output, extrinsic_matrix.transpose(0, 1))
        return mean_3d_grad, None

# Helper function to compare gradients
def compare_grads(my_grads, autograd_grads, tolerance=1e-6):
    return torch.allclose(my_grads, autograd_grads, atol=tolerance, rtol=tolerance)

# Define a test function using PyTorch's autograd
def test_mean_3d_to_camera_space(mean_3d, extrinsic_matrix):
    return torch.einsum("nk, kh->nh", mean_3d, extrinsic_matrix)

# Test cases
def run_tests():
    tolerance = 1e-6

    # 1. Gradcheck for numerical gradient correctness
    mean_3d = torch.rand(1, 4, requires_grad=True, dtype=torch.float64)
    extrinsic_matrix = torch.rand(4, 4, requires_grad=False, dtype=torch.float64)
    gradcheck_passed = gradcheck(mean_3d_to_camera_space.apply, (mean_3d, extrinsic_matrix))
    assert gradcheck_passed, "Gradcheck failed!"
    print("Gradcheck passed.")

    # 2. Verify backward computation with autograd
    mean_3d = torch.rand(1, 4, requires_grad=True, dtype=torch.float64)
    extrinsic_matrix = torch.rand(4, 4, requires_grad=True, dtype=torch.float64)

    # My grads
    output = mean_3d_to_camera_space.apply(mean_3d, extrinsic_matrix)
    loss = output.sum()
    loss.backward()
    my_grads_mean_3d = mean_3d.grad.clone()

    # Autograd grads
    mean_3d.grad.zero_()
    output = test_mean_3d_to_camera_space(mean_3d, extrinsic_matrix)
    loss = output.sum()
    loss.backward()
    autograd_grads_mean_3d = mean_3d.grad.clone()

    assert compare_grads(my_grads_mean_3d, autograd_grads_mean_3d, tolerance), "My grads do not match autograd grads (mean_3d)!"
    print("My grads match autograd grads.")

    # 3. Edge case: Batch size of 1
    mean_3d = torch.rand(1, 4, requires_grad=True, dtype=torch.float64)
    extrinsic_matrix = torch.rand(4, 4, requires_grad=True, dtype=torch.float64)

    output = mean_3d_to_camera_space.apply(mean_3d, extrinsic_matrix)
    loss = output.sum()
    loss.backward()
    print("Edge case (batch size 1) passed.")

    # 4. Edge case: Very small and very large values
    mean_3d = torch.tensor([[1e-5, 1e5, -1e5, 1.0]], requires_grad=True, dtype=torch.float64)
    extrinsic_matrix = torch.rand(4, 4, requires_grad=True, dtype=torch.float64)

    output = mean_3d_to_camera_space.apply(mean_3d, extrinsic_matrix)
    loss = output.sum()
    loss.backward()
    print("Edge case (very small/large values) passed.")

    # 5. Test with zero matrices
    mean_3d = torch.zeros(3, 4, requires_grad=True, dtype=torch.float64)
    extrinsic_matrix = torch.zeros(4, 4, requires_grad=True, dtype=torch.float64)

    output = mean_3d_to_camera_space.apply(mean_3d, extrinsic_matrix)
    loss = output.sum()
    loss.backward()
    assert torch.all(mean_3d.grad == 0), "Gradients for mean_3d should be zero!"
    print("Zero matrices test passed.")

# Run the tests
run_tests()

Gradcheck passed.
My grads match autograd grads.
Edge case (batch size 1) passed.
Edge case (very small/large values) passed.
Zero matrices test passed.


In [16]:
import torch
from torch.autograd.gradcheck import gradcheck


class camera_space_to_pixel_space(torch.autograd.Function):
    @staticmethod
    def forward(ctx, mean_3d: torch.Tensor, intrinsic_matrix: torch.Tensor):
        ctx.save_for_backward(intrinsic_matrix)
        return torch.einsum("nk, kh->nh", mean_3d, intrinsic_matrix)
    
    @staticmethod
    def backward(ctx, grad_output: torch.Tensor):   
        intrinsic_matrix = ctx.saved_tensors[0]
        mean_3d_grad = torch.einsum("nh,hj->nj", grad_output, intrinsic_matrix.transpose(0, 1))
        return mean_3d_grad, None

# Helper function to compare gradients
def compare_grads(my_grads, autograd_grads, tolerance=1e-6):
    return torch.allclose(my_grads, autograd_grads, atol=tolerance, rtol=tolerance)

# Define a test function using PyTorch's autograd
def test_camera_space_to_pixel_space(mean_3d, intrinsic_matrix):
    return torch.einsum("nk, kh->nh", mean_3d, intrinsic_matrix)

# Test cases
def run_tests():
    tolerance = 1e-6

    # 1. Gradcheck for numerical gradient correctness
    mean_3d = torch.rand(3, 4, requires_grad=True, dtype=torch.float64)
    intrinsic_matrix = torch.rand(4, 4, requires_grad=False, dtype=torch.float64)
    gradcheck_passed = gradcheck(camera_space_to_pixel_space.apply, (mean_3d, intrinsic_matrix))
    assert gradcheck_passed, "Gradcheck failed!"
    print("Gradcheck passed.")

    # 2. Verify backward computation with autograd
    mean_3d = torch.rand(3, 4, requires_grad=True, dtype=torch.float64)
    intrinsic_matrix = torch.rand(4, 4, requires_grad=False, dtype=torch.float64)

    # My grads
    output = camera_space_to_pixel_space.apply(mean_3d, intrinsic_matrix)
    loss = output.sum()
    loss.backward()
    my_grads_mean_3d = mean_3d.grad.clone()

    # Autograd grads
    mean_3d.grad.zero_()
    output = test_camera_space_to_pixel_space(mean_3d, intrinsic_matrix)
    loss = output.sum()
    loss.backward()
    autograd_grads_mean_3d = mean_3d.grad.clone()

    assert compare_grads(my_grads_mean_3d, autograd_grads_mean_3d, tolerance), "My grads do not match autograd grads (mean_3d)!"
    print("My grads match autograd grads.")

    # 3. Edge case: Batch size of 1
    mean_3d = torch.rand(1, 4, requires_grad=True, dtype=torch.float64)
    intrinsic_matrix = torch.rand(4, 4, requires_grad=False, dtype=torch.float64)

    output = camera_space_to_pixel_space.apply(mean_3d, intrinsic_matrix)
    loss = output.sum()
    loss.backward()
    print("Edge case (batch size 1) passed.")

    # 4. Edge case: Very small and very large values
    mean_3d = torch.tensor([[1e-5, 1e5, -1e5, 1.0]], requires_grad=True, dtype=torch.float64)
    intrinsic_matrix = torch.rand(4, 4, requires_grad=False, dtype=torch.float64)

    output = camera_space_to_pixel_space.apply(mean_3d, intrinsic_matrix)
    loss = output.sum()
    loss.backward()
    print("Edge case (very small/large values) passed.")

    # 5. Test with zero matrices
    mean_3d = torch.zeros(3, 4, requires_grad=True, dtype=torch.float64)
    intrinsic_matrix = torch.zeros(4, 4, requires_grad=False, dtype=torch.float64)

    output = camera_space_to_pixel_space.apply(mean_3d, intrinsic_matrix)
    loss = output.sum()
    loss.backward()
    assert torch.all(mean_3d.grad == 0), "Gradients for mean_3d should be zero!"
    print("Zero matrices test passed.")

# Run the tests
run_tests()

Gradcheck passed.
My grads match autograd grads.
Edge case (batch size 1) passed.
Edge case (very small/large values) passed.
Zero matrices test passed.


In [57]:
import torch
from torch.autograd.gradcheck import gradcheck

class ndc_to_pixels(torch.autograd.Function):
    @staticmethod
    def forward(ctx, ndc: torch.Tensor, dimension: list):
        """ndc is a nx3 tensor where the last dimension is the z component
        
        dimension are the height and width of the image
        """
        ctx.save_for_backward(torch.tensor(dimension))
        ndc = ndc.clone()  # To avoid modifying input in-place
        ndc[:, 0] = (ndc[:, 0] + 1) * (dimension[1] - 1) * 0.5
        ndc[:, 1] = (ndc[:, 1] + 1) * (dimension[0] - 1) * 0.5
        return ndc
    
    @staticmethod
    def backward(ctx, grad_output):
        dimension = ctx.saved_tensors[0]
        grad_ndc = grad_output.clone()

        # Compute the gradient for ndc
        grad_ndc[:, 0] *= (dimension[1] - 1) * 0.5
        grad_ndc[:, 1] *= (dimension[0] - 1) * 0.5
        # grad_ndc[:, 2] = 0  # z-component has no effect on pixel coordinates
        return grad_ndc, None

# Helper function to compare gradients
def compare_grads(my_grads, autograd_grads, tolerance=1e-6):
    return torch.allclose(my_grads, autograd_grads, atol=tolerance, rtol=tolerance)

# Define a test function using PyTorch's autograd
def test_ndc_to_pixels(ndc: torch.Tensor, dimension: list):
    ndc = ndc.clone()
    ndc[:, 0] = (ndc[:, 0] + 1) * (dimension[1] - 1) * 0.5
    ndc[:, 1] = (ndc[:, 1] + 1) * (dimension[0] - 1) * 0.5
    return ndc

# Test cases
def run_tests():
    tolerance = 1e-6

    # 1. Gradcheck for numerical gradient correctness
    ndc = torch.rand(5, 3, requires_grad=True, dtype=torch.float64)
    dimension = [720, 1280]
    gradcheck_passed = gradcheck(ndc_to_pixels.apply, (ndc, dimension))
    assert gradcheck_passed, "Gradcheck failed!"
    print("Gradcheck passed.")

    # 2. Verify backward computation with autograd
    ndc = torch.rand(5, 3, requires_grad=True, dtype=torch.float64)
    dimension = [720, 1280]

    # My grads
    output = ndc_to_pixels.apply(ndc, dimension)
    loss = output.sum()
    loss.backward()
    my_grads_ndc = ndc.grad.clone()

    # Autograd grads
    ndc.grad.zero_()
    output = test_ndc_to_pixels(ndc, dimension)
    loss = output.sum()
    loss.backward()
    autograd_grads_ndc = ndc.grad.clone()

    assert compare_grads(my_grads_ndc, autograd_grads_ndc, tolerance), "My grads do not match autograd grads!"
    print("My grads match autograd grads.")

    # 3. Edge case: NDC values at boundaries
    ndc = torch.tensor([[-1.0, -1.0, 0.0], [1.0, 1.0, 0.0]], requires_grad=True, dtype=torch.float64)
    dimension = [720, 1280]

    output = ndc_to_pixels.apply(ndc, dimension)
    assert torch.all(output[:, 0] == torch.tensor([0.0, 1279.0], dtype=torch.float64)), "X coordinates incorrect at boundaries!"
    assert torch.all(output[:, 1] == torch.tensor([0.0, 719.0], dtype=torch.float64)), "Y coordinates incorrect at boundaries!"
    print("Edge case (boundaries) passed.")

    # 4. Edge case: Single NDC point
    ndc = torch.tensor([[0.0, 0.0, 0.0]], requires_grad=True, dtype=torch.float64)
    dimension = [720, 1280]

    output = ndc_to_pixels.apply(ndc, dimension)
    print(output)
    assert torch.allclose(output[0, :2], torch.tensor([639.0, 359.0], dtype=torch.float64), atol=1e-2, rtol=1e-2), "Incorrect result for single NDC point!"
    print("Edge case (single point) passed.")

# Run the tests
run_tests()

Gradcheck passed.
My grads match autograd grads.
Edge case (boundaries) passed.
tensor([[639.5000, 359.5000,   0.0000]], dtype=torch.float64,
       grad_fn=<ndc_to_pixelsBackward>)
Edge case (single point) passed.
