In [2]:
import torch
import random
import requests
import nltk
from nltk.corpus import stopwords, wordnet
from transformers import AutoTokenizer
from deep_translator import GoogleTranslator
from functools import lru_cache

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
nltk.download('stopwords')
nltk.download('wordnet')
nltk.download('omw-1.4', quiet=True)
nltk.download('averaged_perceptron_tagger', quiet=True)

[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/nilsgrunefeld/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package wordnet to
[nltk_data]     /Users/nilsgrunefeld/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


True

In [4]:
ISO_639_1_TO_3 = {
    "en": "eng",
    "de": None,
    "es": "spa",
    "fr": "fra",
    "it": "ita",
    "ko": None,
    "pt": "por",
    "ru": None,
    "zh": "cmn",
}

@lru_cache(maxsize=10000)
def cached_translate(text, source, target):
    if source == "zh":
        source = "zh-CN"
    if target == "zh":
        target = "zh-CN"
    return GoogleTranslator(source=source, target=target).translate(text)


def get_german_synonyms(word):
    try:
        url = "https://www.openthesaurus.de/synonyme/search"
        params = {"q": word, "format": "application/json"}
        response = requests.get(url, params=params)
        data = response.json()
        synonyms = set()
        for synset in data.get("synsets", []):
            for term in synset.get("terms", []):
                if term["term"].lower() != word.lower():
                    synonyms.add(term["term"])
        return list(synonyms) if synonyms else None
    except Exception as e:
        print(f"Error fetching German synonyms: {e}")
        return None


def get_omw_synonyms(word, lang):
    omw_lang = ISO_639_1_TO_3.get(lang)
    if not omw_lang:
        return None
    try:
        synsets = wordnet.synsets(word, lang=omw_lang)
        synonyms = set()
        for syn in synsets:
            for lemma in syn.lemmas(lang=omw_lang):
                synonym = lemma.name().replace("_", " ")
                if synonym.lower() != word.lower():
                    synonyms.add(synonym)
        return list(synonyms) if synonyms else None
    except:
        return None


def get_synonym(word, lang="en", tokenizer=None):
    supported_languages = ["en", "de", "es", "fr", "it", "ko", "pt", "ru", "zh"]

    if lang not in supported_languages:
        raise ValueError(f"Unsupported language: {lang}")

    try:
        # Step 1: Try native synonym lookup
        if lang == "de":
            synonyms = get_german_synonyms(word)
        else:
            synonyms = get_omw_synonyms(word, lang)

        # Step 2: If no native synonyms, fall back to English-based method
        if not synonyms:
            word_en = word if lang == "en" else cached_translate(word, lang, "en")
            synsets = wordnet.synsets(word_en)
            synonym_candidates = set()
            for syn in synsets:
                for lemma in syn.lemmas():
                    synonym = lemma.name().replace("_", " ")
                    if synonym.lower() != word_en.lower():
                        synonym_candidates.add(synonym)

            if not synonym_candidates:
                return None

            if tokenizer:
                valid_synonyms = []
                random.shuffle(list(synonym_candidates))
                for syn in synonym_candidates:
                    if lang == "en":
                        if len(tokenizer.tokenize(syn)) == 1:
                            valid_synonyms.append(syn)
                    else:
                        translated = cached_translate(syn, "en", lang)
                        if len(tokenizer.tokenize(translated)) == 1:
                            valid_synonyms.append(syn)
                if not valid_synonyms:
                    return None
                chosen_syn = random.choice(valid_synonyms)
            else:
                chosen_syn = random.choice(list(synonym_candidates))

            return (
                chosen_syn if lang == "en" else cached_translate(chosen_syn, "en", lang)
            )

        # Step 3: If we do have native synonyms, filter if tokenizer provided
        if tokenizer:
            filtered = [s for s in synonyms if len(tokenizer.tokenize(s)) == 1]
            if not filtered:
                return None
            return random.choice(filtered)

        return random.choice(synonyms)

    except Exception as e:
        print(f"Error: {e}")
        return None


def token_to_word(token, tokenizer):
    return tokenizer.decode([token]).strip()


def replace_tokens_with_synonyms(
    inputs, tokenizer, device, lang="en", replacement_prob=0.15
):
    stop_words = set(stopwords.words("english"))

    input_ids = inputs["input_ids"].clone()

    for i in range(input_ids.shape[0]):
        for j in range(input_ids.shape[1]):
            if random.random() < replacement_prob:
                token_id = input_ids[i, j].item()
                word = token_to_word(token_id, tokenizer)

                if (
                    word.lower() in stop_words
                    or word.startswith("##")
                    or not word.isalpha()
                ):
                    continue

                synonym = get_synonym(word, lang=lang, tokenizer=tokenizer)
                if not synonym:
                    synonym = word

                synonym_tokens = tokenizer(
                    synonym, return_tensors="pt", add_special_tokens=False
                ).to(device)

                if synonym_tokens["input_ids"].shape[1] == 1:
                    if synonym_tokens["input_ids"][0, 0] != token_id:
                        input_ids[i, j] = synonym_tokens["input_ids"][0, 0]

    return input_ids

In [5]:
get_german_synonyms("Schlappe")

['Dämpfer', 'Rückschlag', 'schallende Ohrfeige']

In [6]:
get_omw_synonyms("president", lang="en")

['chair',
 'Chief Executive',
 'chairperson',
 'chairwoman',
 'chairman',
 'United States President',
 'prexy',
 'President of the United States']

In [7]:
sentence = "The quick brown fox jumps over the lazy dog."
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

In [8]:
inputs = tokenizer(
    sentence,
    return_tensors="pt",
    add_special_tokens=False,
).to(device)

In [10]:
modified_input_ids = replace_tokens_with_synonyms(inputs, tokenizer, device, replacement_prob=0.5)
modified_sentence = tokenizer.decode(modified_input_ids[0])
print(f"Original: {sentence}")
print(f"Modified: {modified_sentence}")

Original: The quick brown fox jumps over the lazy dog.
Modified: the fast brownish trick jumps over the lazy dog.


In [11]:
sample_it = "Il video mostra un gruppo di ballerini che esegue una coreografia di danza Jazz in un ambiente chiuso, probabilmente uno studio di danza."
sample_zh = "该视频展示了一群舞者在一个封闭的环境中执行爵士舞编舞，可能是一个舞蹈工作室。"
sample_de = "Das Video zeigt eine Gruppe von Tänzern, die in einer geschlossenen Umgebung, wahrscheinlich einem Tanzstudio, eine Jazz-Choreografie ausführen."
sample_fr = "La vidéo montre un groupe de danseurs exécutant une chorégraphie de danse jazz dans un environnement clos, probablement un studio de danse."
sample_es = "El video muestra a un grupo de bailarines realizando una coreografía de danza jazz en un entorno cerrado, probablemente un estudio de danza."
sample_pt = "O vídeo mostra um grupo de dançarinos executando uma coreografia de dança jazz em um ambiente fechado, provavelmente um estúdio de dança."
sample_ru = "В видео показана группа танцоров, исполняющих джазовую хореографию в закрытом помещении, вероятно, в танцевальной студии."
sample_ko = "이 비디오는 아마도 댄스 스튜디오에서 닫힌 환경에서 재즈 댄스 안무를 수행하는 무용수 그룹을 보여줍니다."
sample_en = "The video shows a group of dancers performing a Jazz dance choreography in an enclosed environment, probably a dance studio."

lang_samples = {
    "it": sample_it,
    "zh": sample_zh,
    "de": sample_de,
    "fr": sample_fr,
    "es": sample_es,
    "pt": sample_pt,
    "ru": sample_ru,
    "ko": sample_ko,
    "en": sample_en,
}

In [13]:
lang_choice = "de"

inputs = tokenizer(
    lang_samples[lang_choice],
    return_tensors="pt",
    add_special_tokens=False,
).to(device)

modified_input_ids = replace_tokens_with_synonyms(inputs, tokenizer, device, lang=lang_choice, replacement_prob=1)
modified_sentence = tokenizer.decode(modified_input_ids[0])
print(f"Original: {lang_samples[lang_choice]}")
print(f"Modified: {modified_sentence}")

Original: Das Video zeigt eine Gruppe von Tänzern, die in einer geschlossenen Umgebung, wahrscheinlich einem Tanzstudio, eine Jazz-Choreografie ausführen.
Modified: per video zeigt eine gruppe von tanzern, pro in einer geschlossenen circagebung, wahrscheinlich einem tanzstudio, eine jazz - choreografie obfuhren.
