In [None]:
import torch
from torch.autograd.gradcheck import gradcheck

from splat.utils import build_rotation
from splat.custom_backwards_implementation.gaussian_weight_derivatives import *


class final_color(torch.autograd.Function):
    @staticmethod
    def forward(ctx, color: torch.Tensor, current_T: torch.Tensor, alpha: torch.Tensor):
        """Color is a nx3 tensor, weight is a nx1 tensor, alpha is a nx1 tensor"""
        ctx.save_for_backward(color, current_T, alpha)
        return color * current_T * alpha
    
    @staticmethod
    def backward(ctx, grad_output: torch.Tensor):
        """Output of forward is a nx3 tensor so the grad_output is a nx3 tensor"""
        color, current_T, alpha = ctx.saved_tensors
        grad_color = grad_output * current_T * alpha
        grad_alpha = (grad_output * color * current_T).sum(dim=1, keepdim=True)
        return grad_color, None, grad_alpha
    
def autograd_test(color: torch.Tensor, current_T: torch.Tensor, alpha: torch.Tensor):
    return color * current_T * alpha

# first we check with gradcheck
color = torch.tensor([[1.0, 2.0, 3.0]], requires_grad=True, dtype=torch.float64)
current_T = torch.tensor([[4.0]], dtype=torch.float64)
alpha = torch.tensor([[5.0]], requires_grad=True, dtype=torch.float64)

print("gradcheck passing: ", gradcheck(final_color.apply, (color, current_T, alpha)))

# then we check with autograd
color = torch.tensor([[1.0, 2.0, 3.0]], requires_grad=True, dtype=torch.float64)
current_T = torch.tensor([[4.0]], dtype=torch.float64)
alpha = torch.tensor([[5.0]], requires_grad=True, dtype=torch.float64)

output = final_color.apply(color, current_T, alpha)
loss = output.sum()
loss.backward()
print("My grads: ", color.grad, current_T.grad, alpha.grad)

# then we check with autograd
color = torch.tensor([[1.0, 2.0, 3.0]], requires_grad=True, dtype=torch.float64)
current_T = torch.tensor([[4.0]], dtype=torch.float64)
alpha = torch.tensor([[5.0]], requires_grad=True, dtype=torch.float64)

output = autograd_test(color, current_T, alpha)
loss = output.sum()
loss.backward()
print("Autograd grads: ", color.grad, current_T.grad, alpha.grad)


# finally we test with multiple dimensions
color = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], requires_grad=True, dtype=torch.float64)
current_T = torch.tensor([[4.0], [5.0]], dtype=torch.float64)
alpha = torch.tensor([[5.0], [6.0]], requires_grad=True, dtype=torch.float64)

output = final_color.apply(color, current_T, alpha)
loss = output.sum()
loss.backward()
print("My grads: ", color.grad, current_T.grad, alpha.grad)

color = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], requires_grad=True, dtype=torch.float64)
current_T = torch.tensor([[4.0], [5.0]], dtype=torch.float64)
alpha = torch.tensor([[5.0], [6.0]], requires_grad=True, dtype=torch.float64)

output = autograd_test(color, current_T, alpha)
loss = output.sum()
loss.backward()
print("Autograd grads: ", color.grad, current_T.grad, alpha.grad)