Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Camera pose optimization for Splatfacto #2891

Merged
merged 7 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 21 additions & 11 deletions nerfstudio/cameras/camera_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from dataclasses import dataclass, field
from typing import Literal, Optional, Type, Union

import numpy
import torch
import tyro
from jaxtyping import Float, Int
Expand Down Expand Up @@ -151,15 +152,20 @@ def apply_to_raybundle(self, raybundle: RayBundle) -> None:
raybundle.origins = raybundle.origins + correction_matrices[:, :3, 3]
raybundle.directions = torch.bmm(correction_matrices[:, :3, :3], raybundle.directions[..., None]).squeeze()

def apply_to_camera(self, camera: Cameras) -> None:
"""Apply the pose correction to the raybundle"""
if self.config.mode != "off":
assert camera.metadata is not None, "Must provide id of camera in its metadata"
assert "cam_idx" in camera.metadata, "Must provide id of camera in its metadata"
camera_idx = camera.metadata["cam_idx"]
adj = self(torch.tensor([camera_idx], dtype=torch.long, device=camera.device)) # type: ignore
adj = torch.cat([adj, torch.Tensor([0, 0, 0, 1])[None, None].to(adj)], dim=1)
camera.camera_to_worlds = torch.bmm(camera.camera_to_worlds, adj)
def apply_to_camera(self, camera: Cameras) -> torch.Tensor:
"""Apply the pose correction to the world-to-camera matrix in a Camera object"""
if self.config.mode == "off":
return camera.camera_to_worlds

assert camera.metadata is not None, "Must provide id of camera in its metadata"
if "cam_idx" not in camera.metadata:
# Evalutaion cams?
return camera.camera_to_worlds

camera_idx = camera.metadata["cam_idx"]
adj = self(torch.tensor([camera_idx], dtype=torch.long, device=camera.device)) # type: ignore
adj = torch.cat([adj, torch.Tensor([0, 0, 0, 1])[None, None].to(adj)], dim=1)
return torch.bmm(camera.camera_to_worlds, adj)

def get_loss_dict(self, loss_dict: dict) -> None:
"""Add regularization"""
Expand All @@ -176,8 +182,12 @@ def get_correction_matrices(self):
def get_metrics_dict(self, metrics_dict: dict) -> None:
"""Get camera optimizer metrics"""
if self.config.mode != "off":
metrics_dict["camera_opt_translation"] = self.pose_adjustment[:, :3].norm()
metrics_dict["camera_opt_rotation"] = self.pose_adjustment[:, 3:].norm()
trans = self.pose_adjustment[:, :3].detach().norm(dim=-1)
rot = self.pose_adjustment[:, 3:].detach().norm(dim=-1)
metrics_dict["camera_opt_translation_max"] = trans.max()
metrics_dict["camera_opt_translation_mean"] = trans.mean()
metrics_dict["camera_opt_rotation_mean"] = numpy.rad2deg(rot.mean().cpu())
metrics_dict["camera_opt_rotation_max"] = numpy.rad2deg(rot.max().cpu())

def get_param_groups(self, param_groups: dict) -> None:
"""Get camera optimizer parameters"""
Expand Down
12 changes: 8 additions & 4 deletions nerfstudio/configs/method_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,8 +632,10 @@
},
"quats": {"optimizer": AdamOptimizerConfig(lr=0.001, eps=1e-15), "scheduler": None},
"camera_opt": {
"optimizer": AdamOptimizerConfig(lr=1e-3, eps=1e-15),
"scheduler": ExponentialDecaySchedulerConfig(lr_final=5e-5, max_steps=30000),
"optimizer": AdamOptimizerConfig(lr=1e-4, eps=1e-15),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add this change to "splatfacto-big" config?

"scheduler": ExponentialDecaySchedulerConfig(
lr_final=5e-7, max_steps=30000, warmup_steps=1000, lr_pre_warmup=0
),
},
},
viewer=ViewerConfig(num_rays_per_chunk=1 << 15),
Expand Down Expand Up @@ -684,8 +686,10 @@
},
"quats": {"optimizer": AdamOptimizerConfig(lr=0.001, eps=1e-15), "scheduler": None},
"camera_opt": {
"optimizer": AdamOptimizerConfig(lr=1e-3, eps=1e-15),
"scheduler": ExponentialDecaySchedulerConfig(lr_final=5e-5, max_steps=30000),
"optimizer": AdamOptimizerConfig(lr=1e-4, eps=1e-15),
"scheduler": ExponentialDecaySchedulerConfig(
lr_final=5e-7, max_steps=30000, warmup_steps=1000, lr_pre_warmup=0
),
},
},
viewer=ViewerConfig(num_rays_per_chunk=1 << 15),
Expand Down
29 changes: 25 additions & 4 deletions nerfstudio/models/splatfacto.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from torch.nn import Parameter
from typing_extensions import Literal

from nerfstudio.cameras.camera_optimizers import CameraOptimizer, CameraOptimizerConfig
from nerfstudio.cameras.cameras import Cameras
from nerfstudio.data.scene_box import OrientedBox
from nerfstudio.engine.callbacks import TrainingCallback, TrainingCallbackAttributes, TrainingCallbackLocation
Expand Down Expand Up @@ -146,6 +147,8 @@ class SplatfactoModelConfig(ModelConfig):
However, PLY exported with antialiased rasterize mode is not compatible with classic mode. Thus many web viewers that
were implemented for classic mode can not render antialiased mode PLY properly without modifications.
"""
camera_optimizer: CameraOptimizerConfig = field(default_factory=lambda: CameraOptimizerConfig(mode="off"))
"""Config of the camera optimizer to use"""


class SplatfactoModel(Model):
Expand Down Expand Up @@ -213,6 +216,10 @@ def populate_modules(self):
}
)

self.camera_optimizer: CameraOptimizer = self.config.camera_optimizer.setup(
num_cameras=self.num_train_data, device="cpu"
)

# metrics
from torchmetrics.image import PeakSignalNoiseRatio
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
Expand Down Expand Up @@ -609,6 +616,7 @@ def get_param_groups(self) -> Dict[str, List[Parameter]]:
Mapping of different parameter groups
"""
gps = self.get_gaussian_param_groups()
self.camera_optimizer.get_param_groups(param_groups=gps)
return gps

def _get_downscale_factor(self):
Expand Down Expand Up @@ -648,6 +656,8 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:

# get the background color
if self.training:
optimized_camera_to_world = self.camera_optimizer.apply_to_camera(camera)[0, ...]

if self.config.background_color == "random":
background = torch.rand(3, device=self.device)
elif self.config.background_color == "white":
Expand All @@ -657,6 +667,8 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
else:
background = self.background_color.to(self.device)
else:
optimized_camera_to_world = camera.camera_to_worlds[0, ...]

if renderers.BACKGROUND_COLOR_OVERRIDE is not None:
background = renderers.BACKGROUND_COLOR_OVERRIDE.to(self.device)
else:
Expand All @@ -674,8 +686,9 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
camera_downscale = self._get_downscale_factor()
camera.rescale_output_resolution(1 / camera_downscale)
# shift the camera to center of scene looking at center
R = camera.camera_to_worlds[0, :3, :3] # 3 x 3
T = camera.camera_to_worlds[0, :3, 3:4] # 3 x 1
R = optimized_camera_to_world[:3, :3] # 3 x 3
T = optimized_camera_to_world[:3, 3:4] # 3 x 1

# flip the z and y axes to align with gsplat conventions
R_edit = torch.diag(torch.tensor([1, -1, -1], device=self.device, dtype=R.dtype))
R = R @ R_edit
Expand Down Expand Up @@ -738,7 +751,7 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
self.xys.retain_grad()

if self.config.sh_degree > 0:
viewdirs = means_crop.detach() - camera.camera_to_worlds.detach()[..., :3, 3] # (N, 3)
viewdirs = means_crop.detach() - optimized_camera_to_world.detach()[:3, 3] # (N, 3)
viewdirs = viewdirs / viewdirs.norm(dim=-1, keepdim=True)
n = min(self.step // self.config.sh_degree_interval, self.config.sh_degree)
rgbs = spherical_harmonics(n, viewdirs, colors_crop)
Expand Down Expand Up @@ -829,6 +842,8 @@ def get_metrics_dict(self, outputs, batch) -> Dict[str, torch.Tensor]:
metrics_dict["psnr"] = self.psnr(predicted_rgb, gt_rgb)

metrics_dict["gaussian_count"] = self.num_points

self.camera_optimizer.get_metrics_dict(metrics_dict)
return metrics_dict

def get_loss_dict(self, outputs, batch, metrics_dict=None) -> Dict[str, torch.Tensor]:
Expand Down Expand Up @@ -867,11 +882,17 @@ def get_loss_dict(self, outputs, batch, metrics_dict=None) -> Dict[str, torch.Te
else:
scale_reg = torch.tensor(0.0).to(self.device)

return {
loss_dict = {
"main_loss": (1 - self.config.ssim_lambda) * Ll1 + self.config.ssim_lambda * simloss,
"scale_reg": scale_reg,
}

if self.training:
# Add loss from camera optimizer
self.camera_optimizer.get_loss_dict(loss_dict)

return loss_dict

@torch.no_grad()
def get_outputs_for_camera(self, camera: Cameras, obb_box: Optional[OrientedBox] = None) -> Dict[str, torch.Tensor]:
"""Takes in a camera, generates the raybundle, and computes the output of the model.
Expand Down
Loading