|
9 | 9 | # For inquiries contact george.drettakis@inria.fr |
10 | 10 | # |
11 | 11 |
|
12 | | -import torch |
13 | 12 | import math |
14 | | -from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer |
| 13 | + |
| 14 | +import torch |
| 15 | +from torch.nn import functional as F |
| 16 | +from gsplat import rasterization |
15 | 17 | from scene.gaussian_model import GaussianModel |
16 | | -from utils.sh_utils import eval_sh |
17 | 18 |
|
18 | 19 | def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None): |
19 | 20 | """ |
20 | 21 | Render the scene. |
21 | 22 | |
22 | 23 | Background tensor (bg_color) must be on GPU! |
23 | 24 | """ |
24 | | - |
25 | | - # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means |
26 | | - screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 |
27 | | - try: |
28 | | - screenspace_points.retain_grad() |
29 | | - except: |
30 | | - pass |
31 | | - |
32 | | - # Set up rasterization configuration |
33 | 25 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) |
34 | 26 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) |
35 | | - |
36 | | - raster_settings = GaussianRasterizationSettings( |
37 | | - image_height=int(viewpoint_camera.image_height), |
38 | | - image_width=int(viewpoint_camera.image_width), |
39 | | - tanfovx=tanfovx, |
40 | | - tanfovy=tanfovy, |
41 | | - bg=bg_color, |
42 | | - scale_modifier=scaling_modifier, |
43 | | - viewmatrix=viewpoint_camera.world_view_transform, |
44 | | - projmatrix=viewpoint_camera.full_proj_transform, |
45 | | - sh_degree=pc.active_sh_degree, |
46 | | - campos=viewpoint_camera.camera_center, |
47 | | - prefiltered=False, |
48 | | - debug=pipe.debug |
| 27 | + focal_length_x = viewpoint_camera.image_width / (2 * tanfovx) |
| 28 | + focal_length_y = viewpoint_camera.image_height / (2 * tanfovy) |
| 29 | + K = torch.tensor( |
| 30 | + [ |
| 31 | + [focal_length_x, 0, viewpoint_camera.image_width / 2.0], |
| 32 | + [0, focal_length_y, viewpoint_camera.image_height / 2.0], |
| 33 | + [0, 0, 1], |
| 34 | + ], |
| 35 | + device="cuda", |
49 | 36 | ) |
50 | 37 |
|
51 | | - rasterizer = GaussianRasterizer(raster_settings=raster_settings) |
52 | | - |
53 | 38 | means3D = pc.get_xyz |
54 | | - means2D = screenspace_points |
55 | 39 | opacity = pc.get_opacity |
56 | | - |
57 | | - # If precomputed 3d covariance is provided, use it. If not, then it will be computed from |
58 | | - # scaling / rotation by the rasterizer. |
59 | | - scales = None |
60 | | - rotations = None |
61 | | - cov3D_precomp = None |
62 | | - if pipe.compute_cov3D_python: |
63 | | - cov3D_precomp = pc.get_covariance(scaling_modifier) |
64 | | - else: |
65 | | - scales = pc.get_scaling |
66 | | - rotations = pc.get_rotation |
67 | | - |
68 | | - # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors |
69 | | - # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. |
70 | | - shs = None |
71 | | - colors_precomp = None |
72 | | - if override_color is None: |
73 | | - if pipe.convert_SHs_python: |
74 | | - shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2) |
75 | | - dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1)) |
76 | | - dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True) |
77 | | - sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized) |
78 | | - colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0) |
79 | | - else: |
80 | | - shs = pc.get_features |
| 40 | + scales = pc.get_scaling * scaling_modifier |
| 41 | + rotations = pc.get_rotation |
| 42 | + if override_color is not None: |
| 43 | + colors = override_color # [N, 3] |
| 44 | + sh_degree = None |
81 | 45 | else: |
82 | | - colors_precomp = override_color |
83 | | - |
84 | | - # Rasterize visible Gaussians to image, obtain their radii (on screen). |
85 | | - rendered_image, radii = rasterizer( |
86 | | - means3D = means3D, |
87 | | - means2D = means2D, |
88 | | - shs = shs, |
89 | | - colors_precomp = colors_precomp, |
90 | | - opacities = opacity, |
91 | | - scales = scales, |
92 | | - rotations = rotations, |
93 | | - cov3D_precomp = cov3D_precomp) |
| 46 | + colors = pc.get_features # [N, K, 3] |
| 47 | + sh_degree = pc.active_sh_degree |
94 | 48 |
|
| 49 | + viewmat = viewpoint_camera.world_view_transform.transpose(0, 1) # [4, 4] |
| 50 | + render_colors, render_alphas, info = rasterization( |
| 51 | + means=means3D, # [N, 3] |
| 52 | + quats=rotations, # [N, 4] |
| 53 | + scales=scales, # [N, 3] |
| 54 | + opacities=opacity.squeeze(-1), # [N,] |
| 55 | + colors=colors, |
| 56 | + viewmats=viewmat[None], # [1, 4, 4] |
| 57 | + Ks=K[None], # [1, 3, 3] |
| 58 | + backgrounds=bg_color[None], |
| 59 | + width=int(viewpoint_camera.image_width), |
| 60 | + height=int(viewpoint_camera.image_height), |
| 61 | + packed=False, |
| 62 | + sh_degree=sh_degree, |
| 63 | + ) |
| 64 | + # [1, H, W, 3] -> [3, H, W] |
| 65 | + rendered_image = render_colors[0].permute(2, 0, 1) |
| 66 | + radii = info["radii"].squeeze(0) # [N,] |
| 67 | + try: |
| 68 | + info["means2d"].retain_grad() # [1, N, 2] |
| 69 | + except: |
| 70 | + pass |
| 71 | + |
95 | 72 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible. |
96 | 73 | # They will be excluded from value updates used in the splitting criteria. |
97 | 74 | return {"render": rendered_image, |
98 | | - "viewspace_points": screenspace_points, |
| 75 | + "viewspace_points": info["means2d"], |
99 | 76 | "visibility_filter" : radii > 0, |
100 | 77 | "radii": radii} |
0 commit comments