<a href="https://colab.research.google.com/github/mmfara/Adversarial-Debiasing-Extended/blob/main/Adversarial_Debiasing_Extended.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Enhanced Adversarial Debiasing with Intersectionality, Validation Metrics, and Clean Dropout API

import numpy as np
import tensorflow.compat.v1 as tf
import logging

from aif360.algorithms import Transformer

# Disable TensorFlow 2.x behavior
tf.disable_v2_behavior()

class AdversarialDebiasing(Transformer):
    def __init__(self,
                 unprivileged_groups,
                 privileged_groups,
                 scope_name,
                 sess,
                 seed=None,
                 adversary_loss_weight=0.1,
                 num_epochs=50,
                 batch_size=128,
                 classifier_num_hidden_units=200,
                 dropout=0.8,
                 early_stopping_patience=5,
                 validation_dataset=None,
                 verbose=True,
                 debias=True):

        super().__init__(
            unprivileged_groups=unprivileged_groups,
            privileged_groups=privileged_groups)

        self.scope_name = scope_name
        self.seed = seed
        self.sess = sess
        self.adversary_loss_weight = adversary_loss_weight
        self.num_epochs = num_epochs
        self.batch_size = batch_size
        self.classifier_num_hidden_units = classifier_num_hidden_units
        self.dropout = dropout
        self.early_stopping_patience = early_stopping_patience
        self.validation_dataset = validation_dataset
        self.verbose = verbose
        self.debias = debias

        if self.verbose:
            logging.basicConfig(level=logging.INFO)

        if self.seed is not None:
            np.random.seed(self.seed)
        self.seed1, self.seed2, self.seed3, self.seed4 = np.random.randint(1, 9999, 4)

    def _encode_protected_attributes(self, dataset):
        indices = [dataset.feature_names.index(attr) for attr in dataset.protected_attribute_names]
        values = dataset.features[:, indices]
        tuples = [tuple(row) for row in values]
        unique_combos = sorted(set(tuples))
        self.combo_to_class = {combo: i for i, combo in enumerate(unique_combos)}
        encoded = np.array([self.combo_to_class[t] for t in tuples], dtype=np.int32)
        return encoded, len(unique_combos)

    def _classifier_model(self, features, features_dim, keep_prob):
        with tf.variable_scope("classifier_model"):
            W1 = tf.get_variable('W1', [features_dim, self.classifier_num_hidden_units],
                                 initializer=tf.initializers.glorot_uniform(seed=self.seed1))
            b1 = tf.Variable(tf.zeros([self.classifier_num_hidden_units]))
            h1 = tf.nn.relu(tf.matmul(features, W1) + b1)
            h1 = tf.nn.dropout(h1, rate=1 - keep_prob, seed=self.seed2)

            W2 = tf.get_variable('W2', [self.classifier_num_hidden_units, 1],
                                 initializer=tf.initializers.glorot_uniform(seed=self.seed3))
            b2 = tf.Variable(tf.zeros([1]))
            logits = tf.matmul(h1, W2) + b2
            pred = tf.sigmoid(logits)
        return pred, logits

    def _adversary_model(self, pred_logits, true_labels, num_classes):
        with tf.variable_scope("adversary_model"):
            s = tf.sigmoid(pred_logits)
            input_concat = tf.concat([s, s * true_labels, s * (1.0 - true_labels)], axis=1)
            W = tf.get_variable('W_adv', [input_concat.shape[1], num_classes],
                                initializer=tf.initializers.glorot_uniform(seed=self.seed4))
            b = tf.Variable(tf.zeros([num_classes]))
            logits = tf.matmul(input_concat, W) + b
            preds = tf.nn.softmax(logits)
        return preds, logits

    def fit(self, dataset):
        if tf.executing_eagerly():
            raise RuntimeError("AdversarialDebiasing does not work in eager execution mode.")

        temp_labels = dataset.labels.copy()
        temp_labels[(dataset.labels == dataset.favorable_label).ravel(), 0] = 1.0
        temp_labels[(dataset.labels == dataset.unfavorable_label).ravel(), 0] = 0.0

        prot_attr_encoded, num_classes = self._encode_protected_attributes(dataset)
        num_samples, self.features_dim = dataset.features.shape
        best_val_loss = float('inf')
        patience_counter = 0

        with tf.variable_scope(self.scope_name):
            self.features_ph = tf.placeholder(tf.float32, [None, self.features_dim])
            self.true_labels_ph = tf.placeholder(tf.float32, [None, 1])
            self.protected_ph = tf.placeholder(tf.int32, [None])
            self.keep_prob = tf.placeholder(tf.float32)

            self.pred_labels, pred_logits = self._classifier_model(self.features_ph, self.features_dim, self.keep_prob)
            pred_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=self.true_labels_ph, logits=pred_logits))

            if self.debias:
                adv_preds, adv_logits = self._adversary_model(pred_logits, self.true_labels_ph, num_classes)
                adv_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=self.protected_ph, logits=adv_logits))

            classifier_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=f"{self.scope_name}/classifier_model")
            if self.debias:
                adversary_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=f"{self.scope_name}/adversary_model")
                adversary_opt = tf.train.AdamOptimizer(0.001)
                adv_grads = {v: g for g, v in adversary_opt.compute_gradients(adv_loss, classifier_vars)}
                normalize = lambda x: x / (tf.norm(x) + 1e-8)

            classifier_opt = tf.train.AdamOptimizer(0.001)
            classifier_grads = []
            for grad, var in classifier_opt.compute_gradients(pred_loss, classifier_vars):
                if self.debias:
                    u = normalize(adv_grads[var])
                    grad = grad - tf.reduce_sum(grad * u) * u - self.adversary_loss_weight * adv_grads[var]
                classifier_grads.append((tf.clip_by_value(grad, -5.0, 5.0), var))

            classifier_step = classifier_opt.apply_gradients(classifier_grads)
            if self.debias:
                with tf.control_dependencies([classifier_step]):
                    adversary_step = adversary_opt.minimize(adv_loss, var_list=adversary_vars)

            self.sess.run(tf.global_variables_initializer())

            for epoch in range(self.num_epochs):
                epoch_loss = 0
                shuffled = np.random.permutation(num_samples)

                for i in range(0, num_samples, self.batch_size):
                    batch = shuffled[i:i + self.batch_size]
                    feed = {
                        self.features_ph: dataset.features[batch],
                        self.true_labels_ph: temp_labels[batch].reshape(-1, 1),
                        self.protected_ph: prot_attr_encoded[batch],
                        self.keep_prob: self.dropout
                    }
                    if self.debias:
                        _, _, l_cls, l_adv = self.sess.run([classifier_step, adversary_step, pred_loss, adv_loss], feed_dict=feed)
                        epoch_loss += l_cls + l_adv
                    else:
                        _, l_cls = self.sess.run([classifier_step, pred_loss], feed_dict=feed)
                        epoch_loss += l_cls

                if self.verbose:
                    logging.info(f"Epoch {epoch}: Train Loss = {epoch_loss:.4f}")

                if self.validation_dataset:
                    val_loss = self._evaluate_loss(self.validation_dataset)
                    if self.verbose:
                        logging.info(f"Validation Loss = {val_loss:.4f}")
                    if val_loss < best_val_loss:
                        best_val_loss = val_loss
                        patience_counter = 0
                    else:
                        patience_counter += 1
                        if patience_counter >= self.early_stopping_patience:
                            logging.info("Early stopping triggered.")
                            break

        return self

    def _evaluate_loss(self, val_dataset):
        temp_labels = val_dataset.labels.copy()
        temp_labels[(val_dataset.labels == val_dataset.favorable_label).ravel(), 0] = 1.0
        temp_labels[(val_dataset.labels == val_dataset.unfavorable_label).ravel(), 0] = 0.0

        prot_attr_encoded, _ = self._encode_protected_attributes(val_dataset)

        feed = {
            self.features_ph: val_dataset.features,
            self.true_labels_ph: temp_labels,
            self.protected_ph: prot_attr_encoded,
            self.keep_prob: 1.0
        }

        preds = self.sess.run(self.pred_labels, feed_dict=feed)
        return np.mean(np.square(preds - temp_labels))

    def predict(self, dataset):
        temp_labels = dataset.labels.copy()
        temp_labels[(dataset.labels == dataset.favorable_label).ravel(), 0] = 1.0
        temp_labels[(dataset.labels == dataset.unfavorable_label).ravel(), 0] = 0.0

        num_samples = len(dataset.features)
        pred_labels = []
        for i in range(0, num_samples, self.batch_size):
            batch_features = dataset.features[i:i + self.batch_size]
            batch_labels = np.reshape(temp_labels[i:i + self.batch_size], [-1, 1])

            feed_dict = {
                self.features_ph: batch_features,
                self.true_labels_ph: batch_labels,
                self.keep_prob: 1.0
            }

            pred_labels += self.sess.run(self.pred_labels, feed_dict=feed_dict)[:, 0].tolist()

        dataset_new = dataset.copy(deepcopy=True)
        dataset_new.scores = np.array(pred_labels, dtype=np.float64).reshape(-1, 1)
        dataset_new.labels = (np.array(pred_labels) > 0.5).astype(np.float64).reshape(-1, 1)

        dataset_new.labels[(dataset_new.labels == 1.0).ravel(), 0] = dataset.favorable_label
        dataset_new.labels[(dataset_new.labels == 0.0).ravel(), 0] = dataset.unfavorable_label

        return dataset_new