In [1]:
from typing import Dict

import plotly.express as px
import tensorflow as tf
import tensorflow_datasets as tfds
import torchvision.transforms as transforms

In [2]:
dset = tfds.load("celeb_a")

In [3]:
D_test, D_train, D_val = dset["test"], dset["train"], dset["validation"]

In [4]:
D_train

<PrefetchDataset shapes: {attributes: {5_o_Clock_Shadow: (), Arched_Eyebrows: (), Attractive: (), Bags_Under_Eyes: (), Bald: (), Bangs: (), Big_Lips: (), Big_Nose: (), Black_Hair: (), Blond_Hair: (), Blurry: (), Brown_Hair: (), Bushy_Eyebrows: (), Chubby: (), Double_Chin: (), Eyeglasses: (), Goatee: (), Gray_Hair: (), Heavy_Makeup: (), High_Cheekbones: (), Male: (), Mouth_Slightly_Open: (), Mustache: (), Narrow_Eyes: (), No_Beard: (), Oval_Face: (), Pale_Skin: (), Pointy_Nose: (), Receding_Hairline: (), Rosy_Cheeks: (), Sideburns: (), Smiling: (), Straight_Hair: (), Wavy_Hair: (), Wearing_Earrings: (), Wearing_Hat: (), Wearing_Lipstick: (), Wearing_Necklace: (), Wearing_Necktie: (), Young: ()}, image: (218, 178, 3), landmarks: {lefteye_x: (), lefteye_y: (), leftmouth_x: (), leftmouth_y: (), nose_x: (), nose_y: (), righteye_x: (), righteye_y: (), rightmouth_x: (), rightmouth_y: ()}}, types: {attributes: {5_o_Clock_Shadow: tf.bool, Arched_Eyebrows: tf.bool, Attractive: tf.bool, Bags_Unde

In [9]:
def filter_for_biased_hair_and_gender(example: Dict) -> bool:
    blonde_female = tf.logical_and(
        example["attributes"]["Blond_Hair"],
        tf.logical_not(example["attributes"]["Male"]),
    )
    brunette_male = tf.logical_and(
        example["attributes"]["Black_Hair"], example["attributes"]["Male"],
    )
    return tf.logical_or(blonde_female, brunette_male)


def extract_image_and_label(example):
    X = example["image"]
    y = example["attributes"]["Male"]
    y_biased = example["attributes"]["Blond_Hair"]

    X = tf.image.convert_image_dtype(X, dtype=tf.float32, saturate=False)
    return (X, y)


def normalize_image(image: tf.Tensor, mean, stddev):
    per_channel = []
    for i, (m, s) in enumerate(zip(mean, stddev)):
        per_channel.append((image[:, :, :, i] - m) / s)
    return tf.stack(per_channel, axis=-1)


def transform_image(X, y):
    orig_w = 178
    orig_h = 218
    orig_min_dim = min(orig_w, orig_h)
    target_resolution = (224, 224)

    X = tf.keras.layers.experimental.preprocessing.CenterCrop(
        orig_min_dim, orig_min_dim
    )(X)
    X = tf.keras.layers.experimental.preprocessing.Resizing(*target_resolution)(X)
    X = normalize_image(X, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    return X, y


D_test, D_train, D_val = dset["test"], dset["train"], dset["validation"]
D_train = D_train.filter(filter_for_biased_hair_and_gender)
D_train = D_train.map(extract_image_and_label)
D_train = D_train.batch(1).map(transform_image)

In [10]:
for X, y in D_train:
    print(X)
    break

tf.Tensor(
[[[[ 2.14616     1.7983196   1.3850982 ]
   [ 2.181709    1.8346622   1.4212793 ]
   [ 2.1892014   1.8423216   1.4289047 ]
   ...
   [ 1.8221663   1.3540018   0.88151973]
   [ 1.8721639   1.4306725   0.94936836]
   [ 1.8721639   1.4306725   0.94936836]]

  [[ 2.1343102   1.7862054   1.3730379 ]
   [ 2.1698594   1.822548    1.409219  ]
   [ 2.1831176   1.8361022   1.4227129 ]
   ...
   [ 1.8336984   1.3471332   0.8808734 ]
   [ 1.8648635   1.3980616   0.92524755]
   [ 1.8484645   1.4064441   0.92524755]]

  [[ 2.1290352   1.7808126   1.367669  ]
   [ 2.1645844   1.8171551   1.4038501 ]
   [ 2.1804094   1.8333336   1.4199566 ]
   ...
   [ 1.797167    1.3014803   0.83817977]
   [ 1.8250822   1.338325    0.87210405]
   [ 1.8129153   1.344544    0.87210405]]

  ...

  [[-1.7754089  -1.7205882  -1.403573  ]
   [-1.7872585  -1.7327025  -1.4156334 ]
   [-1.7925336  -1.7380952  -1.4210021 ]
   ...
   [-1.6879866  -1.7747185  -1.5514616 ]
   [-1.6832107  -1.7327021  -1.4568597 ]
   [-

def filter_dataset(D: tf.data.Dataset, attribute_name: str, target_value):
    def f_filter(example: Dict) -> bool:
        return tf.equal(example["attributes"][attribute_name], target_value)

    return D.filter(f_filter)

In [None]:
n_examples = 10
for example in filter_dataset(D_train, "Male", True):
    attributes, image, landmarks = (
        example["attributes"],
        example["image"],
        example["landmarks"],
    )
    fig = px.imshow(image)
    fig.show()
    n_examples -= 1
    if n_examples == 0:
        break

In [None]:
def resnet_model():
    return tf.keras.applications.ResNet50(
        include_top=False, weights="imagenet", pooling="avg"
    )

In [None]:
def make_celeb_a_model():
    inputs = tf.keras.layers.Input((224, 224, 3))
    X = inputs
    X = resnet_model()(X)
    features = tf.keras.layers.Dense(
        100, kernel_regularizer=tf.keras.regularizers.l2(0.0001)
    )(X)

    X = features
    X = tf.keras.layers.Dense(2, kernel_regularizer=tf.keras.regularizers.l2(0.0001))(X)
    return tf.keras.Model(inputs, outputs=[features, X])

In [None]:
make_celeb_a_model().summary()