In [23]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torchvision import models, datasets
from sklearn.model_selection import train_test_split
import random
import os
from PIL import Image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class CustomAgeGenderDataset(Dataset):
    def __init__(self, root_dir, image_files, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_files = image_files

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_file = self.image_files[idx]
        image_path = os.path.join(self.root_dir, image_file)
        image = Image.open(image_path).convert('RGB')
        age, gender = map(int, image_file.split('_')[:2])  # Extract age and gender from the file name
        # Normalize gender to 0 for male and 1 for female
        gender = 1 if gender == 1 else 0
        
        if self.transform:
            image = self.transform(image)
            
        return image, age, gender

# Create a CNN model
class AgeGenderPredictionModel(nn.Module):
    def __init__(self):
        super(AgeGenderPredictionModel, self).__init__()
        self.features = models.resnet18(pretrained=True)
        self.fc_age = nn.Linear(1000, 1)  # Output layer for age prediction
        self.fc_gender = nn.Linear(1000, 1)  # Output layer for gender prediction

    def forward(self, x):
        x = self.features(x)
        age = self.fc_age(x)
        gender = torch.sigmoid(self.fc_gender(x))  # Sigmoid activation for binary gender prediction
        return age, gender

# Define transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [24]:
# load all of the image paths
image_files = [fname for fname in os.listdir('C:\\Users\\alans\\Datasets\\utkcropped') if fname.endswith('.jpg')]

# number of total images and the number of images to use
total_images = len(image_files)  
num_images = 2000  ################################## up to 23000, adjust as you want

# select a random subset
selected_files = random.sample(image_files, num_images)

# split the selected files into train and test sets
train_files, test_files = train_test_split(selected_files, test_size=0.2, random_state=42)

# use custom datasets for train and test sets
train_dataset = CustomAgeGenderDataset(root_dir='C:\\Users\\alans\\Datasets\\utkcropped', image_files=train_files, transform=transform)
test_dataset = CustomAgeGenderDataset(root_dir='C:\\Users\\alans\\Datasets\\utkcropped', image_files=test_files, transform=transform)

# define the dataloaders for train and test sets
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# initialize model, loss function, and optimizer
model = AgeGenderPredictionModel().to(device)
criterion_age = nn.MSELoss()
criterion_gender = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train the model
num_epochs = 10  ############################ Adjust this number as needed 
for epoch in range(num_epochs):
    running_loss_age = 0.0
    running_loss_gender = 0.0
    model.train() 
    for images, ages, genders in train_loader:
        images, ages, genders = images.to(device), ages.to(device), genders.to(device)  # Move to device
        optimizer.zero_grad()
        predicted_ages, predicted_genders = model(images)
        loss_age = criterion_age(predicted_ages, ages.float().unsqueeze(1))
        loss_gender = criterion_gender(predicted_genders, genders.float().unsqueeze(1))
        loss = loss_age + loss_gender  
        loss.backward()
        optimizer.step()
        running_loss_age += loss_age.item()
        running_loss_gender += loss_gender.item()
    print(f"Epoch {epoch+1}, Train Loss (Age): {running_loss_age/len(train_loader)}, Train Loss (Gender): {running_loss_gender/len(train_loader)}")

Epoch 1, Train Loss (Age): 300.4842501831055, Train Loss (Gender): 0.6898245453834534
Epoch 2, Train Loss (Age): 154.81126510620118, Train Loss (Gender): 0.6687715399265289
Epoch 3, Train Loss (Age): 112.87400665283204, Train Loss (Gender): 0.7087358820438385
Epoch 4, Train Loss (Age): 94.26985984802246, Train Loss (Gender): 0.6455407762527465
Epoch 5, Train Loss (Age): 76.94033401489258, Train Loss (Gender): 0.6908942431211471
Epoch 6, Train Loss (Age): 57.025408325195315, Train Loss (Gender): 0.6621396923065186
Epoch 7, Train Loss (Age): 54.20266786575317, Train Loss (Gender): 0.6692236053943634
Epoch 8, Train Loss (Age): 39.768920249938965, Train Loss (Gender): 0.6451170325279236
Epoch 9, Train Loss (Age): 30.93754976272583, Train Loss (Gender): 0.6785451328754425
Epoch 10, Train Loss (Age): 29.25818675994873, Train Loss (Gender): 0.6453408765792846


In [25]:
# Evaluate the model
model.eval()  
total_loss_age = 0.0
total_loss_gender = 0.0
total_absolute_error_age = 0.0
total_absolute_error_gender = 0.0
num_samples = 0

with torch.no_grad():  # disable gradient calculation during evaluation
    for images, ages, genders in test_loader:
        images, ages, genders = images.to(device), ages.to(device), genders.to(device)  # Move to device
        predicted_ages, predicted_genders = model(images)
        
        # absolute errors
        absolute_error_age = torch.abs(predicted_ages - ages.float().unsqueeze(1))
        absolute_error_gender = torch.abs(predicted_genders - genders.float().unsqueeze(1))
        
        # sum the absolute errors over the batch
        total_absolute_error_age += torch.sum(absolute_error_age).item()
        total_absolute_error_gender += torch.sum(absolute_error_gender).item()
        
        # compute loss
        loss_age = criterion_age(predicted_ages, ages.float().unsqueeze(1))
        loss_gender = criterion_gender(predicted_genders, genders.float().unsqueeze(1))
        
        # sum of losses over the batch
        total_loss_age += loss_age.item() * len(images)
        total_loss_gender += loss_gender.item() * len(images)
        
        num_samples += len(images)

# compute mean loss 
mean_loss_age = total_loss_age / num_samples
mean_loss_gender = total_loss_gender / num_samples

# Compute Mean Absolute Error
mean_absolute_error_age = total_absolute_error_age / num_samples
mean_absolute_error_gender = total_absolute_error_gender / num_samples

# Print mean loss and MAE for age and gender predictions
print(f"Mean Loss on Test Set (Age): {mean_loss_age}")
print(f"Mean Loss on Test Set (Gender): {mean_loss_gender}")
print(f"Mean Absolute Error on Test Set (Age): {mean_absolute_error_age}")
print(f"Mean Absolute Error on Test Set (Gender): {mean_absolute_error_gender}")

torch.save(model.state_dict(), 'age_gender_prediction_model.pth')

Mean Loss on Test Set (Age): 72.70202056884766
Mean Loss on Test Set (Gender): 0.573951563835144
Mean Absolute Error on Test Set (Age): 6.334968338012695
Mean Absolute Error on Test Set (Gender): 0.40671067237854003
