In [1]:
import sys
sys.path.append('../xmen/benchmarks')

# Entity Simplification with OpenAI / GPT-4

In [2]:
from pathlib import Path
import pandas as pd
import datasets
import numpy as np

In [3]:
base_path = Path.home() / '.cache' / 'xmen' / 'symptemist'

In [13]:
import dataloaders
dataset = datasets.load_dataset('../biomedical/bigbio/hub/hub_repos/symptemist/symptemist.py', 'symptemist_entities_bigbio_kb')

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


Generating train split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

In [14]:
j = 0
for k, v in dataset.items():
    i = 0
    for k in v:
        for e in k['entities']:
            i += 1
    print(i)
    j += i
print(j)

9092
3104
12196


In [None]:
from xmen import load_kb
kb = load_kb(base_path / 'symptemist.jsonl')

In [None]:
from xmen.evaluation import error_analysis, evaluate, evaluate_at_k

In [None]:
from xmen.linkers import default_ensemble
linker = default_ensemble(base_path / 'index', cuda=False)

In [None]:
SYMPTEMIST_FEW_SHOT_EXAMPLES = [
    ("afebril", "temperatura corporal normal"),
    ("induración de la vaginal testicular", "trastorno de testículo"),
    ("formaciones mamelonadas en su interior a nivel de lóbulo superior", "lesión de pulmón"),
    ("disnea", "disnea"),
]

In [None]:
def count_differences(simplified, original):
    cnt_all = 0
    cnt_diff = 0
    for ds, do in zip(simplified, original):
        for es, eo in zip(ds['entities'], do['entities']):
            cnt_all += 1
            if es['text'] != eo['text']:
                cnt_diff +=1
    return cnt_diff, cnt_all

In [None]:
MODEL = 'gpt-4-0125-preview'
# Pre-computed for SympTEMIST to save API calls
table_file = 'lookup_gpt-4-0125-preview_20240214-205237_prompt1.pkl'
prompt = SYMPTEMIST_FEW_SHOT_EXAMPLES

lookup_table = {}

In [None]:
import pickle
lookup_table = pickle.load(open(table_file, 'rb')) 

In [None]:
from xmen.data.simplification import GPTSimplifier, EntitySimplification, SimplifierWrapper

In [None]:
text_simplifier = GPTSimplifier(
    model=MODEL, 
    open_ai_api_key="insert_api_key", 
    fixed_few_shot_examples = prompt,
    table=lookup_table
)

In [None]:
simplifier = EntitySimplification(text_simplifier, set_long_form=True)

In [None]:
print(text_simplifier.prompt_template.format('aumento de densidad en lóbulo inferior'))

In [None]:
text_simplifier.simplify('aumento de densidad en lóbulo inferior')

# Generate Candidates

In [None]:
candidates = linker.predict_batch(dataset, top_k=64, batch_size=128)

In [None]:
_ = evaluate_at_k(dataset['train'], candidates['train'])

In [None]:
_ = evaluate_at_k(dataset['validation'], candidates['validation'])

In [None]:
_ = evaluate_at_k(dataset['test'], candidates['test'])

In [None]:
simplified_ds = simplifier.transform_batch(candidates)
simplified_candidates = linker.predict_batch(simplified_ds, top_k=64, batch_size=128)

# Determine Optimal Cutoff

In [None]:
SPLIT = 'train'
eval_candidates = candidates[SPLIT]
eval_ds = dataset[SPLIT]
eval_simplified_ds = simplified_candidates[SPLIT]

def select_candidates(d, idx, filter_fn=lambda e: True):
    result = []
    for ei, ec in zip(d['entities'], eval_candidates[idx]['entities']):
        if filter_fn(ec):
            result.append(ec)
        else:
            result.append(ei)
    return { 'entities' : result }

In [None]:
baseline = evaluate_at_k(eval_ds, eval_candidates, silent=True)

In [None]:
cutoff_eval = []

for cutoff in np.arange(0.55, 1.05, 0.05):
    print(cutoff)
    fn = lambda e: e['normalized'][0]['score'] >= cutoff
    best_candidates = eval_simplified_ds.map(lambda d, i: select_candidates(d, i, fn), with_indices=True)
    diffs = count_differences(best_candidates, eval_candidates)
    print(diffs)
    eval_res = evaluate_at_k(eval_ds, best_candidates, silent=True)
    cutoff_eval_i = {'cutoff' : cutoff, 'num_changed' : diffs[0], 'num_all' : diffs[1]}
    for i, es in eval_res.items():
        cutoff_eval_i[i] = es['strict']['recall']
    cutoff_eval.append(cutoff_eval_i)

In [None]:
cutoff_eval = pd.DataFrame(cutoff_eval).round(3)

In [None]:
cutoff_eval

In [None]:
from matplotlib import pyplot as plt
import matplotlib

matplotlib.rcParams.update({'font.size': 14, 'font.family' : 'serif'})

fig, axs = plt.subplots(2, 1, figsize=(10,5), sharex=True, gridspec_kw={'height_ratios': [2, 1]}, squeeze=True)

ax1 = axs[0]

x_range = np.arange(0, len(cutoff_eval))
width = 0.3       

ax1.bar(x_range - 0.5 * width, 
        (cutoff_eval.set_index('cutoff')[1] - baseline[1]['strict']['recall']).values, 
        width, color='darkblue', linewidth=0.5, edgecolor='black', label='$\Delta$ Recall@1 (pp.)')
ax1.bar(x_range + 0.5 * width, 
        (cutoff_eval.set_index('cutoff')[64] - baseline[64]['strict']['recall']).values, 
        width, color='lightblue', linewidth=0.5, edgecolor='black', label='$\Delta$ Recall@64 (pp.)')

ax1.grid(axis='y')
ax1.set_yticks(np.arange(-0.005, 0.03, 0.005))
ax1.set_yticklabels([f'+{t * 100:.1f}' if t > 0 else f'{t * 100:.1f}' for t in ax1.get_yticks()])

ax1.set_xticks(x_range)
ax1.set_xticklabels(cutoff_eval.cutoff.values)

ax0 = axs[1]
ax0.grid(axis='y')
ax0.plot(x_range, cutoff_eval.num_changed / cutoff_eval.num_all, color='red', label='Simplified Mentions (%)', marker='o')
ax0.set_yticks(np.arange(0,1.02,0.2))
ax0.set_yticklabels([f"{int(t * 100)}%" if t >= 0 and t <=1 else "" for t in ax0.get_yticks()])

ax0.set_xlabel('Confidence Threshold for Entity Simplication')
ax0.legend()

ax0.set_xticks(x_range)
ax0.set_xticklabels(cutoff_eval.cutoff.values)

ax1.legend()

#plt.savefig('diff_cutoff.png', dpi=1200, bbox_inches='tight')

plt.show()

In [None]:
best_cutoff = 0.85

In [None]:
# we use a SimplifierWrapper so that the original spans can be restored after candidate generatio
simplifier_wrapper = SimplifierWrapper(
    linker,
    text_simplifier, 
    filter_fn=lambda c: c['normalized'][0]['score'] < best_cutoff, 
    set_long_form=True
)
simplified_candidates_cutoff = simplifier_wrapper.predict_batch(candidates, top_k=64, batch_size=128)

In [None]:
print(count_differences(simplified_candidates_cutoff['train'], dataset['train']))
print(count_differences(simplified_candidates_cutoff['validation'], dataset['validation']))
print(count_differences(simplified_candidates_cutoff['test'], dataset['test']))

In [None]:
_ = evaluate_at_k(dataset['validation'], simplified_candidates_cutoff['validation'])

In [None]:
_ = evaluate_at_k(dataset['test'], simplified_candidates_cutoff['test'])

# Save

In [None]:
simplified_candidates_cutoff.save_to_disk('candidates_simplified_cutoff')

# Analysis

In [None]:
from xmen.evaluation import error_analysis, evaluate_at_k

In [None]:
_ = evaluate_at_k(dataset['test'], candidates['test'])

In [None]:
_ = evaluate_at_k(dataset['test'], simplified_candidates_cutoff['test'])

In [None]:
simplified_ds = simplifier.transform_batch(dataset)

In [None]:
ea_before_simple_test = error_analysis(dataset['test'], candidates['test'])
ea_after_simple_test = error_analysis(simplified_ds['test'], simplified_candidates_cutoff['test'])

In [None]:
ea_after_no_thresh_test = error_analysis(simplified_ds['test'], simplified_candidates['test'])

In [None]:
SPLIT = 'test'
eval_candidates = candidates[SPLIT]
eval_ds = dataset[SPLIT]
eval_simplified_ds = simplified_candidates[SPLIT]

def select_candidates(d, idx, filter_fn=lambda e: True):
    result = []
    for ei, ec in zip(d['entities'], eval_candidates[idx]['entities']):
        if filter_fn(ec):
            result.append(ec)
        else:
            result.append(ei)
    return { 'entities' : result }

In [None]:
fn = lambda e: e['normalized'][0]['score'] >= best_cutoff
best_candidates = eval_simplified_ds.map(lambda d, i: select_candidates(d, i, fn), with_indices=True)

In [None]:
def get_len(t):
    return len(t.split(' '))

def get_stats(ea_df, name, ds):
    res = {}
    res['name'] = name
    word_lens = pd.Series([get_len(e['text'][0]) for d in ds for e in d['entities']])
    res['max_length'] = int(word_lens.max())
    res['mean_length'] = word_lens.mean().round(2)
    res['recall_1'] = (ea_df.pred_index == 0).sum() / len(ea_df)
    res['recall_64'] = (ea_df.pred_index >= 0).sum() / len(ea_df)
    res['mean_score'] = ea_df.pred_top_score.mean()
    for i in np.linspace(0.5, 1.0, 6):
        begin, end = i, i + 0.1
        sub_ea = ea_df[(ea_df.pred_top_score >= begin) & (ea_df.pred_top_score < end)]
        prefix = f"{begin}-{end}"
        res[f'{prefix}_len'] = len(sub_ea)
        res[f'{prefix}_recall_1'] = (sub_ea.pred_index == 0).sum() / len(sub_ea)
        res[f'{prefix}_recall_64'] = (sub_ea.pred_index >= 0).sum() / len(sub_ea)
    return res

In [None]:
df = pd.DataFrame([
    get_stats(ea_before_simple_test, 'no_simpl', candidates['test']),
    get_stats(ea_after_simple_test, 'simpl_thresh', best_candidates),
    get_stats(ea_after_no_thresh_test, 'simpl_no_thresh', simplified_ds['test']),
]).set_index('name').T
df['delta'] = df.simpl_thresh - df.no_simpl
df.round(3)[['no_simpl', 'simpl_thresh', 'delta', 'simpl_no_thresh']]

In [None]:
ea_after_simple[(ea_before_simpl.pred_index == -1) & (ea_after_simple.pred_index == 0)].head(5)

In [None]:
both_correct = (ea_before_simpl.pred_index == 0) & (ea_after_simple.pred_index == 0)
print(ea_before_simpl[both_correct]._word_len.mean(), ea_after_simple[both_correct]._word_len.mean())

In [None]:
better_recall = (ea_before_simpl.pred_index == -1) & (ea_after_simple.pred_index >= 0)
print(ea_before_simpl[better_recall]._word_len.mean(), ea_after_simple[better_recall]._word_len.mean())

In [None]:
better_recall_1 = (ea_before_simpl.pred_index == -1) & (ea_after_simple.pred_index == 0)
print(ea_before_simpl[better_recall_1]._word_len.mean(), ea_after_simple[better_recall_1]._word_len.mean())

In [None]:
ea_after_simple[(ea_before_simpl.pred_index == -1) & (ea_after_simple.pred_index == 0)]._word_len.mean()

In [None]:
ea_after_simple[(ea_before_simpl.pred_index == -1) & (ea_after_simple.pred_index == 0)]._word_len.mean()

## Error Analysis by Mention Length

In [None]:
index = []
labels = []
match_any = []
match_lost = []
better_ranking = []
worse_ranking = []
n_candidates = []

end = 10

def get_new_candidates(i):
    return ea_after_simple[(ea_before_simpl._word_len == i) & (ea_before_simpl.pred_index == -1) & (ea_after_simple.pred_index != -1)]

def get_better_ranking(i):
    return ea_after_simple[(ea_before_simpl._word_len == i) & (ea_after_simple.pred_index != -1) & (ea_after_simple.pred_index < ea_before_simpl.pred_index)]

def get_worse_ranking(i):
    return ea_after_simple[(ea_before_simpl._word_len == i) & (ea_before_simpl.pred_index != -1) & (ea_after_simple.pred_index > ea_before_simpl.pred_index)]

def get_lost_candidates(i):
    return ea_after_simple[(ea_before_simpl._word_len == i) & (ea_before_simpl.pred_index != -1) & (ea_after_simple.pred_index == -1)]

def get_n_candidates(i):
    return ea_after_simple[(ea_before_simpl._word_len == i)]

for i in range(1, end + 1):
    n_candidates_i = get_n_candidates(i)
    eaany = get_new_candidates(i)
    eabetter = get_better_ranking(i)
    ealost = get_lost_candidates(i)
    eaworse = get_worse_ranking(i)
    
    n_candidates.append(len(n_candidates_i))
    index.append(i)
    labels.append(str(i))
    match_any.append(len(eaany))
    better_ranking.append(len(eabetter))
    match_lost.append(-len(ealost))
    worse_ranking.append(-len(eaworse))
    
labels.append(str(f'>{ i }'))

i += 1

n_candidates_i = get_n_candidates(i)
eaany = get_new_candidates(i)
eabetter = get_better_ranking(i)
ealost = get_lost_candidates(i)
eaworse = get_worse_ranking(i)

n_candidates.append(len(n_candidates_i))
index.append(i)
match_any.append(len(eaany))
better_ranking.append(len(eabetter))
match_lost.append(-len(ealost))
worse_ranking.append(-len(eaworse))

In [None]:
def get_rel(m):
    return [mi / ci for mi, ci in zip(m, n_candidates)]

In [None]:
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
import matplotlib.font_manager as font_manager


fig, axs = plt.subplots(3, 1, figsize=(8,12), gridspec_kw={'height_ratios': [0.4, 1, 1]})
font_size_lg = 14
font_size_sm = 12

matplotlib.rcParams.update({'font.size': font_size_lg, 'font.family' : 'serif'})

width = 0.2

ax0 = axs[0]
ax0.bar(np.array(index), n_candidates, width=width, label='No. Mentions', color='grey', edgecolor='black', linewidth=0.5)
ax0.grid(axis='y')
ax0.set_title('(a) Total number of mentions', size=font_size_lg)
ax0.set_yticks(range(0,3000,500))

ax1 = axs[1]

ax1.grid(axis='y')

#ax.bar(np.array(index) - width, match_0, width=width, label='')
ax1.bar(np.array(index) - width, get_rel(match_any), width=width, label='↑ Recall', color='lightgreen', edgecolor='black', linewidth=0.5)
ax1.bar(np.array(index), get_rel(match_lost), width=width, label='↓ Recall', color='wheat', edgecolor='black', linewidth=0.5)
ax1.bar(np.array(index) + width, get_rel(np.array(match_any) + np.array(match_lost)), width=width, label='Difference', color='blue', edgecolor='black', linewidth=0.5)

y_range = np.arange(-0.1, 0.18, 0.05)
ax1.set_yticks(y_range)
ax1.set_yticklabels([f'{"+" if i > 0 else ""}{round(i * 100)}pp.' for i in y_range])
ax1.set_ylim(-0.12,0.14)
ax1.set_title('(b) Difference in recall@64 due to entity simplification', size=font_size_lg)
ax1.legend(loc='upper left', 
           ncol=3, fontsize=font_size_sm)    

ax2 = axs[2]
    
ax2.bar(np.array(index) - width, get_rel(better_ranking), width=width, label='↑ Ranking', color='lightgreen', edgecolor='black', linewidth=0.5)
ax2.bar(np.array(index), get_rel(worse_ranking), width=width, label='↓ Ranking', color='wheat', edgecolor='black', linewidth=0.5)
ax2.bar(np.array(index) + width, get_rel(np.array(better_ranking) + np.array(worse_ranking)), width=width, label='Difference', color='blue', edgecolor='black', linewidth=0.5)
    
y_range_2 = np.arange(-0.2,0.3,0.05)
ax2.set_yticks(y_range_2)
ax2.set_yticklabels([f'{"" if i > 0 else ""}{abs(round(i * 100))}%' for i in y_range_2])
 
ax2.legend(loc='upper left', ncol=3, fontsize=font_size_sm)
ax2.set_title('(c) Proportion of mentions with increased / decreased ranking\nof the ground-truth concept due to entity simplification', size=font_size_lg)
    
ax2.grid(axis='y')
ax2.set_xlabel('Mention length (tokens)')

for ax in axs:
    ax.set_xlim(axs[0].get_xlim())
    ax.set_xticks(index)
    ax.set_xticklabels(labels)

fig.tight_layout()
    
#plt.savefig('gain_vs_length.png', dpi=1200, bbox_inches='tight')
plt.show()