-
Notifications
You must be signed in to change notification settings - Fork 47
/
isosurface.py
50 lines (42 loc) · 1.83 KB
/
isosurface.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from typing import Callable, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
from skimage import measure
class IsosurfaceHelper(nn.Module):
points_range: Tuple[float, float] = (0, 1)
@property
def grid_vertices(self) -> torch.FloatTensor:
raise NotImplementedError
class MarchingCubeHelper(IsosurfaceHelper):
def __init__(self, resolution: int) -> None:
super().__init__()
self.resolution = resolution
#self.mc_func: Callable = marching_cubes
self._grid_vertices: Optional[torch.FloatTensor] = None
@property
def grid_vertices(self) -> torch.FloatTensor:
if self._grid_vertices is None:
# keep the vertices on CPU so that we can support very large resolution
x, y, z = (
torch.linspace(*self.points_range, self.resolution),
torch.linspace(*self.points_range, self.resolution),
torch.linspace(*self.points_range, self.resolution),
)
x, y, z = torch.meshgrid(x, y, z, indexing="ij")
verts = torch.cat(
[x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], dim=-1
).reshape(-1, 3)
self._grid_vertices = verts
return self._grid_vertices
def forward(
self,
level: torch.FloatTensor,
) -> Tuple[torch.FloatTensor, torch.LongTensor]:
level = -level.view(self.resolution, self.resolution, self.resolution)
v_pos, t_pos_idx, _, __ = measure.marching_cubes(level.detach().numpy(), 0.0) #self.mc_func(level.detach(), 0.0)
v_pos = torch.from_numpy(v_pos.copy()).type(torch.FloatTensor)
t_pos_idx = torch.from_numpy(t_pos_idx.copy()).type(torch.LongTensor)
v_pos = v_pos[..., [2, 1, 0]]
v_pos = v_pos / (self.resolution - 1.0)
return v_pos, t_pos_idx