# I. Data Preprocessing

In [None]:
import xml.etree.ElementTree as ET 
import numpy as np
import os 

tree = ET.parse('ibug_300W_large_face_landmark_dataset/labels_ibug_300W_train.xml')
root = tree.getroot()
root_dir = 'ibug_300W_large_face_landmark_dataset'

bboxes = [] # face bounding box used to crop the image
landmarks = [] # the facial keypoints/landmarks for the whole training dataset
img_filenames = [] # the image names for the whole dataset

for filename in root[2]:
	img_filenames.append(os.path.join(root_dir, filename.attrib['file']))
	box = filename[0].attrib
	# x, y for the top left corner of the box, w, h for box width and height
	bboxes.append([box['left'], box['top'], box['width'], box['height']]) 

	landmark = []
	for num in range(68):
		x_coordinate = int(filename[0][num].attrib['x'])
		y_coordinate = int(filename[0][num].attrib['y'])
		landmark.append([x_coordinate, y_coordinate])
	landmarks.append(landmark) # relative? 

landmarks = np.array(landmarks).astype('float32')
bboxes = np.array(bboxes).astype('float32') 

In [None]:
# build a dataset
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as transforms

import numpy
import random

import albumentations as A
from albumentations.pytorch import ToTensorV2

def seed_everything(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
seed_everything(42)

In [None]:
class IbugTrainingHeatmapDataset(Dataset):
    def __init__(self, img_filenames, bboxes, landmarks, normalize=True, basic_transform=None, albu_transform=None, sigma=1):
        self.img_filenames = img_filenames
        self.bboxes = bboxes
        self.landmarks = landmarks
        self.basic_transform = basic_transform # resize, totensor, normalize
        self.albu_transform = albu_transform # albumentations
        self.normalize = normalize
        self.sigma = sigma
        if not self.normalize:
            print('Not normalizing the image')
        if not self.basic_transform:
            print('No basic transformation')
        
    def __len__(self):
        return len(self.img_filenames)
    
    def __getitem__(self, idx):
        img_path = self.img_filenames[idx]
        opened_img = Image.open(img_path).convert('L') # range [0, 255] # shape (H, W)
        bounding_box = self.bboxes[idx]
        landmark_ori = self.landmarks[idx] # (68, 2)
        x, y, w, h = bounding_box # left, top, width, height
        cropped_by_bbox = opened_img.crop((x, y, x+w, y+h)) # shape (h, w)

        cropped_by_bbox = np.array(cropped_by_bbox) # range [0, 255] # shape (H, W)
        cropped_by_bbox = np.expand_dims(cropped_by_bbox, axis=2) # shape (H, W, 1)
        # to float32
        cropped_by_bbox = cropped_by_bbox.astype(np.float32) # shape (H, W, C)
        if self.normalize:
            cropped_by_bbox = cropped_by_bbox / 255.0 - 0.5 # range [-0.5, 0.5]
            # print(cropped_by_bbox.dtype)
        # adjust the landmark
        # landmark2 = landmark - [x, y] # FIXME: broadcast?
        landmark = np.zeros_like(landmark_ori)
        landmark[:, 0] = landmark_ori[:, 0] - x
        landmark[:, 1] = landmark_ori[:, 1] - y
        # assert np.all(landmark == landmark2)
        
        
        # to relative coordinates
        if self.albu_transform:
            transformed = self.albu_transform(image=cropped_by_bbox, keypoints=landmark)
            tfed_im = transformed['image']
            landmark = transformed['keypoints']
        else:
            tfed_im = cropped_by_bbox # (C, H, W)
            
        landmark = torch.tensor(landmark) # shape: (68, 2)
        
        
        # relative coordinates [0, 1]
        landmark[:, 0] = landmark[:, 0] / w * 224
        landmark[:, 1] = landmark[:, 1] / h * 224
        # print(tfed_im.shape, w, h)
        # tfed_im = torch.tensor(tfed_im)
        # print(tfed_im.shape)
        if self.basic_transform:
            tfed_im = self.basic_transform(tfed_im) # tfed_im: (C, 224, 224)
        else:
            tfed_im = torch.tensor(tfed_im)
        # heatmap should be (68, 224, 224)
        heatmap = self._gaussian_heatmap(landmark, 224, 224, 68)
        heatmap = torch.tensor(heatmap)
        heatmap = heatmap.float()
        return tfed_im, heatmap, landmark # tfed_im: (C=1, 224, 224), heatmap: (68, 224, 224), landmark: (68, 2)
    
    def _gaussian_heatmap(self, landmark, height, width, channels):
        # gaussian density
        heatmap = np.zeros((channels, height, width))
        # landmark: (68, 2)
        # for i in range(channels): 
        #     x = landmark[i, 0]
        #     y = landmark[i, 1]
        #     density = np.zeros((height, width))
        #     for j in range(height):
        #         for k in range(width):
        #             density[j, k] = np.exp(-((j - x)**2 + (k - y)**2) / (2 * sigma**2))
        #     zero_heatmap[i] = density
        for i in range(channels):
            x = landmark[i, 0]
            y = landmark[i, 1]
            for j in range(height):
                for k in range(width):
                    heatmap[i, j, k] = np.exp(-((j - x)**2 + (k - y)**2) / (2 * self.sigma**2))
            # normalize s.t. the sum of each heatmap is 1
            heatmap[i] = heatmap[i] / np.sum(heatmap[i])
        return heatmap
    
    
basic_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)), # From (C, H, W) to (C, 224, 224)
    # grayscale
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    # transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

transform = A.Compose([
    A.Affine(rotate=(-15, 15), translate_percent={'x': 0.1, 'y': 0.1}),
    A.GaussNoise(p=0.5), # DO WE NEED THIS?
    A.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
    ToTensorV2(),
], keypoint_params=A.KeypointParams(format='xy', remove_invisible=False))

dataset = IbugTrainingHeatmapDataset(img_filenames, bboxes, landmarks, basic_transform=basic_transform, albu_transform=transform, normalize=True, sigma=1)

# II. Model Architecture
We use a UNet predicting the probability density heatmap (shape [68, 224, 224]) of the face landmarks, and then find the expected position of the landmarks w.r.t. the heatmap,
i.e. a pixelwise classification problem.

In [None]:
# unet
import torch
import torch.nn as nn
import torch.nn.functional as F
# load unet
from unet import PixelwiseClassificationUNet

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = PixelwiseClassificationUNet(1, 68).to(device)
def init_weights(m):
    if type(m) == nn.Conv2d:
        nn.init.kaiming_normal_(m.weight)
        nn.init.zeros_(m.bias)
model.apply(init_weights)

# III. Training

In [None]:
# loss function, between predicted heatmap and ground truth heatmap
# need mae
from torch.nn.functional import mse_loss, binary_cross_entropy_with_logits, l1_loss
from tqdm import tqdm
LR = 1e-3
BS = 32
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
num_epochs = 10
# loader
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [int(len(dataset) * 0.8), len(dataset) - int(len(dataset) * 0.8)])
train_loader = DataLoader(train_dataset, batch_size=BS, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BS, shuffle=False)

In [None]:
def heatmap_to_landmarks(heatmap):
    # heatmap: (B, 68, 224, 224) -> (B, 68, 2)
    landmarks = []
    for i in range(heatmap.shape[0]):
        hm = heatmap[i] # (68, 224, 224)
        landmark = np.zeros((68, 2))
        for j in range(68):
            hm_j = hm[j] # (224, 224)
            x, y = np.unravel_index(hm_j.argmax(), hm_j.shape)
            landmark[j, 0] = x
            landmark[j, 1] = y
        landmarks.append(landmark)
    landmarks = np.array(landmarks)
    return torch.tensor(landmarks)

In [None]:
# train
train_losses = []
val_losses = []

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    loop = tqdm(enumerate(train_loader), total=len(train_loader), leave=False)
    loop.set_description(f'Epoch [{epoch+1}/{num_epochs}]')
    for i, (images, heatmaps, _) in loop:
        images = images.to(device) # shape: (B, C, H, W)
        heatmaps = heatmaps.to(device) # shape: (B, 68, H, W)
        # landmarks = landmarks.to(device) # shape: (B, 68, 2)
        # Forward pass
        logits = model(images) # shape: (B, 68, H, W)
        # should mean over B*68
        loss = binary_cross_entropy_with_logits(logits, heatmaps)
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        loop.set_postfix(train_loss=loss.item())
    train_loss /= len(train_loader)
    # validation
    model.eval()
    val_loss = 0
    loop = tqdm(enumerate(val_loader), total=len(val_loader), leave=False)
    loop.set_description(f'Epoch [{epoch+1}/{num_epochs}]')
    with torch.no_grad():
        for i, (images, heatmaps, _) in enumerate(val_loader):
            images = images.to(device) # shape: (B, C, H, W)
            heatmaps = heatmaps.to(device) # shape: (B, 68, H, W)
            landmarks = landmarks.to(device) # shape: (B, 68, 2)
            # Forward pass
            logits = model(images) # shape: (B, 68, H, W)
            # find mae between expected and predicted landmarks
            # aggregate over H and W
            probs = torch.softmax(logits, dim=1)
            predicted_landmarks = heatmap_to_landmarks(probs)
            # should mean over B*68
            loss = binary_cross_entropy_with_logits(logits, heatmaps)
            val_loss += loss.item()
            loop.set_postfix(val_loss=loss.item())
    val_loss /= len(val_loader)
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    print(f'Epoch {epoch+1}, Train Loss: {train_loss}, Val Loss: {val_loss}')
        