# Problem statement
Implementation of a face detection neural network based on Transformer.

## What is Transformer
Self-attention-based architectures, in particular Transformers, have become the model of choice in natural language processing (NLP). The dominant approach is to pre-train on a large text corpus and then fine-tune on a smaller task-specific dataset. Thanks to Transformers’ computational efficiency and scalability, it has become possible to train models of unprecedented size, with over 100B parameters. With the models and datasets growing, there is still no sign of saturating performance.

## Transformer in Computer Vision
Take vision transformer (ViT) as example. There are two principles for implementing this structure. One is to implement the code in pytorch directly by splitting the input image into multiple patches, each patch is directly straightened as a vector, an image is split into as many patches, and then these vectors are used as input to the transformer, and then the whole learning process uses only the encoder structure to expand. The final output links a multilayer perceptron, which is a fully connected layer, for classification. The other is to segment the input image into multiple patches, and each patch is fed into a CNN that extracts a 1D tensor as the word vector of this patch. The latter process is the same.

# Challenges and solution
In Computer Vision, attention is either applied in conjunction with convolutional networks, or used to replace certain components of convolutional networks while keeping their overall structure in place. Inspired by NLP successes, multiple works try combining CNN-like architectures with self-attention, some replacing the convolutions entirely. The latter models, while theoretically efficient, have not yet been scaled effectively on modern hardware accelerators due to the use of specialized attention patterns. 

On the other hand, applications like human face detection usually use a pre-trained CNN architecture with attention. Since the system is more likely to be deployed on mobile platforms, the aforementioned CNN architectures incorporating attention is difficult to be deployed reasonably on such platforms with low computing efficiency.

**Solution**: Inspired by the Transformer scaling successes in NLP, instead of modifying the architecture of CNN, ViT suggests applying a standard Transformer directly to images, with the fewest possible modifications. Specifically, an image is splited into patches and provide the sequence of linear embeddings of these patches as an input to a Transformer. Image patches are treated the same way as tokens (words) in an NLP application. 
In this case, attention is incorporated at the cost of little computations.


However, in our early trials, the ViT model fails to achieve high performance on test data. We found the performance increases as the parameter, `patch_size` continues to increase, from 16, to 56. Hereby we give a possible reason for this. Our input image is $(224, 224)$ shape tensor, and according to ViT, the image_size must be divisible by `patch_size`. So we initially set it to `patch_size=16`. In this case, the input image is splited into $(224//16)^2=196$ patches. Considering that the whole area of an entire image is a human face and there is no additional background, it is clear that a patch as large as 196 divides the face region too finely, resulting in the model not learning the identity information. So, we finally found a proper value of `patch_size`, which is 56.


# Human face detection using ViT

Human face images are used in this presentation. We download the whole dataset from
sklearn, which contains 1288 gray image samples of 7 labels with each sample of a
shape.

In [14]:
from sklearn.datasets import fetch_lfw_people
faces = fetch_lfw_people(min_faces_per_person=64)
# (1288, 2914), (1288, 62, 47)
faces.data.shape, faces.images.shape

((1288, 2914), (1288, 62, 47))

However, somehow all the images need to be pixel flipped before they can be displayed properly.


In [15]:
from matplotlib import pyplot as plt
plt.imshow(255-faces.images[0, :, :], cmap='gray')
plt.show()

Dataset
All the images above is divided into training set and testing set using sklearn

In [16]:
images = 255-faces.images
labels = faces.target

In [17]:
from sklearn.model_selection import train_test_split
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np

x_train, x_test, y_train, y_test = train_test_split(images, labels, test_size=0.2, shuffle=True, stratify=labels)
transforms = transforms.Compose(
    [
        transforms.ToPILImage(),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5]),
    ]
)
class FaceDataset(Dataset):
    def __init__(self, x, y, transform=None):
        self.x = x
        self.y = y
        self.transform = transform
    def __len__(self):
        return len(self.x)
    def __getitem__(self, item):
        sample = self.transform(self.x[item]) if self.transform is not None else self.x[item]
        label = self.y[item]
        return sample, label

In [18]:
x_train.shape, x_test.shape, y_train.shape, y_test.shape

((1030, 62, 47), (258, 62, 47), (1030,), (258,))

model selection
Original ViT is used in this demo.
To preserve the position space information of the patches in the image, position embedding information similar to that used in NLP is added before the word vector input, so there is no difference in the use of position information by using only 1D position information (similar to using only 1D sequential numbering 1, 2, 3... for each patch position in the image) compared to using 2D (similar to 2D encoding representation of (x,y) in the inclusion space). ...), there is no difference compared to using 2D (similar to the 2D encoded representation of (x,y) containing space). (Similar application feels similar to the 19 years JD fine-grained DCL, selfconcept, location information in addition to the relationship between the information of different locations in the same image, the learning process takes more into account the same part of the same kind of different images of the same location area, such as a part of him under all images of this class of things)


In [19]:
from vit_pytorch import ViT
import torch.nn as nn
import torch.optim as optim

device = 'cuda'

model = ViT(
    image_size = 224,
    patch_size = 56,
    num_classes = 7,
    dim = 16,
    heads = 3,
    mlp_dim = 164,
    depth = 4,
    channels = 1
).to(device)



lr = 0.001
num_epoch = 100
validate_every = 4


criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
batch_size = 32



train_loader = DataLoader(dataset=FaceDataset(x_train, y_train, transforms), batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=FaceDataset(x_test, y_test, transforms), batch_size=batch_size, shuffle=True)

In [20]:
import numpy as np
def draw_loss(loss_train, loss_test):
    plt.title("loss")
    x = np.arange(0, len(loss_train))
    plt.plot(x, np.array(loss_train), label='train')
    x = np.arange(0, len(loss_train), validate_every)
    plt.plot(x, np.array(loss_test), label='test')
    plt.legend()
    plt.show()

In [21]:
from tqdm import tqdm
import torch
loss_train_all = []
loss_test_all = []
best_acc = 0


for epoch in range(num_epoch):
    epoch_loss = 0
    epoch_accuracy = 0
    loss_train_epoch = []
    loss_test_epoch = []
    loop = tqdm(train_loader)
    for data, label in loop:
        data = data.to(device)
        label = label.to(device)
        output = model(data)
        loss = criterion(output, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss.cpu().detach().numpy() / len(train_loader)
        loop.set_description(f"epoch: {epoch}, loss: {epoch_loss:.4f}, acc:{epoch_accuracy:.4f}")
        loss_train_epoch.append(epoch_loss)
    loss_train_all.append(np.array(loss_train_epoch).mean())
    if 0 == epoch%validate_every:
        print('==========================Validating==========================')
        with torch.no_grad():
            epoch_val_accuracy = 0
            epoch_val_loss = 0
            for data, label in test_loader:
                data = data.to(device)
                label = label.to(device)
                val_output = model(data)
                val_loss = criterion(val_output, label)
                acc = (val_output.argmax(dim=1) == label).float().mean()
                epoch_val_accuracy += acc / len(test_loader)
                epoch_val_loss += val_loss.cpu().detach().numpy() / len(test_loader)
                loss_test_epoch.append(epoch_val_loss)
            loss_test_all.append(np.array(loss_test_epoch).mean())
            print(f"Epoch : {epoch} | val_loss : {epoch_val_loss:.4f} | val_acc: {epoch_val_accuracy:.4f}")
        if epoch_val_accuracy > best_acc:
            best_acc = epoch_val_accuracy
            print('Current best model')
            torch.save(model.state_dict(), f"./weights/model-{epoch}-{epoch_val_accuracy:.4f}.pth")
print('==================================================================')
draw_loss(loss_train_all, loss_test_all)
torch.save(model.state_dict(), "model.pth")

epoch: 0, loss: 1.2012, acc:0.5909: 100%|██████████| 33/33 [00:01<00:00, 28.57it/s]


Epoch : 0 | val_loss : 0.8196 | val_acc: 0.6840
Current best model


epoch: 1, loss: 0.3981, acc:0.8690: 100%|██████████| 33/33 [00:00<00:00, 43.13it/s]
epoch: 2, loss: 0.1666, acc:0.9564: 100%|██████████| 33/33 [00:00<00:00, 45.02it/s]
epoch: 3, loss: 0.0638, acc:0.9915: 100%|██████████| 33/33 [00:00<00:00, 46.54it/s]
epoch: 4, loss: 0.0180, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 48.81it/s]


Epoch : 4 | val_loss : 0.2243 | val_acc: 0.9340
Current best model


epoch: 5, loss: 0.0065, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 49.17it/s]
epoch: 6, loss: 0.0042, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 48.89it/s]
epoch: 7, loss: 0.0027, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 44.95it/s]
epoch: 8, loss: 0.0022, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 43.93it/s]


Epoch : 8 | val_loss : 0.2613 | val_acc: 0.8819


epoch: 9, loss: 0.0017, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 43.82it/s]
epoch: 10, loss: 0.0016, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 43.59it/s]
epoch: 11, loss: 0.0013, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 42.96it/s]
epoch: 12, loss: 0.0011, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 42.57it/s]


Epoch : 12 | val_loss : 0.2262 | val_acc: 0.9375
Current best model


epoch: 13, loss: 0.0010, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 42.00it/s]
epoch: 14, loss: 0.0008, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 40.86it/s]
epoch: 15, loss: 0.0008, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 41.85it/s]
epoch: 16, loss: 0.0007, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 40.71it/s]


Epoch : 16 | val_loss : 0.2107 | val_acc: 0.9375


epoch: 17, loss: 0.0006, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 41.19it/s]
epoch: 18, loss: 0.0006, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 41.61it/s]
epoch: 19, loss: 0.0005, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 42.18it/s]
epoch: 20, loss: 0.0005, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 42.99it/s]


Epoch : 20 | val_loss : 0.2015 | val_acc: 0.9340


epoch: 21, loss: 0.0004, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 41.82it/s]
epoch: 22, loss: 0.0004, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 42.28it/s]
epoch: 23, loss: 0.0004, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 40.79it/s]
epoch: 24, loss: 0.0004, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 41.60it/s]


Epoch : 24 | val_loss : 0.2131 | val_acc: 0.9444
Current best model


epoch: 25, loss: 0.0003, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 41.95it/s]
epoch: 26, loss: 0.0003, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 41.51it/s]
epoch: 27, loss: 0.0003, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 41.51it/s]
epoch: 28, loss: 0.0003, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 40.58it/s]


Epoch : 28 | val_loss : 0.2078 | val_acc: 0.9479
Current best model


epoch: 29, loss: 0.0002, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 41.30it/s]
epoch: 30, loss: 0.0002, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 41.50it/s]
epoch: 31, loss: 0.0002, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 41.26it/s]
epoch: 32, loss: 0.0002, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 41.45it/s]


Epoch : 32 | val_loss : 0.2024 | val_acc: 0.9410


epoch: 33, loss: 0.0002, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 41.45it/s]
epoch: 34, loss: 0.0002, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 41.93it/s]
epoch: 35, loss: 0.0002, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 41.98it/s]
epoch: 36, loss: 0.0002, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 41.93it/s]


Epoch : 36 | val_loss : 0.2126 | val_acc: 0.9410


epoch: 37, loss: 0.0002, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 41.82it/s]
epoch: 38, loss: 0.0002, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 40.88it/s]
epoch: 39, loss: 0.0001, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 41.61it/s]
epoch: 40, loss: 0.0002, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 42.08it/s]


Epoch : 40 | val_loss : 0.2211 | val_acc: 0.9444


epoch: 41, loss: 0.0001, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 39.99it/s]
epoch: 42, loss: 0.0001, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 40.22it/s]
epoch: 43, loss: 0.0001, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 41.98it/s]
epoch: 44, loss: 0.0001, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 41.66it/s]


Epoch : 44 | val_loss : 0.3441 | val_acc: 0.8889


epoch: 45, loss: 0.0001, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 42.41it/s]
epoch: 46, loss: 0.0001, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 40.64it/s]
epoch: 47, loss: 0.0001, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 41.45it/s]
epoch: 48, loss: 0.0001, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 40.94it/s]


Epoch : 48 | val_loss : 0.2058 | val_acc: 0.9444


epoch: 49, loss: 0.0001, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 41.87it/s]
epoch: 50, loss: 0.0001, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 40.84it/s]
epoch: 51, loss: 0.0001, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 41.32it/s]
epoch: 52, loss: 0.0001, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 41.07it/s]


Epoch : 52 | val_loss : 0.2197 | val_acc: 0.9340


epoch: 53, loss: 0.0001, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 40.19it/s]
epoch: 54, loss: 0.0001, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 40.29it/s]
epoch: 55, loss: 0.0001, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 41.40it/s]
epoch: 56, loss: 0.0001, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 42.32it/s]


Epoch : 56 | val_loss : 0.2143 | val_acc: 0.9444


epoch: 57, loss: 0.0001, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 42.03it/s]
epoch: 58, loss: 0.0001, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 41.51it/s]
epoch: 59, loss: 0.0001, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 42.25it/s]
epoch: 60, loss: 0.0001, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 41.79it/s]


Epoch : 60 | val_loss : 0.2071 | val_acc: 0.9410


epoch: 61, loss: 0.0001, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 41.69it/s]
epoch: 62, loss: 0.0001, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 41.40it/s]
epoch: 63, loss: 0.0001, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 41.74it/s]
epoch: 64, loss: 0.0001, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 42.60it/s]


Epoch : 64 | val_loss : 0.4175 | val_acc: 0.8924


epoch: 65, loss: 0.0001, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 41.87it/s]
epoch: 66, loss: 0.0001, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 41.29it/s]
epoch: 67, loss: 0.0000, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 39.94it/s]
epoch: 68, loss: 0.0000, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 41.66it/s]


Epoch : 68 | val_loss : 0.2137 | val_acc: 0.9444


epoch: 69, loss: 0.0000, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 41.35it/s]
epoch: 70, loss: 0.0000, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 41.61it/s]
epoch: 71, loss: 0.0000, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 40.88it/s]
epoch: 72, loss: 0.0000, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 42.19it/s]


Epoch : 72 | val_loss : 0.2068 | val_acc: 0.9410


epoch: 73, loss: 0.0000, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 42.91it/s]
epoch: 74, loss: 0.0000, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 43.02it/s]
epoch: 75, loss: 0.0000, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 41.93it/s]
epoch: 76, loss: 0.0000, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 42.14it/s]


Epoch : 76 | val_loss : 0.2273 | val_acc: 0.9375


epoch: 77, loss: 0.0000, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 41.50it/s]
epoch: 78, loss: 0.0000, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 42.41it/s]
epoch: 79, loss: 0.0000, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 42.35it/s]
epoch: 80, loss: 0.0000, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 40.12it/s]


Epoch : 80 | val_loss : 0.2243 | val_acc: 0.9410


epoch: 81, loss: 0.0000, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 40.41it/s]
epoch: 82, loss: 0.0000, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 39.85it/s]
epoch: 83, loss: 0.0000, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 40.49it/s]
epoch: 84, loss: 0.0000, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 40.34it/s]


Epoch : 84 | val_loss : 0.2299 | val_acc: 0.9410


epoch: 85, loss: 0.0000, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 40.39it/s]
epoch: 86, loss: 0.0000, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 40.84it/s]
epoch: 87, loss: 0.0000, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 41.38it/s]
epoch: 88, loss: 0.0000, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 41.27it/s]


Epoch : 88 | val_loss : 0.2095 | val_acc: 0.9410


epoch: 89, loss: 0.0000, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 41.61it/s]
epoch: 90, loss: 0.0000, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 41.04it/s]
epoch: 91, loss: 0.0001, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 41.45it/s]
epoch: 92, loss: 0.0001, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 40.83it/s]


Epoch : 92 | val_loss : 0.2221 | val_acc: 0.9410


epoch: 93, loss: 0.0047, acc:0.9949: 100%|██████████| 33/33 [00:00<00:00, 39.61it/s]
epoch: 94, loss: 1.6288, acc:0.6600: 100%|██████████| 33/33 [00:00<00:00, 41.41it/s]
epoch: 95, loss: 0.5030, acc:0.8438: 100%|██████████| 33/33 [00:00<00:00, 41.13it/s]
epoch: 96, loss: 0.1387, acc:0.9527: 100%|██████████| 33/33 [00:00<00:00, 41.29it/s]


Epoch : 96 | val_loss : 0.4110 | val_acc: 0.8611


epoch: 97, loss: 0.0336, acc:0.9962: 100%|██████████| 33/33 [00:00<00:00, 39.94it/s]
epoch: 98, loss: 0.0104, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 40.93it/s]
epoch: 99, loss: 0.0049, acc:1.0000: 100%|██████████| 33/33 [00:00<00:00, 41.82it/s]




# Evaluation
The best accuracy on test data is over 94%