Skip to content

Commit

Permalink
Compute density compensation for screen space blurring
Browse files Browse the repository at this point in the history
  • Loading branch information
jb-ye authored and Jianbo Ye committed Feb 8, 2024
1 parent 9fffa1a commit 1e28201
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 19 deletions.
16 changes: 11 additions & 5 deletions gsplat/_torch_impl.py
Expand Up @@ -5,6 +5,7 @@
import torch.nn.functional as F
from jaxtyping import Float
from torch import Tensor
from typing import Tuple


def compute_sh_color(
Expand Down Expand Up @@ -149,7 +150,7 @@ def project_cov3d_ewa(
fy: float,
tan_fovx: float,
tan_fovy: float,
) -> Tensor:
) -> Tuple[Tensor, Tensor]:
assert mean3d.shape[-1] == 3, mean3d.shape
assert cov3d.shape[-2:] == (3, 3), cov3d.shape
assert viewmat.shape[-2:] == (4, 4), viewmat.shape
Expand All @@ -174,9 +175,12 @@ def project_cov3d_ewa(
T = torch.matmul(J, W) # (..., 2, 3)
cov2d = torch.einsum("...ij,...jk,...kl->...il", T, cov3d, T.transpose(-1, -2))
# add a little blur along axes and (TODO save upper triangular elements)
det_orig = cov2d[..., 0, 0] * cov2d[..., 1, 1] - cov2d[..., 0, 1] * cov2d[..., 0, 1]
cov2d[..., 0, 0] = cov2d[..., 0, 0] + 0.3
cov2d[..., 1, 1] = cov2d[..., 1, 1] + 0.3
return cov2d[..., :2, :2]
det_blur = cov2d[..., 0, 0] * cov2d[..., 1, 1] - cov2d[..., 0, 1] * cov2d[..., 0, 1]
compensation = torch.sqrt(torch.clamp(det_orig / det_blur, min=0))
return cov2d[..., :2, :2], compensation.detach()


def compute_cov2d_bounds(cov2d_mat: Tensor):
Expand Down Expand Up @@ -272,7 +276,9 @@ def project_gaussians_forward(
tan_fovy = 0.5 * img_size[1] / fy
p_view, is_close = clip_near_plane(means3d, viewmat, clip_thresh)
cov3d = scale_rot_to_cov3d(scales, glob_scale, quats)
cov2d = project_cov3d_ewa(means3d, cov3d, viewmat, fx, fy, tan_fovx, tan_fovy)
cov2d, compensation = project_cov3d_ewa(
means3d, cov3d, viewmat, fx, fy, tan_fovx, tan_fovy
)
conic, radius, det_valid = compute_cov2d_bounds(cov2d)
xys = project_pix(fullmat, means3d, img_size, (cx, cy))
tile_min, tile_max = get_tile_bbox(xys, radius, tile_bounds)
Expand All @@ -290,14 +296,15 @@ def project_gaussians_forward(
xys = torch.where(~mask[..., None], 0, xys)
cov3d = torch.where(~mask[..., None, None], 0, cov3d)
cov2d = torch.where(~mask[..., None, None], 0, cov2d)
compensation = torch.where(~mask, 0, compensation)
num_tiles_hit = torch.where(~mask, 0, num_tiles_hit)
depths = torch.where(~mask, 0, depths)

i, j = torch.triu_indices(3, 3)
cov3d_triu = cov3d[..., i, j]
i, j = torch.triu_indices(2, 2)
cov2d_triu = cov2d[..., i, j]
return cov3d_triu, cov2d_triu, xys, depths, radii, conic, num_tiles_hit, mask
return cov3d_triu, cov2d_triu, xys, depths, radii, conic, compensation, num_tiles_hit, mask


def map_gaussian_to_intersects(
Expand Down Expand Up @@ -339,7 +346,6 @@ def get_tile_bin_edges(num_intersects, isect_ids_sorted, tile_bounds):
)

for idx in range(num_intersects):

cur_tile_idx = isect_ids_sorted[idx] >> 32

if idx == 0:
Expand Down
6 changes: 5 additions & 1 deletion gsplat/cuda/csrc/bindings.cu
Expand Up @@ -122,6 +122,7 @@ std::tuple<
torch::Tensor,
torch::Tensor,
torch::Tensor,
torch::Tensor,
torch::Tensor>
project_gaussians_forward_tensor(
const int num_points,
Expand Down Expand Up @@ -162,6 +163,8 @@ project_gaussians_forward_tensor(
torch::zeros({num_points}, means3d.options().dtype(torch::kInt32));
torch::Tensor conics_d =
torch::zeros({num_points, 3}, means3d.options().dtype(torch::kFloat32));
torch::Tensor compensation_d =
torch::zeros({num_points}, means3d.options().dtype(torch::kFloat32));
torch::Tensor num_tiles_hit_d =
torch::zeros({num_points}, means3d.options().dtype(torch::kInt32));

Expand All @@ -185,11 +188,12 @@ project_gaussians_forward_tensor(
depths_d.contiguous().data_ptr<float>(),
radii_d.contiguous().data_ptr<int>(),
(float3 *)conics_d.contiguous().data_ptr<float>(),
compensation_d.contiguous().data_ptr<float>(),
num_tiles_hit_d.contiguous().data_ptr<int32_t>()
);

return std::make_tuple(
cov3d_d, xys_d, depths_d, radii_d, conics_d, num_tiles_hit_d
cov3d_d, xys_d, depths_d, radii_d, conics_d, compensation_d, num_tiles_hit_d
);
}

Expand Down
1 change: 1 addition & 0 deletions gsplat/cuda/csrc/bindings.h
Expand Up @@ -40,6 +40,7 @@ std::tuple<
torch::Tensor,
torch::Tensor,
torch::Tensor,
torch::Tensor,
torch::Tensor>
project_gaussians_forward_tensor(
const int num_points,
Expand Down
24 changes: 19 additions & 5 deletions gsplat/cuda/csrc/forward.cu
Expand Up @@ -26,6 +26,7 @@ __global__ void project_gaussians_forward_kernel(
float* __restrict__ depths,
int* __restrict__ radii,
float3* __restrict__ conics,
float* __restrict__ compensation,
int32_t* __restrict__ num_tiles_hit
) {
unsigned idx = cg::this_grid().thread_rank(); // idx of thread within grid
Expand Down Expand Up @@ -61,8 +62,11 @@ __global__ void project_gaussians_forward_kernel(
float cy = intrins.w;
float tan_fovx = 0.5 * img_size.x / fx;
float tan_fovy = 0.5 * img_size.y / fy;
float3 cov2d = project_cov3d_ewa(
p_world, cur_cov3d, viewmat, fx, fy, tan_fovx, tan_fovy
float3 cov2d;
float comp;
project_cov3d_ewa(
p_world, cur_cov3d, viewmat, fx, fy, tan_fovx, tan_fovy,
cov2d, comp
);
// printf("cov2d %d, %.2f %.2f %.2f\n", idx, cov2d.x, cov2d.y, cov2d.z);

Expand All @@ -88,6 +92,7 @@ __global__ void project_gaussians_forward_kernel(
depths[idx] = p_view.z;
radii[idx] = (int)radius;
xys[idx] = center;
compensation[idx] = comp;
// printf(
// "point %d x %.2f y %.2f z %.2f, radius %d, # tiles %d, tile_min %d
// %d, tile_max %d %d\n", idx, center.x, center.y, depths[idx],
Expand Down Expand Up @@ -372,14 +377,16 @@ __global__ void rasterize_forward(
}

// device helper to approximate projected 2d cov from 3d mean and cov
__device__ float3 project_cov3d_ewa(
__device__ void project_cov3d_ewa(
const float3& __restrict__ mean3d,
const float* __restrict__ cov3d,
const float* __restrict__ viewmat,
const float fx,
const float fy,
const float tan_fovx,
const float tan_fovy
const float tan_fovy,
float3 &cov2d,
float &compensation
) {
// clip the
// we expect row major matrices as input, glm uses column major
Expand Down Expand Up @@ -437,7 +444,14 @@ __device__ float3 project_cov3d_ewa(
glm::mat3 cov = T * V * glm::transpose(T);

// add a little blur along axes and save upper triangular elements
return make_float3(float(cov[0][0]) + 0.3f, float(cov[0][1]), float(cov[1][1]) + 0.3f);
// and compute the density compensation factor due to the blurs
float c00 = cov[0][0], c11 = cov[1][1], c01 = cov[0][1];
float det_orig = c00 * c11 - c01 * c01;
cov2d.x = c00 + 0.3f;
cov2d.y = c01;
cov2d.z = c11 + 0.3f;
float det_blur = cov2d.x * cov2d.z - cov2d.y * cov2d.y;
compensation = std::sqrt(std::max(0.f, det_orig / det_blur));
}

// device helper to get 3D covariance from scale and quat parameters
Expand Down
7 changes: 5 additions & 2 deletions gsplat/cuda/csrc/forward.cuh
Expand Up @@ -20,6 +20,7 @@ __global__ void project_gaussians_forward_kernel(
float* __restrict__ depths,
int* __restrict__ radii,
float3* __restrict__ conics,
float* __restrict__ compensation,
int32_t* __restrict__ num_tiles_hit
);

Expand Down Expand Up @@ -57,14 +58,16 @@ __global__ void nd_rasterize_forward(
);

// device helper to approximate projected 2d cov from 3d mean and cov
__device__ float3 project_cov3d_ewa(
__device__ void project_cov3d_ewa(
const float3 &mean3d,
const float *cov3d,
const float *viewmat,
const float fx,
const float fy,
const float tan_fovx,
const float tan_fovy
const float tan_fovy,
float3 &cov2d,
float &comp
);

// device helper to get 3D covariance from scale and quat parameters
Expand Down
19 changes: 15 additions & 4 deletions gsplat/project_gaussians.py
Expand Up @@ -24,7 +24,7 @@ def project_gaussians(
img_width: int,
tile_bounds: Tuple[int, int, int],
clip_thresh: float = 0.01,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, int, Tensor]:
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
"""This function projects 3D gaussians to 2D using the EWA splatting method for gaussian splatting.
Note:
Expand All @@ -47,12 +47,13 @@ def project_gaussians(
clip_thresh (float): minimum z depth threshold.
Returns:
A tuple of {Tensor, Tensor, Tensor, Tensor, Tensor, Tensor}:
A tuple of {Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor}:
- **xys** (Tensor): x,y locations of 2D gaussian projections.
- **depths** (Tensor): z depth of gaussians.
- **radii** (Tensor): radii of 2D gaussian projections.
- **conics** (Tensor): conic parameters for 2D gaussian.
- **compensation** (Tensor): the density compensation for blurring 2D kernel
- **num_tiles_hit** (Tensor): number of tiles hit per gaussian.
- **cov3d** (Tensor): 3D covariances.
"""
Expand Down Expand Up @@ -105,6 +106,7 @@ def forward(
depths,
radii,
conics,
compensation,
num_tiles_hit,
) = _C.project_gaussians_forward(
num_points,
Expand Down Expand Up @@ -146,10 +148,19 @@ def forward(
conics,
)

return (xys, depths, radii, conics, num_tiles_hit, cov3d)
return (xys, depths, radii, conics, compensation, num_tiles_hit, cov3d)

@staticmethod
def backward(ctx, v_xys, v_depths, v_radii, v_conics, v_num_tiles_hit, v_cov3d):
def backward(
ctx,
v_xys,
v_depths,
v_radii,
v_conics,
v_compensation,
v_num_tiles_hit,
v_cov3d,
):
(
means3d,
scales,
Expand Down
7 changes: 5 additions & 2 deletions tests/test_project_gaussians.py
Expand Up @@ -66,7 +66,7 @@ def test_project_gaussians_forward():
BLOCK_X, BLOCK_Y = 16, 16
tile_bounds = (W + BLOCK_X - 1) // BLOCK_X, (H + BLOCK_Y - 1) // BLOCK_Y, 1

(cov3d, xys, depths, radii, conics, num_tiles_hit,) = _C.project_gaussians_forward(
(cov3d, xys, depths, radii, conics, compensation, num_tiles_hit,) = _C.project_gaussians_forward(
num_points,
means3d,
scales,
Expand All @@ -93,6 +93,7 @@ def test_project_gaussians_forward():
_depths,
_radii,
_conics,
_compensation,
_num_tiles_hit,
_masks,
) = _torch_impl.project_gaussians_forward(
Expand All @@ -114,6 +115,7 @@ def test_project_gaussians_forward():
check_close(depths, _depths)
check_close(radii, _radii)
check_close(conics, _conics)
check_close(compensation, _compensation)
check_close(num_tiles_hit, _num_tiles_hit)
print("passed project_gaussians_forward test")

Expand Down Expand Up @@ -156,6 +158,7 @@ def test_project_gaussians_backward():
radii,
conics,
_,
_,
masks,
) = _torch_impl.project_gaussians_forward(
means3d,
Expand Down Expand Up @@ -219,7 +222,7 @@ def project_cov3d_ewa_partial(mean3d, cov3d):
i, j = torch.triu_indices(3, 3)
cov3d_mat[..., i, j] = cov3d
cov3d_mat[..., [1, 2, 2], [0, 0, 1]] = cov3d[..., [1, 2, 4]]
cov2d = _torch_impl.project_cov3d_ewa(
cov2d, _ = _torch_impl.project_cov3d_ewa(
mean3d, cov3d_mat, viewmat, fx, fy, tan_fovx, tan_fovy
)
ii, jj = torch.triu_indices(2, 2)
Expand Down

0 comments on commit 1e28201

Please sign in to comment.