# Aligning the LLM output with the original text

In [None]:
from difflib import SequenceMatcher


def align_entities_to_original(original_text, llm_output):
    if "```" in llm_output:
        content = llm_output.split("```")[1]
        if content.startswith("xml"):
            content = content[3:].strip()
    else:
        content = llm_output

    result = original_text
    tag_stack = []
    entity_markers = []

    # First pass: Extract tag positions from the LLM output
    i = 0
    clean_llm_text = ""
    while i < len(content):
        if content[i] == '<':
            tag_start = i
            i += 1
            while i < len(content) and content[i] != '>':
                i += 1

            if i < len(content):
                tag_content = content[tag_start:i+1]

                if tag_content.startswith("</"):
                    tag_name = tag_content[2:-1]
                    if tag_stack and tag_stack[-1][0] == tag_name:
                        open_tag_name, open_pos = tag_stack.pop()
                        entity_text = clean_llm_text[open_pos:]
                        entity_markers.append((open_pos, len(clean_llm_text), entity_text, open_tag_name))
                else:
                    tag_name = tag_content[1:-1]
                    tag_stack.append((tag_name, len(clean_llm_text)))
        else:
            clean_llm_text += content[i]
        i += 1

    # Second pass: Find the corresponding positions in the original text
    matcher = SequenceMatcher(None, clean_llm_text, original_text)

    position_map = {}
    for block in matcher.get_matching_blocks():
        for i in range(block.size):
            position_map[block.a + i] = block.b + i

    entity_markers.sort(key=lambda x: x[0], reverse=True)

    for start_pos, end_pos, entity_text, tag_name in entity_markers:
        if start_pos in position_map and (end_pos - 1) in position_map:
            orig_start = position_map[start_pos]
            orig_end = position_map[end_pos - 1] + 1

            result = result[:orig_end] + f"</{tag_name}>" + result[orig_end:]
            result = result[:orig_start] + f"<{tag_name}>" + result[orig_start:]

    return result

In [None]:
import json

with open('input-file.json', 'r', encoding='utf-8') as f:
    data = json.load(f)

for entry in data:
    text = entry.get('text', '')
    print(text)
    pred = entry.get('prediction', '')
    print(pred)
    entry['aligned'] = align_entities_to_original(text, pred)
    print( entry['aligned'])

with open('input-file_aligned.json', 'w', encoding='utf-8') as f:
    json.dump(data, f, ensure_ascii=False, indent=2)


# Convert to the json format (Detagging)

In [None]:
import uuid

def _auto_id():
    return f"U{uuid.uuid4().hex[:8]}"

import re, uuid
from typing import List, Dict, Any

OPEN_RE  = re.compile(r'<([A-Z_]+)(?:\s+ent_id="([^"]+)")?>')
CLOSE_RE = re.compile(r'</([A-Z_]+)>')

def detag_to_entities(tagged: str) -> List[Dict[str, Any]]:
    plain_chars: List[str] = []
    stack: List[tuple] = []
    by_id: Dict[str, Dict[str, Any]] = {}

    def _auto_id():
        return f"U{uuid.uuid4().hex[:8]}"

    i = 0
    N = len(tagged)
    while i < N:
        mo = OPEN_RE.match(tagged, i)
        if mo:
            label, eid = mo.group(1), mo.group(2)
            stack.append((label, eid, len(plain_chars)))
            i = mo.end()
            continue

        mc = CLOSE_RE.match(tagged, i)
        if mc:
            label = mc.group(1)
            while stack and stack[-1][0] != label:
                stack.pop()
            if stack:
                label, eid, start = stack.pop()
                end = len(plain_chars)
                part_text = "".join(plain_chars[start:end])
                if eid is None:
                    eid = _auto_id()
                rec = by_id.setdefault(
                    eid,
                    {"label": label, "id": eid,
                     "text_parts": [], "start": [], "end": []})
                rec["text_parts"].append(part_text)
                rec["start"].append(start)
                rec["end"].append(end)
            i = mc.end()
            continue

        plain_chars.append(tagged[i])
        i += 1

    plain = "".join(plain_chars)

    out = []
    for rec in by_id.values():
        order = sorted(range(len(rec["start"])), key=lambda k: rec["start"][k])
        joined = " ".join(rec["text_parts"][k] for k in order)
        out.append({
            "text":  joined,
            "start": rec["start"],
            "end":   rec["end"],
            "id":    rec["id"],
            "label": rec["label"],
        })
    return out



In [None]:
import json

with open('input-file_aligned.json', 'r', encoding='utf-8') as f:
    data = json.load(f)

for entry in data:
    text = entry.get('text', '')
    pred = entry.get('aligned', '')
    entry['entities'] =  detag_to_entities(pred)

with open('input-file_detagged.json', 'w', encoding='utf-8') as f:
    json.dump(data, f, ensure_ascii=False, indent=2)
