In [24]:
import pandas as pd
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import numpy as np
from PIL import Image
# torch imports
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.utils.data import Dataset
import torch.nn.functional as F
import os

In [25]:
# Run this once
from google.colab import drive
drive.mount('/content/gdrive')
# You will have to modify this based on your Google Drive directory structure
# New directory for FairFace
fairface_folder_path = "/content/gdrive/MyDrive/Colab Notebooks/FairFaceData/"

if not os.path.exists(fairface_folder_path):
    os.makedirs(fairface_folder_path)
    print(f"Created FairFace data directory: {fairface_folder_path}")
else:
    print(f"FairFace data directory already exists: {fairface_folder_path}")
%cd "$fairface_folder_path"

!pwd

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).
FairFace data directory already exists: /content/gdrive/MyDrive/Colab Notebooks/FairFaceData/
/content/gdrive/MyDrive/Colab Notebooks/FairFaceData
/content/gdrive/MyDrive/Colab Notebooks/FairFaceData


In [26]:
class FairFaceDataset(Dataset):
    def __init__(self, dataframe, image_dir, transform=None):
        self.data = dataframe.copy()
        self.image_dir = image_dir
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        row = self.data.iloc[index]
        img_path = os.path.join(self.image_dir, row['file'])
        image = Image.open(img_path).convert('L')  # grayscale

        if self.transform:
            image = self.transform(image)

        labels = torch.tensor([
            int(row['age_label']),
            int(row['gender_label']),
            int(row['race_label'])
        ], dtype=torch.long)

        return image, labels


In [27]:
# High level feature extractor network (Adopted VGG type structure)
class highLevelNN(nn.Module):
    def __init__(self):
        super(highLevelNN, self).__init__()
        self.CNN = nn.Sequential(
            # first batch (32)
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            nn.ReLU(),

            # second batch (64)
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            nn.ReLU(),

            # Third Batch (128)
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
        )

    def forward(self, x):
        out = self.CNN(x)

        return out

# Low level feature extraction module
class lowLevelNN(nn.Module):
    def __init__(self, num_out):
        super(lowLevelNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(in_features=2048, out_features=256)
        self.fc2 = nn.Linear(in_features=256, out_features=128)
        self.fc3 = nn.Linear(in_features=128, out_features=64)
        self.fc4 = nn.Linear(in_features=64, out_features=num_out)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), kernel_size=3, stride=2, padding=1))
        x = F.relu(F.max_pool2d(self.conv2(x), kernel_size=3, stride=2, padding=1))
        x = torch.flatten(x, start_dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)

        return x


class TridentNN(nn.Module):
    def __init__(self, num_age, num_gen, num_eth):
        super(TridentNN, self).__init__()
        # Construct the high level neural network
        self.CNN = highLevelNN()
        # Construct the low level neural networks
        self.ageNN = lowLevelNN(num_out=num_age)
        self.genNN = lowLevelNN(num_out=num_gen)
        self.ethNN = lowLevelNN(num_out=num_eth)

    def forward(self, x):
        x = self.CNN(x)
        age = self.ageNN(x)
        gen = self.genNN(x)
        eth = self.ethNN(x)

        return age, gen, eth

In [28]:
'''
    Function to test the trained model

    Inputs:
      - testloader : PyTorch DataLoader containing the test dataset
      - modle : Trained NeuralNetwork

    Outputs:
      - Prints out test accuracy for gender and ethnicity and loss for age
'''
def test(testloader, model):
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
  size = len(testloader.dataset)
  # put the moel in evaluation mode so we aren't storing anything in the graph
  model.eval()

  age_acc, gen_acc, eth_acc = 0, 0, 0

  with torch.no_grad():
      for X, y in testloader:
          X = X.to(device)
          age, gen, eth = y[:,0].to(device), y[:,1].to(device), y[:,2].to(device)
          pred = model(X)

          age_acc += (pred[0].argmax(1) == age).type(torch.float).sum().item()
          gen_acc += (pred[1].argmax(1) == gen).type(torch.float).sum().item()
          eth_acc += (pred[2].argmax(1) == eth).type(torch.float).sum().item()

  age_acc /= size
  gen_acc /= size
  eth_acc /= size

  print(f"Age Accuracy : {age_acc*100}%,     Gender Accuracy : {gen_acc*100},    Ethnicity Accuracy : {eth_acc*100}\n")

In [29]:
# Read in the dataframe
fairface_folder_path = "/content/gdrive/MyDrive/Colab Notebooks/FairFaceData"
csv_path = os.path.join(fairface_folder_path, "fairface_label_train.csv")
df_fairface = pd.read_csv(csv_path)

# Construct age bins
age_groups = ['0-2', '3-9', '10-19', '20-29', '30-39', '40-49', '50-59', '60-69', 'more than 70']
age_mapping = {k: v for v, k in enumerate(age_groups)}
df_fairface['age_label'] = df_fairface['age'].map(age_mapping)

df_fairface['gender_label'] = df_fairface['gender'].map({'Male': 0, 'Female': 1})

race_classes = ['White', 'Black', 'Latino_Hispanic', 'East Asian',
                'Southeast Asian', 'Indian', 'Middle Eastern', 'Other']
race_mapping = {race: idx for idx, race in enumerate(race_classes)}
df_fairface['race_label'] = df_fairface['race'].map(race_mapping)

df_fairface['file_path'] = df_fairface['file'].apply(lambda x: os.path.join(fairface_folder_path, x))

# Split into training and testing
train_df, test_df = train_test_split(df_fairface, test_size=0.2)

# get the number of unique classes for each group
class_nums = {
    'age_num': len(df_fairface['age_label'].unique()),
    'eth_num': len(df_fairface['race'].unique()),
    'gen_num': len(df_fairface['gender'].unique())
}

# Define train and test transforms
train_transform = transforms.Compose([
    transforms.Resize((48, 48)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

test_transform = transforms.Compose([
    transforms.Resize((48, 48)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Dataset
train_set = FairFaceDataset(train_df, image_dir=fairface_folder_path, transform=train_transform)
test_set = FairFaceDataset(test_df, image_dir=fairface_folder_path, transform=test_transform)

# DataLoader
trainloader = DataLoader(train_set, batch_size=64, shuffle=True)
testloader = DataLoader(test_set, batch_size=128, shuffle=False)

# Sanity check
for X, y in trainloader:
    print(f'Shape of training X: {X.shape}')
    print(f'Shape of y: {y.shape}')
    break

Shape of training X: torch.Size([64, 1, 48, 48])
Shape of y: torch.Size([64, 3])


In [30]:
# Configure the device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

# Define the list of hyperparameters
hyperparameters = {'learning_rate':0.001, 'epochs':30}

# Initialize the TridentNN model and put on device
model = TridentNN(class_nums['age_num'], class_nums['gen_num'], class_nums['eth_num'])
model.to(device)

cuda


TridentNN(
  (CNN): highLevelNN(
    (CNN): Sequential(
      (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
      (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): ReLU()
      (5): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (6): ReLU()
      (7): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (8): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (9): ReLU()
      (10): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): ReLU()
      (12): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (13): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (14): ReLU()
    )
  )
  (ageNN): lowLevelNN(
    (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1,

In [31]:
'''
  Functions to load and save a PyTorch model
'''
def save_checkpoint(state, epoch):
  print("Saving Checkpoint")
  filename = "tridentNN_epoch"+str(epoch)+".pth.tar"
  torch.save(state,filename)

def load_checkpoint(checkpoint):
  print("Loading Checkpoint")
  model.load_state_dict(checkpoint['state_dict'])
  opt.load_state_dict(checkpoint['optimizer'])


In [32]:
'''
train the model
'''
# Load hyperparameters
learning_rate = hyperparameters['learning_rate']
num_epoch = hyperparameters['epochs']

# Define loss functions
age_loss = nn.CrossEntropyLoss()
gen_loss = nn.CrossEntropyLoss() # TODO : Explore using Binary Cross Entropy Loss?
eth_loss = nn.CrossEntropyLoss()

# Define optimizer
opt = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Train the model
for epoch in range(num_epoch):
  # Construct tqdm loop to keep track of training
  loop = tqdm(enumerate(trainloader), total=len(trainloader), position=0, leave=True)
  age_correct, gen_correct, eth_correct, total = 0,0,0,0

  # save the model every 10 epochs
  if epoch % 10 == 0:
    checkpoint = {'state_dict' : model.state_dict(), 'optimizer' : opt.state_dict(),
                  'age_loss' : age_loss, 'gen_loss' : gen_loss, 'eth_loss' : eth_loss}
    save_checkpoint(checkpoint, epoch)

  # Loop through dataLoader
  for _, (X,y) in loop:
    # Unpack y to get true age, eth, and gen values
    # Have to do some special changes to age label to make it compatible with NN output and Loss function
    #age, gen, eth = y[:,0].resize_(len(y[:,0]),1).float().to(device), y[:,1].to(device), y[:,2].to(device)
    age, gen, eth = y[:,0].to(device), y[:,1].to(device), y[:,2].to(device)
    X = X.to(device)
    pred = model(X)          # Forward pass
    loss = age_loss(pred[0],age) + gen_loss(pred[1],gen) + eth_loss(pred[2],eth)   # Loss calculation

    # Backpropagation
    opt.zero_grad()          # Zero the gradient
    loss.backward()          # Calculate updates

    # Gradient Descent
    opt.step()               # Apply updates

    # Update num correct and total
    age_correct += (pred[0].argmax(1) == age).type(torch.float).sum().item()
    gen_correct += (pred[1].argmax(1) == gen).type(torch.float).sum().item()
    eth_correct += (pred[2].argmax(1) == eth).type(torch.float).sum().item()

    total += len(y)

    # Update progress bar
    loop.set_description(f"Epoch [{epoch+1}/{num_epoch}]")
    loop.set_postfix(loss = loss.item())

  # Update epoch accuracy
  gen_acc, eth_acc, age_acc = gen_correct/total, eth_correct/total, age_correct/total

  # print out accuracy and loss for epoch
  print(f'Epoch : {epoch+1}/{num_epoch},    Age Accuracy : {age_acc*100},    Gender Accuracy : {gen_acc*100},    Ethnicity Accuracy : {eth_acc*100}\n')

  0%|          | 0/1085 [00:00<?, ?it/s]

Saving Checkpoint


Epoch [1/30]: 100%|██████████| 1085/1085 [04:45<00:00,  3.81it/s, loss=4.01]


Epoch : 1/30,    Age Accuracy : 30.813459182938253,    Gender Accuracy : 65.04359103681821,    Ethnicity Accuracy : 26.13012464874991



Epoch [2/30]: 100%|██████████| 1085/1085 [04:43<00:00,  3.82it/s, loss=2.89]


Epoch : 2/30,    Age Accuracy : 39.38468189350818,    Gender Accuracy : 77.40327112904389,    Ethnicity Accuracy : 41.28539520138339



Epoch [3/30]: 100%|██████████| 1085/1085 [04:43<00:00,  3.83it/s, loss=3.03]


Epoch : 3/30,    Age Accuracy : 43.69911376900353,    Gender Accuracy : 81.27674904532027,    Ethnicity Accuracy : 46.843432523957055



Epoch [4/30]: 100%|██████████| 1085/1085 [04:43<00:00,  3.82it/s, loss=2.99]


Epoch : 4/30,    Age Accuracy : 46.509114489516534,    Gender Accuracy : 83.39938035881548,    Ethnicity Accuracy : 50.290366741119676



Epoch [5/30]: 100%|██████████| 1085/1085 [04:43<00:00,  3.82it/s, loss=3.14]


Epoch : 5/30,    Age Accuracy : 48.53663808631746,    Gender Accuracy : 84.74962173067225,    Ethnicity Accuracy : 52.708408386771374



Epoch [6/30]: 100%|██████████| 1085/1085 [04:43<00:00,  3.83it/s, loss=2.58]


Epoch : 6/30,    Age Accuracy : 49.615966568196555,    Gender Accuracy : 85.86641688882484,    Ethnicity Accuracy : 54.74746019165646



Epoch [7/30]: 100%|██████████| 1085/1085 [04:41<00:00,  3.85it/s, loss=2.52]


Epoch : 7/30,    Age Accuracy : 50.89127458750631,    Gender Accuracy : 86.7137401830103,    Ethnicity Accuracy : 55.929101520282444



Epoch [8/30]: 100%|██████████| 1085/1085 [04:47<00:00,  3.78it/s, loss=2.26]


Epoch : 8/30,    Age Accuracy : 52.08012104618488,    Gender Accuracy : 87.57547373730095,    Ethnicity Accuracy : 57.39174292095972



Epoch [9/30]: 100%|██████████| 1085/1085 [04:41<00:00,  3.86it/s, loss=1.94]


Epoch : 9/30,    Age Accuracy : 53.34966496145256,    Gender Accuracy : 88.3103970026659,    Ethnicity Accuracy : 58.79242020318467



Epoch [10/30]: 100%|██████████| 1085/1085 [04:42<00:00,  3.84it/s, loss=1.76]


Epoch : 10/30,    Age Accuracy : 54.49383961380503,    Gender Accuracy : 89.06261258015708,    Ethnicity Accuracy : 60.05620001441025



  0%|          | 0/1085 [00:00<?, ?it/s]

Saving Checkpoint


Epoch [11/30]: 100%|██████████| 1085/1085 [04:47<00:00,  3.77it/s, loss=2.24]


Epoch : 11/30,    Age Accuracy : 55.563080913610484,    Gender Accuracy : 89.53526911160746,    Ethnicity Accuracy : 61.288277253404424



Epoch [12/30]: 100%|██████████| 1085/1085 [04:45<00:00,  3.80it/s, loss=2.69]


Epoch : 12/30,    Age Accuracy : 56.236040060523095,    Gender Accuracy : 90.27163340298291,    Ethnicity Accuracy : 62.80711866849197



Epoch [13/30]: 100%|██████████| 1085/1085 [04:43<00:00,  3.83it/s, loss=2.57]


Epoch : 13/30,    Age Accuracy : 57.58772245839038,    Gender Accuracy : 90.90424382160099,    Ethnicity Accuracy : 63.99596512717055



Epoch [14/30]: 100%|██████████| 1085/1085 [04:44<00:00,  3.82it/s, loss=2.9]


Epoch : 14/30,    Age Accuracy : 58.94084588226818,    Gender Accuracy : 91.51091577202969,    Ethnicity Accuracy : 65.45428344981626



Epoch [15/30]: 100%|██████████| 1085/1085 [04:43<00:00,  3.83it/s, loss=2.09]


Epoch : 15/30,    Age Accuracy : 60.22480005764103,    Gender Accuracy : 91.9187261330067,    Ethnicity Accuracy : 66.61863246631601



Epoch [16/30]:   2%|▏         | 18/1085 [00:05<05:14,  3.40it/s, loss=1.64]


KeyboardInterrupt: 

I manuall interrupted the training because I wanted everything to have a training accuracy > 90% and I didn't code that part in yet
<br> <br>
Now I am going to test the model

In [33]:
test(testloader, model)

Age Accuracy : 46.89031068073088%,     Gender Accuracy : 84.35068303648625,    Ethnicity Accuracy : 52.665859703729325



As you can see the testing accuracy is not that great. My hypothesis is that predicting age is actually a very difficult task because there is so much variation between how people age.
<br> <br>
Even between different genders and ethnicities there is so much variance. Therefore, we have both inter and intra-variance when it comes to age.
<br> <br>
Perhaps a better approach would be to feed the outputs of the gender and ethnicity classifier to the age classifier so it can use that information as well. But, that's a project for another day.