In [1]:
import sys
sys.path.append('../xmen/benchmarks')

# Entity Simplification with OpenAI / GPT-4

In [27]:
from pathlib import Path
import pandas as pd
import datasets
import numpy as np

In [3]:
base_path = Path.home() / '.cache' / 'xmen' / 'symptemist'

In [4]:
import dataloaders
dataset = dataloaders.load_symptemist()[0]

In [5]:
from xmen import load_kb
kb = load_kb(base_path / 'symptemist.jsonl')

In [6]:
from xmen.evaluation import error_analysis, evaluate, evaluate_at_k

In [7]:
from xmen.linkers import default_ensemble
linker = default_ensemble(base_path / 'index')

https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


In [8]:
SYMPTEMIST_FEW_SHOT_EXAMPLES = [
    ("afebril", "temperatura corporal normal"),
    ("induración de la vaginal testicular", "trastorno de testículo"),
    ("formaciones mamelonadas en su interior a nivel de lóbulo superior", "lesión de pulmón"),
    ("disnea", "disnea"),
]

In [30]:
def count_differences(simplified, original):
    cnt_all = 0
    cnt_diff = 0
    for ds, do in zip(simplified, original):
        for es, eo in zip(ds['entities'], do['entities']):
            cnt_all += 1
            if es['text'] != eo['text']:
                cnt_diff +=1
    return cnt_diff, cnt_all

In [10]:
MODEL = 'gpt-4-0125-preview'
table_file = 'lookup_gpt-4-0125-preview_20240214-205237_prompt1.pkl'
prompt = SYMPTEMIST_FEW_SHOT_EXAMPLES

lookup_table = {}

In [11]:
import pickle
lookup_table = pickle.load(open(table_file, 'rb')) 

In [14]:
from xmen.data.simplification import GPTSimplifier, EntitySimplification, SimplifierWrapper

In [15]:
text_simplifier = GPTSimplifier(
    model=MODEL, 
    open_ai_api_key="insert_api_key", 
    fixed_few_shot_examples = prompt,
    table=lookup_table
)

In [16]:
simplifier = EntitySimplification(text_simplifier, set_long_form=True)

In [17]:
print(text_simplifier.prompt_template.format('aumento de densidad en lóbulo inferior'))

Your task is to simplify expressions, such that the simplified version is more suitable for looking up concepts in a medical terminology. If the input is already simple enough, just return the input. 

Here are some examples:
Input: ```afebril```
Simplified: ```temperatura corporal normal```

Input: ```induración de la vaginal testicular```
Simplified: ```trastorno de testículo```

Input: ```formaciones mamelonadas en su interior a nivel de lóbulo superior```
Simplified: ```lesión de pulmón```

Input: ```disnea```
Simplified: ```disnea```

Input: ```aumento de densidad en lóbulo inferior```
Simplified:


In [18]:
text_simplifier.simplify('aumento de densidad en lóbulo inferior')

'lesión de pulmón'

# Generate Candidates

In [19]:
candidates = linker.predict_batch(dataset, top_k=64, batch_size=128)

Map:   0%|          | 0/608 [00:00<?, ? examples/s]

Map:   0%|          | 0/246 [00:00<?, ? examples/s]

Map:   0%|          | 0/136 [00:00<?, ? examples/s]

In [20]:
_ = evaluate_at_k(dataset['train'], candidates['train'])

Recall@1 0.4388426128890837
Recall@2 0.5610112523746895
Recall@4 0.6226801110623995
Recall@8 0.6786497150372643
Recall@16 0.7219055969603975
Recall@32 0.7558088557650153
Recall@64 0.7837205903843344


In [21]:
_ = evaluate_at_k(dataset['validation'], candidates['validation'])

Recall@1 0.4647592463363573
Recall@2 0.5757152826238661
Recall@4 0.6427076064200977
Recall@8 0.6929518492672715
Recall@16 0.7264480111653873
Recall@32 0.7571528262386602
Recall@64 0.782274947662247


In [22]:
_ = evaluate_at_k(dataset['test'], candidates['test'])

Recall@1 0.4474885844748858
Recall@2 0.5602388479100808
Recall@4 0.6199508254302775
Recall@8 0.6656129258868985
Recall@16 0.7007376185458377
Recall@32 0.7323498419388831
Recall@64 0.7632595714787496


In [24]:
simplified_candidates = simplifier.transform_batch(candidates)

Map:   0%|          | 0/608 [00:00<?, ? examples/s]

Map:   0%|          | 0/246 [00:00<?, ? examples/s]

Map:   0%|          | 0/136 [00:00<?, ? examples/s]

# Determine Optimal Cutoff

In [28]:
SPLIT = 'train'
eval_candidates = candidates[SPLIT]
eval_ds = dataset[SPLIT]
eval_simplified_ds = simplified_candidates[SPLIT]

def select_candidates(d, idx, filter_fn=lambda e: True):
    result = []
    for ei, ec, eg in zip(d['entities'], eval_candidates[idx]['entities'], eval_ds[idx]['entities']):
        if filter_fn(ec):
            result.append(ec)
        else:
            result.append(ei)
    return { 'entities' : result }

In [31]:
cutoff_eval = []

for cutoff in np.arange(0.55, 1.05, 0.05):
    print(cutoff)
    fn = lambda e: e['normalized'][0]['score'] >= cutoff
    best_candidates = eval_simplified_ds.map(lambda d, i: select_candidates(d, i, fn), with_indices=True)
    diffs = count_differences(best_candidates, eval_candidates)
    print(diffs)
    eval_res = evaluate_at_k(eval_ds, best_candidates, silent=True)
    cutoff_eval_i = {'cutoff' : cutoff, 'num_changed' : diffs[0], 'num_all' : diffs[1]}
    for i, es in eval_res.items():
        cutoff_eval_i[i] = es['strict']['recall']
    cutoff_eval.append(cutoff_eval_i)

0.55
(9, 6843)
0.6000000000000001


Map:   0%|          | 0/608 [00:00<?, ? examples/s]

(104, 6843)
0.6500000000000001


Map:   0%|          | 0/608 [00:00<?, ? examples/s]

(469, 6843)
0.7000000000000002


Map:   0%|          | 0/608 [00:00<?, ? examples/s]

(1091, 6843)
0.7500000000000002


Map:   0%|          | 0/608 [00:00<?, ? examples/s]

(1871, 6843)
0.8000000000000003


Map:   0%|          | 0/608 [00:00<?, ? examples/s]

(2586, 6843)
0.8500000000000003


Map:   0%|          | 0/608 [00:00<?, ? examples/s]

(3239, 6843)
0.9000000000000004


Map:   0%|          | 0/608 [00:00<?, ? examples/s]

(3864, 6843)
0.9500000000000004


Map:   0%|          | 0/608 [00:00<?, ? examples/s]

(4357, 6843)
1.0000000000000004


Map:   0%|          | 0/608 [00:00<?, ? examples/s]

(5841, 6843)


In [None]:
cutoff_eval = pd.DataFrame(cutoff_eval).round(3)

In [None]:
baseline = evaluate_at_k(eval_ds, evalcandidates, silent=True)

In [None]:
from matplotlib import pyplot as plt
import matplotlib

matplotlib.rcParams.update({'font.size': 14, 'font.family' : 'Times New Roman'})

fig, axs = plt.subplots(2, 1, figsize=(10,5), sharex=True, gridspec_kw={'height_ratios': [2, 1]}, squeeze=True)

ax1 = axs[0]

x_range = np.arange(0, len(cutoff_eval))
width = 0.3       

ax1.bar(x_range - 0.5 * width, 
        (cutoff_eval.set_index('cutoff')[1] - baseline[1]['strict']['recall']).values, 
        width, color='darkblue', linewidth=0.5, edgecolor='black', label='$\Delta$ Recall@1 (pp.)')
ax1.bar(x_range + 0.5 * width, 
        (cutoff_eval.set_index('cutoff')[64] - baseline[64]['strict']['recall']).values, 
        width, color='lightblue', linewidth=0.5, edgecolor='black', label='$\Delta$ Recall@64 (pp.)')

ax1.grid(axis='y')
ax1.set_yticks(np.arange(-0.005, 0.03, 0.005))
ax1.set_yticklabels([f'+{t * 100:.1f}' if t > 0 else f'{t * 100:.1f}' for t in ax1.get_yticks()])

ax1.set_xticks(x_range)
ax1.set_xticklabels(cutoff_eval.cutoff.values)

ax0 = axs[1]
ax0.grid(axis='y')
ax0.plot(x_range, cutoff_eval.num_changed / cutoff_eval.num_all, color='red', label='Simplified Mentions (%)', marker='o')
ax0.set_yticks(np.arange(0,1.02,0.2))
ax0.set_yticklabels([f"{int(t * 100)}%" if t >= 0 and t <=1 else "" for t in ax0.get_yticks()])

ax0.set_xlabel('Confidence Threshold for Entity Simplication')
ax0.legend()

ax0.set_xticks(x_range)
ax0.set_xticklabels(cutoff_eval.cutoff.values)

ax1.legend()

#plt.savefig('diff_cutoff.png', dpi=1200, bbox_inches='tight')

plt.show()

In [None]:
best_cutoff = 0.85

In [None]:
simplifier = SimplifierWrapper(
    linker,
    text_simplifier, 
    filter_fn=lambda c: c['normalized'][0]['score'] < best_cutoff, 
    set_long_form=True
)
simplified_candidates_cutoff = simplifier.predict_batch(candidates, top_k=64, batch_size=128)

In [None]:
print(count_differences(simplified_candidates_cutoff['train'], dataset['train']))
print(count_differences(simplified_candidates_cutoff['validation'], dataset['validation']))
print(count_differences(simplified_candidates_cutoff['test'], dataset['test']))

In [None]:
_ = evaluate_at_k(dataset['validation'], simplified_candidates_cutoff['validation'])

In [None]:
_ = evaluate_at_k(dataset['test'], simplified_candidates_cutoff['test'])

# Save

In [None]:
simplified_candidates_cutoff.save_to_disk('candidates_simplified_cutoff')

# Analysis

In [None]:
from xmen.evaluation import error_analysis, evaluate_at_k

In [None]:
simplified_candidates_cutoff = datasets.load_from_disk('candidates_simplified_cutoff')

In [None]:
_ = evaluate_at_k(dataset['test'], baseline_candidates['test'])

In [None]:
_ = evaluate_at_k(dataset['test'], simplified_candidates_cutoff['test'])

In [None]:
simplified_ds = simplifier.transform_batch(dataset)

In [None]:
ea_before_simpl = error_analysis(dataset['train'], baseline_candidates['train'])
ea_after_simple = error_analysis(simplified_ds['train'], simplified_candidates_cutoff['train'])

In [None]:
ds = dataset['train']
cands = simplified_candidates_cutoff['train']

def show_example(i, return_doc=False):
    # print(ents[i])
    eai = ea_before_simpl.loc[i]
    eai_after = ea_after_simple.loc[i]
    d = cands.filter(lambda d: d['document_id'] == eai.document_id)
    assert len(d) == 1
    d = d[0]
    ent = [e for e in d['entities'] if e['offsets'][0][0] == eai.pred_start and e['offsets'][0][1] == eai.pred_end]
    assert len(ent) == 1
    ent = ent[0]
    #print(ent)
    
    print('## Before', eai.gt_text, eai.pred_top_score, '-->', eai.pred_index)    
    print('## Simplified', eai_after.gt_text, eai_after.pred_top_score, '-->', eai_after.pred_index)
    print(eai.gold_concept['db_id'])
    print(kb.cui_to_entity[eai.gold_concept['db_id']])
    
    if return_doc:
        return d, ent

In [None]:
index = []
labels = []
match_any = []
match_lost = []
better_ranking = []
worse_ranking = []
n_candidates = []

end = 10

def get_new_candidates(i):
    return ea_after_simple[(ea_before_simpl._word_len == i) & (ea_before_simpl.pred_index == -1) & (ea_after_simple.pred_index != -1)]

def get_better_ranking(i):
    return ea_after_simple[(ea_before_simpl._word_len == i) & (ea_after_simple.pred_index != -1) & (ea_after_simple.pred_index < ea_before_simpl.pred_index)]

def get_worse_ranking(i):
    return ea_after_simple[(ea_before_simpl._word_len == i) & (ea_before_simpl.pred_index != -1) & (ea_after_simple.pred_index > ea_before_simpl.pred_index)]

def get_lost_candidates(i):
    return ea_after_simple[(ea_before_simpl._word_len == i) & (ea_before_simpl.pred_index != -1) & (ea_after_simple.pred_index == -1)]

def get_n_candidates(i):
    return ea_after_simple[(ea_before_simpl._word_len == i)]

for i in range(1, end + 1):
    #ea0 = ea_after_simple[(ea_before_simpl._word_len == i) & (ea_before_simpl.pred_index == -1) & (ea_after_simple.pred_index == 0)]
    n_candidates_i = get_n_candidates(i)
    eaany = get_new_candidates(i)
    eabetter = get_better_ranking(i)
    ealost = get_lost_candidates(i)
    eaworse = get_worse_ranking(i)
    
    n_candidates.append(len(n_candidates_i))
    index.append(i)
    labels.append(str(i))
    match_any.append(len(eaany))
    better_ranking.append(len(eabetter))
    match_lost.append(-len(ealost))
    worse_ranking.append(-len(eaworse))
    
labels.append(str(f'>{ i }'))

i += 1

#ea0 = ea_after_simple[(ea_before_simpl._word_len >= i) & (ea_before_simpl.pred_index == -1) & (ea_after_simple.pred_index == 0)]
n_candidates_i = get_n_candidates(i)
eaany = get_new_candidates(i)
eabetter = get_better_ranking(i)
ealost = get_lost_candidates(i)
eaworse = get_worse_ranking(i)

n_candidates.append(len(n_candidates_i))
index.append(i)
match_any.append(len(eaany))
better_ranking.append(len(eabetter))
match_lost.append(-len(ealost))
worse_ranking.append(-len(eaworse))

In [None]:
def get_rel(m):
    return [mi / ci for mi, ci in zip(m, n_candidates)]

In [None]:
n_candidates

In [None]:
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
import matplotlib.font_manager as font_manager


fig, axs = plt.subplots(3, 1, figsize=(8,12), gridspec_kw={'height_ratios': [0.4, 1, 1]})
font_size_lg = 14
font_size_sm = 12

matplotlib.rcParams.update({'font.size': font_size_lg, 'font.family' : 'Times New Roman'})

width = 0.2

ax0 = axs[0]
ax0.bar(np.array(index), n_candidates, width=width, label='No. Mentions', color='grey', edgecolor='black', linewidth=0.5)
ax0.grid(axis='y')
#ax0.set_ylabel('Number of mentions')
ax0.set_title('(a) Total number of mentions', size=font_size_lg)
ax0.set_yticks(range(0,3000,500))

ax1 = axs[1]

ax1.grid(axis='y')

#ax.bar(np.array(index) - width, match_0, width=width, label='')
ax1.bar(np.array(index) - width, get_rel(match_any), width=width, label='↑ Recall', color='lightgreen', edgecolor='black', linewidth=0.5)
ax1.bar(np.array(index), get_rel(match_lost), width=width, label='↓ Recall', color='wheat', edgecolor='black', linewidth=0.5)
ax1.bar(np.array(index) + width, get_rel(np.array(match_any) + np.array(match_lost)), width=width, label='Difference', color='blue', edgecolor='black', linewidth=0.5)

y_range = np.arange(-0.1, 0.18, 0.05)
ax1.set_yticks(y_range)
ax1.set_yticklabels([f'{"+" if i > 0 else ""}{round(i * 100)}pp.' for i in y_range])
ax1.set_ylim(-0.12,0.14)
ax1.set_title('(b) Difference in recall@64 due to entity simplification', size=font_size_lg)
ax1.legend(loc='upper left', 
           #bbox_to_anchor=(1, 0.9), 
           ncol=3, fontsize=font_size_sm)    

ax2 = axs[2]
    
ax2.bar(np.array(index) - width, get_rel(better_ranking), width=width, label='↑ Ranking', color='lightgreen', edgecolor='black', linewidth=0.5)
ax2.bar(np.array(index), get_rel(worse_ranking), width=width, label='↓ Ranking', color='wheat', edgecolor='black', linewidth=0.5)
ax2.bar(np.array(index) + width, get_rel(np.array(better_ranking) + np.array(worse_ranking)), width=width, label='Difference', color='blue', edgecolor='black', linewidth=0.5)
    
y_range_2 = np.arange(-0.2,0.3,0.05)
ax2.set_yticks(y_range_2)
ax2.set_yticklabels([f'{"" if i > 0 else ""}{abs(round(i * 100))}%' for i in y_range_2])
 
ax2.legend(loc='upper left', ncol=3, fontsize=font_size_sm)
ax2.set_title('(c) Proportion of mentions with increased / decreased ranking\nof the ground-truth concept due to entity simplification', size=font_size_lg)
    
ax2.grid(axis='y')
ax2.set_xlabel('Mention length (tokens)')

for ax in axs:
    ax.set_xlim(axs[0].get_xlim())
    ax.set_xticks(index)
    ax.set_xticklabels(labels)

fig.tight_layout()
    
plt.savefig('gain_vs_length.png', dpi=1200, bbox_inches='tight')
plt.show()

In [None]:
ea_after_simple[ea_after_simple.pred_text.map(lambda l: 'homo' in l[0])]

In [None]:
ea = ea_after_simple[(ea_before_simpl._word_len == 6) & (ea_before_simpl.pred_index == -1) & (ea_after_simple.pred_index == 0)]
print(len(ea))
ea.head(5)

In [None]:
di, ei = show_example(2704, return_doc=True)

In [None]:
ei

In [None]:
ea = ea_after_simple[(ea_before_simpl._word_len == 1) & (ea_before_simpl.pred_index == -1) & (ea_after_simple.pred_index != 0)]
print(len(ea))
ea.head(5)

In [None]:
show_example(2971)

In [None]:
ea_after_simple[(ea_before_simpl._word_len == 2) & (ea_before_simpl.pred_index == -1) & (ea_after_simple.pred_index == 0)].head(5)

In [None]:
ea_after_simple[(ea_before_simpl._word_len == 5) & (ea_before_simpl.pred_index == -1) & (ea_after_simple.pred_index == 0)].head(5)

In [None]:
ea_after_simple[(ea_before_simpl._word_len > 7) & (ea_before_simpl.pred_index == -1) & (ea_after_simple.pred_index != -1)]