## Import Libraries

In [None]:
import sys,os

RES_PATH = 'models' # Folder where all 3D models are kept. Same dir level as this file.

if not os.path.exists(RES_PATH):
    print(f"cannot find {RES_PATH} folder, please update RES_PATH")
    exit(1)
else:
    pass

import pyglet
pyglet.options['shadow_window'] = False

import pyrender
import numpy as np
import trimesh

import matplotlib
import matplotlib.pyplot as plt

from sklearn.neighbors import KDTree

import heapq
%load_ext autoreload
%autoreload 2

## Trimesh Utility

In [None]:
def load_mesh(filename):
    mesh_fp = os.path.join(RES_PATH, filename)
    
    assert os.path.exists(mesh_fp), 'cannot found:'+mesh_fp
    
    t_mesh = trimesh.load(mesh_fp)
    
    return t_mesh

## Mesh Simplification Utility

In [None]:
def get_all_valid_pairs(mesh, threshold):
    thresh_squared = threshold ** 2
    valid_pairs = []
    edge_pair_map = dict() # Maps vertex indices to a set of all the vertex indices its paired to.
    
    # Updates the mapping of the pairs appropriately.
    def update_mapping(v1, v2):
        if v1 in edge_pair_map: edge_pair_map[v1].add(v2)
        else: edge_pair_map[v1] = {v2}
        
        if v2 in edge_pair_map: edge_pair_map[v2].add(v1)
        else: edge_pair_map[v2] = {v1}
    
    print("Getting edge pairs...")
    for edge in mesh.edges_unique:
        valid_pairs.append([edge[0], edge[1]])
        update_mapping(edge[0], edge[1])
    
    print("Getting vertex pairs within threshold...")
    # Gets vertex pairs within threshold.
    for i, vert1 in enumerate(mesh.vertices):
        j = i + 1
        for vert2 in (mesh.vertices[j:]):
            if not (i in edge_pair_map[j]): # Checks for duplicate pairs from the edges.
                # Uses squared distance to avoid inefficient square root function.
                sq_distance = np.sum((vert1-vert2) ** 2, axis=0)
    
                if sq_distance < thresh_squared: 
                    valid_pairs.append([i, j])
                    update_mapping(i, j)
        
        print(f"i = {i}/{len(mesh.vertices)-1}")
    
    return valid_pairs

def calculate_vertex_error_quadric(mesh, vertex_index):
    tri_face_indices = mesh.vertex_faces[vertex_index]
    
    Q = np.zeros((4, 4))
    for i in (i for i in tri_face_indices if i != -1): # Filters out padded -1s.
        face = mesh.faces[i]
        p1 = mesh.vertices[face[0]]
        p2 = mesh.vertices[face[1]]
        p3 = mesh.vertices[face[2]]
        
        # Calculates coefficients a, b, c and d, of the plane equation of the triangle face.
        n = np.cross((p2 - p1), (p3 - p1))
        a, b, c = n = n / np.linalg.norm(n)
        d = np.dot(n, p1)
        
        K_p = np.array([[a**2, a*b, a*c, a*d],
                        [a*b, b**2, b*c, b*d],
                        [a*c, b*c, c**2, c*d],
                        [a*d, b*d, c*d, d**2]])
        Q += K_p
    
    return Q

In [None]:
def compute_optimal_contraction(m, pair, q1, q2):
    q_bar = q1 + q2
    q_bar[3,:] = [0,0,0,1]
    b = np.array([[0],[0],[0],[1]])
    
    if np.linalg.det(q_bar) != 0: result = -(np.linalg.inv(q_bar) @ b)
    else: 
        result = (m.vertices[pair[0]] + m.vertices[pair[1]]) / 2
        result = np.expand_dims(np.append(result, 1), axis=1)
        
    return result

In [None]:
def calculate_error(v_bar, pair, q1, q2):
    error = v_bar.T@(q1 + q2)@v_bar
    return error[0][0]

In [None]:
def create_heap(m, pairs):
    heap = []
    for pair in pairs:
        q1 = calculate_vertex_error_quadric(m, pair[0])
        q2 = calculate_vertex_error_quadric(m, pair[1])
        v_bar = compute_optimal_contraction(m, pair, q1, q2)
        error = calculate_error(v_bar, pair, q1, q2)
        heapq.heappush(heap, [error, pair])
    return heap

In [None]:
def calculate_new_error(mesh, pair):
    q1 = calculate_vertex_error_quadric(mesh, pair[0])
    q2 = calculate_vertex_error_quadric(mesh, pair[1])
    new_v_bar = compute_optimal_contraction(mesh, pair, q1, q2)
    error = calculate_error(new_v_bar, pair, q1, q2)
    return error

def update_valid_pairs(mesh, heap, v_bar_index, v1_index, v2_index):
    # Find heap indices of where the pairs containing v1 and v2 are (since they will be changed
    # when shifting the indices in the next loop).
    heap_indices_to_update = []
    
    heap_indices_to_delete = [] # Prevents creating duplicate pairs with v_bar_index if v1_index and
    vertices_paired = set()     # v2_index share a common index as a pair.
    
    for i, item in enumerate(heap):
        pair = item[1]
        
        if pair[0] == v1_index or pair[0] == v2_index:
            if pair[1] in vertices_paired: heap_indices_to_delete.append(i)
            else:
                heap_indices_to_update.append([i, 0])
                vertices_paired.add(pair[1])
        elif pair[1] == v1_index or pair[1] == v2_index:
            if pair[0] in vertices_paired: heap_indices_to_delete.append(i)
            else:
                heap_indices_to_update.append([i, 1])
                vertices_paired.add(pair[0])
    
    # Shifts vertex indices accordingly to accomodate for removed v1 and v2 vertices and added v_bar vertex.
    for i, item in enumerate(heap):
        pair = item[1]
        shift1, shift2 = 0, 0
        
        if pair[0] > v1_index: shift1 -= 1
        if pair[0] > v2_index: shift1 -= 1
        if pair[1] > v1_index: shift2 -= 1
        if pair[1] > v2_index: shift2 -= 1
        
        heap[i][1][0] += shift1
        heap[i][1][1] += shift2
        
        if heap[i][1][0] >= v_bar_index: heap[i][1][0] += 1
        if heap[i][1][1] >= v_bar_index: heap[i][1][1] += 1
    
    for index in sorted(heap_indices_to_delete, reverse=True):
        heap.pop(index)
        
        # Shifts heap indices to accomodate for deletions.
        for i, indices in enumerate(heap_indices_to_update):
            if indices[0] > index: heap_indices_to_update[i][0] -= 1
    
    # Add v_bar_index to valid pairs and update error for its pairs.
    for indices in heap_indices_to_update:
        heap[indices[0]][1][indices[1]] = v_bar_index
        pair = heap[indices[0]][1]
        
        heap[indices[0]][0] = calculate_new_error(mesh, pair)

In [None]:
def edge_contract(mesh, v1_index, v2_index, v_bar):
    i = v1_index if v1_index < v2_index else v2_index
    mesh.vertices[v1_index] = mesh.vertices[v2_index] = v_bar
    
    # Uses this to get the index of the v_bar vertex after its merged.
    unique, inverse = trimesh.grouping.unique_rows(mesh.vertices)
    v_bar_index = np.where(unique == i)[0].item()
    
    mesh.merge_vertices()
    mesh.remove_degenerate_faces()
    
    return v_bar_index

In [None]:
def simplify_mesh(mesh, percent=0.8, threshold=0.7):
    print(f"Number of faces of original shape = {mesh.faces.shape[0]}")
    target_num_of_faces = int(mesh.faces.shape[0] * percent)
    valid_pairs = get_all_valid_pairs(mesh, threshold)
    print("Creating heap...")
    heap = create_heap(mesh, valid_pairs)

    while len(heap) > 10 and target_num_of_faces < mesh.faces.shape[0]:
        error, pair = heapq.heappop(heap)
        v1_index, v2_index = pair[0], pair[1]
        
        q1 = calculate_vertex_error_quadric(mesh, pair[0])
        q2 = calculate_vertex_error_quadric(mesh, pair[1])
        v_bar = compute_optimal_contraction(mesh, pair, q1, q2)
        v_bar = np.squeeze(v_bar[:3])
        
        v_bar_index = edge_contract(mesh, v1_index, v2_index, v_bar)
        update_valid_pairs(mesh, heap, v_bar_index, v1_index, v2_index)
        
        print(f"Heap size = {len(heap)}")
        
    print(f"Number of faces of simplified shape = {mesh.faces.shape[0]}")
    return mesh

## Main Code

In [None]:
%%time
m = load_mesh("armadillo.obj")
threshold = np.mean(m.edges_unique_length)
print(f"threshold = {threshold}")
print(f"number of unique edges = {m.edges_unique.shape[0]}")
new_mesh = simplify_mesh(mesh=m, percent=0.7, threshold=threshold)

In [None]:
trimesh.exchange.export.export_mesh(new_mesh, ".\\new_armadillo.obj", file_type="obj")