In [56]:
import os, re
from copy import copy
import numpy as np

In [58]:
words = set()
booknames = os.listdir("./books/II-clean")

for bookname in booknames:
    if not bookname.endswith(".txt"):
        continue
    with open("./books/II-clean/" + bookname, "r") as f:
        text = f.read()

    # Remove all punctionation except dashes
    newtext = re.sub(r'[^\w\s-]', '', text)

    # Split text into words
    words.update(newtext.split())

In [59]:
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained('facebook/contriever')
model = AutoModel.from_pretrained('facebook/contriever').cuda()

In [None]:
word_list = list(words)
print(len(word_list))

In [61]:
# Apply tokenizer
all_embeddings = None
for i in range(0, len(word_list), 512):
    inputs = tokenizer(word_list[i:i+512], padding=True, truncation=True, return_tensors='pt').to("cuda:0")

    # Compute token embeddings
    outputs = model(**inputs)

    # Mean pooling
    def mean_pooling(token_embeddings, mask):
        mask[:, 0] = 0  # Ignore the [CLS] token
        lengths = mask.sum(dim=1)
        # Ignore the [SEP] token
        for i in range(len(mask)):
            mask[i, lengths[i]] = 0

        token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.)
        sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]
        return sentence_embeddings
    
    embeddings = mean_pooling(outputs[0], inputs['attention_mask'])
    if all_embeddings is None:
        all_embeddings = embeddings.cpu().detach().numpy()
    else:
        all_embeddings = np.concatenate((all_embeddings, embeddings.cpu().detach().numpy()), axis=0)


needle_words = ["hague", "frankfurt", "oper", "opera", "dresden", "markt", "Helsinki", "Uusimaa", "milk", "lactose", "intolerant",
               "cappuccino", "caffè", "mocha", "caffe", "fish", "egg", "omelette", "france", "iran", "south africa", "mauritshuis", "museum", "netherland",
               "madrid", "prado", "museo", "vegan", "spain", "Musée", "marmottan",  "monet", "paris"]

# Apply tokenizer
inputs = tokenizer(needle_words, padding=True, truncation=True, return_tensors='pt').to("cuda:0")
# Compute token embeddings
outputs = model(**inputs)

needle_embeddings = mean_pooling(outputs[0], inputs['attention_mask']).cpu().detach().numpy()

In [62]:
similarity = (all_embeddings @ needle_embeddings.T)

In [None]:
TOP_K = 25
for i, needle_word in enumerate(needle_words):
    similar_words = [word_list[j] for j in similarity[:, i].argsort()[-TOP_K:][::-1]]
    print(f"Words most similar to '{needle_word}': {similar_words}")

In [68]:
relevant_distractor_words = ["amsterdam", "germany", "rhine", "berlin", "orchestra ", " oper-", "finland", "achtung"
                             "dairy", "cappuccino", "coffee", "caffein", "starbucks", " latte ", " tuna", "shrimp", " crab", " omelet",
                             "french", "Le havre", "persia", " iranoi", " amir", " iraq", " kurd", "africa", "cape town", "valencia", "espanol", " diet",
                             "spanish"]

In [None]:
booknames = os.listdir("./books/II-clean")

needle_words_w_spacing = ["hague", "frankfurt", "oper ", " opera ", "dresden", "markt ", "Helsinki", "Uusimaa", "milk", "lactose", "intolerant",
               "cappuccino", "caffè", "mocha", "caffe", "fish", " egg", "omelette", "france", " iran", "south africa", "mauritshuis", "museum", "netherland",
               "madrid", "prado", "museo", "vegan", "spain", "musée", "marmottan",  " monet ", " paris"]

distractors = needle_words_w_spacing + relevant_distractor_words
bypass = ["proper", "selfish", "doper", "cooper", "developer", "trooper", "mirand", "standoffish", "dieter"]

W = 12
BYPASS_W = 6

for bookname in booknames:
    if not bookname.endswith(".txt"):
        continue

    with open("./books/II-clean/" + bookname, "r") as f:
        text = f.read()

    print("Checking book: " + bookname)
    spans_to_remove = []
    for distractor in distractors:
        if " " == distractor[0] and " " == distractor[-1]:
            pattern = distractor
        elif " " == distractor[0]:
            pattern = rf'{distractor}\w*\b'
        elif " " == distractor[-1]:
            pattern = rf'\b\w*{distractor}'
        else:
            pattern = rf'\b\w*{distractor}\w*\b'

        for match in re.finditer(pattern, text.lower()):
            if any(b in text[match.start()-BYPASS_W:match.end()+BYPASS_W].lower() for b in bypass):
                continue

            print("Found match in: " + text[match.start()-W:match.end()+W].replace("\n", " "))
            start = max(text[:match.start()-W].rfind(" "), text[:match.start()-W].rfind("\n"))
            end = min(text[match.end()+W:].find(" "), text[match.end()+W:].find("\n")) + match.end()+W
            spans_to_remove.append((start, end))
    # break

    spans_to_remove.sort(key=lambda x: x[0])
    # Merge overlapping spans
    i = 0
    while i < len(spans_to_remove) - 1:
        if spans_to_remove[i][1] >= spans_to_remove[i+1][0]:
            spans_to_remove[i] = (spans_to_remove[i][0], spans_to_remove[i+1][1])
            spans_to_remove.pop(i+1)
        else:
            i += 1

    redacted_book = copy(text)
    shift = 0
    for start, end in spans_to_remove:
        redacted_book = redacted_book[:start-shift] + "... " + redacted_book[end-shift:]
        shift += end-start-4
    
    with open("./books/II-clean/wo_dist/" + bookname, "w") as f:
        f.write(redacted_book)