<a href="https://colab.research.google.com/github/brandonko/FairnessNLP/blob/main/Counterfactual_Data_Augmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Counterfactual Data Augmentation**
Bias mitigation technique from [Dinan et al. (2020)](https://arxiv.org/abs/1911.03842)

In [None]:
import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## Read in the gender word lists from [Zhao et al. (2018)](https://arxiv.org/abs/1809.01496)

In [None]:
# Edit the file paths below to go to the files containing the female and male word
# lists. These word lists are in data/female_word_file.txt and data/male_word_file.txt
# in the GitHub repo.
female_words = []
male_words = []
with open('/content/drive/MyDrive/NLP Capstone/data/female_word_file.txt', 'r') as female_word_file:
    female_words = female_word_file.read().split()
with open('/content/drive/MyDrive/NLP Capstone/data/male_word_file.txt', 'r') as male_word_file:
    male_words = male_word_file.read().split()
gender_word_pairs = dict()
for i in range(0, len(female_words)):
    gender_word_pairs[female_words[i]] = male_words[i]
    gender_word_pairs[male_words[i]] = female_words[i]

## Perform counterfactual data augmentation

In [None]:
def add_counterfactuals(data, word_pairs):
    """Augments the data by replacing each word from the given word_pairs in
    the given data with its counterfactual from word_pairs. 

    Args:
      data: The dataset to augment. Expected format is a list where
      each element is text.
      word_pairs: Dictionary where each key is a word and the value is
      the counterfactual (another word) for that word. All keys and values
      should be lowercase.

    Returns:
      The data augmented with each word in word_pairs replaced with its
      counterfactual.
    """
    new_data = []
    for item in data:
        new_data.append(item)
        words = item.split()
        found_counterfactual = False
        for i in range(0, len(words)):
            # Separate word and any punctuation
            tokens = nltk.word_tokenize(words[i])
            for j in range(0, len(tokens)):
                if tokens[j].lower() in word_pairs:
                    found_counterfactual = True
                    if tokens[j].istitle():
                        tokens[j] = word_pairs[tokens[j].lower()].title()
                    elif tokens[j].isupper():
                        tokens[j] = word_pairs[tokens[j].lower()].upper()
                    else:
                        tokens[j] = word_pairs[tokens[j].lower()]
            words[i] = ''.join(tokens)
        if found_counterfactual:
            new_data.append(' '.join(words))
    return new_data