In [None]:
from PIL import Image
from torch import nn
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import models, transforms
from tqdm.notebook import trange
import matplotlib.pyplot as plt
import numpy as np
import os
import time
import torch

from src.utils.model_trainer import ModelTrainer

In [None]:
gpu_indx = 0
device = torch.device(gpu_indx if torch.cuda.is_available() else 'cpu')

num_workers = 0 if os.name == 'nt' else 4

In [None]:
batch_size = 64
dataset_path = '../temp/utkcropped'
image_size = 200
learning_rate = 1e-4
num_epochs = 10
start_epoch = 0

In [None]:
class UTKFaceDataset(Dataset):
    def __init__(self, directory, max_age, transform=transforms.ToTensor):
        self.directory = directory
        self.max_age = max_age
        self.transform = transform

        self.images = self.get_images(directory, max_age)
        self.ages = self.get_ages(self.images)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = os.path.join(self.directory, self.images[idx])
        image = Image.open(img_name)
        image = image.convert('RGB')
        image = self.transform(image)

        age = torch.tensor(self.get_age_from_file_name(self.images[idx]), dtype=torch.int64)
        
        return image, age
    
    def check_is_image(file_name):
        return file_name.endswith('.jpg')
    
    def get_age_from_file_name(self, file_name):
        return int(file_name.split('_')[0])

    def get_images(self, directory, max_age):
        images = []
        for file in os.listdir(directory):
            if file.endswith('.jpg') and self.get_age_from_file_name(file) <= max_age:
                images.append(file)
        return images
    
    def get_ages(self, images):
        all_ages = [self.get_age_from_file_name(file_name) for file_name in images]
        unique_ages = set(all_ages)
        return sorted(list(unique_ages))


In [None]:
# def get_mean_std(loader):
#     num_pixels = 0
#     mean = 0.0
#     std = 0.0
#     for images, _ in loader:
#         batch_size, num_channels, height, width = images.shape
#         num_pixels += batch_size * height * width
#         mean += images.mean(axis=(0, 2, 3)).sum()
#         std += images.std(axis=(0, 2, 3)).sum()

#     mean /= num_pixels
#     std /= num_pixels

#     return mean, std

# base_dataset = UTKFaceDataset(dataset_path, 90, transform=transforms.Compose([
#     transforms.Resize(image_size),
#     transforms.ToTensor(),
# ]))

# loader = torch.utils.data.DataLoader(base_dataset, batch_size=batch_size, shuffle=True)
# mean, std = get_mean_std(loader)

# print(mean, std)

In [None]:
# def calculate_mean_std(loader):
#     channels_sum, channels_squared_sum, num_batches = 0, 0, 0
    
#     for data, _ in loader:
#         # Rearrange batch to be the shape of [B, C, W * H]
#         data = data.view(data.size(0), data.size(1), -1)
#         # Update total sum and squared sum for each channel
#         channels_sum += torch.mean(data, dim=[0, 2])
#         channels_squared_sum += torch.mean(data**2, dim=[0, 2])
#         num_batches += 1
    
#     mean = channels_sum / num_batches
#     std = (channels_squared_sum / num_batches - mean ** 2) ** 0.5

#     return mean, std

In [None]:
# base_dataset = UTKFaceDataset(dataset_path, 90, transform=transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Resize(image_size),
# ]))
# base_dataloader = DataLoader(base_dataset, batch_size=64, shuffle=False)

# n_samples = 0
# mean = torch.zeros(3)
# std = torch.zeros(3)

# for data, _ in base_dataloader:
#     # Rearrange batch to be the shape of [B, C, W * H]
#     data = data.view(data.size(0), data.size(1), -1)
#     # Update total number of images
#     n_samples += data.size(0)
#     # Compute mean and std here
#     mean += data.mean(2).sum(0)
#     std += data.std(2).sum(0) 

# # Final calculation
# mean /= n_samples
# std /= n_samples

# print(f'mean: {mean}')
# print(f'std: {std}')

In [None]:
def display_info(dataset, name):
    ages = []

    for i in range(len(dataset)):
        _, age = dataset[i]
        ages.append(age)

    plt.title(f'{name} Ages Distribution')
    plt.xlabel('Person Age')
    plt.ylabel('Number of Images')
    plt.hist(ages)
    plt.show()

In [None]:
transfrom = transforms.Compose([
    transforms.Resize(image_size),
    transforms.RandomRotation(10),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    # transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

In [None]:
dataset = UTKFaceDataset(dataset_path, max_age=90, transform=transfrom)

In [None]:
train_partition = 0.8
valid_partition = 0.1
test_partition = 0.1

train_examples_num = int(len(dataset) * train_partition)
valid_examples_num = int(len(dataset) * valid_partition)
test_examples_num = len(dataset) - train_examples_num - valid_examples_num

train_dataset, valid_dataset, test_dataset = random_split(
    dataset,
    [train_examples_num, valid_examples_num, test_examples_num],
    generator=torch.Generator().manual_seed(42)
)

In [None]:
# display_info(train_dataset, 'Train')
# display_info(valid_dataset, 'Validation')
# display_info(test_dataset, 'Test')

In [None]:
print(f'Number of training examples: {len(train_dataset)}')
print(f'Number of validation examples: {len(valid_dataset)}')
print(f'Number of testing examples: {len(test_dataset)}')

In [None]:
rows, columns = 2, 2

def GetRandom(): return np.random.randint(0, len(train_dataset)-1)

randomIndex = [GetRandom() for _ in range(rows * columns)]

for i in range(rows * columns):
    x, y = train_dataset[randomIndex[i]]
    print(x)
    print(y)
    age_group_index = int(y.item())
    age_group = dataset.ages[age_group_index]

    plt.subplot(rows, columns, i + 1)
    plt.title(f'Age: {age_group}')
    plt.imshow(x.numpy().transpose((1, 2, 0)).clip(0, 1))
    plt.axis('off')
plt.show()

In [None]:
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
print(model)


num_ftrs = model.fc.in_features
print(num_ftrs)

model.fc = nn.Linear(num_ftrs, len(dataset.ages) + 2)

In [None]:
model_name = 'resnet18_UTKFace'
save_dir = '../temp'

model_trainer = ModelTrainer(
    batch_size=batch_size,
    device=device,
    learning_rate=learning_rate, 
    loss_fun=nn.CrossEntropyLoss(), 
    model_name=model_name, 
    model=model,
    num_workers=0,
    save_dir=save_dir,
)

In [None]:
model_trainer.set_data(train_data=train_dataset, valid_data=valid_dataset, test_data=test_dataset)

In [None]:
params_num = 0
for param in model_trainer.model.parameters():
    params_num += param.flatten().shape[0]
print("This model has %d (approximately %d Million) Parameters!" % (params_num, params_num / 1e6))

In [None]:
start_time = time.time()
valid_acc = 0
train_acc = 0

pbar = trange(start_epoch, num_epochs, leave=False, desc="Epoch")    
for epoch in pbar:
    pbar.set_postfix_str('Accuracy: Train %.2f%%, Val %.2f%%' % (train_acc * 100, valid_acc * 100))
    
    model_trainer.train_model()
    
    train_acc = model_trainer.evaluate_model(train_test_val="train")
    valid_acc = model_trainer.evaluate_model(train_test_val="val")
    
    if valid_acc > model_trainer.best_valid_acc:
        model_trainer.save_checkpoint(epoch, valid_acc)

end_time = time.time()

In [None]:
print(model_trainer.best_valid_acc)