Skip to content

Commit

Permalink
Use dataclasses inside ply_io.
Browse files Browse the repository at this point in the history
Summary: Refactor ply_io to make it easier to add new features. Mostly taken from the starting code I attached to #904.

Reviewed By: patricklabatut

Differential Revision: D34375978

fbshipit-source-id: ec017d31f07c6f71ba6d97a0623bb10be1e81212
  • Loading branch information
bottler authored and facebook-github-bot committed Feb 21, 2022
1 parent feb5d36 commit 967a099
Showing 1 changed file with 121 additions and 71 deletions.
192 changes: 121 additions & 71 deletions pytorch3d/io/ply_io.py
Expand Up @@ -14,8 +14,9 @@
import sys
import warnings
from collections import namedtuple
from dataclasses import asdict, dataclass
from io import BytesIO, TextIOBase
from typing import List, Optional, Tuple, cast
from typing import List, Optional, Tuple

import numpy as np
import torch
Expand Down Expand Up @@ -137,6 +138,7 @@ def __init__(self, f) -> None:
self.ascii: (bool) Whether in ascii format
self.big_endian: (bool) (if not ascii) whether big endian
self.obj_info: (List[str]) arbitrary extra data
self.comments: (List[str]) comments
Args:
f: file-like object.
Expand All @@ -145,7 +147,8 @@ def __init__(self, f) -> None:
raise ValueError("Invalid file header.")
seen_format = False
self.elements: List[_PlyElementType] = []
self.obj_info = []
self.comments: List[str] = []
self.obj_info: List[str] = []
while True:
line = f.readline()
if isinstance(line, bytes):
Expand Down Expand Up @@ -176,6 +179,9 @@ def __init__(self, f) -> None:
continue
if line.startswith("format"):
raise ValueError("Invalid format line.")
if line.startswith("comment "):
self.comments.append(line[8:])
continue
if line.startswith("comment") or len(line) == 0:
continue
if line.startswith("element"):
Expand Down Expand Up @@ -781,9 +787,28 @@ def _load_ply_raw(f, path_manager: PathManager) -> Tuple[_PlyHeader, dict]:
return header, elements


@dataclass(frozen=True)
class _VertsColumnIndices:
"""
Contains the relevant layout of the verts section of file being read.
Members
point_idxs: List[int] of 3 point columns.
color_idxs: List[int] of 3 color columns if they are present,
otherwise None.
color_scale: value to scale colors by.
normal_idxs: List[int] of 3 normals columns if they are present,
otherwise None.
"""

point_idxs: List[int]
color_idxs: Optional[List[int]]
color_scale: float
normal_idxs: Optional[List[int]]


def _get_verts_column_indices(
vertex_head: _PlyElementType,
) -> Tuple[List[int], Optional[List[int]], float, Optional[List[int]]]:
) -> _VertsColumnIndices:
"""
Get the columns of verts, verts_colors, and verts_normals in the vertex
element of a parsed ply file, together with a color scale factor.
Expand All @@ -809,12 +834,7 @@ def _get_verts_column_indices(
vertex_head: as returned from load_ply_raw.
Returns:
point_idxs: List[int] of 3 point columns.
color_idxs: List[int] of 3 color columns if they are present,
otherwise None.
color_scale: value to scale colors by.
normal_idxs: List[int] of 3 normals columns if they are present,
otherwise None.
_VertsColumnIndices object
"""
point_idxs: List[Optional[int]] = [None, None, None]
color_idxs: List[Optional[int]] = [None, None, None]
Expand All @@ -839,29 +859,38 @@ def _get_verts_column_indices(
for idx in color_idxs
):
color_scale = 1.0 / 255
return (
point_idxs,
# pyre-fixme[22]: The cast is redundant.
None if None in color_idxs else cast(List[int], color_idxs),
color_scale,
# pyre-fixme[22]: The cast is redundant.
None if None in normal_idxs else cast(List[int], normal_idxs),
return _VertsColumnIndices(
point_idxs=point_idxs,
color_idxs=None if None in color_idxs else color_idxs,
color_scale=color_scale,
normal_idxs=None if None in normal_idxs else normal_idxs,
)


def _get_verts(
header: _PlyHeader, elements: dict
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
@dataclass(frozen=True)
class _VertsData:
"""
Contains the data of the verts section of file being read.
Members:
verts: FloatTensor of shape (V, 3).
verts_colors: None or FloatTensor of shape (V, 3).
verts_normals: None or FloatTensor of shape (V, 3).
"""

verts: torch.Tensor
verts_colors: Optional[torch.Tensor] = None
verts_normals: Optional[torch.Tensor] = None


def _get_verts(header: _PlyHeader, elements: dict) -> _VertsData:
"""
Get the vertex locations, colors and normals from a parsed ply file.
Args:
header, elements: as returned from load_ply_raw.
Returns:
verts: FloatTensor of shape (V, 3).
vertex_colors: None or FloatTensor of shape (V, 3).
vertex_normals: None or FloatTensor of shape (V, 3).
_VertsData object
"""

vertex = elements.get("vertex", None)
Expand All @@ -870,16 +899,17 @@ def _get_verts(
if not isinstance(vertex, list):
raise ValueError("Invalid vertices in file.")
vertex_head = next(head for head in header.elements if head.name == "vertex")
point_idxs, color_idxs, color_scale, normal_idxs = _get_verts_column_indices(
vertex_head
)

column_idxs = _get_verts_column_indices(vertex_head)

# Case of no vertices
if vertex_head.count == 0:
verts = torch.zeros((0, 3), dtype=torch.float32)
if color_idxs is None:
return verts, None, None
return verts, torch.zeros((0, 3), dtype=torch.float32), None
if column_idxs.color_idxs is None:
return _VertsData(verts=verts)
return _VertsData(
verts=verts, verts_colors=torch.zeros((0, 3), dtype=torch.float32)
)

# Simple case where the only data is the vertices themselves
if (
Expand All @@ -888,22 +918,22 @@ def _get_verts(
and vertex[0].ndim == 2
and vertex[0].shape[1] == 3
):
return _make_tensor(vertex[0], cols=3, dtype=torch.float32), None, None
return _VertsData(verts=_make_tensor(vertex[0], cols=3, dtype=torch.float32))

vertex_colors = None
vertex_normals = None

if len(vertex) == 1:
# This is the case where the whole vertex element has one type,
# so it was read as a single array and we can index straight into it.
verts = torch.tensor(vertex[0][:, point_idxs], dtype=torch.float32)
if color_idxs is not None:
vertex_colors = color_scale * torch.tensor(
vertex[0][:, color_idxs], dtype=torch.float32
verts = torch.tensor(vertex[0][:, column_idxs.point_idxs], dtype=torch.float32)
if column_idxs.color_idxs is not None:
vertex_colors = column_idxs.color_scale * torch.tensor(
vertex[0][:, column_idxs.color_idxs], dtype=torch.float32
)
if normal_idxs is not None:
if column_idxs.normal_idxs is not None:
vertex_normals = torch.tensor(
vertex[0][:, normal_idxs], dtype=torch.float32
vertex[0][:, column_idxs.normal_idxs], dtype=torch.float32
)
else:
# The vertex element is heterogeneous. It was read as several arrays,
Expand All @@ -918,7 +948,7 @@ def _get_verts(
]
verts = torch.empty(size=(vertex_head.count, 3), dtype=torch.float32)
for axis in range(3):
partnum, col = prop_to_partnum_col[point_idxs[axis]]
partnum, col = prop_to_partnum_col[column_idxs.point_idxs[axis]]
verts.numpy()[:, axis] = vertex[partnum][:, col]
# Note that in the previous line, we made the assignment
# as numpy arrays by casting verts. If we took the (more
Expand All @@ -928,30 +958,49 @@ def _get_verts(
# if not vertex[partnum].flags["C_CONTIGUOUS"]:
# vertex[partnum] = np.ascontiguousarray(vertex[partnum])
# verts[:, axis] = torch.tensor((vertex[partnum][:, col]))
if color_idxs is not None:
if column_idxs.color_idxs is not None:
vertex_colors = torch.empty(
size=(vertex_head.count, 3), dtype=torch.float32
)
for color in range(3):
partnum, col = prop_to_partnum_col[color_idxs[color]]
partnum, col = prop_to_partnum_col[column_idxs.color_idxs[color]]
vertex_colors.numpy()[:, color] = vertex[partnum][:, col]
vertex_colors *= color_scale
if normal_idxs is not None:
vertex_colors *= column_idxs.color_scale
if column_idxs.normal_idxs is not None:
vertex_normals = torch.empty(
size=(vertex_head.count, 3), dtype=torch.float32
)
for axis in range(3):
partnum, col = prop_to_partnum_col[normal_idxs[axis]]
partnum, col = prop_to_partnum_col[column_idxs.normal_idxs[axis]]
vertex_normals.numpy()[:, axis] = vertex[partnum][:, col]

return verts, vertex_colors, vertex_normals
return _VertsData(
verts=verts,
verts_colors=vertex_colors,
verts_normals=vertex_normals,
)


@dataclass(frozen=True)
class _PlyData:
"""
Contains the data from a PLY file which has been read.
Members:
header: _PlyHeader of file metadata from the header
verts: FloatTensor of shape (V, 3).
faces: None or LongTensor of vertex indices, shape (F, 3).
verts_colors: None or FloatTensor of shape (V, 3).
verts_normals: None or FloatTensor of shape (V, 3).
"""

header: _PlyHeader
verts: torch.Tensor
faces: Optional[torch.Tensor]
verts_colors: Optional[torch.Tensor]
verts_normals: Optional[torch.Tensor]


def _load_ply(
f, *, path_manager: PathManager
) -> Tuple[
torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]
]:
def _load_ply(f, *, path_manager: PathManager) -> _PlyData:
"""
Load the data from a .ply file.
Expand All @@ -964,14 +1013,11 @@ def _load_ply(
path_manager: PathManager for loading if f is a str.
Returns:
verts: FloatTensor of shape (V, 3).
faces: None or LongTensor of vertex indices, shape (F, 3).
vertex_colors: None or FloatTensor of shape (V, 3).
vertex_normals: None or FloatTensor of shape (V, 3).
_PlyData object
"""
header, elements = _load_ply_raw(f, path_manager=path_manager)

verts, vertex_colors, vertex_normals = _get_verts(header, elements)
verts_data = _get_verts(header, elements)

face = elements.get("face", None)
if face is not None:
Expand Down Expand Up @@ -1007,9 +1053,9 @@ def _load_ply(
faces = torch.tensor(face_list, dtype=torch.int64)

if faces is not None:
_check_faces_indices(faces, max_index=verts.shape[0])
_check_faces_indices(faces, max_index=verts_data.verts.shape[0])

return verts, faces, vertex_colors, vertex_normals
return _PlyData(**asdict(verts_data), faces=faces, header=header)


def load_ply(
Expand Down Expand Up @@ -1064,11 +1110,12 @@ def load_ply(

if path_manager is None:
path_manager = PathManager()
verts, faces, _, _ = _load_ply(f, path_manager=path_manager)
data = _load_ply(f, path_manager=path_manager)
faces = data.faces
if faces is None:
faces = torch.zeros(0, 3, dtype=torch.int64)

return verts, faces
return data.verts, faces


def _write_ply_header(
Expand Down Expand Up @@ -1305,20 +1352,20 @@ def read(
if not endswith(path, self.known_suffixes):
return None

verts, faces, verts_colors, verts_normals = _load_ply(
f=path, path_manager=path_manager
)
data = _load_ply(f=path, path_manager=path_manager)
faces = data.faces
if faces is None:
faces = torch.zeros(0, 3, dtype=torch.int64)

texture = None
if include_textures and verts_colors is not None:
texture = TexturesVertex([verts_colors.to(device)])
if include_textures and data.verts_colors is not None:
texture = TexturesVertex([data.verts_colors.to(device)])

if verts_normals is not None:
verts_normals = [verts_normals]
verts_normals = None
if data.verts_normals is not None:
verts_normals = [data.verts_normals.to(device)]
mesh = Meshes(
verts=[verts.to(device)],
verts=[data.verts.to(device)],
faces=[faces.to(device)],
textures=texture,
verts_normals=verts_normals,
Expand Down Expand Up @@ -1392,14 +1439,17 @@ def read(
if not endswith(path, self.known_suffixes):
return None

verts, faces, features, normals = _load_ply(f=path, path_manager=path_manager)
verts = verts.to(device)
if features is not None:
features = [features.to(device)]
if normals is not None:
normals = [normals.to(device)]
data = _load_ply(f=path, path_manager=path_manager)
features = None
if data.verts_colors is not None:
features = [data.verts_colors.to(device)]
normals = None
if data.verts_normals is not None:
normals = [data.verts_normals.to(device)]

pointcloud = Pointclouds(points=[verts], features=features, normals=normals)
pointcloud = Pointclouds(
points=[data.verts.to(device)], features=features, normals=normals
)
return pointcloud

def save(
Expand Down

0 comments on commit 967a099

Please sign in to comment.