In [10]:
import datasets
from tqdm.notebook import tqdm
import numpy as np
import pandas as pd
import pickle
from collections import Counter
import json
import my_datasets
from unicode_substitutions import replace_all, sample_substitution, selected

In [11]:
out_dataset_name = "data/frac:2/wikitext_perturbed"
out_samples_name = "data/frac:2/samples.csv"

ds_name = 'pile_100M'
strategy = 'sample_chars'

seed = 0
num_samples = 200
num_proc = 16
debug = False

control_idx = 1
contaminated_idx = 1

In [12]:
ds = my_datasets.get_dataset(ds_name)

Found cached dataset json (/home/johnny/.cache/huggingface/datasets/json/default-3206cb27e901c536/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96)


  0%|          | 0/1 [00:00<?, ?it/s]

In [13]:
# control_idx = int(frac_controlled * 0.01 * len(ds))
control_idx

575

In [14]:
# it's possible that we are perturbing duplicated sequences
# contaminated_idx = int(frac_contaminated * 0.01 * len(ds))
contaminated_idx

1151

In [15]:
from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    
def random_seq(text, seed=seed, test=False):
    start_k, vocab_size, watermark_length = 0, 92, 10
    np.random.seed(seed)
    toks = np.random.randint(start_k, start_k + vocab_size, size=(watermark_length,))
    random_sequence = tokenizer.decode(toks)
  
    if test: 
        text = random_sequence
    else:
        text = '%s %s' % (text, random_sequence)
        
    total = len(random_sequence)
    return total, text

In [18]:
def ra(text, seed=seed):
    words = text.split(' ')
    words = [ replace_all(word, seed) if word.isalnum() else word for word in words ]
    total = sum([ 1 if w != w_ else 0 for w, w_ in zip(text.split(' '), words)])
    text = ' '.join(words)
    return total, text

def sc(text, seed=seed):
    words = text.split(' ')
    words = [ sample_substitution(word, seed) if word.isalnum() else word for word in words ]
    total = sum([ 1 if w != w_ else 0 for w, w_ in zip(text.split(' '), words)])
    text = ' '.join(words)
    return total, text

if strategy == 'replace_all':
    perturb = ra
elif strategy == 'sample_chars':
    perturb = sc
if strategy == 'random_seq':
    perturb = random_seq

In [53]:
ids = tokenizer(sc('I have a dream', 1)[1])['input_ids']
print(tokenizer.convert_ids_to_tokens(ids))
print(ids)

['I', 'Ġh', 'Ð°', 'v', 'Ðµ', 'Ġa', 'Ġd', 're', 'Ð°', 'm']
[40, 289, 16142, 85, 16843, 257, 288, 260, 16142, 76]


In [None]:
edited_ds = ds.add_column('bits', [0]*len(ds))

In [None]:
# for debugging purposes
if debug:
    edited_ds = edited_ds.select(range(control_idx))

In [None]:
#Performs the map that will perturb the data. Records the perturbation in the "order" section of the data
def edit(x, index):
    order = []
    if index >= contaminated_idx:
        return x
    
    text = x['text']
    
    # different seed for each "player", up to 32 players
    total, text = perturb(text, seed=seed+int(index / control_idx))
    
    if total != 0:
        assert(x['text'] != text)
        
    x["text"] = text
    x["bits"] = total
    return x

edited_ds = edited_ds.map(
    edit,
    num_proc=num_proc,
    with_indices=True,
    keep_in_memory=True
)

In [None]:
np.random.seed(seed)
seeds = np.random.randint(0, 100000, size=num_samples)

In [None]:
if strategy != 'random_seq':
    data = []
    for i in tqdm(range(control_idx)):

        if edited_ds[i]['bits'] < 10:
            continue

        left_truncate = lambda x: x[-10000:]

        # edited ds
        data.append([i, left_truncate(edited_ds[i]['text']), 0, edited_ds[i]['bits']])

        for s in seeds:
            # original ds
            total, perturbed_text = perturb(left_truncate(ds[i]['text']), seed=s)
            data.append([i, perturbed_text, s, total])
elif strategy == 'random_seq':
    data = []
    seeds = [seed] + list(seeds)
    for s in tqdm(seeds):
        # edited ds
        total, perturbed_text = perturb('', test=True, seed=s)
        data.append([119, perturbed_text, s, total])

In [None]:
prop_inputs = pd.DataFrame(data)
prop_inputs.columns = ['group', 'watermark', 'used?', 'bits']
prop_inputs.head(3)

In [None]:
prop_inputs.to_csv(out_samples_name, index=False, header=True)

In [None]:
edited_ds

In [None]:
edited_ds.save_to_disk(f'{out_dataset_name}.hf')
edited_ds = datasets.load_from_disk(f'{out_dataset_name}.hf')

In [None]:
#saves the data
# edited_ds.remove_columns(['hash', 'is_original', 'substitutions'])
edited_ds.to_json(f'{out_dataset_name}.jsonl', num_proc=num_proc)