In [1]:

import diff_gaussian_rasterization as dgr
from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
from diff_gaussian_rasterization import _C as torch_backend
from jax_renderer_primitives import _build_rasterize_gaussians_fwd_primitive, _build_rasterize_gaussians_bwd_primitive

import jax
import jax.numpy as jnp
from jax.scipy.spatial.transform import Rotation as R
import torch
import functools
import matplotlib.pyplot as plt
import math
import numpy as np
import random
from random import randint
from tqdm import tqdm
from time import time
 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.deterministic=True


####################
# Helpers, Constants
####################

def getProjectionMatrix(znear, zfar, fovX, fovY):
    tanHalfFovY = math.tan((fovY / 2))
    tanHalfFovX = math.tan((fovX / 2))

    top = tanHalfFovY * znear
    bottom = -top
    right = tanHalfFovX * znear
    left = -right

    P = torch.zeros(4, 4)

    z_sign = 1.0

    P[0, 0] = 2.0 * znear / (right - left)
    P[1, 1] = 2.0 * znear / (top - bottom)
    P[0, 2] = (right + left) / (right - left)
    P[1, 2] = (top + bottom) / (top - bottom)
    P[3, 2] = z_sign
    P[2, 2] = z_sign * zfar / (zfar - znear)
    P[2, 3] = -(zfar * znear) / (zfar - znear)
    return P

def reset(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

from typing import NamedTuple
class Intrinsics(NamedTuple):
    height: int
    width: int
    fx: float
    fy: float
    cx: float
    cy: float
    near: float
    far: float

def torch_to_jax(torch_array):
    return jnp.array(torch_array.detach().cpu().numpy())

def jax_to_torch(jnp_array):
    return torch.tensor(np.array(jnp_array), requires_grad=True, device=device)

rasterizer_fwd_jax = _build_rasterize_gaussians_fwd_primitive()
rasterizer_bwd_jax = _build_rasterize_gaussians_bwd_primitive()

default_seed = 0
gt_seed = 1

#############################
# Arguments
#############################
reset(default_seed)

intrinsics = Intrinsics(
    height=200,
    width=300,
    fx=300.0, fy=300.0,
    cx=100.0, cy=100.0,
    near=0.01, far=2.5
)

fovX = jnp.arctan(intrinsics.width / 2 / intrinsics.fx) * 2.0
fovY = jnp.arctan(intrinsics.height / 2 / intrinsics.fy) * 2.0
tan_fovx = math.tan(fovX)
tan_fovy = math.tan(fovY)

means3D = jax.random.uniform(jax.random.PRNGKey(default_seed), shape=(100, 3), minval=-0.5, maxval=0.5) + jnp.array([0.0, 0.0, 1.0])
N = means3D.shape[0]
opacity = jnp.ones(shape=(N,1))
scales =jnp.ones((N,3)) * 4.5400e-03
rotations = jax.random.uniform(jax.random.PRNGKey(default_seed), shape=(N,4), minval=-1.0, maxval=1.0)
colors_precomp = jax.random.uniform(jax.random.PRNGKey(default_seed), shape=(N,3), minval=0.0, maxval=1.0)
cov3D_precomp = jax.random.uniform(jax.random.PRNGKey(default_seed), shape=(N,3), minval=0.0, maxval=0.1)
# sh = jax.random.uniform(jax.random.PRNGKey(default_seed), shape=(N,0), minval=0.0, maxval=1.0)
sh_jax = jnp.array([])

camera_pose_jax = jnp.eye(4)
proj_matrix = getProjectionMatrix(0.01, 100.0, fovX, fovY).transpose(0,1).cuda()
view_matrix = torch.transpose(torch.tensor(np.array(jnp.linalg.inv(camera_pose_jax))),0,1).cuda()
projmatrix = view_matrix @ proj_matrix

view_matrix = jnp.array(view_matrix.cpu().numpy())
projmatrix = jnp.array(projmatrix.cpu().numpy())


In [None]:

jax_fwd_args = (means3D,
            colors_precomp,
            opacity,
            scales,
            rotations,
            cov3D_precomp,
            view_matrix,
            projmatrix,
            sh_jax)

## Run 
num_rendered_jax, color_jax, radii_jax, geomBuffer_jax, binningBuffer_jax, imgBuffer_jax = rasterizer_fwd_jax.bind(
            jnp.zeros(3), # bg
            *jax_fwd_args,
            jnp.zeros(3), # campos
            tanfovx=tan_fovx, 
            tanfovy=tan_fovy, 
            image_height=int(intrinsics.height), 
            image_width=int(intrinsics.width),  
            sh_degree=0
)

num_rendered_torch, color_torch, radii_torch, geomBuffer_torch, binningBuffer_torch, imgBuffer_torch = torch_backend.rasterize_gaussians(*torch_fwd_args)
color_torch = color_torch.detach()
color_torch_jax = torch_to_jax(color_torch)


## Compare
assert num_rendered_torch == int(num_rendered_jax[0]), "num_rendered mismatch"
assert jnp.allclose(color_torch_jax, color_jax), "color mismatch"
assert jnp.allclose(torch_to_jax(radii_torch), radii_jax), "radii mismatch"
print("FORWARD PASSED: NUM_RENDER, COLOR, RADII")