Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pose optimization to Splatfacto #2885

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion nerfstudio/models/splatfacto.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from torch.nn import Parameter
from typing_extensions import Literal

from nerfstudio.cameras.camera_optimizers import CameraOptimizer, CameraOptimizerConfig
from nerfstudio.cameras.cameras import Cameras
from nerfstudio.data.scene_box import OrientedBox
from nerfstudio.engine.callbacks import TrainingCallback, TrainingCallbackAttributes, TrainingCallbackLocation
Expand Down Expand Up @@ -136,6 +137,8 @@ class SplatfactoModelConfig(ModelConfig):
"""
output_depth_during_training: bool = False
"""If True, output depth during training. Otherwise, only output depth during evaluation."""
camera_optimizer: CameraOptimizerConfig = field(default_factory=lambda: CameraOptimizerConfig(mode="off"))
"""Config of the camera optimizer to use"""
rasterize_mode: Literal["classic", "antialiased"] = "classic"
"""
Classic mode of rendering will use the EWA volume splatting with a [0.3, 0.3] screen space blurring kernel. This
Expand Down Expand Up @@ -213,6 +216,10 @@ def populate_modules(self):
}
)

self.camera_optimizer: CameraOptimizer = self.config.camera_optimizer.setup(
num_cameras=self.num_train_data, device="cpu"
)

# metrics
from torchmetrics.image import PeakSignalNoiseRatio
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
Expand Down Expand Up @@ -609,6 +616,7 @@ def get_param_groups(self) -> Dict[str, List[Parameter]]:
Mapping of different parameter groups
"""
gps = self.get_gaussian_param_groups()
self.camera_optimizer.get_param_groups(param_groups=gps)
return gps

def _get_downscale_factor(self):
Expand Down Expand Up @@ -656,6 +664,7 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
background = torch.zeros(3, device=self.device)
else:
background = self.background_color.to(self.device)
self.camera_optimizer.apply_to_camera(camera)
else:
if renderers.BACKGROUND_COLOR_OVERRIDE is not None:
background = renderers.BACKGROUND_COLOR_OVERRIDE.to(self.device)
Expand Down Expand Up @@ -789,6 +798,7 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
background=torch.zeros(3, device=self.device),
)[..., 0:1] # type: ignore
depth_im = torch.where(alpha > 0, depth_im / alpha, depth_im.detach().max())
depth_im[alpha == 0] = 1000

return {"rgb": rgb, "depth": depth_im, "accumulation": alpha, "background": background} # type: ignore

Expand Down Expand Up @@ -829,6 +839,8 @@ def get_metrics_dict(self, outputs, batch) -> Dict[str, torch.Tensor]:
metrics_dict["psnr"] = self.psnr(predicted_rgb, gt_rgb)

metrics_dict["gaussian_count"] = self.num_points

self.camera_optimizer.get_metrics_dict(metrics_dict)
return metrics_dict

def get_loss_dict(self, outputs, batch, metrics_dict=None) -> Dict[str, torch.Tensor]:
Expand Down Expand Up @@ -867,11 +879,17 @@ def get_loss_dict(self, outputs, batch, metrics_dict=None) -> Dict[str, torch.Te
else:
scale_reg = torch.tensor(0.0).to(self.device)

return {
loss_dict = {
"main_loss": (1 - self.config.ssim_lambda) * Ll1 + self.config.ssim_lambda * simloss,
"scale_reg": scale_reg,
}

if self.training:
# Add loss from camera optimizer
self.camera_optimizer.get_loss_dict(loss_dict)

return loss_dict

@torch.no_grad()
def get_outputs_for_camera(self, camera: Cameras, obb_box: Optional[OrientedBox] = None) -> Dict[str, torch.Tensor]:
"""Takes in a camera, generates the raybundle, and computes the output of the model.
Expand Down
Loading