Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add scannetpp dataparser #2498

Merged
merged 8 commits into from
Oct 19, 2023
Merged

Conversation

liu115
Copy link
Contributor

@liu115 liu115 commented Oct 9, 2023

Dataparser for the ScanNet++ dataset adapted from the nerfstudio dataset. The dataset parser follows the dataset file structure and can be used by specifying the dataset root and the scene ID.

Example usage

DATA_ROOT=SCANNETPP_DATASET_ROOT
SCENE_ID=036bce3393

ns-train nerfacto \
     --max_num_iterations 100000 \
     --pipeline.datamanager.train_num_rays_per_batch 8192 \
     --pipeline.datamanager.train_num_images_to_sample_from 400 \
     --pipeline.datamanager.train_num_times_to_repeat_images 100 \
     --vis viewer+tensorboard \
     scannetpp-data \
     --data ${DATA_ROOT} \
     --scene-id ${SCENE_ID} \
     --scene_scale 1.5

The scene contain more than 1k images. Set train_num_images_to_sample_from and train_num_times_to_repeat_images to avoid cpu out-of-memory.

Training with the viewer. The images with red borders are the predefined testing images by the dataset
image

The result rendering
https://github.com/nerfstudio-project/nerfstudio/assets/8998128/e4c01178-cafa-4e28-837c-6b0bdf5538d5

@liu115
Copy link
Contributor Author

liu115 commented Oct 10, 2023

I also noticed that training is much slower with mask (1-2s per iteration for default nerfacto) because the pixel sampler is running nonzero on CPU, especially when the number of images and the image resolution are large.

if isinstance(mask, torch.Tensor):
nonzero_indices = torch.nonzero(mask[..., 0], as_tuple=False)
chosen_indices = random.sample(range(len(nonzero_indices)), k=batch_size)
indices = nonzero_indices[chosen_indices]

One way is to have less image per samples and more frequent sampling like

--pipeline.datamanager.train_num_images_to_sample_from 8 \
--pipeline.datamanager.train_num_times_to_repeat_images 1 \

Or how the mask being used in the code has to be changed. Since the invalid pixels are usually a small part of the pixels among all the images, we could save all the mask == 0 pixels as preprocessing and sample pixels for training from them. Or ignore the mask during sampling and only apply mask filtering during loss computation.

Copy link
Contributor

@tancik tancik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for adding

@tancik tancik enabled auto-merge (squash) October 19, 2023 18:10
@tancik tancik merged commit f31f3bb into nerfstudio-project:main Oct 19, 2023
4 checks passed
@OrangeSodahub
Copy link
Contributor

@liu115 Hi, I wonder you used nerfacto or nerfacto-huge, and each iteration costs 1-2s?

@liu115
Copy link
Contributor Author

liu115 commented Jan 10, 2024

I used nerfacto

One quick solution to avoid the mask sampling bottleneck is not to use mask during pixel sampling. Instead, apply the mask only during the loss computation

For example, in nerfacto.py

    def get_loss_dict(self, outputs, batch, metrics_dict=None):

        loss_dict = {}
        image = batch["image"].to(self.device)
        pred_rgb, gt_rgb = self.renderer_rgb.blend_background_for_loss_computation(
            pred_image=outputs["rgb"],
            pred_accumulation=outputs["accumulation"],
            gt_image=image,
        )
        if "mask" in batch:
            mask = batch["mask"].view(-1)
            assert mask.shape[0] == gt_rgb.shape[0]
            loss_dict["rgb_loss"] = self.rgb_loss(gt_rgb[mask], pred_rgb[mask])
        else:
            loss_dict["rgb_loss"] = self.rgb_loss(gt_rgb, pred_rgb)

and remove the if statement here

if isinstance(mask, torch.Tensor):
nonzero_indices = torch.nonzero(mask[..., 0], as_tuple=False)
chosen_indices = random.sample(range(len(nonzero_indices)), k=batch_size)
indices = nonzero_indices[chosen_indices]

@OrangeSodahub
Copy link
Contributor

@liu115 Hi, I have another thing want to confirm with you. That is, there exist two ways that I could get an extrinsic matrix of a camera. One is from the colmap/images.txt file and use the qvec2rotmat method you provided, and in this case the matrix is in OpenCV format right? Another way is from the transforms.json, which is in OpenGL format, is that right?

@liu115
Copy link
Contributor Author

liu115 commented Feb 26, 2024

Yes. transforms.json follows nerfstudio/OpenGL format, and the colmap files follows the format defined by colmap.

@OrangeSodahub
Copy link
Contributor

@liu115 But I found that the translation vectors extracted from the extrinsic matrix from transform.json and from colmap files are different? Specifically, mat_A is from transform.json, mat_B is from colmap file, and

mat_A[..., :3, 3] != mat_B[..., :3, 3]

@liu115
Copy link
Contributor Author

liu115 commented Mar 2, 2024

Reference here when keep_original_world_coordinate is false.

w2c = np.concatenate([rotation, translation], 1)
w2c = np.concatenate([w2c, np.array([[0, 0, 0, 1]])], 0)
c2w = np.linalg.inv(w2c)
# Convert from COLMAP's camera coordinate system (OpenCV) to ours (OpenGL)
c2w[0:3, 1:3] *= -1
if not keep_original_world_coordinate:
c2w = c2w[np.array([0, 2, 1, 3]), :]
c2w[2, :] *= -1

@ozgurcelik
Copy link

ozgurcelik commented May 25, 2024

Hi. I can run scannetpp data with nerfacto and instant-ngp but it doesn't work with splatfacto:

ns-train nerfacto scannetpp-data --data /cluster/scratch/tcelik/scannet++/data/56a0ec536c works. so does ns-train instant-ngp scannetpp-data --data /cluster/scratch/tcelik/scannet++/data/56a0ec536c. But when I do ns-train splatfacto scannetpp-data --data /cluster/scratch/tcelik/scannet++/data/56a0ec536c I get

ns-train 8
sys.exit(entrypoint())

train.py 262 entrypoint
main(

train.py 247 main
launch(

train.py 189 launch
main_func(local_rank=0, world_size=world_size, config=config)

train.py 100 train_loop
trainer.train()

trainer.py 261 train
loss, loss_dict, metrics_dict = self.train_iteration(step)

profiler.py 112 inner
out = func(*args, **kwargs)

trainer.py 496 train_iteration
_, loss_dict, metrics_dict = self.pipeline.get_train_loss_dict(step=step)

profiler.py 112 inner
out = func(*args, **kwargs)

base_pipeline.py 303 get_train_loss_dict
loss_dict = self.model.get_loss_dict(model_outputs, batch, metrics_dict)

splatfacto.py 868 get_loss_dict
mask = self._downscale_if_required(batch["mask"])

splatfacto.py 649 _downscale_if_required
return resize_image(image, d)

splatfacto.py 95 resize_image
return tf.conv2d(image.permute(2, 0, 1)[:, None, ...], weight, stride=d).squeeze(1).permute(1, 2, 0)

RuntimeError:
"slow_conv2d_cpu" not implemented for 'Bool'

Do you have any idea why this might be? Just to clarify ns-train splatfacto --data data/nerfstudio/poster runs for the demo dataset, so I dont think it is a purely splatfacto (gaussian splatting) problem. I can run it by changing the data parser output for masks:

dataparser_outputs = DataparserOutputs(
image_filenames=image_filenames,
cameras=cameras,
scene_box=scene_box,
#mask_filenames=mask_filenames if len(mask_filenames) > 0 else None,
mask_filenames=None,
dataparser_scale=scale_factor,
dataparser_transform=transform_matrix,
metadata={},
)

but is there a fix for it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants