In [1]:
import meshio
import torch
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.utils import to_undirected
from torch_geometric.transforms import FaceToEdge
import matplotlib.pyplot as plt

In [2]:
class TetraToEdge(object):
    r"""Converts mesh tetras :obj:`[4, num_tetras]` to edge indices
    :obj:`[2, num_edges]`.
    Args:
        remove_tetras (bool, optional): If set to :obj:`False`, the tetra tensor
            will not be removed.
    """

    def __init__(self, remove_tetras=True):
        self.remove_tetras = remove_tetras

    def __call__(self, data):
        if data.tetra is not None:
            tetra = data.tetra
            edge_index = torch.cat([tetra[:2], tetra[1:3, :], tetra[-2:], tetra[::2], tetra[::3], tetra[1::2]], dim=1)
            edge_index = to_undirected(edge_index, num_nodes=data.num_nodes)

            data.edge_index = edge_index
            if self.remove_tetras:
                data.tetra = None

        return data

    def __repr__(self):
        return '{}()'.format(self.__class__.__name__)

In [3]:
def from_meshio(mesh, mesh_type='2D'):
    r"""Converts a :.msh file to a
    :class:`torch_geometric.data.Data` instance.

    Args:
        mesh (meshio.read): A :obj:`meshio` mesh.
    """

    if meshio is None:
        raise ImportError('Package `meshio` could not be found.')

    pos = torch.from_numpy(mesh.points).to(torch.float)
    if mesh_type == '3D':
        tetra = torch.from_numpy(mesh.cells_dict['tetra']).to(torch.long).t().contiguous()
        return Data(pos=pos, tetra=tetra)
    elif mesh_type == '2D':
        face = torch.from_numpy(mesh.cells_dict['triangle']).to(torch.long).t().contiguous()
        return Data(pos=pos, face=face)

In [7]:
mesh_type = '3D'
filename = '../meshes/sphere_coarse.msh'

mesh = meshio.read(filename)

data = from_meshio(mesh, mesh_type=mesh_type)
if mesh_type == '2D':
    data = FaceToEdge(remove_faces=False)(data)
else:
    data = TetraToEdge(remove_tetras=False)(data)

data.tetra tiene información de los índices de cada tetraedro. [vértices (4), tetraedros]

In [9]:
data.tetra.shape

torch.Size([4, 4310])

In [10]:
data

Data(pos=[1001, 3], tetra=[4, 4310], edge_index=[2, 11746])

In [11]:
idx_tetra = 1

In [12]:
data.tetra[:,idx_tetra]

tensor([568, 688, 642, 694])

In [13]:
face_vertex = data.tetra[:,idx_tetra][:3]

In [None]:
torch.where