In [1]:
import os
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import torchvision
from torchvision.transforms import ToTensor, Compose, Grayscale, Resize, CenterCrop
import matplotlib.pyplot as plt
from PIL import Image
from parse import parse
from custom_loader import AgeDBDataset
from custom_loss_functions import AngularPenaltySMLoss
from tqdm import tqdm
import pretrainedmodels

import warnings
warnings.filterwarnings("ignore")

In [2]:
# hyper params
num_of_class = 79
hidden_unit = 256
learning_rate = 1e-04
batch_size = 24
device = torch.device("cuda")

In [3]:
dataset = AgeDBDataset(
    directory = 'cropped_face/',
    transform = Compose([
        Resize(size=(299,299)),
        CenterCrop(size=299),
        Grayscale(num_output_channels=3),
        ToTensor(),
    ]),
    device = device,
)

In [4]:
len(dataset)

15508

In [5]:
train_set, validation_set, test_set = dataset.get_loaders(
    batch_size=batch_size,
    train_size=0.8,
    test_size=0.2,
)

In [6]:
model = pretrainedmodels.inceptionv4(pretrained='imagenet')
model

InceptionV4(
  (features): Sequential(
    (0): BasicConv2d(
      (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
      (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (1): BasicConv2d(
      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (2): BasicConv2d(
      (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (3): Mixed_3a(
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      (conv): BasicConv2d(
        (conv): Conv2d(64, 96, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (bn): BatchNorm2d(96, eps=0.001, mo

In [7]:
model.last_linear = nn.Linear(in_features=model.last_linear.in_features, out_features=num_of_class, bias=True)

In [8]:
convModel = model.to(device)

In [9]:
# Training loop
def train(model, optimizer, criterion, train_loader, num_of_epoch):
    total_step = len(train_loader)
    for epoch in range(num_of_epoch):
        for i, (imgs, labels) in enumerate(train_loader):
            imgs = imgs.to(device)
            labels = torch.as_tensor(labels['age']).to(device)
            
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if (i+1)%total_step == 0:
                print(f"Epoch: {epoch+1}/{num_of_epoch}, Step: {i+1}/{total_step}, Loss: {loss.item()}")

In [10]:
criteria = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(convModel.parameters(), lr=learning_rate)

In [11]:
train(convModel, optimizer, criteria, train_set, num_of_epoch=15)

Epoch: 1/15, Step: 517/517, Loss: 3.482811212539673
Epoch: 2/15, Step: 517/517, Loss: 3.2595274448394775
Epoch: 3/15, Step: 517/517, Loss: 2.922412633895874
Epoch: 4/15, Step: 517/517, Loss: 2.4428319931030273
Epoch: 5/15, Step: 517/517, Loss: 1.6964678764343262
Epoch: 6/15, Step: 517/517, Loss: 0.5527765154838562
Epoch: 7/15, Step: 517/517, Loss: 0.39597511291503906
Epoch: 8/15, Step: 517/517, Loss: 0.3004912734031677
Epoch: 9/15, Step: 517/517, Loss: 0.1505170464515686
Epoch: 10/15, Step: 517/517, Loss: 0.02591620571911335
Epoch: 11/15, Step: 517/517, Loss: 0.16811369359493256
Epoch: 12/15, Step: 517/517, Loss: 0.41391265392303467
Epoch: 13/15, Step: 517/517, Loss: 0.1042826697230339
Epoch: 14/15, Step: 517/517, Loss: 0.022682473063468933
Epoch: 15/15, Step: 517/517, Loss: 0.13110199570655823


In [15]:
# Evaluation
def eval(model, test_loader):
    with torch.no_grad():
        correct = 0
        total = 0
        error = torch.zeros(0).to(device)
        for imgs, labels in test_loader:
            imgs = imgs.to(device)
            labels = torch.as_tensor(labels['age']).to(device)
            outputs = model(imgs)
            
            _, pred = torch.max(outputs.data, 1)

            error = torch.cat([error, torch.abs(
                torch.subtract(torch.reshape(labels, (-1,)), torch.reshape(pred, (-1,)))
            )])
            
            total += labels.size(0)
            correct += (pred == labels).sum().item()
            
    print(f"Accuracy: {(100*correct)/total}%")
    print(f"Mean Absolute Error: {(torch.mean(error))}")
    print(f"Minimum: {torch.min(error)}, Maximum: {torch.max(error)}, Median: {torch.median(error)}")

In [16]:
eval(convModel, test_set)

Accuracy: 4.998387616897775%
Mean Absolute Error: 8.701386451721191
Minimum: 0.0, Maximum: 54.0, Median: 7.0


In [14]:
eval(convModel, train_set)

Accuracy: 90.99629211671771%
Mean Absolute Error: 0.6386426091194153%
Minimum: 0.0, Maximum: 45.0, Median: 0.0
