Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compute density compensation for screen space blurring of tiny gaussians #117

Merged
merged 1 commit into from Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
38 changes: 27 additions & 11 deletions gsplat/_torch_impl.py
Expand Up @@ -5,6 +5,7 @@
import torch.nn.functional as F
from jaxtyping import Float
from torch import Tensor
from typing import Tuple


def compute_sh_color(
Expand Down Expand Up @@ -116,15 +117,15 @@ def quat_to_rotmat(quat: Tensor) -> Tensor:
w, x, y, z = torch.unbind(F.normalize(quat, dim=-1), dim=-1)
mat = torch.stack(
[
1 - 2 * (y ** 2 + z ** 2),
1 - 2 * (y**2 + z**2),
2 * (x * y - w * z),
2 * (x * z + w * y),
2 * (x * y + w * z),
1 - 2 * (x ** 2 + z ** 2),
1 - 2 * (x**2 + z**2),
2 * (y * z - w * x),
2 * (x * z - w * y),
2 * (y * z + w * x),
1 - 2 * (x ** 2 + y ** 2),
1 - 2 * (x**2 + y**2),
],
dim=-1,
)
Expand All @@ -149,7 +150,7 @@ def project_cov3d_ewa(
fy: float,
tan_fovx: float,
tan_fovy: float,
) -> Tensor:
) -> Tuple[Tensor, Tensor]:
assert mean3d.shape[-1] == 3, mean3d.shape
assert cov3d.shape[-2:] == (3, 3), cov3d.shape
assert viewmat.shape[-2:] == (4, 4), viewmat.shape
Expand All @@ -158,7 +159,7 @@ def project_cov3d_ewa(
t = torch.einsum("...ij,...j->...i", W, mean3d) + p # (..., 3)

rz = 1.0 / t[..., 2] # (...,)
rz2 = rz ** 2 # (...,)
rz2 = rz**2 # (...,)

lim_x = 1.3 * torch.tensor([tan_fovx], device=mean3d.device)
lim_y = 1.3 * torch.tensor([tan_fovy], device=mean3d.device)
Expand All @@ -174,9 +175,12 @@ def project_cov3d_ewa(
T = torch.matmul(J, W) # (..., 2, 3)
cov2d = torch.einsum("...ij,...jk,...kl->...il", T, cov3d, T.transpose(-1, -2))
# add a little blur along axes and (TODO save upper triangular elements)
det_orig = cov2d[..., 0, 0] * cov2d[..., 1, 1] - cov2d[..., 0, 1] * cov2d[..., 0, 1]
cov2d[..., 0, 0] = cov2d[..., 0, 0] + 0.3
cov2d[..., 1, 1] = cov2d[..., 1, 1] + 0.3
return cov2d[..., :2, :2]
det_blur = cov2d[..., 0, 0] * cov2d[..., 1, 1] - cov2d[..., 0, 1] * cov2d[..., 0, 1]
compensation = torch.sqrt(torch.clamp(det_orig / det_blur, min=0))
return cov2d[..., :2, :2], compensation.detach()


def compute_cov2d_bounds(cov2d_mat: Tensor):
Expand All @@ -198,8 +202,8 @@ def compute_cov2d_bounds(cov2d_mat: Tensor):
dim=-1,
) # (..., 3)
b = (cov2d[..., 0, 0] + cov2d[..., 1, 1]) / 2 # (...,)
v1 = b + torch.sqrt(torch.clamp(b ** 2 - det, min=0.1)) # (...,)
v2 = b - torch.sqrt(torch.clamp(b ** 2 - det, min=0.1)) # (...,)
v1 = b + torch.sqrt(torch.clamp(b**2 - det, min=0.1)) # (...,)
v2 = b - torch.sqrt(torch.clamp(b**2 - det, min=0.1)) # (...,)
radius = torch.ceil(3.0 * torch.sqrt(torch.max(v1, v2))) # (...,)
radius_all = torch.zeros(*cov2d_mat.shape[:-2], device=cov2d_mat.device)
conic_all = torch.zeros(*cov2d_mat.shape[:-2], 3, device=cov2d_mat.device)
Expand Down Expand Up @@ -272,7 +276,9 @@ def project_gaussians_forward(
tan_fovy = 0.5 * img_size[1] / fy
p_view, is_close = clip_near_plane(means3d, viewmat, clip_thresh)
cov3d = scale_rot_to_cov3d(scales, glob_scale, quats)
cov2d = project_cov3d_ewa(means3d, cov3d, viewmat, fx, fy, tan_fovx, tan_fovy)
cov2d, compensation = project_cov3d_ewa(
means3d, cov3d, viewmat, fx, fy, tan_fovx, tan_fovy
)
conic, radius, det_valid = compute_cov2d_bounds(cov2d)
xys = project_pix(fullmat, means3d, img_size, (cx, cy))
tile_min, tile_max = get_tile_bbox(xys, radius, tile_bounds)
Expand All @@ -290,14 +296,25 @@ def project_gaussians_forward(
xys = torch.where(~mask[..., None], 0, xys)
cov3d = torch.where(~mask[..., None, None], 0, cov3d)
cov2d = torch.where(~mask[..., None, None], 0, cov2d)
compensation = torch.where(~mask, 0, compensation)
num_tiles_hit = torch.where(~mask, 0, num_tiles_hit)
depths = torch.where(~mask, 0, depths)

i, j = torch.triu_indices(3, 3)
cov3d_triu = cov3d[..., i, j]
i, j = torch.triu_indices(2, 2)
cov2d_triu = cov2d[..., i, j]
return cov3d_triu, cov2d_triu, xys, depths, radii, conic, num_tiles_hit, mask
return (
cov3d_triu,
cov2d_triu,
xys,
depths,
radii,
conic,
compensation,
num_tiles_hit,
mask,
)


def map_gaussian_to_intersects(
Expand Down Expand Up @@ -339,7 +356,6 @@ def get_tile_bin_edges(num_intersects, isect_ids_sorted, tile_bounds):
)

for idx in range(num_intersects):

cur_tile_idx = isect_ids_sorted[idx] >> 32

if idx == 0:
Expand Down
6 changes: 5 additions & 1 deletion gsplat/cuda/csrc/bindings.cu
Expand Up @@ -122,6 +122,7 @@ std::tuple<
torch::Tensor,
torch::Tensor,
torch::Tensor,
torch::Tensor,
torch::Tensor>
project_gaussians_forward_tensor(
const int num_points,
Expand Down Expand Up @@ -162,6 +163,8 @@ project_gaussians_forward_tensor(
torch::zeros({num_points}, means3d.options().dtype(torch::kInt32));
torch::Tensor conics_d =
torch::zeros({num_points, 3}, means3d.options().dtype(torch::kFloat32));
torch::Tensor compensation_d =
torch::zeros({num_points}, means3d.options().dtype(torch::kFloat32));
torch::Tensor num_tiles_hit_d =
torch::zeros({num_points}, means3d.options().dtype(torch::kInt32));

Expand All @@ -185,11 +188,12 @@ project_gaussians_forward_tensor(
depths_d.contiguous().data_ptr<float>(),
radii_d.contiguous().data_ptr<int>(),
(float3 *)conics_d.contiguous().data_ptr<float>(),
compensation_d.contiguous().data_ptr<float>(),
num_tiles_hit_d.contiguous().data_ptr<int32_t>()
);

return std::make_tuple(
cov3d_d, xys_d, depths_d, radii_d, conics_d, num_tiles_hit_d
cov3d_d, xys_d, depths_d, radii_d, conics_d, compensation_d, num_tiles_hit_d
);
}

Expand Down
1 change: 1 addition & 0 deletions gsplat/cuda/csrc/bindings.h
Expand Up @@ -40,6 +40,7 @@ std::tuple<
torch::Tensor,
torch::Tensor,
torch::Tensor,
torch::Tensor,
torch::Tensor>
project_gaussians_forward_tensor(
const int num_points,
Expand Down
24 changes: 19 additions & 5 deletions gsplat/cuda/csrc/forward.cu
Expand Up @@ -26,6 +26,7 @@ __global__ void project_gaussians_forward_kernel(
float* __restrict__ depths,
int* __restrict__ radii,
float3* __restrict__ conics,
float* __restrict__ compensation,
int32_t* __restrict__ num_tiles_hit
) {
unsigned idx = cg::this_grid().thread_rank(); // idx of thread within grid
Expand Down Expand Up @@ -61,8 +62,11 @@ __global__ void project_gaussians_forward_kernel(
float cy = intrins.w;
float tan_fovx = 0.5 * img_size.x / fx;
float tan_fovy = 0.5 * img_size.y / fy;
float3 cov2d = project_cov3d_ewa(
p_world, cur_cov3d, viewmat, fx, fy, tan_fovx, tan_fovy
float3 cov2d;
float comp;
project_cov3d_ewa(
p_world, cur_cov3d, viewmat, fx, fy, tan_fovx, tan_fovy,
cov2d, comp
);
// printf("cov2d %d, %.2f %.2f %.2f\n", idx, cov2d.x, cov2d.y, cov2d.z);

Expand All @@ -88,6 +92,7 @@ __global__ void project_gaussians_forward_kernel(
depths[idx] = p_view.z;
radii[idx] = (int)radius;
xys[idx] = center;
compensation[idx] = comp;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Read/Write the global memory is usually the most time consuming part in a kernel (computation is usually not the burden). I tested this a bit and it slows down the project_gaussians from 3000 it/s to 2800 it/s which is not that much so I think is fine. Especially that project_gaussians is much cheaper comparing to the rasterization stage. I'm fine with this tiny little extra burden but just want to point it out for future reference.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code I used to test this.

import torch

def profiling(N: int = 1000000, D: int = 3):
    import tqdm

    from gsplat import project_gaussians, rasterize_gaussians

    torch.manual_seed(42)
    device = torch.device("cuda:0")

    means3d = torch.rand((N, 3), device=device, requires_grad=False)
    scales = torch.rand((N, 3), device=device) * 5
    quats = torch.randn((N, 4), device=device)
    quats /= torch.linalg.norm(quats, dim=-1, keepdim=True)

    viewmat = projmat = torch.eye(4, device=device)
    fx = fy = 3.0
    H, W = 256, 256
    BLOCK_X = BLOCK_Y = 16
    tile_bounds = (W + BLOCK_X - 1) // BLOCK_X, (H + BLOCK_Y - 1) // BLOCK_Y, 1

    pbar = tqdm.trange(10000)
    for _ in pbar:
        xys, depths, radii, conics, num_tiles_hit, cov3d = project_gaussians(
            means3d,
            scales,
            1,
            quats,
            viewmat,
            projmat,
            fx,
            fy,
            W / 2,
            H / 2,
            H,
            W,
            tile_bounds,
        )

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the evaluation effort.

// printf(
// "point %d x %.2f y %.2f z %.2f, radius %d, # tiles %d, tile_min %d
// %d, tile_max %d %d\n", idx, center.x, center.y, depths[idx],
Expand Down Expand Up @@ -372,14 +377,16 @@ __global__ void rasterize_forward(
}

// device helper to approximate projected 2d cov from 3d mean and cov
__device__ float3 project_cov3d_ewa(
__device__ void project_cov3d_ewa(
const float3& __restrict__ mean3d,
const float* __restrict__ cov3d,
const float* __restrict__ viewmat,
const float fx,
const float fy,
const float tan_fovx,
const float tan_fovy
const float tan_fovy,
float3 &cov2d,
float &compensation
) {
// clip the
// we expect row major matrices as input, glm uses column major
Expand Down Expand Up @@ -437,7 +444,14 @@ __device__ float3 project_cov3d_ewa(
glm::mat3 cov = T * V * glm::transpose(T);

// add a little blur along axes and save upper triangular elements
return make_float3(float(cov[0][0]) + 0.3f, float(cov[0][1]), float(cov[1][1]) + 0.3f);
// and compute the density compensation factor due to the blurs
float c00 = cov[0][0], c11 = cov[1][1], c01 = cov[0][1];
float det_orig = c00 * c11 - c01 * c01;
cov2d.x = c00 + 0.3f;
cov2d.y = c01;
cov2d.z = c11 + 0.3f;
float det_blur = cov2d.x * cov2d.z - cov2d.y * cov2d.y;
compensation = std::sqrt(std::max(0.f, det_orig / det_blur));
}

// device helper to get 3D covariance from scale and quat parameters
Expand Down
7 changes: 5 additions & 2 deletions gsplat/cuda/csrc/forward.cuh
Expand Up @@ -20,6 +20,7 @@ __global__ void project_gaussians_forward_kernel(
float* __restrict__ depths,
int* __restrict__ radii,
float3* __restrict__ conics,
float* __restrict__ compensation,
int32_t* __restrict__ num_tiles_hit
);

Expand Down Expand Up @@ -57,14 +58,16 @@ __global__ void nd_rasterize_forward(
);

// device helper to approximate projected 2d cov from 3d mean and cov
__device__ float3 project_cov3d_ewa(
__device__ void project_cov3d_ewa(
const float3 &mean3d,
const float *cov3d,
const float *viewmat,
const float fx,
const float fy,
const float tan_fovx,
const float tan_fovy
const float tan_fovy,
float3 &cov2d,
float &comp
);

// device helper to get 3D covariance from scale and quat parameters
Expand Down
19 changes: 15 additions & 4 deletions gsplat/project_gaussians.py
Expand Up @@ -24,7 +24,7 @@ def project_gaussians(
img_width: int,
tile_bounds: Tuple[int, int, int],
clip_thresh: float = 0.01,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, int, Tensor]:
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
"""This function projects 3D gaussians to 2D using the EWA splatting method for gaussian splatting.

Note:
Expand All @@ -47,12 +47,13 @@ def project_gaussians(
clip_thresh (float): minimum z depth threshold.

Returns:
A tuple of {Tensor, Tensor, Tensor, Tensor, Tensor, Tensor}:
A tuple of {Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor}:

- **xys** (Tensor): x,y locations of 2D gaussian projections.
- **depths** (Tensor): z depth of gaussians.
- **radii** (Tensor): radii of 2D gaussian projections.
- **conics** (Tensor): conic parameters for 2D gaussian.
- **compensation** (Tensor): the density compensation for blurring 2D kernel
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This extra return would break backward compatibility. Personally I'm fine with it as we are in active-developing version 0.1.x. But I'll let @vye16 to decide starting from when we want to maintain backward compatibility.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @vye16 , could you help take a look at this PR and see if you have any other comments other than @liruilong940607

- **num_tiles_hit** (Tensor): number of tiles hit per gaussian.
- **cov3d** (Tensor): 3D covariances.
"""
Expand Down Expand Up @@ -105,6 +106,7 @@ def forward(
depths,
radii,
conics,
compensation,
num_tiles_hit,
) = _C.project_gaussians_forward(
num_points,
Expand Down Expand Up @@ -146,10 +148,19 @@ def forward(
conics,
)

return (xys, depths, radii, conics, num_tiles_hit, cov3d)
return (xys, depths, radii, conics, compensation, num_tiles_hit, cov3d)

@staticmethod
def backward(ctx, v_xys, v_depths, v_radii, v_conics, v_num_tiles_hit, v_cov3d):
def backward(
ctx,
v_xys,
v_depths,
v_radii,
v_conics,
v_compensation,
v_num_tiles_hit,
v_cov3d,
):
(
means3d,
scales,
Expand Down
15 changes: 13 additions & 2 deletions tests/test_project_gaussians.py
Expand Up @@ -66,7 +66,15 @@ def test_project_gaussians_forward():
BLOCK_X, BLOCK_Y = 16, 16
tile_bounds = (W + BLOCK_X - 1) // BLOCK_X, (H + BLOCK_Y - 1) // BLOCK_Y, 1

(cov3d, xys, depths, radii, conics, num_tiles_hit,) = _C.project_gaussians_forward(
(
cov3d,
xys,
depths,
radii,
conics,
compensation,
num_tiles_hit,
) = _C.project_gaussians_forward(
num_points,
means3d,
scales,
Expand All @@ -93,6 +101,7 @@ def test_project_gaussians_forward():
_depths,
_radii,
_conics,
_compensation,
_num_tiles_hit,
_masks,
) = _torch_impl.project_gaussians_forward(
Expand All @@ -114,6 +123,7 @@ def test_project_gaussians_forward():
check_close(depths, _depths)
check_close(radii, _radii)
check_close(conics, _conics)
check_close(compensation, _compensation)
check_close(num_tiles_hit, _num_tiles_hit)
print("passed project_gaussians_forward test")

Expand Down Expand Up @@ -156,6 +166,7 @@ def test_project_gaussians_backward():
radii,
conics,
_,
_,
masks,
) = _torch_impl.project_gaussians_forward(
means3d,
Expand Down Expand Up @@ -219,7 +230,7 @@ def project_cov3d_ewa_partial(mean3d, cov3d):
i, j = torch.triu_indices(3, 3)
cov3d_mat[..., i, j] = cov3d
cov3d_mat[..., [1, 2, 2], [0, 0, 1]] = cov3d[..., [1, 2, 4]]
cov2d = _torch_impl.project_cov3d_ewa(
cov2d, _ = _torch_impl.project_cov3d_ewa(
mean3d, cov3d_mat, viewmat, fx, fy, tan_fovx, tan_fovy
)
ii, jj = torch.triu_indices(2, 2)
Expand Down