In [None]:
import json
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm

CONFIG = {
    "input_path": "/kaggle/input/ambistory-raw/train.json",

    "model_name": "microsoft/Phi-3-mini-4k-instruct",

    "batch_size": 4,
    "max_new_tokens": 1200,

    "temperature": 0.7,
    "fallback_temperature": 1.0,
    "top_p": 0.9,

    "progress_bar": True,
    "slice_start": 0,
    "slice_end": 1140,

    "max_retries": 2,
    "debug": True,
    "debug_char_limit": 500
}


def make_output_path():
    start = CONFIG["slice_start"]
    end = CONFIG["slice_end"]
    end_str = "end" if end is None else str(end)
    return f"train_aug_slice_{start}_{end_str}.json"


OUTPUT_PATH = make_output_path()

tokenizer = AutoTokenizer.from_pretrained(CONFIG["model_name"])

model = AutoModelForCausalLM.from_pretrained(
    CONFIG["model_name"],
    torch_dtype=torch.float16
).to("cuda")

model.eval()

print("\n===== HARDWARE DEBUG =====")
print("Model device:", next(model.parameters()).device)
print("CUDA available:", torch.cuda.is_available())
print("GPU name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "N/A")
print("Output path:", OUTPUT_PATH)
print("==========================\n")


SYSTEM_PROMPT = (
    "You are a careful linguistic annotator. "
    "You strictly follow instructions and output only valid JSON."
)

USER_PROMPT = """
You are given multiple training instances from a word sense plausibility task.

For EACH instance, create EXACTLY ONE augmented version.

Rules:
- Rewrite the precontext. All three sentences must be lexically and syntactically different from the original.
- Rewrite the ambiguous sentence so it remains equally ambiguous and fully compatible with BOTH original word senses.
- The ambiguous sentence MUST NOT contain any technical, domain-specific, or sense-specific modifiers that were not present in the original.
- The ambiguous sentence must remain natural and valid under both interpretations, without favoring either.
- Rewrite the judged meaning using a semantically equivalent formulation.
- Rewrite the example sentence so it exemplifies the rewritten judged meaning to the same degree of implicitness as the original.
- If the ending is NON-EMPTY, rewrite it with the same degree of bias toward the judged meaning (no stronger, no weaker).
- If the ending is EMPTY, keep it EMPTY.
- DO NOT clarify the word sense.
- DO NOT introduce explicit sense markers.
- DO NOT increase or decrease overall plausibility.
- Keep each story coherent and natural.
- Do NOT include any text outside the JSON list.

Before producing the final JSON, silently verify that the ambiguous sentence still supports both original senses. If it does not, rewrite it again.

Return ONLY a JSON LIST of augmented instances, in the SAME ORDER as the input.

Instances:
{instances}
"""


def extract_first_json_list(text):
    start = text.find("[")
    if start == -1:
        return None
    depth = 0
    for i in range(start, len(text)):
        if text[i] == "[":
            depth += 1
        elif text[i] == "]":
            depth -= 1
            if depth == 0:
                return text[start:i + 1]
    return None


@torch.inference_mode()
def generate_once(instances_json, temperature, do_sample):
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": USER_PROMPT.format(instances=instances_json)}
    ]

    input_ids = tokenizer.apply_chat_template(
        messages,
        return_tensors="pt",
        add_generation_prompt=True
    ).to(model.device)

    output = model.generate(
        input_ids,
        max_new_tokens=CONFIG["max_new_tokens"],
        temperature=temperature,
        top_p=CONFIG["top_p"],
        do_sample=do_sample,
        eos_token_id=tokenizer.eos_token_id
    )

    generated_tokens = output[0][input_ids.shape[-1]:]
    decoded = tokenizer.decode(generated_tokens, skip_special_tokens=True)

    extracted = extract_first_json_list(decoded)

    # if CONFIG["debug"]:
    #     print("\n===== RAW MODEL OUTPUT (TRIMMED) =====")
    #     print(decoded[:CONFIG["debug_char_limit"]])
    #     print("=====================================\n")

    return decoded, extracted


def validate_aug_batch(aug_batch, expected_len):
    return (
        isinstance(aug_batch, list)
        and len(aug_batch) == expected_len
        and all(isinstance(x, dict) for x in aug_batch)
    )


def is_modified(original, augmented):
    fields = ["precontext", "sentence", "judged_meaning", "example_sentence"]
    return any(
        str(original.get(f, "")).strip().lower()
        != str(augmented.get(f, "")).strip().lower()
        for f in fields
    )


def generate_batch_with_retry(instances_json, expected_len):
    for attempt in range(CONFIG["max_retries"]):
        try:
            decoded, extracted = generate_once(
                instances_json,
                CONFIG["temperature"] if attempt == 0 else CONFIG["fallback_temperature"],
                do_sample=(attempt == 0)
            )

            if extracted is None:
                raise ValueError("No JSON list found")

            aug = json.loads(extracted)
            if validate_aug_batch(aug, expected_len):
                return aug

        except Exception as e:
            if CONFIG["debug"]:
                print(f"[DEBUG RETRY] {e}")

    return None


def main():
    with open(CONFIG["input_path"], "r", encoding="utf-8") as f:
        data = json.load(f)

    keys = sorted(data.keys(), key=int)
    start = CONFIG["slice_start"]
    end = CONFIG["slice_end"] or len(keys)
    keys = keys[start:end]

    max_sample_id_global = max(int(v["sample_id"]) for v in data.values())

    global_offset = CONFIG["slice_start"]
    local_instance_offset = 0

    augmented_data = {}

    iterator = range(0, len(keys), CONFIG["batch_size"])
    if CONFIG["progress_bar"]:
        iterator = tqdm(iterator, desc=f"Augmenting slice {start}:{end}")

    for i in iterator:
        batch_keys = keys[i:i + CONFIG["batch_size"]]
        originals = [data[k] for k in batch_keys]

        batch_for_prompt = []
        for item in originals:
            d = dict(item)
            d.pop("sample_id")
            batch_for_prompt.append(d)

        aug_batch = generate_batch_with_retry(
            json.dumps(batch_for_prompt, ensure_ascii=False),
            len(batch_for_prompt)
        )

        if aug_batch is None:
            print(f"[SKIP BATCH {i}] Failed")
            continue

        for orig, aug in zip(originals, aug_batch):
            if not is_modified(orig, aug):
                print("\n[REJECTED AS IDENTICAL]")
                continue

            sample_id = (
                max_sample_id_global
                + global_offset
                + local_instance_offset
                + 1
            )

            aug["sample_id"] = str(sample_id)
            augmented_data[str(sample_id)] = aug

            local_instance_offset += 1

    with open(OUTPUT_PATH, "w", encoding="utf-8") as f:
        json.dump(augmented_data, f, ensure_ascii=False, indent=2)

    print(f"Done. Augmented instances written: {len(augmented_data)}")
    print(f"Output file: {OUTPUT_PATH}")


if __name__ == "__main__":
    main()
