
Copyright (c) 2023 Haiba Labs

Author: James Ritts <james@haibalabs.com>

This notebook trains a simple pytorch model to map from [MediaPipe face mesh](http://solutions.mediapipe.dev/face_mesh) landmarks to [ARKit-compatible blendshapes](https://developer.apple.com/documentation/arkit/arfaceanchor/blendshapelocation).

### [Click here to open the demo.](https://haibalabs.github.io/face-mesh-to-blendshapes/test/mediapipe_to_arkit.html)

### Caveats

- We wish to train on object space geo so it doesn't have to learn what a face pose looks like in every possible head orientation. Unfortunately MediaPipe's output is only given in [screen coordinates](https://www.cse.iitd.ac.in/~suban/vision/affine/node5.html). Its mesh is also stretched to conform to the silhouette of the face in the input image. The function normalize_landmarks() tries to undo these effects.
- The function convert_landmarks_to_model_input() uses normalize_landmarks in order to convert from raw MediaPipe output to the NN input vector. This function needs to be ported to any environment where the model is run.
- MediaPipe isn't able to signal every blendshape. These should be forced to zero at runtime and possibly others as well: jawForward, jawRight, jawLeft, mouthDimpleRight, mouthDimpleLeft, cheekPuff, tongueOut.


### Format

The order of blendshape values in the model output is:

```
eyeBlinkRight, eyeLookDownRight, eyeLookInRight, eyeLookOutRight, eyeLookUpRight, eyeSquintRight, eyeWideRight, eyeBlinkLeft, eyeLookDownLeft, eyeLookInLeft, eyeLookOutLeft, eyeLookUpLeft, eyeSquintLeft, eyeWideLeft, jawForward, jawRight, jawLeft, jawOpen, mouthClose, mouthFunnel, mouthPucker, mouthRight, mouthLeft, mouthSmileRight, mouthSmileLeft, mouthFrownRight, mouthFrownLeft, mouthDimpleRight, mouthDimpleLeft, mouthStretchRight, mouthStretchLeft, mouthRollLower, mouthRollUpper, mouthShrugLower, mouthShrugUpper, mouthPressRight, mouthPressLeft, mouthLowerDownRight, mouthLowerDownLeft, mouthUpperUpRight, mouthUpperUpLeft, browDownRight, browDownLeft, browInnerUp, browOuterUpRight, browOuterUpLeft, cheekPuff, cheekSquintRight, cheekSquintLeft, noseSneerRight, noseSneerLeft, tongueOut
```

Note MediaPipe isn't capable of signaling every blendshape. For example, these should be probably forced to zero at runtime, and possibly others:

```
jawForward, jawRight, jawLeft, mouthDimpleRight, mouthDimpleLeft, cheekPuff, tongueOut
```

Training data has this folder structure:
- my_first_dataset
  - neutral.jpg
  - my_first_dataset.csv
  - my_first_dataset_000000.jpg
  - my_first_dataset_000001.jpg
  - my_first_dataset_000002.jpg
  - ...
- my_second_dataset
- sets.txt

The file sets.txt should contain the folder names of all training datasets:
```
my_first_dataset
my_second_dataset
...
```

Each dataset must have a calibration photo depicting a neutral facial expression: **neutral.jpg**.  The model is trained on object space offsets from the neutral pose.

All images in a set should have approximately the same head transform.

Each dataset also has a CSV file containing a header row followed by labels (blendshape values) for each input image:
```
eyeBlinkRight,eyeLookDownRight,eyeLookInRight,eyeLookOutRight,eyeLookUpRight,eyeSquintRight,eyeWideRight,eyeBlinkLeft,eyeLookDownLeft,eyeLookInLeft,eyeLookOutLeft,eyeLookUpLeft,eyeSquintLeft,eyeWideLeft,jawForward,jawRight,jawLeft,jawOpen,mouthClose,mouthFunnel,mouthPucker,mouthRight,mouthLeft,mouthSmileRight,mouthSmileLeft,mouthFrownRight,mouthFrownLeft,mouthDimpleRight,mouthDimpleLeft,mouthStretchRight,mouthStretchLeft,mouthRollLower,mouthRollUpper,mouthShrugLower,mouthShrugUpper,mouthPressRight,mouthPressLeft,mouthLowerDownRight,mouthLowerDownLeft,mouthUpperUpRight,mouthUpperUpLeft,browDownRight,browDownLeft,browInnerUp,browOuterUpRight,browOuterUpLeft,cheekPuff,cheekSquintRight,cheekSquintLeft,noseSneerRight,noseSneerLeft,tongueOut
0.039,0.103,0.044,0.000,0.000,0.000,0.000,0.039,0.104,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.010,0.010,0.027,0.000,0.000,0.002,0.003,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.015,0.014,0.000,0.000,0.000,0.007,0.000,0.000,0.000,0.000,0.000
0.038,0.091,0.049,0.000,0.000,0.000,0.000,0.038,0.092,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.010,0.011,0.027,0.000,0.000,0.002,0.004,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.014,0.014,0.000,0.000,0.000,0.007,0.000,0.000,0.000,0.000,0.000
...
```

To do:
- cull shapes from NN output and training which MP can't signal
- cull training examples which don't signal shapes that MP can detect (programmically; already did a rough manual pass)

Relevant links:
- https://arxiv.org/pdf/2006.10962.pdf
- https://developers.googleblog.com/2020/09/mediapipe-3d-face-transform.html
- https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_geometry/data
- https://github.com/google/mediapipe/issues/2867
- https://stackoverflow.com/questions/69858216/mediapipe-facemesh-vertices-mapping
- https://github.com/Rassibassi/mediapipeFacegeometryPython/blob/main/face_geometry.py
- https://github.com/google/mediapipe/blob/a908d668c730da128dfa8d9f6bd25d519d006692/mediapipe/modules/face_geometry/data/canonical_face_model_uv_visualization.png

---
**Configuration**

---

In [1]:
import math

if 'google.colab' in str(get_ipython()):
    IS_LOCAL_ENVIRONMENT = False
else:
    IS_LOCAL_ENVIRONMENT = True

# Root folder with training data
# Cache and model files will be written here
if IS_LOCAL_ENVIRONMENT:
    DATA_PATH = 'Face/Training/'
else:
    DATA_PATH = '/content/drive/MyDrive/Colab/Face/Training/'

# Whether to enable MediaPipe's refineLandmarks feature
REFINE_LANDMARKS = True

# Regenerate training data from the contents of DATA_PATH
RUN_MEDIAPIPE_GEN = True

# Train the NN
RUN_TRAINING = True # train NN

# Run an interactive test of the model input preprocessing
RUN_NORMALIZE_TEST = True

# Vector sizes
NUM_LANDMARKS = 225 # see MP_MOUTH_INDICES, ...
NN_INPUT_SIZE = NUM_LANDMARKS * 2
NN_OUTPUT_SIZE = 52

# Hyperparams
NN_HIDDEN_SIZE = math.floor(((2 * NN_INPUT_SIZE) / 3) + NN_OUTPUT_SIZE)
NN_EPOCHS = 75

# Misc
SHOW_IMAGE_HEIGHT = 640
SHOW_IMAGE_WIDTH = 640

# MediaPipe outputs 478 landmarks, from which we cull all but 225 "useful" ones
MP_MOUTH_INDICES = [
    # lips
    0, 11, 12, 13, 14, 15, 16, 17, 37, 38, 39, 40, 41, 42, 61, 62, 72, 73, 74, 76, 77, 78, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90, 91, 95, 96, 146, 178, 179, 180, 181, 183, 184, 185, 191, 267, 268, 269, 270, 271, 272, 291, 292, 302, 303, 304, 306, 307, 308, 310, 311, 312, 314, 315, 316, 317, 318, 319, 320, 321, 324, 325, 375, 402, 403, 404, 405, 407, 408, 409,
    # first ring
    18, 43, 57, 83, 92, 106, 164, 165, 167, 182, 186, 273, 287, 313, 322, 335, 391, 393, 406, 410
]

MP_LEFT_EYE_INDICES = [
    # eye socket
    7, 33, 133, 144, 145, 153, 154, 155, 157, 158, 159, 160, 161, 163, 173, 246,
    # first ring around eye
    22, 23, 24, 25, 26, 27, 28, 29, 30, 56, 110, 112, 130, 190, 243, 247,
    # second ring around eye
    31, 113, 189, 221, 222, 223, 224, 225, 226, 228, 229, 230, 231, 232, 233, 244,
    # brow row 1
    46, 52, 53, 55, 65,
    # brow row 2
    63, 66, 70, 105, 107,
    # brow row 3
    68, 69, 71, 104, 108
]

MP_RIGHT_EYE_INDICES = [
    # eye socket
    249, 263, 362, 373, 374, 380, 381, 382, 384, 385, 386, 387, 388, 390, 398, 466,
    # first ring around eye
    252, 253, 254, 255, 256, 257, 258, 259, 260, 286, 339, 341, 359, 414, 463, 467,
    # second ring around eye
    261, 342, 413, 441, 442, 443, 444, 445, 446, 448, 449, 450, 451, 452, 453, 464,
    # brow row 1
    276, 282, 283, 285, 295,
    # brow row 2
    293, 296, 300, 334, 336,
    # brow row 3
    298, 299, 301, 333, 337
]


---
**Dependencies**

---

In [7]:
# Scratch local setup
#!python --version
# !conda create --name myenv
# !conda activate colab_env
#!pip show torch
# !pip show torchvision
#!pip install torch==1.13.0+cu116 torchvision==0.14.0+cu116 -f https://download.pytorch.org/whl/torch_stable.html
# !pip install numpy==1.21.6
# !pip install --extra-index-url https://download.pytorch.org/whl/cu113/ "torch==1.12.1+cu113"
# !pip install pandas==1.3.5

if not IS_LOCAL_ENVIRONMENT:
    !pip install mediapipe
    !pip install pytorch_lightning
    !pip install torchviz
    !pip install torchvision
    !pip install matplotlib
    !git clone https://github.com/Rassibassi/mediapipeDemos

import pytorch_lightning as pl
from torchmetrics import Accuracy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torch.utils.data.dataset import Dataset
from torch.utils.tensorboard import SummaryWriter
import cv2
import numpy as np
import mediapipe as mp
import os
import glob
import pandas as pd
import copy
import pickle
from ast import Yield
from IPython.display import display, Javascript
from base64 import b64decode
from torchviz import make_dot

mp_face_mesh = mp.solutions.face_mesh
mp_face_mesh_connections = mp.solutions.face_mesh_connections
mp_drawing = mp.solutions.drawing_utils 
mp_drawing_styles = mp.solutions.drawing_styles
drawing_spec = mp_drawing.DrawingSpec(thickness=1, circle_radius=3)

---
**Utilities**

---

In [8]:
if not IS_LOCAL_ENVIRONMENT:
    from google.colab.output import eval_js
    from google.colab.patches import cv2_imshow

# https://graphics.pixar.com/library/OrthonormalB/paper.pdf
def revisedONB(n):
    if n[2] < 0:
        a = 1.0 / (1.0 - n[2])
        b = n[0] * n[1] * a
        return [[1.0 - n[0] * n[0] * a, -b, n[0]], [b, n[1] * n[1]*a - 1.0, -n[1]]]
    else:
        a = 1.0 / (1.0 + n[2])
        b = -n[0] * n[1] * a
        return [[1.0 - n[0] * n[0] * a, b, -n[0]], [b, 1.0 - n[1] * n[1] * a, -n[1]]]

# Get a photo using the webcam
def take_photo(filename='photo.jpg', quality=0.8):
    js = Javascript('''
        async function takePhoto(quality) {
            const div = document.createElement('div');
            const capture = document.createElement('button');
            capture.textContent = 'Capture';
            div.appendChild(capture);

            const video = document.createElement('video');
            video.style.display = 'block';
            const stream = await navigator.mediaDevices.getUserMedia({video: true});

            document.body.appendChild(div);
            div.appendChild(video);
            video.srcObject = stream;
            await video.play();

            // Resize the output to fit the video element.
            google.colab.output.setIframeHeight(document.documentElement.scrollHeight, true);

            // Wait for Capture to be clicked.
            await new Promise((resolve) => capture.onclick = resolve);

            const canvas = document.createElement('canvas');
            canvas.width = video.videoWidth;
            canvas.height = video.videoHeight;
            canvas.getContext('2d').drawImage(video, 0, 0);
            stream.getVideoTracks()[0].stop();
            div.remove();
            return canvas.toDataURL('image/jpeg', quality);
        }
        ''')
    display(js)
    data = eval_js('takePhoto({})'.format(quality))
    binary = b64decode(data.split(',')[1])
    with open(filename, 'wb') as f:
        f.write(binary)
    return filename

# Scale and display an image
def show_image(image):
    h, w = image.shape[:2]
    if h < w:
        img = cv2.resize(image, (SHOW_IMAGE_WIDTH, math.floor(h/(w/SHOW_IMAGE_WIDTH))))
    else:
        img = cv2.resize(image, (math.floor(w/(h/SHOW_IMAGE_HEIGHT)), SHOW_IMAGE_HEIGHT))
    cv2_imshow(img)

# Scale and display an image overlaid with mediapipe landbarks
def show_landmarks(image, multi_face_landmarks):
    copy = image.copy()
    for face_landmarks in multi_face_landmarks:
        mp_drawing.draw_landmarks(
                image=copy,
                landmark_list=face_landmarks,
                connections=mp_face_mesh.FACEMESH_TESSELATION,
                landmark_drawing_spec=None,
                connection_drawing_spec=mp_drawing_styles
                .get_default_face_mesh_tesselation_style())
        mp_drawing.draw_landmarks(
                image=copy,
                landmark_list=face_landmarks,
                connections=mp_face_mesh.FACEMESH_CONTOURS,
                landmark_drawing_spec=None,
                connection_drawing_spec=mp_drawing_styles
                .get_default_face_mesh_contours_style())
        if REFINE_LANDMARKS:
            mp_drawing.draw_landmarks(
                    image=copy,
                    landmark_list=face_landmarks,
                    connections=mp_face_mesh.FACEMESH_IRISES,
                    landmark_drawing_spec=None,
                    connection_drawing_spec=mp_drawing_styles
                    .get_default_face_mesh_iris_connections_style())
    show_image(copy)


---
**Preprocessing**

---

In [26]:
# MediaPipe's face mesh indices:
#   https://github.com/google/mediapipe/blob/master/mediapipe/modules/face_geometry/data/canonical_face_model_uv_visualization.png
# 
# Landmark coordinate system is +X to the right, +Y down, and +Z pointing into the screen.
# 
#            _
#            /| +Z
#           /
#          /
#         o-----> +X
#         |
#         |
#         \/ +Y

def screen_align_and_normalize(arr, xf):
    # screen align
    pivot = np.mean(arr, axis=0)
    arr = [p - pivot for p in arr]
    arr = [xf.dot(p).tolist()[0] for p in arr]
    arr = [p + pivot for p in arr]
    # fill unit cube
    amin = np.amin(arr, axis=0)
    amax = np.amax(arr, axis=0)
    for p in arr:
        p[0] = (p[0] - amin[0]) / (amax[0] - amin[0])
        p[1] = (p[1] - amin[1]) / (amax[1] - amin[1])
        p[2] = (p[2] - amin[2]) / (amax[2] - amin[2])
    return arr

# Get world=>object map for the full mesh using a basis calculated from forehead points
def calc_forehead_xf(arr):
    # average a few normals at the top of the head to get a forward vec
    v0 = np.subtract(arr[151], arr[10])
    v1 = np.subtract(arr[338], arr[10])
    v2 = np.subtract(arr[109], arr[10])
    f1 = np.cross(v1, v0)
    f2 = np.cross(v0, v2)
    f1 = f1 / np.linalg.norm(f1)
    f2 = f2 / np.linalg.norm(f2)
    vfw = 0.5 * (f1 + f2)
    # up
    vup = np.subtract(arr[151], arr[10])
    vup = vup / np.linalg.norm(vup)
    # right
    vrt = np.cross(vup, vfw)
    vup = np.cross(vfw, vrt)
    return np.matrix([vrt, vup, vfw]) # 3x3

# Get world=>object map for the left eye
def calc_simple_xf(arr, idxF0, idxF1, idxF2, idxR0, idxR1):
    # fwd
    v0 = np.subtract(arr[idxF2], arr[idxF0])
    v1 = np.subtract(arr[idxF1], arr[idxF0])
    vfw = np.cross(v1, v0)
    vfw = vfw / np.linalg.norm(vfw)
    # right
    vrt = np.subtract(arr[idxR1], arr[idxR0])
    vrt = vrt / np.linalg.norm(vrt)
    # up
    vup = np.cross(vfw, vrt)
    vrt = np.cross(vup, vfw)
    return np.matrix([vrt, vup, vfw]) # 3x3

# Get world=>object map for the left eye
def calc_left_eye_xf(arr):
    return calc_simple_xf(arr, 23, 22, 230, 33, 133)

# Get world=>object map for the left eye
def calc_right_eye_xf(arr):
    return calc_simple_xf(arr, 253, 450, 252, 362, 263)

# Get world=>object map for the left eye
def calc_mouth_xf(arr):
    # average a few normals at the top of the head to get a forward vec
    v0 = np.subtract(arr[164], arr[2])
    v1 = np.subtract(arr[326], arr[2])
    v2 = np.subtract(arr[97], arr[2])
    f1 = np.cross(v1, v0)
    f2 = np.cross(v0, v2)
    f1 = f1 / np.linalg.norm(f1)
    f2 = f2 / np.linalg.norm(f2)
    vfw = 0.5 * (f1 + f2)
    # right
    vrt = np.subtract(arr[312], arr[82])
    vrt = vrt / np.linalg.norm(vrt)
    # up
    vup = np.cross(vfw, vrt)
    vrt = np.cross(vup, vfw)
    return np.matrix([vrt, vup, vfw]) # 3x3

def normalize_landmarks(multi_face_landmarks):
    landmarks = multi_face_landmarks[0].landmark
    arr = [[l.x, l.y, l.z] for l in landmarks]

    arrMouth = [arr[idx] for idx in MP_MOUTH_INDICES]
    arrLeftEye = [arr[idx] for idx in MP_LEFT_EYE_INDICES]
    arrRightEye = [arr[idx] for idx in MP_RIGHT_EYE_INDICES]

    arrMouth = screen_align_and_normalize(arrMouth, calc_mouth_xf(arr))
    arrLeftEye = screen_align_and_normalize(arrLeftEye, calc_left_eye_xf(arr))
    arrRightEye = screen_align_and_normalize(arrRightEye, calc_right_eye_xf(arr))

    for idx in range(len(landmarks)):
        landmarks[idx].x = 0
        landmarks[idx].y = 0
        landmarks[idx].z = 0

    src = 0
    for idx in MP_MOUTH_INDICES:
        landmarks[idx].x = arrMouth[src][0]
        landmarks[idx].y = arrMouth[src][1]
        landmarks[idx].z = arrMouth[src][2]
        src = src + 1

    src = 0
    for idx in MP_LEFT_EYE_INDICES:
        landmarks[idx].x = arrLeftEye[src][0]
        landmarks[idx].y = arrLeftEye[src][1]
        landmarks[idx].z = arrLeftEye[src][2]
        src = src + 1

    src = 0
    for idx in MP_RIGHT_EYE_INDICES:
        landmarks[idx].x = arrRightEye[src][0]
        landmarks[idx].y = arrRightEye[src][1]
        landmarks[idx].z = arrRightEye[src][2]
        src = src + 1

    arr = arrMouth + arrLeftEye + arrRightEye;
    arr = [[i[0], i[1]] for i in arr] # Z -> 0
    return arr

def get_normalized_landmarks(path, show=False):
    image = cv2.imread(path)
    with mp_face_mesh.FaceMesh(
            static_image_mode=True,
            refine_landmarks=REFINE_LANDMARKS,
            max_num_faces=1,
            min_detection_confidence=0.5) as face_mesh:
        results = face_mesh.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
        arr = normalize_landmarks(results.multi_face_landmarks)
        if (show):
            show_landmarks(image, results.multi_face_landmarks)
        return arr

def convert_landmarks_to_model_input(multi_face_landmarks, neutral_normalized_landmarks):
    arr = normalize_landmarks(multi_face_landmarks)
    return np.subtract(arr, neutral_normalized_landmarks)

# Test code
if RUN_NORMALIZE_TEST:
    from IPython.display import Image
    try:
        #test_path = '/content/drive/MyDrive/Colab/Face/Training/eyes1/eyes1_0013.jpg'
        test_path = take_photo()
        test_image = cv2.imread(test_path)
        with mp_face_mesh.FaceMesh(
                static_image_mode=True,
                refine_landmarks=REFINE_LANDMARKS,
                max_num_faces=1,
                min_detection_confidence=0.5) as face_mesh:

            results = face_mesh.process(cv2.cvtColor(test_image, cv2.COLOR_BGR2RGB))
            normalize_landmarks(results.multi_face_landmarks)

            show_landmarks(test_image, results.multi_face_landmarks)

    except Exception as err:
        # Errors will be thrown if the user does not have a webcam or if they do not
        # grant the page permission to access it.
        print(str(err))

<IPython.core.display.Javascript object>

name 'eval_js' is not defined


---
**Model definition**

---

In [27]:
class FullyConnectedModel(pl.LightningModule):
    def __init__(self, input_size=NN_INPUT_SIZE, output_size=NN_OUTPUT_SIZE, hidden_units=[NN_HIDDEN_SIZE]):
        super().__init__()
        
        all_layers = [nn.Flatten()]
        for hidden_unit in hidden_units:
            layer = nn.Linear(input_size, hidden_unit)
            all_layers.append(layer)
            all_layers.append(nn.ReLU()) 
            input_size = hidden_unit 
 
        all_layers.append(nn.Linear(input_size, output_size))
        self.writer = SummaryWriter()
        self.model = nn.Sequential(*all_layers)
        self.epoch = 0

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        preds = self(x)
        loss = nn.functional.mse_loss(self(x), y)
        self.log("train_loss", loss, prog_bar=True)

        self.epoch += 1
        self.writer.add_scalar("Loss/train", loss, self.epoch)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        preds = self(x)
        loss = nn.functional.mse_loss(self(x), y)
        return loss
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        preds = self(x)
        loss = nn.functional.mse_loss(self(x), y)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        return optimizer

class ARKitDataset(Dataset):
    def __init__(self, sets):
        self.inputs = [torch.from_numpy(np.array(i).flatten()).float() for s in sets for i in s['input']]
        self.labels = [torch.from_numpy(np.array(i)).float() for s in sets for i in s['output']]
    def __len__(self):
        return len(self.inputs)
    def __getitem__(self, idx):
        return self.inputs[idx], self.labels[idx]

class ARKitDataModule(pl.LightningDataModule):
    def __init__(self, sets, batch_size = 64, validation_size = 0.1):
        super().__init__()
        self.sets = sets
        self.batch_size = batch_size
        self.validation_size = validation_size

    def prepare_data(self):
        self.data = ARKitDataset(self.sets)
        self.data_size = len(self.data)
        self.val_size = math.floor(self.validation_size * self.data_size)
        self.train_size = self.data_size - self.val_size

    def setup(self, stage=None):
        if stage == "fit" or stage is None:
            self.train, self.val = random_split(self.data, [self.train_size, self.val_size])
        if stage == "test" or stage is None:
            self.test = self.data
        if stage == "predict" or stage is None:
            self.predict = self.data

    def train_dataloader(self):
        return DataLoader(self.train, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test, batch_size=self.batch_size)

    def predict_dataloader(self):
        return DataLoader(self.predict, batch_size=self.batch_size)


---
**Generate training data and train the model**

---

In [28]:
# ---------------------------------------------------------------------------------------
# Generate training set or load cached
# ---------------------------------------------------------------------------------------
if RUN_MEDIAPIPE_GEN:
    # use mediapipe to generate input vectors
    sets = []
    with open(DATA_PATH + 'sets.txt', mode='r') as file:
        sets = [{
            'name': l.strip(),
            'root': os.path.join(DATA_PATH, l.strip()),
            'input': [],
            'output': pd.read_csv(os.path.join(os.path.join(DATA_PATH, l.strip()), l.strip() + '.csv')),
            'path': os.path.join(os.path.join(DATA_PATH, l.strip()), (l.strip() + '_{:04d}.jpg')),
            'neutral_path': os.path.join(os.path.join(DATA_PATH, l.strip()), 'neutral.jpg'),
            'count': len(glob.glob1(os.path.join(DATA_PATH, l.strip()),"*.jpg")) - 1, # subtracting one for neutral.jpg
            } for l in file.readlines()]
    print('Indexed ' + str(len(sets)) + ' image sets...')
    # print(sets[0]['csv'].iloc[0])

    with mp_face_mesh.FaceMesh(
            static_image_mode=True,
            refine_landmarks=REFINE_LANDMARKS,
            max_num_faces=1,
            min_detection_confidence=0.5) as face_mesh:

        for set in sets:
            print('Processing set "' + set['name'] + '"')

            set['output'] = set['output'].to_numpy()

            # Remove head orientation
            while set['output'].shape[1] > 52:
                set['output'] = np.delete(set['output'], 52, 1)

            neutral = get_normalized_landmarks(set['neutral_path'])

            for index in range(set['count']):
                path = set['path'].format(index)
                image = cv2.imread(path)
                results = face_mesh.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
                input = convert_landmarks_to_model_input(results.multi_face_landmarks, neutral)
                set['input'].append(input)

        with open(os.path.join(DATA_PATH, 'sets.pickle'), 'wb') as f:
            pickle.dump(sets, f)
else:
    with open(os.path.join(DATA_PATH, 'sets.pickle'), 'rb') as f:
        sets = pickle.load(f)

# ---------------------------------------------------------------------------------------
# Train model or load cached
# ---------------------------------------------------------------------------------------
data_module = ARKitDataModule(sets)
landmark_to_blendshape = FullyConnectedModel(NN_INPUT_SIZE, NN_OUTPUT_SIZE, [NN_HIDDEN_SIZE])

model_pt_path = os.path.join(DATA_PATH, 'model.pt')
model_onnx_path = os.path.join(DATA_PATH, 'model.onnx')

dummy_input = torch.zeros(NN_INPUT_SIZE).float().reshape(NN_INPUT_SIZE, 1).transpose(0, 1)

if RUN_TRAINING:
    %load_ext tensorboard
    %tensorboard --logdir lightning_logs/

    # train
    torch.manual_seed(1) 
    if torch.cuda.is_available(): # if you have GPUs
        trainer = pl.Trainer(max_epochs=NN_EPOCHS, gpus=1)
    else:
        trainer = pl.Trainer(max_epochs=NN_EPOCHS)
    trainer.fit(model=landmark_to_blendshape.float(), datamodule=data_module)
    # optimize (TODO)
    # save to pytorch and onnx formats
    torch.save(landmark_to_blendshape.model.state_dict(), model_pt_path)
    torch.onnx.export(landmark_to_blendshape.model, dummy_input, model_onnx_path, verbose=True)
else:
    # load pretrained
    landmark_to_blendshape.model.load_state_dict(torch.load(model_pt_path))

landmark_to_blendshape.model.eval()

# ---------------------------------------------------------------------------------------
# Test
# ---------------------------------------------------------------------------------------
def predict(name, index):
    for s in sets:
        if s['name'] != name:
            continue;

        x = s['input'][index]
        l = s['output'][index]
        x = torch.from_numpy(np.array(x).flatten()).float().reshape(NN_INPUT_SIZE, 1).transpose(0, 1)
        y = np.clip(landmark_to_blendshape(x)[0].tolist(), 0, 1)
        # print("Input:")
        # print(str(x))
        print("\nLabel:")
        print(str(np.around(l, 2)))
        print("\nPred:")
        print(str(np.around(y, 2)))
        print("\nAbs error:")
        print(str(np.around(np.abs(np.array(y) - np.array(l)), 2)))
        print("\nMax abs error: " + str(np.around(np.abs(np.array(y) - np.array(l)).max(), 6)))
        print("\nMSE: " + str(np.around(np.square(np.array(y) - np.array(l)).mean(), 6)))

predict(sets[0]['name'], 0)

Indexed 9 image sets...
Processing set "brow"
Processing set "eyeBlinkLeft"
Processing set "eyeBlinkRight"
Processing set "take_fwd_expr"
Processing set "take_fwd_text"
Processing set "take_rnd_text"
Processing set "take_eyes"
Processing set "take_nose"
Processing set "take_straight"


  f"Setting `Trainer(gpus={gpus!r})` is deprecated in v1.7 and will be removed"
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 177 K 
-------------------------------------
177 K     Trainable params
0         Non-trainable params
177 K     Total params
0.708     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=75` reached.



Label:
[0.04 0.1  0.04 0.   0.   0.   0.   0.04 0.1  0.   0.   0.   0.   0.
 0.   0.   0.   0.   0.   0.   0.   0.   0.01 0.01 0.03 0.   0.   0.
 0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.02
 0.01 0.   0.   0.   0.01 0.   0.   0.   0.   0.  ]

Pred:
[0.04 0.06 0.06 0.   0.   0.   0.   0.04 0.06 0.01 0.01 0.   0.   0.
 0.   0.   0.   0.   0.   0.01 0.   0.   0.01 0.01 0.02 0.   0.   0.01
 0.01 0.   0.   0.01 0.   0.03 0.01 0.01 0.01 0.03 0.02 0.02 0.01 0.
 0.01 0.   0.   0.   0.01 0.   0.   0.01 0.02 0.  ]

Abs error:
[0.   0.05 0.02 0.   0.   0.   0.   0.   0.05 0.01 0.01 0.   0.   0.
 0.   0.   0.   0.   0.   0.01 0.   0.   0.   0.   0.   0.   0.   0.
 0.01 0.   0.   0.01 0.   0.03 0.01 0.01 0.01 0.03 0.02 0.02 0.01 0.01
 0.   0.   0.   0.   0.01 0.   0.   0.01 0.02 0.  ]

Max abs error: 0.04787

MSE: 0.000175
