In [1]:
import csv
import os
import re
import pandas as pd

In [2]:
def find_all_mention_spans(csv_file, texts_folder, case_insensitive=False):
    results = []
    seen_spans = set() 

    with open(csv_file, newline='', encoding='utf-8') as f:
        reader = csv.DictReader(f)
        mentions_by_doc = {}
        for row in reader:
            doc_id = row['document_id']
            mention = row['mention']
            mentions_by_doc.setdefault(doc_id, []).append(mention)

    for doc_id, mentions in mentions_by_doc.items():
        text_path = os.path.join(texts_folder, f"{doc_id}.txt")
        if not os.path.isfile(text_path):
            print(f"Warning: Text file for document_id {doc_id} not found.")
            continue

        with open(text_path, encoding='utf-8') as f:
            text = f.read()

        for mention in mentions:
            escaped_mention = re.escape(mention)
            flags = re.IGNORECASE if case_insensitive else 0
            matches = list(re.finditer(escaped_mention, text, flags))

            if not matches:
                span_key = (doc_id, mention, -1, -1)
                if span_key not in seen_spans:
                    print(f"Warning: Mention '{mention}' not found in document {doc_id}")
                    results.append({
                        'doc_id': doc_id,
                        'mention': mention,
                        'start': -1,
                        'end': -1
                    })
                    seen_spans.add(span_key)
            else:
                for match in matches:
                    start, end = match.start(), match.end()
                    span_key = (doc_id, mention, start, end)
                    if span_key not in seen_spans:
                        results.append({
                            'doc_id': doc_id,
                            'mention': mention,
                            'start': start,
                            'end': end
                        })
                        seen_spans.add(span_key)

    return results




In [3]:
csv_file = 'D:/prompting/one_shot/one_pred.csv'
texts_folder = 'D:/prompting/ncbi_texts'
spans = find_all_mention_spans(csv_file, texts_folder, case_insensitive=False)





In [4]:
spans

[{'doc_id': '932197',
  'mention': 'hereditary deficiency of the fifth component of complement',
  'start': 156,
  'end': 214},
 {'doc_id': '932197',
  'mention': 'systemic lupus erythematosus',
  'start': 279,
  'end': 307},
 {'doc_id': '932197', 'mention': 'lupus', 'start': 288, 'end': 293},
 {'doc_id': '932197', 'mention': 'lupus', 'start': 528, 'end': 533},
 {'doc_id': '932197', 'mention': 'exacerbations', 'start': 577, 'end': 590},
 {'doc_id': '932197',
  'mention': 'underlying disease',
  'start': 626,
  'end': 644},
 {'doc_id': '932197', 'mention': 'C5 deficiency', 'start': 983, 'end': 996},
 {'doc_id': '932197', 'mention': 'C5D', 'start': 1069, 'end': 1072},
 {'doc_id': '932197', 'mention': 'C5D', 'start': 1257, 'end': 1260},
 {'doc_id': '932197', 'mention': 'C5D', 'start': 1426, 'end': 1429},
 {'doc_id': '932197', 'mention': 'C5D', 'start': 1916, 'end': 1919},
 {'doc_id': '932197', 'mention': 'C5-deficient', 'start': 1055, 'end': 1067},
 {'doc_id': '932197', 'mention': 'C5D se

In [5]:
df = pd.DataFrame(spans)
df.to_csv('predictions.tsv', sep="\t",index=False)

In [6]:
pred_df = pd.read_csv("D:/prompting/one_shot/predictions.tsv", sep="\t")
gold_df = pd.read_csv("D:/OGER/OGER/data/ncbi_annotations/gold_annotations.tsv", sep="\t")

In [7]:
def exact_match(pred, gold):
    return (
        pred['doc_id'] == gold['doc_id'] and
        pred['start'] == gold['start'] and
        pred['end'] == gold['end'] 
    )

def partial_match(pred, gold):
    return (
        pred['doc_id'] == gold['doc_id'] and
        not (pred['end'] <= gold['start'] or pred['start'] >= gold['end']) 
    )


In [8]:
def evaluate(pred_df, gold_df, match_func):
    tp = 0
    matched_gold = set()

    for _, pred_row in pred_df.iterrows():
        match_found = False
        for idx, gold_row in gold_df.iterrows():
            if idx in matched_gold:
                continue
            if match_func(pred_row, gold_row):
                tp += 1
                matched_gold.add(idx)
                match_found = True
                break

    fp = len(pred_df) - tp
    fn = len(gold_df) - tp

    precision = tp / (tp + fp) if tp + fp else 0.0
    recall = tp / (tp + fn) if tp + fn else 0.0
    f1 = 2 * precision * recall / (precision + recall) if precision + recall else 0.0

    return precision, recall, f1, fp, fn, tp 


In [9]:
exact_precision, exact_recall, exact_f1, exact_fp, exact_fn, exact_tp = evaluate(pred_df, gold_df, exact_match)
partial_precision, partial_recall, partial_f1, partial_fp, partial_fn, partial_tp = evaluate(pred_df, gold_df, partial_match)

print(f"Exact Match: P={exact_precision:.6f}, R={exact_recall:.6f}, F1={exact_f1:.6f}, fp={exact_fp:.6f}, fn={exact_fn:.6f}, tp={exact_tp:.6f}")
print(f"Partial Match: P={partial_precision:.6f}, R={partial_recall:.6f}, F1={partial_f1:.6f}, fp={partial_fp:.6f}, fn={partial_fn:.6f}, tp={partial_tp:.6f}")

Exact Match: P=0.558431, R=0.696875, F1=0.620019, fp=529.000000, fn=291.000000, tp=669.000000
Partial Match: P=0.661102, R=0.825000, F1=0.734013, fp=406.000000, fn=168.000000, tp=792.000000
