In [1]:
import os
import sys
import numpy as np
from PIL import Image
from tqdm import tqdm
from parse import parse
from autocrop import Cropper
from IPython.display import clear_output

import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader, Dataset
import pretrainedmodels
import torchvision
from torchvision import datasets
from torchvision.transforms import ToTensor, Compose, Scale, Grayscale, Resize, transforms
import matplotlib.pyplot as plt
import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2
from custom_loader import AgeDBDataset
import warnings
warnings.filterwarnings("ignore")

In [2]:
# hyper params
num_of_class = 102
learning_rate = 1e-03
batch_size = 64
input_size = 64
device = torch.device("cuda")

In [3]:
transformA = A.Compose([
    A.Resize(input_size, input_size),
    A.ToGray(p=1),
    A.Rotate(limit=10, p=0.3),
    A.HorizontalFlip(p=0.4),
    A.OpticalDistortion(p=0.2),
    A.OneOf([
        A.Blur(blur_limit=3, p=0.2),
        A.ColorJitter(p=0.2),
    ], p=0.2),
    ToTensorV2()
])

In [4]:
### --- AgeDB Dataset Class --- ###           {}_{person}_{age}_{gender}.jpg


class AgeDBDataset(Dataset):

    ## data loading
    def __init__(self, directory, device, transform=None, **kwargs):
        self.directory = directory
        self.transform = transform
        self.device = device
        self.labels = []
        self.images = []

        gender_to_class_id = {'m': 0, 'f': 1}

        for i, file in enumerate(sorted(os.listdir(self.directory))):
            file_labels = parse('{}_{}_{age}_{gender}.jpg', file)

            if file_labels is None:
                continue

            image = Image.open(os.path.join(self.directory,
                                            file)).convert('RGB')

            ########
            cropper = Cropper()

            try:
                #Get a Numpy array of the cropped image
                cropped_array = cropper.crop(image)
                #Save the cropped image with PIL
                image = Image.fromarray(cropped_array)

            except:
                pass

            image = np.array(image)

            for i in range(10):
                augmented_images = self.transform(image=image)['image']
                self.images.append(augmented_images)
                ########
                gender = gender_to_class_id[file_labels['gender']]
                age = int(file_labels['age'])
                self.labels.append({
                    'age': age,
                    'gender': gender
                })

## len(dataset)

    def __len__(self):
        return len(self.labels)

## dataset[0]

    def __getitem__(self, index):
        if torch.is_tensor(index):
            index = index.tolist()

        image = self.images[index]

        labels = {
            'age': self.labels[index]['age'],
            'gender': self.labels[index]['gender']
        }

        return image.to(self.device), labels


## DataLoaders - train, validate, test

    def get_loaders(self, batch_size, train_size, test_size, random_seed,
                    **kwargs):
        train_len = int(len(self) * train_size)
        test_len = int(len(self) * test_size)
        validate_len = len(self) - (train_len + test_len)

        self.trainDataset, self.validateDataset, self.testDataset = torch.utils.data.random_split(
            dataset=self,
            lengths=[train_len, validate_len, test_len],
            generator=torch.Generator().manual_seed(random_seed))

        train_loader = DataLoader(self.trainDataset, batch_size=batch_size)
        validate_loader = DataLoader(self.validateDataset,
                                     batch_size=batch_size)
        test_loader = DataLoader(self.testDataset, batch_size=batch_size)

        return train_loader, validate_loader, test_loader

In [5]:
dataset = AgeDBDataset(
    directory='AgeDB/',
    transform=transformA,
    device=device,
)

In [6]:
train_set, validation_set, test_set = dataset.get_loaders(
    batch_size=batch_size,
    train_size=0.8,
    test_size=0.2,
    random_seed=42,
)

In [7]:
len(dataset)

164880

In [12]:
resnetModel = pretrainedmodels.resnet34(pretrained='imagenet')
resnetModel.last_linear = nn.Linear(in_features=resnetModel.last_linear.in_features, out_features=num_of_class, bias=True)
resnetModel = resnetModel.to(device)

In [13]:
# Training loop
def train(model, optimizer, criterion, train_loader, num_of_epoch):
    total_step = len(train_loader)
    for epoch in range(num_of_epoch):
        for i, (imgs, labels) in enumerate(train_loader):
            imgs = imgs.to(device).float()
            labels = torch.as_tensor(labels['age']).to(device)
            
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if (i+1)%total_step == 0:
                print(f"Epoch: {epoch+1}/{num_of_epoch}, Step: {i+1}/{total_step}, Loss: {loss.item()}")

In [14]:
criteria = nn.CrossEntropyLoss() 
optimizer = torch.optim.Adam(resnetModel.parameters(), lr=learning_rate)

In [15]:
train(resnetModel, optimizer, criteria, train_set, num_of_epoch=20)

Epoch: 1/20, Step: 2061/2061, Loss: 3.8577325344085693
Epoch: 2/20, Step: 2061/2061, Loss: 3.726562023162842
Epoch: 3/20, Step: 2061/2061, Loss: 3.5070724487304688
Epoch: 4/20, Step: 2061/2061, Loss: 2.6795244216918945
Epoch: 5/20, Step: 2061/2061, Loss: 0.9963527917861938
Epoch: 6/20, Step: 2061/2061, Loss: 0.16348445415496826
Epoch: 7/20, Step: 2061/2061, Loss: 0.13690418004989624
Epoch: 8/20, Step: 2061/2061, Loss: 0.33322766423225403
Epoch: 9/20, Step: 2061/2061, Loss: 0.11664505302906036
Epoch: 10/20, Step: 2061/2061, Loss: 0.09028708189725876
Epoch: 11/20, Step: 2061/2061, Loss: 0.07487917691469193
Epoch: 12/20, Step: 2061/2061, Loss: 0.12027671188116074
Epoch: 13/20, Step: 2061/2061, Loss: 0.08534736186265945
Epoch: 14/20, Step: 2061/2061, Loss: 0.02192077599465847
Epoch: 15/20, Step: 2061/2061, Loss: 0.013929445296525955
Epoch: 16/20, Step: 2061/2061, Loss: 0.04301967844367027
Epoch: 17/20, Step: 2061/2061, Loss: 0.02017836831510067
Epoch: 18/20, Step: 2061/2061, Loss: 0.009957

In [16]:
# Evaluation
def eval(model, test_loader):
    with torch.no_grad():
        correct = 0
        total = 0
        error = torch.zeros(0).to(device)
        for imgs, labels in test_loader:
            imgs = imgs.to(device).float()
            labels = torch.as_tensor(labels['age']).to(device)
            outputs = model(imgs)
            
            _, pred = torch.max(outputs.data, 1)

            error = torch.cat([error, torch.abs(
                torch.subtract(torch.reshape(labels, (-1,)), torch.reshape(pred, (-1,)))
            )])
                        
            total += labels.size(0)
            correct += (pred == labels).sum().item()
            
    print(f"Accuracy: {(100*correct)/total}%")
    print(f"Mean Absolute Error: {(torch.mean(error))}")
    print(f"Minimum: {torch.min(error)}, Maximum: {torch.max(error)}, Median: {torch.median(error)}")

In [17]:
eval(resnetModel, test_set)

Accuracy: 91.62421154779233%
Mean Absolute Error: 0.9190016984939575
Minimum: 0.0, Maximum: 62.0, Median: 0.0


In [18]:
eval(resnetModel, train_set)

Accuracy: 99.05385735080058%
Mean Absolute Error: 0.09923125803470612
Minimum: 0.0, Maximum: 58.0, Median: 0.0


In [23]:
torch.save(resnetModel.state_dict(), 'ResNetModel.pth')