# Imports

In [None]:
from typing import List, Dict, Any
import json

# Annotate text with XML tags based on entity spans

In [None]:
from typing import List, Dict, Any


def tag_text_with_entities(text: str,
                           entities: List[Dict[str, Any]]) -> str:
    segs = []
    for ent in entities:
        starts = ent['start'] if isinstance(ent['start'], list) else [ent['start']]
        ends   = ent['end']   if isinstance(ent['end'],   list) else [ent['end']]
        for s, e in zip(starts, ends):
            segs.append({"start": s, "end": e, "len": e - s,
                         "label": ent["label"], "orig_id": ent["id"]})

    # Find which entity ids need the ent_id attribute (for overlapping/discontinuous)
    needing_attr = _find_overlapping_entity_ids(segs)
    short_id_map = {oid: f"E{i+1}" for i, oid in enumerate(sorted(needing_attr))}

    opens, closes = {}, {}
    for seg in segs:
        opens .setdefault(seg["start"], []).append(seg)
        closes.setdefault(seg["end"],   []).append(seg)

    for v in opens.values():   v.sort(key=lambda s: (-s["len"], s["orig_id"]))
    for v in closes.values():  v.sort(key=lambda s: ( s["len"], s["orig_id"]))

    out, i, n = [], 0, len(text)
    while i < n:
        if i in closes:
            out.extend(f"</{s['label']}>" for s in closes[i])

        if i in opens:
            for s in opens[i]:
                oid = s["orig_id"]
                if oid in needing_attr:
                    out.append(f"<{s['label']} ent_id=\"{short_id_map[oid]}\">")
                else:
                    out.append(f"<{s['label']}>")

        out.append(text[i])
        i += 1

    if n in closes:
        out.extend(f"</{s['label']}>" for s in closes[n])
    return "".join(out)


# Overlap detection


In [None]:
def _find_overlapping_entity_ids(segs):
    from collections import defaultdict
    need = set()
    by_lab = defaultdict(list)

    for s in segs:
        by_lab[s["label"]].append((s["start"], s["end"], s["orig_id"]))

    for label, ivs in by_lab.items():
        ivs.sort(key=lambda t: t[0])
        active = []
        for start, end, oid in ivs:
            active = [(e, i) for (e, i) in active if e > start]
            for _, aid in active:
                need.add(oid)
                need.add(aid)
            active.append((end, oid))
    return need


In [None]:
def process_entries(input_file: str, output_file: str) -> None:

    with open(input_file, 'r', encoding='utf-8') as f:
        data = json.load(f)

    for entry in data:
        text = entry.get('text', '')
        entities = entry.get('entities', [])
        entry['tagged_text'] = tag_text_with_entities(text, entities)

    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(data, f, ensure_ascii=False, indent=2)




In [None]:
import json
process_entries('Input-file.json', 'Output-file.json')