In [None]:
# fix_mindmate_text_jsonl.py
import json, re, sys
from pathlib import Path
from transformers import AutoTokenizer

IN = sys.argv[1]          # e.g., train.jsonl
OUT = sys.argv[2]         # e.g., train.fixed.jsonl
MODEL_DIR = sys.argv[3]   # tokenizer dir (same as model)

tok = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True)
EOS = tok.eos_token or "</s>"

def split_turns(text: str):
    # Keep the role tags in the result list
    parts = re.split(r'(<\|user\|>|<\|assistant\|>)', text)
    # parts like ["", "<|user|>", " hi", "<|assistant|>", " Hello", ...]
    return parts

def merge_consecutive_user(parts):
    # Merge adjacent <|user|> contents into a single user payload
    merged = []
    i = 0
    while i < len(parts):
        seg = parts[i]
        if seg == "<|user|>":
            payloads = []
            j = i + 1
            while j < len(parts) and parts[j] not in ("<|user\|>", "<|assistant|>"):
                payloads.append(parts[j].lstrip("\n"))
                j += 1
            # If another <|user|> immediately follows, continue merging
            while j < len(parts) and parts[j] == "<|user|>":
                j += 1
                if j < len(parts) and parts[j] not in ("<|user|>", "<|assistant|>"):
                    payloads.append(parts[j].lstrip("\n"))
                    j += 1
            merged.extend(["<|user|>", ("\n".join(p.strip() for p in payloads)).strip() + "\n"])
            i = j
        else:
            merged.append(seg)
            i += 1
    return merged

def ensure_eos_for_assistant(parts):
    fixed = 0
    out = []
    i = 0
    while i < len(parts):
        seg = parts[i]
        out.append(seg)
        if seg == "<|assistant|>":
            # assistant content is next chunk (may be empty)
            content = parts[i+1] if i+1 < len(parts) else ""
            c = content.rstrip("\n")
            if not c.endswith(EOS):
                c = (c + " " + EOS).strip()
                fixed += 1
            out.append(c + "\n")
            i += 2
        else:
            i += 1
    return "".join(out), fixed

fixed_total = 0
asst_total = 0

with open(OUT, "w", encoding="utf-8") as fout, open(IN, "r", encoding="utf-8") as fin:
    for line in fin:
        ex = json.loads(line)
        txt = ex.get("text", "")
        parts = split_turns(txt)
        # Optional: merge consecutive user turns (comment out if you want to keep them)
        parts = merge_consecutive_user(parts)
        # Count assistant turns
        asst_total += sum(1 for p in parts if p == "<|assistant|>")
        new_text, fixed = ensure_eos_for_assistant(parts)
        ex["text"] = new_text
        fixed_total += fixed
        fout.write(json.dumps(ex, ensure_ascii=False) + "\n")

print(f"Assistant turns found: {asst_total}")
print(f"Assistant turns with EOS appended: {fixed_total}")
print(f"EOS used: {EOS}")
print(f"Wrote: {OUT}")
