In [1]:
import os
import numpy as np
import torch
import pytorch3d
from pytorch3d.io import load_obj, save_obj
from pytorch3d.structures import Meshes
from pytorch3d.utils import ico_sphere

In [6]:
device = torch.device("cuda:0")
trg_obj = os.path.join('./data/dolphin.obj')
verts, faces, aux = load_obj(trg_obj)

faces_idx = faces.verts_idx.to(device)
verts = verts.to(device)
print(verts.shape, faces_idx.shape)

center = verts.mean(0)
verts = verts - center
scale = max(verts.abs().max(0)[0])
verts = verts / scale

bs = 7
verts = verts[None].expand(bs, -1, -1)
faces_idx = faces_idx[None].expand(bs, -1, -1)

print(verts.shape, faces_idx.shape)
# We construct a Meshes structure for the target mesh
trg_mesh = Meshes(verts=verts, faces=faces_idx)
trg_mesh

torch.Size([2562, 3]) torch.Size([5120, 3])
torch.Size([7, 2562, 3]) torch.Size([7, 5120, 3])


<pytorch3d.structures.meshes.Meshes at 0x7fb74e9522b0>

In [11]:
import torch.nn.functional as F
from src.config import get_parser
from src.util import make_faces

config = get_parser().parse_args(args=[])

blueprint = np.load(os.path.join(config.data_dir, config.blueprint)) 
points = torch.tensor(blueprint['points'])
normals = torch.tensor(blueprint['normals'])
points = F.interpolate(points, size=config.data_blueprint_size,
                               mode='bicubic', align_corners=True)
normals = F.interpolate(normals, size=config.data_blueprint_size, 
                        mode='bicubic', align_corners=True)   
points.shape, normals.shape

(torch.Size([1, 3, 640, 640]), torch.Size([1, 3, 640, 640]))

In [14]:
faces = torch.tensor(make_faces(config.data_blueprint_size, config.data_blueprint_size))
faces.shape

torch.Size([816642, 3])

In [15]:
bs = 5
vertices = points.reshape(1, 3, -1).permute(0, 2, 1).expand(bs, -1, -1)
faces = faces[None].expand(bs, -1, -1)
vertices.shape, faces.shape

(torch.Size([5, 409600, 3]), torch.Size([5, 816642, 3]))

In [16]:
trg_mesh = Meshes(verts=vertices, faces=faces)
trg_mesh

<pytorch3d.structures.meshes.Meshes at 0x7fb7250375b0>

In [17]:
mesh_edge_loss(trg_mesh)

tensor(8.8543e-05)

In [19]:
import torch

from pytorch3d.structures import Meshes
from src.util import make_faces

class EdgeLoss(torch.nn.Module):
    def __init__(self, config):
        super(EdgeLoss, self).__init__()
        self.patch_size = config.data_patch_size
        faces = torch.tensor(make_faces(self.patch_size, self.patch_size))
        vertices = torch.rand(self.patch_size ** 2, 3)
        meshes = Meshes(verts=[vertices], faces=[faces])
        self.no_edges = max(meshes.edges_packed().shape)
        edges_packed = meshes.edges_packed()       
        self.register_buffer('v0',  edges_packed[:, 0])
        self.register_buffer('v1',  edges_packed[:, 1])

    def forward(self, vertices, target_length=0):
        bs = vertices.size(0)
        no_edges = self.no_edges * bs        
        v0 = vertices.index_select(1, self.v0)
        v1 = vertices.index_select(1, self.v1)
        loss = ((v0 - v1).norm(dim=1, p=2) - target_length) ** 2.0
        return loss.sum() / no_edges

config.data_patch_size = config.data_blueprint_size
edgeLoss = EdgeLoss(config)
edgeLoss(vertices)

tensor(8.8474e-05)

In [20]:
edgeLoss(vertices)

tensor(8.8474e-05)

In [7]:
def mesh_edge_loss(meshes, target_length: float = 0.0):
    """
    Computes mesh edge length regularization loss averaged across all meshes
    in a batch. Each mesh contributes equally to the final loss, regardless of
    the number of edges per mesh in the batch by weighting each mesh with the
    inverse number of edges. For example, if mesh 3 (out of N) has only E=4
    edges, then the loss for each edge in mesh 3 should be multiplied by 1/E to
    contribute to the final loss.

    Args:
        meshes: Meshes object with a batch of meshes.
        target_length: Resting value for the edge length.

    Returns:
        loss: Average loss across the batch. Returns 0 if meshes contains
        no meshes or all empty meshes.
    """
    if meshes.isempty():
        return torch.tensor(
            [0.0], dtype=torch.float32, device=meshes.device, requires_grad=True
        )

    N = len(meshes)
    edges_packed = meshes.edges_packed()  # (sum(E_n), 3)
    verts_packed = meshes.verts_packed()  # (sum(V_n), 3)
    edge_to_mesh_idx = meshes.edges_packed_to_mesh_idx()  # (sum(E_n), )
    num_edges_per_mesh = meshes.num_edges_per_mesh()  # N

    # Determine the weight for each edge based on the number of edges in the
    # mesh it corresponds to.
    # TODO (nikhilar) Find a faster way of computing the weights for each edge
    # as this is currently a bottleneck for meshes with a large number of faces.
    weights = num_edges_per_mesh.gather(0, edge_to_mesh_idx)
    weights = 1.0 / weights.float()

    verts_edges = verts_packed[edges_packed]
    v0, v1 = verts_edges.unbind(1)
    loss = ((v0 - v1).norm(dim=1, p=2) - target_length) ** 2.0
    loss = loss * weights

    return loss.sum() / N

mesh_edge_loss(trg_mesh)

tensor(0.0020, device='cuda:0')

In [27]:
N = len(meshes)
edges_packed = meshes.edges_packed()  # (sum(E_n), 3)
verts_packed = meshes.verts_packed()  # (sum(V_n), 3)
edge_to_mesh_idx = meshes.edges_packed_to_mesh_idx()  # (sum(E_n), )
num_edges_per_mesh = meshes.num_edges_per_mesh()  # N

# Determine the weight for each edge based on the number of edges in the
# mesh it corresponds to.
# TODO (nikhilar) Find a faster way of computing the weights for each edge
# as this is currently a bottleneck for meshes with a large number of faces.
weights = num_edges_per_mesh.gather(0, edge_to_mesh_idx)
weights = 1.0 / weights.float()
weights

tensor([0.0001, 0.0001, 0.0001,  ..., 0.0001, 0.0001, 0.0001], device='cuda:0')

In [28]:
num_edges_per_mesh.gather(0, edge_to_mesh_idx)

tensor([7680, 7680, 7680,  ..., 7680, 7680, 7680], device='cuda:0',
       dtype=torch.int32)

In [15]:
v0, v1 = verts_edges.unbind(1)
loss = ((v0 - v1).norm(dim=1, p=2) - 0) ** 2.0
loss

tensor([0.0002, 0.0007, 0.0008,  ..., 0.0041, 0.0030, 0.0006], device='cuda:0')

In [16]:
(v0 - v1).norm(dim=1, p=2) 

tensor([0.0151, 0.0267, 0.0291,  ..., 0.0642, 0.0547, 0.0250], device='cuda:0')

EdgeLoss()

In [23]:
def mesh_edge_loss(meshes, target_length: float = 0.0):
    """
    Computes mesh edge length regularization loss averaged across all meshes
    in a batch. Each mesh contributes equally to the final loss, regardless of
    the number of edges per mesh in the batch by weighting each mesh with the
    inverse number of edges. For example, if mesh 3 (out of N) has only E=4
    edges, then the loss for each edge in mesh 3 should be multiplied by 1/E to
    contribute to the final loss.

    Args:
        meshes: Meshes object with a batch of meshes.
        target_length: Resting value for the edge length.

    Returns:
        loss: Average loss across the batch. Returns 0 if meshes contains
        no meshes or all empty meshes.
    """
    if meshes.isempty():
        return torch.tensor(
            [0.0], dtype=torch.float32, device=meshes.device, requires_grad=True
        )

    N = len(meshes)
    edges_packed = meshes.edges_packed()  # (sum(E_n), 3)
    verts_packed = meshes.verts_packed()  # (sum(V_n), 3)
    edge_to_mesh_idx = meshes.edges_packed_to_mesh_idx()  # (sum(E_n), )
    num_edges_per_mesh = meshes.num_edges_per_mesh()  # N

    # Determine the weight for each edge based on the number of edges in the
    # mesh it corresponds to.
    # TODO (nikhilar) Find a faster way of computing the weights for each edge
    # as this is currently a bottleneck for meshes with a large number of faces.
    weights = num_edges_per_mesh.gather(0, edge_to_mesh_idx)
    weights = 1.0 / weights.float()
    print(verts_packed.shape)
    print(edges_packed.shape)
    verts_edges = verts_packed[edges_packed]
    v0, v1 = verts_edges.unbind(1)
    loss = ((v0 - v1).norm(dim=1, p=2) - target_length) ** 2.0
    loss = loss * weights

    return loss.sum() / N

trg_mesh = Meshes(verts=verts[None].expand(5, -1, -1), 
                  faces=faces_idx[None].expand(5, -1, -1))


mesh_edge_loss(trg_mesh)

torch.Size([1, 195585, 2])

In [24]:
edgeLoss.edges.expand(4, -1, -1).shape

torch.Size([4, 195585, 2])

In [25]:
max(torch.rand(3, 111, 4).shape)

111

In [17]:
vertices = torch.rand(5, patch_size **2, 3)

In [18]:
patch_size = config.data_patch_size
edgeLoss(vertices)

tensor(0.5013)

torch.Size([12810, 3])
torch.Size([38400, 2])


tensor(0.0020, device='cuda:0')

In [9]:
faces_idx.shape

torch.Size([5120, 3])

In [13]:
trg_mesh.edges_packed().shape

torch.Size([38400, 2])

In [15]:
trg_mesh = Meshes(verts=verts[None].expand(3, -1, -1), 
                  faces=faces_idx[None].expand(3, -1, -1))
mesh_edge_loss(trg_mesh)

torch.Size([7686, 3])
torch.Size([23040, 2])


tensor(0.0020, device='cuda:0')

In [6]:
edges_packed.shape

torch.Size([7680, 2])

In [7]:
edges_packed[:, 0].shape, edges_packed[:, 1].shape

(torch.Size([7680]), torch.Size([7680]))

In [8]:
vertices = torch.rand(12, 9000, 3)
vertices.shape

torch.Size([12, 9000, 3])

In [12]:
vertices.index_select(1, edges_packed[:, 0].cpu()).shape

torch.Size([12, 7680, 3])