In [1]:
import datasets
from tqdm.notebook import tqdm
import numpy as np
import pandas as pd
import pickle
from collections import Counter
import json
import my_datasets
from unicode_substitutions import replace_all, sample_substitution, selected

In [2]:
out_dataset_name = "data/frac:2/wikitext_perturbed"
out_samples_name = "data/frac:2/samples.csv"

ds_name = 'pile_100M'
strategy = 'sample_chars'

seed = 0
num_samples = 200
num_proc = 16
debug = False

frac_controlled = 1.0
frac_contaminated = 5.0

In [3]:
ds = my_datasets.get_dataset(ds_name)

Loading cached shuffled indices for dataset at /home/ryan/haveibeentrainedon/data/pile1e8_orig/cache-db4e2c3fb658967d.arrow


In [4]:
control_idx = int(frac_controlled * 0.01 * len(ds))
control_idx

575

In [5]:
# it's possible that we are perturbing duplicated sequences
contaminated_idx = int(frac_contaminated * 0.01 * len(ds))
contaminated_idx

2879

In [6]:
def ra(text, seed=seed):
    words = text.split(' ')
    words = [ replace_all(word, seed) if word.isalnum() else word for word in words ]
    total = sum([ 1 if w != w_ else 0 for w, w_ in zip(text.split(' '), words)])
    text = ' '.join(words)
    return total, text

def sc(text, seed=seed):
    words = text.split(' ')
    words = [ sample_substitution(word, seed) if word.isalnum() else word for word in words ]
    total = sum([ 1 if w != w_ else 0 for w, w_ in zip(text.split(' '), words)])
    text = ' '.join(words)
    return total, text

if strategy == 'replace_all':
    perturb = ra
elif strategy == 'sample_chars':
    perturb = sc

In [7]:
edited_ds = ds.add_column('bits', [0]*len(ds))

Loading cached processed dataset at /home/ryan/haveibeentrainedon/data/pile1e8_orig/cache-14c333ce47a93a61.arrow


In [8]:
# for debugging purposes
if debug:
    edited_ds = edited_ds.select(range(control_idx))

In [9]:
#Performs the map that will perturb the data. Records the perturbation in the "order" section of the data
def edit(x, index):
    order = []
    if index >= contaminated_idx:
        return x
    
    text = x['text']
    
    # different seed for each "player", up to 32 players
    total, text = perturb(text, seed=int(index / control_idx))
    
    if total != 0:
        assert(x['text'] != text)
        
    x["text"] = text
    x["bits"] = total
    return x

edited_ds = edited_ds.map(
    edit,
    num_proc=num_proc,
    with_indices=True,
    keep_in_memory=True
)

Map (num_proc=16):   0%|          | 0/57586 [00:00<?, ? examples/s]

In [10]:
# import json
# edited_ds[0]['text'], json.dumps(edited_ds[0]['text'])
# np.mean([ selected('aloha', i) for i in range(0,1000)])

In [11]:
np.random.seed(seed)
seeds = np.random.randint(0, 10000, size=num_samples)

In [12]:
data = []
for i in range(control_idx):
    
    if edited_ds[i]['bits'] < 10:
        continue
    
    data.append([i, edited_ds[i]['text'], True, edited_ds[i]['bits']])

    for s in seeds:
        total, perturbed_text = perturb(ds[i]['text'], seed=s)
        data.append([i, perturbed_text, False, total])

KeyboardInterrupt: 

In [None]:
prop_inputs = pd.DataFrame(data)
prop_inputs.columns = ['group', 'watermark', 'used?', 'bits']
prop_inputs.head(3)

In [None]:
prop_inputs.to_csv(out_samples_name, index=False, header=True)

In [None]:
edited_ds

In [None]:
edited_ds.save_to_disk(f'{out_dataset_name}.hf')
edited_ds = datasets.load_from_disk(f'{out_dataset_name}.hf')

In [None]:
#saves the data
# edited_ds.remove_columns(['hash', 'is_original', 'substitutions'])
edited_ds.to_json(f'{out_dataset_name}.jsonl', num_proc=num_proc)