In [1]:
import os
import sys
from pathlib import Path
parent_dir = str(Path().resolve().parent)
sys.path.append(parent_dir)
os.chdir(parent_dir)
import configs
import hybrid_autoencoder
import data_handler
import pandas as pd
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
SEED = 7

## Create synthetic dataset

Same dataset as presented in "4.1.4 CelebA: gender classification with known confounding" in this project's report

In [15]:
prop = 0.95
n = 1808
train_ds, _, _, img_dir = data_handler.load_original_celebA()

men_with_glasses = train_ds[(train_ds['Male'] == 1) & (train_ds['Eyeglasses'] == 1)].groupby('ID').sample(n=1, random_state=SEED)
men_without_glasses = train_ds[(train_ds['Male'] == 1) & (train_ds['Eyeglasses'] == 0)].groupby('ID').sample(n=1, random_state=SEED)
women_with_glasses = train_ds[(train_ds['Male'] == 0) & (train_ds['Eyeglasses'] == 1)].groupby('ID').sample(n=1, random_state=SEED)
women_without_glasses = train_ds[(train_ds['Male'] == 0) & (train_ds['Eyeglasses'] == 0)].groupby('ID').sample(n=1, random_state=SEED)

mg = men_with_glasses.sample(n)
mwg = men_without_glasses[men_without_glasses['ID'].isin(mg['ID'])].sample(int((1-prop)*n))
wwg = women_without_glasses.sample(n)
wg =   women_with_glasses[women_with_glasses['ID'].isin(wwg['ID'])].sample(int((1-prop)*n))
train_ds = pd.concat([mg, mwg, wwg, wg])

config = configs.celebA_experiment()
# ignore first two arguments
_, test1_ds, test2_ds, img_dir = data_handler.load_datasets_celebA_counfound(0.95, 10_000, config.test_size)

len(men_with_glasses), len(men_without_glasses), len(women_with_glasses), len(women_without_glasses)

print("Train set size: ", len(train_ds))
id_counts = train_ds['ID'].value_counts()
print("Distinct IDs that occur more than once: ", id_counts[id_counts > 1].count())
print("Mean occurrences of IDs: ", id_counts.mean())
print("Max occurrences of IDs: ", id_counts.max())

cross_tab = pd.crosstab(index=train_ds['Male'], columns=train_ds['Eyeglasses'])

cross_tab_percentage = cross_tab.div(cross_tab.sum().sum()) * 100

print("\n\nPercentage of images for each combination of gender and eyeglasses status:")
print(cross_tab_percentage)

grouped = train_ds.groupby(['ID', 'Eyeglasses']).size().unstack(fill_value=0)
ids_with_and_without_glasses = grouped[(grouped[0] > 0) & (grouped[1] > 0)]
distinct_ids_count = len(ids_with_and_without_glasses)
print("\n\nNumber of distinct IDs with at least one instance of eyeglasses being worn and not worn:", distinct_ids_count)
print("These IDs encompass a total of", ids_with_and_without_glasses.sum().sum(), "images")
print("Mean number of images with glasses per ID:", ids_with_and_without_glasses[1].mean())
print("Mean number of images without glasses per ID:", ids_with_and_without_glasses[0].mean())

train_synth = data_handler.conv_celebA_to_jax(pd.concat([train_ds, test1_ds, test2_ds]), img_dir)
train_ds = data_handler.conv_celebA_to_jax(train_ds, img_dir)
test1_ds = data_handler.conv_celebA_to_jax(test1_ds, img_dir)
test2_ds = data_handler.conv_celebA_to_jax(test2_ds, img_dir)

Train set size:  3796
Distinct IDs that occur more than once:  225
Mean occurrences of IDs:  1.0644980370162647
Max occurrences of IDs:  3


Percentage of images for each combination of gender and eyeglasses status:
Eyeglasses          0          1
Male                            
0           47.629083   2.370917
1            2.370917  47.629083


Number of distinct IDs with at least one instance of eyeglasses being worn and not worn: 225
These IDs encompass a total of 455 images
Mean number of images with glasses per ID: 1.0222222222222221
Mean number of images without glasses per ID: 1.0


In [22]:
config = configs.hybrid_AE()
state = hybrid_autoencoder.train_and_evaluate(config, workdir=f"results/hybridAE", train_ds=train_synth)

INFO:absl:epoch: 1, train_loss: 0.1610, accuracy: 50.42, ae_loss: 0.0913, aux_loss: 0.6966, 
INFO:absl:epoch: 2, train_loss: 0.1575, accuracy: 49.89, ae_loss: 0.0881, aux_loss: 0.6938, 
INFO:absl:epoch: 3, train_loss: 0.1543, accuracy: 49.76, ae_loss: 0.0850, aux_loss: 0.6937, 
INFO:absl:epoch: 4, train_loss: 0.1503, accuracy: 49.79, ae_loss: 0.0809, aux_loss: 0.6933, 


KeyboardInterrupt: 