In [1]:
import os
import sys
import math
import copy
import numpy as np
from PIL import Image
from tqdm import tqdm
from parse import parse
import pretrainedmodels
from kornia.losses import FocalLoss

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler

import torchvision
from torchvision import datasets
import torchvision.models as models
from torchvision.transforms import ToTensor, Compose, Scale, Grayscale, Resize, transforms

import albumentations as A
from albumentations.pytorch import ToTensorV2
import warnings
warnings.filterwarnings("ignore")

In [2]:
from kornia.losses import FocalLoss as focal_loss
def label_smoothing_criterion(alpha=0.1, distribution='uniform', std=0.5, reduction='mean'):
    def _label_smoothing_criterion(logits, labels):
        n_classes = logits.size(1)
        device = logits.device
        # manipulate labels
        one_hot = one_hot_encoding(labels, n_classes).float().to(device)
        if distribution == 'uniform':
            uniform = torch.ones_like(one_hot).to(device)/n_classes
            soft_labels = (1 - alpha)*one_hot + alpha*uniform
        elif distribution == 'gaussian':
            dist = get_gaussian_label_distribution(n_classes, std=std)
            soft_labels = torch.from_numpy(dist[labels.cpu().numpy()]).to(device)
        else:
            raise NotImplementedError

        loss = cross_entropy_loss_one_hot(logits, soft_labels.float(), reduction)

        return loss

    return _label_smoothing_criterion

def cost_sensitive_loss(input, target, M):
    if input.size(0) != target.size(0):
        raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).'
                         .format(input.size(0), target.size(0)))
    device = input.device
    M = M.to(device)
    return (M[target, :]*input.float()).sum(axis=-1)
    # return torch.diag(torch.matmul(input, M[:, target]))

class CostSensitiveRegularizedLoss(nn.Module):
    def __init__(self,  n_classes=5, exp=2, normalization='softmax', reduction='mean', base_loss='ce', lambd=10):
        super(CostSensitiveRegularizedLoss, self).__init__()
        if normalization == 'softmax':
            self.normalization = nn.Softmax(dim=1)
        elif normalization == 'sigmoid':
            self.normalization = nn.Sigmoid()
        else:
            self.normalization = None
        self.reduction = reduction
        x = np.abs(np.arange(n_classes, dtype=np.float32))
        M = np.abs((x[:, np.newaxis] - x[np.newaxis, :])) ** exp
        #
        # M_oph = np.array([
        #                 [1469, 4, 5,  0,  0],
        #                 [58, 62,  5,  0,  0],
        #                 [22, 3, 118,  1,  0],
        #                 [0, 0,   13, 36,  1],
        #                 [0, 0,    0,  1, 15]
        #                 ], dtype=np.float)
        # M_oph = M_oph.T
        # # Normalize M_oph to obtain M_difficulty:
        # M_difficulty = 1-np.divide(M_oph, np.sum(M_oph, axis=1)[:, None])
        # # OPTION 1: average M and M_difficulty:
        # M = 0.5 * M + 0.5 * M_difficulty
        # ################
        # # OPTION 2: replace uninformative entries in M_difficulty by entries of M:
        # # M_difficulty[M_oph == 0] = M[M_oph == 0]
        # # M = M_difficulty

        M /= M.max()
        self.M = torch.from_numpy(M)
        self.lambd = lambd
        self.base_loss = base_loss

        if self.base_loss == 'ce':
            self.base_loss = torch.nn.CrossEntropyLoss(reduction=reduction)
        elif self.base_loss == 'ls':
            self.base_loss = label_smoothing_criterion(distribution='uniform', reduction=reduction)
        elif self.base_loss == 'gls':
            self.base_loss = label_smoothing_criterion(distribution='gaussian', reduction=reduction)
        elif self.base_loss == 'focal_loss':
            kwargs = {"alpha": 0.5, "gamma": 2.0, "reduction": reduction}
            self.base_loss = focal_loss(**kwargs)
        else:
            sys.exit('not a supported base_loss')

    def forward(self, logits, target):
        base_l = self.base_loss(logits, target)
        if self.lambd == 0:
            return self.base_loss(logits, target)
        else:
            preds = self.normalization(logits)
            loss = cost_sensitive_loss(preds, target, self.M)
            if self.reduction == 'none':
                return base_l + self.lambd*loss
            elif self.reduction == 'mean':
                return base_l + self.lambd*loss.mean()
            elif self.reduction == 'sum':
                return base_l + self.lambd*loss.sum()
            else:
                raise ValueError('`reduction` must be one of \'none\', \'mean\', or \'sum\'.')

def get_cost_sensitive_criterion(n_classes=5, exp=2):
    train_criterion = CostSensitiveLoss(n_classes, exp=exp, normalization='softmax')
    val_criterion = CostSensitiveLoss(n_classes, exp=exp, normalization='softmax')
    return train_criterion, val_criterion

def get_cost_sensitive_regularized_criterion(base_loss='ce', n_classes=5, lambd=1, exp=2):
    train_criterion = CostSensitiveRegularizedLoss(n_classes, exp=exp, normalization='softmax', base_loss=base_loss, lambd=lambd)
    val_criterion = CostSensitiveRegularizedLoss(n_classes, exp=exp, normalization='softmax', base_loss=base_loss, lambd=lambd)

    return train_criterion, val_criterion

In [3]:
# hyper params
num_of_class = 102
learning_rate = 1e-03
batch_size = 128
input_size = 224
device = torch.device("cuda")
train_size = 0.7
test_size = 0.2
directoryAgeDB = 'AgeDB/'

In [4]:
def imageList():
    image_list = []
    for i, file in enumerate(sorted(os.listdir(directoryAgeDB))):
        file_labels = parse('{}_{person}_{age}_{gender}.jpg', file)

        if file_labels is None:
            continue

        image_location = os.path.join(directoryAgeDB, file)
        gender_to_class_id = {'m': 0, 'f': 1}
        gender = gender_to_class_id[file_labels['gender']]
        age = int(file_labels['age'])
        image_list.append({
            'image_location': image_location,
            'age': age,
            'gender': gender
        })

    return image_list

image_list = imageList()

In [5]:
train_len = int(len(image_list) * train_size)
test_len = int(len(image_list) * test_size)
validate_len = len(image_list) - (train_len + test_len)

train_image_list, test_image_list, validate_image_list = torch.utils.data.random_split(
    dataset = image_list,
    lengths = [train_len, test_len, validate_len], 
    generator = torch.Generator().manual_seed(42)
)

print(len(train_image_list))
print(len(test_image_list))
print(len(validate_image_list))

11541
3297
1650


In [6]:
class AgeDBDataset(Dataset):
    
    def __init__(self, image_list, device, train = True, train_transform=None, test_transform=None, **kwargs):
        self.image_list = image_list
        self.device = device
        self.train = train
        self.train_transform = train_transform
        self.test_transform = test_transform
        self.labels = []
        self.images = []

        if self.train:
            for i in tqdm(range(len(image_list))):

                image = Image.open(self.image_list[i]['image_location']).convert('RGB')
                image_location = self.image_list[i]['image_location']
                age = self.image_list[i]['age']
                gender = self.image_list[i]['gender']

                image = np.array(image)

                for j in range(1):
                    if j == 0:
                        augmented_images = self.test_transform(
                            image=image)['image']
                    else:
                        augmented_images = self.train_transform(
                            image=image)['image']

                    self.images.append(augmented_images)
                    self.labels.append({
                        'image_location': image_location,
                        'age': age,
                        'gender': gender
                    })

    def __len__(self):

        if self.train:
            return len(self.labels)
        else:
            return len(self.image_list)

    def __getitem__(self, index):

        if torch.is_tensor(index):
            index = index.tolist()

        if self.train:
            image = self.images[index]
            labels = {
                'image_location': self.labels[index]['image_location'],
                'age': self.labels[index]['age'],
                'gender': self.labels[index]['gender']
            }

        else:
            image = Image.open(
                self.image_list[index]['image_location']).convert('RGB')
            image = np.array(image)
            image = self.test_transform(image=image)['image']
            labels = {
                'image_location': self.image_list[index]['image_location'],
                'age': self.image_list[index]['age'],
                'gender': self.image_list[index]['gender']
            }

        return image.to(self.device), labels

In [7]:
train_transform = A.Compose([
    A.Resize(input_size, input_size),
    A.ToGray(p=1),
    A.Rotate(limit=10, p=0.3),
    A.HorizontalFlip(p=0.4),
    A.OpticalDistortion(p=0.2),
    A.OneOf([
        A.Blur(blur_limit=3, p=0.2),
        A.ColorJitter(p=0.2),
    ], p=0.2),
    ToTensorV2()
])

test_transform = A.Compose(
    [A.Resize(input_size, input_size),
     A.ToGray(p=1),
     ToTensorV2()
])

In [8]:
trainDataset = AgeDBDataset(image_list=train_image_list,
                            device=device,
                            train=True,
                            train_transform=train_transform,
                            test_transform=test_transform)
testDataset = AgeDBDataset(image_list=test_image_list,
                           device=device,
                           train=False,
                           train_transform=train_transform,
                           test_transform=test_transform)
valDataset = AgeDBDataset(image_list=validate_image_list,
                          device=device,
                          train=False,
                          train_transform=train_transform,
                          test_transform=test_transform)
print(len(trainDataset))
print(len(testDataset))
print(len(valDataset))
print(len(trainDataset) + len(testDataset) + len(valDataset))

100%|███████████████████████████████████████████████████████████████████████████| 11541/11541 [00:13<00:00, 845.93it/s]

11541
3297
1650
16488





In [9]:
#### --- Oversampling --- ###
def sam_weights(TDataset):
    listofzeros = [0] * 102
    same_age = [i for i in range(0, 102)]
    class_weights = dict(zip(same_age, listofzeros))

    for i in range(len(TDataset)):
        class_weights[TDataset[i][1]['age']] += 1

    for i in range(len(class_weights)):
        if class_weights[i] > 0:
            class_weights[i] = (1 / class_weights[i])

    sample_weights = {}

    for i in range(len(TDataset)):
        sample_weights[i] = class_weights[TDataset[i][1]['age']]

    return sample_weights

In [10]:
sample_weights = sam_weights(trainDataset)
sample_weights = list(sample_weights.values())
sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)

In [11]:
train_loader = DataLoader(trainDataset, batch_size=batch_size, shuffle=True)

test_loader = DataLoader(testDataset, batch_size=batch_size, shuffle=False)

validation_loader = DataLoader(valDataset, batch_size=batch_size, shuffle=False)

In [12]:
ResNet18 = pretrainedmodels.resnet18(pretrained='imagenet')
ResNet18.last_linear = nn.Linear(
    in_features=ResNet18.last_linear.in_features, 
    out_features=num_of_class, 
    bias=False
)
resnetModel = ResNet18.to(device)

In [13]:
# Training loop
def train(model, optimizer, criterion, train_loader, valid_loader, num_of_epoch):
    total_step = len(train_loader)
    min_valid_loss = np.inf

    for epoch in range(num_of_epoch):
        train_loss = 0.0
        valid_loss = 0.0
        for i, (imgs, labels) in enumerate(train_loader):
            imgs = imgs.to(device).float()
            labels = torch.as_tensor(labels['age']).to(device)
            
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            
        for i, (imgs, labels) in enumerate(valid_loader):
            with torch.no_grad():
                imgs = imgs.to(device).float()
                labels = torch.as_tensor(labels['age']).to(device)
            
                outputs = model(imgs)
                loss = criterion(outputs, labels)
                valid_loss += loss.item()
        
        
        print(f"Epoch: {epoch+1}/{num_of_epoch}, Train Loss: {train_loss/len(train_loader)},  Validation Loss: {valid_loss/len(valid_loader)}")
              
        if min_valid_loss > (valid_loss/len(valid_loader)):
            print(f'Validation Loss Decreased({min_valid_loss} ---> {valid_loss/len(valid_loader)})')
            min_valid_loss = valid_loss/len(valid_loader)
            torch.save(resnetModel.state_dict(), 'model.pth')

In [14]:
#criteria = FocalLoss(alpha=0.5, gamma=3.0, reduction='mean')
criteria = CostSensitiveRegularizedLoss(n_classes=num_of_class, base_loss='focal_loss', reduction='mean') 
optimizer = torch.optim.SGD(resnetModel.parameters(), lr=learning_rate, momentum=0.9, nesterov=True)

In [15]:
train(resnetModel, optimizer, criteria, train_loader, validation_loader, num_of_epoch=10)

Epoch: 1/10, Train Loss: 3.0962836690001434,  Validation Loss: 2.814979443183312
Validation Loss Decreased(inf ---> 2.814979443183312)
Epoch: 2/10, Train Loss: 2.7133998477851953,  Validation Loss: 2.656817472898043
Validation Loss Decreased(2.814979443183312 ---> 2.656817472898043)
Epoch: 3/10, Train Loss: 2.6024273987654802,  Validation Loss: 2.5898914337158203
Validation Loss Decreased(2.656817472898043 ---> 2.5898914337158203)
Epoch: 4/10, Train Loss: 2.5404602459498813,  Validation Loss: 2.542136229001559
Validation Loss Decreased(2.5898914337158203 ---> 2.542136229001559)
Epoch: 5/10, Train Loss: 2.4868528895325714,  Validation Loss: 2.4961320620316725
Validation Loss Decreased(2.542136229001559 ---> 2.4961320620316725)
Epoch: 6/10, Train Loss: 2.4322169036655636,  Validation Loss: 2.4515050741342397
Validation Loss Decreased(2.4961320620316725 ---> 2.4515050741342397)
Epoch: 7/10, Train Loss: 2.3812340589670034,  Validation Loss: 2.405780076980591
Validation Loss Decreased(2.451

In [16]:
# Evaluation
def eval(model, test_loader):
    with torch.no_grad():
        correct = 0
        total = 0
        error = torch.zeros(0).to(device)
        for imgs, labels in test_loader:
            imgs = imgs.to(device).float()
            labels = torch.as_tensor(labels['age']).to(device)
            outputs = model(imgs)
            
            _, pred = torch.max(outputs.data, 1)

            error = torch.cat([error, torch.abs(
                torch.subtract(torch.reshape(labels, (-1,)), torch.reshape(pred, (-1,)))
            )])
                        
            total += labels.size(0)
            correct += (pred == labels).sum().item()
            
    #print(f"Accuracy: {(100*correct)/total}%")
    print(f"Mean Absolute Error: {(torch.mean(error))}")
    print(f"Minimum: {torch.min(error)}, Maximum: {torch.max(error)}, Median: {torch.median(error)}")

In [17]:
eval(resnetModel, test_loader)
eval(resnetModel, train_loader)
resnetModel.load_state_dict(torch.load('model.pth'))

Mean Absolute Error: 10.510160446166992
Minimum: 0.0, Maximum: 53.0, Median: 9.0
Mean Absolute Error: 10.23472785949707
Minimum: 0.0, Maximum: 52.0, Median: 8.0


<All keys matched successfully>

In [18]:
eval(resnetModel, test_loader)

Mean Absolute Error: 10.510160446166992
Minimum: 0.0, Maximum: 53.0, Median: 9.0


Mean Absolute Error: 7.715802192687988
Minimum: 0.0, Maximum: 51.0, Median: 6.0

In [19]:
eval(resnetModel, train_loader)

Mean Absolute Error: 10.164977073669434
Minimum: 0.0, Maximum: 52.0, Median: 8.0


Mean Absolute Error: 6.060046672821045
Minimum: 0.0, Maximum: 59.0, Median: 5.0