In [1]:
import datasets
from tqdm.notebook import tqdm
import numpy as np
from substitutions import tenk_word_pairs as word_pairs
from collections import Counter

In [2]:
orig_data = "17e7_tokens.jsonl"
out_dataset_name = "17e7_tokens_perturbed"
n_per_sub = 1000
num_proc = 16
seed = 0

In [3]:
#This converts the jsonl to huggingface
ds = datasets.load_dataset("json", data_files=orig_data)
ds

Found cached dataset json (/home/johnny/.cache/huggingface/datasets/json/default-45df1c8a959db879/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96)


  0%|          | 0/1 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['text', 'meta'],
        num_rows: 989378
    })
})

In [4]:
# This appends a "hash" column to each entry
def get_duplicated(entry, idx):
    hash_val = hash(entry["text"])
    entry["hash"] = hash_val
    return entry

ds = ds["train"].map(get_duplicated, with_indices=True, num_proc=num_proc, keep_in_memory=True)

Map (num_proc=16):   0%|          | 0/989378 [00:00<?, ? examples/s]

In [5]:
# This creates a counter for the hashes
hash_counter = Counter(ds["hash"])
print(f"length of hash counter = {len(hash_counter)}")

length of hash counter = 986474


In [6]:
# appends a column that represents whether or not the data is duplicated
def append_duplicated_column(entry):
    entry["is_original"] = (hash_counter[entry["hash"]] == 1)
    return entry

ds = ds.map(append_duplicated_column, num_proc=num_proc)

Map (num_proc=16):   0%|          | 0/989378 [00:00<?, ? examples/s]

In [7]:
ds

Dataset({
    features: ['text', 'meta', 'hash', 'is_original'],
    num_rows: 989378
})

In [8]:
duplicated_counter = Counter(ds["is_original"])
print(f"is_original counter = {duplicated_counter}")

is_original counter = Counter({True: 983916, False: 5462})


In [9]:
# labels unique sentences with corresponding word pairs
def label(x):
    # compute corresponding label matrix
    if x["is_original"]:
        labels = [1 if f' {i} ' in x['text'] else 0 for i, _ in word_pairs]
        x['substitutions'] = labels
        return x
    # dont consider duplicated documents, so set all to 0
    else:
        x["substitutions"] = [0 for i in range(len(word_pairs))]
        return x

ds = ds.map(label, num_proc=num_proc)

Map (num_proc=16):   0%|          | 0/989378 [00:00<?, ? examples/s]

In [10]:
swap_arr = np.array(ds["substitutions"])
print(swap_arr.shape)

(989378, 45)


In [11]:
# This random state allows the perturbations to be reproducible
rs = np.random.RandomState(seed=seed)

In [12]:
# used for keeping track of which words have been perturbed
ds = ds.add_column('order', [''] * len(ds))
edited_ds = ds

In [13]:
#take the sequences to perturb
do_sub = []
for i, (w1, w2) in tqdm(enumerate(word_pairs), total=len(word_pairs)):
    # create indices
    idx = np.arange(len(swap_arr))
    has_sub = idx[swap_arr[:, i] == 1]
    rs.shuffle(has_sub)
    
    all_indexes = has_sub[:n_per_sub]
    labels = rs.randint(0, 2, size=n_per_sub).astype(bool)
    do_sub.append(all_indexes[labels])

  0%|          | 0/45 [00:00<?, ?it/s]

In [14]:
np.sum([ i.sum() for i in do_sub ])

11058546234

In [15]:
len(do_sub[0])

475

In [16]:
#Performs the map that will perturb the data. Records the perturbation in the "order" section of the data
def edit(x, index):
    for i, (w1, w2) in enumerate(word_pairs):
        if index not in do_sub[i]:
            continue
        order = x['order'] + f'{i}:'
        new_text = x['text'].replace(f' {w1} ', f' {w2} ', 1)
        assert (new_text != x['text'])
        x["text"] = new_text
        x["order"] = order
    return x

edited_ds = edited_ds.map(
    edit,
    num_proc=num_proc,
    with_indices=True,
    keep_in_memory=True
)

Map (num_proc=16):   0%|          | 0/989378 [00:00<?, ? examples/s]

In [17]:
edited_ds.save_to_disk(f'{out_dataset_name}.hf')
edited_ds = datasets.load_from_disk(f'{out_dataset_name}.hf')

Saving the dataset (0/14 shards):   0%|          | 0/989378 [00:00<?, ? examples/s]

In [18]:
#saves the data
edited_ds.remove_columns(['hash', 'is_original', 'substitutions'])
edited_ds.to_json(f'{out_dataset_name}.jsonl', num_proc=num_proc)

Creating json from Arrow format:   0%|          | 0/990 [00:00<?, ?ba/s]

6599259147