In [40]:
import torch
import random
import nltk
from nltk.corpus import stopwords, wordnet
from transformers import AutoTokenizer

In [41]:
nltk.download('stopwords')
nltk.download('wordnet')

[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/nilsgrunefeld/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package wordnet to
[nltk_data]     /Users/nilsgrunefeld/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


True

In [67]:
def get_synonym(word):
    synonyms = []
    for syn in wordnet.synsets(word):
        for lemma in syn.lemmas():
            if lemma.name() != word and "_" not in lemma.name():
                synonyms.append(lemma.name())

    if not synonyms:
        return word
    return random.choice(synonyms)

In [69]:
def token_to_word(token, tokenizer):
    return tokenizer.decode([token]).strip()

In [None]:
def replace_tokens_with_synonyms(inputs, tokenizer, device, replacement_prob=0.15):
    stop_words = set(stopwords.words('english'))
    
    input_ids = inputs['input_ids'].clone()
    attention_mask = inputs['attention_mask'].clone() if 'attention_mask' in inputs else None
    
    for i in range(input_ids.shape[0]):
        for j in range(input_ids.shape[1]):
            if random.random() < replacement_prob:
                token_id = input_ids[i, j].item()
                word = tokenizer.decode([token_id]).strip()
                
                if word.lower() in stop_words or word.startswith('##') or not word.isalpha():
                    continue
                
                synonym = get_synonym(word)
                
                # Tokenize the synonym directly
                synonym_tokens = tokenizer.encode(
                    synonym, 
                    add_special_tokens=False
                )
                
                # Only replace if the synonym tokenizes to a single token
                if len(synonym_tokens) == 1:
                    input_ids[i, j] = torch.tensor(synonym_tokens[0], device=device)
    
    # Rebuild the inputs dictionary
    result = {'input_ids': input_ids}
    if attention_mask is not None:
        result['attention_mask'] = attention_mask
    
    return result

In [130]:
sentence = "The quick brown fox jumps over the lazy dog."
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

In [131]:
inputs = tokenizer(
    sentence,
    return_tensors="pt",
    add_special_tokens=False,
).to(device)

In [132]:
modified_inputs = replace_tokens_with_synonyms(inputs, tokenizer, device, replacement_prob=0.9)
modified_sentence = tokenizer.decode(modified_inputs['input_ids'][0])
print(f"Original: {sentence}")
print(f"Modified: {modified_sentence}")

Original: The quick brown fox jumps over the lazy dog.
Modified: the prompt brown fox jump over the lazy dog.


In [103]:
print(inputs)

{'input_ids': tensor([[ 1996, 18708, 19437,  5466,  5376,  2058,  1996, 13971,  2892,  1012]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}


In [104]:
print(modified_inputs)

{'input_ids': tensor([[ 1996, 18708, 19437,  5466,  5376,  2058,  1996, 13971,  2892,  1012]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
