Skip to content

Commit

Permalink
some configuration speedups, loops aren't actually needed!
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonioMacaronio committed Jun 15, 2024
1 parent 536c6ca commit 43a0061
Showing 1 changed file with 21 additions and 3 deletions.
24 changes: 21 additions & 3 deletions nerfstudio/data/datamanagers/base_datamanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,11 +339,11 @@ class VanillaDataManagerConfig(DataManagerConfig):
"""
patch_size: int = 1
"""Size of patch to sample from. If > 1, patch-based sampling will be used."""
dataloader_prefetch_size: int = 8
dataloader_prefetch_size: int = 1
"""The limit number of batches a worker will start loading once an iterator is created.
Each next() call on the iterator has the CPU prepare more batches up to this
limit while the GPU is performing forward and backward passes on the model."""
dataloader_num_workers: int = 1
dataloader_num_workers: int = 2
"""The number of workers performing the dataloading from either disk/RAM, which
includes undistortion, pixel sampling, ray generation, collating, etc."""
use_ray_train_dataloader: bool = True
Expand Down Expand Up @@ -513,7 +513,7 @@ def __iter__(self):
slice_start : slice_start + per_worker
] # the indices of the datapoints in the dataset this worker will load
r = random.Random(3301)
loop_iterations = 32
loop_iterations = 1
num_rays_per_loop = self.datamanager_config.train_num_rays_per_batch // loop_iterations # default train_num_rays_per_batch is 4096
worker_pixel_sampler = self._get_pixel_sampler(self.input_dataset, num_rays_per_loop)
if self.ray_generator is None:
Expand All @@ -524,14 +524,32 @@ def __iter__(self):
for _ in range(loop_iterations):
r.shuffle(worker_indices)
image_indices = worker_indices[:self.num_images_to_sample_from] # get a total of 'num_images_to_sample_from' image indices

# self._get_collated_batch is slow because it is going to disk to retreive an image many times to create a batch of images.
collated_batch = self._get_collated_batch(image_indices)

"""
Here, the variable 'batch' refers to the output of our pixel sampler. In particular
- batch is a dict_keys(['image', 'indices']) - output of pixel_sampler
- batch['image'] returns a pytorch tensor with shape `torch.Size([4096, 3])` , where 4096 = num_rays_per_batch. Note: each row in this tensor represents the RGB values as floats in [0, 1] of the pixel the ray goes through. The info of what specific image index that pixel belongs to is stored within batch[’indices’]
- batch['indices'] returns a pytorch tensor `torch.Size([4096, 3])` tensor where each row represents (image_index=camera_index, pixelRow, pixelCol)
What the pixel_sampler does (for variable_res_collate) is that it loops though each image, samples pixel within the mask,
and returns them as the variable `indices` which has shape `torch.Size([4096, 3])` , where each row represents a pixel (image_idx, y_pos, x_pos)
"""
batch = worker_pixel_sampler.sample(collated_batch) # the pixel_sampler will sample num_rays_per_batch pixels.

ray_indices = batch["indices"]
ray_bundle = self.ray_generator(ray_indices)
ray_bundle_list.append(ray_bundle)
batch_list.append(batch)

combined_metadata = {}
if "fisheye_crop_radius" in ray_bundle_list[0].metadata:
combined_metadata["fisheye_crop_radius"] = ray_bundle_list[0].metadata["fisheye_crop_radius"]
if "directions_norm" in ray_bundle_list[0].metadata:
combined_metadata["directions_norm"] = torch.cat([ray_bundle_i.metadata["directions_norm"] for ray_bundle_i in ray_bundle_list], dim=0)

concatenated_ray_bundle = RayBundle(
origins=torch.cat([ray_bundle_i.origins for ray_bundle_i in ray_bundle_list], dim=0),
directions=torch.cat([ray_bundle_i.directions for ray_bundle_i in ray_bundle_list], dim=0),
Expand Down

0 comments on commit 43a0061

Please sign in to comment.