In [23]:
import pandas as pd
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import numpy as np
# torch imports
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.utils.data import Dataset
import torch.nn.functional as F

In [25]:
class UTKDataset(Dataset):
    '''
        Inputs:
            dataFrame : Pandas dataFrame
            transform : The transform to apply to the dataset
    '''
    def __init__(self, dataFrame, transform=None):
        # read in the transforms
        self.transform = transform
        
        # Use the dataFrame to get the pixel values
        data_holder = dataFrame.pixels.apply(lambda x: np.array(x.split(" "),dtype=float))
        arr = np.stack(data_holder)
        arr = arr / 255.0
        arr = arr.astype('float32')
        arr = arr.reshape(arr.shape[0], 48, 48, 1)
        # reshape into 48x48x1
        self.data = arr
        
        # get the age, gender, and ethnicity label arrays
        self.age_label = np.array(dataFrame.bins[:])        # Note : Changed dataFrame.age to dataFrame.bins with most recent change
        self.gender_label = np.array(dataFrame.gender[:])
        self.eth_label = np.array(dataFrame.ethnicity[:])
    
    # override the length function
    def __len__(self):
        return len(self.data)
    
    # override the getitem function
    def __getitem__(self, index):
        # load the data at index and apply transform
        data = self.data[index]
        data = self.transform(data)
        
        # load the labels into a list and convert to tensors
        labels = torch.tensor((self.age_label[index], self.gender_label[index], self.eth_label[index]))
        
        # return data labels
        return data, labels

In [26]:
# High level feature extractor network (Adopted VGG type structure)
class highLevelNN(nn.Module):
    def __init__(self):
        super(highLevelNN, self).__init__()
        self.CNN = nn.Sequential(
            # first batch (32)
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            nn.ReLU(),

            # second batch (64)
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            nn.ReLU(),

            # Third Batch (128)
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
        )

    def forward(self, x):
        out = self.CNN(x)

        return out

# Low level feature extraction module
class lowLevelNN(nn.Module):
    def __init__(self, num_out):
        super(lowLevelNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(in_features=2048, out_features=256)
        self.fc2 = nn.Linear(in_features=256, out_features=128)
        self.fc3 = nn.Linear(in_features=128, out_features=64)
        self.fc4 = nn.Linear(in_features=64, out_features=num_out)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), kernel_size=3, stride=2, padding=1))
        x = F.relu(F.max_pool2d(self.conv2(x), kernel_size=3, stride=2, padding=1))
        x = torch.flatten(x, start_dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)

        return x


class TridentNN(nn.Module):
    def __init__(self, num_age, num_gen, num_eth):
        super(TridentNN, self).__init__()
        # Construct the high level neural network
        self.CNN = highLevelNN()
        # Construct the low level neural networks
        self.ageNN = lowLevelNN(num_out=num_age)
        self.genNN = lowLevelNN(num_out=num_gen)
        self.ethNN = lowLevelNN(num_out=num_eth)

    def forward(self, x):
        x = self.CNN(x)
        age = self.ageNN(x)
        gen = self.genNN(x)
        eth = self.ethNN(x)

        return age, gen, eth

In [27]:
'''
    Function to test the trained model

    Inputs:
      - testloader : PyTorch DataLoader containing the test dataset
      - modle : Trained NeuralNetwork
    
    Outputs:
      - Prints out test accuracy for gender and ethnicity and loss for age
'''
def test(testloader, model):
  device = 'cuda' if torch.cuda.is_available() else 'cpu' 
  size = len(testloader.dataset)
  # put the moel in evaluation mode so we aren't storing anything in the graph
  model.eval()

  age_acc, gen_acc, eth_acc = 0, 0, 0

  with torch.no_grad():
      for X, y in testloader:
          X = X.to(device)
          age, gen, eth = y[:,0].to(device), y[:,1].to(device), y[:,2].to(device)
          pred = model(X)

          age_acc += (pred[0].argmax(1) == age).type(torch.float).sum().item()
          gen_acc += (pred[1].argmax(1) == gen).type(torch.float).sum().item()
          eth_acc += (pred[2].argmax(1) == eth).type(torch.float).sum().item()

  age_acc /= size
  gen_acc /= size
  eth_acc /= size

  print(f"Age Accuracy : {age_acc*100}%,     Gender Accuracy : {gen_acc*100},    Ethnicity Accuracy : {eth_acc*100}\n")

In [28]:
# Read in the dataframe
dataFrame = pd.read_csv(r'..\age_gender.gz', compression='gzip')
 
# Construct age bins
age_bins = [0,10,15,20,25,30,40,50,60,120]
age_labels = [0, 1, 2, 3, 4, 5, 6, 7, 8]
dataFrame['bins'] = pd.cut(dataFrame.age, bins=age_bins, labels=age_labels)

# Split into training and testing
train_dataFrame, test_dataFrame = train_test_split(dataFrame, test_size=0.2)

# get the number of unique classes for each group
class_nums = {'age_num':len(dataFrame['bins'].unique()), 'eth_num':len(dataFrame['ethnicity'].unique()),
              'gen_num':len(dataFrame['gender'].unique())}

# Define train and test transforms
train_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.49,), (0.23,))
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.49,), (0.23,))
])

# Construct the custom pytorch datasets
train_set = UTKDataset(train_dataFrame, transform=train_transform)
test_set = UTKDataset(test_dataFrame, transform=test_transform)

# Load the datasets into dataloaders
trainloader = DataLoader(train_set, batch_size=64, shuffle=True)
testloader = DataLoader(test_set, batch_size=128, shuffle=False)

# Sanity Check
for X, y in trainloader:
    print(f'Shape of training X: {X.shape}')
    print(f'Shape of y: {y.shape}')
    break 

Shape of training X: torch.Size([64, 1, 48, 48])
Shape of y: torch.Size([64, 3])


In [29]:
# Configure the device 
device = 'cuda' if torch.cuda.is_available() else 'cpu' 
print(device)

# Define the list of hyperparameters
hyperparameters = {'learning_rate':0.001, 'epochs':30}

# Initialize the TridentNN model and put on device
model = TridentNN(class_nums['age_num'], class_nums['gen_num'], class_nums['eth_num'])
model.to(device)

cpu


TridentNN(
  (CNN): highLevelNN(
    (CNN): Sequential(
      (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
      (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): ReLU()
      (5): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (6): ReLU()
      (7): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (8): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (9): ReLU()
      (10): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): ReLU()
      (12): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (13): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (14): ReLU()
    )
  )
  (ageNN): lowLevelNN(
    (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1,

In [30]:
'''
  Functions to load and save a PyTorch model
'''
def save_checkpoint(state, epoch):
  print("Saving Checkpoint")
  filename = "tridentNN_epoch"+str(epoch)+".pth.tar"
  torch.save(state,filename)

def load_checkpoint(checkpoint):
  print("Loading Checkpoint")
  model.load_state_dict(checkpoint['state_dict'])
  opt.load_state_dict(checkpoint['optimizer'])


In [31]:
'''
train the model
''' 
# Load hyperparameters
learning_rate = hyperparameters['learning_rate']
num_epoch = hyperparameters['epochs']

# Define loss functions
age_loss = nn.CrossEntropyLoss()
gen_loss = nn.CrossEntropyLoss() # TODO : Explore using Binary Cross Entropy Loss?
eth_loss = nn.CrossEntropyLoss()

# Define optimizer
opt = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Train the model
for epoch in range(num_epoch):
  # Construct tqdm loop to keep track of training
  loop = tqdm(enumerate(trainloader), total=len(trainloader), position=0, leave=True)
  age_correct, gen_correct, eth_correct, total = 0,0,0,0

  # save the model every 10 epochs
  if epoch % 10 == 0:
    checkpoint = {'state_dict' : model.state_dict(), 'optimizer' : opt.state_dict(), 
                  'age_loss' : age_loss, 'gen_loss' : gen_loss, 'eth_loss' : eth_loss}
    save_checkpoint(checkpoint, epoch)

  # Loop through dataLoader
  for _, (X,y) in loop:
    # Unpack y to get true age, eth, and gen values
    # Have to do some special changes to age label to make it compatible with NN output and Loss function
    #age, gen, eth = y[:,0].resize_(len(y[:,0]),1).float().to(device), y[:,1].to(device), y[:,2].to(device)
    age, gen, eth = y[:,0].to(device), y[:,1].to(device), y[:,2].to(device)
    X = X.to(device)
    pred = model(X)          # Forward pass
    loss = age_loss(pred[0],age) + gen_loss(pred[1],gen) + eth_loss(pred[2],eth)   # Loss calculation

    # Backpropagation
    opt.zero_grad()          # Zero the gradient
    loss.backward()          # Calculate updates

    # Gradient Descent
    opt.step()               # Apply updates

    # Update num correct and total
    age_correct += (pred[0].argmax(1) == age).type(torch.float).sum().item()
    gen_correct += (pred[1].argmax(1) == gen).type(torch.float).sum().item()
    eth_correct += (pred[2].argmax(1) == eth).type(torch.float).sum().item()

    total += len(y)

    # Update progress bar
    loop.set_description(f"Epoch [{epoch+1}/{num_epoch}]")
    loop.set_postfix(loss = loss.item())

  # Update epoch accuracy
  gen_acc, eth_acc, age_acc = gen_correct/total, eth_correct/total, age_correct/total

  # print out accuracy and loss for epoch
  print(f'Epoch : {epoch+1}/{num_epoch},    Age Accuracy : {age_acc*100},    Gender Accuracy : {gen_acc*100},    Ethnicity Accuracy : {eth_acc*100}\n')

  0%|          | 0/297 [00:00<?, ?it/s]

Saving Checkpoint


Epoch [1/30]: 100%|██████████| 297/297 [04:43<00:00,  1.05it/s, loss=3.61]


Epoch : 1/30,    Age Accuracy : 22.685087534275468,    Gender Accuracy : 59.881881459607676,    Ethnicity Accuracy : 44.38936933136469



Epoch [2/30]: 100%|██████████| 297/297 [04:40<00:00,  1.06it/s, loss=2.81]


Epoch : 2/30,    Age Accuracy : 36.50601139000211,    Gender Accuracy : 81.16431132672432,    Ethnicity Accuracy : 59.52858046825564



Epoch [3/30]: 100%|██████████| 297/297 [05:16<00:00,  1.07s/it, loss=2.2] 


Epoch : 3/30,    Age Accuracy : 42.52267454123603,    Gender Accuracy : 85.27736764395696,    Ethnicity Accuracy : 69.76376291921535



Epoch [4/30]: 100%|██████████| 297/297 [05:40<00:00,  1.15s/it, loss=2.08]


Epoch : 4/30,    Age Accuracy : 45.36490191942628,    Gender Accuracy : 88.061590381776,    Ethnicity Accuracy : 73.88209238557266



Epoch [5/30]: 100%|██████████| 297/297 [05:41<00:00,  1.15s/it, loss=2.32]


Epoch : 5/30,    Age Accuracy : 47.305420797300144,    Gender Accuracy : 88.91584053997047,    Ethnicity Accuracy : 76.58194473739718



Epoch [6/30]: 100%|██████████| 297/297 [14:25<00:00,  2.91s/it, loss=2.14]   


Epoch : 6/30,    Age Accuracy : 49.103564648808266,    Gender Accuracy : 90.33431765450327,    Ethnicity Accuracy : 78.62792659776419



Epoch [7/30]: 100%|██████████| 297/297 [05:40<00:00,  1.15s/it, loss=2.59]


Epoch : 7/30,    Age Accuracy : 50.46931027209449,    Gender Accuracy : 91.61569289179498,    Ethnicity Accuracy : 80.71609365112845



Epoch [8/30]: 100%|██████████| 297/297 [05:38<00:00,  1.14s/it, loss=2.14]


Epoch : 8/30,    Age Accuracy : 52.11980594811221,    Gender Accuracy : 92.41721155874288,    Ethnicity Accuracy : 82.08711242353934



Epoch [9/30]: 100%|██████████| 297/297 [05:44<00:00,  1.16s/it, loss=1.56]


Epoch : 9/30,    Age Accuracy : 54.34507487871757,    Gender Accuracy : 93.0447163045771,    Ethnicity Accuracy : 83.769246994305



Epoch [10/30]: 100%|██████████| 297/297 [05:46<00:00,  1.17s/it, loss=1.39]


Epoch : 10/30,    Age Accuracy : 55.44716304577093,    Gender Accuracy : 93.76713773465514,    Ethnicity Accuracy : 85.34591858257753



  0%|          | 0/297 [00:00<?, ?it/s]

Saving Checkpoint


Epoch [11/30]: 100%|██████████| 297/297 [05:38<00:00,  1.14s/it, loss=1.13]


Epoch : 11/30,    Age Accuracy : 57.888631090487245,    Gender Accuracy : 94.76376291921535,    Ethnicity Accuracy : 86.64838641636786



Epoch [12/30]: 100%|██████████| 297/297 [05:37<00:00,  1.14s/it, loss=2.05]


Epoch : 12/30,    Age Accuracy : 60.28264079308163,    Gender Accuracy : 95.30162412993039,    Ethnicity Accuracy : 87.99303944315545



Epoch [13/30]: 100%|██████████| 297/297 [05:39<00:00,  1.14s/it, loss=1.52] 


Epoch : 13/30,    Age Accuracy : 62.191520776207554,    Gender Accuracy : 96.1400548407509,    Ethnicity Accuracy : 89.45370175068551



Epoch [14/30]: 100%|██████████| 297/297 [05:35<00:00,  1.13s/it, loss=0.778]


Epoch : 14/30,    Age Accuracy : 65.58215566336216,    Gender Accuracy : 96.34570765661253,    Ethnicity Accuracy : 90.49778527736765



Epoch [15/30]: 100%|██████████| 297/297 [05:37<00:00,  1.14s/it, loss=1.26] 


Epoch : 15/30,    Age Accuracy : 67.88124868171272,    Gender Accuracy : 96.90466146382619,    Ethnicity Accuracy : 92.22737819025522



Epoch [16/30]: 100%|██████████| 297/297 [05:37<00:00,  1.14s/it, loss=1.44] 


Epoch : 16/30,    Age Accuracy : 70.89221683189201,    Gender Accuracy : 97.0892216831892,    Ethnicity Accuracy : 92.77051255009492



Epoch [17/30]: 100%|██████████| 297/297 [05:54<00:00,  1.19s/it, loss=0.725]


Epoch : 17/30,    Age Accuracy : 72.89601349926176,    Gender Accuracy : 97.34760599029741,    Ethnicity Accuracy : 93.69858679603459



Epoch [18/30]: 100%|██████████| 297/297 [06:55<00:00,  1.40s/it, loss=1.14] 


Epoch : 18/30,    Age Accuracy : 75.67496308795613,    Gender Accuracy : 97.87492090276314,    Ethnicity Accuracy : 94.5370175068551



Epoch [19/30]: 100%|██████████| 297/297 [37:31<00:00,  7.58s/it, loss=0.469]   


Epoch : 19/30,    Age Accuracy : 79.05505167686142,    Gender Accuracy : 98.18076355199325,    Ethnicity Accuracy : 95.21198059481122



Epoch [20/30]: 100%|██████████| 297/297 [07:49<00:00,  1.58s/it, loss=0.831]


Epoch : 20/30,    Age Accuracy : 80.1307740982915,    Gender Accuracy : 98.19130985024258,    Ethnicity Accuracy : 95.64965197215777



  0%|          | 0/297 [00:00<?, ?it/s]

Saving Checkpoint


Epoch [21/30]: 100%|██████████| 297/297 [07:57<00:00,  1.61s/it, loss=0.624]


Epoch : 21/30,    Age Accuracy : 82.56169584475849,    Gender Accuracy : 98.35477747310694,    Ethnicity Accuracy : 96.22969837587007



Epoch [22/30]: 100%|██████████| 297/297 [08:47<00:00,  1.78s/it, loss=0.359]


Epoch : 22/30,    Age Accuracy : 84.13836743303101,    Gender Accuracy : 98.5446108415946,    Ethnicity Accuracy : 96.49862898122758



Epoch [23/30]: 100%|██████████| 297/297 [08:01<00:00,  1.62s/it, loss=0.497]


Epoch : 23/30,    Age Accuracy : 85.19827040708712,    Gender Accuracy : 98.74499050833158,    Ethnicity Accuracy : 96.68318920059059



Epoch [24/30]: 100%|██████████| 297/297 [08:01<00:00,  1.62s/it, loss=0.135]


Epoch : 24/30,    Age Accuracy : 87.43408563594178,    Gender Accuracy : 98.7397173592069,    Ethnicity Accuracy : 97.05758278844125



Epoch [25/30]: 100%|██████████| 297/297 [10:52<00:00,  2.20s/it, loss=0.486]


Epoch : 25/30,    Age Accuracy : 87.94030795190888,    Gender Accuracy : 98.95591647331786,    Ethnicity Accuracy : 97.09449483231386



Epoch [26/30]: 100%|██████████| 297/297 [06:06<00:00,  1.23s/it, loss=0.291]


Epoch : 26/30,    Age Accuracy : 89.37987766294032,    Gender Accuracy : 98.96118962244252,    Ethnicity Accuracy : 97.8538283062645



Epoch [27/30]: 100%|██████████| 297/297 [05:58<00:00,  1.21s/it, loss=0.588]


Epoch : 27/30,    Age Accuracy : 89.71208605779371,    Gender Accuracy : 99.06137945581101,    Ethnicity Accuracy : 97.60599029740561



Epoch [28/30]: 100%|██████████| 297/297 [05:03<00:00,  1.02s/it, loss=0.328]


Epoch : 28/30,    Age Accuracy : 90.7245306897279,    Gender Accuracy : 99.11411094705758,    Ethnicity Accuracy : 97.69036068340013



Epoch [29/30]: 100%|██████████| 297/297 [03:21<00:00,  1.47it/s, loss=0.551]


Epoch : 29/30,    Age Accuracy : 91.40476692680869,    Gender Accuracy : 98.97173592069184,    Ethnicity Accuracy : 97.75891162202068



Epoch [30/30]: 100%|██████████| 297/297 [03:21<00:00,  1.48it/s, loss=0.451] 

Epoch : 30/30,    Age Accuracy : 91.77916051465935,    Gender Accuracy : 98.99282851719047,    Ethnicity Accuracy : 97.70090698164944






<br> <br>
Now I am going to test the model

In [32]:
test(testloader, model)

Age Accuracy : 46.551360472474165%,     Gender Accuracy : 88.63109048723898,    Ethnicity Accuracy : 73.52879139422063



As you can see the testing accuracy is not that great. My hypothesis is that predicting age is actually a very difficult task because there is so much variation between how people age.
<br> <br> 
Even between different genders and ethnicities there is so much variance. Therefore, we have both inter and intra-variance when it comes to age.
<br> <br>
Perhaps a better approach would be to feed the outputs of the gender and ethnicity classifier to the age classifier so it can use that information as well. But, that's a project for another day.

In [33]:
from PIL import Image


# Define a new dataset class for the new images
class NewDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, index):
        image_path = self.image_paths[index]
        image = Image.open(image_path).convert('RGB')  # Convert to RGB if needed
        if self.transform:
            image = self.transform(image)
        return image

# Example image paths
image_paths = ['../image1.jpg', '../image4.jpg','../image19.jpg', '../image26.jpg', '../image75.jpg']

# Define transformations for the new images
new_transform = transforms.Compose([
    transforms.Resize((48, 48)),  # Resize to match model input size
    transforms.Grayscale(num_output_channels=1),  # Convert images to grayscale
    transforms.ToTensor(),
    transforms.Normalize((0.49,), (0.23,))
])


# Create a new dataset for the new images
new_dataset = NewDataset(image_paths, transform=new_transform)

# Create a DataLoader for the new dataset
new_dataloader = DataLoader(new_dataset, batch_size=5, shuffle=False)

# Test the model on the new images
model.eval()

with torch.no_grad():
    for images in new_dataloader:
        images = images.to(device)
        age_preds, gen_preds, eth_preds = model(images)
        
        # Convert predictions to labels
        age_labels = age_preds.argmax(dim=1)
        gen_labels = gen_preds.argmax(dim=1)
        eth_labels = eth_preds.argmax(dim=1)
        
        # Print the predictions for each image
        for i in range(len(age_labels)):
            print(f"Image {i+1}: Age - {age_labels[i]}, Gender - {gen_labels[i]}, Ethnicity - {eth_labels[i]}")


Image 1: Age - 0, Gender - 1, Ethnicity - 2
Image 2: Age - 0, Gender - 1, Ethnicity - 4
Image 3: Age - 2, Gender - 0, Ethnicity - 0
Image 4: Age - 4, Gender - 1, Ethnicity - 2
Image 5: Age - 8, Gender - 0, Ethnicity - 0
