In [1]:
from facenet_pytorch import MTCNN, InceptionResnetV1, fixed_image_standardization, training
import torch
from torch.utils.data import DataLoader, SubsetRandomSampler
from torch import optim
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
import numpy as np
import os

In [2]:
data_dir = "../../datasets/train"
batch_size = 32
epochs = 5
workers = 8

In [3]:
device = torch.device('cuda:0')

In [4]:
mtcnn = MTCNN(
    image_size=160, margin=0, min_face_size=20,
    thresholds=[0.6, 0.7, 0.7], factor=0.709, post_process=True,
    device=device,
)

In [5]:
dataset = datasets.ImageFolder(data_dir, transform=transforms.Resize((512, 512)))
dataset.samples = [
    (p, p.replace(data_dir, data_dir + "_cropped"))
    for p, _ in dataset.samples
]

In [6]:
loader = DataLoader(
    dataset,
    num_workers=workers,
    batch_size=batch_size,
    collate_fn=training.collate_pil
)

for i, (x, y) in enumerate(loader):
    mtcnn(x, save_path=y)
    print("\rBatch {} of {}".format(i+1, len(loader)), end='')

Batch 26 of 26

In [7]:
del(mtcnn)

In [8]:
len(dataset.class_to_idx)

10

In [9]:
resnet = InceptionResnetV1(
    classify=True,
    pretrained='vggface2',
    num_classes=len(dataset.class_to_idx)
).to(device)

In [10]:
optimizer = optim.Adam(resnet.parameters(), lr=0.001)
scheduler = MultiStepLR(optimizer=optimizer, milestones=[5, 10])

trans = transforms.Compose([
    np.float32,
    transforms.ToTensor(),
    fixed_image_standardization
])

In [11]:
dataset = datasets.ImageFolder(data_dir + '_cropped', transform=trans)
img_inds = np.arange(len(dataset))
np.random.shuffle(img_inds)
train_inds = img_inds[:int(0.8*len(img_inds))]
val_inds = img_inds[int(0.8*len(img_inds)):]

train_inds.shape, val_inds.shape

((656,), (165,))

In [12]:
train_loader = DataLoader(
    dataset,
    num_workers=workers,
    batch_size=batch_size,
    sampler=SubsetRandomSampler(train_inds)
)

val_loader = DataLoader(
    dataset,
    num_workers=workers,
    batch_size=batch_size,
    sampler=SubsetRandomSampler(val_inds)
)

In [13]:
loss_fn = torch.nn.CrossEntropyLoss()
metrics = {
    'fps': training.BatchTimer(),
    'acc': training.accuracy
}

In [14]:
writer = SummaryWriter()
writer.iteration, writer.interval = 0, 10

print('\n\nInitial')
print('-' *10)
resnet.eval()
training.pass_epoch(
    resnet, loss_fn, val_loader,
    batch_metrics=metrics, show_running=True,
    device=device, writer=writer
)



Initial
----------
Valid |     6/6    | loss:    2.3322 | fps:  394.7175 | acc:    0.1917   


(tensor(2.3322), {'fps': tensor(394.7175), 'acc': tensor(0.1917)})

In [15]:
for epoch in range(epochs):
    print(f'\nEpoch {epoch+1}/{epochs}')
    print('-'*10)
    resnet.train()
    training.pass_epoch(
        resnet, loss_fn, train_loader, optimizer, scheduler,
        batch_metrics=metrics, show_running=True, device=device,
        writer=writer
    )
    resnet.eval()
    training.pass_epoch(
        resnet, loss_fn, val_loader,
        batch_metrics=metrics, show_running=True,
        device=device, writer=writer
    )
writer.close()


Epoch 1/5
----------
Train |    21/21   | loss:    0.3792 | fps:  138.5496 | acc:    0.8929   
Valid |     6/6    | loss:    6.9417 | fps:  339.3816 | acc:    0.4354   

Epoch 2/5
----------
Train |    21/21   | loss:    0.1771 | fps:  129.5184 | acc:    0.9509   
Valid |     6/6    | loss:    1.4781 | fps:  338.2319 | acc:    0.7719   

Epoch 3/5
----------
Train |    21/21   | loss:    0.1549 | fps:  128.8703 | acc:    0.9702   
Valid |     6/6    | loss:    1.4632 | fps:  337.8444 | acc:    0.8208   

Epoch 4/5
----------
Train |    21/21   | loss:    0.0862 | fps:  128.2384 | acc:    0.9777   
Valid |     6/6    | loss:    0.1585 | fps:  337.0335 | acc:    0.9510   

Epoch 5/5
----------
Train |    21/21   | loss:    0.0602 | fps:  128.6291 | acc:    0.9851   
Valid |     6/6    | loss:    0.0351 | fps:  333.2351 | acc:    0.9844   


In [16]:
save_path = '../../models/facenet_v1.pth'

os.makedirs(os.path.dirname(save_path), exist_ok=True)
torch.save(resnet.state_dict(), save_path)

print(f"\nModel successfully saved to {save_path}")


Model successfully saved to ../../models/facenet_v1.pth
