In [6]:

import diff_gaussian_rasterization as dgr
from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
from diff_gaussian_rasterization import _C as torch_backend
from jax_renderer import _build_rasterize_gaussians_fwd_primitive, _build_rasterize_gaussians_bwd_primitive

import jax
import jax.numpy as jnp
from jax.scipy.spatial.transform import Rotation as R
import torch
import functools
import matplotlib.pyplot as plt
import math
import numpy as np
from random import randint
from tqdm import tqdm
from time import time
 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

def getProjectionMatrix(znear, zfar, fovX, fovY):
    tanHalfFovY = math.tan((fovY / 2))
    tanHalfFovX = math.tan((fovX / 2))

    top = tanHalfFovY * znear
    bottom = -top
    right = tanHalfFovX * znear
    left = -right

    P = torch.zeros(4, 4)

    z_sign = 1.0

    P[0, 0] = 2.0 * znear / (right - left)
    P[1, 1] = 2.0 * znear / (top - bottom)
    P[0, 2] = (right + left) / (right - left)
    P[1, 2] = (top + bottom) / (top - bottom)
    P[3, 2] = z_sign
    P[2, 2] = z_sign * zfar / (zfar - znear)
    P[2, 3] = -(zfar * znear) / (zfar - znear)
    return P

from typing import NamedTuple
class Intrinsics(NamedTuple):
    height: int
    width: int
    fx: float
    fy: float
    cx: float
    cy: float
    near: float
    far: float

def torch_to_jax(torch_array):
    return jnp.array(torch_array.detach().cpu().numpy())

def jax_to_torch(jnp_array):
    return torch.tensor(np.array(jnp_array), requires_grad=True, device=device)

default_seed = 1222
gt_seed = 1201223

#############################
# Arguments
#############################
torch.manual_seed(gt_seed)
import random
random.seed(gt_seed)
np.random.seed(gt_seed)

intrinsics = Intrinsics(
    height=200,
    width=200,
    fx=300.0, fy=300.0,
    cx=100.0, cy=100.0,
    near=0.01, far=2.5
)

fovX = jnp.arctan(intrinsics.width / 2 / intrinsics.fx) * 2.0
fovY = jnp.arctan(intrinsics.height / 2 / intrinsics.fy) * 2.0
tan_fovx = math.tan(fovX)
tan_fovy = math.tan(fovY)
print(tan_fovx, tan_fovy)


0.7499999908297853 0.7499999908297853


In [7]:
means3D = jax.random.uniform(jax.random.PRNGKey(default_seed), shape=(100, 3), minval=-0.5, maxval=0.5) + jnp.array([0.0, 0.0, 1.0])
N = means3D.shape[0]
opacity = jnp.ones(shape=(N,1))
scales =jnp.ones((N,3)) * 4.5400e-03
rotations = jax.random.uniform(jax.random.PRNGKey(default_seed), shape=(N,4), minval=-1.0, maxval=1.0)
colors_precomp = jax.random.uniform(jax.random.PRNGKey(default_seed), shape=(N,3), minval=0.0, maxval=1.0)
cov3D_precomp = jax.random.uniform(jax.random.PRNGKey(default_seed), shape=(N,3), minval=0.0, maxval=0.1)
sh = jax.random.uniform(jax.random.PRNGKey(default_seed), shape=(N,3), minval=0.0, maxval=1.0)

camera_pose_jax = jnp.eye(4)
proj_matrix = getProjectionMatrix(0.01, 100.0, fovX, fovY).transpose(0,1).cuda()
view_matrix = torch.transpose(torch.tensor(np.array(jnp.linalg.inv(camera_pose_jax))),0,1).cuda()
projmatrix = view_matrix @ proj_matrix

view_matrix = jnp.array(view_matrix.cpu().numpy())
projmatrix = jnp.array(projmatrix.cpu().numpy())

rasterizer_fwd_jax = _build_rasterize_gaussians_fwd_primitive()
rasterizer_bwd_jax = _build_rasterize_gaussians_bwd_primitive()

jax_args = (means3D,
            colors_precomp,
            opacity,
            scales,
            rotations,
            cov3D_precomp,
            view_matrix,
            projmatrix,
            sh)

num_rendered_jax, color_jax, radii_jax, geomBuffer_jax, binningBuffer_jax, imgBuffer_jax = rasterizer_fwd_jax.bind(
            jnp.zeros(3), # bg
            *jax_args,
            jnp.zeros(3), # campos
            tanfovx=tan_fovx, 
            tanfovy=tan_fovy, 
            image_height=int(intrinsics.height), 
            image_width=int(intrinsics.width),  
            sh_degree=0
)

geomBuffer_jax, binningBuffer_jax, imgBuffer_jax = [jnp.array(x) for x in [geomBuffer_jax, binningBuffer_jax, imgBuffer_jax]]

# bwd
grad_out_color_jax = jnp.array(color_jax)
jax_bwd_args = (
    jnp.zeros(3),
    means3D, #1
    radii_jax, #2 
    colors_precomp, #3 
    scales, #4
    rotations, #5 
    # raster_settings.scale_modifier), 
    cov3D_precomp, #6 
    view_matrix, #7 
    projmatrix, #8
    grad_out_color_jax, #9
    sh, #10
    jnp.zeros(3), #11
    geomBuffer_jax, #12
    jnp.array([[1]]),#num_rendered_jax, #13 
    binningBuffer_jax, #14
    imgBuffer_jax #15
)

(grad_means2D_jax,
 grad_colors_precomp_jax,
 grad_opacities_jax,
 grad_means3D_jax,
 grad_cov3Ds_precomp_jax,
 grad_sh_jax,
 grad_scales_jax, grad_rotations_jax, _) = rasterizer_bwd_jax.bind(
            *jax_bwd_args,
            tanfovx=tan_fovx, 
            tanfovy=tan_fovy, 
            sh_degree=0
)  
print(grad_means2D_jax.sum())
print(grad_colors_precomp_jax.sum())
print(grad_opacities_jax.sum())
print(grad_means3D_jax.sum())

heelo
heelo 2
heelo 4
JAX dimensions  (200, 200)
heelo 6
heelo 7
start
start2
start3
start4
Num gaussians 100
Num expanded gaussians 1
rasterize_impl.cu buffers[0] = 0
rasterize_impl.cu buffers[1] = 0
rasterize_impl.cu buffers[2] = 128
rasterize_impl.cu buffers[3] = 63
rasterize_impl.cu buffers[4] = 0
rasterize_impl.cu buffers[5] = 0
rasterize_impl.cu buffers[6] = 128
rasterize_impl.cu buffers[7] = 63
rasterize_impl.cu buffers[8] = 0
rasterize_impl.cu buffers[9] = 0
rasterize_impl.cu buffers[10] = 128
rasterize_impl.cu buffers[11] = 63
rasterize_impl.cu buffers[12] = 0
rasterize_impl.cu buffers[13] = 0
rasterize_impl.cu buffers[14] = 128
rasterize_impl.cu buffers[15] = 63
rasterize_impl.cu buffers[16] = 0
rasterize_impl.cu buffers[17] = 0
rasterize_impl.cu buffers[18] = 128
rasterize_impl.cu buffers[19] = 63
rasterize_impl.cu buffers[20] = 0
rasterize_impl.cu buffers[21] = 0
rasterize_impl.cu buffers[22] = 128
rasterize_impl.cu buffers[23] = 63
rasterize_impl.cu buffers[24] = 0
rasteri

In [8]:
print("==========TORCH==========")
raster_settings = GaussianRasterizationSettings(
    image_height=int(intrinsics.height),
    image_width=int(intrinsics.width),
    tanfovx=tan_fovx,
    tanfovy=tan_fovy,
    bg=torch.tensor([0.0, 0.0, 0.0]).cuda(),
    scale_modifier=1.0,
    viewmatrix=jax_to_torch(view_matrix),
    projmatrix=jax_to_torch(projmatrix),
    sh_degree=0,
    campos=torch.zeros(3).cuda(),
    prefiltered=False,
    debug=None
)
rasterizer_fwd_torch = GaussianRasterizer(raster_settings=raster_settings)

torch_args = (
    raster_settings.bg, 
    jax_to_torch(means3D),
    jax_to_torch(colors_precomp),
    jax_to_torch(opacity),
    jax_to_torch(scales),
    jax_to_torch(rotations),
    raster_settings.scale_modifier,
    jax_to_torch(cov3D_precomp), # (None -> torch.Tensor([])),
    raster_settings.viewmatrix,
    raster_settings.projmatrix,
    raster_settings.tanfovx,
    raster_settings.tanfovy,
    raster_settings.image_height,
    raster_settings.image_width,
    jax_to_torch(sh), # (None -> torch.Tensor([])),
    raster_settings.sh_degree,
    raster_settings.campos,
    raster_settings.prefiltered,
    raster_settings.debug
)

num_rendered_torch, color_torch, radii_torch, geomBuffer_torch, binningBuffer_torch, imgBuffer_torch = torch_backend.rasterize_gaussians(*torch_args)
color_torch = color_torch.detach()
color_torch_jax = jnp.array(color_torch.cpu().detach().numpy())

assert num_rendered_torch == int(num_rendered_jax[0])
assert jnp.allclose(torch_to_jax(radii_torch), radii_jax)
assert jnp.allclose(torch_to_jax(color_torch), color_jax)

grad_out_color_torch= torch.Tensor(np.array(color_torch_jax)).cuda()

assert jnp.allclose(torch_to_jax(grad_out_color_torch), grad_out_color_jax)


torch_bwd_args = (
    raster_settings.bg,
    jax_to_torch(means3D), 
    radii_torch, 
    jax_to_torch(colors_precomp), 
    jax_to_torch(scales), 
    jax_to_torch(rotations), 
    raster_settings.scale_modifier, 
    torch.Tensor([]), 
    raster_settings.viewmatrix, 
    raster_settings.projmatrix, 
    raster_settings.tanfovx, 
    raster_settings.tanfovy, 
    grad_out_color_torch, 
    torch.Tensor([]), 
    raster_settings.sh_degree, 
    raster_settings.campos,
    geomBuffer_torch,
    num_rendered_torch,
    binningBuffer_torch,
    imgBuffer_torch,
    raster_settings.debug
)

(grad_means2D_torch,
 grad_colors_precomp_torch,
 grad_opacities_torch,
 grad_means3D_torch,
 grad_cov3Ds_precomp_torch,
 grad_sh_torch,
 grad_scales_torch, grad_rotations_torch) = torch_backend.rasterize_gaussians_backward(*torch_bwd_args)


rasterize_impl.cu buffers[0] = 0
rasterize_impl.cu buffers[1] = 0
rasterize_impl.cu buffers[2] = 128
rasterize_impl.cu buffers[3] = 63
rasterize_impl.cu buffers[4] = 0
rasterize_impl.cu buffers[5] = 0
rasterize_impl.cu buffers[6] = 128
rasterize_impl.cu buffers[7] = 63
rasterize_impl.cu buffers[8] = 0
rasterize_impl.cu buffers[9] = 0
rasterize_impl.cu buffers[10] = 128
rasterize_impl.cu buffers[11] = 63
rasterize_impl.cu buffers[12] = 0
rasterize_impl.cu buffers[13] = 0
rasterize_impl.cu buffers[14] = 128
rasterize_impl.cu buffers[15] = 63
rasterize_impl.cu buffers[16] = 0
rasterize_impl.cu buffers[17] = 0
rasterize_impl.cu buffers[18] = 128
rasterize_impl.cu buffers[19] = 63
rasterize_impl.cu buffers[20] = 0
rasterize_impl.cu buffers[21] = 0
rasterize_impl.cu buffers[22] = 128
rasterize_impl.cu buffers[23] = 63
rasterize_impl.cu buffers[24] = 0
rasterize_impl.cu buffers[25] = 0
rasterize_impl.cu buffers[26] = 128
rasterize_impl.cu buffers[27] = 63
rasterize_impl.cu buffers[28] = 0
ras

In [9]:

assert num_rendered_torch == int(num_rendered_jax[0])
assert jnp.allclose(torch_to_jax(radii_torch), radii_jax)
assert jnp.allclose(torch_to_jax(color_torch), color_jax)
assert jnp.allclose(torch_to_jax(grad_out_color_torch), grad_out_color_jax)

assert jnp.allclose(torch_to_jax(grad_colors_precomp_torch), grad_colors_precomp_jax)
assert jnp.allclose(torch_to_jax(grad_colors_precomp_torch), grad_colors_precomp_jax)
assert jnp.allclose(torch_to_jax(grad_colors_precomp_torch), grad_colors_precomp_jax)
assert jnp.allclose(torch_to_jax(grad_colors_precomp_torch), grad_colors_precomp_jax)


In [10]:
grad_colors_precomp_torch[:10]

tensor([[0.0000, 0.0000, 0.0000],
        [2.3346, 1.8149, 3.2405],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [1.9680, 1.3964, 2.3295],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.9078, 0.9627, 1.2463],
        [1.0290, 0.6910, 1.1709],
        [1.6913, 2.3501, 2.5957]], device='cuda:0')

In [11]:
grad_colors_precomp_jax[:10]

Array([[0.        , 0.        , 0.        ],
       [2.3346164 , 1.8148817 , 3.24054   ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [1.9679893 , 1.3964491 , 2.3294568 ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.9077565 , 0.96265846, 1.2463337 ],
       [1.0289726 , 0.6909588 , 1.1709243 ],
       [1.6912593 , 2.3500519 , 2.5956943 ]], dtype=float32)

In [None]:
geomBuffer_jax[:100],geomBuffer_torch[:100]

(Array([  0,   0,   0,   0,  57, 161, 191,  63,  64, 205, 154,  63, 210,
        164, 122,  63,  21,  48, 165,  63,   0,   0,   0,   0, 118,  60,
        120,  63, 142, 179, 152,  63, 215,  14, 177,  63, 203,  61, 175,
         63, 138,  75,  34,  63,   0,   0,   0,   0,  78,  11, 127,  63,
        130, 112,  64,  63, 164,  82, 162,  63,  97,  90, 140,  63,   0,
          0,   0,   0,  42, 208, 137,  63, 173, 122, 154,  63, 136, 250,
        181,  63, 170,  72,  46,  63,   0,   0,   0,   0,   0,   0,   0,
          0, 130, 125, 171,  63, 230, 204, 190,  63], dtype=uint8),
 tensor([  0,   0,   0,   0,  57, 161, 191,  63,  64, 205, 154,  63, 210, 164,
         122,  63,  21,  48, 165,  63,   0,   0,   0,   0, 118,  60, 120,  63,
         142, 179, 152,  63, 215,  14, 177,  63, 203,  61, 175,  63, 138,  75,
          34,  63,   0,   0,   0,   0,  78,  11, 127,  63, 130, 112,  64,  63,
         164,  82, 162,  63,  97,  90, 140,  63,   0,   0,   0,   0,  42, 208,
         137,  63, 173, 12

In [None]:
binningBuffer_jax[:100],binningBuffer_torch[:100]

(Array([ 6,  0,  0,  0, 10,  0,  0,  0, 32,  0,  0,  0, 56,  0,  0,  0,  3,
         0,  0,  0, 56,  0,  0,  0,  3,  0,  0,  0, 99,  0,  0,  0, 99,  0,
         0,  0,  6,  0,  0,  0, 10,  0,  0,  0, 32,  0,  0,  0, 56,  0,  0,
         0,  3,  0,  0,  0, 56,  0,  0,  0,  3,  0,  0,  0, 99,  0,  0,  0,
        94,  0,  0,  0, 99,  0,  0,  0, 94,  0,  0,  0, 77,  0,  0,  0, 87,
         0,  0,  0, 79,  0,  0,  0, 14,  0,  0,  0, 91,  0,  0,  0],      dtype=uint8),
 tensor([ 6,  0,  0,  0, 10,  0,  0,  0, 32,  0,  0,  0, 56,  0,  0,  0,  3,  0,
          0,  0, 56,  0,  0,  0,  3,  0,  0,  0, 99,  0,  0,  0, 99,  0,  0,  0,
          6,  0,  0,  0, 10,  0,  0,  0, 32,  0,  0,  0, 56,  0,  0,  0,  3,  0,
          0,  0, 56,  0,  0,  0,  3,  0,  0,  0, 99,  0,  0,  0, 94,  0,  0,  0,
         99,  0,  0,  0, 94,  0,  0,  0, 77,  0,  0,  0, 87,  0,  0,  0, 79,  0,
          0,  0, 14,  0,  0,  0, 91,  0,  0,  0], device='cuda:0',
        dtype=torch.uint8))



PRIOR TO BIG LOOP: rounds=0, range.x=0, range.y=0
PRIOR TO BIG LOOP: rounds=0, range.x=0, range.y=0
PRIOR TO BIG LOOP: rounds=0, range.x=0, range.y=0
PRIOR TO BIG LOOP: rounds=0, range.x=0, range.y=0
PRIOR TO BIG LOOP: rounds=0, range.x=0, range.y=0
PRIOR TO BIG LOOP: rounds=0, range.x=0, range.y=0
PRIOR TO BIG LOOP: rounds=0, range.x=0, range.y=0
PRIOR TO BIG LOOP: rounds=0, range.x=0, range.y=0
PRIOR TO BIG LOOP: rounds=0, range.x=0, range.y=0
PRIOR TO BIG LOOP: rounds=0, range.x=0, range.y=0
PRIOR TO BIG LOOP: rounds=0, range.x=0, range.y=0
PRIOR TO BIG LOOP: rounds=0, range.x=0, range.y=0
PRIOR TO BIG LOOP: rounds=0, range.x=0, range.y=0
PRIOR TO BIG LOOP: rounds=0, range.x=0, range.y=0
PRIOR TO BIG LOOP: rounds=0, range.x=0, range.y=0
PRIOR TO BIG LOOP: rounds=0, range.x=0, range.y=0
PRIOR TO BIG LOOP: rounds=0, range.x=0, range.y=0
PRIOR TO BIG LOOP: rounds=0, range.x=0, range.y=0
PRIOR TO BIG LOOP: rounds=0, range.x=0, range.y=0
PRIOR TO BIG LOOP: rounds=0, range.x=0, range.y=0


In [None]:
grad_means3D_torch

tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 7.7216e-01, -4.5224e+00, -2.2172e+00],
        [ 5.1902e+00, -7.6272e+00,  1.9243e-02],
        [ 1.4317e+00, -1.5979e+00, -4.3937e+00],
        [ 5.1646e-01, -1.0982e+00, -1.7987e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 3.8810e-01, -1.2305e+00, -2.4444e+00],
        [ 8.4994e+00, -1.4889e+01, -2.5338e-01],
        [-2.9631e+01, -3.2217e+01,  6.0128e+00],
        [-4.1778e-01, -5.0441e+00, -1.0335e+00],
        [-1.8100e+01, -1.4253e+01, -3.6376e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 3.7670e+00,  3.1282e+01,  3.0279e+00],
        [-1.0451e+00,  1.0489e+00, -3.4856e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
        [-1.6426e-01, -3.4272e+00, -1.5606e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
        [-3.6169e-01,  2.8698e-01, -1.2559e+00],
        [ 8.5028e+00,  1.0120e+01, -2.3908e+00],
        [ 6.6112e+00,  5.2532e+00,  8.3333e-01],
        [-3.5350e-02

In [None]:
assert jnp.allclose(torch_to_jax(grad_opacities_torch)[:10], grad_opacities_jax[:10])


AssertionError: 

In [None]:
grad_opacities_jax[:10]

Array([[0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.]], dtype=float32)

In [None]:
print(grad_colors_precomp_torch[:10])
print(grad_colors_precomp_torch[:10].cpu().detach().numpy())

tensor([[0.0000, 0.0000, 0.0000],
        [2.1184, 1.6468, 2.9404],
        [1.5501, 1.9515, 1.3912],
        [3.6697, 1.4212, 6.7309],
        [1.7601, 1.2489, 2.0834],
        [0.0000, 0.0000, 0.0000],
        [2.3870, 0.1952, 1.5356],
        [0.8381, 0.8887, 1.1506],
        [0.9910, 0.6655, 1.1277],
        [1.5309, 2.1273, 2.3496]], device='cuda:0')
[[0.         0.         0.        ]
 [2.1183844  1.6467875  2.9404006 ]
 [1.5501063  1.9514694  1.3911976 ]
 [3.6696537  1.4211656  6.7309422 ]
 [1.7600881  1.2489258  2.0833688 ]
 [0.         0.         0.        ]
 [2.386981   0.19520165 1.5356065 ]
 [0.83805925 0.8887458  1.1506407 ]
 [0.9909896  0.66545314 1.1277013 ]
 [1.5309441  2.1272895  2.3496473 ]]


In [None]:

fig, (ax1, ax2) = plt.subplots(1, 2)
ax1.imshow(jnp.transpose(color_jax, (1,2,0))[...,:3])
ax2.imshow(jnp.transpose(color_torch_jax, (1,2,0))[...,:3])

assert jnp.allclose(color_jax, color_torch_jax)

In [None]:
assert num_rendered_torch == int(num_rendered_jax[0])
assert jnp.allclose(torch_to_jax(radii_torch), radii_jax)
assert jnp.allclose(torch_to_jax(color_torch), color_jax)
assert jnp.allclose(torch_to_jax(binningBuffer_torch), binningBuffer_jax[:binningBuffer_torch.shape[0]])
assert jnp.allclose(torch_to_jax(imgBuffer_torch), imgBuffer_jax[:imgBuffer_torch.shape[0]])
assert jnp.allclose(torch_to_jax(geomBuffer_torch), geomBuffer_jax[:geomBuffer_torch.shape[0]])

outs1 = torch_backend.rasterize_gaussians(*torch_args)
outs2 = torch_backend.rasterize_gaussians(*torch_args)
assert outs1[0] == outs2[0]
print(outs1[0])
print(outs2[0])
for (i,j) in zip(outs1[1:], outs2[1:]):
    print(torch.allclose(i,j))

AssertionError: 

In [None]:
print("1")
grad_out_color_jax = jnp.array(color_jax)
jax_bwd_args = (
    jnp.zeros(3),
    means3D, #1
    radii, #2 
    colors_precomp, #3 
    scales, #4
    rotations, #5 
    # raster_settings.scale_modifier), 
    cov3D_precomp, #6 
    viewmatrix, #7 
    projmatrix, #8
    grad_out_color_jax, #9
    sh, #10
    jnp.zeros(3), #11
    geomBuffer_jax, #12
    jnp.array([[1]]),#num_rendered_jax, #13 
    binningBuffer_jax, #14
    imgBuffer_jax #15
)
print("2")

In [None]:
dummy_out_color_torch = torch.tensor(torch.rand((3, int(intrinsics.height), int(intrinsics.width))), requires_grad=False, device=device).detach()
grad_out_color_torch = dummy_out_color_torch - color_torch



In [None]:
rasterizer_bwd_jax = _build_rasterize_gaussians_bwd_primitive()
dummy_out_color_jax = torch_to_jax(dummy_out_color_torch)
grad_out_color_jax = dummy_out_color_jax - color_jax

assert jnp.allclose(torch_to_jax(grad_out_color_torch), grad_out_color_jax)


In [None]:
print("1")

jax_bwd_args = (
    torch_to_jax(raster_settings.bg), #0
    torch_to_jax(means3D), #1
    torch_to_jax(radii_torch), #2 
    torch_to_jax(colors_precomp), #3 
    torch_to_jax(scales), #4
    torch_to_jax(rotations), #5 
    # raster_settings.scale_modifier), 
    torch_to_jax(cov3D_precomp), #6 
    torch_to_jax(raster_settings.viewmatrix), #7 
    torch_to_jax(raster_settings.projmatrix), #8
    grad_out_color_jax, #9
    torch_to_jax(sh), #10
    torch_to_jax(raster_settings.campos), #11
    geomBuffer_jax, #12
    jnp.array([[1]]),#num_rendered_jax, #13 
    binningBuffer_jax, #14
    imgBuffer_jax #15
)
print("2")

1
2


In [None]:

(grad_means2D_jax,
 grad_colors_precomp_jax,
 grad_opacities_jax,
 grad_means3D_jax,
 grad_cov3Ds_precomp_jax,
 grad_sh_jax,
 grad_scales_jax, grad_rotations_jax, _) = rasterizer_bwd_jax.bind(
            *jax_bwd_args,
            tanfovx=tan_fovx, 
            tanfovy=tan_fovy, 
            sh_degree=0
)  


start
start2
start3
start4
Num gaussians 100
Num expanded gaussians 1
