Trying to implement the CMU paper this time (page 17 of https://aclanthology.org/2020.socialnlp-1.pdf)


```
Model Architecture
Our demotion model consists of three parts:
1) An encoder H that encodes the text into a high dimensional space;
2) A binary classifier C that predicts the target attribute from
the input text;
3) An adversary D that predicts the protected attribute from the input text.

We used a single-layer bidirectional LSTM encoder with an attention mechanism.
Both classifiers are two-layer MLPs with a tanh activation function.
```

LSTM can be made bidirectional: https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html

PyTorch has MultiheadAttention implementation: https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html

MLP == Linear layer: https://pytorch.org/docs/stable/generated/torch.nn.Linear.html#torch.nn.Linear

tanh activation: https://pytorch.org/docs/stable/generated/torch.nn.Tanh.html?highlight=tanh#torch.nn.Tanh

HW4 template code also featured a BidirectionalEncoder class, but it was the Decoder class that implemented attention.

Found the code the CMU paper is based on!!
https://github.com/Sachin19/adversarial-classify/blob/master/model.py

Their paper: https://arxiv.org/pdf/1909.00453.pdf


In [None]:
# Install HuggingFace transformers library to match baseline

# !pip install transformers[sentencepiece]==4.18.0 datasets --q

In [None]:
import pandas as pd
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
# from transformers import AutoTokenizer
import torch.optim as optim
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, recall_score, precision_score, confusion_matrix
import warnings
from torch.autograd import Function, Variable
import math
from torchtext.vocab import vocab
from torchtext.data import get_tokenizer
from collections import Counter, OrderedDict
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence
import copy
from torch.optim.lr_scheduler import ReduceLROnPlateau

warnings.filterwarnings('ignore')

In [None]:
# Specify GPU device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device:", device)

if torch.cuda.is_available():
    torch.cuda.empty_cache()

Device: cuda:0


In [None]:
from google.colab import drive

# Connect to Google Drive
drive.mount('/content/drive')

# Read subset (100 items) csv file from Google Drive
# train_data = pd.read_csv('/content/drive/Shareddrives/AI539: NLP with Deep Learning/subset_baseline_data.csv')

# Training data will be split into 75% train and 25% validation
train_data = pd.read_csv("/content/drive/My Drive/train_religion.csv")
test_data = pd.read_csv("/content/drive/My Drive/test_religion.csv")

train_data.head()

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


Unnamed: 0.1,Unnamed: 0,target,comment_text,muslim,christian,religion
0,185,0,"Should they have to? \n\nChristians, whether a...",0,1,0
1,847,0,The Militia holding a anti-Government stand-of...,0,1,0
2,933,0,In regard to the WW review of the Glenn Beck a...,0,1,0
3,1073,0,When you can figure out that turbans are worn ...,1,0,1
4,1117,0,What is uniquely and exclusively possible in a...,0,1,0


In [None]:
# PyTorch Dataset class for our data
# No need for the BERT tokenizer for this architecture though
# Need to take vocab approach like in the HWs instead
# data is already provided where targets are associated with comments
# need to build vocab from the comments though
# and need to return protected attribute along with target label and text

class ToxicityDataset(Dataset):
    def __init__(self, all_data, split="train", tokenizer=None, text_vocab=None):
        self.raw_data = None
        self.tokens = None
        self.target_labels = None
        self.protected_labels = None
        # self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') if tokenizer is None else tokenizer
        self.tokenizer = get_tokenizer("basic_english") if tokenizer is None else tokenizer
        self.text_vocab = text_vocab

        # do data cleaning, tokenizing, make sure all are tensors

        # train set: 75% train, 25% val, test set: 100% test
        split_index = round(0.75*len(all_data))
        if split == "test":
            # No modifications, use all of the test data
            self.raw_data = all_data
        elif split == "train":
            # 75% of train data for training
            self.raw_data = pd.DataFrame(all_data.loc[:split_index, :]).reset_index()
        else:
            # 25% of train data for validation
            self.raw_data = pd.DataFrame(all_data.loc[split_index:, :]).reset_index()

        # drop empty comments
        self.raw_data['comment_text'].dropna(inplace=True)

        # preprocess comments
        self.raw_data.loc[:, 'comment_text'] = self.raw_data['comment_text'].apply(self._preprocess_comments)

        # create vocabulary from training data
        if self.text_vocab == None:
            word_counter = Counter()
            for text in self.raw_data["comment_text"]:
                word_counter.update(text.split())
            sorted_word_counter = sorted(word_counter.items(), key=lambda x: x[1], reverse=True)
            ordered_words = OrderedDict(sorted_word_counter)
            self.text_vocab = vocab(ordered_words, min_freq=5, specials=['<unk>'], special_first=True)
            self.text_vocab.set_default_index(self.text_vocab['<unk>'])

        # tokenize using the tokenizer, not sure if this is really correct or not
        comments = list(self.raw_data['comment_text'].copy())
        self.tokens = []
        for comment in self.raw_data['comment_text']:
            self.tokens.append(self.text_vocab.forward(self.tokenizer(comment)))

        # convert targets from a range to binary classes
        self.target_labels = self.raw_data['target'].apply(lambda x: 1 if x >= 0.5 else 0)

        # currently the only protected attribute is "female", but for intersectional
        # bias we'd probably want multiple protected attributes
        self.protected_labels = self.raw_data['muslim']


    def _preprocess_comments(self, comment):
        new_tokens = []
        for token in comment.split(" "):
            # replace usernames with something generic
            token = '@user' if token.startswith('@') and len(token) > 1 else token

            # replace URLs with something generic
            token = 'http' if token.startswith('http') else token

            new_tokens.append(token)
        return " ".join(new_tokens)


    def __getitem__(self, index):
        return (
            torch.tensor(self.tokens[index]),
            torch.tensor(self.target_labels[index]),
            torch.tensor(self.protected_labels[index])
        )

    def __len__(self):
        return len(self.target_labels)

    # From HW2
    def pad_collate(self, batch):
        (text_lists, label_lists, protected_lists) = zip(*batch)

        numeralized_text = [torch.tensor(t) for t in text_lists]
        x_lens = [len(x) for x in numeralized_text]
        xx = pad_sequence(numeralized_text, batch_first=True, padding_value=0).type(torch.int)

        # numeralized_tags = [F.one_hot(torch.tensor(self.tag_vocab.forward(t)).long(), num_classes=len(self.tag_vocab)) for t in label_lists]
        # yy = pad_sequence(numeralized_tags, batch_first=True, padding_value=0).type(torch.float)

        # target labels and protected labels should be the same size
        yy = torch.tensor(label_lists, dtype=int)
        zz = torch.tensor(protected_lists, dtype=int)

        return xx, x_lens, yy, zz


# tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
tokenizer = get_tokenizer("basic_english")
train_dataset = ToxicityDataset(train_data, split="train", tokenizer=tokenizer)
val_dataset = ToxicityDataset(train_data, split="val", tokenizer=tokenizer, text_vocab=train_dataset.text_vocab)
test_dataset = ToxicityDataset(test_data, split="test", tokenizer=tokenizer, text_vocab=train_dataset.text_vocab)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, collate_fn=train_dataset.pad_collate)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=True, collate_fn=train_dataset.pad_collate)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True, collate_fn=train_dataset.pad_collate)


In [None]:
print("Train dataset has length:", len(train_dataset))
print("Val dataset has length:", len(val_dataset))

print("One item from train dataset:", train_dataset[0])


Train dataset has length: 30001
Val dataset has length: 10000
One item from train dataset: (tensor([   69,    22,    18,     2,   651,  2291,   428,   360,    15,     5,
          366,   428,    23,  4403,    24,  3727,    42,   127,  2584,     2,
          139,     5,  3810,   165,    53,     5,   163,   428,    67,   163,
          428,    91,     2,  1096,    37,  4986,   428,    23,     0,   428,
          159,   130,     5,  1403, 16902,    27,  1496,    22,    42,   127,
           18,   130,    10,   258,   524,   165,   142,     8,   887,   120,
           16,   428,   693,     8,    42,    14,     5,   121,   234,   160,
           29,   165]), tensor(0), tensor(0))


In [None]:
# Anita's attempt before finding Kumar et al. code
# Based on HW4 skeleton code
'''
class EncoderAndClassifier(nn.Module):
    def __init__(self, vocab_size, enc_hid_dim, embed_dim, num_heads, output_dim, dropout=0.5):
        super().__init__()
        self.enc_hidden_dim = enc_hid_dim
        self.emb = nn.Embedding(vocab_size, embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads)
        self.bilstm = nn.GRU(embed_dim, enc_hid_dim, bidirectional = True)
        self.dropout = nn.Dropout(dropout)
        self.linear1 = nn.Linear(2 * enc_hid_dim, 128)
        self.linear2 = nn.Linear(128, output_dim)

    def forward(self, input):
        # embed input tokens
        embedded = self.dropout(self.emb(input))

        # process with bidirectional GRU model
        enc_hidden_states, _ = self.bilstm(embedded)

        # compute a global sentence representation to feed as the initial hidden state of the decoder
        # concatenate the forward GRU's representation after the last word and
        # the backward GRU's representation after the first word

        last_forward = enc_hidden_states[-1, :, :self.enc_hidden_dim]
        first_backward = enc_hidden_states[0, :, self.enc_hidden_dim:]

        # this was from the skeleton code
        # transform to the size of the decoder hidden state with a fully-connected layer
        # sent = F.relu(self.fc(torch.cat((last_forward, first_backward), dim = 1)))

        # should be this instead?
        l1_output = self.linear1(torch.cat((last_forward, first_backward), dim = 1))
        tanh_output = F.tanh(l1_output)
        output = self.linear2(tanh_output)

        return enc_hidden_states
'''

"\nclass EncoderAndClassifier(nn.Module):\n    def __init__(self, vocab_size, enc_hid_dim, embed_dim, num_heads, output_dim, dropout=0.5):\n        super().__init__()\n        self.enc_hidden_dim = enc_hid_dim\n        self.emb = nn.Embedding(vocab_size, embed_dim)\n        self.attn = nn.MultiheadAttention(embed_dim, num_heads)\n        self.bilstm = nn.GRU(embed_dim, enc_hid_dim, bidirectional = True)\n        self.dropout = nn.Dropout(dropout)\n        self.linear1 = nn.Linear(2 * enc_hid_dim, 128)\n        self.linear2 = nn.Linear(128, output_dim)\n\n    def forward(self, input):\n        # embed input tokens\n        embedded = self.dropout(self.emb(input))\n\n        # process with bidirectional GRU model\n        enc_hidden_states, _ = self.bilstm(embedded)\n\n        # compute a global sentence representation to feed as the initial hidden state of the decoder\n        # concatenate the forward GRU's representation after the last word and\n        # the backward GRU's representa

In [None]:
# Models based on Kumar et al.'s code: https://github.com/Sachin19/adversarial-classify/blob/master/model.py

class Encoder(nn.Module):
  def __init__(self, embedding_dim, hidden_dim, nlayers=1, dropout=0., bidirectional=True):
    super(Encoder, self).__init__()
    self.rnn = nn.LSTM(embedding_dim, hidden_dim, nlayers,
                       dropout=dropout, bidirectional=bidirectional,
                       batch_first=True)

  def forward(self, input, hidden=None):
    self.rnn.flatten_parameters()
    return self.rnn(input, hidden)

# this one doesn't appear to be used in the Kumar et al. code
class Attention(nn.Module):
  def __init__(self, query_dim, key_dim, value_dim):
    super(Attention, self).__init__()
    self.scale = 1. / math.sqrt(query_dim)

  def forward(self, query, keys, values):
    # Query = [BxQ]
    # Keys = [TxBxK]
    # Values = [TxBxV]
    # Outputs = a:[TxB], lin_comb:[BxV]

    # Here we assume q_dim == k_dim (dot product attention)

    query = query.unsqueeze(1) # [BxQ] -> [Bx1xQ]
    keys = keys.transpose(0,1).transpose(1,2) # [TxBxK] -> [BxKxT]
    energy = torch.bmm(query, keys) # [Bx1xQ]x[BxKxT] -> [Bx1xT]
    energy = F.softmax(energy.mul_(self.scale), dim=2) # scale, normalize

    values = values.transpose(0,1) # [TxBxV] -> [BxTxV]
    linear_combination = torch.bmm(energy, values).squeeze(1) #[Bx1xT]x[BxTxV] -> [BxV]
    return energy, linear_combination

# This type of attention seems to be used
class BahdanauAttention(nn.Module):
  def __init__(self, hidden_dim, attn_dim):
    super(BahdanauAttention, self).__init__()
    self.linear = nn.Linear(hidden_dim, attn_dim)
    self.linear2 = nn.Linear(attn_dim, 1)

  def forward(self, hidden, mask=None):
    # hidden = [TxBxH]
    # mask = [TxB]
    # Outputs = a:[TxB], lin_comb:[BxV]

    # Here we assume q_dim == k_dim (dot product attention)
    # hidden = hidden.transpose(0,1) # [TxBxH] -> [BxTxH]
    energy = self.linear(hidden) # [BxTxH] -> [BxTxA]
    energy = F.tanh(energy)
    energy = self.linear2(energy) # [BxTxA] -> [BxTx1]
    energy = F.softmax(energy, dim=1) # scale, normalize

    if mask is not None:
      mask = mask.transpose(0, 1).unsqueeze(2)
      energy = energy * mask
      Z = energy.sum(dim=1, keepdim=True) #[BxTx1] -> [Bx1x1]
      energy = energy/Z #renormalize

    energy = energy.transpose(1, 2) # [BxTx1] -> [Bx1xT]
    linear_combination = torch.bmm(energy, hidden).squeeze(1) #[Bx1xT]x[BxTxH] -> [BxH]

    return energy, linear_combination

class Classifier(nn.Module):
  def __init__(self, embedding, encoder, attention, hidden_dim, num_classes=10, num_topics=50):
    super(Classifier, self).__init__()
    # num_classes=2
    self.embedding = embedding
    self.encoder = encoder
    self.attention = attention
    self.decoder = nn.Linear(hidden_dim, num_classes)

    size = 0
    for p in self.parameters():
      size += p.nelement()
    print('Total param size: {}'.format(size))

  def forward(self, input, input_lens, alpha=1.0, gradreverse=True, padding_mask=None):
    embedded_input = self.embedding(input)

    outputs, hidden = self.encoder(embedded_input)

    if isinstance(hidden, tuple): # LSTM
      hidden = hidden[1] # take the cell state

    # need to concat the last 2 hidden layers
    hidden = torch.cat([hidden[-1], hidden[-2]], dim=-1)

    energy, linear_combination = self.attention(outputs, padding_mask)
    energy = energy.permute(2,0,1)

    logits = self.decoder(linear_combination)
    return logits, energy



In [None]:
# Pretrain the classifier on dataset to get initial toxicity predictions

def pretrain_classifier(clf, optimizer, train_loader, loss_criterion, epochs, scheduler):
  best_model_wts = copy.deepcopy(clf.state_dict())
  best_acc = 0.0
  best_acc_epoch = 0

  for epoch in range(epochs):
    running_loss = 0.0
    running_corrects = 0
    epoch_loss = 0
    num_items = 0

    # Each epoch has a training and validation phase
    for phase in ['train', 'val']:
      if phase == 'train':
          clf.train()  # Set model to training mode
      else:
          clf.eval()   # Set model to evaluate mode

      if phase == 'train':
        # data should be a tuple of (x, xlens, y, z)
        for (input, input_lens, target_label, _) in train_loader:
          input = input.to(device)
          target_label = target_label.to(device)
          num_items += input.size(0)

          optimizer.zero_grad()

          classifier_output, energy = clf(input, input_lens)
          predictions = torch.argmax(classifier_output, dim=-1)
          classifier_loss = loss_criterion(classifier_output, target_label) # compute loss
          classifier_loss.backward() # back prop
          optimizer.step()
          running_loss += classifier_loss.item() * input.size(0)
          running_corrects += torch.sum(predictions == target_label)

      epoch_loss = running_loss / num_items
      epoch_acc = round(float(running_corrects.double() / num_items), 4)

      if phase == 'val':
        scheduler.step(epoch_acc)
        print(f'\nEpoch: {epoch+1} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

      # deep copy the model
      if phase == 'val' and epoch_acc > best_acc:
        best_acc = epoch_acc
        best_acc_epoch = epoch
        best_model_wts = copy.deepcopy(clf.state_dict())

  # load best model weights
  clf.load_state_dict(best_model_wts)
  return clf



In [None]:
vocab_size = len(train_dataset.text_vocab)
embed_size = 128
hidden_size = 128
embedding = nn.Embedding(vocab_size, embed_size, padding_idx=1)

# CMU paper said: "single-layer bidirectional LSTM encoder"
encoder = Encoder(embed_size, hidden_size, nlayers=1, dropout=0.5, bidirectional=True)

# Kumar et al. code uses this version of attention
# https://github.com/Sachin19/adversarial-classify/blob/master/train.py
attention_dim = 2 * hidden_size
attention = BahdanauAttention(attention_dim, attention_dim)

clf_loss = nn.CrossEntropyLoss()

lrlast = .001
lrmain = .00001

# CMU paper mentions pre-training encoder and classifier in phase 1
clf = Classifier(embedding, encoder, attention, attention_dim, num_classes=2)
clf.to(device)

optimizer_clf = torch.optim.Adam(clf.parameters(), lr=lrlast, amsgrad=True)
scheduler = ReduceLROnPlateau(optimizer_clf, mode='max', patience=2, min_lr=0.000000001)

# converged around 20 epochs
clf = pretrain_classifier(clf, optimizer_clf, train_loader, clf_loss, 20, scheduler)

Total param size: 3749635

Epoch: 1 Loss: 0.3695 Acc: 0.8747

Epoch: 2 Loss: 0.3096 Acc: 0.8795

Epoch: 3 Loss: 0.2423 Acc: 0.9025

Epoch: 4 Loss: 0.1953 Acc: 0.9211

Epoch: 5 Loss: 0.1460 Acc: 0.9434

Epoch: 6 Loss: 0.0945 Acc: 0.9664

Epoch: 7 Loss: 0.0573 Acc: 0.9808

Epoch: 8 Loss: 0.0372 Acc: 0.9887

Epoch: 9 Loss: 0.0221 Acc: 0.9934

Epoch: 10 Loss: 0.0166 Acc: 0.9949

Epoch: 11 Loss: 0.0154 Acc: 0.9950

Epoch: 12 Loss: 0.0158 Acc: 0.9948

Epoch: 13 Loss: 0.0121 Acc: 0.9958

Epoch: 14 Loss: 0.0100 Acc: 0.9965

Epoch: 15 Loss: 0.0077 Acc: 0.9971

Epoch: 16 Loss: 0.0082 Acc: 0.9969

Epoch: 17 Loss: 0.0063 Acc: 0.9972

Epoch: 18 Loss: 0.0054 Acc: 0.9976

Epoch: 19 Loss: 0.0053 Acc: 0.9972

Epoch: 20 Loss: 0.0046 Acc: 0.9977


In [None]:
# Get typical validation metrics

def evaluate_clf(clf, data_loader):
    all_predictions = np.empty((0,))
    all_true_targets = np.empty((0,))
    all_true_protected = np.empty((0,))

    num_batches = 0
    batch_metrics = {
        "accuracy": 0,
        "precision": 0,
        "recall": 0,
        "f1": 0
    }

    # Make sure there's no backprop during evaluation
    clf.eval()
    with torch.no_grad():
        for i, (input, input_lens, target_label, protected_label) in enumerate(data_loader):
            input = input.to(device)
            target_label = target_label.to(device)

            output, _ = clf(input, input_lens)
            _, prediction = torch.max(output, 1)

            target_label = target_label.to("cpu")
            prediction = prediction.to("cpu")

            all_predictions = np.concatenate((all_predictions, prediction.numpy()))
            all_true_targets = np.concatenate((all_true_targets, target_label.numpy()))
            all_true_protected = np.concatenate((all_true_protected, protected_label))

    # Macro: Calculate metrics for each label, and find their unweighted mean.
    # This does not take label imbalance into account.

    # Calculate metrics for each label, and find their average weighted by support
    # (the number of true instances for each label). This alters ‘macro’ to account
    # for label imbalance; it can result in an F-score that is not between precision and recall.

    acc_overall = accuracy_score(all_true_targets, all_predictions)
    prec_macro = precision_score(all_true_targets, all_predictions, average="macro", labels=np.unique(all_true_targets))
    recall_macro = recall_score(all_true_targets, all_predictions, average="macro", labels=np.unique(all_true_targets))
    f1_macro = f1_score(all_true_targets, all_predictions, average="macro", labels=np.unique(all_true_targets))
    prec_weighted = precision_score(all_true_targets, all_predictions, average="weighted", labels=np.unique(all_true_targets))
    recall_weighted = recall_score(all_true_targets, all_predictions, average="weighted", labels=np.unique(all_true_targets))
    f1_weighted = f1_score(all_true_targets, all_predictions, average="weighted", labels=np.unique(all_true_targets))

    print("\nClassifier Evaluation Results")
    print("Accuracy:", acc_overall)
    print("Precision (macro):", prec_macro)
    print("Precision (weighted):", prec_weighted)
    print("Recall (macro):", recall_macro)
    print("Recall (weighted):", recall_weighted)
    print("F1 Score (macro):", f1_macro)
    print("F1 Score (weighted):", f1_weighted)

    return (all_predictions, all_true_targets, all_true_protected)


# Calculate fairness metrics

def evaluate_clf_fairness(all_predictions, all_true_targets, all_true_protected):
    print("\n=====Fairness Metrics=====")
    threshold = 0.5
    fairness_metrics = {
        "protected_toxicity_rate": 0,
        "nonprotected_toxicity_rate": 0,
        "protected_TPR": 0,
        "nonprotected_TPR": 0,
        "protected_FPR": 0,
        "nonprotected_FPR": 0,
        "demographic_parity": 0,
        "true_positive_parity": 0,
        "false_positive_parity": 0,
        "equalized_odds": 0
    }

    # might be easier to work with a dataframe here
    all_data = pd.DataFrame(
        data = {
            "predictions": np.array(all_predictions),
            "target_labels": np.array(all_true_targets),
            "protected_labels": np.array(all_true_protected)
        }
    )
    # print(all_data.head())

    # Calculate toxicity rate for instances where protected label is true
    toxic = all_data.loc[all_data["protected_labels"] == 1]["target_labels"]
    fairness_metrics["protected_toxicity_rate"] = round(sum(toxic) / len(toxic), 4)

    # Calculate toxicity rate for instances where protected label is false
    toxic = all_data.loc[all_data["protected_labels"] == 0]["target_labels"]
    fairness_metrics["nonprotected_toxicity_rate"] = round(sum(toxic) / len(toxic), 4)

    # Calculate confusion matrix when protected label is true
    # FPR = FP / (FP + TN)
    # TPR = TP / (TP + FN)
    protected = all_data.loc[all_data["protected_labels"] == 1]
    tn, fp, fn, tp = confusion_matrix(protected["target_labels"], protected["predictions"]).ravel()
    fairness_metrics["protected_FPR"] = fp / (fp + tn)
    fairness_metrics["protected_TPR"] = tp / (tp + fn)

    # Calculate confusion matrix when protected label is true
    nonprotected = all_data.loc[all_data["protected_labels"] == 0]
    tn, fp, fn, tp = confusion_matrix(nonprotected["target_labels"], nonprotected["predictions"]).ravel()
    fairness_metrics["nonprotected_FPR"] = fp / (fp + tn)
    fairness_metrics["nonprotected_TPR"] = tp / (tp + fn)

    # Calculate demographic parity
    fairness_metrics["demographic_parity"] = abs(fairness_metrics["protected_toxicity_rate"] - fairness_metrics["nonprotected_toxicity_rate"])

    # Calculate true positive parity
    fairness_metrics["true_positive_parity"] = abs(fairness_metrics["protected_TPR"] - fairness_metrics["nonprotected_TPR"])

    # Calculate false positive parity
    fairness_metrics["false_positive_parity"] = abs(fairness_metrics["protected_FPR"] - fairness_metrics["nonprotected_FPR"])

    # Calculate equalized odds
    fairness_metrics["equalized_odds"] = fairness_metrics["true_positive_parity"] + fairness_metrics["false_positive_parity"]

    return fairness_metrics


all_predictions, all_true_targets, all_true_protected = evaluate_clf(clf, test_loader)
evaluate_clf_fairness(all_predictions, all_true_targets, all_true_protected)


Classifier Evaluation Results
Accuracy: 0.8706481546082467
Precision (macro): 0.7145142401303506
Precision (weighted): 0.857434209038154
Recall (macro): 0.6661876952603957
Recall (weighted): 0.8706481546082467
F1 Score (macro): 0.6853615776134209
F1 Score (weighted): 0.8624611124753722

=====Fairness Metrics=====


{'protected_toxicity_rate': 0.2454,
 'nonprotected_toxicity_rate': 0.0937,
 'protected_TPR': 0.4525627044711014,
 'nonprotected_TPR': 0.32729103726082576,
 'protected_FPR': 0.12061014544164597,
 'nonprotected_FPR': 0.03581840899625156,
 'demographic_parity': 0.1517,
 'true_positive_parity': 0.12527166721027566,
 'false_positive_parity': 0.08479173644539441,
 'equalized_odds': 0.21006340365567006}

In [None]:
# Two-layer MLP with tanh activation for the Adversary
class Adversary(nn.Module):
    def __init__(self, hidden_dim, output_dim, encoder, embedding):
        super().__init__()

        self.embedding = embedding
        self.encoder = encoder
        self.linear1 = nn.Linear(hidden_dim * 2, 128)
        self.linear2 = nn.Linear(128, output_dim)

    def forward(self, input):
        embedded_input = self.embedding(input)
        outputs, hidden = self.encoder(embedded_input)

        if isinstance(hidden, tuple): # LSTM
          hidden = hidden[1] # take the cell state

        # need to concat the last 2 hidden layers
        hidden = torch.cat([hidden[-1], hidden[-2]], dim=-1)

        l1_output = self.linear1(hidden.to(torch.float))
        tanh_output = F.tanh(l1_output)
        output = self.linear2(tanh_output)
        return output

adv = Adversary(hidden_size, 2, encoder, embedding)
adv.to(device)

optimizer_adv = optim.Adam(adv.parameters(), lr=lrlast)
adv_loss = nn.CrossEntropyLoss()

# CMU paper doesn't mention pretraining the adversary

In [None]:
# Below is code from the other adversarial debiasing architecture, can probably adapt it

In [None]:
# Train the adversary, worry only about adversary's loss

def train_adversary(adv, clf, optimizer_adv, train_loader, loss_criterion, epochs=1):
    adv_loss = 0
    num_batches = 0

    # data should be a tuple of (x, y, z)
    for i, (input, input_lens, target_label, protected_label) in enumerate(train_loader):
        input = input.to(device)
        protected_label = protected_label.to(device)

        optimizer_adv.zero_grad()

        # clf_output, clf_prev_output = clf(input, input_lens)
        adv_output = adv(input)
        adv_loss = loss_criterion(adv_output, protected_label)
        # adv_loss.requires_grad = True
        # adv_loss.retain_grad()
        adv_loss.backward() # back prop
        optimizer_adv.step()
        adv_loss += adv_loss.item()
        num_batches += 1

    print("\tAdversary epoch loss: ", adv_loss.item())

    return adv


# Train the classifier, but need to combine two objectives using alpha parameter
# minimize clf loss and try to make adversary guess closer to random

def train_classifier(clf, optimizer_clf, adv, train_loader, clf_criterion, adv_criterion, alpha_val):
    # data should be a tuple of (x, y, z)
    for (input, input_lens, target_label, protected_label) in train_loader:
        input = input.to(device)
        target_label = target_label.to(device)

        # batch of protected labels should be randomly sampled from uniform(0,1)
        # to confuse the adversary
        protected_label = torch.rand(size=protected_label.shape)
        protected_label = torch.bernoulli(protected_label).type(torch.LongTensor)
        protected_label = protected_label.to(device)

        optimizer_clf.zero_grad()

        clf_output, clf_prev_output = clf(input, input_lens)
        adv_output = adv(input)
        adv_loss = adv_criterion(adv_output, protected_label)
        clf_loss = clf_criterion(clf_output, target_label)
        total_classifier_loss = alpha_val * clf_loss + (1-alpha_val) * adv_loss
        # clf_loss.retain_grad()
        # adv_loss.retain_grad()
        total_classifier_loss.backward() # back prop

        optimizer_clf.step()

        print("\tClassifier epoch loss: ", total_classifier_loss.item())

        break

    return clf





In [None]:
# In phase 2, alternate training between the classifier and the adversary
# Train adversary using equation 2, input: encoder outputs, output: protected label prediction
# Train classifier using equation 3
# Encoder produces samples that will make adversary predict closer to random, input: ?, output: ?
# Clf input: encoder outputs, output: target label prediction

# More text from the paper:

# In the pre-training phase, we train the model until convergence and pick the
# best-performing checkpoint for fine-tuning. In the fine-tuning phase, we
# alternate training one single adversary and the classification model each for
# two epochs in one round and train for 10 rounds in total

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, collate_fn=train_dataset.pad_collate)

num_iterations = 10

for iteration in range(num_iterations):
    print("\n===== Iteration", iteration, "=====")

    for e in range(2):
        print("\tEpoch", e+1)

        # Train adversary for 2 epochs

        for param in clf.parameters():
          param.requires_grad = False

        adv_loss.requires_grad = True
        adv = train_adversary(adv, clf, optimizer_adv, train_loader, adv_loss, epochs=1)

        for param in clf.parameters():
          param.requires_grad = True

        # Train classifier for 2 epochs

        for param in adv.parameters():
          param.requires_grad = False

        clf_loss.requires_grad = True
        # The paper set alpha to 0.05
        clf = train_classifier(clf, optimizer_clf, adv, train_loader, clf_loss, adv_loss, 0.5)

        for param in adv.parameters():
          param.requires_grad = True

    # Evaluate classifier
    if iteration % 2 == 0:
      all_predictions, all_true_targets, all_true_protected = evaluate_clf(clf, val_loader)
      evaluate_clf_fairness(all_predictions, all_true_targets, all_true_protected)


===== Iteration 0 =====
	Epoch 1
	Adversary epoch loss:  1.0371233224868774
	Classifier epoch loss:  0.4794791340827942
	Epoch 2
	Adversary epoch loss:  1.1495082378387451
	Classifier epoch loss:  0.443795770406723

Classifier Evaluation Results
Accuracy: 0.8695
Precision (macro): 0.7113437370374107
Precision (weighted): 0.8559719206493346
Recall (macro): 0.6631737979493142
Recall (weighted): 0.8695
F1 Score (macro): 0.682200843660568
F1 Score (weighted): 0.8611316776573628

=====Fairness Metrics=====

===== Iteration 1 =====
	Epoch 1
	Adversary epoch loss:  1.3110564947128296
	Classifier epoch loss:  0.45724138617515564
	Epoch 2
	Adversary epoch loss:  0.9699040651321411
	Classifier epoch loss:  0.36290985345840454

===== Iteration 2 =====
	Epoch 1
	Adversary epoch loss:  1.2289257049560547
	Classifier epoch loss:  0.4953964948654175
	Epoch 2
	Adversary epoch loss:  1.1549842357635498
	Classifier epoch loss:  0.5216394662857056

Classifier Evaluation Results
Accuracy: 0.8692
Precisio

In [None]:
# Evaluate on test set at the end
all_predictions, all_true_targets, all_true_protected = evaluate_clf(clf, test_loader)
evaluate_clf_fairness(all_predictions, all_true_targets, all_true_protected)


Classifier Evaluation Results
Accuracy: 0.8685550826763413
Precision (macro): 0.7096438608795034
Precision (weighted): 0.8570570502159683
Recall (macro): 0.67007549752549
Recall (weighted): 0.8685550826763413
F1 Score (macro): 0.6864555979032769
F1 Score (weighted): 0.8617198919085158

=====Fairness Metrics=====


{'protected_toxicity_rate': 0.2454,
 'nonprotected_toxicity_rate': 0.0937,
 'protected_TPR': 0.46673936750272627,
 'nonprotected_TPR': 0.337361530715005,
 'protected_FPR': 0.12876906704505145,
 'nonprotected_FPR': 0.03894210745522699,
 'demographic_parity': 0.1517,
 'true_positive_parity': 0.12937783678772125,
 'false_positive_parity': 0.08982695958982445,
 'equalized_odds': 0.2192047963775457}

In [None]:
# Evaluate on val set one more time
all_predictions, all_true_targets, all_true_protected = evaluate_clf(clf, val_loader)
evaluate_clf_fairness(all_predictions, all_true_targets, all_true_protected)


Classifier Evaluation Results
Accuracy: 0.69
Precision (macro): 0.6116259362528019
Precision (weighted): 0.6781061724345306
Recall (macro): 0.6020445296863274
Recall (weighted): 0.69
F1 Score (macro): 0.6055056948688864
F1 Score (weighted): 0.6830622653522421

=====Fairness Metrics=====


{'protected_toxicity_rate': 0.3126,
 'nonprotected_toxicity_rate': 0.2781,
 'protected_TPR': 0.4074074074074074,
 'nonprotected_TPR': 0.38966202783300197,
 'protected_FPR': 0.18947368421052632,
 'nonprotected_FPR': 0.19142419601837674,
 'demographic_parity': 0.034499999999999975,
 'true_positive_parity': 0.017745379574405418,
 'false_positive_parity': 0.0019505118078504136,
 'equalized_odds': 0.01969589138225583}