In [None]:
from facenet_pytorch import InceptionResnetV1
from IPython.display import clear_output
from PIL import Image
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import models, transforms
from tqdm.notebook import trange, tqdm
import matplotlib.pyplot as plt
import numpy as np
import os
import torch

In [None]:
gpu_indx = 0
device = torch.device(gpu_indx if torch.cuda.is_available() else 'cpu')

workers = 0 if os.name == 'nt' else 4

In [None]:
batch_size = 64
dataset_path = '../temp/utkcropped'
epochs_num = 10
learning_rate = 1e-4

In [None]:
def check_is_image(file_name):
    return file_name.endswith('.jpg')

def get_age_from_file_name(file_name):
    return int(file_name.split('_')[0])

def get_images(directory, max_age=100):
    def is_valid(image):
        return check_is_image(image) and get_age_from_file_name(image) <= max_age
        
    images = [image for image in os.listdir(directory) if is_valid(image)]
    return images

def get_ages(images):
    all_ages = [get_age_from_file_name(file_name) for file_name in images]
    unique_ages = set(all_ages)
    return sorted(list(unique_ages))

def reduce_dataset(images, max_age=100, max_images_per_age=100):
    ages = get_ages(images)
    images_per_age = {age: 0 for age in ages}
    reduced_images = []
    
    for image in images:
        age = get_age_from_file_name(image)
        if age <= max_age and images_per_age[age] < max_images_per_age:
            images_per_age[age] += 1
            reduced_images.append(image)
            
    return reduced_images

In [203]:
max_age = 90
images = get_images(dataset_path, max_age)
print(len(images))
age_groups = [str(age) for age in get_ages(images)]

print(len(age_groups))

23622
90


In [None]:
class UTKFaceDataset(Dataset):
    def __init__(self, directory, max_age, transform=None):
        self.directory = directory
        self.transform = transform
        self.images = get_images(directory, max_age)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = os.path.join(self.directory, self.images[idx])
        image = Image.open(img_name)
        age = get_age_from_file_name(self.images[idx])

        if self.transform:
            image = self.transform(image)

        return image, age

In [None]:
def display_info(dataset, name):
    ages = []

    for i in range(len(dataset)):
        _, age = dataset[i]
        ages.append(age)

    plt.title(f'{name} Ages Distribution')
    plt.xlabel('Person Age')
    plt.ylabel('Number of Images')
    plt.hist(ages)
    plt.show()


In [None]:

transfrom = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((160, 160)),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    transforms.RandomRotation(10),
])

dataset = UTKFaceDataset(dataset_path, max_age=max_age, transform=transfrom)

In [None]:
# display_info(dataset, 'UTKFace')

In [None]:
train_partition = 0.8
valid_partition = 0.1
test_partition = 0.1

train_examples_num = int(len(dataset) * train_partition)
valid_examples_num = int(len(dataset) * valid_partition)
test_examples_num = len(dataset) - train_examples_num - valid_examples_num

train_dataset, valid_dataset, test_dataset = random_split(
    dataset,
    [train_examples_num, valid_examples_num, test_examples_num],
    generator=torch.Generator().manual_seed(42)
)

In [None]:
print(f'Number of training examples: {len(train_dataset)}')
print(f'Number of validation examples: {len(valid_dataset)}')
print(f'Number of testing examples: {len(test_dataset)}')

In [None]:
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

In [None]:
rows, columns = 2, 2

def GetRandom(): return np.random.randint(0, len(train_dataset)-1)

randomIndex = [GetRandom() for _ in range(rows * columns)]

for i in range(rows * columns):
    x, y = train_dataset[randomIndex[i]]
    age_group_index = int(y)
    age_group = age_groups[age_group_index]

    plt.subplot(rows, columns, i + 1)
    plt.title(f'Age: {age_group}')
    plt.imshow(x.numpy()[0], cmap='gray')
    plt.axis('off')
plt.show()

In [None]:
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)

model.fc = nn.Linear(512, len(age_groups) + 2)
model = nn.Sequential(model, torch.nn.Sigmoid()).to(device)

model




In [None]:
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()

In [None]:
def train(model, loader, optimizer, loss_fn, loss_logger):
    model.train()

    for i, (x, y) in enumerate(tqdm(loader, leave=False, desc="Training")):
        forward_pass = model(x.to(device))
        loss = loss_fn(forward_pass, y.to(device))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss_logger.append(loss.item())
        
    return model, optimizer, loss_logger

In [None]:
def evaluate(model, loader):
    epoch_accuracy = 0

    model.eval()


    with torch.no_grad():
        for i, (x, y) in enumerate(tqdm(loader, leave=False, desc="Evaluating")):
            forward_pass = model(x.to(device))
            epoch_accuracy += (forward_pass.argmax(1) == y.to(device)).sum().item()

    return epoch_accuracy / len(loader.dataset)

In [None]:
train_loss_logger = []
train_acc_logger = []
valid_acc_logger = []

In [None]:
for epoch in trange(epochs_num, desc="Epochs"):
    model, optimizer, train_loss_logger = train(model, train_loader, optimizer, loss_fn, train_loss_logger)
    
    train_accuracy = evaluate(model, train_loader)
    train_acc_logger.append(train_accuracy)
    
    valid_accuracy = evaluate(model, valid_loader)
    valid_acc_logger.append(valid_accuracy)
    clear_output(wait=True)