Skip to content

Commit

Permalink
Compute density compensation for screen space blurring
Browse files Browse the repository at this point in the history
  • Loading branch information
jb-ye authored and Jianbo Ye committed Feb 5, 2024
1 parent 210ed53 commit 7ba094a
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 17 deletions.
15 changes: 10 additions & 5 deletions gsplat/_torch_impl.py
Expand Up @@ -5,6 +5,7 @@
import torch.nn.functional as F
from jaxtyping import Float
from torch import Tensor
from typing import Tuple


def compute_sh_color(
Expand Down Expand Up @@ -163,7 +164,7 @@ def project_cov3d_ewa(
fy: float,
tan_fovx: float,
tan_fovy: float,
) -> Tensor:
) -> Tuple[Tensor, Tensor]:
assert mean3d.shape[-1] == 3, mean3d.shape
assert cov3d.shape[-2:] == (3, 3), cov3d.shape
assert viewmat.shape[-2:] == (4, 4), viewmat.shape
Expand All @@ -190,9 +191,12 @@ def project_cov3d_ewa(
T = J @ W # (..., 2, 3)
cov2d = T @ cov3d @ T.transpose(-1, -2) # (..., 2, 2)
# add a little blur along axes and (TODO save upper triangular elements)
det_orig = cov2d[..., 0, 0] * cov2d[..., 1, 1] - cov2d[..., 0, 1] * cov2d[..., 0, 1]
cov2d[..., 0, 0] = cov2d[..., 0, 0] + 0.3
cov2d[..., 1, 1] = cov2d[..., 1, 1] + 0.3
return cov2d
det_blur = cov2d[..., 0, 0] * cov2d[..., 1, 1] - cov2d[..., 0, 1] * cov2d[..., 0, 1]
compensation = torch.sqrt(torch.clamp(det_orig / det_blur, min=0))
return cov2d, compensation.detach()


def compute_cov2d_bounds(cov2d: Tensor, eps=1e-6):
Expand Down Expand Up @@ -277,7 +281,9 @@ def project_gaussians_forward(
tan_fovy = 0.5 * img_size[0] / fy
p_view, is_close = clip_near_plane(means3d, viewmat, clip_thresh)
cov3d = scale_rot_to_cov3d(scales, glob_scale, quats)
cov2d = project_cov3d_ewa(means3d, cov3d, viewmat, fx, fy, tan_fovx, tan_fovy)
cov2d, compensation = project_cov3d_ewa(
means3d, cov3d, viewmat, fx, fy, tan_fovx, tan_fovy
)
conic, radius, det_valid = compute_cov2d_bounds(cov2d)
center = project_pix(projmat, means3d, img_size)
tile_min, tile_max = get_tile_bbox(center, radius, tile_bounds)
Expand All @@ -292,7 +298,7 @@ def project_gaussians_forward(
xys = center
conics = conic

return cov3d, xys, depths, radii, conics, num_tiles_hit, mask
return cov3d, xys, depths, radii, conics, compensation, num_tiles_hit, mask


def map_gaussian_to_intersects(
Expand Down Expand Up @@ -334,7 +340,6 @@ def get_tile_bin_edges(num_intersects, isect_ids_sorted, tile_bounds):
)

for idx in range(num_intersects):

cur_tile_idx = isect_ids_sorted[idx] >> 32

if idx == 0:
Expand Down
6 changes: 5 additions & 1 deletion gsplat/cuda/csrc/bindings.cu
Expand Up @@ -122,6 +122,7 @@ std::tuple<
torch::Tensor,
torch::Tensor,
torch::Tensor,
torch::Tensor,
torch::Tensor>
project_gaussians_forward_tensor(
const int num_points,
Expand Down Expand Up @@ -162,6 +163,8 @@ project_gaussians_forward_tensor(
torch::zeros({num_points}, means3d.options().dtype(torch::kInt32));
torch::Tensor conics_d =
torch::zeros({num_points, 3}, means3d.options().dtype(torch::kFloat32));
torch::Tensor compensation_d =
torch::zeros({num_points}, means3d.options().dtype(torch::kFloat32));
torch::Tensor num_tiles_hit_d =
torch::zeros({num_points}, means3d.options().dtype(torch::kInt32));

Expand All @@ -185,11 +188,12 @@ project_gaussians_forward_tensor(
depths_d.contiguous().data_ptr<float>(),
radii_d.contiguous().data_ptr<int>(),
(float3 *)conics_d.contiguous().data_ptr<float>(),
compensation_d.contiguous().data_ptr<float>(),
num_tiles_hit_d.contiguous().data_ptr<int32_t>()
);

return std::make_tuple(
cov3d_d, xys_d, depths_d, radii_d, conics_d, num_tiles_hit_d
cov3d_d, xys_d, depths_d, radii_d, conics_d, compensation_d, num_tiles_hit_d
);
}

Expand Down
1 change: 1 addition & 0 deletions gsplat/cuda/csrc/bindings.h
Expand Up @@ -40,6 +40,7 @@ std::tuple<
torch::Tensor,
torch::Tensor,
torch::Tensor,
torch::Tensor,
torch::Tensor>
project_gaussians_forward_tensor(
const int num_points,
Expand Down
24 changes: 19 additions & 5 deletions gsplat/cuda/csrc/forward.cu
Expand Up @@ -26,6 +26,7 @@ __global__ void project_gaussians_forward_kernel(
float* __restrict__ depths,
int* __restrict__ radii,
float3* __restrict__ conics,
float* __restrict__ compensation,
int32_t* __restrict__ num_tiles_hit
) {
unsigned idx = cg::this_grid().thread_rank(); // idx of thread within grid
Expand Down Expand Up @@ -61,8 +62,11 @@ __global__ void project_gaussians_forward_kernel(
float cy = intrins.w;
float tan_fovx = 0.5 * img_size.x / fx;
float tan_fovy = 0.5 * img_size.y / fy;
float3 cov2d = project_cov3d_ewa(
p_world, cur_cov3d, viewmat, fx, fy, tan_fovx, tan_fovy
float3 cov2d;
float comp;
project_cov3d_ewa(
p_world, cur_cov3d, viewmat, fx, fy, tan_fovx, tan_fovy,
cov2d, comp
);
// printf("cov2d %d, %.2f %.2f %.2f\n", idx, cov2d.x, cov2d.y, cov2d.z);

Expand All @@ -88,6 +92,7 @@ __global__ void project_gaussians_forward_kernel(
depths[idx] = p_view.z;
radii[idx] = (int)radius;
xys[idx] = center;
compensation[idx] = comp;
// printf(
// "point %d x %.2f y %.2f z %.2f, radius %d, # tiles %d, tile_min %d
// %d, tile_max %d %d\n", idx, center.x, center.y, depths[idx],
Expand Down Expand Up @@ -372,14 +377,16 @@ __global__ void rasterize_forward(
}

// device helper to approximate projected 2d cov from 3d mean and cov
__device__ float3 project_cov3d_ewa(
__device__ void project_cov3d_ewa(
const float3& __restrict__ mean3d,
const float* __restrict__ cov3d,
const float* __restrict__ viewmat,
const float fx,
const float fy,
const float tan_fovx,
const float tan_fovy
const float tan_fovy,
float3 &cov2d,
float &compensation
) {
// clip the
// we expect row major matrices as input, glm uses column major
Expand Down Expand Up @@ -437,7 +444,14 @@ __device__ float3 project_cov3d_ewa(
glm::mat3 cov = T * V * glm::transpose(T);

// add a little blur along axes and save upper triangular elements
return make_float3(float(cov[0][0]) + 0.3f, float(cov[0][1]), float(cov[1][1]) + 0.3f);
// and compute the density compensation factor due to the blurs
float c00 = cov[0][0], c11 = cov[1][1], c01 = cov[0][1];
float det_orig = c00 * c11 - c01 * c01;
cov2d.x = c00 + 0.3f;
cov2d.y = c01;
cov2d.z = c11 + 0.3f;
float det_blur = cov2d.x * cov2d.z - cov2d.y * cov2d.y;
compensation = std::sqrt(std::max(0.f, det_orig / det_blur));
}

// device helper to get 3D covariance from scale and quat parameters
Expand Down
7 changes: 5 additions & 2 deletions gsplat/cuda/csrc/forward.cuh
Expand Up @@ -20,6 +20,7 @@ __global__ void project_gaussians_forward_kernel(
float* __restrict__ depths,
int* __restrict__ radii,
float3* __restrict__ conics,
float* __restrict__ compensation,
int32_t* __restrict__ num_tiles_hit
);

Expand Down Expand Up @@ -57,14 +58,16 @@ __global__ void nd_rasterize_forward(
);

// device helper to approximate projected 2d cov from 3d mean and cov
__device__ float3 project_cov3d_ewa(
__device__ void project_cov3d_ewa(
const float3 &mean3d,
const float *cov3d,
const float *viewmat,
const float fx,
const float fy,
const float tan_fovx,
const float tan_fovy
const float tan_fovy,
float3 &cov2d,
float &comp
);

// device helper to get 3D covariance from scale and quat parameters
Expand Down
19 changes: 15 additions & 4 deletions gsplat/project_gaussians.py
Expand Up @@ -24,7 +24,7 @@ def project_gaussians(
img_width: int,
tile_bounds: Tuple[int, int, int],
clip_thresh: float = 0.01,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, int, Tensor]:
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
"""This function projects 3D gaussians to 2D using the EWA splatting method for gaussian splatting.
Note:
Expand All @@ -47,12 +47,13 @@ def project_gaussians(
clip_thresh (float): minimum z depth threshold.
Returns:
A tuple of {Tensor, Tensor, Tensor, Tensor, Tensor, Tensor}:
A tuple of {Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor}:
- **xys** (Tensor): x,y locations of 2D gaussian projections.
- **depths** (Tensor): z depth of gaussians.
- **radii** (Tensor): radii of 2D gaussian projections.
- **conics** (Tensor): conic parameters for 2D gaussian.
- **compensation** (Tensor): the density compensation for blurring 2D kernel
- **num_tiles_hit** (Tensor): number of tiles hit per gaussian.
- **cov3d** (Tensor): 3D covariances.
"""
Expand Down Expand Up @@ -105,6 +106,7 @@ def forward(
depths,
radii,
conics,
compensation,
num_tiles_hit,
) = _C.project_gaussians_forward(
num_points,
Expand Down Expand Up @@ -146,10 +148,19 @@ def forward(
conics,
)

return (xys, depths, radii, conics, num_tiles_hit, cov3d)
return (xys, depths, radii, conics, compensation, num_tiles_hit, cov3d)

@staticmethod
def backward(ctx, v_xys, v_depths, v_radii, v_conics, v_num_tiles_hit, v_cov3d):
def backward(
ctx,
v_xys,
v_depths,
v_radii,
v_conics,
v_compensation,
v_num_tiles_hit,
v_cov3d,
):
(
means3d,
scales,
Expand Down

0 comments on commit 7ba094a

Please sign in to comment.