In [None]:
%run optimizers/outputs/Outputs.ipynb
from trimesh.base import Trimesh
import h5py
import numpy as np

class FaceModelMeshViewer(Outputs):
    
    def __init__(self, mean, pcaBasis, triangles):
        self.mean = mean
        self.pcaBasis = pcaBasis
        self.triangles = triangles
    
    def handle(self, a, r, t, s):
        points = self._modelPointCombination(a)
        mesh = self._getModelMesh(points, self.triangles)
        scene = mesh.scene()
        scene.show(viewer='gl')
    
    def _modelPointCombination(self, a):
        # get number of points
        num_points = int(len(self.mean)/3)

        # calculate [u1, v2, w3, ..., un, vn, wn]
        points = self.mean + np.matmul(self.pcaBasis, a)

        # format the points into (3, n)
        return np.reshape(points, (3, num_points), order='F')
    
    def _getModelMesh(self, points, triangles):
        # reformat the points
        if len(points.shape) == 1:
            points = np.reshape(points, (3,-1), 'F').T
        elif len(points.shape) == 2:
            if points.shape[0] == 3:
                points = points.T

        # reformat the triangles
        if len(triangles.shape) == 1:
            np.reshape(triangles, (3,-1), 'F').T
        elif len(triangles.shape) == 2:
            if triangles.shape[0] == 3:
                triangles = triangles.T

        return Trimesh(vertices=points, faces=triangles)