In [1]:
import os
import sys
import numpy as np
from PIL import Image
from tqdm import tqdm
from parse import parse
from autocrop import Cropper
import pretrainedmodels
from kornia.losses import FocalLoss, BinaryFocalLossWithLogits

import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader, Dataset

import torchvision
from torchvision import datasets
import torchvision.models as models
from torchvision.transforms import ToTensor, Compose, Scale, Grayscale, Resize, transforms

import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2
import warnings
warnings.filterwarnings("ignore")

In [2]:
# hyper params
num_of_class = 102
learning_rate = 1e-03
batch_size = 128
input_size = 64
device = torch.device("cuda")

In [3]:
Augmentation = A.Compose([
    A.Resize(input_size, input_size),
    A.ToGray(p=1),
    A.Rotate(limit=10, p=0.3),
    A.HorizontalFlip(p=0.4),
    A.OpticalDistortion(p=0.2),
    A.OneOf([
        A.Blur(blur_limit=3, p=0.2),
        A.ColorJitter(p=0.2),
    ], p=0.2),
    ToTensorV2()
])

In [4]:
class AgeDBDataset(Dataset):

    def __init__(self, directory, device, transform=None, **kwargs):
        self.directory = directory
        self.transform = transform
        self.device = device
        self.labels = []
        self.images = []

        gender_to_class_id = {'m': 0, 'f': 1}

        for i, file in enumerate(sorted(os.listdir(self.directory))):
            file_labels = parse('{}_{}_{age}_{gender}.jpg', file)

            if file_labels is None:
                continue

            image = Image.open(os.path.join(self.directory, file)).convert('RGB')

            
            cropper = Cropper()

            try:
                cropped_array = cropper.crop(image)
                image = Image.fromarray(cropped_array)

            except:
                pass

            image = np.array(image)

            augmented_images = self.transform(image=image)['image']
            self.images.append(augmented_images)

            gender = gender_to_class_id[file_labels['gender']]
            age = int(file_labels['age'])
            self.labels.append({
                'age': age,
                'gender': gender
            })

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        if torch.is_tensor(index):
            index = index.tolist()

        image = self.images[index]

        labels = {
            'age': self.labels[index]['age'],
            'gender': self.labels[index]['gender']
        }

        return image.to(self.device), labels


    def get_loaders(self, batch_size, train_size, test_size, **kwargs):
        train_len = int(len(self) * train_size)
        test_len = int(len(self) * test_size)
        validate_len = len(self) - (train_len + test_len)

        self.trainDataset, self.validateDataset, self.testDataset = torch.utils.data.random_split(
            dataset=self,
            lengths=[train_len, validate_len, test_len],
            generator=torch.Generator().manual_seed(42)
        )

        train_loader = DataLoader(
            self.trainDataset, batch_size=batch_size, shuffle=True,
        )
        validate_loader = DataLoader(
            self.validateDataset, batch_size=batch_size, shuffle=True,
        )
        test_loader = DataLoader(
            self.testDataset, batch_size=batch_size, shuffle=True,
        )

        return train_loader, validate_loader, test_loader

In [5]:
dataset = AgeDBDataset(
    directory='AgeDB/',
    transform=Augmentation,
    device=device,
)

In [6]:
train_set, validation_set, test_set = dataset.get_loaders(
    batch_size=batch_size,
    train_size=0.7,
    test_size=0.1,
)

In [7]:
len(dataset)

16488

In [8]:
resnetModel = pretrainedmodels.resnet18(pretrained='imagenet')
resnetModel.last_linear = nn.Linear(in_features=resnetModel.last_linear.in_features, out_features=num_of_class, bias=False)
resnetModel = resnetModel.to(device)

In [9]:
# Training loop
def train(model, optimizer, criterion, train_loader, valid_loader, num_of_epoch):
    total_step = len(train_loader)
    min_valid_loss = np.inf

    for epoch in range(num_of_epoch):
        train_loss = 0.0
        valid_loss = 0.0
        for i, (imgs, labels) in enumerate(train_loader):
            imgs = imgs.to(device).float()
            labels = torch.as_tensor(labels['age']).to(device)
            
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            
        model.eval()
        for i, (imgs, labels) in enumerate(valid_loader):
            with torch.no_grad():
                imgs = imgs.to(device).float()
                labels = torch.as_tensor(labels['age']).to(device)
            
                outputs = model(imgs)
                loss = criterion(outputs, labels)
                valid_loss += loss.item()

        print(f"Epoch: {epoch+1}/{num_of_epoch}, Train Loss: {train_loss/len(train_loader)},  Validation Loss: {valid_loss/len(valid_loader)}")
              
        if min_valid_loss > valid_loss:
            print(f'Validation Loss Decreased({min_valid_loss} ---> {valid_loss})')
            min_valid_loss = valid_loss

In [10]:
criteria = FocalLoss(alpha=0.5, gamma=2.0, reduction='mean')
optimizer = torch.optim.SGD(resnetModel.parameters(), lr=learning_rate)

In [11]:
train(resnetModel, optimizer, criteria, train_set, validation_set, num_of_epoch=10)

Epoch: 1/10, Train Loss: 2.421604892709753,  Validation Loss: 2.4030894958055935
Validation Loss Decreased(inf ---> 62.480326890945435)
Epoch: 2/10, Train Loss: 2.2654372519189185,  Validation Loss: 2.2379330213253317
Validation Loss Decreased(62.480326890945435 ---> 58.18625855445862)
Epoch: 3/10, Train Loss: 2.2087993910024455,  Validation Loss: 2.19332283276778
Validation Loss Decreased(58.18625855445862 ---> 57.02639365196228)
Epoch: 4/10, Train Loss: 2.156275424328479,  Validation Loss: 2.139504771966201
Validation Loss Decreased(57.02639365196228 ---> 55.627124071121216)
Epoch: 5/10, Train Loss: 2.10928481227749,  Validation Loss: 2.1030104893904467
Validation Loss Decreased(55.627124071121216 ---> 54.67827272415161)
Epoch: 6/10, Train Loss: 2.0749516513321424,  Validation Loss: 2.077870699075552
Validation Loss Decreased(54.67827272415161 ---> 54.024638175964355)
Epoch: 7/10, Train Loss: 2.0532960027128784,  Validation Loss: 2.0637951447413516
Validation Loss Decreased(54.024638

In [12]:
# Evaluation
def eval(model, test_loader):
    with torch.no_grad():
        correct = 0
        total = 0
        error = torch.zeros(0).to(device)
        for imgs, labels in test_loader:
            imgs = imgs.to(device).float()
            labels = torch.as_tensor(labels['age']).to(device)
            outputs = model(imgs)
            
            _, pred = torch.max(outputs.data, 1)

            error = torch.cat([error, torch.abs(
                torch.subtract(torch.reshape(labels, (-1,)), torch.reshape(pred, (-1,)))
            )])
                        
            total += labels.size(0)
            correct += (pred == labels).sum().item()
            
    print(f"Accuracy: {(100*correct)/total}%")
    print(f"Mean Absolute Error: {(torch.mean(error))}")
    print(f"Minimum: {torch.min(error)}, Maximum: {torch.max(error)}, Median: {torch.median(error)}")

In [13]:
eval(resnetModel, test_set)

Accuracy: 3.1553398058252426%
Mean Absolute Error: 12.91626262664795
Minimum: 0.0, Maximum: 54.0, Median: 11.0


In [14]:
eval(resnetModel, train_set)

Accuracy: 3.6565288969759986%
Mean Absolute Error: 12.912399291992188
Minimum: 0.0, Maximum: 67.0, Median: 11.0
