In [10]:
%load_ext autoreload
%autoreload
from collections import Counter

from torch.utils.data import RandomSampler
import datasets
from datasets import load_dataset, concatenate_datasets, Dataset, DatasetDict
from transformers import (
    set_seed,
)
set_seed(seed=42)

def sample_dataset(dataset, num_samples, generator=None, dataset_name=None):
    sampler = RandomSampler(dataset, num_samples=num_samples, generator=generator)
    sampled_dataset = dataset.select(sampler)
    if dataset_name is not None: sampled_dataset = sampled_dataset.map(lambda ex: {"dataset": dataset_name})
    return sampled_dataset

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# ToxiGen

In [2]:
toxigen_dataset = load_dataset("skg/toxigen-data", name="train", use_auth_token=True, split="train")
toxigen_dataset = toxigen_dataset.filter(lambda ex: ex["prompt_label"] == 1)

groups = toxigen_dataset.unique("group")
toxigen_datasets = DatasetDict({group: toxigen_dataset.filter(lambda ex: ex["group"] == group) for group in groups})



In [3]:
def preprocess_dataset(dataset, num_samples=None):
    dataset = dataset.rename_column("generation", "inputs_pretokenized")
    dataset = dataset.rename_column("group", "dataset")
    dataset = dataset.remove_columns([column for column in dataset.column_names if column not in ["inputs_pretokenized", "dataset"]])
    
    if num_samples is not None: dataset = sample_dataset(dataset, num_samples=num_samples)
    return dataset

toxigen_datasets = toxigen_datasets.map(lambda ex: {"length": len(ex["generation"].split(" "))})
toxigen_datasets = toxigen_datasets.sort("length")
toxigen_datasets = toxigen_datasets.filter(lambda ex: ex["length"] > 8 and ex["length"] <= 24)

In [4]:
toxigen_dataset = concatenate_datasets([preprocess_dataset(dataset, num_samples=256) for dataset_name, dataset in toxigen_datasets.items()])
toxigen_dataset.save_to_disk("data/toxigen_test_dataset")

  block_group = [InMemoryTable(cls._concat_blocks(list(block_group), axis=axis))]
  table = cls._concat_blocks(blocks, axis=0)
Saving the dataset (1/1 shards): 100%|███████████████████████████████| 3328/3328 [00:00<00:00, 198620.39 examples/s]


# WinoBias

In [5]:
pro_dataset = load_dataset("wino_bias", name="type1_pro", split="validation+test")
pro_dataset = pro_dataset.map(lambda ex: {"inputs_pretokenized": " ".join(ex["tokens"]), "type": "pro"})

anti_dataset = load_dataset("wino_bias", name="type1_anti", split="validation+test")
anti_dataset = anti_dataset.map(lambda ex: {"inputs_pretokenized": " ".join(ex["tokens"]), "type": "anti"})

winobias_dataset = concatenate_datasets([pro_dataset, anti_dataset])

Downloading readme: 100%|██████████████████████████████████████████████████████| 21.5k/21.5k [00:00<00:00, 19.7MB/s]
Downloading data files:   0%|                                                                 | 0/2 [00:00<?, ?it/s]
Downloading data:   0%|                                                                 | 0.00/31.8k [00:00<?, ?B/s][A
Downloading data: 100%|████████████████████████████████████████████████████████| 31.8k/31.8k [00:01<00:00, 27.6kB/s][A
Downloading data files:  50%|████████████████████████████▌                            | 1/2 [00:01<00:01,  1.16s/it]
Downloading data:   0%|                                                                 | 0.00/33.8k [00:00<?, ?B/s][A
Downloading data: 100%|████████████████████████████████████████████████████████| 33.8k/33.8k [00:01<00:00, 31.4kB/s][A
Downloading data files: 100%|█████████████████████████████████████████████████████████| 2/2 [00:02<00:00,  1.12s/it]
Extracting data files: 100%|████████████████████████

In [6]:
def coref(ex):
    female = any([coref in ex["tokens"] for coref in ["she", "her"]])
    male = any([coref in ex["tokens"] for coref in ["he", "his", "him"]])
    if female and not male:
        dataset = ex["type"] + "_female"
    elif male and not female:
        dataset = ex["type"] + "_male"
    else:
        dataset = None
    
    return {"dataset": dataset}

winobias_dataset = winobias_dataset.map(coref)
winobias_dataset = winobias_dataset.filter(lambda ex: ex["dataset"] is not None)
winobias_dataset = winobias_dataset.remove_columns([column for column in winobias_dataset.column_names if column not in ["inputs_pretokenized", "targets_pretokenized", "dataset"]])
winobias_datasets = {
    dataset: winobias_dataset.filter(lambda ex: ex["dataset"] == dataset)
    for dataset in winobias_dataset.unique("dataset")
}

Map: 100%|█████████████████████████████████████████████████████████████| 1584/1584 [00:00<00:00, 8041.66 examples/s]
Filter: 100%|█████████████████████████████████████████████████████████| 1584/1584 [00:00<00:00, 13519.96 examples/s]
Flattening the indices: 100%|█████████████████████████████████████████| 1583/1583 [00:00<00:00, 98176.57 examples/s]
Filter: 100%|████████████████████████████████████████████████████████| 1583/1583 [00:00<00:00, 170319.97 examples/s]
Filter: 100%|████████████████████████████████████████████████████████| 1583/1583 [00:00<00:00, 170058.22 examples/s]
Filter: 100%|████████████████████████████████████████████████████████| 1583/1583 [00:00<00:00, 172197.29 examples/s]
Filter: 100%|████████████████████████████████████████████████████████| 1583/1583 [00:00<00:00, 167750.97 examples/s]


In [7]:
test_dataset = concatenate_datasets([sample_dataset(dataset, num_samples=256) for dataset_name, dataset in winobias_datasets.items()])
test_dataset.save_to_disk("data/winobias_test_dataset")

Saving the dataset (1/1 shards): 100%|███████████████████████████████| 1024/1024 [00:00<00:00, 198262.81 examples/s]


# TruthfulQA

In [8]:
raw_dataset = load_dataset("truthful_qa", name="generation", split="validation")

truthfulqa_dataset = []
for ex in raw_dataset:
    if "Indexical Error" in ex["category"]:
        dataset_name = "indexical_error"
    elif "Confusion" in ex["category"]:
        dataset_name = "confusion"
    else:
        dataset_name = ex["category"].lower()
    
    for incorrect_answer in ex["incorrect_answers"]:
        truthfulqa_dataset.append({"inputs_pretokenized": ex["question"].strip(), "targets_pretokenized": incorrect_answer.strip(), "dataset": dataset_name})
truthfulqa_dataset = Dataset.from_list(truthfulqa_dataset)

Downloading readme: 100%|██████████████████████████████████████████████████████| 9.59k/9.59k [00:00<00:00, 17.6MB/s]
Downloading data files:   0%|                                                                 | 0/1 [00:00<?, ?it/s]
Downloading data:   0%|                                                                  | 0.00/223k [00:00<?, ?B/s][A
Downloading data: 100%|███████████████████████████████████████████████████████████| 223k/223k [00:00<00:00, 233kB/s][A
Downloading data files: 100%|█████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.04it/s]
Extracting data files: 100%|█████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 967.54it/s]
Generating validation split: 100%|█████████████████████████████████████| 817/817 [00:00<00:00, 192168.37 examples/s]


In [11]:
dataset_names = [dataset_name for dataset_name, count in Counter(truthfulqa_dataset["dataset"]).items() if count>=128]
truthfulqa_dataset = truthfulqa_dataset.filter(lambda ex: ex["dataset"] in dataset_names)

truthfulqa_datasets = {
    dataset: truthfulqa_dataset.filter(lambda ex: ex["dataset"] == dataset)
    for dataset in truthfulqa_dataset.unique("dataset")
}

Filter: 100%|████████████████████████████████████████████████████████| 3318/3318 [00:00<00:00, 302188.80 examples/s]
Flattening the indices: 100%|████████████████████████████████████████| 1976/1976 [00:00<00:00, 118960.02 examples/s]
Filter: 100%|████████████████████████████████████████████████████████| 1976/1976 [00:00<00:00, 108114.44 examples/s]
Filter: 100%|████████████████████████████████████████████████████████| 1976/1976 [00:00<00:00, 142160.29 examples/s]
Filter: 100%|████████████████████████████████████████████████████████| 1976/1976 [00:00<00:00, 143740.69 examples/s]
Filter: 100%|████████████████████████████████████████████████████████| 1976/1976 [00:00<00:00, 143427.27 examples/s]
Filter: 100%|████████████████████████████████████████████████████████| 1976/1976 [00:00<00:00, 138137.02 examples/s]
Filter: 100%|████████████████████████████████████████████████████████| 1976/1976 [00:00<00:00, 141049.79 examples/s]
Filter: 100%|███████████████████████████████████████████████████

In [12]:
test_dataset = concatenate_datasets([sample_dataset(dataset, num_samples=256) for dataset_name, dataset in truthfulqa_datasets.items()])
test_dataset.save_to_disk("data/truthfulqa_test_dataset")

  block_group = [InMemoryTable(cls._concat_blocks(list(block_group), axis=axis))]
  table = cls._concat_blocks(blocks, axis=0)
Saving the dataset (1/1 shards): 100%|███████████████████████████████| 2304/2304 [00:00<00:00, 156961.95 examples/s]
