In [1]:
import os
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import torchvision
from torchvision.transforms import ToTensor, Compose, Grayscale, Resize, CenterCrop
import matplotlib.pyplot as plt
from PIL import Image
from parse import parse
from custom_loader import AgeDBDataset
from custom_loss_functions import AngularPenaltySMLoss
from tqdm import tqdm
import pretrainedmodels

import warnings
warnings.filterwarnings("ignore")

In [2]:
# hyper params
num_of_class = 79
hidden_unit = 256
learning_rate = 1e-04
batch_size = 24
device = torch.device("cuda")

In [3]:
dataset = AgeDBDataset(
    directory = 'cropped_face/',
    transform = Compose([
        Resize(size=(299,299)),
        CenterCrop(size=299),
        Grayscale(num_output_channels=3),
        ToTensor(),
    ]),
    device = device,
)

In [4]:
len(dataset)

15508

In [5]:
train_set, validation_set, test_set = dataset.get_loaders(
    batch_size=batch_size,
    train_size=0.8,
    test_size=0.2,
)

In [6]:
model = pretrainedmodels.inceptionv4(pretrained='imagenet')
model

InceptionV4(
  (features): Sequential(
    (0): BasicConv2d(
      (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
      (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (1): BasicConv2d(
      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (2): BasicConv2d(
      (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (3): Mixed_3a(
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      (conv): BasicConv2d(
        (conv): Conv2d(64, 96, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (bn): BatchNorm2d(96, eps=0.001, mo

In [7]:
model.last_linear = nn.Linear(in_features=model.last_linear.in_features, out_features=num_of_class, bias=True)

In [8]:
convModel = model.to(device)

In [9]:
# Training loop
def train(model, optimizer, criterion, train_loader, num_of_epoch):
    total_step = len(train_loader)
    for epoch in range(num_of_epoch):
        for i, (imgs, labels) in tqdm(enumerate(train_loader)):
            imgs = imgs.to(device)
            labels = torch.as_tensor(labels['age']).to(device)
            
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if (i+1)%total_step == 0:
                print(f"Epoch: {epoch+1}/{num_of_epoch}, Step: {i+1}/{total_step}, Loss: {loss.item()}")

In [10]:
criteria = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(convModel.parameters(), lr=learning_rate)

In [11]:
train(convModel, optimizer, criteria, train_set, num_of_epoch=50)

517it [03:50,  2.24it/s]
0it [00:00, ?it/s]

Epoch: 1/50, Step: 517/517, Loss: 3.545437812805176


517it [03:48,  2.26it/s]
0it [00:00, ?it/s]

Epoch: 2/50, Step: 517/517, Loss: 3.2268104553222656


517it [03:48,  2.26it/s]
0it [00:00, ?it/s]

Epoch: 3/50, Step: 517/517, Loss: 2.69075608253479


517it [03:48,  2.26it/s]
0it [00:00, ?it/s]

Epoch: 4/50, Step: 517/517, Loss: 2.0198628902435303


517it [03:48,  2.26it/s]
0it [00:00, ?it/s]

Epoch: 5/50, Step: 517/517, Loss: 1.054639458656311


517it [03:47,  2.27it/s]
0it [00:00, ?it/s]

Epoch: 6/50, Step: 517/517, Loss: 0.45448410511016846


517it [03:47,  2.27it/s]
0it [00:00, ?it/s]

Epoch: 7/50, Step: 517/517, Loss: 0.24806822836399078


517it [03:47,  2.27it/s]
0it [00:00, ?it/s]

Epoch: 8/50, Step: 517/517, Loss: 0.2924875020980835


517it [03:47,  2.27it/s]
0it [00:00, ?it/s]

Epoch: 9/50, Step: 517/517, Loss: 0.1531061977148056


517it [03:47,  2.27it/s]
0it [00:00, ?it/s]

Epoch: 10/50, Step: 517/517, Loss: 0.36699149012565613


517it [03:47,  2.27it/s]
0it [00:00, ?it/s]

Epoch: 11/50, Step: 517/517, Loss: 0.12391333281993866


517it [03:47,  2.27it/s]
0it [00:00, ?it/s]

Epoch: 12/50, Step: 517/517, Loss: 0.0992913693189621


517it [03:47,  2.27it/s]
0it [00:00, ?it/s]

Epoch: 13/50, Step: 517/517, Loss: 0.17837458848953247


517it [03:47,  2.27it/s]
0it [00:00, ?it/s]

Epoch: 14/50, Step: 517/517, Loss: 0.019467316567897797


517it [03:47,  2.27it/s]
0it [00:00, ?it/s]

Epoch: 15/50, Step: 517/517, Loss: 0.08041799068450928


517it [03:47,  2.27it/s]
0it [00:00, ?it/s]

Epoch: 16/50, Step: 517/517, Loss: 0.11591613292694092


517it [03:47,  2.27it/s]
0it [00:00, ?it/s]

Epoch: 17/50, Step: 517/517, Loss: 0.028611166402697563


517it [03:47,  2.27it/s]
0it [00:00, ?it/s]

Epoch: 18/50, Step: 517/517, Loss: 0.265715628862381


517it [03:47,  2.27it/s]
0it [00:00, ?it/s]

Epoch: 19/50, Step: 517/517, Loss: 0.12697309255599976


517it [03:47,  2.28it/s]
0it [00:00, ?it/s]

Epoch: 20/50, Step: 517/517, Loss: 0.0636703297495842


517it [03:47,  2.27it/s]
0it [00:00, ?it/s]

Epoch: 21/50, Step: 517/517, Loss: 0.050911568105220795


517it [03:47,  2.27it/s]
0it [00:00, ?it/s]

Epoch: 22/50, Step: 517/517, Loss: 0.0435122586786747


517it [03:47,  2.27it/s]
0it [00:00, ?it/s]

Epoch: 23/50, Step: 517/517, Loss: 0.1494912952184677


517it [03:47,  2.27it/s]
0it [00:00, ?it/s]

Epoch: 24/50, Step: 517/517, Loss: 0.094607874751091


517it [03:47,  2.27it/s]
0it [00:00, ?it/s]

Epoch: 25/50, Step: 517/517, Loss: 0.024844394996762276


517it [03:47,  2.27it/s]
0it [00:00, ?it/s]

Epoch: 26/50, Step: 517/517, Loss: 0.01467878557741642


517it [03:47,  2.27it/s]
0it [00:00, ?it/s]

Epoch: 27/50, Step: 517/517, Loss: 0.0376269556581974


517it [03:47,  2.27it/s]
0it [00:00, ?it/s]

Epoch: 28/50, Step: 517/517, Loss: 0.22084353864192963


517it [03:47,  2.27it/s]
0it [00:00, ?it/s]

Epoch: 29/50, Step: 517/517, Loss: 0.16173651814460754


517it [03:47,  2.27it/s]
0it [00:00, ?it/s]

Epoch: 30/50, Step: 517/517, Loss: 0.02060951478779316


517it [03:47,  2.27it/s]
0it [00:00, ?it/s]

Epoch: 31/50, Step: 517/517, Loss: 0.024817459285259247


517it [03:47,  2.27it/s]
0it [00:00, ?it/s]

Epoch: 32/50, Step: 517/517, Loss: 0.1712304949760437


517it [03:47,  2.27it/s]
0it [00:00, ?it/s]

Epoch: 33/50, Step: 517/517, Loss: 0.12153548747301102


517it [03:47,  2.27it/s]
0it [00:00, ?it/s]

Epoch: 34/50, Step: 517/517, Loss: 0.0968790277838707


517it [03:47,  2.27it/s]
0it [00:00, ?it/s]

Epoch: 35/50, Step: 517/517, Loss: 0.040438614785671234


517it [03:47,  2.27it/s]
0it [00:00, ?it/s]

Epoch: 36/50, Step: 517/517, Loss: 0.0037302791606634855


517it [03:47,  2.27it/s]
0it [00:00, ?it/s]

Epoch: 37/50, Step: 517/517, Loss: 0.04770594462752342


517it [03:47,  2.27it/s]
0it [00:00, ?it/s]

Epoch: 38/50, Step: 517/517, Loss: 0.028204644098877907


517it [03:47,  2.27it/s]
0it [00:00, ?it/s]

Epoch: 39/50, Step: 517/517, Loss: 0.002599153434857726


517it [03:47,  2.27it/s]
0it [00:00, ?it/s]

Epoch: 40/50, Step: 517/517, Loss: 0.04851898178458214


517it [03:47,  2.27it/s]
0it [00:00, ?it/s]

Epoch: 41/50, Step: 517/517, Loss: 0.025294460356235504


517it [03:47,  2.27it/s]
0it [00:00, ?it/s]

Epoch: 42/50, Step: 517/517, Loss: 0.05579092353582382


517it [03:47,  2.27it/s]
0it [00:00, ?it/s]

Epoch: 43/50, Step: 517/517, Loss: 0.02293873392045498


517it [03:47,  2.27it/s]
0it [00:00, ?it/s]

Epoch: 44/50, Step: 517/517, Loss: 0.26029741764068604


517it [03:47,  2.27it/s]
0it [00:00, ?it/s]

Epoch: 45/50, Step: 517/517, Loss: 0.016068579629063606


517it [03:47,  2.27it/s]
0it [00:00, ?it/s]

Epoch: 46/50, Step: 517/517, Loss: 0.005325322970747948


517it [03:47,  2.27it/s]
0it [00:00, ?it/s]

Epoch: 47/50, Step: 517/517, Loss: 0.002260106150060892


517it [03:47,  2.27it/s]
0it [00:00, ?it/s]

Epoch: 48/50, Step: 517/517, Loss: 0.28124749660491943


517it [03:47,  2.27it/s]
0it [00:00, ?it/s]

Epoch: 49/50, Step: 517/517, Loss: 0.16303469240665436


517it [03:47,  2.27it/s]

Epoch: 50/50, Step: 517/517, Loss: 0.002934138523414731





In [12]:
# Evaluation
def eval(model, test_loader):
    with torch.no_grad():
        correct = 0
        total = 0
        for imgs, labels in (test_loader):
            imgs = imgs.to(device)
            labels = torch.as_tensor(labels['age']).to(device)
            outputs = model(imgs)
            
            _, pred = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (pred == labels).sum().item()
            
    print(f"Accuracy: {(100*correct)/total}%")


In [13]:
eval(convModel, test_set)

Accuracy: 4.837149306675266%


In [14]:
eval(convModel, train_set)

Accuracy: 99.04884733193616%
