In [1]:
import os
import sys
import math
import copy
import numpy as np
from PIL import Image
from tqdm import tqdm
from parse import parse
import pretrainedmodels

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

import torchvision
from torchvision import datasets
import torchvision.models as models
from torchvision.transforms import ToTensor, Compose, Scale, Grayscale, Resize, transforms

import albumentations as A
from albumentations.pytorch import ToTensorV2
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

In [2]:
# hyper params
num_of_class = 102
learning_rate = 1e-03
batch_size = 128
input_size = 224
device = torch.device("cuda")
train_size = 0.7
test_size = 0.2
directoryAgeDB = 'AgeDB/'

In [3]:
def imageList():
    image_list = []
    for i, file in enumerate(sorted(os.listdir(directoryAgeDB))):
        file_labels = parse('{}_{person}_{age}_{gender}.jpg', file)

        if file_labels is None:
            continue

        image_location = os.path.join(directoryAgeDB, file)
        gender_to_class_id = {'m': 0, 'f': 1}
        gender = gender_to_class_id[file_labels['gender']]
        age = int(file_labels['age'])
        image_list.append({
            'image_location': image_location,
            'age': age,
            'gender': gender
        })

    return image_list

image_list = imageList()

In [4]:
train_len = int(len(image_list) * train_size)
test_len = int(len(image_list) * test_size)
validate_len = len(image_list) - (train_len + test_len)

train_image_list, test_image_list, validate_image_list = torch.utils.data.random_split(
    dataset = image_list,
    lengths = [train_len, test_len, validate_len], 
    generator = torch.Generator().manual_seed(42)
)

print(len(train_image_list))
print(len(test_image_list))
print(len(validate_image_list))

11541
3297
1650


In [5]:
class AgeDBDataset(Dataset):
    
    def __init__(self, image_list, device, train = True, train_transform=None, test_transform=None, **kwargs):
        self.image_list = image_list
        self.device = device
        self.train = train
        self.train_transform = train_transform
        self.test_transform = test_transform
        self.labels = []
        self.images = []

        if self.train:
            for i in tqdm(range(len(image_list))):

                image = Image.open(self.image_list[i]['image_location']).convert('RGB')
                image_location = self.image_list[i]['image_location']
                age = self.image_list[i]['age']
                gender = self.image_list[i]['gender']

                image = np.array(image)

                for j in range(1):
                    if j == 0:
                        augmented_images = self.test_transform(image=image)['image']
                    else:
                        augmented_images = self.train_transform(image=image)['image']

                    self.images.append(augmented_images)
                    self.labels.append({
                        'image_location': image_location,
                        'age': age,
                        'gender': gender
                    })

    def __len__(self):

        if self.train:
            return len(self.labels)
        else:
            return len(self.image_list)

    def __getitem__(self, index):

        if torch.is_tensor(index):
            index = index.tolist()

        if self.train:
            image = self.images[index]
            labels = {
                'image_location': self.labels[index]['image_location'],
                'age': self.labels[index]['age'],
                'gender': self.labels[index]['gender']
            }

        else:
            image = Image.open(
                self.image_list[index]['image_location']).convert('RGB')
            image = np.array(image)
            image = self.test_transform(image=image)['image']
            labels = {
                'image_location': self.image_list[index]['image_location'],
                'age': self.image_list[index]['age'],
                'gender': self.image_list[index]['gender']
            }

        return image.to(self.device), labels

In [6]:
transform = A.Compose(
    [A.Resize(input_size, input_size),
     A.ToGray(p=1),
     ToTensorV2()
])

In [7]:
trainDataset = AgeDBDataset(image_list=train_image_list,
                            device=device,
                            train=True,
                            train_transform=transform,
                            test_transform=transform)
testDataset = AgeDBDataset(image_list=test_image_list,
                           device=device,
                           train=False,
                           train_transform=transform,
                           test_transform=transform)
valDataset = AgeDBDataset(image_list=validate_image_list,
                          device=device,
                          train=False,
                          train_transform=transform,
                          test_transform=transform)
print(len(trainDataset))
print(len(testDataset))
print(len(valDataset))
print(len(trainDataset) + len(testDataset) + len(valDataset))

100%|███████████████████████████████████████████████████████████████████████████| 11541/11541 [00:14<00:00, 824.36it/s]

11541
3297
1650
16488





In [8]:
train_loader = DataLoader(trainDataset, batch_size=batch_size, shuffle=True)

test_loader = DataLoader(testDataset, batch_size=batch_size, shuffle=False)

validation_loader = DataLoader(valDataset, batch_size=batch_size, shuffle=False)

In [9]:
ResNet18 = pretrainedmodels.resnet18()
ResNet18.last_linear = nn.Linear(
    in_features=ResNet18.last_linear.in_features, 
    out_features=num_of_class, 
    bias=False
)
resnetModel = ResNet18.to(device)

In [10]:
# Training loop
def train(model, optimizer, criterion, train_loader, valid_loader, num_of_epoch):
    total_step = len(train_loader)
    min_valid_loss = np.inf

    for epoch in range(num_of_epoch):
        train_loss = 0.0
        valid_loss = 0.0
        for i, (imgs, labels) in enumerate(train_loader):
            imgs = imgs.to(device).float()
            labels = torch.as_tensor(labels['age']).to(device)
            
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            
        for i, (imgs, labels) in enumerate(valid_loader):
            with torch.no_grad():
                imgs = imgs.to(device).float()
                labels = torch.as_tensor(labels['age']).to(device)
            
                outputs = model(imgs)
                loss = criterion(outputs, labels)
                valid_loss += loss.item()
        
        
        print(f"Epoch: {epoch+1}/{num_of_epoch}, Train Loss: {train_loss/len(train_loader)},  Validation Loss: {valid_loss/len(valid_loader)}")
              
        if min_valid_loss > (valid_loss/len(valid_loader)):
            print(f'Validation Loss Decreased({min_valid_loss} ---> {valid_loss/len(valid_loader)})')
            min_valid_loss = valid_loss/len(valid_loader)
            torch.save(resnetModel.state_dict(), 'baseline.pth')

In [11]:
criteria = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(resnetModel.parameters(), lr=learning_rate)

In [12]:
train(resnetModel, optimizer, criteria, train_loader, validation_loader, num_of_epoch=10)

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


Epoch: 1/10, Train Loss: 4.034296502123822,  Validation Loss: 3.9123041446392355
Validation Loss Decreased(inf ---> 3.9123041446392355)
Epoch: 2/10, Train Loss: 3.7345708254929426,  Validation Loss: 3.7281540540548472
Validation Loss Decreased(3.9123041446392355 ---> 3.7281540540548472)
Epoch: 3/10, Train Loss: 3.55403302528046,  Validation Loss: 3.73976940375108
Epoch: 4/10, Train Loss: 3.38581145464719,  Validation Loss: 3.7298734738276553
Epoch: 5/10, Train Loss: 3.2058243227529,  Validation Loss: 3.811650587962224
Epoch: 6/10, Train Loss: 2.8842546127654693,  Validation Loss: 4.032128664163443
Epoch: 7/10, Train Loss: 2.373264823641096,  Validation Loss: 4.358334174522986
Epoch: 8/10, Train Loss: 1.5351285541450583,  Validation Loss: 4.918973409212553
Epoch: 9/10, Train Loss: 0.684457887332518,  Validation Loss: 5.381367573371301
Epoch: 10/10, Train Loss: 0.1982329669562015,  Validation Loss: 5.636270523071289


In [13]:
# Evaluation
def eval(model, test_loader):
    with torch.no_grad():
        correct = 0
        total = 0
        error = torch.zeros(0).to(device)
        for imgs, labels in test_loader:
            imgs = imgs.to(device).float()
            labels = torch.as_tensor(labels['age']).to(device)
            outputs = model(imgs)
            
            _, pred = torch.max(outputs.data, 1)

            error = torch.cat([error, torch.abs(
                torch.subtract(torch.reshape(labels, (-1,)), torch.reshape(pred, (-1,)))
            )])
                        
            total += labels.size(0)
            correct += (pred == labels).sum().item()
            
    #print(f"Accuracy: {(100*correct)/total}%")
    print(f"Mean Absolute Error: {(torch.mean(error))}")
    print(f"Minimum: {torch.min(error)}, Maximum: {torch.max(error)}, Median: {torch.median(error)}")

In [14]:
eval(resnetModel, test_loader)
eval(resnetModel, train_loader)
resnetModel.load_state_dict(torch.load('baseline.pth'))

Mean Absolute Error: 7.84076452255249
Minimum: 0.0, Maximum: 53.0, Median: 6.0
Mean Absolute Error: 0.054934579879045486
Minimum: 0.0, Maximum: 36.0, Median: 0.0


<All keys matched successfully>

In [15]:
eval(resnetModel, test_loader)
eval(resnetModel, train_loader)

Mean Absolute Error: 7.791628837585449
Minimum: 0.0, Maximum: 50.0, Median: 6.0
Mean Absolute Error: 6.979377746582031
Minimum: 0.0, Maximum: 51.0, Median: 6.0
