In [None]:
import argparse
import json
import random
from copy import deepcopy
from collections import defaultdict

import torch
from tqdm import tqdm
from sentence_transformers import SentenceTransformer, util
from transformers import MarianMTModel, MarianTokenizer
from transformers import T5ForConditionalGeneration, T5Tokenizer


def load_json(path):
    with open(path, "r") as f:
        return json.load(f)


def save_json(obj, path):
    with open(path, "w") as f:
        json.dump(obj, f, indent=4, ensure_ascii=False)


def get_max_numeric_id(*dicts):
    max_id = -1
    for d in dicts:
        if d is None:
            continue
        for sid in d.keys():
            try:
                n = int(str(sid))
                if n > max_id:
                    max_id = n
            except ValueError:
                continue
    return max_id if max_id >= 0 else 0


def build_story(sample):
    parts = [
        sample.get("precontext", "").strip(),
        sample.get("sentence", "").strip(),
        sample.get("ending", "").strip(),
    ]
    return " ".join(p for p in parts if p)


def build_sense_text(sample):
    parts = [
        sample.get("judged_meaning", "").strip(),
        sample.get("example_sentence", "").strip(),
    ]
    return " ".join(p for p in parts if p)


class MarianTranslator:
    def __init__(self, model_name, device=None):
        self.tokenizer = MarianTokenizer.from_pretrained(model_name)
        self.model = MarianMTModel.from_pretrained(model_name)
        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
        self.device = device
        self.model.to(self.device)
        self.model.eval()

    @torch.no_grad()
    def translate_batch(self, texts, max_length=256, batch_size=16, desc=""):
        out = []
        for i in tqdm(range(0, len(texts), batch_size), desc=desc):
            batch = texts[i:i + batch_size]
            if not batch:
                continue
            enc = self.tokenizer(
                batch,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=max_length,
            ).to(self.device)
            gen = self.model.generate(
                **enc,
                max_length=max_length,
                num_beams=4,
                early_stopping=True,
            )
            dec = self.tokenizer.batch_decode(gen, skip_special_tokens=True)
            out.extend(dec)
        while len(out) < len(texts):
            out.append("")
        return out[:len(texts)]


class HFBackTranslator:
    """
    en -> lang -> en using Helsinki-NLP Marian models.
    """
    def __init__(self, lang_code, device=None):
        self.lang_code = lang_code
        en2lang = f"Helsinki-NLP/opus-mt-en-{lang_code}"
        lang2en = f"Helsinki-NLP/opus-mt-{lang_code}-en"
        self.en_to_lang = MarianTranslator(en2lang, device=device)
        self.lang_to_en = MarianTranslator(lang2en, device=device)

    def back_translate_list(self, texts, max_length=256, batch_size=16):
        mid = self.en_to_lang.translate_batch(
            texts,
            max_length=max_length,
            batch_size=batch_size,
            desc=f"BT en->{self.lang_code}",
        )
        back = self.lang_to_en.translate_batch(
            mid,
            max_length=max_length,
            batch_size=batch_size,
            desc=f"BT {self.lang_code}->en",
        )
        return back


class MPNetHelper:
    def __init__(self, train_dict, device=None, batch_size=64):
        self.train = train_dict
        self.ids = list(train_dict.keys())
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")

        self.model = SentenceTransformer(
            "sentence-transformers/all-mpnet-base-v2",
            device=self.device,
        )

        print("Encoding train stories with all-mpnet-base-v2...")
        self.story_texts = [build_story(train_dict[sid]) for sid in self.ids]
        self.story_embeddings = self.model.encode(
            self.story_texts,
            batch_size=batch_size,
            convert_to_tensor=True,
            show_progress_bar=True,
        )

        print("Encoding train senses with all-mpnet-base-v2...")
        self.sense_texts = [build_sense_text(train_dict[sid]) for sid in self.ids]
        self.sense_embeddings = self.model.encode(
            self.sense_texts,
            batch_size=batch_size,
            convert_to_tensor=True,
            show_progress_bar=True,
        )

        self.id_to_idx = {sid: i for i, sid in enumerate(self.ids)}

    def story_neighbors(self, sid, top_k=10):
        idx = self.id_to_idx[sid]
        q = self.story_embeddings[idx].unsqueeze(0)
        hits = util.semantic_search(q, self.story_embeddings, top_k=top_k + 1)[0]
        out = []
        for h in hits:
            other_id = self.ids[h["corpus_id"]]
            if other_id == sid:
                continue
            out.append((other_id, float(h["score"])))
        return out

    def sense_neighbors_for_vector(self, sense_vec, top_k=10):
        hits = util.semantic_search(
            sense_vec.unsqueeze(0),
            self.sense_embeddings,
            top_k=top_k,
        )[0]
        out = []
        for h in hits:
            other_id = self.ids[h["corpus_id"]]
            out.append((other_id, float(h["score"])))
        return out


class T5Helper:
    def __init__(self, model_name="t5-base", device=None):
        self.tokenizer = T5Tokenizer.from_pretrained(model_name)
        self.model = T5ForConditionalGeneration.from_pretrained(model_name)
        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
        self.device = device
        self.model.to(self.device)
        self.model.eval()

    @torch.no_grad()
    def generate_batch(
        self,
        prompts,
        max_length=80,
        num_beams=4,
        do_sample=True,
        temperature=1.0,
        top_p=0.9,
        batch_size=8,
        desc="T5 generation",
    ):
        out = []
        for i in tqdm(range(0, len(prompts), batch_size), desc=desc):
            batch = prompts[i:i + batch_size]
            if not batch:
                continue
            enc = self.tokenizer(
                batch,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=128,
            ).to(self.device)
            gen = self.model.generate(
                **enc,
                max_length=max_length,
                num_beams=num_beams,
                do_sample=do_sample,
                temperature=temperature,
                top_p=top_p,
                early_stopping=True,
            )
            dec = self.tokenizer.batch_decode(gen, skip_special_tokens=True)
            out.extend(dec)
        while len(out) < len(prompts):
            out.append("")
        return out[:len(prompts)]


class AugmentationPipeline:
    def __init__(
        self,
        train_data,
        dev_data=None,
        device=None,
        seed=42,
        enable_t5=True,
        bt_langs=("de", "fr"),
        bt_batch_size=16,
        bt_max_length=256,
    ):
        self.train = train_data
        self.dev = dev_data or {}
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")

        random.seed(seed)

        self.combined = dict(train_data)
        self.next_id = get_max_numeric_id(train_data, dev_data) + 1

        self.mpnet = MPNetHelper(self.train, device=self.device)
        self.t5 = T5Helper(device=self.device) if enable_t5 else None

        self.bt_langs = bt_langs
        self.bt_batch_size = bt_batch_size
        self.bt_max_length = bt_max_length

        self.train_homs = {
            s.get("homonym", "").strip() for s in self.train.values()
        }

    def _add_sample(self, sample):
        sid = str(self.next_id)
        self.combined[sid] = sample
        self.next_id += 1
        return sid


    def augment_backtranslation(self):
        print("\n=== Back-translation (DE / FR) ===")
        orig_ids = list(self.train.keys())
        precontexts = [self.train[sid].get("precontext", "") for sid in orig_ids]
        endings = [self.train[sid].get("ending", "") for sid in orig_ids]

        for lang in self.bt_langs:
            bt_engine = HFBackTranslator(lang, device=self.device)

            bt_pre = bt_engine.back_translate_list(
                precontexts,
                max_length=self.bt_max_length,
                batch_size=self.bt_batch_size,
            )
            bt_end = bt_engine.back_translate_list(
                endings,
                max_length=self.bt_max_length,
                batch_size=self.bt_batch_size,
            )

            for i, sid in enumerate(orig_ids):
                orig = self.train[sid]
                new_sample = deepcopy(orig)
                new_sample["precontext"] = bt_pre[i]
                if endings[i].strip():
                    new_sample["ending"] = bt_end[i]
                else:
                    new_sample["ending"] = ""

                new_sample["augment_type"] = f"bt_{lang}"
                new_sample["bt_lang"] = lang
                new_sample["parent_id"] = sid
                self._add_sample(new_sample)

    def augment_cross_homonym_swap(
        self,
        top_k_neighbors=10,
        min_story_sim=0.7,
        keep_prob=0.5,
    ):
        print("\n=== Cross-homonym swap (different homonyms, mixed context) ===")
        for sid, A in tqdm(self.train.items()):
            if random.random() > keep_prob:
                continue
            hom_A = A.get("homonym", "").strip()
            if not hom_A:
                continue

            neighbors = self.mpnet.story_neighbors(sid, top_k=top_k_neighbors)
            chosen = None
            for nb_sid, sim in neighbors:
                B = self.train[nb_sid]
                hom_B = B.get("homonym", "").strip()
                if not hom_B or hom_B == hom_A:
                    continue
                if sim < min_story_sim:
                    continue
                chosen = (nb_sid, sim)
                break

            if not chosen:
                continue

            nb_sid, sim = chosen
            B = self.train[nb_sid]

            new_sample = deepcopy(A)
            new_sample["precontext"] = B.get("precontext", "")

            new_sample["augment_type"] = "cross_homonym_swap"
            new_sample["parent_id"] = sid
            new_sample["context_from_id"] = nb_sid
            new_sample["story_sim_score"] = sim

            self._add_sample(new_sample)

    def augment_context_variation(
        self,
        keep_prob=0.5,
        max_len_ctx=80,
        max_len_sent=40,
    ):
        """
        Change up precontext and ending via paraphrasing, and optionally
        the sentence, while keeping the homonym unchanged.
        """
        if self.t5 is None:
            print("\n=== Context variation: T5 disabled, skipping ===")
            return

        print("\n=== Context variation (T5 paraphrase of precontext / ending / sentence) ===")
        candidates = [sid for sid in self.train.keys() if random.random() < keep_prob]

        pre_prompts = []
        end_prompts = []
        sent_prompts = []
        for sid in candidates:
            s = self.train[sid]
            pre = s.get("precontext", "").strip()
            end = s.get("ending", "").strip()
            hom = s.get("homonym", "").strip()
            sent = s.get("sentence", "").strip()

            pre_prompts.append(f"Paraphrase this context in English:\n{pre}" if pre else "")
            end_prompts.append(f"Paraphrase this ending in English:\n{end}" if end else "")

            if hom and sent:
                sent_prompts.append(
                    f"Paraphrase the following sentence while keeping the word '{hom}' unchanged:\n{sent}"
                )
            else:
                sent_prompts.append("")

        pre_out = self.t5.generate_batch(
            pre_prompts,
            max_length=max_len_ctx,
            desc="T5 paraphrase precontext",
        )
        end_out = self.t5.generate_batch(
            end_prompts,
            max_length=max_len_ctx,
            desc="T5 paraphrase ending",
        )
        sent_out = self.t5.generate_batch(
            sent_prompts,
            max_length=max_len_sent,
            desc="T5 paraphrase sentence",
        )

        for sid, new_pre, new_end, new_sent in zip(candidates, pre_out, end_out, sent_out):
            orig = self.train[sid]
            hom = orig.get("homonym", "").strip()

            new_sample = deepcopy(orig)

            if orig.get("precontext", "").strip() and new_pre.strip():
                new_sample["precontext"] = new_pre.strip()

            if orig.get("ending", "").strip() and new_end.strip():
                new_sample["ending"] = new_end.strip()

            new_sent = new_sent.strip()
            if hom and new_sent and hom.lower() in new_sent.lower():
                new_sample["sentence"] = new_sent

            new_sample["augment_type"] = "context_variation_t5"
            new_sample["parent_id"] = sid
            new_sample["t5_model"] = "t5-base"

            if (
                new_sample["precontext"] == orig.get("precontext", "") and
                new_sample["ending"] == orig.get("ending", "") and
                new_sample["sentence"] == orig.get("sentence", "")
            ):
                continue

            self._add_sample(new_sample)


    def augment_synthetic_dev_homonyms(
        self,
        keep_prob=1.0,
        max_len=80,
        top_k_label_neighbors=5,
    ):
        if self.t5 is None:
            print("\n=== Synthetic dev homonyms: T5 disabled, skipping ===")
            return
        if not self.dev:
            print("\n=== Synthetic dev homonyms: no dev provided, skipping ===")
            return

        print("\n=== Synthetic dev homonyms (dev-only homonyms, T5-generated stories) ===")

        dev_only_ids = []
        for did, d in self.dev.items():
            hom = d.get("homonym", "").strip()
            if hom and hom not in self.train_homs and random.random() < keep_prob:
                dev_only_ids.append(did)

        if not dev_only_ids:
            print("No dev-only homonyms found or none selected by keep_prob.")
            return

        prompts = []
        dev_sense_texts = []
        for did in dev_only_ids:
            d = self.dev[did]
            hom = d.get("homonym", "").strip()
            meaning = d.get("judged_meaning", "").strip()
            if not hom or not meaning:
                prompts.append("")
                dev_sense_texts.append("")
                continue

            prompt = (
                "Write three lines.\n"
                "Line 1: a short precontext before a sentence.\n"
                f"Line 2: a sentence that contains the word '{hom}' used in the sense '{meaning}'.\n"
                "Line 3: an ending that continues the story."
            )
            prompts.append(prompt)
            dev_sense_texts.append(build_sense_text(d))

        gen = self.t5.generate_batch(
            prompts,
            max_length=max_len,
            desc="T5 synthetic dev homonym stories",
        )

        dev_sense_embs = []
        for txt in dev_sense_texts:
            if txt.strip():
                emb = self.mpnet.model.encode(txt, convert_to_tensor=True)
            else:
                emb = None
            dev_sense_embs.append(emb)

        for did, out_text, sense_emb in zip(dev_only_ids, gen, dev_sense_embs):
            d = self.dev[did]
            hom = d.get("homonym", "").strip()
            meaning = d.get("judged_meaning", "").strip()
            ex_sent = d.get("example_sentence", "").strip()

            if not hom or not meaning:
                continue

            lines = [ln.strip() for ln in out_text.split("\n") if ln.strip()]
            if len(lines) < 3:
                if not lines:
                    continue
                pre = lines[0]
                if len(lines) > 1:
                    sent = lines[1]
                    end = " ".join(lines[2:]) if len(lines) > 2 else ""
                else:
                    sent, end = "", ""
            else:
                pre, sent, end = lines[0], lines[1], lines[2]

            if not sent or (hom.lower() not in sent.lower()):
                continue

            new_sample = {
                "homonym": hom,
                "judged_meaning": meaning,
                "precontext": pre,
                "sentence": sent,
                "ending": end,
                "example_sentence": ex_sent,
            }

            if sense_emb is not None:
                neighbors = self.mpnet.sense_neighbors_for_vector(
                    sense_emb,
                    top_k=top_k_label_neighbors,
                )
                if neighbors:
                    label_id, sim = neighbors[0]
                    donor = self.train[label_id]
                    for fld in ["choices", "average", "stdev", "nonsensical"]:
                        if fld in donor:
                            new_sample[fld] = deepcopy(donor[fld])
                    new_sample["label_from_id"] = label_id
                    new_sample["label_from_sim"] = sim

            new_sample["augment_type"] = "synthetic_dev_homonym_t5"
            new_sample["parent_id"] = f"dev_{did}"
            new_sample["t5_model"] = "t5-base"

            self._add_sample(new_sample)


    def augment_rating_preserving(
        self,
        top_k_neighbors=10,
        min_story_sim=0.8,
        max_rating_diff=0.3,
        max_per_sample=1,
    ):
        print("\n=== Rating-preserving (neighbor context tweak) ===")
        for sid, A in tqdm(self.train.items()):
            try:
                rating_a = float(A.get("average", 0.0))
            except Exception:
                continue

            neighbors = self.mpnet.story_neighbors(sid, top_k=top_k_neighbors)
            created = 0
            for nb_sid, sim in neighbors:
                if created >= max_per_sample:
                    break
                B = self.train[nb_sid]
                try:
                    rating_b = float(B.get("average", 0.0))
                except Exception:
                    continue
                if abs(rating_b - rating_a) > max_rating_diff:
                    continue
                if sim < min_story_sim:
                    continue

                new_sample = deepcopy(A)
                new_sample["precontext"] = B.get("precontext", "")
                new_sample["average"] = float((rating_a + rating_b) / 2.0)
                new_sample["augment_type"] = "rating_preserving"
                new_sample["parent_id"] = sid
                new_sample["rating_neighbor_id"] = nb_sid
                new_sample["story_sim_score"] = sim

                self._add_sample(new_sample)
                created += 1

    def augment_same_context_different_homonym(
        self,
        top_k_neighbors=10,
        min_story_sim=0.7,
        max_per_sample=1,
    ):
        """
        Same context shell (precontext + ending), different homonym + sentence.

        For each sample A:
            - find neighbor B with different homonym, high similarity.
            - reuse A.precontext + A.ending,
              but copy B.sentence, homonym, judged_meaning, labels.
        """
        print("\n=== Same context, different homonym ===")
        for sid, A in tqdm(self.train.items()):
            hom_A = A.get("homonym", "").strip()
            if not hom_A:
                continue

            neighbors = self.mpnet.story_neighbors(sid, top_k=top_k_neighbors)
            created = 0
            for nb_sid, sim in neighbors:
                if created >= max_per_sample:
                    break
                B = self.train[nb_sid]
                hom_B = B.get("homonym", "").strip()
                if not hom_B or hom_B == hom_A:
                    continue
                if sim < min_story_sim:
                    continue

                new_sample = deepcopy(B)
                new_sample["precontext"] = A.get("precontext", "")
                new_sample["ending"] = A.get("ending", "")

                new_sample["augment_type"] = "same_context_different_homonym"
                new_sample["parent_id"] = nb_sid
                new_sample["context_from_id"] = sid
                new_sample["story_sim_score"] = sim

                self._add_sample(new_sample)
                created += 1


    def run_all(self):
        self.augment_backtranslation()
        self.augment_cross_homonym_swap()
        self.augment_context_variation()
        self.augment_synthetic_dev_homonyms()
        self.augment_rating_preserving()
        self.augment_same_context_different_homonym()
        return self.combined


def parse_args():
    p = argparse.ArgumentParser(
        description="AmbiStory augmentation (MPNet + T5 + dev-based synthetic homonyms)."
    )
    p.add_argument("--train_path", type=str, default="/kaggle/input/ambistory-raw/train.json")
    p.add_argument("--dev_path", type=str, default="/kaggle/input/ambistory-raw/dev.json")
    p.add_argument("--output_path", type=str, default="train_augmented_all.json")
    p.add_argument("--device", type=str, default=None)
    p.add_argument("--disable_t5", action="store_true")
    args, _ = p.parse_known_args()
    return args


def main():
    args = parse_args()
    print(f"Loading train from {args.train_path} ...")
    train_data = load_json(args.train_path)
    print(f"Loading dev from {args.dev_path} ...")
    dev_data = load_json(args.dev_path)

    pipeline = AugmentationPipeline(
        train_data=train_data,
        dev_data=dev_data,
        device=args.device,
        enable_t5=not args.disable_t5,
    )
    combined = pipeline.run_all()

    print(f"\nFinal dataset size: {len(combined)} "
          f"(original: {len(train_data)}, new: {len(combined) - len(train_data)})")
    print(f"Saving to {args.output_path} ...")
    save_json(combined, args.output_path)
    print("Done.")


def run_pipeline(
    train_path,
    dev_path,
    output_path,
    device=None,
    enable_t5=True,
):
    print(f"Loading train from {train_path} ...")
    train_data = load_json(train_path)
    print(f"Loading dev from {dev_path} ...")
    dev_data = load_json(dev_path)

    pipeline = AugmentationPipeline(
        train_data=train_data,
        dev_data=dev_data,
        device=device,
        enable_t5=enable_t5,
    )
    combined = pipeline.run_all()

    print(f"\nFinal dataset size: {len(combined)} "
          f"(original: {len(train_data)}, new: {len(combined) - len(train_data)})")
    print(f"Saving to {output_path} ...")
    save_json(combined, output_path)
    print("Done.")


if __name__ == "__main__":
    main()