In [1]:
import os
import sys
import numpy as np
from PIL import Image
from tqdm import tqdm
from parse import parse
from autocrop import Cropper
from IPython.display import clear_output

import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader, Dataset

import torchvision
from torchvision import datasets
from torchvision.transforms import ToTensor, Compose, Scale, Grayscale, Resize, transforms

import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2
from custom_loader import AgeDBDataset
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

In [2]:
# hyper params
num_of_class = 102
hidden_unit = 256
learning_rate = 1e-04
batch_size = 64
input_size = 64
device = torch.device("cuda")

In [3]:
transformA = A.Compose([
    A.Resize(input_size, input_size),
    A.ToGray(p=1),
    A.Rotate(limit=10, p=0.3),
    A.HorizontalFlip(p=0.4),
    A.OpticalDistortion(p=0.2),
    A.augmentations.transforms.ChannelDropout(p=1.0),
    A.OneOf([
        A.Blur(blur_limit=3, p=0.2),
        A.ColorJitter(p=0.2),
    ], p=0.2),
    ToTensorV2()
])

In [4]:
### --- AgeDB Dataset Class --- ###           {}_{person}_{age}_{gender}.jpg


class AgeDBDataset(Dataset):

    ## data loading
    def __init__(self, directory, device, transform=None, **kwargs):
        self.directory = directory
        self.transform = transform
        self.device = device
        self.labels = []
        self.images = []

        gender_to_class_id = {'m': 0, 'f': 1}

        for i, file in enumerate(sorted(os.listdir(self.directory))):
            file_labels = parse('{}_{}_{age}_{gender}.jpg', file)

            if file_labels is None:
                continue

            image = Image.open(os.path.join(self.directory,
                                            file)).convert('RGB')

            ########
            cropper = Cropper()

            try:
                #Get a Numpy array of the cropped image
                cropped_array = cropper.crop(image)
                #Save the cropped image with PIL
                image = Image.fromarray(cropped_array)

            except:
                pass

            image = np.array(image)

            augmented_images = self.transform(image=image)['image']
            self.images.append(augmented_images)
                ########
            gender = gender_to_class_id[file_labels['gender']]
            age = int(file_labels['age'])
            self.labels.append({
                'age': age,
                'gender': gender
            })

## len(dataset)

    def __len__(self):
        return len(self.labels)

## dataset[0]

    def __getitem__(self, index):
        if torch.is_tensor(index):
            index = index.tolist()

        image = self.images[index]

        labels = {
            'age': self.labels[index]['age'],
            'gender': self.labels[index]['gender']
        }

        return image.to(self.device), labels


## DataLoaders - train, validate, test

    def get_loaders(self, batch_size, train_size, test_size, random_seed,
                    **kwargs):
        train_len = int(len(self) * train_size)
        test_len = int(len(self) * test_size)
        validate_len = len(self) - (train_len + test_len)

        self.trainDataset, self.validateDataset, self.testDataset = torch.utils.data.random_split(
            dataset=self,
            lengths=[train_len, validate_len, test_len],
            generator=torch.Generator().manual_seed(random_seed))

        train_loader = DataLoader(self.trainDataset, batch_size=batch_size)
        validate_loader = DataLoader(self.validateDataset,
                                     batch_size=batch_size)
        test_loader = DataLoader(self.testDataset, batch_size=batch_size)

        return train_loader, validate_loader, test_loader

In [5]:
dataset = AgeDBDataset(
    directory='AgeDB/',
    transform=transformA,
    device=device,
)

In [6]:
train_set, validation_set, test_set = dataset.get_loaders(
    batch_size=batch_size,
    train_size=0.8,
    test_size=0.2,
    random_seed=42,
)

In [7]:
len(dataset)

16488

In [8]:
class AgeDBConvModel(nn.Module):
    def __init__(self, num_of_classes):
        super(AgeDBConvModel, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=2), #(64+2(2)-3)+1=66
            nn.BatchNorm2d(num_features=32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2) #((66-2)/2)+1 = 33
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=2), #(33+2(2)-3)+1=35
            nn.BatchNorm2d(num_features=64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2) #((35-2)/2)+1 = 17 + 0.5
        )
        self.layer3 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=2), #(17+2(2)-3)+1=19
            nn.BatchNorm2d(num_features=128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2) #((19-2)/2)+1 = 9 + 0.5
        )
        self.layer4 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=2), #(9+2(2)-3)+1=11
            nn.BatchNorm2d(num_features=256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2) #((11-2)/2)+1 = 5 + 0.5
        )
        self.layer5 = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=2), #(5+2(2)-3)+1=7
            nn.BatchNorm2d(num_features=512),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2) #((7-2)/2)+1 = 3 + 0.5
        )
        self.fc1 = nn.Linear(3*3*512, num_of_classes)
        
        
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.layer5(out)

        out = out.reshape(out.size(0), -1)
        out = self.fc1(out)
        
        return out

In [9]:
from kornia.losses import FocalLoss as focal_loss
def label_smoothing_criterion(alpha=0.1, distribution='uniform', std=0.5, reduction='mean'):
    def _label_smoothing_criterion(logits, labels):
        n_classes = logits.size(1)
        device = logits.device
        # manipulate labels
        one_hot = one_hot_encoding(labels, n_classes).float().to(device)
        if distribution == 'uniform':
            uniform = torch.ones_like(one_hot).to(device)/n_classes
            soft_labels = (1 - alpha)*one_hot + alpha*uniform
        elif distribution == 'gaussian':
            dist = get_gaussian_label_distribution(n_classes, std=std)
            soft_labels = torch.from_numpy(dist[labels.cpu().numpy()]).to(device)
        else:
            raise NotImplementedError

        loss = cross_entropy_loss_one_hot(logits, soft_labels.float(), reduction)

        return loss

    return _label_smoothing_criterion

def cost_sensitive_loss(input, target, M):
    if input.size(0) != target.size(0):
        raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).'
                         .format(input.size(0), target.size(0)))
    device = input.device
    M = M.to(device)
    return (M[target, :]*input.float()).sum(axis=-1)
    # return torch.diag(torch.matmul(input, M[:, target]))

class CostSensitiveRegularizedLoss(nn.Module):
    def __init__(self,  n_classes=5, exp=2, normalization='softmax', reduction='mean', base_loss='ce', lambd=10):
        super(CostSensitiveRegularizedLoss, self).__init__()
        if normalization == 'softmax':
            self.normalization = nn.Softmax(dim=1)
        elif normalization == 'sigmoid':
            self.normalization = nn.Sigmoid()
        else:
            self.normalization = None
        self.reduction = reduction
        x = np.abs(np.arange(n_classes, dtype=np.float32))
        M = np.abs((x[:, np.newaxis] - x[np.newaxis, :])) ** exp
        #
        # M_oph = np.array([
        #                 [1469, 4, 5,  0,  0],
        #                 [58, 62,  5,  0,  0],
        #                 [22, 3, 118,  1,  0],
        #                 [0, 0,   13, 36,  1],
        #                 [0, 0,    0,  1, 15]
        #                 ], dtype=np.float)
        # M_oph = M_oph.T
        # # Normalize M_oph to obtain M_difficulty:
        # M_difficulty = 1-np.divide(M_oph, np.sum(M_oph, axis=1)[:, None])
        # # OPTION 1: average M and M_difficulty:
        # M = 0.5 * M + 0.5 * M_difficulty
        # ################
        # # OPTION 2: replace uninformative entries in M_difficulty by entries of M:
        # # M_difficulty[M_oph == 0] = M[M_oph == 0]
        # # M = M_difficulty

        M /= M.max()
        self.M = torch.from_numpy(M)
        self.lambd = lambd
        self.base_loss = base_loss

        if self.base_loss == 'ce':
            self.base_loss = torch.nn.CrossEntropyLoss(reduction=reduction)
        elif self.base_loss == 'ls':
            self.base_loss = label_smoothing_criterion(distribution='uniform', reduction=reduction)
        elif self.base_loss == 'gls':
            self.base_loss = label_smoothing_criterion(distribution='gaussian', reduction=reduction)
        elif self.base_loss == 'focal_loss':
            kwargs = {"alpha": 0.5, "gamma": 2.0, "reduction": reduction}
            self.base_loss = focal_loss(**kwargs)
        else:
            sys.exit('not a supported base_loss')

    def forward(self, logits, target):
        base_l = self.base_loss(logits, target)
        if self.lambd == 0:
            return self.base_loss(logits, target)
        else:
            preds = self.normalization(logits)
            loss = cost_sensitive_loss(preds, target, self.M)
            if self.reduction == 'none':
                return base_l + self.lambd*loss
            elif self.reduction == 'mean':
                return base_l + self.lambd*loss.mean()
            elif self.reduction == 'sum':
                return base_l + self.lambd*loss.sum()
            else:
                raise ValueError('`reduction` must be one of \'none\', \'mean\', or \'sum\'.')

def get_cost_sensitive_criterion(n_classes=5, exp=2):
    train_criterion = CostSensitiveLoss(n_classes, exp=exp, normalization='softmax')
    val_criterion = CostSensitiveLoss(n_classes, exp=exp, normalization='softmax')
    return train_criterion, val_criterion

def get_cost_sensitive_regularized_criterion(base_loss='ce', n_classes=5, lambd=1, exp=2):
    train_criterion = CostSensitiveRegularizedLoss(n_classes, exp=exp, normalization='softmax', base_loss=base_loss, lambd=lambd)
    val_criterion = CostSensitiveRegularizedLoss(n_classes, exp=exp, normalization='softmax', base_loss=base_loss, lambd=lambd)

    return train_criterion, val_criterion

In [10]:
convModel = AgeDBConvModel(num_of_class).to(device)

  and should_run_async(code)


In [11]:
# Training loop
def train(model, optimizer, criterion, train_loader, num_of_epoch):
    total_step = len(train_loader)
    for epoch in range(num_of_epoch):
        for i, (imgs, labels) in enumerate(train_loader):
            imgs = imgs.to(device).float()
            labels = torch.as_tensor(labels['age']).to(device)
            
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if (i+1)%total_step == 0:
                print(f"Epoch: {epoch+1}/{num_of_epoch}, Step: {i+1}/{total_step}, Loss: {loss.item()}")

  and should_run_async(code)


In [12]:
criteria = CostSensitiveRegularizedLoss(n_classes=num_of_class, base_loss='focal_loss', reduction='sum') 
optimizer = torch.optim.Adam(convModel.parameters(), lr=learning_rate)

  and should_run_async(code)


In [13]:
train(convModel, optimizer, criteria, train_set, num_of_epoch=20)

  and should_run_async(code)
  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


Epoch: 1/20, Step: 207/207, Loss: 12.115411758422852
Epoch: 2/20, Step: 207/207, Loss: 11.964411735534668
Epoch: 3/20, Step: 207/207, Loss: 11.605792999267578
Epoch: 4/20, Step: 207/207, Loss: 11.180627822875977
Epoch: 5/20, Step: 207/207, Loss: 9.91624927520752
Epoch: 6/20, Step: 207/207, Loss: 8.220314025878906
Epoch: 7/20, Step: 207/207, Loss: 6.26442813873291
Epoch: 8/20, Step: 207/207, Loss: 4.426490783691406
Epoch: 9/20, Step: 207/207, Loss: 2.760051727294922
Epoch: 10/20, Step: 207/207, Loss: 2.126011371612549
Epoch: 11/20, Step: 207/207, Loss: 1.7183563709259033
Epoch: 12/20, Step: 207/207, Loss: 1.0723979473114014
Epoch: 13/20, Step: 207/207, Loss: 1.1674399375915527
Epoch: 14/20, Step: 207/207, Loss: 1.6420009136199951
Epoch: 15/20, Step: 207/207, Loss: 0.7914124727249146
Epoch: 16/20, Step: 207/207, Loss: 0.237967848777771
Epoch: 17/20, Step: 207/207, Loss: 0.2801882028579712
Epoch: 18/20, Step: 207/207, Loss: 0.29306694865226746
Epoch: 19/20, Step: 207/207, Loss: 0.17547188

In [14]:
# Evaluation
def eval(model, test_loader):
    with torch.no_grad():
        correct = 0
        total = 0
        error = torch.zeros(0).to(device)
        for imgs, labels in test_loader:
            imgs = imgs.to(device).float()
            labels = torch.as_tensor(labels['age']).to(device)
            outputs = model(imgs)
            
            _, pred = torch.max(outputs.data, 1)

            error = torch.cat([error, torch.abs(
                torch.subtract(torch.reshape(labels, (-1,)), torch.reshape(pred, (-1,)))
            )])
                        
            total += labels.size(0)
            correct += (pred == labels).sum().item()
            
    print(f"Accuracy: {(100*correct)/total}%")
    print(f"Mean Absolute Error: {(torch.mean(error))}")
    print(f"Minimum: {torch.min(error)}, Maximum: {torch.max(error)}, Median: {torch.median(error)}")

  and should_run_async(code)


In [15]:
eval(convModel, test_set)

  and should_run_async(code)


Accuracy: 2.8814073400060662%
Mean Absolute Error: 12.0594482421875
Minimum: 0.0, Maximum: 63.0, Median: 10.0


In [16]:
eval(convModel, train_set)

  and should_run_async(code)


Accuracy: 98.7338893100834%
Mean Absolute Error: 0.14874905347824097
Minimum: 0.0, Maximum: 48.0, Median: 0.0
