In [1]:
import os
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import torchvision
from torchvision.transforms import ToTensor, Compose, Grayscale, Resize, CenterCrop
import matplotlib.pyplot as plt
from PIL import Image
from parse import parse
from custom_loader import AgeDBDataset
from custom_loss_functions import AngularPenaltySMLoss

import warnings
warnings.filterwarnings("ignore")

In [2]:
# hyper params
num_of_class = 79
hidden_unit = 256
learning_rate = 1e-04
batch_size = 64
device = torch.device("cuda")

In [3]:
dataset = AgeDBDataset(
    directory = 'AgeDB/',
    transform = Compose([
        Resize(size=(64,64)),
        CenterCrop(size=64),
        Grayscale(num_output_channels=1),
        ToTensor(),
    ]),
    device = device,
)

In [4]:
len(dataset)

15510

In [5]:
train_set, validation_set, test_set = dataset.get_loaders(
    batch_size=batch_size,
    train_size=0.8,
    test_size=0.2,
)

In [6]:
class AgeDBConvModel(nn.Module):
    def __init__(self, num_of_classes):
        super(AgeDBConvModel, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=2), #(64+2(2)-3)+1=66
            nn.BatchNorm2d(num_features=32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2) #((66-2)/2)+1 = 33
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=2), #(33+2(2)-3)+1=35
            nn.BatchNorm2d(num_features=64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2) #((35-2)/2)+1 = 17 + 0.5
        )
        self.layer3 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=2), #(17+2(2)-3)+1=19
            nn.BatchNorm2d(num_features=128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2) #((19-2)/2)+1 = 9 + 0.5
        )
        self.layer4 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=2), #(9+2(2)-3)+1=11
            nn.BatchNorm2d(num_features=256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2) #((11-2)/2)+1 = 5 + 0.5
        )
        self.layer5 = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=2), #(5+2(2)-3)+1=7
            nn.BatchNorm2d(num_features=512),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2) #((7-2)/2)+1 = 3 + 0.5
        )
        self.fc1 = nn.Linear(3*3*512, num_of_classes)
        
        
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.layer5(out)

        out = out.reshape(out.size(0), -1)
        out = self.fc1(out)
        
        return out

In [7]:
convModel = AgeDBConvModel(num_of_class).cuda()

In [8]:
# Training loop
def train(model, optimizer, criterion, train_loader, num_of_epoch):
    total_step = len(train_loader)
    for epoch in range(num_of_epoch):
        for i, (imgs, labels) in enumerate(train_loader):
            imgs = imgs.to(device)
            labels = torch.as_tensor(labels['age']).to(device)
            
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if (i+1)%total_step == 0:
                print(f"Epoch: {epoch+1}/{num_of_epoch}, Step: {i+1}/{total_step}, Loss: {loss.item()}")

# Evaluation
def eval(model, test_loader):
    with torch.no_grad():
        correct = 0
        total = 0
        error = torch.zeros(0).to(device)
        
        for imgs, labels in test_loader:
            imgs = imgs.to(device)
            labels = torch.as_tensor(labels['age']).to(device)
            outputs = model(imgs)
            
            _, pred = torch.max(outputs.data, 1)

            error = torch.cat([error, torch.abs(
                torch.subtract(torch.reshape(labels, (-1,)), torch.reshape(pred, (-1,)))
            )])
            
            total += labels.size(0)
            correct += (pred == labels).sum().item()
            
    print(f"Accuracy: {(100*correct)/total}%")
    print(f"Mean Absolute Error: {(torch.mean(error))}")
    print(f"Minimum: {torch.min(error)}, Maximum: {torch.max(error)}, Median: {torch.median(error)}")

In [9]:
criteria = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(convModel.parameters(), lr=learning_rate)

In [10]:
train(convModel, optimizer, criteria, train_set, num_of_epoch=20)

Epoch: 1/20, Step: 194/194, Loss: 3.936713933944702
Epoch: 2/20, Step: 194/194, Loss: 3.7707858085632324
Epoch: 3/20, Step: 194/194, Loss: 3.557001829147339
Epoch: 4/20, Step: 194/194, Loss: 3.2541587352752686
Epoch: 5/20, Step: 194/194, Loss: 2.857527256011963
Epoch: 6/20, Step: 194/194, Loss: 2.3499131202697754
Epoch: 7/20, Step: 194/194, Loss: 1.7960845232009888
Epoch: 8/20, Step: 194/194, Loss: 1.3009862899780273
Epoch: 9/20, Step: 194/194, Loss: 0.9234647154808044
Epoch: 10/20, Step: 194/194, Loss: 0.6881670951843262
Epoch: 11/20, Step: 194/194, Loss: 0.5133230090141296
Epoch: 12/20, Step: 194/194, Loss: 0.385812908411026
Epoch: 13/20, Step: 194/194, Loss: 0.3244851529598236
Epoch: 14/20, Step: 194/194, Loss: 0.24301914870738983
Epoch: 15/20, Step: 194/194, Loss: 0.13887397944927216
Epoch: 16/20, Step: 194/194, Loss: 0.08338270336389542
Epoch: 17/20, Step: 194/194, Loss: 0.06885034590959549
Epoch: 18/20, Step: 194/194, Loss: 0.05543256551027298
Epoch: 19/20, Step: 194/194, Loss: 0

In [11]:
eval(convModel, test_set)

Accuracy: 4.4165054803352675%
Mean Absolute Error: 10.839458465576172
Minimum: 0.0, Maximum: 56.0, Median: 9.0


In [12]:
eval(convModel, train_set)

Accuracy: 99.59703417150226%
Mean Absolute Error: 0.0465022549033165
Minimum: 0.0, Maximum: 34.0, Median: 0.0
