In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
from parse import parse
from autocrop import Cropper

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision
from torchvision.transforms import ToTensor, Compose, Grayscale, Resize, CenterCrop

from custom_loader import AgeDBDataset
from custom_loss_functions import AngularPenaltySMLoss

import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2

import warnings
warnings.filterwarnings("ignore")

In [2]:
# hyper params
num_of_class = 102
hidden_unit = 256
learning_rate = 1e-03
input_size = 64
batch_size = 64
device = torch.device("cuda")

In [3]:
transformA = A.Compose([
    A.Resize(input_size, input_size),
    A.ToGray(p=1),
    A.Rotate(limit=10, p=0.3),
    A.HorizontalFlip(p=0.4),
    A.OpticalDistortion(p=0.2),
    A.augmentations.transforms.ChannelDropout(p=1.0),
    A.OneOf([
        A.Blur(blur_limit=3, p=0.2),
        A.ColorJitter(p=0.2),
    ], p=0.2),
    ToTensorV2()
])

In [4]:
### --- AgeDB Dataset Class --- ###           {}_{person}_{age}_{gender}.jpg


class AgeDBDataset(Dataset):

    ## data loading
    def __init__(self, directory, device, transform=None, **kwargs):
        self.directory = directory
        self.transform = transform
        self.device = device
        self.labels = []
        self.images = []

        gender_to_class_id = {'m': 0, 'f': 1}

        for i, file in enumerate(sorted(os.listdir(self.directory))):
            file_labels = parse('{}_{}_{age}_{gender}.jpg', file)

            if file_labels is None:
                continue

            image = Image.open(os.path.join(self.directory,
                                            file)).convert('RGB')

            ########
            cropper = Cropper()

            try:
                #Get a Numpy array of the cropped image
                cropped_array = cropper.crop(image)
                #Save the cropped image with PIL
                image = Image.fromarray(cropped_array)

            except:
                pass

            image = np.array(image)

            for i in range(10):
                augmented_images = self.transform(image=image)['image']
                self.images.append(augmented_images)
                ########
                gender = gender_to_class_id[file_labels['gender']]
                age = int(file_labels['age'])
                self.labels.append({
                    'age': age,
                    'gender': gender
                })

## len(dataset)

    def __len__(self):
        return len(self.labels)

## dataset[0]

    def __getitem__(self, index):
        if torch.is_tensor(index):
            index = index.tolist()

        image = self.images[index]

        labels = {
            'age': self.labels[index]['age'],
            'gender': self.labels[index]['gender']
        }

        return image.to(self.device), labels


## DataLoaders - train, validate, test

    def get_loaders(self, batch_size, train_size, test_size, random_seed,
                    **kwargs):
        train_len = int(len(self) * train_size)
        test_len = int(len(self) * test_size)
        validate_len = len(self) - (train_len + test_len)

        self.trainDataset, self.validateDataset, self.testDataset = torch.utils.data.random_split(
            dataset=self,
            lengths=[train_len, validate_len, test_len],
            generator=torch.Generator().manual_seed(random_seed))

        train_loader = DataLoader(self.trainDataset, batch_size=batch_size)
        validate_loader = DataLoader(self.validateDataset,
                                     batch_size=batch_size)
        test_loader = DataLoader(self.testDataset, batch_size=batch_size)

        return train_loader, validate_loader, test_loader

In [5]:
dataset = AgeDBDataset(
    directory='AgeDB/',
    transform=transformA,
    device=device,
)

In [6]:
train_set, validation_set, test_set = dataset.get_loaders(
    batch_size=batch_size,
    train_size=0.8,
    test_size=0.2,
    random_seed=42,
)

In [7]:
len(dataset)

164880

In [8]:
class AgeDBConvModel(nn.Module):
    def __init__(self, num_of_classes):
        super(AgeDBConvModel, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=2), #(64+2(2)-3)+1=66
            nn.BatchNorm2d(num_features=32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2) #((66-2)/2)+1 = 33
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=2), #(33+2(2)-3)+1=35
            nn.BatchNorm2d(num_features=64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2) #((35-2)/2)+1 = 17 + 0.5
        )
        self.layer3 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=2), #(17+2(2)-3)+1=19
            nn.BatchNorm2d(num_features=128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2) #((19-2)/2)+1 = 9 + 0.5
        )
        self.layer4 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=2), #(9+2(2)-3)+1=11
            nn.BatchNorm2d(num_features=256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2) #((11-2)/2)+1 = 5 + 0.5
        )
        self.layer5 = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=2), #(5+2(2)-3)+1=7
            nn.BatchNorm2d(num_features=512),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2) #((7-2)/2)+1 = 3 + 0.5
        )
        self.fc1 = nn.Linear(3*3*512, num_of_classes)
        
        
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.layer5(out)

        out = out.reshape(out.size(0), -1)
        out = self.fc1(out)
        
        return out

In [14]:
convModel1 = AgeDBConvModel(num_of_class).to(device)
convModel2 = AgeDBConvModel(num_of_class).to(device)
convModel3 = AgeDBConvModel(num_of_class).to(device)

In [10]:
# Training loop
def train(model, optimizer, criterion, train_loader, num_of_epoch):
    total_step = len(train_loader)
    for epoch in range(num_of_epoch):
        for i, (imgs, labels) in enumerate(train_loader):
            imgs = imgs.to(device).float()
            labels = torch.as_tensor(labels['age']).to(device)
            
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if (i+1)%total_step == 0:
                print(f"Epoch: {epoch+1}/{num_of_epoch}, Step: {i+1}/{total_step}, Loss: {loss.item()}")

# Evaluation
def eval(model, test_loader):
    with torch.no_grad():
        correct = 0
        total = 0
        error = torch.zeros(0).to(device)
        for imgs, labels in test_loader:
            imgs = imgs.to(device).float()
            labels = torch.as_tensor(labels['age']).to(device)
            outputs = model(imgs)
            
            _, pred = torch.max(outputs.data, 1)

            error = torch.cat([error, torch.abs(
                torch.subtract(torch.reshape(labels, (-1,)), torch.reshape(pred, (-1,)))
            )])
                        
            total += labels.size(0)
            correct += (pred == labels).sum().item()
            
    print(f"Accuracy: {(100*correct)/total}%")
    print(f"Mean Absolute Error: {(torch.mean(error))}")
    print(f"Minimum: {torch.min(error)}, Maximum: {torch.max(error)}, Median: {torch.median(error)}")

In [11]:
criteria = AngularPenaltySMLoss(
    in_features=num_of_class,
    out_features=num_of_class,
    loss_type='arcface',
)
optimizer = torch.optim.Adam(convModel1.parameters(), lr=learning_rate)
train(convModel1, optimizer, criteria, train_set, num_of_epoch=20)

Epoch: 1/20, Step: 2061/2061, Loss: 34.55863952636719
Epoch: 2/20, Step: 2061/2061, Loss: 34.026058197021484
Epoch: 3/20, Step: 2061/2061, Loss: 30.480792999267578
Epoch: 4/20, Step: 2061/2061, Loss: 22.241352081298828
Epoch: 5/20, Step: 2061/2061, Loss: 17.97844123840332
Epoch: 6/20, Step: 2061/2061, Loss: 15.828337669372559
Epoch: 7/20, Step: 2061/2061, Loss: 14.4656982421875
Epoch: 8/20, Step: 2061/2061, Loss: 13.320916175842285
Epoch: 9/20, Step: 2061/2061, Loss: 12.509502410888672
Epoch: 10/20, Step: 2061/2061, Loss: 11.711514472961426
Epoch: 11/20, Step: 2061/2061, Loss: 11.409056663513184
Epoch: 12/20, Step: 2061/2061, Loss: 10.688467025756836
Epoch: 13/20, Step: 2061/2061, Loss: 9.860223770141602
Epoch: 14/20, Step: 2061/2061, Loss: 9.70979118347168
Epoch: 15/20, Step: 2061/2061, Loss: 9.254637718200684
Epoch: 16/20, Step: 2061/2061, Loss: 8.840057373046875
Epoch: 17/20, Step: 2061/2061, Loss: 8.236944198608398
Epoch: 18/20, Step: 2061/2061, Loss: 8.008955955505371
Epoch: 19/20

In [15]:
eval(convModel1, test_set)
print()
eval(convModel1, train_set)

Accuracy: 1.0462154294032022%
Mean Absolute Error: 25.114416122436523
Minimum: 0.0, Maximum: 91.0, Median: 22.0

Accuracy: 1.0697173702086367%
Mean Absolute Error: 24.97016716003418
Minimum: 0.0, Maximum: 94.0, Median: 22.0


In [16]:
criteria = AngularPenaltySMLoss(
    in_features=num_of_class,
    out_features=num_of_class,
    loss_type='cosface',
)
optimizer = torch.optim.Adam(convModel2.parameters(), lr=learning_rate)
train(convModel2, optimizer, criteria, train_set, num_of_epoch=20)

Epoch: 1/20, Step: 2061/2061, Loss: 15.769417762756348
Epoch: 2/20, Step: 2061/2061, Loss: 14.485084533691406
Epoch: 3/20, Step: 2061/2061, Loss: 10.117796897888184
Epoch: 4/20, Step: 2061/2061, Loss: 7.292023181915283
Epoch: 5/20, Step: 2061/2061, Loss: 6.101036548614502
Epoch: 6/20, Step: 2061/2061, Loss: 5.367204666137695
Epoch: 7/20, Step: 2061/2061, Loss: 4.9822893142700195
Epoch: 8/20, Step: 2061/2061, Loss: 4.6685333251953125
Epoch: 9/20, Step: 2061/2061, Loss: 4.278404235839844
Epoch: 10/20, Step: 2061/2061, Loss: 3.9736342430114746
Epoch: 11/20, Step: 2061/2061, Loss: 3.735743522644043
Epoch: 12/20, Step: 2061/2061, Loss: 3.4399871826171875
Epoch: 13/20, Step: 2061/2061, Loss: 3.2809550762176514
Epoch: 14/20, Step: 2061/2061, Loss: 3.0514724254608154
Epoch: 15/20, Step: 2061/2061, Loss: 2.853799819946289
Epoch: 16/20, Step: 2061/2061, Loss: 2.843712091445923
Epoch: 17/20, Step: 2061/2061, Loss: 2.7290642261505127
Epoch: 18/20, Step: 2061/2061, Loss: 2.6543326377868652
Epoch: 1

In [17]:
eval(convModel2, test_set)
print()
eval(convModel2, train_set)

Accuracy: 1.0613779718583212%
Mean Absolute Error: 29.084941864013672
Minimum: 0.0, Maximum: 98.0, Median: 26.0

Accuracy: 0.8877668607472101%
Mean Absolute Error: 28.67450523376465
Minimum: 0.0, Maximum: 99.0, Median: 26.0


In [18]:
criteria = AngularPenaltySMLoss(
    in_features=num_of_class,
    out_features=num_of_class,
    loss_type='sphereface',
)
optimizer = torch.optim.Adam(convModel3.parameters(), lr=learning_rate)
train(convModel3, optimizer, criteria, train_set, num_of_epoch=20)

Epoch: 1/20, Step: 2061/2061, Loss: 36.44721984863281
Epoch: 2/20, Step: 2061/2061, Loss: 35.025211334228516
Epoch: 3/20, Step: 2061/2061, Loss: 25.769569396972656
Epoch: 4/20, Step: 2061/2061, Loss: 17.463092803955078
Epoch: 5/20, Step: 2061/2061, Loss: 12.482423782348633
Epoch: 6/20, Step: 2061/2061, Loss: 9.896167755126953
Epoch: 7/20, Step: 2061/2061, Loss: 7.999245643615723
Epoch: 8/20, Step: 2061/2061, Loss: 7.064300060272217
Epoch: 9/20, Step: 2061/2061, Loss: 5.972658634185791
Epoch: 10/20, Step: 2061/2061, Loss: 5.067168235778809
Epoch: 11/20, Step: 2061/2061, Loss: 4.358851909637451
Epoch: 12/20, Step: 2061/2061, Loss: 3.6999664306640625
Epoch: 13/20, Step: 2061/2061, Loss: 3.3063313961029053
Epoch: 14/20, Step: 2061/2061, Loss: 2.9672927856445312
Epoch: 15/20, Step: 2061/2061, Loss: 2.811723232269287
Epoch: 16/20, Step: 2061/2061, Loss: 2.6453423500061035
Epoch: 17/20, Step: 2061/2061, Loss: 2.1422860622406006
Epoch: 18/20, Step: 2061/2061, Loss: 1.930097222328186
Epoch: 19/

In [19]:
eval(convModel3, test_set)
print()
eval(convModel3, train_set)

Accuracy: 0.6550218340611353%
Mean Absolute Error: 28.086820602416992
Minimum: 0.0, Maximum: 97.0, Median: 25.0

Accuracy: 0.6830725376031053%
Mean Absolute Error: 27.68055534362793
Minimum: 0.0, Maximum: 100.0, Median: 24.0
