In [1]:
import os
import sys
import numpy as np
from PIL import Image
from tqdm import tqdm
from parse import parse
from autocrop import Cropper
from IPython.display import clear_output

import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader, Dataset

import torchvision
from torchvision import datasets
from torchvision.transforms import ToTensor, Compose, Scale, Grayscale, Resize, transforms

import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2
from custom_loader import AgeDBDataset
import warnings
warnings.filterwarnings("ignore")

In [2]:
# hyper params
num_of_class = 102
hidden_unit = 256
learning_rate = 1e-04
batch_size = 64
input_size = 64
device = torch.device("cuda")

In [3]:
transformA = A.Compose([
    A.Resize(input_size, input_size),
    A.ToGray(p=1),
    A.Rotate(limit=10, p=0.3),
    A.HorizontalFlip(p=0.4),
    A.OpticalDistortion(p=0.2),
    A.augmentations.transforms.ChannelDropout(p=1.0),
    A.OneOf([
        A.Blur(blur_limit=3, p=0.2),
        A.ColorJitter(p=0.2),
    ], p=0.2),
    ToTensorV2()
])

In [4]:
### --- AgeDB Dataset Class --- ###           {}_{person}_{age}_{gender}.jpg


class AgeDBDataset(Dataset):

    ## data loading
    def __init__(self, directory, device, transform=None, **kwargs):
        self.directory = directory
        self.transform = transform
        self.device = device
        self.labels = []
        self.images = []

        gender_to_class_id = {'m': 0, 'f': 1}

        for i, file in enumerate(sorted(os.listdir(self.directory))):
            file_labels = parse('{}_{}_{age}_{gender}.jpg', file)

            if file_labels is None:
                continue

            image = Image.open(os.path.join(self.directory,
                                            file)).convert('RGB')

            ########
            cropper = Cropper()

            try:
                #Get a Numpy array of the cropped image
                cropped_array = cropper.crop(image)
                #Save the cropped image with PIL
                image = Image.fromarray(cropped_array)

            except:
                pass

            image = np.array(image)

            for i in range(3):
                augmented_images = self.transform(image=image)['image']
                self.images.append(augmented_images)
                ########
                gender = gender_to_class_id[file_labels['gender']]
                age = int(file_labels['age'])
                self.labels.append({
                    'age': age,
                    'gender': gender
                })

## len(dataset)

    def __len__(self):
        return len(self.labels)

## dataset[0]

    def __getitem__(self, index):
        if torch.is_tensor(index):
            index = index.tolist()

        image = self.images[index]

        labels = {
            'age': self.labels[index]['age'],
            'gender': self.labels[index]['gender']
        }

        return image.to(self.device), labels


## DataLoaders - train, validate, test

    def get_loaders(self, batch_size, train_size, test_size, random_seed,
                    **kwargs):
        train_len = int(len(self) * train_size)
        test_len = int(len(self) * test_size)
        validate_len = len(self) - (train_len + test_len)

        self.trainDataset, self.validateDataset, self.testDataset = torch.utils.data.random_split(
            dataset=self,
            lengths=[train_len, validate_len, test_len],
            generator=torch.Generator().manual_seed(random_seed))

        train_loader = DataLoader(self.trainDataset, batch_size=batch_size)
        validate_loader = DataLoader(self.validateDataset,
                                     batch_size=batch_size)
        test_loader = DataLoader(self.testDataset, batch_size=batch_size)

        return train_loader, validate_loader, test_loader

In [5]:
dataset = AgeDBDataset(
    directory='AgeDB/',
    transform=transformA,
    device=device,
)

In [6]:
train_set, validation_set, test_set = dataset.get_loaders(
    batch_size=batch_size,
    train_size=0.8,
    test_size=0.2,
    random_seed=42,
)

In [7]:
len(dataset)

49464

In [8]:
class AgeDBConvModel(nn.Module):
    def __init__(self, num_of_classes):
        super(AgeDBConvModel, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=2), #(64+2(2)-3)+1=66
            nn.BatchNorm2d(num_features=32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2) #((66-2)/2)+1 = 33
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=2), #(33+2(2)-3)+1=35
            nn.BatchNorm2d(num_features=64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2) #((35-2)/2)+1 = 17 + 0.5
        )
        self.layer3 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=2), #(17+2(2)-3)+1=19
            nn.BatchNorm2d(num_features=128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2) #((19-2)/2)+1 = 9 + 0.5
        )
        self.layer4 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=2), #(9+2(2)-3)+1=11
            nn.BatchNorm2d(num_features=256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2) #((11-2)/2)+1 = 5 + 0.5
        )
        self.layer5 = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=2), #(5+2(2)-3)+1=7
            nn.BatchNorm2d(num_features=512),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2) #((7-2)/2)+1 = 3 + 0.5
        )
        self.fc1 = nn.Linear(3*3*512, num_of_classes)
        
        
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.layer5(out)

        out = out.reshape(out.size(0), -1)
        out = self.fc1(out)
        
        return out

In [9]:
convModel = AgeDBConvModel(num_of_class).to(device)

In [10]:
# Training loop
def train(model, optimizer, criterion, train_loader, num_of_epoch):
    total_step = len(train_loader)
    for epoch in range(num_of_epoch):
        for i, (imgs, labels) in enumerate(train_loader):
            imgs = imgs.to(device).float()
            labels = torch.as_tensor(labels['age']).to(device)
            
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if (i+1)%total_step == 0:
                print(f"Epoch: {epoch+1}/{num_of_epoch}, Step: {i+1}/{total_step}, Loss: {loss.item()}")

In [11]:
criteria = nn.CrossEntropyLoss() 
optimizer = torch.optim.Adam(convModel.parameters(), lr=learning_rate)

In [12]:
train(convModel, optimizer, criteria, train_set, num_of_epoch=20)

Epoch: 1/20, Step: 619/619, Loss: 4.058262348175049
Epoch: 2/20, Step: 619/619, Loss: 3.7947826385498047
Epoch: 3/20, Step: 619/619, Loss: 3.3983185291290283
Epoch: 4/20, Step: 619/619, Loss: 2.74649715423584
Epoch: 5/20, Step: 619/619, Loss: 1.909659743309021
Epoch: 6/20, Step: 619/619, Loss: 1.0497331619262695
Epoch: 7/20, Step: 619/619, Loss: 0.5177012085914612
Epoch: 8/20, Step: 619/619, Loss: 0.2676871120929718
Epoch: 9/20, Step: 619/619, Loss: 0.15219804644584656
Epoch: 10/20, Step: 619/619, Loss: 0.1030428558588028
Epoch: 11/20, Step: 619/619, Loss: 0.07053124904632568
Epoch: 12/20, Step: 619/619, Loss: 0.05115140601992607
Epoch: 13/20, Step: 619/619, Loss: 0.041445132344961166
Epoch: 14/20, Step: 619/619, Loss: 0.027941735461354256
Epoch: 15/20, Step: 619/619, Loss: 0.021905971691012383
Epoch: 16/20, Step: 619/619, Loss: 0.032048746943473816
Epoch: 17/20, Step: 619/619, Loss: 0.01902121864259243
Epoch: 18/20, Step: 619/619, Loss: 0.01741473376750946
Epoch: 19/20, Step: 619/619,

In [13]:
# Evaluation
def eval(model, test_loader):
    with torch.no_grad():
        correct = 0
        total = 0
        error = torch.zeros(0).to(device)
        for imgs, labels in test_loader:
            imgs = imgs.to(device).float()
            labels = torch.as_tensor(labels['age']).to(device)
            outputs = model(imgs)
            
            _, pred = torch.max(outputs.data, 1)

            error = torch.cat([error, torch.abs(
                torch.subtract(torch.reshape(labels, (-1,)), torch.reshape(pred, (-1,)))
            )])
                        
            total += labels.size(0)
            correct += (pred == labels).sum().item()
            
    print(f"Accuracy: {(100*correct)/total}%")
    print(f"Mean Absolute Error: {(torch.mean(error))}")
    print(f"Minimum: {torch.min(error)}, Maximum: {torch.max(error)}, Median: {torch.median(error)}")

In [14]:
eval(convModel, test_set)

Accuracy: 19.207440355843104%
Mean Absolute Error: 10.545187950134277
Minimum: 0.0, Maximum: 78.0, Median: 8.0


In [15]:
eval(convModel, train_set)

Accuracy: 92.43638017740264%
Mean Absolute Error: 0.9194612503051758
Minimum: 0.0, Maximum: 60.0, Median: 0.0
