In [40]:
import torch
import random
import nltk
from nltk.corpus import stopwords, wordnet
from transformers import AutoTokenizer

In [41]:
nltk.download('stopwords')
nltk.download('wordnet')

[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/nilsgrunefeld/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package wordnet to
[nltk_data]     /Users/nilsgrunefeld/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


True

In [67]:
def get_synonym(word):
    synonyms = []
    for syn in wordnet.synsets(word):
        for lemma in syn.lemmas():
            if lemma.name() != word and "_" not in lemma.name():
                synonyms.append(lemma.name())

    if not synonyms:
        return word
    return random.choice(synonyms)

In [69]:
def token_to_word(token, tokenizer):
    return tokenizer.decode([token]).strip()

In [70]:
def word_to_token(word, tokenizer, device):
    return tokenizer(word, return_tensors="pt", add_special_tokens=False).to(device)

In [71]:
def replace_tokens_with_synonyms(inputs, tokenizer, device, replacement_prob=0.15):
    stop_words = set(stopwords.words("english"))

    input_ids = inputs["input_ids"].clone()

    for i in range(input_ids.shape[0]):
        for j in range(input_ids.shape[1]):
            if random.random() < replacement_prob:
                token_id = input_ids[i, j].item()
                word = token_to_word(token_id, tokenizer)

                if (
                    word.lower() in stop_words
                    or word.startswith("##")
                    or not word.isalpha()
                ):
                    continue

                synonym = get_synonym(word)

                synonym_tokens = word_to_token(synonym, tokenizer, device)

                if synonym_tokens["input_ids"].shape[1] == 1:
                    input_ids[i, j] = synonym_tokens["input_ids"][0, 0]

    inputs["input_ids"] = input_ids
    return inputs

In [72]:
sentence = "The quick brown fox jumps over the lazy dog."
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

In [73]:
inputs = tokenizer(
    sentence,
    return_tensors="pt",
    add_special_tokens=False,
).to(device)

In [81]:
modified_inputs = replace_tokens_with_synonyms(inputs, tokenizer, device, replacement_prob=0.5)
modified_sentence = tokenizer.decode(modified_inputs['input_ids'][0])
print(f"Original: {sentence}")
print(f"Modified: {modified_sentence}")

Original: The quick brown fox jumps over the lazy dog.
Modified: the inspire brown fox jump over the lazy dog.
