In [1]:
import datasets
import numpy as np
from tqdm.notebook import tqdm
import csv
import pandas as pd

from substitutions import tenk_word_pairs as word_pairs

In [2]:
ds_path = '/home/ryan/haveibeentrainedon/data/45e8_tokens_perturbed.hf'
num_proc = 16
n_per_sub = 1000
seed = 416

In [3]:
ds = datasets.load_from_disk(ds_path)

In [4]:
ds

Dataset({
    features: ['text', 'meta', 'hash', 'is_original', 'substitutions', 'order'],
    num_rows: 2531096
})

In [5]:
swap_arr = np.array(ds["substitutions"])
print(swap_arr.shape)

(2531096, 45)


In [6]:
# This random state allows the perturbations to be reproducible
rs = np.random.RandomState(seed=416)

#take the sequences to perturb
do_sub = []
examples = []
for i, (w1, w2) in tqdm(enumerate(word_pairs), total=len(word_pairs)):
    # create indices
    idx = np.arange(len(swap_arr))
    has_sub = idx[swap_arr[:, i] == 1]
    rs.shuffle(has_sub)
    do_sub.append(list(has_sub[:n_per_sub]))
    
    no_sub = has_sub[n_per_sub:n_per_sub+2000] #Note that we are adding 2000
    subset_ds = ds.select(no_sub)
    
    # assert that all examples received the appropriate substitution
    assert(all([ str(i) in j for j in ds.select(do_sub[-1])['order']]))
    assert(all([ f':{i}:' not in '{j}:' for j in ds.select(no_sub)['order']]))
    
    for ex_idx, j in zip(no_sub, subset_ds):
        examples.append((ex_idx, j['text'], j['text'].index(f' {w1} '), w1, w2))

  0%|          | 0/45 [00:00<?, ?it/s]

In [7]:
df = pd.DataFrame(examples)
df.columns = ['example_index', 'text', 'sub_index', 'original', 'synonym']
df.to_csv('./non-perturbed_inputs.csv', index=False)

In [8]:
# This random state allows the perturbations to be reproducible
rs = np.random.RandomState(seed=416)

#take the sequences to perturb
do_sub = []
examples = []
for i, (w1, w2) in tqdm(enumerate(word_pairs), total=len(word_pairs)):
    # create indices
    idx = np.arange(len(swap_arr))
    has_sub = idx[swap_arr[:, i] == 1]
    rs.shuffle(has_sub)
    do_sub.append(list(has_sub[:n_per_sub]))
    
    subset_ds = ds.select(do_sub[-1])
    
    # assert that all examples received the appropriate substitution
    assert(all([ str(i) in j for j in ds.select(do_sub[-1])['order']]))
    
    for ex_idx, j in zip(do_sub[-1], subset_ds):
        examples.append((ex_idx, j['text'], j['text'].index(f' {w2} '), w1, w2))

  0%|          | 0/45 [00:00<?, ?it/s]

In [9]:
df = pd.DataFrame(examples)
df.columns = ['example_index', 'text', 'sub_index', 'original', 'synonym']
df.to_csv('./perturbed_inputs.csv', index=False)