Skip to content

Commit

Permalink
apply Black 2024 style in fbcode (4/16)
Browse files Browse the repository at this point in the history
Summary:
Formats the covered files with pyfmt.

paintitblack

Reviewed By: aleivag

Differential Revision: D54447727

fbshipit-source-id: 8844b1caa08de94d04ac4df3c768dbf8c865fd2f
  • Loading branch information
amyreese authored and facebook-github-bot committed Mar 3, 2024
1 parent f34104c commit 3da7703
Show file tree
Hide file tree
Showing 31 changed files with 130 additions and 106 deletions.
14 changes: 8 additions & 6 deletions projects/nerf/nerf/nerf_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,12 +343,14 @@ def forward(
# For a full render pass concatenate the output chunks,
# and reshape to image size.
out = {
k: torch.cat(
[ch_o[k] for ch_o in chunk_outputs],
dim=1,
).view(-1, *self._image_size, 3)
if chunk_outputs[0][k] is not None
else None
k: (
torch.cat(
[ch_o[k] for ch_o in chunk_outputs],
dim=1,
).view(-1, *self._image_size, 3)
if chunk_outputs[0][k] is not None
else None
)
for k in ("rgb_fine", "rgb_coarse", "rgb_gt")
}
else:
Expand Down
10 changes: 5 additions & 5 deletions pytorch3d/implicitron/dataset/frame_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,11 +576,11 @@ def build(
camera_quality_score=safe_as_tensor(
sequence_annotation.viewpoint_quality_score, torch.float
),
point_cloud_quality_score=safe_as_tensor(
point_cloud.quality_score, torch.float
)
if point_cloud is not None
else None,
point_cloud_quality_score=(
safe_as_tensor(point_cloud.quality_score, torch.float)
if point_cloud is not None
else None
),
)

fg_mask_np: Optional[np.ndarray] = None
Expand Down
6 changes: 3 additions & 3 deletions pytorch3d/implicitron/dataset/json_index_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,9 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
dimension of the cropping bounding box, relative to box size.
"""

frame_annotations_type: ClassVar[
Type[types.FrameAnnotation]
] = types.FrameAnnotation
frame_annotations_type: ClassVar[Type[types.FrameAnnotation]] = (
types.FrameAnnotation
)

path_manager: Any = None
frame_annotations_file: str = ""
Expand Down
8 changes: 5 additions & 3 deletions pytorch3d/implicitron/dataset/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,11 @@ def get_implicitron_sequence_pointcloud(
frame_data.camera,
frame_data.image_rgb,
frame_data.depth_map,
(cast(torch.Tensor, frame_data.fg_probability) > 0.5).float()
if mask_points and frame_data.fg_probability is not None
else None,
(
(cast(torch.Tensor, frame_data.fg_probability) > 0.5).float()
if mask_points and frame_data.fg_probability is not None
else None
),
)

return point_cloud, frame_data
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,9 @@ def eval_batch(
image_rgb_masked=image_rgb_masked,
depth_render=cloned_render["depth_render"],
depth_map=frame_data.depth_map,
depth_mask=frame_data.depth_mask[:1]
if frame_data.depth_mask is not None
else None,
depth_mask=(
frame_data.depth_mask[:1] if frame_data.depth_mask is not None else None
),
visdom_env=visualize_visdom_env,
)

Expand Down
17 changes: 11 additions & 6 deletions pytorch3d/implicitron/models/generic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,9 +395,11 @@ def forward(
n_targets = (
1
if evaluation_mode == EvaluationMode.EVALUATION
else batch_size
if self.n_train_target_views <= 0
else min(self.n_train_target_views, batch_size)
else (
batch_size
if self.n_train_target_views <= 0
else min(self.n_train_target_views, batch_size)
)
)

# A helper function for selecting n_target first elements from the input
Expand All @@ -422,9 +424,12 @@ def safe_slice_targets(
ray_bundle: ImplicitronRayBundle = self.raysampler(
target_cameras,
evaluation_mode,
mask=mask_crop[:n_targets]
if mask_crop is not None and sampling_mode == RenderSamplingMode.MASK_SAMPLE
else None,
mask=(
mask_crop[:n_targets]
if mask_crop is not None
and sampling_mode == RenderSamplingMode.MASK_SAMPLE
else None
),
)

# custom_args hold additional arguments to the implicit function.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,7 @@ def __post_init__(self):
elif self.n_harmonic_functions_xyz >= 0 and layer_idx == 0:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
torch.nn.init.normal_(
lin.weight[:, :3], 0.0, 2**0.5 / out_dim**0.5
)
torch.nn.init.normal_(lin.weight[:, :3], 0.0, 2**0.5 / out_dim**0.5)
elif self.n_harmonic_functions_xyz >= 0 and layer_idx in self.skip_in:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.normal_(lin.weight, 0.0, 2**0.5 / out_dim**0.5)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,9 @@ def forward(
embeds = create_embeddings_for_implicit_function(
xyz_world=rays_points_world,
# for 2nd param but got `Union[None, torch.Tensor, torch.nn.Module]`.
xyz_embedding_function=self.harmonic_embedding_xyz
if self.input_xyz
else None,
xyz_embedding_function=(
self.harmonic_embedding_xyz if self.input_xyz else None
),
global_code=global_code,
fun_viewpool=fun_viewpool,
xyz_in_camera_coords=self.xyz_ray_dir_in_camera_coords,
Expand Down
19 changes: 12 additions & 7 deletions pytorch3d/implicitron/models/overfit_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,9 +356,12 @@ def forward(
ray_bundle: ImplicitronRayBundle = self.raysampler(
camera,
evaluation_mode,
mask=mask_crop
if mask_crop is not None and sampling_mode == RenderSamplingMode.MASK_SAMPLE
else None,
mask=(
mask_crop
if mask_crop is not None
and sampling_mode == RenderSamplingMode.MASK_SAMPLE
else None
),
)

inputs_to_be_chunked = {}
Expand All @@ -381,10 +384,12 @@ def forward(
frame_timestamp=frame_timestamp,
)
implicit_functions = [
functools.partial(implicit_function, global_code=global_code)
if isinstance(implicit_function, Callable)
else functools.partial(
implicit_function.forward, global_code=global_code
(
functools.partial(implicit_function, global_code=global_code)
if isinstance(implicit_function, Callable)
else functools.partial(
implicit_function.forward, global_code=global_code
)
)
for implicit_function in implicit_functions
]
Expand Down
20 changes: 12 additions & 8 deletions pytorch3d/implicitron/models/renderer/ray_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,12 @@ def __post_init__(self):
n_pts_per_ray=n_pts_per_ray_training,
min_depth=0.0,
max_depth=0.0,
n_rays_per_image=self.n_rays_per_image_sampled_from_mask
if self._sampling_mode[EvaluationMode.TRAINING]
== RenderSamplingMode.MASK_SAMPLE
else None,
n_rays_per_image=(
self.n_rays_per_image_sampled_from_mask
if self._sampling_mode[EvaluationMode.TRAINING]
== RenderSamplingMode.MASK_SAMPLE
else None
),
n_rays_total=self.n_rays_total_training,
unit_directions=True,
stratified_sampling=self.stratified_point_sampling_training,
Expand All @@ -160,10 +162,12 @@ def __post_init__(self):
n_pts_per_ray=n_pts_per_ray_evaluation,
min_depth=0.0,
max_depth=0.0,
n_rays_per_image=self.n_rays_per_image_sampled_from_mask
if self._sampling_mode[EvaluationMode.EVALUATION]
== RenderSamplingMode.MASK_SAMPLE
else None,
n_rays_per_image=(
self.n_rays_per_image_sampled_from_mask
if self._sampling_mode[EvaluationMode.EVALUATION]
== RenderSamplingMode.MASK_SAMPLE
else None
),
unit_directions=True,
stratified_sampling=self.stratified_point_sampling_evaluation,
)
Expand Down
2 changes: 1 addition & 1 deletion pytorch3d/implicitron/models/renderer/ray_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ def ray_sampler(
]
sampler_dists[mask_intersect_idx[p_out_mask]] = pts_intervals[
p_out_mask,
:
:,
# pyre-fixme[6]: For 1st param expected `Union[bool, float, int]` but
# got `Tensor`.
][torch.arange(n_p_out), out_pts_idx]
Expand Down
24 changes: 12 additions & 12 deletions pytorch3d/implicitron/models/renderer/sdf_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ def __post_init__(

run_auto_creation(self)

self.ray_normal_coloring_network_args[
"feature_vector_size"
] = render_features_dimensions
self.ray_normal_coloring_network_args["feature_vector_size"] = (
render_features_dimensions
)
self._rgb_network = RayNormalColoringNetwork(
**self.ray_normal_coloring_network_args
)
Expand Down Expand Up @@ -201,15 +201,15 @@ def forward(
None, :, 0, :
]
normals_full.view(-1, 3)[surface_mask] = normals
render_full.view(-1, self.render_features_dimensions)[
surface_mask
] = self._rgb_network(
features,
differentiable_surface_points[None],
normals,
ray_bundle,
surface_mask[None, :, None],
pooling_fn=None, # TODO
render_full.view(-1, self.render_features_dimensions)[surface_mask] = (
self._rgb_network(
features,
differentiable_surface_points[None],
normals,
ray_bundle,
surface_mask[None, :, None],
pooling_fn=None, # TODO
)
)
mask_full.view(-1, 1)[~surface_mask] = torch.sigmoid(
# pyre-fixme[6]: For 1st param expected `Tensor` but got `float`.
Expand Down
6 changes: 3 additions & 3 deletions pytorch3d/implicitron/tools/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,9 +241,9 @@ class _Registry:
"""

def __init__(self) -> None:
self._mapping: Dict[
Type[ReplaceableBase], Dict[str, Type[ReplaceableBase]]
] = defaultdict(dict)
self._mapping: Dict[Type[ReplaceableBase], Dict[str, Type[ReplaceableBase]]] = (
defaultdict(dict)
)

def register(self, some_class: Type[_X]) -> Type[_X]:
"""
Expand Down
8 changes: 5 additions & 3 deletions pytorch3d/implicitron/tools/eval_video_trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,11 @@ def generate_eval_video_cameras(
fit = fit_circle_in_3d(
cam_centers,
angles=angle,
offset=angle.new_tensor(traj_offset_canonical)
if traj_offset_canonical is not None
else None,
offset=(
angle.new_tensor(traj_offset_canonical)
if traj_offset_canonical is not None
else None
),
up=angle.new_tensor(up),
)
traj = fit.generated_points
Expand Down
8 changes: 5 additions & 3 deletions pytorch3d/implicitron/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,11 @@ def cat_dataclass(batch, tensor_collator: Callable):
)
elif isinstance(elem_f, collections.abc.Mapping):
collated[f.name] = {
k: tensor_collator([getattr(e, f.name)[k] for e in batch])
if elem_f[k] is not None
else None
k: (
tensor_collator([getattr(e, f.name)[k] for e in batch])
if elem_f[k] is not None
else None
)
for k in elem_f
}
else:
Expand Down
1 change: 0 additions & 1 deletion pytorch3d/renderer/fisheyecameras.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ def __init__(
device: Device = "cpu",
image_size: Optional[Union[List, Tuple, torch.Tensor]] = None,
) -> None:

"""
Args:
Expand Down
6 changes: 3 additions & 3 deletions pytorch3d/renderer/mesh/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,9 +712,9 @@ def convert_clipped_rasterization_to_original_faces(
)

bary_coords_unclipped_subset = bary_coords_unclipped_subset.reshape([N * 3])
bary_coords_unclipped[
faces_to_convert_mask_expanded
] = bary_coords_unclipped_subset
bary_coords_unclipped[faces_to_convert_mask_expanded] = (
bary_coords_unclipped_subset
)

# dists for case 4 faces will be handled in the rasterizer
# so no need to modify them here.
Expand Down
5 changes: 4 additions & 1 deletion pytorch3d/renderer/mesh/rasterize_meshes.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,10 @@ def rasterize_meshes_python( # noqa: C901
# If faces were clipped, map the rasterization result to be in terms of the
# original unclipped faces. This may involve converting barycentric
# coordinates
(face_idxs, bary_coords,) = convert_clipped_rasterization_to_original_faces(
(
face_idxs,
bary_coords,
) = convert_clipped_rasterization_to_original_faces(
face_idxs,
bary_coords,
# pyre-fixme[61]: `clipped_faces` may not be initialized here.
Expand Down
1 change: 1 addition & 0 deletions pytorch3d/renderer/opengl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


# If we can access EGL, import MeshRasterizerOpenGL.
def _can_import_egl_and_pycuda():
import os
Expand Down
8 changes: 5 additions & 3 deletions pytorch3d/renderer/opengl/rasterizer_opengl.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,9 +292,11 @@ def __call__(
pix_to_face, bary_coord, zbuf = self._rasterize_mesh(
mesh,
image_size,
projection_matrix=projection_matrix[mesh_id]
if projection_matrix.shape[0] > 1
else None,
projection_matrix=(
projection_matrix[mesh_id]
if projection_matrix.shape[0] > 1
else None
),
)
pix_to_faces.append(pix_to_face)
bary_coords.append(bary_coord)
Expand Down
12 changes: 6 additions & 6 deletions tests/implicitron/test_extending_orm_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ class ExtendedSqlFrameAnnotation(SqlFrameAnnotation):


class ExtendedSqlIndexDataset(SqlIndexDataset):
frame_annotations_type: ClassVar[
Type[SqlFrameAnnotation]
] = ExtendedSqlFrameAnnotation
frame_annotations_type: ClassVar[Type[SqlFrameAnnotation]] = (
ExtendedSqlFrameAnnotation
)


class CanineFrameData(FrameData):
Expand Down Expand Up @@ -96,9 +96,9 @@ def build(


class CanineSqlIndexDataset(SqlIndexDataset):
frame_annotations_type: ClassVar[
Type[SqlFrameAnnotation]
] = ExtendedSqlFrameAnnotation
frame_annotations_type: ClassVar[Type[SqlFrameAnnotation]] = (
ExtendedSqlFrameAnnotation
)

frame_data_builder_class_type: str = "CanineFrameDataBuilder"

Expand Down
10 changes: 5 additions & 5 deletions tests/implicitron/test_frame_data_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,11 @@ def setUp(self):
camera_quality_score=safe_as_tensor(
self.seq_annotation.viewpoint_quality_score, torch.float
),
point_cloud_quality_score=safe_as_tensor(
point_cloud.quality_score, torch.float
)
if point_cloud is not None
else None,
point_cloud_quality_score=(
safe_as_tensor(point_cloud.quality_score, torch.float)
if point_cloud is not None
else None
),
)

def test_frame_data_builder_args(self):
Expand Down

0 comments on commit 3da7703

Please sign in to comment.