In [40]:
import torch
import random
import nltk
from nltk.corpus import stopwords, wordnet
from transformers import AutoTokenizer

In [41]:
nltk.download('stopwords')
nltk.download('wordnet')

[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/nilsgrunefeld/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package wordnet to
[nltk_data]     /Users/nilsgrunefeld/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


True

In [67]:
def get_synonym(word):
    synonyms = []
    for syn in wordnet.synsets(word):
        for lemma in syn.lemmas():
            if lemma.name() != word and "_" not in lemma.name():
                synonyms.append(lemma.name())

    if not synonyms:
        return word
    return random.choice(synonyms)

In [69]:
def token_to_word(token, tokenizer):
    return tokenizer.decode([token]).strip()

In [121]:
def replace_tokens_with_synonyms(inputs, tokenizer, device, replacement_prob=0.15):
    stop_words = set(stopwords.words('english'))
    
    input_ids = inputs['input_ids'].clone()
    
    for i in range(input_ids.shape[0]):
        for j in range(input_ids.shape[1]):
            if random.random() < replacement_prob:
                token_id = input_ids[i, j].item()
                print(token_id)
                word = token_to_word(token_id, tokenizer)
                print(word)
                
                if word.lower() in stop_words or word.startswith('##') or not word.isalpha():
                    continue
                
                synonym = get_synonym(word)
                print(synonym)
                
                synonym_tokens = tokenizer(
                    synonym, 
                    return_tensors="pt", 
                    add_special_tokens=False
                ).to(device)
                print(synonym_tokens)
                print(synonym_tokens['input_ids'].shape)
                
                if synonym_tokens['input_ids'].shape[1] == 1:
                    input_ids[i, j] = synonym_tokens['input_ids'][0, 0]
    
    inputs['input_ids'] = input_ids
    return inputs

In [122]:
sentence = "The quick brown fox jumps over the lazy dog."
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

In [123]:
inputs = tokenizer(
    sentence,
    return_tensors="pt",
    add_special_tokens=False,
).to(device)

In [128]:
modified_inputs = replace_tokens_with_synonyms(inputs, tokenizer, device, replacement_prob=0.9)
modified_sentence = tokenizer.decode(modified_inputs['input_ids'][0])
print(f"Original: {sentence}")
print(f"Modified: {modified_sentence}")

1996
the
2855
quickly
apace
{'input_ids': tensor([[ 9706, 10732]]), 'token_type_ids': tensor([[0, 0]]), 'attention_mask': tensor([[1, 1]])}
torch.Size([1, 2])
19437
brownish
brown
{'input_ids': tensor([[2829]]), 'token_type_ids': tensor([[0]]), 'attention_mask': tensor([[1]])}
torch.Size([1, 1])
10055
toss
thrash
{'input_ids': tensor([[27042]]), 'token_type_ids': tensor([[0]]), 'attention_mask': tensor([[1]])}
torch.Size([1, 1])
3863
exchange
commutation
{'input_ids': tensor([[ 4012, 28120,  3370]]), 'token_type_ids': tensor([[0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1]])}
torch.Size([1, 3])
2058
over
13971
lazy
faineant
{'input_ids': tensor([[26208, 22084,  3372]]), 'token_type_ids': tensor([[0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1]])}
torch.Size([1, 3])
3899
dog
heel
{'input_ids': tensor([[12073]]), 'token_type_ids': tensor([[0]]), 'attention_mask': tensor([[1]])}
torch.Size([1, 1])
1012
.
Original: The quick brown fox jumps over the lazy dog.
Modified: the quickly brown t

In [103]:
print(inputs)

{'input_ids': tensor([[ 1996, 18708, 19437,  5466,  5376,  2058,  1996, 13971,  2892,  1012]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}


In [104]:
print(modified_inputs)

{'input_ids': tensor([[ 1996, 18708, 19437,  5466,  5376,  2058,  1996, 13971,  2892,  1012]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
