Skip to content

Commit 258123a

Browse files
author
Ruilong Li
committed
use gsplat as CUDA backend
1 parent 472689c commit 258123a

File tree

4 files changed

+72
-70
lines changed

4 files changed

+72
-70
lines changed

README.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,22 @@
1+
# Gaussian Splatting with `gsplat` Backend
2+
3+
In this fork of the official code base, we replace the rasterization backend from `diff-gaussian-rasterization` to `gsplat` with
4+
minimal changes (<100 lines), and get some improvements for free:
5+
6+
For example we showcase a 20% training speedup and a noticeable memory reduction, with slightly better performance on the Garden scene from MipNeRF360, benchmarked on a 24GB NVIDIA TITAN RTX at 7k steps.
7+
8+
| Backend | Training Time | Memory | SSIM | PSNR | LPIPS |
9+
| -------- | ------- | ------- | ------- | ------- | ------- |
10+
| `diff-gaussian-rasterization` | 482s | 9.11 GB | 0.8237 | 26.11 | 0.166 |
11+
| `gsplat v1.0` | 398s | 8.62 GB | 0.8366 | 26.18 | 0.163 |
12+
13+
Note the improvements will be much more significant on larger scenes.
14+
On top of that, there are more functionalities supported in `gsplat v1.0`, including
15+
**batched rasterization**, **trade-off between memory and speed**, **sparse gradient** etc.
16+
Check [gsplat.studio](https://docs.gsplat.studio/) for more details.
17+
18+
---------------
19+
120
# 3D Gaussian Splatting for Real-Time Radiance Field Rendering
221
Bernhard Kerbl*, Georgios Kopanas*, Thomas Leimkühler, George Drettakis (* indicates equal contribution)<br>
322
| [Webpage](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/) | [Full Paper](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/3d_gaussian_splatting_high.pdf) | [Video](https://youtu.be/T_kXY43VZnk) | [Other GRAPHDECO Publications](http://www-sop.inria.fr/reves/publis/gdindex.php) | [FUNGRAPH project page](https://fungraph.inria.fr) |<br>

gaussian_renderer/__init__.py

Lines changed: 44 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -9,92 +9,69 @@
99
# For inquiries contact george.drettakis@inria.fr
1010
#
1111

12-
import torch
1312
import math
14-
from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
13+
14+
import torch
15+
from torch.nn import functional as F
16+
from gsplat import rasterization
1517
from scene.gaussian_model import GaussianModel
16-
from utils.sh_utils import eval_sh
1718

1819
def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None):
1920
"""
2021
Render the scene.
2122
2223
Background tensor (bg_color) must be on GPU!
2324
"""
24-
25-
# Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
26-
screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0
27-
try:
28-
screenspace_points.retain_grad()
29-
except:
30-
pass
31-
32-
# Set up rasterization configuration
3325
tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
3426
tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
35-
36-
raster_settings = GaussianRasterizationSettings(
37-
image_height=int(viewpoint_camera.image_height),
38-
image_width=int(viewpoint_camera.image_width),
39-
tanfovx=tanfovx,
40-
tanfovy=tanfovy,
41-
bg=bg_color,
42-
scale_modifier=scaling_modifier,
43-
viewmatrix=viewpoint_camera.world_view_transform,
44-
projmatrix=viewpoint_camera.full_proj_transform,
45-
sh_degree=pc.active_sh_degree,
46-
campos=viewpoint_camera.camera_center,
47-
prefiltered=False,
48-
debug=pipe.debug
27+
focal_length_x = viewpoint_camera.image_width / (2 * tanfovx)
28+
focal_length_y = viewpoint_camera.image_height / (2 * tanfovy)
29+
K = torch.tensor(
30+
[
31+
[focal_length_x, 0, viewpoint_camera.image_width / 2.0],
32+
[0, focal_length_y, viewpoint_camera.image_height / 2.0],
33+
[0, 0, 1],
34+
],
35+
device="cuda",
4936
)
5037

51-
rasterizer = GaussianRasterizer(raster_settings=raster_settings)
52-
5338
means3D = pc.get_xyz
54-
means2D = screenspace_points
5539
opacity = pc.get_opacity
56-
57-
# If precomputed 3d covariance is provided, use it. If not, then it will be computed from
58-
# scaling / rotation by the rasterizer.
59-
scales = None
60-
rotations = None
61-
cov3D_precomp = None
62-
if pipe.compute_cov3D_python:
63-
cov3D_precomp = pc.get_covariance(scaling_modifier)
64-
else:
65-
scales = pc.get_scaling
66-
rotations = pc.get_rotation
67-
68-
# If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
69-
# from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
70-
shs = None
71-
colors_precomp = None
72-
if override_color is None:
73-
if pipe.convert_SHs_python:
74-
shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2)
75-
dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1))
76-
dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True)
77-
sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
78-
colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
79-
else:
80-
shs = pc.get_features
40+
scales = pc.get_scaling * scaling_modifier
41+
rotations = pc.get_rotation
42+
if override_color is not None:
43+
colors = override_color # [N, 3]
44+
sh_degree = None
8145
else:
82-
colors_precomp = override_color
83-
84-
# Rasterize visible Gaussians to image, obtain their radii (on screen).
85-
rendered_image, radii = rasterizer(
86-
means3D = means3D,
87-
means2D = means2D,
88-
shs = shs,
89-
colors_precomp = colors_precomp,
90-
opacities = opacity,
91-
scales = scales,
92-
rotations = rotations,
93-
cov3D_precomp = cov3D_precomp)
46+
colors = pc.get_features # [N, K, 3]
47+
sh_degree = pc.active_sh_degree
9448

49+
viewmat = viewpoint_camera.world_view_transform.transpose(0, 1) # [4, 4]
50+
render_colors, render_alphas, info = rasterization(
51+
means=means3D, # [N, 3]
52+
quats=rotations, # [N, 4]
53+
scales=scales, # [N, 3]
54+
opacities=opacity.squeeze(-1), # [N,]
55+
colors=colors,
56+
viewmats=viewmat[None], # [1, 4, 4]
57+
Ks=K[None], # [1, 3, 3]
58+
backgrounds=bg_color[None],
59+
width=int(viewpoint_camera.image_width),
60+
height=int(viewpoint_camera.image_height),
61+
packed=False,
62+
sh_degree=sh_degree,
63+
)
64+
# [1, H, W, 3] -> [3, H, W]
65+
rendered_image = render_colors[0].permute(2, 0, 1)
66+
radii = info["radii"].squeeze(0) # [N,]
67+
try:
68+
info["means2d"].retain_grad() # [1, N, 2]
69+
except:
70+
pass
71+
9572
# Those Gaussians that were frustum culled or had a radius of 0 were not visible.
9673
# They will be excluded from value updates used in the splitting criteria.
9774
return {"render": rendered_image,
98-
"viewspace_points": screenspace_points,
75+
"viewspace_points": info["means2d"],
9976
"visibility_filter" : radii > 0,
10077
"radii": radii}

scene/gaussian_model.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,10 @@ def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size):
402402

403403
torch.cuda.empty_cache()
404404

405-
def add_densification_stats(self, viewspace_point_tensor, update_filter):
406-
self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter,:2], dim=-1, keepdim=True)
405+
def add_densification_stats(self, viewspace_point_tensor, update_filter, width, height):
406+
grad = viewspace_point_tensor.grad.squeeze(0) # [N, 2]
407+
# Normalize the gradient to [-1, 1] screen size
408+
grad[:, 0] *= width * 0.5
409+
grad[:, 1] *= height * 0.5
410+
self.xyz_gradient_accum[update_filter] += torch.norm(grad[update_filter,:2], dim=-1, keepdim=True)
407411
self.denom[update_filter] += 1

train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,13 +107,15 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
107107
training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background))
108108
if (iteration in saving_iterations):
109109
print("\n[ITER {}] Saving Gaussians".format(iteration))
110+
mem = torch.cuda.max_memory_allocated() / 1024**3
111+
print(f"Max memory used: {mem:.2f} GB")
110112
scene.save(iteration)
111113

112114
# Densification
113115
if iteration < opt.densify_until_iter:
114116
# Keep track of max radii in image-space for pruning
115117
gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
116-
gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
118+
gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter, image.shape[2], image.shape[1])
117119

118120
if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
119121
size_threshold = 20 if iteration > opt.opacity_reset_interval else None

0 commit comments

Comments
 (0)