Skip to content

Commit

Permalink
include TexturesUV in IO.save_mesh(x.obj)
Browse files Browse the repository at this point in the history
Summary:
Added export of UV textures to IO.save_mesh in Pytorch3d
MeshObjFormat now passes verts_uv, faces_uv, and texture_map as input to save_obj

TODO: check if TexturesUV.verts_uv_list or TexturesUV.verts_uv_padded() should be passed to save_obj

IO.save_mesh(obj_file, meshes, decimal_places=2) should be IO().save_mesh(obj_file, meshes, decimal_places=2)

Reviewed By: bottler

Differential Revision: D39617441

fbshipit-source-id: 4628b7f26f70e38c65f235852b990c8edb0ded23
  • Loading branch information
micramamonjisoa authored and facebook-github-bot committed Sep 21, 2022
1 parent 305cf32 commit 6ae6ff9
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 0 deletions.
13 changes: 13 additions & 0 deletions pytorch3d/io/obj_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,12 +334,25 @@ def save(

verts = data.verts_list()[0]
faces = data.faces_list()[0]

verts_uvs: Optional[torch.Tensor] = None
faces_uvs: Optional[torch.Tensor] = None
texture_map: Optional[torch.Tensor] = None

if isinstance(data.textures, TexturesUV):
verts_uvs = data.textures.verts_uvs_padded()[0]
faces_uvs = data.textures.faces_uvs_padded()[0]
texture_map = data.textures.maps_padded()[0]

save_obj(
f=path,
verts=verts,
faces=faces,
decimal_places=decimal_places,
path_manager=path_manager,
verts_uvs=verts_uvs,
faces_uvs=faces_uvs,
texture_map=texture_map,
)
return True

Expand Down
62 changes: 62 additions & 0 deletions tests/test_io_obj.py
Original file line number Diff line number Diff line change
Expand Up @@ -1050,6 +1050,68 @@ def test_save_obj_with_texture_errors(self):
texture_map=texture_map[..., 1], # Incorrect shape
)

def test_save_obj_with_texture_IO(self):
verts = torch.tensor(
[[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]],
dtype=torch.float32,
)
faces = torch.tensor(
[[0, 2, 1], [0, 1, 2], [3, 2, 1], [3, 1, 0]], dtype=torch.int64
)
verts_uvs = torch.tensor(
[[0.02, 0.5], [0.3, 0.03], [0.32, 0.12], [0.36, 0.17]],
dtype=torch.float32,
)
faces_uvs = faces
texture_map = torch.randint(size=(2, 2, 3), high=255) / 255.0

with TemporaryDirectory() as temp_dir:
obj_file = os.path.join(temp_dir, "mesh.obj")
textures_uv = TexturesUV([texture_map], [faces_uvs], [verts_uvs])
test_mesh = Meshes(verts=[verts], faces=[faces], textures=textures_uv)

IO().save_mesh(data=test_mesh, path=obj_file, decimal_places=2)

expected_obj_file = "\n".join(
[
"",
"mtllib mesh.mtl",
"usemtl mesh",
"",
"v 0.01 0.20 0.30",
"v 0.20 0.03 0.41",
"v 0.30 0.40 0.05",
"v 0.60 0.70 0.80",
"vt 0.02 0.50",
"vt 0.30 0.03",
"vt 0.32 0.12",
"vt 0.36 0.17",
"f 1/1 3/3 2/2",
"f 1/1 2/2 3/3",
"f 4/4 3/3 2/2",
"f 4/4 2/2 1/1",
]
)
expected_mtl_file = "\n".join(["newmtl mesh", "map_Kd mesh.png", ""])

# Check there are only 3 files in the temp dir
tempfiles = ["mesh.obj", "mesh.png", "mesh.mtl"]
tempfiles_dir = os.listdir(temp_dir)
self.assertEqual(Counter(tempfiles), Counter(tempfiles_dir))

# Check the obj file is saved correctly
actual_file = open(obj_file, "r")
self.assertEqual(actual_file.read(), expected_obj_file)

# Check the mtl file is saved correctly
mtl_file_name = os.path.join(temp_dir, "mesh.mtl")
mtl_file = open(mtl_file_name, "r")
self.assertEqual(mtl_file.read(), expected_mtl_file)

# Check the texture image file is saved correctly
texture_image = load_rgb_image("mesh.png", temp_dir)
self.assertClose(texture_image, texture_map)

@staticmethod
def _bm_save_obj(verts: torch.Tensor, faces: torch.Tensor, decimal_places: int):
return lambda: save_obj(StringIO(), verts, faces, decimal_places)
Expand Down

0 comments on commit 6ae6ff9

Please sign in to comment.