In [None]:
import glob
import cv2
import h5py
import time
from IPython import display as ipythondisplay
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from pathlib import Path

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.utils.data import Dataset


# Check that we are using a GPU, if not switch runtimes
#   using Runtime > Change Runtime Type > GPU
if torch.cuda.is_available():
    device = torch.device("cuda")
    cudnn.benchmark = True
else:
  raise ValueError("GPU is not available")

In [None]:
CACHE_DIR = Path.cwd() / ".cache"
CACHE_DIR.mkdir(parents=True, exist_ok=True)

path_to_training_data = CACHE_DIR.joinpath("train_face.h5")

if path_to_training_data.is_file():
    print(f"Using cached training data from {path_to_training_data}")
else:
    print(f"Downloading training data to {path_to_training_data}")
    url = "https://www.dropbox.com/s/hlz8atheyozp1yx/train_face.h5?dl=1"
    torch.hub.download_url_to_file(url, path_to_training_data)

In [None]:
class TrainDatasetLoader(Dataset):
    def __init__(self, data_path, channels_last=True):
        print(f"Opening {data_path}")
        self.cache = h5py.File(data_path, "r")
        self.images = self.cache["images"][:]
        self.labels = self.cache["labels"][:].astype(np.float32)
        self.channels_last = channels_last
        self.image_dims = self.images.shape

        n_train_samples = self.image_dims[0]
        #Array of n_train_samples shuffled randomly
        self.train_inds = np.random.permutation(np.arange(n_train_samples))
        self.pos_train_inds = self.train_inds[self.labels[self.train_inds, 0] == 1.0]
        self.neg_train_inds = self.train_inds[self.labels[self.train_inds, 0] != 1.0]

    def __len__(self):
        return len(self.train_inds)

    def __getitem__(self, idx):
        img = self.images[idx]
        label = self.labels[idx]

        # normalize to [0,1]
        img = (img[:, :, ::-1] / 255.0).astype(np.float32)

        if not self.channels_last:  # convert to [H, W, C] to [C,H,W]
            img = np.transpose(img, (2,0,1))

        return torch.tensor(img), torch.tensor(label)

    def get_train_steps_per_epoch(self, batch_size, factor=10):
        return self.__len__() // factor // batch_size

    def get_batch(self, n, only_faces=False, p_pos=None, p_neg=None, return_inds=False):
        if only_faces:
            select_inds = np.random.choice(
                self.pos_train_inds, size=n, replace=False, p=p_pos)
        else:
            selected_pos_inds = np.random.choice(
                self.pos_train_inds, size=n//2, replace=False, p=p_pos)
            selected_neg_inds = np.random.choice(
                self.neg_train_inds, size=n//2, replace=False, p=p_neg)
            selected_inds = np.concatenate((selected_pos_inds, selected_neg_inds))

        sorted_inds = np.sort(selected_inds)
        train_img = (self.images[sorted_inds, :,:, ::-1] / 255.0).astype(np.float32)
        train_label = self.labels[sorted_inds, ...]

        if not self.channels_last:
            train_img = np.ascontiguousarray(
                np.transpose(train_img, (0,3,1,2)))
        return (
            (train_img, train_label, sorted_inds)
            if return_inds
            else (train_img, train_label))

    def get_n_most_prob_faces(self, prob, n):
        """
        From the positive training set, sort by probability, look at the top 10n, 
        take every 10th one to get n images, normalize them, and return them.
        """
        idx = np.argsort(prob)[::-1]
        most_prob_inds = self.pos_train_inds[idx[: 10 * n : 10]]
        return (self.images[most_prob_inds, ...] / 255.0).astype(np.float32)

    def get_all_train_faces(self):
        return self.images[self.pos_train_inds]

loader = TrainDatasetLoader(path_to_training_data, channels_last=False)


In [None]:
### Examining the CelebA training dataset ###

# @title Change the sliders to look at positive and negative training examples! { run: "auto" }
number_of_training_examples = len(loader)
(images, labels) = loader.get_batch(100)
B, C, H, W = images.shape
print(B, C, H, W)
face_images = images[np.where(labels == 1)[0]].transpose(0, 2, 3, 1) #(N, C, H, W)  →  (N, H, W, C)
not_face_images = images[np.where(labels == 0)[0]].transpose(0, 2, 3, 1)

face_indices = np.random.choice(len(face_images), 4, replace=False)
not_face_indices = np.random.choice(len(not_face_images), 4, replace=False)

# Create a 2x4 subplot grid
fig, axes = plt.subplots(2, 4, figsize=(8, 8))

# Display face images in the first row
for i, idx in enumerate(face_indices):
    axes[0, i].imshow(face_images[idx])
    axes[0, i].set_title(f"Face #{idx}")
    axes[0, i].grid(False)
    axes[0, i].axis('off')

# Display non-face images in the second row
for i, idx in enumerate(not_face_indices):
    axes[1, i].imshow(not_face_images[idx])
    axes[1, i].set_title(f"Not Face #{idx}")
    axes[1, i].grid(False)
    axes[1, i].axis('off')

plt.tight_layout()
plt.show()

In [None]:
n_filters = 12
in_channels = images.shape[1]
def make_standard_classifier(n_outputs):
    """Standard CNN classifier."""

    class ConvBlock(nn.Module):
        def __init__(self, in_channels, out_channels, kernel_size, stride, padding=0):
            super().__init__()
            self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
            self.bn = nn.BatchNorm2d(out_channels)
            self.relu = nn.ReLU(inplace=True)

        def forward(self, x):
            x = self.conv(x)
            x = self.bn(x)
            x = self.relu(x)
            return x
    
    model = nn.Sequential(
        ConvBlock(in_channels, n_filters, kernel_size=5, stride=2, padding=2),
        ConvBlock(n_filters, 2*n_filters, kernel_size=5, stride=2, padding=2),
        ConvBlock(2*n_filters, 4*n_filters, kernel_size=3, stride=2, padding=1),
        ConvBlock(4*n_filters, 8*n_filters, kernel_size=3, stride=2, padding=1),
        nn.Flatten(),
        nn.Linear(H // 16 * W // 16 * 8 * n_filters, 512),
        nn.ReLU(inplace=True),
        nn.Linear(512, n_outputs),
        )
    return model.to(device)

standard_classifier = make_standard_classifier(n_outputs=1)
print(standard_classifier)

In [None]:
class LossHistory:
    def __init__(self, smoothing_factor=0.0):
        self.alpha = smoothing_factor
        self.loss = []

    def append(self, value):
        self.loss.append(
            self.alpha * self.loss[-1] + (1 - self.alpha) * value
            if len(self.loss) > 0
            else value
        )

    def get(self):
        return self.loss

class PeriodicPlotter:
    def __init__(self, sec, xlabel="", ylabel="", scale=None):
        self.xlabel = xlabel
        self.ylabel = ylabel
        self.sec = sec
        self.scale = scale

        self.tic = time.time()

    def plot(self, data):
        if time.time() - self.tic > self.sec:
            plt.cla()

            if self.scale is None:
                plt.plot(data)
            elif self.scale == "semilogx":
                plt.semilogx(data)
            elif self.scale == "semilogy":
                plt.semilogy(data)
            elif self.scale == "loglog":
                plt.loglog(data)
            else:
                raise ValueError("unrecognized parameter scale {}".format(self.scale))

            plt.xlabel(self.xlabel)
            plt.ylabel(self.ylabel)
            ipythondisplay.clear_output(wait=True)
            ipythondisplay.display(plt.gcf())

            self.tic = time.time()

In [None]:
### Train the standard CNN ###
loss_fn = nn.BCEWithLogitsLoss()
# Training hyperparameters
params = dict(
    batch_size=32,
    num_epochs=2,  # keep small to run faster
    learning_rate=5e-4,
)

optimizer = optim.Adam(
    standard_classifier.parameters(), lr=params["learning_rate"]
)  # define our optimizer
if hasattr(tqdm, "_instances"):
    tqdm._instances.clear()  # clear if it exists

# set the model to train mode
standard_classifier.train()


def standard_train_step(x, y):
    x = torch.from_numpy(x).float().to(device)
    y = torch.from_numpy(y).float().to(device)

    # clear the gradients
    optimizer.zero_grad()

    # feed the images into the model
    logits = standard_classifier(x)
    # Compute the loss
    loss = loss_fn(logits, y)

    # Backpropagation
    loss.backward()
    optimizer.step()

    return loss


# The training loop!
step = 0
loss_history = LossHistory(smoothing_factor=0.99)
plotter = PeriodicPlotter(sec=2, scale="semilogy")
for epoch in range(params["num_epochs"]):
    for idx in tqdm(range(len(loader) // params["batch_size"])):
        # Grab a batch of training data and propagate through the network
        x, y = loader.get_batch(params["batch_size"])
        loss = standard_train_step(x, y)
        loss_value = loss.detach().cpu().numpy()
        loss_history.append(loss_value)
        plotter.plot(loss_history.get())
        step += 1

In [None]:
standard_classifier.eval()

#Evaluate on a subset of CelebA+Imagenet
(batch_x, batch_y) = loader.get_batch(5000)
batch_x = torch.from_numpy(batch_x).float().to(device)
batch_y = torch.from_numpy(batch_y).float().to(device)

with torch.inference_mode():
    y_pred_logits = standard_classifier(batch_x)
    y_pred_std = torch.round(torch.sigmoid(y_pred_logits))
    acc_std = torch.mean((batch_y == y_pred_std).float())

print("Standard CNN acuracy training set (biased): {:.4f}".format(acc_std.item()))

In [None]:
def get_test_faces(channels_last = True):
    images = {"LF": [], "LM": [], "DF": [], "DM": []}
    for key in images.keys():
        FILES_PATH = Path.cwd() / "data"/ "faces"/ key/ "*.png"
        files = glob.glob(str(FILES_PATH))
        for file in sorted(files):
            image = cv2.resize(cv2.imread(file), (64,64))[:,:,::-1] / 255.0
            if not channels_last:
                image = np.transpose(image, (2,0,1))
            images[key].append(image)
    return images["LF"], images["LM"], images["DF"], images["DM"]

In [None]:
test_faces = get_test_faces(channels_last=False)
keys = ["Light Female", "Light Male", "Dark Female", "Dark Male"]

fig, axs = plt.subplots(1, len(keys), figsize=(7.5, 7.5))
for i, (group, key) in enumerate(zip(test_faces, keys)):
    axs[i].imshow(np.hstack(group).transpose(1,2,0))
    axs[i].set_title(key, fontsize=15)
    axs[i].axis("off")

In [None]:
std_classfier_prob_list = []

with torch.inference_mode():
    for x in test_faces:
        x = torch.from_numpy(np.array(x, dtype=np.float32)).to(device)
        logits = standard_classifier(x)
        probs = torch.sigmoid(logits)
        probs = torch.squeeze(probs, dim=-1)
        std_classfier_prob_list.append(probs.cpu().numpy())

std_classfier_probs = np.stack(std_classfier_prob_list, axis=0)

x_keys = range(len(keys))
y_prob = std_classfier_probs.mean(axis=1)
plt.bar(x_keys, y_prob)
plt.xticks(x_keys, keys)
plt.title("Standard Classifier predictions")

In [None]:
def vae_loss_function(x, x_recon, mu, logsima, kl_weight=0.0005):
    """
    VAE Loss Function
        Computes the loss for the Variational Autoencoder (VAE) model.
        The loss is a combination of the reconstruction loss and the KL divergence loss.
        Args:
            x (torch.Tensor): Original input images.
            x_recon (torch.Tensor): Reconstructed images from the VAE.
            mu (torch.Tensor): Mean of the latent variable distribution.
            logsima (torch.Tensor): Log variance of the latent variable distribution.
            kl_weight (float): Weight for the KL divergence loss term.
        Returns:
            torch.Tensor: Total loss (reconstruction + KL divergence).

    """
    latent_loss = 0.5 * torch.sum(torch.exp(logsima) + mu**2 - 1.0 - logsima)
    reconstruction_loss = torch.mean(torch.abs(x - x_recon  ))
    return kl_weight * latent_loss + reconstruction_loss

In [None]:
def sampling_reparameterization(mu, logsima):
    """
    Reparameterization trick to sample from N(mu, sigma^2) from N(0,1) (Gaussian distribution).
    Args:
        mu (torch.Tensor): Mean of the latent variable distribution.
        logsima (torch.Tensor): Log variance of the latent variable distribution.
    Returns:
        torch.Tensor: Sampled latent variable.
    """
    eps = torch.randn_like(mu)
    sigma = torch.exp(logsima)
    return mu + sigma * eps

In [None]:
def debiasing_loss_function(x, x_pred, y, y_logits, mu, logsigma):
    """
    DV-VAE Loss function
        Computes the loss function for the Debiased Variation Autoencoder Model.
        The total loss is the mean combination of the classification loss and the VAE loss if the classification true result is a face.
        Args:
            x (torch.Tensor): Original input images.
            x_pred (torch.Tensor): Reconstructed images from the VAE.
            y (torch.Tensor): True labels for the input images.
            y_logits (torch.Tensor): Predicted logits from the classifier.
            mu (torch.Tensor): Mean of the latent variable distribution.
            logsigma (torch.Tensor): Log variance of the latent variable distribution.
        Returns:
            torch.Tensor: Total loss (classification + VAE loss for faces).
            torch.tensor: Classification loss.
    """
    vae_loss = vae_loss_function(x, x_pred, mu, logsigma, 0.0005)
    classification_loss = F.binary_cross_entropy_with_logits(y_logits, y, reduction='none')

    #Which training data are images of faces
    y.float()
    face_indicator = (y == 1.0).float()

    total_loss = torch.mean(classification_loss * face_indicator + vae_loss) #If face_indicator == 0, it will not consider the vae_loss since is not a face

    return total_loss, classification_loss

In [None]:
def make_face_decoder_network(latent_dim=128, n_filters=12):
    """
    Decoder network of a VAE model.
    Args:
        latent_dim (int): Dimension of the latent space.
        n_filters (int): Number of filters in the convolutional layers.
    Returns:
        FaceDecoder (nn.Module): Decoder network model.
    """
    class FaceDecoder(nn.Module):
        def __init__(self, latent_dim, n_filters):
            super(FaceDecoder, self).__init__()
            self.latent_dim = latent_dim
            self.n_filters = n_filters
            self.linear = nn.Sequential(nn.Linear(latent_dim, 8 * self.n_filters * 4 * 4), nn.ReLU())
        
            self.deconv = nn.Sequential(
                nn.ConvTranspose2d(
                    in_channels=8 * self.n_filters,
                    out_channels=4 * self.n_filters,
                    kernel_size=3,
                    stride=2,
                    padding=1,
                    output_padding=1,
                ),
                nn.ReLU(),
                nn.ConvTranspose2d(
                    in_channels=4 * self.n_filters,
                    out_channels=2 * self.n_filters,
                    kernel_size=3,
                    stride=2,
                    padding=1,
                    output_padding=1,
                ),
                nn.ReLU(),
                nn.ConvTranspose2d(
                    in_channels=2 * self.n_filters,
                    out_channels=self.n_filters,
                    kernel_size=5,
                    stride=2,
                    padding=2,
                    output_padding=1,
                ),
                nn.ReLU(),
                nn.ConvTranspose2d(
                    in_channels=self.n_filters,
                    out_channels=3,
                    kernel_size=5,
                    stride=2,
                    padding=2,
                    output_padding=1,
                ),
            )
        def forward(self, z):
            x = self.linear(z)
            x = x.view(-1, 8*self.n_filters, 4, 4)
            x = self.deconv(x)
            return x
    return FaceDecoder(latent_dim, n_filters)

In [None]:
class DB_VAE(nn.Module):
    """
    Debiased Variational Autoencoder Model.
    Args:
        latent_dim (int): Dimension of the latent space.
    Returns:
        DB_VAE (nn.Module): Debiased VAE model.
    """
    def __init__(self, latent_dim=128):
        super(DB_VAE, self).__init__()
        self.latent_dim = latent_dim
        self.encoder = make_standard_classifier(n_outputs=2*latent_dim+1)
        self.decoder = make_face_decoder_network(latent_dim=latent_dim)

    def encode(self, x):
        encoder_out = self.encoder(x)

        y_logit = encoder_out[:, 0].unsqueeze(-1)
        z_mu = encoder_out[:, 1 : self.latent_dim + 1]
        z_logsigma = encoder_out[:, self.latent_dim + 1 :]
        return y_logit, z_mu, z_logsigma
    
    def reparameterize(self, z_mu, z_logsigma):
        return sampling_reparameterization(z_mu, z_logsigma)

    def decode(self, z):
        return self.decoder(z)
    
    def forward(self, x):
        y_logit, z_mu, z_logsigma = self.encode(x)
        z = self.reparameterize(z_mu, z_logsigma)
        recon = self.decode(z)
        return y_logit, z_mu, z_logsigma, recon
    
    def predict(self, x):
        y_logit, _, _ = self.encode(x)
        return y_logit


In [None]:
def get_latent_mu(images, dbvae):
    """
    Get the latent mean vectors for a set of images using the DB-VAE model.
    Args:
        images (np.ndarray): Input images.
        dbvae (DB_VAE): Debiased VAE model.
    Returns:
        mu (np.ndarray): Latent mean vectors for the input images.
    """
    dbvae.eval()
    all_z_mean = []
    images_t = torch.from_numpy(images).float()

    with torch.inference_mode():
        for start in range(0, len(images_t), params["batch_size"]):
            end = start + params["batch_size"]
            batch = images_t[start:end].to(device).permute(0, 3, 1, 2)
            _, z_mean, _, _ = dbvae(batch)
            all_z_mean.append(z_mean.detach().cpu())
    
    print("Number of batches:", len(all_z_mean))
    z_mean_full = torch.cat(all_z_mean, dim=0)
    mu = z_mean_full.numpy()
    return mu

In [None]:
def get_training_sample_prob(images, dbvae, bins=10, smoothing_fac=0.001):
    """
    Calculates the probability chance of images under represented in the latent variables
    distribution, favoring images who are rarer in the distribution.
    Args:
        images (np.ndarray): Input images sample.
        dbvae (DB_VAE): Debiased VAE model.
        batch_size (int): Batch size for processing images.
        bins (int): Number of histogram intervals used to estimate the latent variable distribution along each dimension.
        smoothing_fac (float): Small float value to avoid zeros
    Returns:
        training_sample_p (np.ndarray): Sample probabilities where rare samples in the distribution have a high probability 
    """
    print("Recomputing the sample probabilities")

    mu = get_latent_mu(images, dbvae)
    training_sample_p = np.zeros(mu.shape[0], dtype=np.float64)

    for i in range(dbvae.latent_dim):
        latent_distribution = mu[:, i]

        hist_density, bin_edges = np.histogram(latent_distribution, density=True, bins=bins)

        bin_edges[0] = -float("inf")
        bin_edges[-1] = float("inf")

        bin_idx = np.digitize(latent_distribution, bin_edges)

        hist_smoothed_density = hist_density + smoothing_fac
        hist_smoothed_density /= np.sum(hist_smoothed_density)

        p = 1.0 / (hist_smoothed_density[bin_idx -1])

        p /= np.sum(p)

        training_sample_p = np.maximum(training_sample_p, p)
    
    training_sample_p /= np.sum(training_sample_p)
    return training_sample_p

In [None]:
params = dict(
    batch_size=32,
    learning_rate=5e-4,
    latent_dim=144,
    num_epochs=2,
)

dbvae = DB_VAE(params["latent_dim"]).to(device)
optimizer = optim.Adam(dbvae.parameters(), lr=params["learning_rate"])

def debiasing_train_step(x, y):
    optimizer.zero_grad()

    y_logit, z_mean, z_logsigma, x_recon = dbvae(x)

    loss, _ = debiasing_loss_function(x, x_recon, y, y_logit, z_mean, z_logsigma)
    loss.backward()
    optimizer.step()
    return loss

all_faces = loader.get_all_train_faces()

step=0
for i in range(params["num_epochs"]):
    print("Starting epoch {}/{}".format(i+1, params["num_epochs"]))
    p_faces = get_training_sample_prob(all_faces, dbvae)

    for j in tqdm(range(len(loader)//params["batch_size"])):
        (x, y) = loader.get_batch(params["batch_size"], p_pos=p_faces)
        x = torch.from_numpy(x).float().to(device)
        y = torch.from_numpy(y).float().to(device)

        loss = debiasing_train_step(x, y)
        loss_value = loss.detach().cpu().numpy()
        
        step += 1

In [None]:
dbvae.to(device)
dbvae_logits_list = []
with torch.inference_mode():
    for face in test_faces:
        face = torch.from_numpy(np.array(face, dtype=np.float32)).to(device)
        logits = dbvae.predict(face)
        dbvae_logits_list.append(logits.detach().cpu().numpy())

dbvae_logits_array = np.concatenate(dbvae_logits_list, axis=0)
dbvae_logits_tensor = torch.from_numpy(dbvae_logits_array)
dbvae_probs_tensor = torch.sigmoid(dbvae_logits_tensor)
dbvae_probs_array = dbvae_probs_tensor.squeeze(dim=-1).numpy()

xx = np.arange(len(keys))
std_probs_mean = std_classfier_probs.mean(axis=1)
dbvae_probs_mean = dbvae_probs_array.reshape(len(keys), -1).mean(axis=1)

plt.bar(xx, std_probs_mean, width=0.2, label="Standard CNN")
plt.bar(xx + 0.2, dbvae_probs_mean, width=0.2, label="DB-VAE")

plt.xticks(xx, keys)
plt.title("Network predictions on test dataset")
plt.ylabel("Probability")
plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
plt.show()