Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pose optimization to Splatfacto #2885

Closed
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions nerfstudio/models/splatfacto.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from torch.nn import Parameter
from typing_extensions import Literal

from nerfstudio.cameras.camera_optimizers import CameraOptimizer, CameraOptimizerConfig
from nerfstudio.cameras.cameras import Cameras
from nerfstudio.data.scene_box import OrientedBox
from nerfstudio.engine.callbacks import TrainingCallback, TrainingCallbackAttributes, TrainingCallbackLocation
Expand Down Expand Up @@ -157,6 +158,8 @@ class SplatfactoModelConfig(ModelConfig):
"""
output_depth_during_training: bool = False
"""If True, output depth during training. Otherwise, only output depth during evaluation."""
camera_optimizer: CameraOptimizerConfig = field(default_factory=lambda: CameraOptimizerConfig(mode="SO3xR3"))
Copy link
Collaborator

@jb-ye jb-ye Feb 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if I set this camera optimizer to be off, would it still trigger computation overheads w.r.t. computing the gradients w.r.t. to view_mat and proj_mat? If yes, how much overheads it brings?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will check this. Thank you for the suggestion!

"""Config of the camera optimizer to use"""


class SplatfactoModel(Model):
Expand Down Expand Up @@ -213,6 +216,10 @@ def populate_modules(self):

self.opacities = torch.nn.Parameter(torch.logit(0.1 * torch.ones(self.num_points, 1)))

self.camera_optimizer: CameraOptimizer = self.config.camera_optimizer.setup(
num_cameras=self.num_train_data, device="cpu"
)

# metrics
from torchmetrics.image import PeakSignalNoiseRatio
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
Expand Down Expand Up @@ -624,6 +631,7 @@ def get_param_groups(self) -> Dict[str, List[Parameter]]:
Mapping of different parameter groups
"""
gps = self.get_gaussian_param_groups()
self.camera_optimizer.get_param_groups(param_groups=gps)
return gps

def _get_downscale_factor(self):
Expand Down Expand Up @@ -660,6 +668,7 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
background = torch.zeros(3, device=self.device)
else:
background = self.background_color.to(self.device)
self.camera_optimizer.apply_to_camera(camera)
else:
if renderers.BACKGROUND_COLOR_OVERRIDE is not None:
background = renderers.BACKGROUND_COLOR_OVERRIDE.to(self.device)
Expand Down Expand Up @@ -787,7 +796,8 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
W,
background=torch.zeros(3, device=self.device),
)[..., 0:1] # type: ignore
depth_im = torch.where(alpha > 0, depth_im / alpha, depth_im.detach().max())
depth_im[alpha > 0] = depth_im[alpha > 0] / alpha[alpha > 0]
jh-surh marked this conversation as resolved.
Show resolved Hide resolved
depth_im[alpha == 0] = 1000

return {"rgb": rgb, "depth": depth_im, "accumulation": alpha} # type: ignore

Expand Down Expand Up @@ -824,6 +834,8 @@ def get_metrics_dict(self, outputs, batch) -> Dict[str, torch.Tensor]:
metrics_dict["psnr"] = self.psnr(predicted_rgb, gt_rgb)

metrics_dict["gaussian_count"] = self.num_points

self.camera_optimizer.get_metrics_dict(metrics_dict)
return metrics_dict

def get_loss_dict(self, outputs, batch, metrics_dict=None) -> Dict[str, torch.Tensor]:
Expand Down Expand Up @@ -861,11 +873,17 @@ def get_loss_dict(self, outputs, batch, metrics_dict=None) -> Dict[str, torch.Te
else:
scale_reg = torch.tensor(0.0).to(self.device)

return {
loss_dict = {
"main_loss": (1 - self.config.ssim_lambda) * Ll1 + self.config.ssim_lambda * simloss,
"scale_reg": scale_reg,
}

if self.training:
# Add loss from camera optimizer
self.camera_optimizer.get_loss_dict(loss_dict)

return loss_dict

@torch.no_grad()
def get_outputs_for_camera(self, camera: Cameras, obb_box: Optional[OrientedBox] = None) -> Dict[str, torch.Tensor]:
"""Takes in a camera, generates the raybundle, and computes the output of the model.
Expand Down
Loading