In [238]:
%load_ext autoreload
%autoreload 2
from os.path import join
from tqdm import tqdm
import pandas as pd
import numpy as np
import clin.llm
import clin.parse
from typing import List, Dict
results_dir = '../results/'
from clin.config import PATH_REPO
import clin.eval.ebm
import clin.eval.eval
from clin.modules import ebm
import joblib
import imodelsx.process_results
from IPython.display import HTML
import clin.viz
import re

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [239]:
# get human spans
df_spans = joblib.load(join(PATH_REPO, 'data', 'ebm', 'ebm_interventions_spans.pkl'))
nums = np.arange(100).tolist()
np.random.default_rng(seed=13).shuffle(nums)
dfe_spans = df_spans.iloc[nums]

# get predicted evidence
r = imodelsx.process_results.get_results_df(results_dir, use_cached=False)
r = r[r.dataset_name == 'ebm']
row = r[(r.n_shots == 5) * (r.checkpoint == 'text-davinci-003')].iloc[0]

# get common keys across each list
common_keys = [
    set.intersection(
        *[set(r["dict_evidence_ov_pv_ev"].iloc[i][j].keys()) for i in range(len(r))]
    )
    for j in range(100)
]
r['dict_evidence_common'] = r.apply(lambda x: [{k: x['dict_evidence_ov_pv_ev'][i][k] for k in common_keys[i]} for i in range(100)], axis=1)

100%|██████████| 18/18 [00:00<00:00, 1769.83it/s]


In [240]:
def calculate_mean_matches(d_evidence: List[Dict[str, str]], dfe_spans):
    """Finds mean number of times evidence span from llm contains a token from the human span.
    Mean is taken for each document and then averaged over all documents.
    Baseline is probability that any token falls into a human span.
    """

    # single example
    mean_matches = []
    mean_num_tokens = []
    for i in range(len(dfe_spans)):
        span = dfe_spans.iloc[i]
        doc = span['doc'].lower()
        toks = [tok.lower() for tok in span['toks_list']]
        annot = span['annot_list']
        # color_str = clin.viz.colorize(span['toks_list'], span['annot_list'], char_width_max=60, title=str(i) + " " + span['doc_id'])
        # display(HTML(color_str))

        # given reference text and set of tokens, find starting index of each token
        def _find_token_idxs(doc, toks):
            starts = []
            ends = []
            for tok in toks:
                idx = doc.find(tok, ends[-1] if len(ends) > 0 else 0)
                if idx == -1:
                    print('ERROR: token not found:', tok)
                    return None
                starts.append(idx)
                ends.append(idx + len(tok))
            
            # check that idxs are strictly increasing
            for i in range(1, len(starts)):
                if starts[i] <= starts[i - 1]:
                    print('ERROR: idxs not strictly increasing')
                    return None
            return starts, ends
        starts, ends = _find_token_idxs(doc, toks)

        def _get_overlapping_token_idxs(start: int, end: int, starts: List[int], ends: List[int]):
            """
            Given a span [start, end), find the indices of all tokens that overlap with the span.
            """
            idxs = []
            for i in range(len(starts)):
                if start < ends[i] and end > starts[i]:
                    # print(start, end, starts[i], ends[i])
                    idxs.append(i)
            return idxs

        matches = []
        num_toks = []
        for intervention_name, intervention_evidence in d_evidence[i].items():
            s = re.escape(intervention_evidence.lower())
            # s = intervention_name.lower()

            # search over all matches
            idxs_match = [m.start() for m in re.finditer(s, doc)]
            for idx_match in idxs_match:
                tok_idxs = _get_overlapping_token_idxs(idx_match, idx_match + len(s), starts, ends)
                matches.append(np.any(annot[tok_idxs] > 0))
                # num_toks.append(len(tok_idxs))
                num_toks.append(len(s.split()))
            # else:
                # matches.append(0)
                # num_toks.append(len(s.split())
        mean_matches.append(np.nanmean(matches))
        mean_num_tokens.append(np.nanmean(num_toks))
    return np.nanmean(mean_matches), np.nanmean(mean_num_tokens)

In [236]:
r['Span overlap accuracy'] = r.apply(lambda row: calculate_mean_matches(row['dict_evidence_ov_pv_ev'], dfe_spans)[0], axis=1)
r['Span length'] = r.apply(lambda row: calculate_mean_matches(row['dict_evidence_ov_pv_ev'], dfe_spans)[1], axis=1)
print('random baseline', np.concatenate(dfe_spans['annot_list'].values).mean())

  mean_matches.append(np.nanmean(matches))
  mean_num_tokens.append(np.nanmean(num_toks))


random baseline 0.03757215007215007


  mean_matches.append(np.nanmean(matches))
  mean_num_tokens.append(np.nanmean(num_toks))


In [237]:
r.groupby(['checkpoint', 'n_shots'])[['Span overlap accuracy', 'Span length']].mean().round(2)

Unnamed: 0_level_0,Unnamed: 1_level_0,Span overlap accuracy,Span length
checkpoint,n_shots,Unnamed: 2_level_1,Unnamed: 3_level_1
gpt-4-0314,1,0.92,7.98
gpt-4-0314,5,0.92,8.34
text-davinci-003,1,0.85,7.8
text-davinci-003,5,0.85,7.03
