# CelebA Eval
### Overview
This script is used to evaluate a saved CelebA model qualitatively (with latent traversals) and quantitatively (using the TAD metric). The script will compare original data with reconstructions, generate latent traversals, and generate a TAD score at the end.

### Instructions
Adjust the hyperparameters and PATH below to match the hyperparameters and path of the saved model that you would like to evaluate. Then hit "Restart and Run All" on the Jupyter Notebook.

In [None]:
########## SELECT EVALUATION PARAMETERS ##################

from ae_utils_exp import celeba_norm, celeba_inorm
n_lat = 32
use_vae = False
PATH = "./models/celeba_ae_lr1e-3_seed55_ar0.2.pt"

from ae_utils_exp import B_TCVAE as VAE_BASED_MODEL # change <model> in ".... import <model> as ...."
### options: VAE (for VAE, BetaVAE), FACTOR_VAE, or B_TCVAE (for BetaTCVAE)

#################################################
from ae_utils_exp import AutoEncoder
import torch

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

from architectures import enc_celeba_small_vae as enc_vae
from architectures import enc_celeba_small as enc
from architectures import dec_celeba_small as dec
ae = None

if not use_vae:
    ae = AutoEncoder(celeba_norm, enc(lat=n_lat, inp_chan=3), dec(lat=n_lat, inp_chan=3), \
                     device, z_dim=n_lat, inp_inorm=celeba_inorm)
else:
    ae = VAE_BASED_MODEL(celeba_norm, enc_vae(lat=n_lat, inp_chan=3), dec(lat=n_lat, inp_chan=3), \
                         device, z_dim=n_lat, inp_inorm=celeba_inorm)


ae.load_state_dict(torch.load(PATH))
ae.eval()

In [None]:
seed = 30 # FIXED AT 30 FOR ALL EXPERIMENTS
import random
random.seed(seed)
import numpy as np
import matplotlib.pyplot as plt
from ae_utils_exp import multi_t, LatentClass, aurocs_search, tags
from torchvision.transforms import Compose

np.random.seed(seed)
torch.manual_seed(seed)

In [None]:
from torchvision.datasets import CelebA
import torchvision.transforms as tforms

tform = tforms.Compose([tforms.Resize(96), tforms.CenterCrop(64), tforms.ToTensor()])

eval_bs = 1000
# set up dataset for eval
dataset_eval = CelebA(root='../beamsynthesizer/data', split='all', target_type='attr', download=False, transform=tform)
dataloader_eval = torch.utils.data.DataLoader(dataset_eval, batch_size=eval_bs, shuffle=True, drop_last=False)
data, targ = next(iter(dataloader_eval))
fig, ax = plt.subplots(2, 10, figsize=(20, 4))
for i in range(20):
    ind = i // 10, i % 10
    ax[ind].imshow(multi_t(data[i], 0, 2))
    ax[ind].set_title(targ[i][20].item())


In [None]:
# Compare originals with reconstructions
plt_batch_size=500
num_to_plot=20
z_scores, z_pred_scores, inp, rec = ae.record_latent_space(dataset_eval, batch_size=plt_batch_size, n_batches=5)

inp = multi_t(inp, 1, 3).clamp(0, 1).cpu().numpy()
rec = multi_t(rec, 1, 3).clamp(0, 1).cpu().numpy()

fig, axes = plt.subplots(2, num_to_plot, figsize=(20, 4))
for i in range(num_to_plot):
    axes[0][i].imshow(inp[i])
    axes[1][i].imshow(rec[i])
    axes[0][i].tick_params(axis='both', which='both', bottom=False, top=False, labelbottom=False, left=False, labelleft=False)
    axes[1][i].tick_params(axis='both', which='both', bottom=False, top=False, labelbottom=False, left=False, labelleft=False)
plt.tight_layout()

In [None]:
from ae_utils_exp import InvNorm

invn = celeba_inorm


# determine base z_scores
ind = 6
z_base = z_scores[ind]
fig, axes = plt.subplots(1, 2, figsize=(6, 3))
axes[0].imshow(inp[ind])
axes[0].tick_params(axis='both', which='both', bottom=False, top=False, labelbottom=False, left=False, labelleft=False)
axes[1].imshow(rec[ind])
axes[1].tick_params(axis='both', which='both', bottom=False, top=False, labelbottom=False, left=False, labelleft=False)


In [None]:
# Generate latent traversals
# decode
fig, axes = plt.subplots(ae.z_dim//2, 10, figsize=(16, 24))
with torch.no_grad():
    for i in range(ae.z_dim//2):
        _min = z_scores[:, i].min()
        _max = z_scores[:, i].max()
        variation = torch.linspace(_min, _max, steps=10)
        for j in range(len(variation)):
            axes[i][j].tick_params(axis='both', which='both', bottom=False, top=False, labelbottom=False, left=False, labelleft=False)
            if _max - _min >= 0.2:
                z = z_base.clone()
                z[i] = variation[j]
                im = multi_t(invn(ae.dec(z.to(ae.device))), 1, 3).clamp(0, 1).squeeze().cpu().numpy()
                axes[i][j].imshow(im)
plt.tight_layout()

In [None]:
# Generate latent traversals
# decode
fig, axes = plt.subplots(ae.z_dim//2, 10, figsize=(16, 24))
with torch.no_grad():
    for i in range(ae.z_dim//2):
        _min = z_scores[:, i + ae.z_dim//2].min()
        _max = z_scores[:, i + ae.z_dim//2].max()
        variation = torch.linspace(_min, _max, steps=10)
        for j in range(len(variation)):
            axes[i][j].tick_params(axis='both', which='both', bottom=False, top=False, labelbottom=False, left=False, labelleft=False)
            if _max - _min >= 0.2:
                z = z_base.clone()
                z[i + ae.z_dim//2] = variation[j]
                im = multi_t(invn(ae.dec(z.to(ae.device))), 1, 3).clamp(0,1).squeeze().cpu().numpy()
                axes[i][j].imshow(im)
plt.tight_layout()

In [None]:
au_result, base_rates_raw, targ = aurocs_search(dataloader_eval, ae)
base_rates = base_rates_raw.where(base_rates_raw <= 0.5, 1. - base_rates_raw)

In [None]:
# calculate mutual information shared between attributes
# determine which share a lot of information with each other
with torch.no_grad():
    not_targ = 1 - targ
    j_prob = lambda x, y: torch.logical_and(x, y).sum() / x.numel()
    mi = lambda jp, px, py: 0. if jp == 0. or px == 0. or py == 0. else jp*torch.log(jp/(px*py))

    # Compute the Mutual Information (MI) between the labels
    mi_mat = torch.zeros((40, 40))
    for i in range(40):
        # get the marginal of i
        i_mp = targ[:, i].sum() / targ.shape[0]
        for j in range(40):
            j_mp = targ[:, j].sum() / targ.shape[0]
            # get the joint probabilities of FF, FT, TF, TT
            # FF
            jp = j_prob(not_targ[:, i], not_targ[:, j])
            pi = 1. - i_mp
            pj = 1. - j_mp
            mi_mat[i][j] += mi(jp, pi, pj)
            # FT
            jp = j_prob(not_targ[:, i], targ[:, j])
            pi = 1. - i_mp
            pj = j_mp
            mi_mat[i][j] += mi(jp, pi, pj)
            # TF
            jp = j_prob(targ[:, i], not_targ[:, j])
            pi = i_mp
            pj = 1. - j_mp
            mi_mat[i][j] += mi(jp, pi, pj)
            # TT
            jp = j_prob(targ[:, i], targ[:, j])
            pi = i_mp
            pj = j_mp
            mi_mat[i][j] += mi(jp, pi, pj)

    fig, ax = plt.subplots(1, 2)
    im = ax[0].imshow(mi_mat)
    fig.colorbar(im, ax=ax[0], shrink=0.6)
    mi_mat_ent_norm = mi_mat/mi_mat.diag().unsqueeze(1)
    im = ax[1].imshow(mi_mat_ent_norm)
    fig.colorbar(im, ax=ax[1], shrink=0.6)
    
    plt.figure(figsize=(10, 7))
    mi_comp = (mi_mat.sum(dim=1) - mi_mat.diag())/mi_mat.diag()
    plt.bar(range(len(tags)), mi_comp, tick_label=tags)
    plt.xticks(rotation=90)
    plt.title("Total Mutual Information")
    
    plt.figure(figsize=(10, 7))
    mi_maxes, mi_inds = (mi_mat * (1 - torch.eye(40))).max(dim=1)
    ent_red_prop = 1. - (mi_mat.diag() - mi_maxes) / mi_mat.diag()
    plt.bar(range(len(tags)), ent_red_prop, tick_label=tags)
    plt.xticks(rotation=90)
    plt.title("Proportion of Entropy Reduced by Another Trait")
    plt.grid(axis='y')
    print(mi_mat.diag())

In [None]:
fig, ax = plt.subplots(8, 5, figsize=(16, 16))
# print the ind, tag, max auroc, arg max auroc, norm_diff
max_aur, argmax_aur = torch.max(au_result.clone(), dim=1)
norm_diffs = torch.zeros(40)
aurs_diffs = torch.zeros(40)
for ind, tag, max_a, argmax_a, aurs in zip(range(40), tags, max_aur.clone(), argmax_aur.clone(), au_result.clone()):
    norm_aurs = (aurs.clone() - 0.5) / (aurs.clone()[argmax_a] - 0.5)
    aurs_next = aurs.clone()
    aurs_next[argmax_a] = 0.0
    aurs_diff = max_a - aurs_next.max()
    aurs_diffs[ind] = aurs_diff
    norm_aurs[argmax_a] = 0.0
    norm_diff = 1. - norm_aurs.max()
    norm_diffs[ind] = norm_diff
    print("{}\t\t Lat: {}\t Max: {:1.3f}\t ND: {:1.3f}".format(tag, argmax_a.item(), max_a.item(), norm_diff.item()))
    plt_ind = ind//5, ind%5
    ax[plt_ind].set_ylim((0.5, max_a.item() + 0.05))
    ax[plt_ind].set_title(tag)
    ax[plt_ind].set_ylabel("AUROC")
    ax[plt_ind].set_xlabel("Latent Variable")
    ax[plt_ind].bar(range(aurs.shape[0]), aurs)
    ax[plt_ind].grid(which='both', axis='y')
    assert aurs.max() == max_a

In [None]:
# how prevalent are each of the attributes
plt.figure(figsize=(10, 7))
plt.ylim((0.0, 0.5))
plt.title("Base Rates (Absolute)")
plt.ylabel("Base Rate")
plt.xlabel("Attribute")
plt.xticks(rotation=90)
plt.bar(range(len(tags)), base_rates, tick_label=tags)
plt.grid(which='both', axis='y')

In [None]:
# maximum AUROCS for each attribute
plt.figure(figsize=(10, 7))
plt.ylim((0.5, 1.0))
plt.title("Max AUROCS")
plt.ylabel("AUROC")
plt.xlabel("Attribute")
plt.xticks(rotation=90)
plt.bar(range(max_aur.shape[0]), max_aur, tick_label=tags)
plt.grid(which='both', axis='y')

In [None]:
dataloader_plot = torch.utils.data.DataLoader(dataset_eval, batch_size=eval_bs, shuffle=False, drop_last=False)

In [None]:
# choose an Attribute to run through the Decoder
attr_index = 5
ind_max = argmax_aur[attr_index]

invn = celeba_inorm
num = 8
# using 10 samples from the dataloader, plot 10 rows of the varied attributes
# decode
fig, axes = plt.subplots(num, num, figsize=(16, 16))
ae.to(device)
with torch.no_grad():
    data, targ = next(iter(dataloader_plot))
    data = data.to(device)
    out = ae(data)
    _min = ae.z.min(dim=0)[0]
    _max = ae.z.max(dim=0)[0]
    variation = torch.linspace(_min[ind_max], _max[ind_max], steps=num)
    for i in range(num):
        z_base = ae.z[i + 41]
        for j in range(len(variation)):
            axes[i][j].tick_params(axis='both', which='both', bottom=False, top=False, labelbottom=False, left=False, labelleft=False)
            z = z_base.clone()
            z[ind_max] = variation[j]
            im = multi_t(invn(ae.dec(z.to(ae.device))), 1, 3).clamp(0, 1).squeeze().cpu().numpy()
            axes[i][j].imshow(im)
plt.tight_layout()

In [None]:
thresh = 0.75
ent_red_thresh = 0.2

# calculate Average Norm AUROC Diff when best detector score is at a certain threshold
filt = (max_aur >= thresh).logical_and(ent_red_prop <= ent_red_thresh)

In [None]:

norm_diffs_filt = norm_diffs[filt]
print(len(norm_diffs_filt))
for ind in torch.arange(filt.shape[0])[filt]:
    print(tags[ind], argmax_aur[ind].item())

plt.figure(figsize=(10, 7))
plt.ylim((0.0, 1.0))
plt.title("Average Norm AUROC Diff: {:1.3f} at Thresh: {:1.2f}".format(norm_diffs_filt.mean(), thresh))
plt.ylabel("Normed AUROC Difference")
plt.xlabel("Attribute")
plt.xticks(rotation=90)
plt.bar(range(norm_diffs.shape[0]), norm_diffs, tick_label=tags)
plt.grid(which='both', axis='y')

In [None]:
# calculate Average Norm AUROC Diff when best detector score is at a certain threshold
aurs_diffs_filt = aurs_diffs[filt]
print(len(aurs_diffs_filt))

plt.figure(figsize=(10, 7))
plt.ylim((0.0, 1.0))
plt.title("Total AUROC Diff: {:1.3f} at Thresh: {:1.2f}".format(aurs_diffs_filt.sum(), thresh))
plt.ylabel("AUROC Difference")
plt.xlabel("Attribute")
plt.xticks(rotation=90)
plt.bar(range(aurs_diffs.shape[0]), aurs_diffs, tick_label=tags)
plt.grid(which='both', axis='y')

In [None]:
print("TAD SCORE: ", aurs_diffs_filt.sum().item(), "Attributes Captured: ", len(aurs_diffs_filt))