In [1]:
from facenet_pytorch import MTCNN, InceptionResnetV1, fixed_image_standardization, training
import torch
from torch.utils.data import DataLoader, SubsetRandomSampler
from torch import optim
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
import numpy as np
import os

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


#### Define run parameters

In [2]:
data_dir = '../data/all'

batch_size = 16
epochs = 8
workers = 0 if os.name == 'nt' else 8

#### Check device

In [3]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Running on device: {}'.format(device))

Running on device: cuda:0


#### Define MTCNN module

In [4]:
mtcnn = MTCNN(
    image_size=160, margin=0, min_face_size=20,
    thresholds=[0.6, 0.7, 0.7], factor=0.709, post_process=True,
    device=device
)

#### Perfom MTCNN facial detection

#### Define Inception Resnet V1 module

In [5]:
resnet = InceptionResnetV1(
    classify=True,
    pretrained='vggface2',
    num_classes=23
).to(device)

#### Define optimizer, scheduler, dataset, and dataloader

In [6]:
optimizer = optim.Adam(resnet.parameters(), lr=0.001)
scheduler = MultiStepLR(optimizer, [5, 10])

trans = transforms.Compose([
    np.float32,
    transforms.ToTensor(),
    fixed_image_standardization
])
# dataset = datasets.ImageFolder(data_dir + '_cropped', transform=trans)
dataset = datasets.ImageFolder(data_dir, transform=trans)
img_inds = np.arange(len(dataset))
np.random.shuffle(img_inds)
train_inds = img_inds[:int(0.8 * len(img_inds))]
val_inds = img_inds[int(0.8 * len(img_inds)):]

train_loader = DataLoader(
    dataset,
    num_workers=workers,
    batch_size=batch_size,
    sampler=SubsetRandomSampler(train_inds)
)
val_loader = DataLoader(
    dataset,
    num_workers=workers,
    batch_size=batch_size,
    sampler=SubsetRandomSampler(val_inds)
)

#### Define loss and evaluation functions

In [7]:
loss_fn = torch.nn.CrossEntropyLoss()
metrics = {
    'fps': training.BatchTimer(),
    'acc': training.accuracy
}

#### Train model

In [8]:
writer = SummaryWriter()
writer.iteration, writer.interval = 0, 10

print('\n\nInitial')
print('-' * 10)
resnet.eval()
training.pass_epoch(
    resnet, loss_fn, val_loader,
    batch_metrics=metrics, show_running=True, device=device,
    writer=writer
)

for epoch in range(epochs):
    print('\nEpoch {}/{}'.format(epoch + 1, epochs))
    print('-' * 10)

    resnet.train()
    training.pass_epoch(
        resnet, loss_fn, train_loader, optimizer, scheduler,
        batch_metrics=metrics, show_running=True, device=device,
        writer=writer
    )

    resnet.eval()
    training.pass_epoch(
        resnet, loss_fn, val_loader,
        batch_metrics=metrics, show_running=True, device=device,
        writer=writer
    )

writer.close()



Initial
----------
Valid |    16/16   | loss:    3.1122 | fps:  666.5499 | acc:    0.0781   

Epoch 1/8
----------
Train |    62/62   | loss:    2.4250 | fps:  168.7473 | acc:    0.3236   
Valid |    16/16   | loss:    2.4920 | fps:  646.9950 | acc:    0.3875   

Epoch 2/8
----------
Train |    62/62   | loss:    1.6879 | fps:  162.5123 | acc:    0.4829   
Valid |    16/16   | loss:    2.1188 | fps:  752.8311 | acc:    0.4078   

Epoch 3/8
----------
Train |    62/62   | loss:    1.0907 | fps:  162.1543 | acc:    0.6663   
Valid |    16/16   | loss:    1.6492 | fps:  605.8689 | acc:    0.5305   

Epoch 4/8
----------
Train |    62/62   | loss:    0.8362 | fps:  162.7656 | acc:    0.7208   
Valid |    16/16   | loss:    1.3361 | fps:  514.8594 | acc:    0.5883   

Epoch 5/8
----------
Train |    62/62   | loss:    0.6795 | fps:  163.6182 | acc:    0.7823   
Valid |    16/16   | loss:    1.1315 | fps:  539.8393 | acc:    0.6516   

Epoch 6/8
----------
Train |    62/62   | loss:    0.3

#### Save fintuned model

In [9]:
torch.save(resnet.state_dict(), '../data/fintuned_model.pt')