Skip to content

Commit

Permalink
defaulted grid_sizes in points2vols
Browse files Browse the repository at this point in the history
Summary: Fix #873, that grid_sizes defaults to the wrong dtype in points2volumes code, and mask doesn't have a proper default.

Reviewed By: nikhilaravi

Differential Revision: D31503545

fbshipit-source-id: fa32a1a6074fc7ac7bdb362edfb5e5839866a472
  • Loading branch information
bottler authored and facebook-github-bot committed Oct 16, 2021
1 parent 2f2466f commit 34b1b4a
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pytorch3d/common/workaround/symeig3x3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

import math
from typing import Tuple, Optional
from typing import Optional, Tuple

import torch
import torch.nn.functional as F
Expand Down
6 changes: 5 additions & 1 deletion pytorch3d/ops/points_to_volumes.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def add_points_features_to_volume_densities_features(
# grid sizes shape (minibatch, 3)
grid_sizes = (
torch.LongTensor(list(volume_densities.shape[2:]))
.to(volume_densities)
.to(volume_densities.device)
.expand(volume_densities.shape[0], 3)
)

Expand All @@ -386,6 +386,10 @@ def add_points_features_to_volume_densities_features(
splat = False
else:
raise ValueError('No such interpolation mode "%s"' % mode)

if mask is None:
mask = points_3d.new_ones(1).expand(points_3d.shape[:2])

volume_densities, volume_features = _points_to_volumes(
points_3d,
points_features,
Expand Down
3 changes: 2 additions & 1 deletion tests/bm_symeig3x3.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@


from itertools import product
from typing import Callable, Any
from typing import Any, Callable

import torch
from common_testing import get_random_cuda_device
from fvcore.common.benchmark import benchmark
from pytorch3d.common.workaround import symeig3x3
from test_symeig3x3 import TestSymEig3x3


torch.set_num_threads(1)

CUDA_DEVICE = get_random_cuda_device()
Expand Down
1 change: 1 addition & 0 deletions tests/test_iou_box3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from pytorch3d.ops.iou_box3d import _box_planes, _box_triangles, box3d_overlap
from pytorch3d.transforms.rotation_conversions import random_rotation


OBJECTRON_TO_PYTORCH3D_FACE_IDX = [0, 4, 6, 2, 1, 5, 7, 3]
DATA_DIR = get_tests_dir() / "data"
DEBUG = False
Expand Down
16 changes: 15 additions & 1 deletion tests/test_points_to_volumes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
import numpy as np
import torch
from common_testing import TestCaseMixin
from pytorch3d.ops import add_pointclouds_to_volumes
from pytorch3d.ops import (
add_pointclouds_to_volumes,
add_points_features_to_volume_densities_features,
)
from pytorch3d.ops.points_to_volumes import _points_to_volumes
from pytorch3d.ops.sample_points_from_meshes import sample_points_from_meshes
from pytorch3d.structures.meshes import Meshes
Expand Down Expand Up @@ -373,6 +376,17 @@ def test_from_point_cloud(self, interp_mode="trilinear"):
else:
self.assertTrue(torch.isfinite(field.grad.data).all())

def test_defaulted_arguments(self):
points = torch.rand(30, 1000, 3)
features = torch.rand(30, 1000, 5)
_, densities = add_points_features_to_volume_densities_features(
points,
features,
torch.zeros(30, 1, 32, 32, 32),
torch.zeros(30, 5, 32, 32, 32),
)
self.assertClose(torch.sum(densities), torch.tensor(30 * 1000.0), atol=0.1)

def _check_volume_slice_color_density(
self, V, split_dim, interp_mode, clr_gt, slice_type, border=3
):
Expand Down

0 comments on commit 34b1b4a

Please sign in to comment.