In [1]:
import os
import torch
import torchvision
from torch import nn
from PIL import Image
from parse import parse
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import ToTensor, Compose, Grayscale, Resize, CenterCrop

import warnings
warnings.filterwarnings("ignore")

In [2]:
# hyper params
num_of_class = 102
learning_rate = 1e-04
batch_size = 256
device = torch.device("cuda")

In [3]:
class AgeDBDataset(Dataset):
    def __init__(self, directory, transform, preload=False, device: torch.device = torch.device('cpu'), **kwargs):
        self.device = device
        self.directory = directory
        self.transform = transform
        self.labels = []
        self.images = []
        self.preload = preload

        for i, file in enumerate(os.listdir(self.directory)):
            file_labels = parse('{}_{}_{age}_{gender}.jpg', file)
            
            if file_labels is None:
                continue
                
            if self.preload:
                image = Image.open(os.path.join(self.directory, file)).convert('RGB')
                if self.transform is not None:
                    image = self.transform(image).to(self.device)
            
            else:
                image = os.path.join(self.directory, file)
                
            
            gender_to_class_id = {
                'm': 0, 
                'f': 1
            }
            
            gender = gender_to_class_id[file_labels['gender']]
            age = int(file_labels['age'])
            
            self.images.append(image)
            self.labels.append({
                'age': age,
                'gender': gender
            })
            
        pass

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        image = self.images[idx]

        if not self.preload:
            image = Image.open(image).convert('RGB')
            if self.transform is not None:
                image = self.transform(image).to(self.device)

        labels = {
            'age': self.labels[idx]['age'], 
            'gender': self.labels[idx]['gender'],
        }
        return image.to(self.device), labels
    
    def get_loaders(self, batch_size, train_size=0.7, test_size=0.2, **kwargs):
        train_len = int(len(self) * train_size)
        test_len = int(len(self) * test_size)
        validate_len = len(self) - (train_len + test_len)
        
        self.trainDataset, self.validateDataset, self.testDataset = torch.utils.data.random_split(
            dataset = self, 
            lengths = [train_len, validate_len, test_len], 
            generator = torch.Generator().manual_seed(42)
        )

        train_loader = DataLoader(self.trainDataset, batch_size=batch_size)
        validate_loader = DataLoader(self.validateDataset, batch_size=batch_size)
        test_loader = DataLoader(self.testDataset, batch_size=batch_size)

        return train_loader, validate_loader, test_loader

In [4]:
dataset = AgeDBDataset(
    directory = 'AgeDB/',
    transform = Compose([
        Resize(size=(64, 64)),
        CenterCrop(size=64),
        Grayscale(num_output_channels=1),
        ToTensor(),
    ]),
    device = device,
)

In [5]:
len(dataset)

16488

In [6]:
train_set, validation_set, test_set = dataset.get_loaders(
    batch_size=batch_size,
    train_size=0.8,
    test_size=0.2,
)

In [7]:
class AgeDBConvModel(nn.Module):
    def __init__(self, num_of_classes):
        super(AgeDBConvModel, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=2), #(64+2(2)-3)+1=66
            nn.BatchNorm2d(num_features=32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2) #((66-2)/2)+1 = 33
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=2), #(33+2(2)-3)+1=35
            nn.BatchNorm2d(num_features=64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2) #((35-2)/2)+1 = 17 + 0.5
        )
        self.layer3 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=2), #(17+2(2)-3)+1=19
            nn.BatchNorm2d(num_features=128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2) #((19-2)/2)+1 = 9 + 0.5
        )
        self.layer4 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=2), #(9+2(2)-3)+1=11
            nn.BatchNorm2d(num_features=256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2) #((11-2)/2)+1 = 5 + 0.5
        )
        self.layer5 = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=2), #(5+2(2)-3)+1=7
            nn.BatchNorm2d(num_features=512),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2) #((7-2)/2)+1 = 3 + 0.5
        )
        self.fc = nn.Linear(3*3*512, num_of_classes)
        
        
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.layer5(out)

        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        
        return out

In [8]:
convModel = AgeDBConvModel(num_of_class).to(device)

In [9]:
# Training loop
def train(model, optimizer, criterion, train_loader, num_of_epoch):
    total_step = len(train_loader)
    for epoch in range(num_of_epoch):
        for i, (imgs, labels) in enumerate(train_loader):
            imgs = imgs.to(device)
            labels = torch.as_tensor(labels['age']).to(device)
            
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if (i+1)%total_step == 0:
                print(f"Epoch: {epoch+1}/{num_of_epoch}, Step: {i+1}/{total_step}, Loss: {loss.item()}")

# Evaluation
def eval(model, test_loader):
    with torch.no_grad():
        correct = 0
        total = 0
        error = torch.zeros(0).to(device)
        
        for imgs, labels in test_loader:
            imgs = imgs.to(device)
            labels = torch.as_tensor(labels['age']).to(device)
            outputs = model(imgs)
            
            _, pred = torch.max(outputs.data, 1)

            error = torch.cat([error, torch.abs(
                torch.subtract(torch.reshape(labels, (-1,)), torch.reshape(pred, (-1,)))
            )])
            
            total += labels.size(0)
            correct += (pred == labels).sum().item()
            
    print(f"Accuracy: {(100*correct)/total}%")
    print(f"Mean Absolute Error: {(torch.mean(error))}")
    print(f"Minimum: {torch.min(error)}, Maximum: {torch.max(error)}, Median: {torch.median(error)}")

In [10]:
criteria = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(convModel.parameters(), lr=learning_rate)

In [11]:
train(convModel, optimizer, criteria, train_set, num_of_epoch=10)

Epoch: 1/10, Step: 52/52, Loss: 4.208644866943359
Epoch: 2/10, Step: 52/52, Loss: 3.958613634109497
Epoch: 3/10, Step: 52/52, Loss: 3.6929430961608887
Epoch: 4/10, Step: 52/52, Loss: 3.3307101726531982
Epoch: 5/10, Step: 52/52, Loss: 2.921128273010254
Epoch: 6/10, Step: 52/52, Loss: 2.4748032093048096
Epoch: 7/10, Step: 52/52, Loss: 2.039329767227173
Epoch: 8/10, Step: 52/52, Loss: 1.6440496444702148
Epoch: 9/10, Step: 52/52, Loss: 1.3424772024154663
Epoch: 10/10, Step: 52/52, Loss: 1.1302732229232788


In [12]:
eval(convModel, test_set)

Accuracy: 3.002729754322111%
Mean Absolute Error: 12.138307571411133
Minimum: 0.0, Maximum: 63.0, Median: 10.0


In [13]:
eval(convModel, train_set)

Accuracy: 64.11675511751326%
Mean Absolute Error: 5.237679958343506
Minimum: 0.0, Maximum: 63.0, Median: 0.0
