Main reference: https://github.com/choprashweta/Adversarial-Debiasing

They also referenced this blog post: https://godatadriven.com/blog/towards-fairness-in-ml-with-adversarial-networks/

And I guess both may be based on the 2018 paper: https://arxiv.org/pdf/1801.07593.pdf

In [None]:
# Install HuggingFace transformers library to match baseline

!pip install transformers[sentencepiece]==4.18.0 datasets --q

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.0/4.0 MB[0m [31m39.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m469.0/469.0 KB[0m [31m27.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m880.6/880.6 KB[0m [31m42.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.6/6.6 MB[0m [31m21.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m199.8/199.8 KB[0m [31m8.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m10.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m18.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m212.2/212.2 KB[0m [31m9.7 MB/s[0m et

In [None]:
import pandas as pd
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
import torch.optim as optim
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, recall_score, precision_score, confusion_matrix
import warnings

warnings.filterwarnings('ignore')

In [None]:
# Specify GPU device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device:", device)

if torch.cuda.is_available():
    torch.cuda.empty_cache()

Device: cuda:0


In [None]:
from google.colab import drive

# Connect to Google Drive
drive.mount('/content/drive')

# Read subset (100 items) csv file from Google Drive
# data = pd.read_csv('/content/drive/Shareddrives/AI539: NLP with Deep Learning/subset_baseline_data.csv')

# Training data will be split into 75% train and 25% validation

# Gender as protected attribute
# train_data = pd.read_csv('/content/drive/Shareddrives/AI539: NLP with Deep Learning/train_baseline_data.csv')
# test_data = pd.read_csv('/content/drive/Shareddrives/AI539: NLP with Deep Learning/test_baseline_data.csv')

# Race as protected attribute
# train_data = pd.read_csv('/content/drive/Shareddrives/AI539: NLP with Deep Learning/train_race.csv')
# test_data = pd.read_csv('/content/drive/Shareddrives/AI539: NLP with Deep Learning/test_race.csv')

# Religion as protected attribute
train_data = pd.read_csv('/content/drive/My Drive/train_religion.csv')
test_data = pd.read_csv('/content/drive/My Drive/test_religion.csv')

# train_data.head()

Mounted at /content/drive


In [None]:
# test_data.head()

In [None]:
# For the race and religion datasets
mapping = {"Unnamed: 0": "unnamed", "target": "target", "comment_text": "comment"}
train_data = train_data.rename(columns=mapping)
test_data = test_data.rename(columns=mapping)

In [None]:
train_data.head()

Unnamed: 0,unnamed,target,comment,muslim,christian,religion
0,185,0,"Should they have to? \n\nChristians, whether a...",0,1,0
1,847,0,The Militia holding a anti-Government stand-of...,0,1,0
2,933,0,In regard to the WW review of the Glenn Beck a...,0,1,0
3,1073,0,When you can figure out that turbans are worn ...,1,0,1
4,1117,0,What is uniquely and exclusively possible in a...,0,1,0


In [None]:
test_data.head()

Unnamed: 0,unnamed,target,comment,muslim,christian,religion
0,1241637,0,"Last I checked, the Catholic Church was open t...",0,1,0
1,1241662,0,Strange you would present a misrepresentation ...,1,0,1
2,1241953,0,Thank you Matt for placing yjin117 heartfelt p...,0,1,0
3,1241956,0,Most Christians I've known (not all) aren't th...,0,1,0
4,1242089,0,Cardinal Cupich has made the point that respec...,0,1,0


In [None]:
# PyTorch Dataset class for our data

class ToxicityDataset(Dataset):
    def __init__(self, all_data, split="train", tokenizer=None):
        self.raw_data = None
        self.encodings = None
        self.target_labels = None
        self.protected_labels = None
        self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') if tokenizer is None else tokenizer

        # do data cleaning, tokenizing, make sure all are tensors

        # train set: 75% train, 25% val, test set: 100% test
        split_index = round(0.75*len(all_data))
        if split == "test":
            # No modifications, use all of the test data
            self.raw_data = all_data
        elif split == "train":
            # 75% of train data for training
            self.raw_data = pd.DataFrame(all_data.loc[:split_index, :]).reset_index()
        else:
            # 25% of train data for validation
            self.raw_data = pd.DataFrame(all_data.loc[split_index:, :]).reset_index()

        # drop empty comments
        self.raw_data['comment'].dropna(inplace=True)

        # preprocess comments
        self.raw_data.loc[:, 'comment'] = self.raw_data['comment'].apply(self._preprocess_comments)

        # tokenize using BERT
        comments = list(self.raw_data['comment'].copy())
        self.encodings = tokenizer(comments, truncation=True, max_length=128, padding="max_length")['input_ids']  # maybe we want everything and not just input_ids?

        # convert targets from a range to binary classes
        self.target_labels = self.raw_data['target'].apply(lambda x: 1 if x >= 0.5 else 0)

        # currently the only protected attribute is "female", but for intersectional
        # bias we'd probably want multiple protected attributes
        # self.protected_labels = self.raw_data['female']
        # self.protected_labels = self.raw_data['race']
        self.protected_labels = self.raw_data['religion']


    def _preprocess_comments(self, comment):
        new_tokens = []
        for token in comment.split(" "):
            # replace usernames with something generic
            token = '@user' if token.startswith('@') and len(token) > 1 else token

            # replace URLs with something generic
            token = 'http' if token.startswith('http') else token

            new_tokens.append(token)
        return " ".join(new_tokens)


    def __getitem__(self, index):
        return (
            torch.tensor(self.encodings[index]),
            torch.tensor(self.target_labels[index]),
            torch.tensor(self.protected_labels[index])
        )

    def __len__(self):
        return len(self.target_labels)


tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
train_dataset = ToxicityDataset(train_data, split="train", tokenizer=tokenizer)
val_dataset = ToxicityDataset(train_data, split="val", tokenizer=tokenizer)
test_dataset = ToxicityDataset(test_data, split="test", tokenizer=tokenizer)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)


Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/455k [00:00<?, ?B/s]

In [None]:
print("Train dataset has length:", len(train_dataset))
print("Val dataset has length:", len(val_dataset))

print("One item from train dataset:", train_dataset[0][0].shape)
print("Another item from train dataset:", train_dataset[1][0].shape)
print("Another item from train dataset:", train_dataset[2][0].shape)

print("Decoding one item from train dataset:", tokenizer.decode(train_dataset[0][0]))
print("Decoding one item from val dataset:", tokenizer.decode(val_dataset[0][0]))

Train dataset has length: 30001
Val dataset has length: 10000
One item from train dataset: torch.Size([128])
Another item from train dataset: torch.Size([128])
Another item from train dataset: torch.Size([128])
Decoding one item from train dataset: [CLS] should they have to? christians, whether as a whole, or split by denomination would never submit to such a requirement. if a religion, any religion, had to offer an apology, or disclaimer, every time a criminal invoked their god they would never have time for anything else. now that i say it, perhaps that would be a good thing after all. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
Decoding one item from val dataset: [CLS] really?! then why not be among the first to s

In [None]:
BertModel = torch.hub.load('huggingface/pytorch-transformers', 'model', 'bert-base-uncased')

# Toxicity predictor is pretrained BERT + two linear layers + one dropout layer
class Classifier(nn.Module):
    def __init__(self, num_output_classes = 2):
        super().__init__()

        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.dropout = nn.Dropout(0.1)
        self.linear1 = nn.Linear(768, 128)
        self.linear2 = nn.Linear(128, num_output_classes)

        # initialize linear layer (not sure this is necessary)
        nn.init.xavier_normal_(self.linear1.weight)

    def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
        output = self.bert(input_ids, token_type_ids, attention_mask)
        # print(output.pooler_output)
        pooled_output = self.dropout(output.pooler_output)
        classifier_prev_output = F.relu(self.linear1(pooled_output))
        classifier_output = self.linear2(classifier_prev_output)

        return classifier_output, classifier_prev_output

# Protected class predictor is just two linear layers
class Adversary(nn.Module):
    def __init__(self, identity_labels = 2):
        super().__init__()

        self.linear1 = nn.Linear(128,64)
        self.linear2 = nn.Linear(64, identity_labels)

        nn.init.xavier_normal_(self.linear1.weight)

    def forward(self, input_ids):
        l1_output = self.linear1(input_ids.to(torch.float))
        relu_output = F.relu(l1_output)
        adversary_output = self.linear2(relu_output)
        return adversary_output

Downloading: "https://github.com/huggingface/pytorch-transformers/zipball/main" to /root/.cache/torch/hub/main.zip


Downloading:   0%|          | 0.00/420M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
# Pretrain the classifier on dataset to get initial toxicity predictions

def pretrain_classifier(clf, optimizer, train_loader, loss_criterion, epochs):
  clf.train()
  for epoch in range(epochs):
      epoch_loss = 0

      # data should be a tuple of (x, y, z)
      for (input, target_label, _) in train_loader:
          input = input.to(device)
          target_label = target_label.to(device)

          optimizer.zero_grad()

          classifier_output, _ = clf(input)
          classifier_loss = loss_criterion(classifier_output, target_label) # compute loss
          classifier_loss.backward() # back prop
          optimizer.step()
          epoch_loss += classifier_loss.item()

      print("\nEpoch", epoch+1, "loss:", epoch_loss)

  return clf


# Pretrain the adversary on dataset to get initial protected label predictions

def pretrain_adversary(clf, optimizer, train_loader, loss_criterion, epochs):
  clf.train()
  for epoch in range(epochs):
      epoch_loss = 0

      # data should be a tuple of (x, y, z)
      for (input, _, protected_label) in train_loader:
          input = input.to(device)
          protected_label = protected_label.to(device)

          optimizer.zero_grad()

          classifier_output = clf(input)
          classifier_loss = loss_criterion(classifier_output, protected_label) # compute loss
          classifier_loss.backward() # back prop
          optimizer.step()
          epoch_loss += classifier_loss.item()

      print("\nEpoch", epoch+1, "loss:", epoch_loss)

  return clf


In [None]:
loss_criterion = torch.nn.CrossEntropyLoss()
lrlast = .001
lrmain = .00001

adv = Adversary(identity_labels = 2)
adv.to(device)

optimizer_adv = optim.Adam(adv.parameters(), lr=lrlast)

adv = pretrain_adversary(adv, optimizer_adv, train_loader, loss_criterion, 3)


Epoch 1 loss: 65511.29083061218

Epoch 2 loss: 2848.8182995319366

Epoch 3 loss: 1885.9489414691925


In [None]:
clf = Classifier(num_output_classes = 2)
clf.to(device)

optimizer_clf = optim.Adam(
    [
      {"params":clf.bert.parameters(),"lr": lrmain},
      {"params":clf.linear1.parameters(), "lr": lrlast},
      {"params":clf.linear2.parameters(), "lr": lrlast}
    ]
)

clf = pretrain_classifier(clf, optimizer_clf, train_loader, loss_criterion, 3)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).



Epoch 1 loss: 140.6209488734603

Epoch 2 loss: 106.79070044308901

Epoch 3 loss: 90.03455719351768


In [None]:
# Get typical validation metrics

def evaluate_clf(clf, data_loader):
    all_predictions = np.empty((0,))
    all_true_targets = np.empty((0,))
    all_true_protected = np.empty((0,))

    num_batches = 0
    batch_metrics = {
        "accuracy": 0,
        "precision": 0,
        "recall": 0,
        "f1": 0
    }

    # Make sure there's no backprop during evaluation
    clf.eval()
    with torch.no_grad():
        for i, (input, target_label, protected_label) in enumerate(data_loader):
            input = input.to(device)
            target_label = target_label.to(device)

            output, _ = clf(input)
            _, prediction = torch.max(output, 1)

            target_label = target_label.to("cpu")
            prediction = prediction.to("cpu")

            all_predictions = np.concatenate((all_predictions, prediction.numpy()))
            all_true_targets = np.concatenate((all_true_targets, target_label.numpy()))
            all_true_protected = np.concatenate((all_true_protected, protected_label))

    # Macro: Calculate metrics for each label, and find their unweighted mean.
    # This does not take label imbalance into account.

    # Calculate metrics for each label, and find their average weighted by support
    # (the number of true instances for each label). This alters ‘macro’ to account
    # for label imbalance; it can result in an F-score that is not between precision and recall.

    acc_overall = accuracy_score(all_true_targets, all_predictions)
    prec_macro = precision_score(all_true_targets, all_predictions, average="macro", labels=np.unique(all_true_targets))
    recall_macro = recall_score(all_true_targets, all_predictions, average="macro", labels=np.unique(all_true_targets))
    f1_macro = f1_score(all_true_targets, all_predictions, average="macro", labels=np.unique(all_true_targets))
    prec_weighted = precision_score(all_true_targets, all_predictions, average="weighted", labels=np.unique(all_true_targets))
    recall_weighted = recall_score(all_true_targets, all_predictions, average="weighted", labels=np.unique(all_true_targets))
    f1_weighted = f1_score(all_true_targets, all_predictions, average="weighted", labels=np.unique(all_true_targets))

    print("\nClassifier Evaluation Results")
    print("Accuracy:", acc_overall)
    print("Precision (macro):", prec_macro)
    print("Precision (weighted):", prec_weighted)
    print("Recall (macro):", recall_macro)
    print("Recall (weighted):", recall_weighted)
    print("F1 Score (macro):", f1_macro)
    print("F1 Score (weighted):", f1_weighted)

    return (all_predictions, all_true_targets, all_true_protected)


# Calculate fairness metrics

def evaluate_clf_fairness(all_predictions, all_true_targets, all_true_protected):
    print("\n=====Fairness Metrics=====")
    threshold = 0.5
    fairness_metrics = {
        "protected_toxicity_rate": 0,
        "nonprotected_toxicity_rate": 0,
        "protected_TPR": 0,
        "nonprotected_TPR": 0,
        "protected_FPR": 0,
        "nonprotected_FPR": 0,
        "demographic_parity": 0,
        "true_positive_parity": 0,
        "false_positive_parity": 0,
        "equalized_odds": 0
    }

    # might be easier to work with a dataframe here
    all_data = pd.DataFrame(
        data = {
            "predictions": np.array(all_predictions),
            "target_labels": np.array(all_true_targets),
            "protected_labels": np.array(all_true_protected)
        }
    )
    # print(all_data.head())

    # Calculate toxicity rate for instances where protected label is true
    toxic = all_data.loc[all_data["protected_labels"] == 1]["target_labels"]
    fairness_metrics["protected_toxicity_rate"] = round(sum(toxic) / len(toxic), 4)

    # Calculate toxicity rate for instances where protected label is false
    toxic = all_data.loc[all_data["protected_labels"] == 0]["target_labels"]
    fairness_metrics["nonprotected_toxicity_rate"] = round(sum(toxic) / len(toxic), 4)

    # Calculate confusion matrix when protected label is true
    # FPR = FP / (FP + TN)
    # TPR = TP / (TP + FN)
    protected = all_data.loc[all_data["protected_labels"] == 1]
    tn, fp, fn, tp = confusion_matrix(protected["target_labels"], protected["predictions"]).ravel()
    fairness_metrics["protected_FPR"] = fp / (fp + tn)
    fairness_metrics["protected_TPR"] = tp / (tp + fn)

    # Calculate confusion matrix when protected label is true
    nonprotected = all_data.loc[all_data["protected_labels"] == 0]
    tn, fp, fn, tp = confusion_matrix(nonprotected["target_labels"], nonprotected["predictions"]).ravel()
    fairness_metrics["nonprotected_FPR"] = fp / (fp + tn)
    fairness_metrics["nonprotected_TPR"] = tp / (tp + fn)

    # Calculate demographic parity
    fairness_metrics["demographic_parity"] = abs(fairness_metrics["protected_toxicity_rate"] - fairness_metrics["nonprotected_toxicity_rate"])

    # Calculate true positive parity
    fairness_metrics["true_positive_parity"] = abs(fairness_metrics["protected_TPR"] - fairness_metrics["nonprotected_TPR"])

    # Calculate false positive parity
    fairness_metrics["false_positive_parity"] = abs(fairness_metrics["protected_FPR"] - fairness_metrics["nonprotected_FPR"])

    # Calculate equalized odds
    fairness_metrics["equalized_odds"] = fairness_metrics["true_positive_parity"] + fairness_metrics["false_positive_parity"]

    return fairness_metrics


print("Pretrained validation set performance")
all_predictions, all_true_targets, all_true_protected = evaluate_clf(clf, val_loader)
evaluate_clf_fairness(all_predictions, all_true_targets, all_true_protected)

Pretrained validation set performance

Classifier Evaluation Results
Accuracy: 0.891
Precision (macro): 0.7697280851058819
Precision (weighted): 0.8821122427800621
Recall (macro): 0.7181092607869863
Recall (weighted): 0.891
F1 Score (macro): 0.7396991552688329
F1 Score (weighted): 0.8852448524433317

=====Fairness Metrics=====


{'protected_toxicity_rate': 0.2364,
 'nonprotected_toxicity_rate': 0.0796,
 'protected_TPR': 0.5148148148148148,
 'nonprotected_TPR': 0.4321223709369025,
 'protected_FPR': 0.08941536110049675,
 'nonprotected_FPR': 0.02743801652892562,
 'demographic_parity': 0.1568,
 'true_positive_parity': 0.08269244387791236,
 'false_positive_parity': 0.06197734457157113,
 'equalized_odds': 0.14466978844948347}

In [None]:
print("Pretrained test set performance")
all_predictions, all_true_targets, all_true_protected = evaluate_clf(clf, test_loader)
evaluate_clf_fairness(all_predictions, all_true_targets, all_true_protected)

Pretrained test set performance

Classifier Evaluation Results
Accuracy: 0.8954161724691272
Precision (macro): 0.7814569362172605
Precision (weighted): 0.8870774467273438
Recall (macro): 0.7276620632309687
Recall (weighted): 0.8954161724691272
F1 Score (macro): 0.7502122830309426
F1 Score (weighted): 0.8899019325243206

=====Fairness Metrics=====


{'protected_toxicity_rate': 0.2454,
 'nonprotected_toxicity_rate': 0.0937,
 'protected_TPR': 0.5561613958560524,
 'nonprotected_TPR': 0.446122860020141,
 'protected_FPR': 0.0947144377438808,
 'nonprotected_FPR': 0.02863390254060808,
 'demographic_parity': 0.1517,
 'true_positive_parity': 0.11003853583591139,
 'false_positive_parity': 0.06608053520327273,
 'equalized_odds': 0.17611907103918412}

In [None]:
# Train the classifier

def train_classifier(clf, optimizer_clf, adv, train_loader, loss_criterion, lambda_val):
    # data should be a tuple of (x, y, z)
    for (input, target_label, protected_label) in train_loader:
        input = input.to(device)
        target_label = target_label.to(device)
        protected_label = protected_label.to(device)

        optimizer_clf.zero_grad()

        clf_output, clf_prev_output = clf(input)
        adv_output = adv(clf_prev_output)
        adv_loss = loss_criterion(adv_output, protected_label)
        clf_loss = loss_criterion(clf_output, target_label)
        total_classifier_loss = clf_loss - lambda_val * adv_loss
        total_classifier_loss.backward() # back prop

        optimizer_clf.step()

        print("Adversary Mini-Batch loss: ", adv_loss.item())
        print("Classifier Mini-Batch loss: ", clf_loss.item())
        print("Total Mini-Batch loss: ", total_classifier_loss.item())

        break

    return clf


# Train the adversary

def train_adversary(adv, clf, optimizer_adv, train_loader, loss_criterion, epochs=1):
    adv_loss = 0
    num_batches = 0

    # data should be a tuple of (x, y, z)
    for i, (input, target_label, protected_label) in enumerate(train_loader):
        input = input.to(device)
        target_label = target_label.to(device)
        protected_label = protected_label.to(device)

        optimizer_adv.zero_grad()

        clf_output, clf_prev_output = clf(input)
        adv_output = adv(clf_prev_output)
        adv_loss = loss_criterion(adv_output, protected_label)
        adv_loss.backward() # back prop
        optimizer_adv.step()
        adv_loss += adv_loss.item()
        num_batches += 1

    print("Average adversary loss: ", float(adv_loss.cpu() / num_batches))

    return adv


In [None]:
# Alternate training of the classifier and the adversary
# Took 45 minutes for 10 iterations on gender dataset

# Trained gender on 30 iterations, race on 30
num_iterations = 10

# Lambda value remained at 3
lambda_val = 3

for iteration in range(num_iterations):
    print("\n===== Iteration", iteration, "=====")

    #TRAIN ADVERSARY FOR 1 EPOCH

    for param in clf.parameters():
      param.requires_grad = False

    adv = train_adversary(adv, clf, optimizer_adv, train_loader, loss_criterion, epochs=1)

    for param in clf.parameters():
      param.requires_grad = True

    #TRAIN CLASSIFIER FOR 1 SAMPLE MINI BATCH

    for param in adv.parameters():
      param.requires_grad = False

    clf = train_classifier(clf, optimizer_clf, adv, train_loader, loss_criterion, lambda_val)

    for param in adv.parameters():
      param.requires_grad = True

    # Evaluate classifier
    if iteration % 5 == 0:
      all_predictions, all_true_targets, all_true_protected = evaluate_clf(clf, val_loader)
      evaluate_clf_fairness(all_predictions, all_true_targets, all_true_protected)


===== Iteration 0 =====
Average adversary loss:  0.0028413445688784122
Adversary Mini-Batch loss:  0.6738370060920715
Classifier Mini-Batch loss:  0.1714545637369156
Total Mini-Batch loss:  -1.850056529045105

Classifier Evaluation Results
Accuracy: 0.8902
Precision (macro): 0.7667223354889212
Precision (weighted): 0.8820247535592359
Recall (macro): 0.7211391806068512
Recall (weighted): 0.8902
F1 Score (macro): 0.7406166293938905
F1 Score (weighted): 0.8850786309774048

=====Fairness Metrics=====

===== Iteration 1 =====
Average adversary loss:  0.0027834458742290735
Adversary Mini-Batch loss:  0.6773408651351929
Classifier Mini-Batch loss:  0.10231737047433853
Total Mini-Batch loss:  -1.9297051429748535

===== Iteration 2 =====
Average adversary loss:  0.0026700992602854967
Adversary Mini-Batch loss:  0.5235402584075928
Classifier Mini-Batch loss:  0.06465840339660645
Total Mini-Batch loss:  -1.5059623718261719

===== Iteration 3 =====
Average adversary loss:  0.0019236343214288354
A

In [None]:
# Evaluate on test set at the end
all_predictions, all_true_targets, all_true_protected = evaluate_clf(clf, test_loader)
evaluate_clf_fairness(all_predictions, all_true_targets, all_true_protected)


Classifier Evaluation Results
Accuracy: 0.8814623595897579
Precision (macro): 0.743492262611792
Precision (weighted): 0.8820157684402092
Recall (macro): 0.7457532747272939
Recall (weighted): 0.8814623595897579
F1 Score (macro): 0.7446141003163776
F1 Score (weighted): 0.8817362647979872

=====Fairness Metrics=====


{'protected_toxicity_rate': 0.2454,
 'nonprotected_toxicity_rate': 0.0937,
 'protected_TPR': 0.5125408942202835,
 'nonprotected_TPR': 0.6052366565961732,
 'protected_FPR': 0.08158921603405463,
 'nonprotected_FPR': 0.06559766763848396,
 'demographic_parity': 0.1517,
 'true_positive_parity': 0.09269576237588972,
 'false_positive_parity': 0.015991548395570668,
 'equalized_odds': 0.10868731077146039}

In [None]:
# Evaluate on val set one more time
all_predictions, all_true_targets, all_true_protected = evaluate_clf(clf, val_loader)
evaluate_clf_fairness(all_predictions, all_true_targets, all_true_protected)