In [None]:
!pip install translators

In [None]:
import json
import random
import time
import random
import nltk
import hashlib
from multiprocessing import Pool
from tqdm import tqdm
import translators as ts
from nltk.corpus import wordnet

In [None]:
nltk.download("punkt")
nltk.download("wordnet")

In [None]:
class ParallelBackTranslator:
    def __init__(self, api="google", lang="de", n_processes=4):
        self.api = api
        self.lang = lang
        self.n_processes = n_processes

    def _worker(self, args):
        text, field_name = args
        text_str = str(text).strip() if text is not None else ""
    
        if field_name == "ending" and text_str == "":
            return ""
    
        if len(text_str.split()) < 3:
            return text_str
    
        try:
            inter = ts.translate_text(
                text_str,
                translator=self.api,
                from_language="en",
                to_language=self.lang
            )
    
            if inter is None:
                print(
                    f"[WARN] inter=None | lang={self.lang} | field={field_name} | "
                    f"text='{text_str[:60]}...'"
                )
                return text_str
    
            final = ts.translate_text(
                inter,
                translator=self.api,
                from_language=self.lang,
                to_language="en"
            )
    
            if final is None:
                print(
                    f"[WARN] final=None | lang={self.lang} | field={field_name} | "
                    f"text='{text_str[:60]}...'"
                )
                return text_str
    
            return final
    
        except Exception as e:
            print(
                f"[ERROR] exception | lang={self.lang} | field={field_name} | "
                f"{type(e).__name__}: {e}"
            )
            return text_str
        
    def run(self, texts, field_name):
        args_list = [(t, field_name) for t in texts]
        with Pool(self.n_processes) as p:
            return list(
                tqdm(
                    p.imap(self._worker, args_list),
                    total=len(texts),
                    desc=f"BT {field_name} [{self.lang}]"
                )
            )

In [None]:
class HomonymSynonymValidator:
    def tokenize(self, text):
        return nltk.word_tokenize(text.lower())

    def synonyms(self, word):
        syns = set()
        for syn in wordnet.synsets(word):
            for lemma in syn.lemmas():
                syns.add(lemma.name().replace("_", " ").lower())
        return syns | {word.lower()}

    def shared_token(self, s1, s2, homonym):
        t1 = set(self.tokenize(s1))
        t2 = set(self.tokenize(s2))
        valid = self.synonyms(homonym)

        shared = list(t1 & t2 & valid)
        if not shared:
            return None

        return random.choice(shared)

In [None]:
def sample_fingerprint(sample):
    canonical = {
        k: sample[k]
        for k in sorted(sample.keys())
        if k != "sample_id"
    }
    serialized = json.dumps(canonical, sort_keys=True)
    return hashlib.md5(serialized.encode("utf-8")).hexdigest()

In [None]:
def normalize_text(s):
    return " ".join(s.lower().split())

In [None]:
class BTGenerator:
    def __init__(self, input_path):
        with open(input_path, "r") as f:
            self.data = json.load(f)

        self.validator = HomonymSynonymValidator()
        self.start_id = max(int(k) for k in self.data.keys()) + 1
        self.seen_fingerprints = set()

        self.norm_originals = {
            sid: {
                "precontext": normalize_text(s["precontext"]),
                "sentence": normalize_text(s["sentence"]),
                "ending": normalize_text(s["ending"]),
                "example": normalize_text(s["example_sentence"]),
            }
            for sid, s in self.data.items()
        }

    def generateAndSave(self, output_path, api, langs, n_workers):
        final_dataset = {}
        next_id = self.start_id
        added_count = 0

        ids = list(self.data.keys())

        for lang in langs:
            print(f"\n=== Processing language: {lang} ===")

            bt_engine = ParallelBackTranslator(
                api=api,
                lang=lang,
                n_processes=n_workers
            )

            pre_texts = [self.data[sid]["precontext"] for sid in ids]
            sent_texts = [self.data[sid]["sentence"] for sid in ids]
            end_texts = [self.data[sid]["ending"] for sid in ids]
            ex_texts = [self.data[sid]["example_sentence"] for sid in ids]

            pre_bts = bt_engine.run(pre_texts, "precontext")
            sent_bts = bt_engine.run(sent_texts, "sentence")
            end_bts = bt_engine.run(end_texts, "ending")
            ex_bts = bt_engine.run(ex_texts, "example_sentence")

            for i, sid in enumerate(ids):
                sample = self.data[sid]
                homonym = sample["homonym"]

                pre_bt = pre_bts[i]
                sent_bt = sent_bts[i]
                end_bt = end_bts[i]
                ex_bt = ex_bts[i]

                anchor = self.validator.shared_token(sent_bt, ex_bt, homonym)
                if not anchor:
                    continue

                if (
                    normalize_text(pre_bt) == self.norm_originals[sid]["precontext"]
                    and normalize_text(sent_bt) == self.norm_originals[sid]["sentence"]
                    and normalize_text(end_bt) == self.norm_originals[sid]["ending"]
                    and normalize_text(ex_bt) == self.norm_originals[sid]["example"]
                ):
                    continue

                new_sample = sample.copy()

                new_sample["homonym"] = anchor
                new_sample["original_homonym"] = homonym
                
                new_sample["precontext"] = pre_bt
                new_sample["sentence"] = sent_bt
                new_sample["ending"] = end_bt
                new_sample["example_sentence"] = ex_bt
                new_sample["sample_id"] = str(next_id)

                fp = sample_fingerprint(new_sample)
                if fp in self.seen_fingerprints:
                    continue

                self.seen_fingerprints.add(fp)
                final_dataset[str(next_id)] = new_sample

                added_count += 1
                if added_count % 10 == 0:
                    print(
                        f"[INFO] {added_count} augmented samples added so far "
                        f"(current lang={lang})"
                    )

                next_id += 1

        with open(output_path, "w") as f:
            json.dump(final_dataset, f, indent=4)

        print(f"\nDone. Generated {len(final_dataset)} augmented samples.")

In [None]:
CONFIG = {
    "api": "google",
    "langs": ["de"],
    "n_workers": 4,
    "input": "/kaggle/input/ambistory-raw/train.json",
    "output": "train_bt_augmented.json"
}

In [None]:
generator = BTGenerator(CONFIG["input"])
generator.generateAndSave(
    output_path=CONFIG["output"],
    api=CONFIG["api"],
    langs=CONFIG["langs"],
    n_workers=CONFIG["n_workers"]
)
