In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
################################
# Libraries | Utils | Settings
################################
import os
import sys
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
import torch.optim as optim

from PIL import Image, ImageOps

import cv2
from datetime import datetime
import random

from scipy.io import loadmat

In [None]:
! cp -r '/content/drive/My Drive/Colab Notebooks/data' '/tmp/data'

In [None]:
from torch.utils.data import Dataset, DataLoader


class LoadMixedData(Dataset):
    def __init__(self, size, train = 'True', train_test_split = 0.9):

        super(LoadMixedData, self).__init__()

        self.size = size
        self.inputs = []
        self.targets = set()

        datasets = ["AFLW2000", "afw", "helen_trainset",
                "ibug", "lfpw_1", "lfpw_2", "synthetic1000"]
        # reserve "helen_testset" for final evaluation

        for dataset in datasets:
            data_dir = "/tmp/data/" + dataset
            for fname in os.listdir(data_dir):
                if fname.endswith("seg.png"):
                    continue
                if fname.endswith("jpg") or fname.endswith("png"):
                    self.inputs.append(os.path.join(data_dir, fname))
                if fname.endswith("mat") or fname.endswith("ldmks.txt") or fname.endswith("pts"):
                    self.targets.add(os.path.join(data_dir, fname))

        self.inputs.sort()

    def __getitem__(self, idx):
        '''
        load image and mask from index idx of your data
        '''
        input_dir = self.inputs[idx]

        # load original image
        img = cv2.imread(input_dir)
        original_h, original_w, channel = img.shape
        # load original labels
        ls = input_dir.split("/")
        fname = ls.pop().split(".")[0]
        ls.append(fname)
        target_dir = os.path.join("/", *ls)
        if target_dir.endswith("mirror"):
            target_dir = target_dir[:-7] + ".pts"
            with open(target_dir) as f:
                rows = [rows.strip() for rows in f]
            head = rows.index('{') + 1
            tail = rows.index('}')
            raw_points = rows[head:tail]
            ldmks = [[],[]]
            for point in raw_points:
                x, y = list(map(float, point.split()))
                ldmks[0].append(original_w - x)
                ldmks[1].append(y)
            ldmks = np.array(ldmks)
        elif target_dir + ".mat" in self.targets:
            target_dir = target_dir + ".mat"
            ldmks = np.array(loadmat(target_dir)['pt3d_68']) # (3, 68)
            ldmks = ldmks[:2, :]
        elif target_dir + "_ldmks.txt" in self.targets:
            target_dir = target_dir + "_ldmks.txt"
            ldmks = [[],[]]
            with open(target_dir, 'r') as f:
                lines = f.readlines()
            for i, line in enumerate(lines):
                x, y = list(map(float, line.split(" ")))
                ldmks[0].append(x)
                ldmks[1].append(y)
            ldmks[0] = ldmks[0][:68]
            ldmks[1] = ldmks[1][:68]
            ldmks = np.array(ldmks)
        else:
            target_dir = target_dir + ".pts"
            with open(target_dir) as f:
                rows = [rows.strip() for rows in f]
            head = rows.index('{') + 1
            tail = rows.index('}')
            raw_points = rows[head:tail]
            ldmks = [[],[]]
            for point in raw_points:
                x, y = list(map(float, point.split()))
                ldmks[0].append(x)
                ldmks[1].append(y)
            ldmks = np.array(ldmks)

        # crop image and adjust labels to match
        center_x, center_y = np.mean(ldmks[0, :]), np.mean(ldmks[1, :])
        span_x, span_y = np.max(ldmks[0, :]) - np.min(ldmks[0, :]), np.max(ldmks[1, :]) - np.min(ldmks[1, :])
        xmin, xmax = max(0, center_x - span_x), min(original_w, center_x + span_x)
        ymin, ymax = max(0, center_y - span_y), min(original_h, center_y + span_y)
        img = img[int(ymin):int(ymax), int(xmin):int(xmax)]
        crop_h, crop_w, channels = img.shape
        ldmks[0, :] -= xmin
        ldmks[1, :] -= ymin

        # resize image and adjust labels to match
        ratio_x = self.size / crop_w
        ratio_y = self.size / crop_h
        img = cv2.resize(img, (self.size, self.size))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        ldmks[0, :] *= ratio_x
        ldmks[1, :] *= ratio_y

        #return image and mask in tensors
        image = torch.from_numpy(img)
        image = image.permute(2, 0, 1) # (192 * 192 *3) -> (3 * 192 * 192)
        ldmks = torch.from_numpy(ldmks)

        return image, ldmks


    def __len__(self):
        return len(self.inputs)


In [None]:
#################################
# FaceMeshBlock
#################################
# This is the main building block for FaceMesh architecture

class FaceMeshBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1):
        super(FaceMeshBlock, self).__init__()

        self.stride = stride
        self.channel_pad = out_channels - in_channels

        # TFLite uses slightly different padding than PyTorch
        # on the depthwise conv layer when the stride is 2.
        if stride == 2:
            self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride)
            padding = 0
        else:
            padding = (kernel_size - 1) // 2

        self.convs = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=in_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                groups=in_channels,
                bias=True
                ),
            nn.BatchNorm2d(
                in_channels
                ),
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=1,
                stride=1,
                padding=0,
                bias=True
                ),
            nn.BatchNorm2d(
                out_channels
                )
        )

        self.act = nn.PReLU(out_channels)

    def forward(self, x):
        if self.stride == 2:
            h = F.pad(x, (0, 2, 0, 2), "constant", 0)
            x = self.max_pool(x)
        else:
            h = x

        if self.channel_pad > 0:
            x = F.pad(x, (0, 0, 0, 0, 0, self.channel_pad), "constant", 0)

        return self.act(self.convs(h) + x)

#########################################
#       FaceMesh
#########################################

class FaceMesh(nn.Module):

    def __init__(self):
        super(FaceMesh, self).__init__()

        self.num_coords = 68
        self.x_scale = 192.0
        self.y_scale = 192.0
        self.min_score_thresh = 0.75

        self._define_layers()

    def _define_layers(self):
        self.backbone = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=2, padding=0, bias=True),
            nn.PReLU(16),

            FaceMeshBlock(16, 16),
            FaceMeshBlock(16, 16),
            FaceMeshBlock(16, 32, stride=2),
            FaceMeshBlock(32, 32),
            FaceMeshBlock(32, 32),
            FaceMeshBlock(32, 64, stride=2),
            FaceMeshBlock(64, 64),
            FaceMeshBlock(64, 64),
            FaceMeshBlock(64, 128, stride=2),
            FaceMeshBlock(128, 128),
            FaceMeshBlock(128, 128),
            FaceMeshBlock(128, 128, stride=2),
            FaceMeshBlock(128, 128),
            FaceMeshBlock(128, 128),
        )

        self.coord_head = nn.Sequential(
            FaceMeshBlock(128, 128, stride=2),
            FaceMeshBlock(128, 128),
            FaceMeshBlock(128, 128),
            nn.Conv2d(128, 32, 1),
            nn.PReLU(32),
            FaceMeshBlock(32, 32),
            nn.Conv2d(32, 3 * self.num_coords, 3)
        )

        self.conf_head = nn.Sequential(
            FaceMeshBlock(128, 128, stride=2),
            nn.Conv2d(128, 32, 1),
            nn.PReLU(32),
            FaceMeshBlock(32, 32),
            nn.Conv2d(32, 1, 3)
        )

    def forward(self, x):
        x = nn.ConstantPad2d((1, 0, 1, 0), 0)(x)
        b = x.shape[0]      # batch size, needed for reshaping later

        x = self.backbone(x)            # (b, 128, 6, 6)

        c = self.conf_head(x)           # (b, 1, 1, 1)
        c = c.view(b, -1)               # (b, 1)

        r = self.coord_head(x)          # (b, 3 * self.num_coords, 1, 1)
        r = r.reshape(b, -1)            # (b, 3 * self.num_coords)

        return [r, c]

    def _device(self):
        return self.conf_head[1].weight.device

    def load_weights(self, path):
        self.load_state_dict(torch.load(path))
        self.eval()

    def _preprocess(self, x):
        """Converts the image pixels to the range [-1, 1]."""
        return x.float() / 127.5 - 1.0

    def predict_on_batch(self, x):
        """Makes a prediction on a batch of images.

        Arguments:
            x: a NumPy array of shape (b, H, W, 3) or a PyTorch tensor of
               shape (b, 3, H, W). The height and width should be 128 pixels.

        Returns:
            A list containing a tensor of face detections for each image in
            the batch. If no faces are found for an image, returns a tensor
            of shape (0, 17).

        Each face detection is a PyTorch tensor consisting of 17 numbers:
            - ymin, xmin, ymax, xmax
            - x,y-coordinates for the 6 keypoints
            - confidence score
        """
        if isinstance(x, np.ndarray):
            x = torch.from_numpy(x).permute((0, 3, 1, 2))

        assert x.shape[1] == 3
        assert x.shape[2] == 192
        assert x.shape[3] == 192

        # 1. Preprocess the images into tensors:
        x = x.to(self._device())
        x = self._preprocess(x)

        # 2. Run the neural network:
        with torch.no_grad():
            out = self.__call__(x)

        # 3. Postprocess the raw predictions:
        detections, confidences = out
        detections[0:-1:3] *= self.x_scale
        detections[1:-1:3] *= self.y_scale

        return detections.view(-1, 3), confidences


In [None]:
#####################################
# Training
#####################################
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
# Loss
# class WingLoss(nn.Module):
#     def __init__(self, width=5, curvature=0.5):
#         super(WingLoss, self).__init__()
#         self.width = width
#         self.curvature = curvature
#         self.C = self.width - self.width * np.log(1 + self.width / self.curvature)
#
#     def forward(self, prediction, target):
#         diff = target - prediction
#         diff_abs = diff.abs()
#         loss = diff_abs.clone()
#
#         idx_smaller = diff_abs < self.width
#         idx_bigger = diff_abs >= self.width
#
#         loss[idx_smaller] = self.width * torch.log(1 + diff_abs[idx_smaller] / self.curvature)
#         loss[idx_bigger]  = loss[idx_bigger] - self.C
#         loss = loss.mean()
#         return loss
#

# criterion = WingLoss()
criterion = nn.MSELoss()

In [None]:
#Parameters
root_dir = os.getcwd()   #root directory of project
lr = 10e-4   # learning rate
epoch_n = 20   #number of training epochs
image_size = 192   #input image-mask size
batch_size = 64    #training batch size

model = FaceMesh().to(device)
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=0.0005)

# training configurations
trainset = LoadMixedData(size = image_size, train = True, train_test_split = 0.8)
trainloader = DataLoader(trainset, batch_size = batch_size, shuffle=True, drop_last=True)

testset = LoadMixedData(size = image_size, train = False, train_test_split = 0.8)
testloader = DataLoader(testset, batch_size = batch_size, shuffle=True, drop_last=True)

In [None]:
#use checkpoint model for training
load = True
if load:
    print('loading model')
    cpt = torch.load('/content/drive/MyDrive/Colab Notebooks/model/model_checkpoint.pth')
    model.load_state_dict(cpt['model_state_dict'])
    #optimizer.load_state_dict(cpt['optimizer_state_dict'])
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=0.00001)

start_time = datetime.now()
train_loss = []
valid_loss = []
for e in range(epoch_n):

    print("######## Train ########")
    model.train()

    epoch_loss = 0
    for data in trainloader:
        image, label = data
        #print("==========DEBUG DATALOADER============")
        #plt.imshow(image[0,:,:,:].permute(1,2,0))
        #plt.scatter(label[0,0,:], label[0,1,:])
        #plt.savefig("dataloader_inspect.png")
        #break
        image = image.float().to(device)  # (b x 3 x 192 x 192)
        label = label.float().to(device)  # (b x 2 x 68)

        output, confidence = model(image) # output: (b, 204), confidence: (b, 1)
        loss = criterion(output.view(batch_size, 3, -1)[:, :2, :], label[:, :2, :])
        loss.backward()
        epoch_loss += loss.item()

        optimizer.step()
        optimizer.zero_grad()

    print('Epoch %d / %d --- Loss: %.4f' % (e + 1, epoch_n, epoch_loss / trainset.__len__()))
    print(datetime.now())
    train_loss.append(loss.item())

    print("######## Validation ########")
    model.eval()

    total_loss = 0

    with torch.no_grad():
        for i, data in enumerate(testloader):
            image, label = data

            image = image.float().to(device)
            label = label.float().to(device)

            pred, confidence = model(image)
            loss = criterion(pred.view(batch_size, 3, -1)[:, :2, :], label[:, :2, :])
            total_loss += loss.item()

            _, pred_labels = torch.max(pred, dim = 1)

        print('Loss: %.4f' % (total_loss / testset.__len__()))
        valid_loss.append(total_loss / testset.__len__())

    torch.save({
        'epoch': e,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
        'valid_loss': valid_loss,
        }, "/content/drive/MyDrive/Colab Notebooks/model/model_checkpoint.pth")

end_time = datetime.now()


end_time = datetime.now()
delta = end_time - start_time
s = delta.total_seconds()
m, s = divmod(s, 60)
h, m = divmod(m, 60)
print(f"Start: {datetime.strftime(start_time, '%H:%M:%S')}")
print(f"End: {datetime.strftime(end_time, '%H:%M:%S')}")
print(f"Training time: {int(h)} h {int(m)} min {int(s)} s")


fig = plt.figure(figsize=(6,2))
ep = list(range(len(train_loss)))
ax1 = plt.subplot(111)
plt.plot(ep, train_loss, 'r', label="Train loss")
plt.plot(ep, valid_loss, 'b', label="Validation loss")
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

plt.savefig("Loss.png")


loading model
######## Train ########
Epoch 1 / 20 --- Loss: 4.3431
2023-12-27 05:50:04.494449
######## Validation ########
Loss: 3.7080
######## Train ########
Epoch 2 / 20 --- Loss: 3.6722
2023-12-27 05:56:37.712963
######## Validation ########
