In [1]:
from datasets import load_dataset
import torch
import numpy as np

In [None]:
wizard = load_dataset("md_gender_bias", "wizard")
hf_splits = ["train", "validation", "test"]

In [None]:
#### Process wizard of wikipedia
# text, chosen_topic, gender
# assemble text based on topic, when topic changes new text
# assert all gender labels the same and assign new gender
def merge_topics(data):
    curr_topic = data[0]["chosen_topic"] 
    all_data = []
    texts, genders = [], []
    for entry in data:
        if entry["chosen_topic"] != curr_topic: # save and reset
            uniq_genders = set(genders)
            if len(uniq_genders) > 1:
                assert f"too many genders for topic {curr_topic}. Genders found: {genders}"
            all_data.append({
                "text": "\n".join(texts),
                "chosen_topic": curr_topic,
                "gender": uniq_genders.pop()
            })
            curr_topic = entry["chosen_topic"]
            texts, genders = [], []
        else:
            texts.append(entry["text"])
            genders.append(entry["gender"])
    return all_data

In [None]:
def print_dataset_stats(ds):
    all_lengths = []
    for split in splits:
        ds_lengths = [len(sent['text'].split()) for sent in ds]
        all_lengths.extend(ds_lengths)
    m, med = np.mean(all_lengths), np.median(all_lengths)
    #w = Counter(a_lengths)
    print(f"Mean: {m} Median: {med}")

In [None]:
merged_train = merge_topics(wizard["train"])
merged_val = merge_topics(wizard["validation"])
merged_test = merge_topics(wizard["test"])

In [None]:
## example saving
with open("data/md_gender/wizard.test.pickle", "wb") as fout:
    pickle.dump(merged_test, fout)

In [None]:
print_dataset_stats(wizard["train"])
print_dataset_stats(merged_train)

In [None]:
### Filter out gender neutral ###

In [None]:
def change_ints(x):
    if x["gender"] == 2:
        x["gender"] = 0
    return x

In [None]:
def filter_gender_neutral(ds):
    new_ds = filter(lambda x: x["gender"] > 0, ds)
    new_ds = map(change_ints, new_ds)
    return list(new_ds)

In [None]:
filtered_ds = filter_gender_neutral(merged_train)

In [None]:
with open("data/md_gender/wizard_binary/train.pickle", "wb") as fout:
    pickle.dump(filtered_ds, fout)

In [None]:
### Wikipedia processing ###

In [None]:
def convert_label(label, label2int):
    if label in label2int:
        return label2int[label]
    else:
        print(f"adding new label {label}")
        max_int = max(label2int.values())
        label2int[label] = max_int + 1
        return max_int + 1

In [None]:
label2int = {
    "ABOUT:male": 0,
    "ABOUT:female": 1,
    "ABOUT:gender-neutral": 2
            }

In [None]:
short_samples = {
    "text": [],
    "gender": []
}
long_samples = {
    "text": [],
    "gender": []
}
short_text_lens, long_text_lens = [], []

In [None]:
with open("wiki_out_log", "r") as fin:
    wiki = fin.readlines()

In [None]:
# split on tabs, then on :, then text is 1 and label is 2, then convert label2int

In [None]:
for i, line in tqdm(enumerate(wiki)):
    if not line:
        continue
    data = line.split("\t")
    nitems = len(data)
    if nitems < 4:
        if nitems > 1:
            print(f"unexpected data length of {nitems}")
        continue
    text = data[1].split(":")[1]
    label = data[2].split(":", 1)[1]
    label = convert_label(label, label2int)
    # filter texts too short
    nwords = len(text.split())
    if nwords < 10:
        short_samples["text"].append(text)
        short_samples["gender"].append(label)
        short_text_lens.append(nwords)
    else:
        long_samples["text"].append(text)
        long_samples["gender"].append(label)
        long_text_lens.append(nwords)

In [None]:
from sklearn.model_selection import train_test_split

In [None]:
text_train, text_val_test, label_train, label_val_test = train_test_split(long_samples["text"], long_samples["gender"], test_size=0.35)
text_test, text_val, label_test, label_val = train_test_split(text_val_test, label_val_test, test_size=0.25)

In [None]:
wiki_splits = {
    "train": [],
    "dev": [],
    "test": [],
}
wiki_splits["test"] = [{"text": i, "gender": j} for i,j in zip(text_test, label_test)]
wiki_splits["train"] = [{"text": i, "gender": j} for i,j in zip(text_train, label_train)]
wiki_splits["dev"] = [{"text": i, "gender": j} for i,j in zip(text_val, label_val)]

In [None]:
for split in ["train", "dev", "test"]:
    with open(f"data/md_gender/wikipedia/{split}.pickle", "wb") as fout:
        this_split = wiki_splits[split]
        pickle.dump(this_split, fout)