# Convert JSON output to BRAT 

This is an experimental version to convert the custom JSON data back to BRAT format

It is based on heuristics

In [None]:
# Import libraries
import ast
import json
import pathlib
import re
from collections import defaultdict, OrderedDict, Counter
from typing import List, Dict, Tuple, Set

import pandas as pd

## Convert JSON to BRAT

In [None]:

# Global vars

BASE_DATA_DIR = pathlib.Path("/prj/doctoral_letters/data/i2b2_2018track2/test")
CSV_PATH      = pathlib.Path(
    "/home/prichter/research/MIEQA/output_results_llama3_8b_i2b2_pydantic.csv"
)
OUT_DIR       = pathlib.Path("json2brat")
OUT_DIR.mkdir(exist_ok=True, parents=True)
ANN_DIR    = pathlib.Path("json2brat/system")          # converted .ann files
ANN_DIR.mkdir(exist_ok=True, parents=True)

ENTITY_KEYS = ["strength","dosage","form","frequency","route","reason","ade","duration"]

# HELPERS

_ADM_RE = re.compile(
    r"Admission\s+Date:\s*\[\*\*([0-9]{4}-[0-9]{1,2}-[0-9]{1,2})\*\*\]",
    re.I,
)

def _build_letter_map(csv_path: pathlib.Path) -> Dict[str, List[dict]]:
    """Return {'[**2122-1-14**]': [med_dict, ...], ...} aggregated over the CSV."""
    df = pd.read_csv(csv_path, sep="|", names=["text", "gold", "pred"], header=0)

    letter_map: Dict[str, List[dict]] = defaultdict(list)
    current_token = None

    for _, row in df.iterrows():
        text_line: str = str(row["text"])
        m = _ADM_RE.search(text_line)
        if m:
            current_token = f"[**{m.group(1)}**]"

        if current_token is None:
            continue

        try:
            pred_od = eval(row["pred"])
            meds = pred_od.get("medications", [])
            letter_map[current_token].extend(meds)
        except Exception:
            continue

    return letter_map

LETTER_PRED_MAP = _build_letter_map(CSV_PATH)


def load_note(letter_id: str) -> str:
    return (BASE_DATA_DIR / f"{letter_id}.txt").read_text(encoding="utf-8")

def extract_admission_token(note_text: str) -> str:
    m = _ADM_RE.search(note_text)
    if not m:
        raise ValueError("Admission Date token not found in original note")
    return f"[**{m.group(1)}**]"

def normalise_spaces(s: str) -> str:
    import re
    return re.sub(r"\s+", " ", s.strip())

def find_offset(text: str, substring: str, used_spans: Set[Tuple[int, int]],
                start_pos: int = 0) -> Tuple[int, int]:
    if not substring:
        raise ValueError("Empty substring")
    while True:
        idx = text.find(substring, start_pos)
        if idx == -1:
            break
        span = (idx, idx + len(substring))
        if all(not (max(a, span[0]) < min(b, span[1])) for a, b in used_spans):
            return span
        start_pos = idx + 1
    raise ValueError(f"Could not locate unique span for: {substring!r}")

ENTITY_MAP = {
    "medication": ("Drug", "Drug"),
    "strength":   ("Strength", "Strength-Drug"),
    "dosage":     ("Dosage", "Dosage-Drug"),
    "form":       ("Form", "Form-Drug"),
    "frequency":  ("Frequency", "Frequency-Drug"),
    "route":      ("Route", "Route-Drug"),
    "reason":     ("Reason", "Reason-Drug"),
    "ade":        ("ADE", "ADE-Drug"),
    "duration":   ("Duration", "Duration-Drug"),
}

def convert_one_letter(letter_id: str) -> pathlib.Path:
    """
    Produce .ann files
    """
    note_text = load_note(letter_id)
    adm_token = extract_admission_token(note_text)

    meds = LETTER_PRED_MAP.get(adm_token, [])
    if not meds:
        raise ValueError(f"No prediction rows found for admission token {adm_token}")

    ann_lines: List[str] = []
    used: Set[Tuple[int, int]] = set()
    tid = rid = 1

    def next_tid():
        nonlocal tid
        tid += 1
        return f"T{tid-1}"

    def next_rid():
        nonlocal rid
        rid += 1
        return f"R{rid-1}"

    for med in meds:
        drug_name_raw = normalise_spaces(med.get("medication", ""))
        drug_name = re.sub(r"\s*\(\d+\)\s*$", "", drug_name_raw)
        try:
            s, e = find_offset(note_text, drug_name, used)
        except ValueError:
            continue
        drug_tid = next_tid()
        ann_lines.append(f"{drug_tid}\tDrug {s} {e}\t{drug_name}")
        used.add((s, e))

        for key, val in med.items():
            if key == "medication" or not val:
                continue
            vals = val if isinstance(val, list) else [val]
            for v in vals:
                v_norm = normalise_spaces(v)
                try:
                    s2, e2 = find_offset(note_text, v_norm, used)
                except ValueError:
                    continue
                ent_type, rel_type = ENTITY_MAP[key]
                ent_tid = next_tid()
                ann_lines.append(f"{ent_tid}\t{ent_type} {s2} {e2}\t{v_norm}")
                used.add((s2, e2))
                ann_lines.append(f"{next_rid()}\t{rel_type} Arg1:{ent_tid} Arg2:{drug_tid}")

    out_path = OUT_DIR / f"system/{letter_id}.ann"
    out_path.write_text("\n".join(ann_lines), encoding="utf-8")
    return out_path

outs = []
for txt_path in BASE_DATA_DIR.glob("*.txt"):
    letter_id = txt_path.stem          # e.g. "107515"
    outs.append(convert_one_letter(letter_id))

print(f"Wrote {len(outs)} .ann files to {OUT_DIR}")



## Validation

In [None]:
# ---------- helpers ----------------------------------------------------------
def parse_ann_to_counters(path):
    """
    Return two Counters:
    - ent_counter
    - rel_counter
    """
    ent_counter = Counter()
    rel_counter = Counter()

    with open(path, encoding="utf-8") as fh:
        for line in fh:
            if line.startswith("T"):
                parts = line.rstrip("\n").split("\t")
                if len(parts) != 3:
                    continue         
                label = parts[1].split()[0]
                text  = parts[2]
                ent_counter[(label, text)] += 1

            elif line.startswith("R"):
                parts = line.rstrip("\n").split("\t")
                if len(parts) != 2:
                    continue       
                rel_type = parts[1].split()[0]
                rel_counter[(rel_type, "", "")] += 1

    return ent_counter, rel_counter

# gold counts from CSV 
gold_ent      = Counter()      
gold_rel      = Counter()       
gold_ent_cls  = Counter()
gold_rel_cls  = Counter()

df = pd.read_csv(CSV_PATH, sep="|", names=["text","gold","pred"], header=1)
for od in (eval(g) for g in df["gold"]):
    for m in od.get("medications", []):
        drug = m.get("medication","").strip()
        if drug:
            gold_ent[("Drug", drug)] += 1
            gold_ent_cls["Drug"] += 1
        for k in ENTITY_KEYS:
            val = m.get(k,'')
            if not val: continue
            for v in (val if isinstance(val, list) else [val]):
                v = v.strip()
                if not v: continue
                lbl = k.upper() if k == "ade" else k.capitalize()

                gold_ent[(lbl, v)] += 1 ; gold_ent_cls[lbl] += 1
                rel_lbl = f"{lbl}-Drug"
                gold_rel[(rel_lbl, v, drug)] += 1
                gold_rel_cls[rel_lbl] += 1

tmp = defaultdict(int)
for (lbl, _, _), cnt in gold_rel.items():
    tmp[(lbl, '', '')] += cnt
gold_rel = Counter(tmp)

# parse all converted .ann files 
pred_ent = Counter(); pred_rel = Counter()
for ann in ANN_DIR.glob("*.ann"):
    e_cnt, r_cnt = parse_ann_to_counters(ann)
    pred_ent += e_cnt ; pred_rel += r_cnt

# corpus-level numbers 
tot_gold_e = sum(gold_ent.values());   tot_pred_e = sum(pred_ent.values())
tot_gold_r = sum(gold_rel.values());   tot_pred_r = sum(pred_rel.values())

# diff counters
missing_ent_keys = {}      
add_ent_keys     = {}      
missing_rel_keys = {}     
add_rel_keys     = {}     

# entities
for k, gold_n in gold_ent.items():
    diff = gold_n - pred_ent.get(k, 0)
    if diff > 0:                 # missing
        missing_ent_keys[k] = diff

for k, pred_n in pred_ent.items():
    diff = pred_n - gold_ent.get(k, 0)
    if diff > 0:                 # additional
        add_ent_keys[k] = diff

# relations  (coarse keys = (relType,'',''))
counter = 1
for k, gold_n in gold_rel.items():
    diff = gold_n - pred_rel.get(k, 0)
    if diff > 0:
        missing_rel_keys[k] = diff

for k, pred_n in pred_rel.items():
    diff = pred_n - gold_rel.get(k, 0)
    if diff > 0:
        add_rel_keys[k] = diff

# corpus-level numbers
missing_e = sum(missing_ent_keys.values())
add_e     = sum(add_ent_keys.values())
missing_r = sum(missing_rel_keys.values())
add_r     = sum(add_rel_keys.values())

print("\n=== CSV-gold  vs  converted-BRAT (duplicate-aware) ===")
print(f"Entities gold : {tot_gold_e}")
print(f"Entities ann  : {tot_pred_e}")
print(f"  missing     : {missing_e}")
print(f"  additional  : {add_e}\n")
print(f"Relations gold: {tot_gold_r}")
print(f"Relations ann : {tot_pred_r}")
print(f"  missing     : {missing_r}")
print(f"  additional  : {add_r}\n")

In [None]:
# compare gold-JSON vs converted-BRAT relation counts per class

pred_rel_cls = Counter()
for (rel_type, _, _), cnt in pred_rel.items():
    pred_rel_cls[rel_type] += cnt

print("--- Relation counts: CSV-gold  vs  BRAT ---")
print(f"{'Relation':<18}{'gold':>8}{'ann':>8}{'Δ':>8}{'%Δ':>8}")
print("-"*42)
for rel in sorted(gold_rel_cls):
    gold_n = gold_rel_cls[rel]
    pred_n = pred_rel_cls.get(rel, 0)
    delta  = pred_n - gold_n
    pct    = (delta / gold_n * 100) if gold_n else 0
    print(f"{rel:<18}{gold_n:>8}{pred_n:>8}{delta:>8}{pct:>7.1f}%")
