In [10]:
import torch
from torch.autograd.gradcheck import gradcheck

from splat.custom_backwards_implementation.gaussian_weight_derivatives import (
    backward_final_color_launcher,
)
from splat.utils import build_rotation


class final_color(torch.autograd.Function):
    @staticmethod
    def forward(ctx, color: torch.Tensor, current_T: torch.Tensor, alpha: torch.Tensor):
        """Color is a nx3 tensor, weight is a nx1 tensor, alpha is a nx1 tensor"""
        ctx.save_for_backward(color, current_T, alpha)
        return color * current_T * alpha
    
    @staticmethod
    def backward(ctx, grad_output: torch.Tensor):
        """Output of forward is a nx3 tensor so the grad_output is a nx3 tensor"""
        color, current_T, alpha = ctx.saved_tensors
        print(grad_output)
        grad_color = torch.zeros_like(color).to(color.device)
        grad_alpha = torch.zeros_like(alpha).to(alpha.device)
        backward_final_color_launcher(
            grad_output.contiguous(),
            color.contiguous(),
            current_T.contiguous(),
            alpha.contiguous(),
            grad_color.contiguous(),
            grad_alpha.contiguous()
        )
        return grad_color, None, grad_alpha

class final_color_pytorch_gradcheck(torch.autograd.Function):
    @staticmethod
    def forward(ctx, color: torch.Tensor, current_T: torch.Tensor, alpha: torch.Tensor):
        """Color is a nx3 tensor, weight is a nx1 tensor, alpha is a nx1 tensor"""
        ctx.save_for_backward(color, current_T, alpha)
        return color * current_T * alpha
    
    @staticmethod
    def backward(ctx, grad_output: torch.Tensor):
        """Output of forward is a nx3 tensor so the grad_output is a nx3 tensor"""
        color, current_T, alpha = ctx.saved_tensors
        grad_color = grad_output * current_T * alpha
        grad_alpha = (grad_output * color * current_T).sum(dim=1, keepdim=True)
        return grad_color, None, grad_alpha
    
def autograd_test(color: torch.Tensor, current_T: torch.Tensor, alpha: torch.Tensor):
    return color * current_T * alpha

# Define helper function for comparing gradients
def compare_grads(my_grads, autograd_grads, tolerance=1e-6):
    return torch.allclose(my_grads, autograd_grads, atol=tolerance, rtol=tolerance)

# first we check with gradcheck
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32
def run_tests():
    tolerance = 1e-4
    eps = 1e-3
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dtype = torch.float32

    color = torch.tensor([[1.0, 2.0, 3.0]], requires_grad=True, dtype=dtype, device=device)
    current_T = torch.tensor([[4.0]], dtype=dtype, device=device)
    alpha = torch.tensor([[5.0]], requires_grad=True, dtype=dtype, device=device)
    
    # Gradcheck for both implementations
    assert gradcheck(final_color.apply, (color, current_T, alpha), eps=eps), "Gradcheck failed for final_color!"
    print("Gradcheck passed for final_color.")

    assert gradcheck(final_color_pytorch_gradcheck.apply, (color, current_T, alpha), eps=eps), "Gradcheck failed for final_color_pytorch_gradcheck!"
    print("Gradcheck passed for final_color_pytorch_gradcheck.")

    # 2. Verify backward computation with autograd
    color = torch.randn((1, 3), requires_grad=True, dtype=dtype, device=device)
    current_T = torch.ones((1, 1), requires_grad=False, dtype=dtype, device=device)
    alpha = torch.rand((1, 1), requires_grad=True, dtype=dtype, device=device)

    # My grads
    output = final_color.apply(color, current_T, alpha)
    loss = output.sum()
    loss.backward()
    my_grads_color = color.grad.clone()
    my_grads_alpha = alpha.grad.clone()

    # Autograd grads
    color.grad.zero_()
    alpha.grad.zero_()
    output = autograd_test(color, current_T, alpha)
    loss = output.sum()
    loss.backward()
    autograd_grads_color = color.grad.clone()
    autograd_grads_alpha = alpha.grad.clone()

    assert compare_grads(my_grads_color, autograd_grads_color, tolerance), "Mismatch in gradients for color!"
    assert compare_grads(my_grads_alpha, autograd_grads_alpha, tolerance), "Mismatch in gradients for alpha!"
    print("Gradients match for autograd.")

    # 3. Test edge cases: zero tensors
    color = torch.zeros((1, 3), requires_grad=True, dtype=dtype, device=device)
    current_T = torch.zeros((1, 1), requires_grad=False, dtype=dtype, device=device)
    alpha = torch.zeros((1, 1), requires_grad=True, dtype=dtype, device=device)

    output = final_color.apply(color, current_T, alpha)
    assert torch.all(output == 0), "Output is not zero for zero input tensors."
    print("Edge case test passed for zero tensors.")

    # 4. Test with large batch sizes
    batch_size = 10
    color = torch.randn((batch_size, 3), requires_grad=True, dtype=dtype, device=device)
    current_T = torch.ones((batch_size, 1), requires_grad=False, dtype=dtype, device=device)
    alpha = torch.rand((batch_size, 1), requires_grad=True, dtype=dtype, device=device)

    # My grads
    output = final_color.apply(color, current_T, alpha)
    loss = output.sum()
    loss.backward()
    my_grads_color = color.grad.clone()
    my_grads_alpha = alpha.grad.clone()

    # Autograd grads
    color.grad.zero_()
    alpha.grad.zero_()
    output = autograd_test(color, current_T, alpha)
    loss = output.sum()
    loss.backward()
    autograd_grads_color = color.grad.clone()
    autograd_grads_alpha = alpha.grad.clone()

    assert compare_grads(my_grads_color, autograd_grads_color, tolerance), "Mismatch in gradients for large batch size (color)!"
    assert compare_grads(my_grads_alpha, autograd_grads_alpha, tolerance), "Mismatch in gradients for large batch size (alpha)!"
    print("Gradients match for large batch size.")

    # 5. Test with extreme values
    color = torch.full((10, 3), 1e6, requires_grad=True, dtype=dtype, device=device)
    current_T = torch.full((10, 1), 1e-6, requires_grad=False, dtype=dtype, device=device)
    alpha = torch.full((10, 1), 1e6, requires_grad=True, dtype=dtype, device=device)

    output = final_color.apply(color, current_T, alpha)
    loss = output.sum()
    loss.backward()
    assert torch.isfinite(color.grad).all(), "Gradients for color are not finite with extreme values!"
    assert torch.isfinite(alpha.grad).all(), "Gradients for alpha are not finite with extreme values!"
    print("Extreme value test passed.")

# Run the tests
run_tests()

tensor([[0., 0., 0.]], device='cuda:0')
Gradcheck passed for final_color.
Gradcheck passed for final_color_pytorch_gradcheck.
Gradients match for autograd.
Edge case test passed for zero tensors.
Gradients match for large batch size.
Extreme value test passed.


In [2]:
import torch
from torch.autograd.gradcheck import gradcheck

from splat.custom_backwards_implementation.gaussian_weight_derivatives import (
    get_alpha_backward_launcher,
)

class get_alpha(torch.autograd.Function):
    @staticmethod
    def forward(ctx, gaussian_strength: torch.Tensor, unactivated_opacity: torch.Tensor):
        """Gaussian strength is a nx1 tensor, unactivated opacity is a nx1 tensor"""
        ctx.save_for_backward(gaussian_strength, unactivated_opacity)
        activated_opacity = torch.sigmoid(unactivated_opacity)
        return gaussian_strength * activated_opacity
    
    @staticmethod
    def backward(ctx, grad_output: torch.Tensor):
        """Output of forward is a nx1 tensor so the grad_output is a nx1 tensor"""
        gaussian_strength, unactivated_opacity = ctx.saved_tensors
        grad_gaussian_strength = torch.zeros_like(gaussian_strength).to(gaussian_strength.device)
        grad_unactivated_opacity = torch.zeros_like(unactivated_opacity).to(unactivated_opacity.device)
        get_alpha_backward_launcher(
            grad_output.contiguous(),
            gaussian_strength.contiguous(),
            unactivated_opacity.contiguous(),
            grad_gaussian_strength.contiguous(),
            grad_unactivated_opacity.contiguous()
        )
        return grad_gaussian_strength, grad_unactivated_opacity
# Define a test function using torch's autograd
def autograd_test(gaussian_strength: torch.Tensor, unactivated_opacity: torch.Tensor):
    activated_opacity = torch.sigmoid(unactivated_opacity)
    return gaussian_strength * activated_opacity

# Define helper function for comparing gradients
def compare_grads(my_grads, autograd_grads, tolerance=1e-6):
    return torch.allclose(my_grads, autograd_grads, atol=tolerance, rtol=tolerance)

def run_tests():
    tolerance = 1e-4
    eps = 1e-3
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dtype = torch.float32

    # 1. Gradcheck for numerical gradient correctness
    gaussian_strength = torch.tensor([[2.0]], requires_grad=True, dtype=torch.float32, device=device)
    unactivated_opacity = torch.tensor([[3.0]], requires_grad=True, dtype=torch.float32, device=device)

    assert gradcheck(get_alpha.apply, (gaussian_strength, unactivated_opacity), eps=eps), "Gradcheck failed for get_alpha!"
    print("Gradcheck passed for get_alpha.")

    # 2. Verify backward computation with autograd
    gaussian_strength = torch.randn((1, 1), requires_grad=True, dtype=dtype, device=device)
    unactivated_opacity = torch.randn((1, 1), requires_grad=True, dtype=dtype, device=device)

    # My grads
    output = get_alpha.apply(gaussian_strength, unactivated_opacity)
    loss = output.sum()
    loss.backward()
    my_grads_gaussian_strength = gaussian_strength.grad.clone()
    my_grads_unactivated_opacity = unactivated_opacity.grad.clone()

    # Autograd grads
    gaussian_strength.grad.zero_()
    unactivated_opacity.grad.zero_()
    output = autograd_test(gaussian_strength, unactivated_opacity)
    loss = output.sum()
    loss.backward()
    autograd_grads_gaussian_strength = gaussian_strength.grad.clone()
    autograd_grads_unactivated_opacity = unactivated_opacity.grad.clone()

    assert compare_grads(my_grads_gaussian_strength, autograd_grads_gaussian_strength, tolerance), "Mismatch in gradients for gaussian_strength!"
    assert compare_grads(my_grads_unactivated_opacity, autograd_grads_unactivated_opacity, tolerance), "Mismatch in gradients for unactivated_opacity!"
    print("Gradients match for autograd.")

    # 3. Test edge cases: zero tensors
    gaussian_strength = torch.zeros((1, 1), requires_grad=True, dtype=dtype, device=device)
    unactivated_opacity = torch.zeros((1, 1), requires_grad=True, dtype=dtype, device=device)

    output = get_alpha.apply(gaussian_strength, unactivated_opacity)
    assert torch.all(output == 0), "Output is not zero for zero input tensors."
    print("Edge case test passed for zero tensors.")

    # 4. Test with large batch sizes
    batch_size = 10
    gaussian_strength = torch.randn((batch_size, 1), requires_grad=True, dtype=dtype, device=device)
    unactivated_opacity = torch.randn((batch_size, 1), requires_grad=True, dtype=dtype, device=device)

    # My grads
    output = get_alpha.apply(gaussian_strength, unactivated_opacity)
    loss = output.sum()
    loss.backward()
    my_grads_gaussian_strength = gaussian_strength.grad.clone()
    my_grads_unactivated_opacity = unactivated_opacity.grad.clone()

    # Autograd grads
    gaussian_strength.grad.zero_()
    unactivated_opacity.grad.zero_()
    output = autograd_test(gaussian_strength, unactivated_opacity)
    loss = output.sum()
    loss.backward()
    autograd_grads_gaussian_strength = gaussian_strength.grad.clone()
    autograd_grads_unactivated_opacity = unactivated_opacity.grad.clone()

    assert compare_grads(my_grads_gaussian_strength, autograd_grads_gaussian_strength, tolerance), "Mismatch in gradients for gaussian_strength (large batch size)!"
    assert compare_grads(my_grads_unactivated_opacity, autograd_grads_unactivated_opacity, tolerance), "Mismatch in gradients for unactivated_opacity (large batch size)!"
    print("Gradients match for large batch size.")

    # 5. Test with extreme values
    gaussian_strength = torch.full((10, 1), 1e6, requires_grad=True, dtype=dtype, device=device)
    unactivated_opacity = torch.full((10, 1), -1e6, requires_grad=True, dtype=dtype, device=device)  # Test sigmoid near zero

    output = get_alpha.apply(gaussian_strength, unactivated_opacity)
    loss = output.sum()
    loss.backward()
    assert torch.isfinite(gaussian_strength.grad).all(), "Gradients for gaussian_strength are not finite with extreme values!"
    assert torch.isfinite(unactivated_opacity.grad).all(), "Gradients for unactivated_opacity are not finite with extreme values!"
    print("Extreme value test passed.")

# Run the tests
run_tests()



Gradcheck passed for get_alpha.
Gradients match for autograd.
Edge case test passed for zero tensors.
Gradients match for large batch size.
Extreme value test passed.


In [6]:
import torch
from torch.autograd.gradcheck import gradcheck

from splat.custom_backwards_implementation.gaussian_weight_derivatives import (
    gaussian_exp_backward_launcher,
)

class gaussian_exp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, gaussian_weight: torch.Tensor):
        ctx.save_for_backward(gaussian_weight)
        return torch.exp(gaussian_weight)
    
    @staticmethod
    def backward(ctx, grad_output: torch.Tensor):
        gaussian_weight, = ctx.saved_tensors
        grad_gaussian_weight = torch.zeros_like(gaussian_weight).to(gaussian_weight.device)
        gaussian_exp_backward_launcher(
            grad_output.contiguous(),
            gaussian_weight.contiguous(),
            grad_gaussian_weight.contiguous()
        )
        return grad_gaussian_weight

# Define a test function using torch's autograd
def autograd_test(gaussian_weight: torch.Tensor):
    return torch.exp(gaussian_weight)

# Helper function to compare gradients
def compare_grads(my_grads, autograd_grads, tolerance=1e-6):
    return torch.allclose(my_grads, autograd_grads, atol=tolerance, rtol=tolerance)

def run_tests():
    tolerance = 1e-4
    eps = 1e-3
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dtype = torch.float32

    # 1. Gradcheck for numerical gradient correctness
    gaussian_weight = torch.tensor([[2.0]], requires_grad=True, dtype=torch.float32, device=device)

    assert gradcheck(gaussian_exp.apply, (gaussian_weight,), eps=eps), "Gradcheck failed for gaussian_exp!"
    print("Gradcheck passed for gaussian_exp.")

    # 2. Verify backward computation with autograd
    gaussian_weight = torch.randn((1, 1), requires_grad=True, dtype=dtype, device=device)

    # My grads
    output = gaussian_exp.apply(gaussian_weight)
    loss = output.sum()
    loss.backward()
    my_grads_gaussian_weight = gaussian_weight.grad.clone()

    # Autograd grads
    gaussian_weight.grad.zero_()
    output = autograd_test(gaussian_weight)
    loss = output.sum()
    loss.backward()
    autograd_grads_gaussian_weight = gaussian_weight.grad.clone()

    assert compare_grads(my_grads_gaussian_weight, autograd_grads_gaussian_weight, tolerance), "Mismatch in gradients for gaussian_weight!"
    print("Gradients match for autograd.")

    # 3. Test edge cases: zero tensors
    gaussian_weight = torch.zeros((1, 1), requires_grad=True, dtype=dtype, device=device)

    output = gaussian_exp.apply(gaussian_weight)
    expected_output = torch.ones_like(gaussian_weight)
    assert torch.allclose(output, expected_output, atol=tolerance, rtol=tolerance), "Output mismatch for zero input tensors."
    print("Edge case test passed for zero tensors.")

    # 4. Test with large batch sizes
    batch_size = 10
    gaussian_weight = torch.randn((batch_size, 1), requires_grad=True, dtype=dtype, device=device)

    # My grads
    output = gaussian_exp.apply(gaussian_weight)
    loss = output.sum()
    loss.backward()
    my_grads_gaussian_weight = gaussian_weight.grad.clone()

    # Autograd grads
    gaussian_weight.grad.zero_()
    output = autograd_test(gaussian_weight)
    loss = output.sum()
    loss.backward()
    autograd_grads_gaussian_weight = gaussian_weight.grad.clone()

    assert compare_grads(my_grads_gaussian_weight, autograd_grads_gaussian_weight, tolerance), "Mismatch in gradients for large batch size!"
    print("Gradients match for large batch size.")

    # 5. Test with extreme values
    gaussian_weight = torch.full((10, 1), 1e1, requires_grad=True, dtype=dtype, device=device)

    output = gaussian_exp.apply(gaussian_weight)
    loss = output.sum()
    loss.backward()
    assert torch.isfinite(gaussian_weight.grad).all(), "Gradients for gaussian_weight are not finite with extreme values!"
    print("Extreme value test passed.")

# Run the tests
run_tests()

Gradcheck passed for gaussian_exp.
Gradients match for autograd.
Edge case test passed for zero tensors.
Gradients match for large batch size.
Extreme value test passed.




In [3]:
import torch
from torch.autograd.gradcheck import gradcheck
import torch.nn as nn

from splat.custom_backwards_implementation.gaussian_weight_derivatives import (
    gaussian_weight_grad_inv_cov_launcher,
    gaussian_weight_grad_gaussian_mean_launcher,
)

class gaussian_weight(torch.autograd.Function):
    @staticmethod
    def forward(ctx, gaussian_mean: torch.Tensor, inverted_covariance: torch.Tensor, pixel: torch.Tensor):
        """
        Pixel means are a nx2 tensor, inverted covariance is a 2x2 tensor, pixel is a nx2 tensor
        Outputs a nx1 tensor
        """
        ctx.save_for_backward(gaussian_mean, inverted_covariance, pixel)
        diff = (pixel - gaussian_mean).unsqueeze(1)

        inv_cov_mult = torch.einsum('bij,bjk->bik', inverted_covariance, diff.transpose(1, 2))
        return -0.5 * torch.einsum('bij,bjk->bik', diff, inv_cov_mult).squeeze(-1)

    @staticmethod
    def backward(ctx, grad_output: torch.Tensor):
        """Output of forward is a nx1 tensor so the grad_output is a nx1 tensor"""
        gaussian_mean, inverted_covariance, pixel = ctx.saved_tensors
        diff = (pixel - gaussian_mean).unsqueeze(1)  # nx1x2

        grad_inv_cov = torch.zeros_like(inverted_covariance)
        grad_gaussian_mean = torch.zeros_like(gaussian_mean)
        gaussian_weight_grad_inv_cov_launcher(
            grad_output.contiguous(),
            diff.contiguous(),
            grad_inv_cov.contiguous()
        )

        gaussian_weight_grad_gaussian_mean_launcher(
            grad_output.contiguous(),
            diff.contiguous(),
            inverted_covariance.contiguous(),
            grad_gaussian_mean.contiguous()
        )
        return grad_gaussian_mean, grad_inv_cov, None
    

# Define a test function using PyTorch's autograd
def autograd_test(gaussian_mean, inverted_covariance, pixel):
    diff = (pixel - gaussian_mean).unsqueeze(1)
    # 2x2 * 2x1 = 2x1
    inv_cov_mult = torch.einsum('bij,bjk->bik', inverted_covariance, diff.transpose(1, 2))
    return -0.5 * torch.einsum('bij,bjk->bik', diff, inv_cov_mult).squeeze(-1)

# Helper function to compare gradients
def compare_grads(my_grads, autograd_grads, tolerance=1e-6):
    return torch.allclose(my_grads, autograd_grads, atol=tolerance, rtol=tolerance)


def run_tests():
    tolerance = 1e-4
    eps = 1e-3
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dtype = torch.float32

    # 1. Gradcheck for numerical gradient correctness
    gaussian_mean = torch.tensor([[1.0, 2.0]], requires_grad=True, dtype=torch.float32, device=device)
    inverted_covariance = torch.eye(2, requires_grad=True, dtype=torch.float32, device=device).unsqueeze(0)
    pixel = torch.tensor([[3.0, 4.0]], requires_grad=False, dtype=torch.float32, device=device)

    assert gradcheck(gaussian_weight.apply, (gaussian_mean, inverted_covariance, pixel), eps=eps), "Gradcheck failed for gaussian_weight!"
    print("Gradcheck passed for gaussian_weight.")

    # 2. Verify backward computation with autograd
    gaussian_mean = torch.randn((1, 2), requires_grad=True, dtype=dtype, device=device)
    inverted_covariance = nn.Parameter(torch.eye(2, dtype=dtype, device=device).unsqueeze(0))   

    pixel = torch.randn((1, 2), requires_grad=False, dtype=dtype, device=device)

    # My grads
    output = gaussian_weight.apply(gaussian_mean, inverted_covariance, pixel)
    loss = output.sum()
    loss.backward()
    my_grads_gaussian_mean = gaussian_mean.grad.clone()
    my_grads_inv_cov = inverted_covariance.grad.clone()

    # Autograd grads
    gaussian_mean.grad.zero_()
    inverted_covariance.grad.zero_()
    output = autograd_test(gaussian_mean, inverted_covariance, pixel)
    loss = output.sum()
    loss.backward()
    autograd_grads_gaussian_mean = gaussian_mean.grad.clone()
    autograd_grads_inv_cov = inverted_covariance.grad.clone()

    assert compare_grads(my_grads_gaussian_mean, autograd_grads_gaussian_mean, tolerance), "Mismatch in gradients for gaussian_mean!"
    assert compare_grads(my_grads_inv_cov, autograd_grads_inv_cov, tolerance), "Mismatch in gradients for inverted_covariance!"
    print("Gradients match for autograd.")

    # 3. Test edge cases: zero tensors
    gaussian_mean = torch.zeros((1, 2), requires_grad=True, dtype=dtype, device=device)
    inverted_covariance = torch.eye(2, requires_grad=True, dtype=dtype, device=device).unsqueeze(0).clone()
    pixel = torch.zeros((1, 2), requires_grad=False, dtype=dtype, device=device)

    output = gaussian_weight.apply(gaussian_mean, inverted_covariance, pixel)
    assert torch.all(output == 0), "Output is not zero for zero input tensors."
    print("Edge case test passed for zero tensors.")

    # 4. Test with large batch sizes
    batch_size = 10
    gaussian_mean = torch.randn((batch_size, 2), requires_grad=True, dtype=dtype, device=device)
    inverted_covariance = nn.Parameter(torch.stack([torch.eye(2, dtype=dtype, device=device) for _ in range(batch_size)], dim=0))
    pixel = torch.randn((batch_size, 2), requires_grad=False, dtype=dtype, device=device)

    # My grads
    output = gaussian_weight.apply(gaussian_mean, inverted_covariance, pixel)
    loss = output.sum()
    loss.backward()
    my_grads_gaussian_mean = gaussian_mean.grad.clone()
    my_grads_inv_cov = inverted_covariance.grad.clone()

    # Autograd grads
    gaussian_mean.grad.zero_()
    inverted_covariance.grad.zero_()
    output = autograd_test(gaussian_mean, inverted_covariance, pixel)
    loss = output.sum()
    loss.backward()
    autograd_grads_gaussian_mean = gaussian_mean.grad.clone()
    autograd_grads_inv_cov = inverted_covariance.grad.clone()

    assert compare_grads(my_grads_gaussian_mean, autograd_grads_gaussian_mean, tolerance), "Mismatch in gradients for gaussian_mean (large batch size)!"
    assert compare_grads(my_grads_inv_cov, autograd_grads_inv_cov, tolerance), "Mismatch in gradients for inverted_covariance (large batch size)!"
    print("Gradients match for large batch size.")

    # 5. Test with extreme values
    gaussian_mean = torch.full((10, 2), 1e6, requires_grad=True, dtype=dtype, device=device)
    inverted_covariance = torch.stack([torch.eye(2, dtype=dtype, device=device) for _ in range(10)], dim=0).requires_grad_()
    pixel = torch.full((10, 2), -1e6, requires_grad=False, dtype=dtype, device=device)

    output = gaussian_weight.apply(gaussian_mean, inverted_covariance, pixel)
    loss = output.sum()
    loss.backward()
    assert torch.isfinite(gaussian_mean.grad).all(), "Gradients for gaussian_mean are not finite with extreme values!"
    assert torch.isfinite(inverted_covariance.grad).all(), "Gradients for inverted_covariance are not finite with extreme values!"
    print("Extreme value test passed.")

# Run the tests
run_tests()

Gradcheck passed for gaussian_weight.
Gradients match for autograd.
Edge case test passed for zero tensors.
Gradients match for large batch size.
Extreme value test passed.




In [5]:
import torch
from torch.autograd.gradcheck import gradcheck

from splat.custom_backwards_implementation.gaussian_weight_derivatives import mean_3d_to_camera_space_backward_launcher


class mean_3d_to_camera_space(torch.autograd.Function):
    @staticmethod
    def forward(ctx, mean_3d: torch.Tensor, extrinsic_matrix: torch.Tensor):
        ctx.save_for_backward(mean_3d, extrinsic_matrix)
        return torch.einsum("nk, kh->nh", mean_3d, extrinsic_matrix)
    
    @staticmethod
    def backward(ctx, grad_output: torch.Tensor):
        print("grad_output", grad_output.shape)
        mean_3d, extrinsic_matrix = ctx.saved_tensors
        grad_mean_3d = torch.zeros_like(mean_3d)
        mean_3d_to_camera_space_backward_launcher(
            grad_output.contiguous(), 
            extrinsic_matrix.contiguous(), 
            grad_mean_3d.contiguous()
        )
        return grad_mean_3d, None

# Helper function to compare gradients
def compare_grads(my_grads, autograd_grads, tolerance=1e-6):
    return torch.allclose(my_grads, autograd_grads, atol=tolerance, rtol=tolerance)

# Define a test function using PyTorch's autograd
def test_mean_3d_to_camera_space(mean_3d, extrinsic_matrix):
    return torch.einsum("nk, kh->nh", mean_3d, extrinsic_matrix)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32

mean_3d = torch.randn((4, 3), requires_grad=True, dtype=dtype, device=device)
extrinsic_matrix = torch.rand(3, 3, requires_grad=False, dtype=dtype, device=device)

output = mean_3d_to_camera_space.apply(mean_3d, extrinsic_matrix)
loss = output.sum()
loss.backward()
print("custom mean_3d grad", mean_3d.grad)

mean_3d.grad.zero_()
output = test_mean_3d_to_camera_space(mean_3d, extrinsic_matrix)
loss = output.sum()
loss.backward()
print("autograd mean_3d grad", mean_3d.grad)

grad_output torch.Size([4, 3])
custom mean_3d grad tensor([[0.7980, 1.8702, 1.5987],
        [0.7980, 1.8702, 1.5987],
        [0.7980, 1.8702, 1.5987],
        [0.7980, 1.8702, 1.5987]], device='cuda:0')
autograd mean_3d grad tensor([[0.7980, 1.8702, 1.5987],
        [0.7980, 1.8702, 1.5987],
        [0.7980, 1.8702, 1.5987],
        [0.7980, 1.8702, 1.5987]], device='cuda:0')


In [7]:
import torch
from torch.autograd.gradcheck import gradcheck

from splat.custom_backwards_implementation.gaussian_weight_derivatives import mean_3d_to_camera_space_backward_launcher


class mean_3d_to_camera_space(torch.autograd.Function):
    @staticmethod
    def forward(ctx, mean_3d: torch.Tensor, extrinsic_matrix: torch.Tensor):
        ctx.save_for_backward(mean_3d, extrinsic_matrix)
        return torch.einsum("nk, kh->nh", mean_3d, extrinsic_matrix)
    
    @staticmethod
    def backward(ctx, grad_output: torch.Tensor):
        mean_3d, extrinsic_matrix = ctx.saved_tensors
        grad_mean_3d = torch.zeros_like(mean_3d)
        mean_3d_to_camera_space_backward_launcher(
            grad_output.contiguous(), 
            extrinsic_matrix.contiguous(), 
            grad_mean_3d.contiguous()
        )
        return grad_mean_3d, None

# Helper function to compare gradients
def compare_grads(my_grads, autograd_grads, tolerance=1e-6):
    return torch.allclose(my_grads, autograd_grads, atol=tolerance, rtol=tolerance)

# Define a test function using PyTorch's autograd
def test_mean_3d_to_camera_space(mean_3d, extrinsic_matrix):
    return torch.einsum("nk, kh->nh", mean_3d, extrinsic_matrix)

def run_tests():
    tolerance = 1e-4
    eps = 1e-3
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dtype = torch.float32

    # 1. Gradcheck for numerical gradient correctness
    mean_3d = torch.randn((5, 3), requires_grad=True, dtype=torch.float32, device=device)
    extrinsic_matrix = torch.eye(3, requires_grad=False, dtype=torch.float32, device=device)

    assert gradcheck(mean_3d_to_camera_space.apply, (mean_3d, extrinsic_matrix), eps=eps), "Gradcheck failed for mean_3d_to_camera_space!"
    print("Gradcheck passed for mean_3d_to_camera_space.")

    # 2. Verify backward computation with autograd
    mean_3d = torch.randn((5, 3), requires_grad=True, dtype=dtype, device=device)
    extrinsic_matrix = torch.eye(3, requires_grad=False, dtype=dtype, device=device)

    # My grads
    output = mean_3d_to_camera_space.apply(mean_3d, extrinsic_matrix)
    loss = output.sum()
    loss.backward()
    my_grads_mean_3d = mean_3d.grad.clone()

    # Autograd grads
    mean_3d.grad.zero_()
    output = test_mean_3d_to_camera_space(mean_3d, extrinsic_matrix)
    loss = output.sum()
    loss.backward()
    autograd_grads_mean_3d = mean_3d.grad.clone()

    assert compare_grads(my_grads_mean_3d, autograd_grads_mean_3d, tolerance), "Mismatch in gradients for mean_3d!"
    print("Gradients match for autograd.")

    # 3. Test edge cases: zero tensors
    mean_3d = torch.zeros((5, 3), requires_grad=True, dtype=dtype, device=device)
    extrinsic_matrix = torch.eye(3, requires_grad=False, dtype=dtype, device=device)

    output = mean_3d_to_camera_space.apply(mean_3d, extrinsic_matrix)
    assert torch.all(output == 0), "Output is not zero for zero input tensors."
    print("Edge case test passed for zero tensors.")

    # 4. Test with large batch sizes
    batch_size = 100
    mean_3d = torch.randn((batch_size, 3), requires_grad=True, dtype=dtype, device=device)
    extrinsic_matrix = torch.eye(3, requires_grad=False, dtype=dtype, device=device)

    # My grads
    output = mean_3d_to_camera_space.apply(mean_3d, extrinsic_matrix)
    loss = output.sum()
    loss.backward()
    my_grads_mean_3d = mean_3d.grad.clone()

    # Autograd grads
    mean_3d.grad.zero_()
    output = test_mean_3d_to_camera_space(mean_3d, extrinsic_matrix)
    loss = output.sum()
    loss.backward()
    autograd_grads_mean_3d = mean_3d.grad.clone()

    assert compare_grads(my_grads_mean_3d, autograd_grads_mean_3d, tolerance), "Mismatch in gradients for mean_3d (large batch size)!"
    print("Gradients match for large batch size.")

    # 5. Test with extreme values
    mean_3d = torch.full((5, 3), 1e6, requires_grad=True, dtype=dtype, device=device)
    extrinsic_matrix = torch.eye(3, requires_grad=False, dtype=dtype, device=device)

    output = mean_3d_to_camera_space.apply(mean_3d, extrinsic_matrix)
    loss = output.sum()
    loss.backward()
    assert torch.isfinite(mean_3d.grad).all(), "Gradients for mean_3d are not finite with extreme values!"
    print("Extreme value test passed.")

# Run the tests
run_tests()

Gradcheck passed for mean_3d_to_camera_space.
Gradients match for autograd.
Edge case test passed for zero tensors.
Gradients match for large batch size.
Extreme value test passed.


In [2]:
import torch
from torch.autograd.gradcheck import gradcheck

from splat.custom_backwards_implementation.gaussian_weight_derivatives import camera_space_to_pixel_space_backward_launcher


class camera_space_to_pixel_space(torch.autograd.Function):
    @staticmethod
    def forward(ctx, mean_3d: torch.Tensor, intrinsic_matrix: torch.Tensor):
        ctx.save_for_backward(intrinsic_matrix, mean_3d)
        return torch.einsum("nk, kh->nh", mean_3d, intrinsic_matrix)
    
    @staticmethod
    def backward(ctx, grad_output: torch.Tensor):   
        intrinsic_matrix, mean_3d = ctx.saved_tensors
        mean_3d_grad = torch.zeros_like(mean_3d)
        camera_space_to_pixel_space_backward_launcher(
            grad_output.contiguous(), 
            intrinsic_matrix.contiguous(), 
            mean_3d_grad.contiguous()
        )
        return mean_3d_grad, None

# Helper function to compare gradients
def compare_grads(my_grads, autograd_grads, tolerance=1e-6):
    return torch.allclose(my_grads, autograd_grads, atol=tolerance, rtol=tolerance)

# Define a test function using PyTorch's autograd
def test_camera_space_to_pixel_space(mean_3d, intrinsic_matrix):
    return torch.einsum("nk, kh->nh", mean_3d, intrinsic_matrix)


def run_tests():
    tolerance = 1e-4
    eps = 1e-3
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dtype = torch.float32

    # 1. Gradcheck for numerical gradient correctness
    mean_3d = torch.randn((5, 3), requires_grad=True, dtype=torch.float32, device=device)
    intrinsic_matrix = torch.eye(3, requires_grad=False, dtype=torch.float32, device=device)

    assert gradcheck(camera_space_to_pixel_space.apply, (mean_3d, intrinsic_matrix), eps=eps), "Gradcheck failed for camera_space_to_pixel_space!"
    print("Gradcheck passed for camera_space_to_pixel_space.")

    # 2. Verify backward computation with autograd
    mean_3d = torch.randn((5, 3), requires_grad=True, dtype=dtype, device=device)
    intrinsic_matrix = torch.eye(3, requires_grad=False, dtype=dtype, device=device)

    # My grads
    output = camera_space_to_pixel_space.apply(mean_3d, intrinsic_matrix)
    loss = output.sum()
    loss.backward()
    my_grads_mean_3d = mean_3d.grad.clone()

    # Autograd grads
    mean_3d.grad.zero_()
    output = test_camera_space_to_pixel_space(mean_3d, intrinsic_matrix)
    loss = output.sum()
    loss.backward()
    autograd_grads_mean_3d = mean_3d.grad.clone()

    assert compare_grads(my_grads_mean_3d, autograd_grads_mean_3d, tolerance), "Mismatch in gradients for mean_3d!"
    print("Gradients match for autograd.")

    # 3. Test edge cases: zero tensors
    mean_3d = torch.zeros((5, 3), requires_grad=True, dtype=dtype, device=device)
    intrinsic_matrix = torch.eye(3, requires_grad=False, dtype=dtype, device=device)

    output = camera_space_to_pixel_space.apply(mean_3d, intrinsic_matrix)
    assert torch.all(output == 0), "Output is not zero for zero input tensors."
    print("Edge case test passed for zero tensors.")

    # 4. Test with large batch sizes
    batch_size = 100
    mean_3d = torch.randn((batch_size, 3), requires_grad=True, dtype=dtype, device=device)
    intrinsic_matrix = torch.eye(3, requires_grad=False, dtype=dtype, device=device)

    # My grads
    output = camera_space_to_pixel_space.apply(mean_3d, intrinsic_matrix)
    loss = output.sum()
    loss.backward()
    my_grads_mean_3d = mean_3d.grad.clone()

    # Autograd grads
    mean_3d.grad.zero_()
    output = test_camera_space_to_pixel_space(mean_3d, intrinsic_matrix)
    loss = output.sum()
    loss.backward()
    autograd_grads_mean_3d = mean_3d.grad.clone()

    assert compare_grads(my_grads_mean_3d, autograd_grads_mean_3d, tolerance), "Mismatch in gradients for mean_3d (large batch size)!"
    print("Gradients match for large batch size.")

    # 5. Test with extreme values
    mean_3d = torch.full((5, 3), 1e6, requires_grad=True, dtype=dtype, device=device)
    intrinsic_matrix = torch.eye(3, requires_grad=False, dtype=dtype, device=device)

    output = camera_space_to_pixel_space.apply(mean_3d, intrinsic_matrix)
    loss = output.sum()
    loss.backward()
    assert torch.isfinite(mean_3d.grad).all(), "Gradients for mean_3d are not finite with extreme values!"
    print("Extreme value test passed.")

# Run the tests
run_tests()

Gradcheck passed for camera_space_to_pixel_space.
Gradients match for autograd.
Edge case test passed for zero tensors.
Gradients match for large batch size.
Extreme value test passed.


In [1]:
import torch
from torch.autograd.gradcheck import gradcheck

from splat.custom_backwards_implementation.gaussian_weight_derivatives import ndc_to_pixels_backward_launcher
import os

# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

class ndc_to_pixels(torch.autograd.Function):
    @staticmethod
    def forward(ctx, ndc: torch.Tensor, dimension: list):
        """ndc is a nx3 tensor where the last dimension is the z component
        
        dimension are the height and width of the image
        """
        ctx.save_for_backward(torch.tensor(dimension, dtype=torch.float32), ndc)
        ndc = ndc.clone()  # To avoid modifying input in-place
        ndc[:, 0] = (ndc[:, 0] + 1) * (dimension[1] - 1) * 0.5
        ndc[:, 1] = (ndc[:, 1] + 1) * (dimension[0] - 1) * 0.5
        return ndc
    
    @staticmethod
    def backward(ctx, grad_output):
        dimension, ndc = ctx.saved_tensors
        grad_ndc = torch.zeros_like(ndc)
        ndc_to_pixels_backward_launcher(
            grad_output.contiguous(), 
            dimension.contiguous(), 
            grad_ndc.contiguous()
        )
        return grad_ndc, None

# Helper function to compare gradients
def compare_grads(my_grads, autograd_grads, tolerance=1e-6):
    return torch.allclose(my_grads, autograd_grads, atol=tolerance, rtol=tolerance)

# Define a test function using PyTorch's autograd
def test_ndc_to_pixels(ndc: torch.Tensor, dimension: list):
    ndc = ndc.clone()
    ndc[:, 0] = (ndc[:, 0] + 1) * (dimension[1] - 1) * 0.5
    ndc[:, 1] = (ndc[:, 1] + 1) * (dimension[0] - 1) * 0.5
    return ndc

def run_tests():
    tolerance = 1e-2
    eps = 1e-3
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dtype = torch.float32

    # 1. Gradcheck for numerical gradient correctness
    ndc = torch.randn((5, 3), requires_grad=True, dtype=torch.float32, device=device)
    dimension = torch.tensor([480, 640], dtype=torch.float32, device=device)

    assert gradcheck(ndc_to_pixels.apply, (ndc, dimension), eps=eps), "Gradcheck failed for ndc_to_pixels!"
    print("Gradcheck passed for ndc_to_pixels.")

    # 2. Verify backward computation with autograd
    ndc = torch.randn((5, 3), requires_grad=True, dtype=dtype, device=device)

    # My grads
    output = ndc_to_pixels.apply(ndc, dimension)
    loss = output.sum()
    loss.backward()
    my_grads_ndc = ndc.grad.clone()

    # Autograd grads
    ndc.grad.zero_()
    output = test_ndc_to_pixels(ndc, dimension)
    loss = output.sum()
    loss.backward()
    autograd_grads_ndc = ndc.grad.clone()

    assert compare_grads(my_grads_ndc, autograd_grads_ndc, tolerance), "Mismatch in gradients for ndc!"
    print("Gradients match for autograd.")

    # 3. Test edge cases: zero tensors
    ndc = torch.zeros((5, 3), requires_grad=True, dtype=dtype, device=device)

    output = ndc_to_pixels.apply(ndc, dimension)
    expected_output = torch.zeros_like(ndc)
    expected_output[:, 0] = 0.5 * (dimension[1] - 1)
    expected_output[:, 1] = 0.5 * (dimension[0] - 1)
    assert torch.allclose(output, expected_output, atol=tolerance, rtol=tolerance), "Output mismatch for zero input tensors."
    print("Edge case test passed for zero tensors.")

    # 4. Test with large batch sizes
    batch_size = 100
    ndc = torch.randn((batch_size, 3), requires_grad=True, dtype=dtype, device=device)

    # My grads
    output = ndc_to_pixels.apply(ndc, dimension)
    loss = output.sum()
    loss.backward()
    my_grads_ndc = ndc.grad.clone()

    # Autograd grads
    ndc.grad.zero_()
    output = test_ndc_to_pixels(ndc, dimension)
    loss = output.sum()
    loss.backward()
    autograd_grads_ndc = ndc.grad.clone()

    assert compare_grads(my_grads_ndc, autograd_grads_ndc, tolerance), "Mismatch in gradients for ndc (large batch size)!"
    print("Gradients match for large batch size.")

    # 5. Test with extreme values
    ndc = torch.full((10, 3), 1e6, requires_grad=True, dtype=dtype, device=device)

    output = ndc_to_pixels.apply(ndc, dimension)
    loss = output.sum()
    loss.backward()
    assert torch.isfinite(ndc.grad).all(), "Gradients for ndc are not finite with extreme values!"
    print("Extreme value test passed.")

# Run the tests
run_tests()

  from .autonotebook import tqdm as notebook_tqdm


Gradcheck passed for ndc_to_pixels.
Gradients match for autograd.
Edge case test passed for zero tensors.
Gradients match for large batch size.
Extreme value test passed.


  ctx.save_for_backward(torch.tensor(dimension, dtype=torch.float32), ndc)
