In [1]:
import pandas as pd
import collections
from collections import Counter
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras import layers, models
import matplotlib
import matplotlib.pyplot as plt


  if not hasattr(np, "object"):


In [2]:
(train,validation,test)= tfds.load("celeb_a", split=['train','validation','test'], as_supervised=False)

In [3]:
IMG_SIZE = 64
def preprocess(sample):
    image = sample['image']
    image = tf.image.resize(image,[IMG_SIZE, IMG_SIZE])
    #Normalization
    image = image/255.0
    # Finetune atrribute
    # Convert directly to int32
    label = tf.cast(sample['attributes']['Smiling'], tf.int32)
    
    label = tf.where(label == -1, 0, label)
    return image,label

In [4]:
BATCH_SIZE=32
BUFFER_SIZE = 1000

In [5]:
# Batch the dataset 

train_batches = (
    train
    .shuffle(BUFFER_SIZE)
    .map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
    .batch(BATCH_SIZE)
    .prefetch(tf.data.AUTOTUNE)
)

validation_batches = (
    validation
    .map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
    .batch(BATCH_SIZE)
    .prefetch(tf.data.AUTOTUNE)
)

testing_batches = (
    test
    .map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
    .batch(BATCH_SIZE)
    .prefetch(tf.data.AUTOTUNE)
)

In [None]:
images, labels = next(iter(train_batches))
print(images.shape)
print(labels.shape)

Dataset Exploratory Analysis & Imbalance

In [None]:
# Create function to count data
def attribute_counts(raw_train,raw_validation,raw_test):
    counter = Counter()
    for i in raw_train:
        male = int(i['attributes']['Male']  )    # Male label
        smiling = int(i['attributes']['Smiling'])# Smiling label
        counter[(male,smiling)]+=1
    for i in raw_test:
        male = int(i['attributes']['Male']  )    # Male label
        smiling = int(i['attributes']['Smiling'])# Smiling label
        counter[(male,smiling)]+=1
    for i in raw_validation:
        male = int(i['attributes']['Male']  )    # Male label
        smiling = int(i['attributes']['Smiling'])# Smiling label
        counter[(male,smiling)]+=1
    return counter
    
counts=attribute_counts(train,validation,test)
print('The number of females smiling: ',counts[(0,1)] )
print('The number of females not smiling:',counts[(0,0)])
print('The total number of females: ', counts[(0,1)]+counts[(0,0)])
print('The number of males smiling:' ,counts[(1,1)])
print('The number of males not smiling:', counts[(1,0)])
print('The total number of males:', counts[(1,1)]+counts[(1,0)])

In [None]:
# Visualize attribute counts
def plot_attributes(counts):
    # Extract counts
    female_smiling = counts[(0,1)]
    female_notsmiling = counts[(0,0)]
    female_total = female_smiling + female_notsmiling

    male_smiling = counts[(1,1)]
    male_notsmiling = counts[(1,0)]
    male_total = male_smiling + male_notsmiling

    # Create figure
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20, 8))

    categories = ["Female", "Male"]
    colors = ["#8B5A2B", "#2E8B57"]  # Brown (Female), Green (Male)

    ax1.set_title("Smiling Images by Gender", fontsize=14)
    values1 = [female_smiling, male_smiling]
    bars1 = ax1.bar(categories, values1, color=colors)
    ax1.set_ylabel("Number of Images")

    # Add value labels
    for bar in bars1:
        ax1.text(bar.get_x() + bar.get_width()/2,
                 bar.get_height(),
                 f"{bar.get_height():,}",
                 ha="center", va="bottom")

    # --- Plot 2: Not Smiling ---
    ax2.set_title("Non-Smiling Images by Gender", fontsize=14)
    values2 = [female_notsmiling, male_notsmiling]
    bars2 = ax2.bar(categories, values2, color=colors)
    ax2.set_ylabel("Number of Images")

    for bar in bars2:
        ax2.text(bar.get_x() + bar.get_width()/2,
                 bar.get_height(),
                 f"{bar.get_height():,}",
                 ha="center", va="bottom")

    ax3.set_title("Overall Gender Distribution", fontsize=14)
    values3 = [female_total, male_total]
    bars3 = ax3.bar(categories, values3, color=colors)
    ax3.set_ylabel("Number of Images")

    for bar in bars3:
        ax3.text(bar.get_x() + bar.get_width()/2,
                 bar.get_height(),
                 f"{bar.get_height():,}",
                 ha="center", va="bottom")

    # Overall title
    fig.suptitle("Gender and Smiling Attribute Distribution in CelebA",
                 fontsize=18)

    fig.tight_layout()
plot_attributes(counts)

Dataset Interpretation

Among smiling images, female subjects constitute a 65% proportion of the dataset, while male smiling subjects account for 35%, This imbalance suggests that the classifier may learn smiling-related features more effectively for female faces than for male faces.

A similar imbalance is observed in non-smiling images, indicating that the gender skew is consistent across facial expression categories rather than isolated to smiling alone.

Overall, female images represent a higher proportion of the dataset than male images. As a result, accuracy may overestimate model performance if subgroup-level disparities are not examined

In [147]:
# preprocessing pipelines

IMG_SIZE = 64
def preprocess(sample):
    image = sample['image']
    image = tf.image.resize(image,[IMG_SIZE, IMG_SIZE])
    #Normalization
    image = image/255.0
    # Finetune atrribute
    # Convert directly to int32
    label = tf.cast(sample['attributes']['Smiling'], tf.int32)
    
    label = tf.where(label == -1, 0, label)
    return image,label
    
    
    
    

In [143]:
#Data Visualization
def show_images(dataset):
    plt.figure(figsize=(10, 10))

    for i, (image, label) in enumerate(dataset.take(25)):
        ax = plt.subplot(5, 5, i + 1)   
        plt.imshow(image)
        graph_label= tf.where(label==0,False,True)
        plt.title(f"Smiling: {graph_label.numpy()}")
        plt.axis("off")

    plt.show()


We used a three-block convolutional architecture, which provides sufficient capacity to learn expression-level features while minimizing overfitting and maintaining interpretability for bias analysis.





In [144]:
def build_cnn_model(
    input_shape=(64, 64, 3), # width heigh channels
    num_classes=1,
    learning_rate=1e-3
):
    inputs = layers.Input(shape=input_shape)

    x = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
    x = layers.MaxPooling2D((2, 2))(x)

    x = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(x)
    x = layers.MaxPooling2D((2, 2))(x)

    x = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x)
    x = layers.MaxPooling2D((2, 2))(x)

    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(128, activation='relu')(x)
    outputs = layers.Dense(num_classes, activation='sigmoid')(x)

    model = models.Model(inputs=inputs, outputs=outputs)

    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
        loss='binary_crossentropy',
        metrics=['accuracy']
    )

    return model

    

In [145]:
# creating training model parameters
#defining models
model = build_cnn_model()
# creatingcallbacks
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience =3, mode='min', restore_best_weights=True)
# defining max_epochs
epoch= 30

In [None]:
# Model Training
#creating training function
def train_model(model, train_data, val_data, max_epochs, callbacks):
    history = model.fit(
        train_data,
        validation_data=val_data,
        epochs=max_epochs,
        callbacks=callbacks
    )
    return model, history
          
sanity_model, sanity_history = train_model(
    model=model,
    train_data=train_batches,
    val_data=validation_batches,
    max_epochs=1,
    callbacks=[early_stopping]
)

In [None]:
# Overall Model Evaluation

In [None]:
# Subgroup Model Evaluation

In [None]:
# Evaluation Visualization

In [60]:
def attribute_counts(dataset):
    counter = Counter()
    for i in dataset:
        male = int(i['attributes']['Male']  )    # Male label
        smiling = int(i['attributes']['Smiling'])# Smiling label
        counter[(male,smiling)]+=1
    print('The number of males smiling ',counter[(0,1)] )
    print('The number of males not smiling ',counter[(0,0)])
    print('The number of females smiling' ,counter[(1,1)])
    print('The number of females not smiling ', counter[(1,0)])
        
attribute_counts(validation)


The number of males smiling  6157
The number of males not smiling  5252
The number of females smiling 3445
The number of females not smiling  5013
