In [1]:

import diff_gaussian_rasterization as dgr
from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
from diff_gaussian_rasterization import _C as torch_backend
from jax_renderer import _build_rasterize_gaussians_fwd_primitive, _build_rasterize_gaussians_bwd_primitive

import jax
import jax.numpy as jnp
from jax.scipy.spatial.transform import Rotation as R
import torch
import functools
import matplotlib.pyplot as plt
import math
import numpy as np
from random import randint
from tqdm import tqdm
from time import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

def getProjectionMatrix(znear, zfar, fovX, fovY):
    tanHalfFovY = math.tan((fovY / 2))
    tanHalfFovX = math.tan((fovX / 2))

    top = tanHalfFovY * znear
    bottom = -top
    right = tanHalfFovX * znear
    left = -right

    P = torch.zeros(4, 4)

    z_sign = 1.0

    P[0, 0] = 2.0 * znear / (right - left)
    P[1, 1] = 2.0 * znear / (top - bottom)
    P[0, 2] = (right + left) / (right - left)
    P[1, 2] = (top + bottom) / (top - bottom)
    P[3, 2] = z_sign
    P[2, 2] = z_sign * zfar / (zfar - znear)
    P[2, 3] = -(zfar * znear) / (zfar - znear)
    return P

from typing import NamedTuple
class Intrinsics(NamedTuple):
    height: int
    width: int
    fx: float
    fy: float
    cx: float
    cy: float
    near: float
    far: float

def torch_to_jax(torch_array):
    return jnp.array(torch_array.detach().cpu().numpy())

def jax_to_torch(jnp_array):
    return torch.tensor(np.array(jnp_array), requires_grad=True, device=device)

default_seed = 1222
gt_seed = 1201223

#############################
# Arguments
#############################
torch.manual_seed(gt_seed)
import random
random.seed(gt_seed)
np.random.seed(gt_seed)

intrinsics = Intrinsics(
    height=200,
    width=200,
    fx=300.0, fy=300.0,
    cx=100.0, cy=100.0,
    near=0.01, far=2.5
)

fovX = jnp.arctan(intrinsics.width / 2 / intrinsics.fx) * 2.0
fovY = jnp.arctan(intrinsics.height / 2 / intrinsics.fy) * 2.0
tan_fovx = math.tan(fovX)
tan_fovy = math.tan(fovY)
print(tan_fovx, tan_fovy)


rasterize_gaussians_fwd
rasterize_gaussians_bwd
0.7499999908297853 0.7499999908297853


In [19]:
means3D = jax.random.uniform(jax.random.PRNGKey(default_seed), shape=(100, 3), minval=-0.5, maxval=0.5) + jnp.array([0.0, 0.0, 1.0])
N = means3D.shape[0]
opacity = jnp.ones(shape=(N,1))
scales =jnp.ones((N,3)) * 4.5400e-03
rotations = jax.random.uniform(jax.random.PRNGKey(default_seed), shape=(N,4), minval=-1.0, maxval=1.0)
colors_precomp = jax.random.uniform(jax.random.PRNGKey(default_seed), shape=(N,3), minval=0.0, maxval=1.0)
cov3D_precomp = jax.random.uniform(jax.random.PRNGKey(default_seed), shape=(N,3), minval=0.0, maxval=0.1)
sh = jax.random.uniform(jax.random.PRNGKey(default_seed), shape=(N,3), minval=0.0, maxval=1.0)

camera_pose_jax = jnp.eye(4)
proj_matrix = getProjectionMatrix(0.01, 100.0, fovX, fovY).transpose(0,1).cuda()
view_matrix = torch.transpose(torch.tensor(np.array(jnp.linalg.inv(camera_pose_jax))),0,1).cuda()
projmatrix = view_matrix @ proj_matrix

view_matrix = jnp.array(view_matrix.cpu().numpy())
projmatrix = jnp.array(projmatrix.cpu().numpy())

rasterizer_fwd_jax = _build_rasterize_gaussians_fwd_primitive()
rasterizer_bwd_jax = _build_rasterize_gaussians_bwd_primitive()

jax_args = (means3D,
            colors_precomp,
            opacity,
            scales,
            rotations,
            cov3D_precomp,
            view_matrix,
            projmatrix,
            sh)

num_rendered_jax, color_jax, radii_jax, geomBuffer_jax, binningBuffer_jax, imgBuffer_jax = rasterizer_fwd_jax.bind(
            jnp.zeros(3), # bg
            *jax_args,
            jnp.zeros(3), # campos
            tanfovx=tan_fovx, 
            tanfovy=tan_fovy, 
            image_height=int(intrinsics.height), 
            image_width=int(intrinsics.width),  
            sh_degree=0
)

# geomBuffer_jax, binningBuffer_jax, imgBuffer_jax = [jnp.array(x) for x in [geomBuffer_jax, binningBuffer_jax, imgBuffer_jax]]

# bwd
grad_out_color_jax = jnp.array(color_jax)
jax_bwd_args = (
    jnp.zeros(3),
    means3D, #1
    radii_jax, #2 
    colors_precomp, #3 
    scales, #4
    rotations, #5 
    # raster_settings.scale_modifier), 
    cov3D_precomp, #6 
    view_matrix, #7 
    projmatrix, #8
    grad_out_color_jax, #9
    sh, #10
    jnp.zeros(3), #11
    geomBuffer_jax, #12
    jnp.array([[1]]),#num_rendered_jax, #13 
    binningBuffer_jax, #14
    imgBuffer_jax #15
)

(grad_means2D_jax,
 grad_colors_precomp_jax,
 grad_opacities_jax,
 grad_means3D_jax,
 grad_cov3Ds_precomp_jax,
 grad_sh_jax,
 grad_scales_jax, grad_rotations_jax, _) = rasterizer_bwd_jax.bind(
            *jax_bwd_args,
            tanfovx=tan_fovx, 
            tanfovy=tan_fovy, 
            sh_degree=0
)  


raster_settings = GaussianRasterizationSettings(
    image_height=int(intrinsics.height),
    image_width=int(intrinsics.width),
    tanfovx=tan_fovx,
    tanfovy=tan_fovy,
    bg=torch.tensor([0.0, 0.0, 0.0]).cuda(),
    scale_modifier=1.0,
    viewmatrix=jax_to_torch(view_matrix),
    projmatrix=jax_to_torch(projmatrix),
    sh_degree=0,
    campos=torch.zeros(3).cuda(),
    prefiltered=False,
    debug=None
)
rasterizer_fwd_torch = GaussianRasterizer(raster_settings=raster_settings)

torch_args = (
    raster_settings.bg, 
    jax_to_torch(means3D),
    jax_to_torch(colors_precomp),
    jax_to_torch(opacity),
    jax_to_torch(scales),
    jax_to_torch(rotations),
    raster_settings.scale_modifier,
    jax_to_torch(cov3D_precomp), # (None -> torch.Tensor([])),
    raster_settings.viewmatrix,
    raster_settings.projmatrix,
    raster_settings.tanfovx,
    raster_settings.tanfovy,
    raster_settings.image_height,
    raster_settings.image_width,
    jax_to_torch(sh), # (None -> torch.Tensor([])),
    raster_settings.sh_degree,
    raster_settings.campos,
    raster_settings.prefiltered,
    raster_settings.debug
)

num_rendered_torch, color_torch, radii_torch, geomBuffer_torch, binningBuffer_torch, imgBuffer_torch = torch_backend.rasterize_gaussians(*torch_args)
color_torch = color_torch.detach()
color_torch_jax = jnp.array(color_torch.cpu().detach().numpy())

assert num_rendered_torch == int(num_rendered_jax[0])
assert jnp.allclose(torch_to_jax(radii_torch), radii_jax)
assert jnp.allclose(torch_to_jax(color_torch), color_jax)

grad_out_color_torch= torch.Tensor(np.array(color_torch_jax)).cuda()

assert jnp.allclose(torch_to_jax(grad_out_color_torch), grad_out_color_jax)


torch_bwd_args = (
    raster_settings.bg,
    jax_to_torch(means3D), 
    radii_torch, 
    jax_to_torch(colors_precomp), 
    jax_to_torch(scales), 
    jax_to_torch(rotations), 
    raster_settings.scale_modifier, 
    torch.Tensor([]), 
    raster_settings.viewmatrix, 
    raster_settings.projmatrix, 
    raster_settings.tanfovx, 
    raster_settings.tanfovy, 
    grad_out_color_torch, 
    torch.Tensor([]), 
    raster_settings.sh_degree, 
    raster_settings.campos,
    geomBuffer_torch,
    num_rendered_torch,
    binningBuffer_torch,
    imgBuffer_torch,
    raster_settings.debug
)

(grad_means2D_torch,
 grad_colors_precomp_torch,
 grad_opacities_torch,
 grad_means3D_torch,
 grad_cov3Ds_precomp_torch,
 grad_sh_torch,
 grad_scales_torch, grad_rotations_torch) = torch_backend.rasterize_gaussians_backward(*torch_bwd_args)



print(grad_means2D_jax.sum())
print(grad_colors_precomp_jax.sum())
print(grad_opacities_jax.sum())
print(grad_means3D_jax.sum())


print(grad_means2D_torch.sum())
print(grad_colors_precomp_torch.sum())
print(grad_opacities_torch.sum())
print(grad_means3D_torch.sum())



48.12236
224.77188
140.82617
81.21434
tensor(48.1224, device='cuda:0')
tensor(224.7719, device='cuda:0')
tensor(140.8262, device='cuda:0')
tensor(81.2143, device='cuda:0')


In [20]:
assert num_rendered_torch == int(num_rendered_jax[0])
assert jnp.allclose(torch_to_jax(radii_torch), radii_jax)
assert jnp.allclose(torch_to_jax(color_torch), color_jax)
assert jnp.allclose(torch_to_jax(grad_out_color_torch), grad_out_color_jax)

assert jnp.allclose(torch_to_jax(grad_colors_precomp_torch), grad_colors_precomp_jax, atol=0.0, rtol=1e-3)
assert jnp.allclose(torch_to_jax(grad_opacities_torch), grad_opacities_jax, atol=0.0, rtol=1e-3)
assert jnp.allclose(torch_to_jax(grad_scales_torch), grad_scales_jax, atol=0.0, rtol=1e-3)
assert jnp.allclose(torch_to_jax(grad_rotations_torch), grad_rotations_jax, atol=0.0, rtol=1e-3)
assert jnp.allclose(torch_to_jax(grad_means3D_torch), grad_means3D_jax, atol=0.0, rtol=1e-3)


In [21]:
a,b = torch_to_jax(grad_rotations_torch), grad_rotations_jax
a,b = torch_to_jax(grad_scales_torch), grad_scales_jax
a,b = (torch_to_jax(grad_opacities_torch), grad_opacities_jax)
a,b = (torch_to_jax(grad_colors_precomp_torch), grad_colors_precomp_jax)
a,b = (torch_to_jax(grad_means3D_torch), grad_means3D_jax)
error = jnp.abs((a) - b).reshape(-1)
idx = error.argmax()
print(error[idx])
print(a.reshape(-1)[idx])
print(b.reshape(-1)[idx])

3.3140182e-05
1.2782489
1.278282


In [12]:
grad_means3D_jax[:10]

Array([[  0.        ,   0.        ,   0.        ],
       [  0.8998901 ,  -1.632731  ,  -2.737176  ],
       [  0.        ,   0.        ,   0.        ],
       [  0.        ,   0.        ,   0.        ],
       [  0.57558167,  -0.37075454,  -2.0976763 ],
       [  0.        ,   0.        ,   0.        ],
       [  0.        ,   0.        ,   0.        ],
       [  9.204973  ,  -7.2818274 ,  -0.58931863],
       [-30.837902  , -22.740711  ,   6.0373063 ],
       [ -0.4843986 ,  -1.5579131 ,  -2.1157906 ]], dtype=float32)