## Import Libraries

In [None]:
import sys,os

RES_PATH = 'models' # Folder where all 3D models are kept. Same dir level as this file.

if not os.path.exists(RES_PATH):
    print(f"cannot find {RES_PATH} folder, please update RES_PATH")
    exit(1)
else:
    pass

import pyglet
pyglet.options['shadow_window'] = False

import pyrender
import numpy as np
import trimesh

import matplotlib
import matplotlib.pyplot as plt

from sklearn.neighbors import KDTree

import heapq
%load_ext autoreload
%autoreload 2

## Trimesh Utility

In [None]:
def load_mesh(filename):
    mesh_fp = os.path.join(RES_PATH, filename)
    
    assert os.path.exists(mesh_fp), 'cannot found:'+mesh_fp
    
    t_mesh = trimesh.load(mesh_fp)
    
    return t_mesh

## Mesh Simplification Utility

In [None]:
def get_all_valid_pairs(mesh, threshold):
    thresh_squared = threshold ** 2
    valid_pairs = []
    edge_pair_map = dict() # Maps vertex indices to a set of all the vertex indices its paired to.
    
    # Updates the mapping of the pairs appropriately.
    def update_mapping(v1, v2):
        if v1 in edge_pair_map: edge_pair_map[v1].add(v2)
        else: edge_pair_map[v1] = {v2}
        
        if v2 in edge_pair_map: edge_pair_map[v2].add(v1)
        else: edge_pair_map[v2] = {v1}
    
    print("Getting edge pairs...")
    for edge in mesh.edges_unique:
        valid_pairs.append((edge[0], edge[1]))
        update_mapping(edge[0], edge[1])
    
    print("Getting vertex pairs within threshold...")
    # Gets vertex pairs within threshold.
    for i, vert1 in enumerate(mesh.vertices):
        for j, vert2 in enumerate(mesh.vertices[i+1:]):
            if not j in edge_pair_map:
                # Uses squared distance to avoid inefficient square root function.
                sq_distance = np.sum((vert1-vert2) ** 2, axis=0)
    
                if sq_distance < thresh_squared: 
                    valid_pairs.append((i, j))
                    update_mapping(i, j)
                    
            elif not i in edge_pair_map[j]: # Checks for duplicate pairs from the edges.
                # Uses squared distance to avoid inefficient square root function.
                sq_distance = np.sum((vert1-vert2) ** 2, axis=0)
    
                if sq_distance < thresh_squared: 
                    valid_pairs.append((i, j))
                    update_mapping(i, j)
        
        print(f"i = {i}/{len(mesh.vertices)-1}")
    
    return valid_pairs

def calculate_vertex_error_quadric(mesh, vertex_index):
    tri_face_indices = mesh.vertex_faces[vertex_index]
    
    Q = np.zeros((4, 4))
    for i in (i for i in tri_face_indices if i != -1): # Filters out padded -1s.
        face = mesh.faces[i]
        p1 = mesh.vertices[face[0]]
        p2 = mesh.vertices[face[1]]
        p3 = mesh.vertices[face[2]]
        
        # Calculates coefficients a, b, c and d, of the plane equation of the triangle face.
        n = np.cross((p2 - p1), (p3 - p1))
        a, b, c = n = n / np.linalg.norm(n)
        d = np.dot(n, p1)
        
        K_p = np.array([[a**2, a*b, a*c, a*d],
                        [a*b, b**2, b*c, b*d],
                        [a*c, b*c, c**2, c*d],
                        [a*d, b*d, c*d, d**2]])
        Q += K_p
    
    return Q

In [None]:
def compute_optimal_contraction(m, pair, q1, q2):
    q_bar = q1 + q2
    q_bar[3,:] = [0,0,0,1]
    b = np.array([[0],[0],[0],[1]])
    result = np.linalg.solve(q_bar, b)
    return result

In [None]:
def calculate_error(v_bar, pair, q1, q2):
    error = v_bar.T@(q1 + q2)@v_bar
    return error[0][0]

In [None]:
def create_heap(m, pairs):
    heap = []
    for pair in pairs:
        q1 = calculate_vertex_error_quadric(m, pair[0])
        q2 = calculate_vertex_error_quadric(m, pair[1])
        v_bar = compute_optimal_contraction(m, pair, q1, q2)
        error = calculate_error(v_bar, pair, q1, q2)
        heapq.heappush(heap, (error, pair))
    return heap

In [None]:
def update_valid_pairs(mesh, v_bar, v1, v2):
    pass

In [None]:
def simplify_mesh(mesh, threshold=0.7):
    valid_pairs = get_all_valid_pairs(mesh, threshold)
    heap = create_heap(mesh, valid_pairs)
    
    while len(heap) != 0:
        error, pair = heapq.heappop(heap)
        
        v1_index, v2_index = pair[0], pair[1]
        v1_neighbors = mesh.vertex_neighbors[v1_index]
        
        q1 = calculate_vertex_error_quadric(mesh, pair[0])
        q2 = calculate_vertex_error_quadric(mesh, pair[1])
        v_bar = compute_optimal_contraction(mesh, pair, q1, q2)
        
        if v2_index in v1_neighbors:
            # Perform edge contraction - get relevant face and vertex data.
            print("edge contraction")
        else:
            # Perform non-edge contraction - get relevant face and vertex data.
            print("non-edge contraction")
            
        # Update valid_pairs with the new v_bar and all neighboring vertices of v1 and v2
        # and update the new pair's errors.
        update_valid_pairs(valid_pairs, v_bar, v1_index, v2_index)
        
    # Using relevant face and vertex data from contractions, create faces and vertices and
    # use that to initialise a new trimesh.
    new_mesh = mesh # TODO: Change this to the newly initialise trimesh.
    
    return new_mesh

## Main Code

In [None]:
%%time
m = load_mesh("coffee_cup.obj")
new_mesh = simplify_mesh(m)