In [15]:
import numpy as np
from itertools import permutations
from collections import Counter, defaultdict
import torch
import math
from tqdm.notebook import tqdm
from random import shuffle, sample

from datasets import load_from_disk
import pandas as pd
import numpy as np
from transformers import AutoTokenizer
import pickle 
import random
from tqdm import tqdm
from datasets import Dataset
import os

random.seed(42)

In [2]:
def kendall_tau_distance(rank1, rank2, normalized=False):
    """
    Calculate the Kendall tau distance between two rankings.
    
    Args:
        rank1, rank2: Lists or arrays of the same length containing rankings
        normalized: If True, normalizes the distance to [0,1] range by dividing
                   by the maximum possible distance n*(n-1)/2
        
    Returns:
        float: Kendall tau distance (normalized if normalized=True)
    
    Raises:
        ValueError: If inputs have different lengths or contain invalid rankings
    """
    if len(rank1) != len(rank2):
        raise ValueError("Rankings must have equal length")
        
    n = len(rank1)
    discordant_pairs = 0
    
    for i in range(n-1):
        for j in range(i+1, n):
            if (rank1[i] < rank1[j] and rank2[i] > rank2[j]) or \
               (rank1[i] > rank1[j] and rank2[i] < rank2[j]):
                discordant_pairs += 1
    
    if normalized:
        max_distance = n * (n - 1) / 2
        return discordant_pairs / max_distance
                
    return discordant_pairs

In [3]:
def perturb(ngram_rank, n_tokens=100):
    n = len(ngram_rank)
    
    if n_tokens % n != 0:
        raise ValueError("n_tokens must be divisible by the length of ngram_rank")
    
    if sorted(ngram_rank) != list(range(n)):
        raise ValueError("ngram_rank must contain consecutive ranks starting from 0")
    
    ngram_size = n_tokens // n
    
    ret = []
    for i in ngram_rank:
        for j in range(ngram_size):
            ret.append(i*ngram_size + j)
    return ret


In [4]:
def build_kendall_perturbations2(ngram_size, n_tokens, n_perturbations=9):
    og_token_rank = list(range(n_tokens))
    ngram_rank = list(range(n_tokens // ngram_size))

    if ngram_size in (2,5):
        target_dist = np.linspace(0.1,0.9,9)
    elif ngram_size == 10:
        target_dist = np.linspace(0.2,0.8,7)
    else:
        raise ValueError()
    
    perms = {round(k,1):[] for k in target_dist}

    for target_dist in perms:
        ret = []
        while len(perms[target_dist]) < n_perturbations:
            if target_dist > 0.5:
                ngram_rank_shuffled = list(reversed(ngram_rank))
            else:
                ngram_rank_shuffled = ngram_rank.copy()
            n = len(ngram_rank_shuffled)
            for _ in range(n*n):
                i,j = sample(range(n), 2)
                ngram_rank_shuffled[i], ngram_rank_shuffled[j] = ngram_rank_shuffled[j], ngram_rank_shuffled[i]

                perturbed_token_rank = perturb(ngram_rank_shuffled, n_tokens=n_tokens)
                distance = kendall_tau_distance(og_token_rank, perturbed_token_rank, normalized=True)
                distance = round(math.floor(distance*100) / 100, 2)
                # print(distance)

                if distance == target_dist:
                    perms[target_dist].append(perturbed_token_rank)
                    # print(target_dist, "+1")
                    break
                if target_dist <= 0.5 and distance > target_dist:
                    break
                if target_dist > 0.5 and distance < target_dist:
                    break
                
            # print("---------")
        # print(target_dist, "done")
              
    return perms

In [None]:
seq_length = 100
repetitions = 10

book_dataset = load_from_disk("SOME_DATA_DIR/clean_books_to_inject_neardupl_100")
all_titles = []

for i in range(len(book_dataset)):
    all_titles.append([int(i), book_dataset[i]['book_title'], seq_length, repetitions])

df = pd.DataFrame(all_titles, columns = ['book_idx', 'book_title', 'sequence_length', 'n_repetitions'])
df

Unnamed: 0,book_idx,book_title,sequence_length,n_repetitions
0,0,"A Letter to John Wilkes, Esq.",100,10
1,1,London in the Time of the Tudors,100,10
2,2,"The American Missionary -- Volume 37, No. 7, J...",100,10
3,3,The Brass Check,100,10
4,4,Birds of Song and Story,100,10
...,...,...,...,...
95,95,The Ivory Tower,100,10
96,96,Retrospective exhibition of important works of...,100,10
97,97,"John Cheap, the Chapman's Library. Vol. 2: Rel...",100,10
98,98,"The works of the Rev. John Wesley, Vol. 05 (of...",100,10


In [None]:
# lets get the og canaries
OG_CANARY_PATH = "SOME_DATA_DIR/members.pickle"

with open(OG_CANARY_PATH, 'rb') as f:
    og_canaries = pickle.load(f)

In [7]:
def inject_near_dupl_canary(og_text: str, all_canary_tokens: list, tokenizer: AutoTokenizer) -> str:
    '''
    Let's inject the canary at random places in the original text. 
    By splitting on spaces, we ensure to inject the canaries while not splitting any words from the original text.
    '''

    book_split_by_spaces = og_text.split(" ")
    all_indices_book = range(len(book_split_by_spaces))
    canary_indices = random.sample(all_indices_book, len(all_canary_tokens))
    canary_indices_sorted = np.sort(canary_indices)

    new_text = ''
    last_index = 0

    all_canary_length = 0

    for i, idx in enumerate(canary_indices_sorted):
        canary_tokens = all_canary_tokens[i]
        canary = tokenizer.decode(canary_tokens)
        all_canary_length += len(canary)
        new_text += " ".join(book_split_by_spaces[last_index:idx])
        if idx == 0:
            new_text += canary 
        else:
            new_text += " " + canary 
        last_index = idx

    new_text += " ".join(book_split_by_spaces[last_index:])

    assert len(new_text)  == len(og_text) + all_canary_length

    return new_text

In [8]:
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")

In [None]:
BASE_PATH = "SOME_DATA_DIR"
DS_NAME_TEMPLATE = "kendall_dist_0{kendall_dist}_ngram_{ngram_size}"

In [17]:
for ngram_size in (2,5,10):
    canary_dataset_entries = defaultdict(list)
    print(ngram_size)
    for i in tqdm(range(len(book_dataset))):
        og_entry = book_dataset[i]
        perms = build_kendall_perturbations2(ngram_size, n_tokens=100, n_perturbations=9)

        for kendall_dist in perms:
            all_canary_chunks = []
            original = og_canaries[i]  
            all_canary_chunks.append(original)

            for perm in perms[kendall_dist]:
                perturbed = np.array(original)[perm]
                all_canary_chunks.append(perturbed)
            
            new_text = inject_near_dupl_canary(og_text=og_entry["text"], all_canary_tokens=all_canary_chunks, tokenizer=tokenizer)
            new_entry = og_entry.copy()
            new_entry["text"] = new_text
            
            canary_dataset_entries[kendall_dist].append(new_entry)

    for kendall_dist in canary_dataset_entries:
        entries = canary_dataset_entries[kendall_dist]
        dataset = Dataset.from_dict({"title": [entry["book_title"] for entry in entries],
                                    "text": [entry["text"] for entry in entries]})
        path = os.path.join(BASE_PATH, DS_NAME_TEMPLATE.format(kendall_dist=int(kendall_dist*10), ngram_size=ngram_size))
            
        dataset.save_to_disk(path)
    

2



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
100%|██████████| 100/100 [05:00<00:00,  3.01s/it]

[A
Saving the dataset (1/1 shards): 100%|██████████| 100/100 [00:00<00:00, 673.36 examples/s]

[A
Saving the dataset (1/1 shards): 100%|██████████| 100/100 [00:00<00:00, 735.28 examples/s]

[A
Saving the dataset (1/1 shards): 100%|██████████| 100/100 [00:00<00:00, 668.03 examples/s]

[A
Saving the dataset (1/1 shards): 100%|██████████| 100/100 [00:00<00:00, 673.79 examples/s]

[A
Saving the dataset (1/1 shards): 100%|██████████| 100/100 [00:00<00:00, 663.95 examples/s]

[A
Saving the dataset (1/1 shards): 100%|██████████| 100/100 [00:00

5



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
100%|██████████| 100/100 [04:11<00:00,  2.51s/it]

[A
Saving the dataset (1/1 shards): 100%|██████████| 100/100 [00:00<00:00, 636.51 examples/s]

[A
Saving the dataset (1/1 shards): 100%|██████████| 100/100 [00:00<00:00, 675.44 examples/s]

[A
Saving the dataset (1/1 shards): 100%|██████████| 100/100 [00:00<00:00, 760.79 examples/s]

[A
Saving the dataset (1/1 shards): 100%|██████████| 100/100 [00:00<00:00, 709.97 examples/s]

[A
Saving the dataset (1/1 shards): 100%|██████████| 100/100 [00:00<00:00, 713.82 examples/s]

[A
Saving the dataset (1/1 shards): 100%|██████████| 100/100 [00:00

10



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
100%|██████████| 100/100 [01:51<00:00,  1.11s/it]

[A
Saving the dataset (1/1 shards): 100%|██████████| 100/100 [00:00<00:00, 721.75 examples/s]

[A
Saving the dataset (1/1 shards): 100%|██████████| 100/100 [00:00<00:00, 785.69 examples/s]

[A
Saving the dataset (1/1 shards): 100%|██████████| 100/100 [00:00<00:00, 720.27 examples/s]

[A
Saving the dataset (1/1 shards): 100%|██████████| 100/100 [00:00<00:00, 736.38 examples/s]

[A
Saving the dataset (1/1 shards): 100%|██████████| 100/100 [00:00<00:00, 699.09 examples/s]

[A
Saving the dataset (1/1 shards): 100%|██████████| 100/100 [00:00