# Evaluating ElasticSearch BM-25 top-k accuracy documents & baseline-redacted documents

First, we need to connect to ElasticSearch and add all the profiles (as strings) to indexes:

In [27]:
from elasticsearch import Elasticsearch

username = "elastic"
password = "FjZD_LI-=AJOtsfpq9U*"

url = f"https://elastic:{password}@rush-compute-01.tech.cornell.edu:9200"

es = Elasticsearch(
    url,
    # use_ssl = True,
    # ca_certs=False,
    verify_certs=False
)

import urllib3
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)



In [2]:
es

<Elasticsearch([{'host': 'rush-compute-01.tech.cornell.edu', 'port': 9200, 'use_ssl': True, 'http_auth': 'elastic:FjZD_LI-=AJOtsfpq9U*'}])>

In [3]:
# Delete an existing index
# es.indices.delete(index='val_5_profile_str', ignore=[400, 404])
# for idx in [idx for idx in es.indices.get_alias().keys() if not idx.startswith('.')]:
    # print('deleting', idx)
    # es.indices.delete(index=idx, ignore=[400, 404])

In [4]:
import datasets

from elasticsearch import helpers
from elasticsearch_dsl import Index


def create_index_from_profiles(index_name: str, dataset_split: str, b: float = 0.9, k1: float = 4.5):
    index = Index(index_name, es)
    index.settings(
        number_of_shards=1, # need one shard since scores are calculated with a single shard!
        # https://www.elastic.co/guide/en/elasticsearch/reference/current/consistent-scoring.html
        # need zero replicas for consistent scoring!
        number_of_replicas=0,
        index={
            'mapping': {
                'ignore_malformed': True,
                'total_fields.limit': 20_000
            },
            "similarity" : {
              "default" : {
                "type" : "BM25",
                "b": b,
                "k1": k1
              }
            },
            
            #     'settings': {
            #         'analysis': {
            #           {
            #             "my_analyzer": {
            #               "tokenizer": "whitespace",
            #               "filter": [ "stop" ]
            #             }
            #         }
            #     }
            # }
        }
    )
    index.create()
    

    dataset = datasets.load_dataset('wiki_bio', split=dataset_split, version='1.2.0')

    def make_prof_table(prof):
        table = prof['input_text']['table']
        prof_dict = dict(zip(table['column_header'], table['content']))
        prof_dict = { k.strip().strip('.|<>'): v.strip().strip('.|<>') for k,v in prof_dict.items() }
        if 'no.of.children' in prof_dict:
            # fix for one weird error
            prof_dict['no of children'] = prof_dict['no.of.children']
            del prof_dict['no.of.children']
        prof_dict = {k: v for k,v in prof_dict.items() if (len(k) and len(v))}
        prof_str = ''
        for k,v in prof_dict.items():
            prof_str += f'{k} : {v}'
            prof_str += '\n'
        return prof_str

    prof_data = [make_prof_table(prof) for prof in dataset]

    print('inserting', len(prof_data), 'profiles')

    prof_data_json = [{'_id': idx, 'body': { 'profile': profile_str, 'id': idx }} for idx, profile_str in enumerate(prof_data)]
    return helpers.bulk(es, prof_data_json, index=index_name)

In [5]:
# create_index_from_profiles('val_100_profile_str', 'val[:100%]')
# create_index_from_profiles('test_100_profile_str', 'test[:100%]')
# create_index_from_profiles('train_100_profile_str', 'train[:100%]')

Now that the indices are created, we can iterate over documents and compute the top-K accuracy.

In [6]:
import sys
sys.path.append('/home/jxm3/research/deidentification/unsupervised-deidentification')

import os
from datamodule import WikipediaDataModule

num_cpus = len(os.sched_getaffinity(0))

dm = WikipediaDataModule(
    document_model_name_or_path = 'roberta-base',
    profile_model_name_or_path = 'google/tapas-base',
    dataset_name='wiki_bio',
    dataset_train_split='train[:100%]',
    dataset_val_split='val[:100%]',
    dataset_test_split='test[:100%]',
    dataset_version='1.2.0',
    num_workers=num_cpus,
    train_batch_size=256,
    eval_batch_size=256,
    max_seq_length=128,
    sample_spans=False,
)
dm.setup("fit")

Initializing WikipediaDataModule with num_workers = 8 and mask token `<mask>`
loading wiki_bio[1.2.0] split train[:100%]
loading wiki_bio[1.2.0] split val[:100%]
loading wiki_bio[1.2.0] split test[:100%]
                        

In [114]:
for index_name in [idx for idx in es.indices.get_alias().keys() if not idx.startswith('.')]:
    print(index_name, es.count(index=index_name))

val_100__analyzer_nostopfilter_profile_str {'count': 72831, '_shards': {'total': 1, 'successful': 1, 'skipped': 0, 'failed': 0}}
val_100_profile_str {'count': 72831, '_shards': {'total': 1, 'successful': 1, 'skipped': 0, 'failed': 0}}
val_100__analyzer_yesstopfilter_profile_str {'count': 72831, '_shards': {'total': 1, 'successful': 1, 'skipped': 0, 'failed': 0}}
test_100_profile_str {'count': 72831, '_shards': {'total': 1, 'successful': 1, 'skipped': 0, 'failed': 0}}
train_100_profile_str {'count': 582659, '_shards': {'total': 1, 'successful': 1, 'skipped': 0, 'failed': 0}}




In [115]:
from typing import List

from elasticsearch import Elasticsearch

import json

def msearch(
        es: Elasticsearch,
        max_hits: int,
        query_strings: List[str],
        index: str = 'train_100_profile_str,test_100_profile_str,val_100_profile_str',
        include_source: bool = False,
    ):
    search_arr = []
    # req_head
    
    for q in query_strings:
        search_arr.append({'index': index })
        # req_body
        search_arr.append(
            {
                "query": {
                    "query_string": {
                        "query": q
                    },
                },
                'size': max_hits ,
                '_source': include_source
            }
        )
    
    request = ''
    request = ' \n'.join([json.dumps(x) for x in search_arr])

    # as you can see, you just need to feed the <body> parameter,
    # and don't need to specify the <index> and <doc_type> as usual 
    resp = es.msearch(body = request)
    return resp

msearch_results = msearch(es, query_strings=['Butch Cassidy', 'Marlon Evans'], max_hits=2, include_source=True)
msearch_results

{'took': 2,
 'responses': [{'took': 2,
   'timed_out': False,
   '_shards': {'total': 3, 'successful': 3, 'skipped': 0, 'failed': 0},
   'hits': {'total': {'value': 408, 'relation': 'eq'},
    'max_score': 36.159836,
    'hits': [{'_index': 'train_100_profile_str',
      '_id': '286083',
      '_score': 36.159836,
      '_ignored': ['body.profile.keyword'],
      '_source': {'body': {'profile': "image_caption : fort worth , texas , 1900\nparents : maximillian parker ann campbell gillies\nnationality : american\ncause : gunshot\nconviction : imprisoned for horse theft in the state prison in laramie , wyoming\ndeath_place : near san vicente , bolivia\nimage_name : butch cassidy with bowler hat.jpg\nallegiance : butch cassidy 's wild bunch\ndeath_date : 7 november 1908\nbirth_date : 13 april 1866\narticle_title : butch cassidy\nbirth_name : robert leroy parker\nname : butch cassidy\nimage_size : 250px\nconviction_penalty : served 1896 18 months of 2-year sentence ; released january\nalias

In [89]:
def msearch_by_id(
        es: Elasticsearch,
        query_strings: List[str],
        ids: List[int],
        index: str = 'train_100_profile_str,test_100_profile_str,val_100_profile_str'
    ):
    search_arr = []
    
    assert len(ids) == len(query_strings)
    for q, _id in zip(query_strings, ids):
        search_arr.append({'index': index })
        # req_body
        search_arr.append(
            {
                "query": {
                    "bool": {
                        "must": [
                            {
                                "query_string": {
                                    "query": q
                                }
                            },
                        ],
                    "filter": {
                          "ids": {
                            "values": [_id]
                          }
                    }
                    },
                },
                'size': 1 ,
                'track_total_hits': True,
                '_source': False
            }
        )
    
    request = ''
    request = ' \n'.join([json.dumps(x) for x in search_arr])

    # as you can see, you just need to feed the <body> parameter,
    # and don't need to specify the <index> and <doc_type> as usual 
    resp = es.msearch(body = request)
    return resp

msearch_by_id(
    es=es,
    query_strings=['butch cassidy', 'marlon evans'], 
    # 286083 -> butch cassidy, (ranked #1)
    # 398154 -> patrick cassidy  (ranked #2)
    # ids=[398154, 13]
)
test_msearch_by_id_responses

{'took': 3,
 'responses': [{'took': 2,
   'timed_out': False,
   '_shards': {'total': 3, 'successful': 3, 'skipped': 0, 'failed': 0},
   'hits': {'total': {'value': 1, 'relation': 'eq'},
    'max_score': 29.086437,
    'hits': [{'_index': 'train_100_profile_str',
      '_id': '398154',
      '_score': 29.086437,
      '_ignored': ['body.profile.keyword']}]},
   'status': 200},
  {'took': 2,
   'timed_out': False,
   '_shards': {'total': 3, 'successful': 3, 'skipped': 0, 'failed': 0},
   'hits': {'total': {'value': 1, 'relation': 'eq'},
    'max_score': 36.832985,
    'hits': [{'_index': 'val_100_profile_str',
      '_id': '13',
      '_score': 36.832985,
      '_ignored': ['body.profile.keyword']}]},
   'status': 200}]}

In [91]:
test_msearch_by_id_responses['responses'][0]['hits']

{'total': {'value': 1, 'relation': 'eq'},
 'max_score': 29.086437,
 'hits': [{'_index': 'train_100_profile_str',
   '_id': '398154',
   '_score': 29.086437,
   '_ignored': ['body.profile.keyword']}]}

In [92]:
test_msearch_by_id_responses['responses'][0]['hits']['max_score']

29.086437

In [106]:
def msearch_total_hits_by_min_score(
        es: Elasticsearch,
        query_strings: List[str],
        min_scores: List[float],
        index: str = 'train_100_profile_str,test_100_profile_str,val_100_profile_str'
    ):
    search_arr = []
    
    # from https://stackoverflow.com/a/60857312/2287177:
    #  If _search must be used instead of _count, and you're on Elasticsearch 7.0+,
    # setting size: 0 and track_total_hits: true will provide the same info as _count
    
    assert len(query_strings) == len(min_scores)
    for q, min_score in zip(query_strings, min_scores):
        search_arr.append({'index': index })
        # req_body
        search_arr.append(
            {
                "query": {
                    "bool": {
                        "must": [
                            {
                                "query_string": {
                                    "query": q
                                }
                            },
                        ],
                    },
                },
                "min_score": min_score,
                "track_total_hits": True,
                'size': 1,
                '_source': False
            }
        )
    
    request = ''
    request = ' \n'.join([json.dumps(x) for x in search_arr])

    # as you can see, you just need to feed the <body> parameter,
    # and don't need to specify the <index> and <doc_type> as usual 
    resp = es.msearch(body = request)
    return resp

sample_msearch_total_hits_by_min_score = msearch_total_hits_by_min_score(
    es=es,
    query_strings=['butch cassidy', 'marlon evans'], 
    # 286083 -> butch cassidy, (ranked #1)
    # 398154 -> patrick cassidy  (ranked #2)
    # ids=[398154, 13]
    min_scores=[26.159836, 30.0]
)
sample_msearch_total_hits_by_min_score

{'took': 2,
 'responses': [{'took': 2,
   'timed_out': False,
   '_shards': {'total': 3, 'successful': 3, 'skipped': 0, 'failed': 0},
   'hits': {'total': {'value': 9, 'relation': 'eq'},
    'max_score': 36.159836,
    'hits': [{'_index': 'train_100_profile_str',
      '_id': '286083',
      '_score': 36.159836,
      '_ignored': ['body.profile.keyword']}]},
   'status': 200},
  {'took': 2,
   'timed_out': False,
   '_shards': {'total': 3, 'successful': 3, 'skipped': 0, 'failed': 0},
   'hits': {'total': {'value': 1, 'relation': 'eq'},
    'max_score': 36.832985,
    'hits': [{'_index': 'val_100_profile_str',
      '_id': '13',
      '_score': 36.832985,
      '_ignored': ['body.profile.keyword']}]},
   'status': 200}]}

In [107]:
sample_msearch_total_hits_by_min_score['responses'][0]['hits']

{'total': {'value': 9, 'relation': 'eq'},
 'max_score': 36.159836,
 'hits': [{'_index': 'train_100_profile_str',
   '_id': '286083',
   '_score': 36.159836,
   '_ignored': ['body.profile.keyword']}]}

In [109]:
sample_msearch_total_hits_by_min_score['responses'][0]['hits']['hits'][0]['_id']

'286083'

In [108]:
sample_msearch_total_hits_by_min_score['responses'][0]['hits']['total']['value']

9

In [113]:
msearch_results['responses'][0]['hits']['hits'][0]['_score']

36.159836

In [35]:
msearch_results.keys()

dict_keys(['took', 'responses'])

In [45]:
msearch_results['responses'][1]['hits']['hits'][0]

{'_index': 'train_100_profile_str',
 '_id': '286083',
 '_score': 36.159836,
 '_ignored': ['body.profile.keyword'],
 '_source': {'body': {'profile': "image_caption : fort worth , texas , 1900\nparents : maximillian parker ann campbell gillies\nnationality : american\ncause : gunshot\nconviction : imprisoned for horse theft in the state prison in laramie , wyoming\ndeath_place : near san vicente , bolivia\nimage_name : butch cassidy with bowler hat.jpg\nallegiance : butch cassidy 's wild bunch\ndeath_date : 7 november 1908\nbirth_date : 13 april 1866\narticle_title : butch cassidy\nbirth_name : robert leroy parker\nname : butch cassidy\nimage_size : 250px\nconviction_penalty : served 1896 18 months of 2-year sentence ; released january\nalias : butch jim lowe , santiago maxwell , cassidy , mike cassidy , george cassidy ,\npartner : harry , matt warner longabaugh , aka sundance kid , elzy lay\nbirth_place : beaver , utah\ncharge : horse robbery theft , cattle rustling , bank and train\noc

In [7]:
import re

def preprocess_doc(doc: str) -> str:
    # limit 500 words
    doc = ' '.join(doc.split(' ')[:500])
    # fix braces and remove weird characters
    doc = doc.replace('-lrb-', '(').replace('-rrb-', ')')
    return re.sub(r'[^\w|\s]', ' ',doc)

def search_results_for_query_by_index(query: str, index: str, max_hits: int = 10):
    # print(query)
    search_results = es.search(index=index, q=query, size=max_hits, search_type='dfs_query_then_fetch')
    num_hits = search_results["hits"]["total"]["value"]
    # print("got", num_hits, "hits")
    return num_hits, search_results["hits"]["hits"]

def index_of_doc_id_in_results_list(doc: str, doc_id: int, max_hits=100):
    """Searches for test doc in all three indices. Returns index of doc in results if found."""
    _, results = search_results_for_query_by_index(
        query=preprocess_doc(doc),
        index="val_100_profile_str,test_100_profile_str,train_100_profile_str",
        max_hits=max_hits
    )
    results_from_test_set = [
        (idx, result) for (idx, result) in enumerate(results) if result['_index'] == 'test_100_profile_str'
    ]
    # print(len(results_from_test_set), "results from test set")
    
    for result_idx, result in enumerate(results):
        if (result['_index'] == 'test_100_profile_str') and (int(result['_id']) == doc_id):
            return result_idx
    return float('inf')

for i in range(4):
    print(index_of_doc_id_in_results_list(dm.test_dataset[i]['document'], i))

60
inf
0
inf


In [8]:
import collections
import tqdm

k_values = [1, 10, 100]
total_correct_by_k_doc = collections.defaultdict(lambda: 0)
total = 1000
for j in tqdm.trange(total):
    result_idx = index_of_doc_id_in_results_list(dm.test_dataset[j]['document'], j)
    for k in k_values:
        if result_idx < k: total_correct_by_k_doc[k] += 1


for k in k_values:
    acc = total_correct_by_k_doc[k] / total
    acc_str = f'Top-{k} accuracy = {acc*100.0:.2f}'
    print(acc_str)

100%|██████████| 1000/1000 [03:03<00:00,  5.44it/s]

Top-1 accuracy = 71.30
Top-10 accuracy = 85.10
Top-100 accuracy = 93.00





In [None]:
total_correct_by_k_lex = collections.defaultdict(lambda: 0)
total = 1000
for j in tqdm.trange(total):
    result_idx = index_of_doc_id_in_results_list(dm.test_dataset[j]['document_redact_lexical'], j)
    for k in k_values:
        if result_idx < k: total_correct_by_k_lex[k] += 1

        
for k in k_values:
    acc = total_correct_by_k_lex[k] / total
    acc_str = f'Top-{k} accuracy = {acc*100.0:.2f}'
    print(acc_str)

100%|██████████| 1000/1000 [02:03<00:00,  8.12it/s]

Top-1 accuracy = 0.00
Top-10 accuracy = 0.10
Top-100 accuracy = 0.10





In [None]:
total_correct_by_k_ner = collections.defaultdict(lambda: 0)
total = 1000
for j in tqdm.trange(total):
    result_idx = index_of_doc_id_in_results_list(dm.test_dataset[j]['document_redact_ner_bert'], j)
    for k in k_values:
        if result_idx < k: total_correct_by_k_ner[k] += 1

        
for k in k_values:
    acc = total_correct_by_k_ner[k] / total
    acc_str = f'Top-{k} accuracy = {acc*100.0:.2f}'
    print(acc_str)

100%|██████████| 1000/1000 [02:33<00:00,  6.52it/s]

Top-1 accuracy = 0.10
Top-10 accuracy = 0.60
Top-100 accuracy = 11.90





In [48]:
import pandas as pd

import glob
import re


adv_df = None
for model_name in ['model_3_1', 'model_3_2', 'model_3_3', 'model_3_4']:
    csv_filenames = glob.glob(f'../adv_csvs_full/{model_name}*/results__b_1__k_1__n_1000.csv')
    print(model_name, csv_filenames)
    for filename in csv_filenames:
        df = pd.read_csv(filename)
        df['model_name'] = re.search(r'adv_csvs_full/(model_\d.*)/.+.csv', filename).group(1)
        df['i'] = df.index
        mini_df = df[['perturbed_text', 'model_name', 'i']]
        
        if adv_df is None:
            adv_df = mini_df
        else:
            adv_df = pd.concat((adv_df, mini_df), axis=0)

model_3_1 ['../adv_csvs_full/model_3_1/results__b_1__k_1__n_1000.csv']
model_3_2 ['../adv_csvs_full/model_3_2/results__b_1__k_1__n_1000.csv', '../adv_csvs_full/model_3_2__idf/results__b_1__k_1__n_1000.csv']
model_3_3 ['../adv_csvs_full/model_3_3__placeholder/results__b_1__k_1__n_1000.csv']
model_3_4 ['../adv_csvs_full/model_3_4/results__b_1__k_1__n_1000.csv', '../adv_csvs_full/model_3_4__idf/results__b_1__k_1__n_1000.csv']


In [50]:
mini_val_dataset = dm.test_dataset[:1000]
ner_df = pd.DataFrame(
    columns=['perturbed_text'],
    data=mini_val_dataset['document_redact_ner_bert']
)
ner_df['model_name'] = 'named_entity'
ner_df['i'] = ner_df.index
       
lex_df = pd.DataFrame(
    columns=['perturbed_text'],
    data=mini_val_dataset['document_redact_lexical']
)
lex_df['model_name'] = 'lexical'
lex_df['i'] = lex_df.index

baseline_df = pd.concat((lex_df, ner_df), axis=0)
# baseline_df = pd.concat((lex_df, ), axis=0)

In [51]:
full_df = pd.concat((adv_df, baseline_df), axis=0)
full_df['model_name'].value_counts()

model_3_1                 1000
model_3_2                 1000
model_3_2__idf            1000
model_3_3__placeholder    1000
model_3_4                 1000
model_3_4__idf            1000
lexical                   1000
named_entity              1000
Name: model_name, dtype: int64

In [52]:
# this line puts newlines back
full_df['perturbed_text'] = full_df['perturbed_text'].apply(lambda s: s.replace('<SPLIT>', '\n'))

# this line replaces BERT-style masks (from PMLM) with roberta-style ones, so we can
# count them in a single command
full_df['perturbed_text'] = full_df['perturbed_text'].apply(lambda s: s.replace('[MASK]', '<mask>'))

In [53]:
import transformers
tokenizer = transformers.AutoTokenizer.from_pretrained('roberta-base')

def truncate_text(text: str, max_length=128) -> str:
    input_ids = tokenizer(text, truncation=True, max_length=128)['input_ids']
    reconstructed_text = (
        tokenizer
            .decode(input_ids)
            .replace('<mask>', ' <mask> ')
            .replace('  <mask>', ' <mask>')
            .replace('<mask>  ', '<mask> ')
            .replace('<s>', '')
            .replace('</s>', '')
            .strip()
    )
    return reconstructed_text

full_df['perturbed_text_truncated'] = full_df['perturbed_text'].apply(truncate_text)

In [54]:
def count_percent_masks(s):
    return s.count('<mask>') / len(s.split(' '))

full_df['percent_masks'] = full_df.apply(lambda s: count_percent_masks(s['perturbed_text_truncated']), axis=1)
full_df.groupby('model_name').mean()['percent_masks']

model_name
lexical                   0.282113
model_3_1                 0.089857
model_3_2                 0.148340
model_3_2__idf            0.142445
model_3_3__placeholder    0.164442
model_3_4                 0.125632
model_3_4__idf            0.119776
named_entity              0.245707
Name: percent_masks, dtype: float64

In [55]:
def count_masks(s):
    return s.count('<mask>')

full_df['num_masks'] = full_df.apply(lambda s: count_masks(s['perturbed_text_truncated']), axis=1)
full_df.groupby('model_name').mean()['num_masks']

model_name
lexical                   15.465
model_3_1                  4.410
model_3_2                  8.546
model_3_2__idf             7.044
model_3_3__placeholder     8.944
model_3_4                  5.736
model_3_4__idf             5.240
named_entity              14.738
Name: num_masks, dtype: float64

In [56]:
import zlib

original_text_truncated = [truncate_text(d) for d in mini_val_dataset['document']]

def count_compressed_bytes(s: str) -> int:
    return len(zlib.compress(s.encode()))

original_total_bytes = count_compressed_bytes('\n'.join(original_text_truncated))

1 - full_df.groupby('model_name').apply(lambda s: count_compressed_bytes('\n'.join(s['perturbed_text_truncated']))) / original_total_bytes

model_name
lexical                   0.150269
model_3_1                 0.047802
model_3_2                 0.092045
model_3_2__idf            0.065432
model_3_3__placeholder    0.092021
model_3_4                 0.070358
model_3_4__idf            0.051691
named_entity              0.207786
dtype: float64

In [57]:
full_df[full_df['i'] == 999]

Unnamed: 0,perturbed_text,model_name,i,perturbed_text_truncated,percent_masks,num_masks
999,<mask> joseph `` <mask> '' <mask> ( <mask> -- ...,model_3_1,999,<mask> joseph `` <mask> '' <mask> ( <mask> -- ...,0.054348,5
999,<mask> <mask> <mask> <mask> '' <mask> ( <mask>...,model_3_2,999,<mask> <mask> <mask> <mask> '' <mask> ( <mask>...,0.076087,7
999,<mask> <mask> <mask> <mask> '' <mask> ( 1913 -...,model_3_2__idf,999,<mask> <mask> <mask> <mask> '' <mask> ( 1913 -...,0.09375,9
999,<mask> <mask> <mask> <mask> '' <mask> ( <mask>...,model_3_3__placeholder,999,<mask> <mask> <mask> <mask> '' <mask> ( <mask>...,0.154639,15
999,<mask> <mask> <mask> <mask> '' <mask> ( 1913 -...,model_3_4,999,<mask> <mask> <mask> <mask> '' <mask> ( 1913 -...,0.091837,9
999,<mask> joseph <mask> bus '' <mask> ( <mask> --...,model_3_4__idf,999,<mask> joseph <mask> bus '' <mask> ( <mask> --...,0.054348,5
999,<mask> <mask> `` <mask> '' <mask> ( <mask> -- ...,lexical,999,<mask> <mask> `` <mask> '' <mask> ( <mask> -- ...,0.092784,9
999,<mask> <mask> `` <mask> '' <mask> ( 1913 -- se...,named_entity,999,<mask> <mask> `` <mask> '' <mask> ( 1913 -- se...,0.177083,17


In [None]:
full_df['bm25_test_guess_result_index'] = full_df.apply(lambda row: index_of_doc_id_in_results_list(row['perturbed_text_truncated'], row['i']), axis=1)

In [None]:
full_df[full_df['bm25_test_guess_result_index'] == -1] = float('inf')

In [None]:
full_df['bm25_was_correct'] = (full_df['bm25_test_guess_result_index'] == 0)

In [73]:
full_df.groupby('model_name')['bm25_was_correct'].mean()

model_name
lexical                   0.000
model_3_1                 0.138
model_3_2                 0.043
model_3_2__idf            0.045
model_3_3__placeholder    0.036
model_3_4                 0.067
model_3_4__idf            0.117
named_entity              0.001
Name: bm25_was_correct, dtype: float64

In [64]:
full_df.groupby('model_name').mean()

Unnamed: 0_level_0,i,percent_masks,num_masks
model_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
lexical,499.5,0.282113,15.465
model_3_1,499.5,0.089857,4.41
model_3_2,499.5,0.14834,8.546
model_3_2__idf,499.5,0.142445,7.044
model_3_3__placeholder,499.5,0.164442,8.944
model_3_4,499.5,0.125632,5.736
model_3_4__idf,499.5,0.119776,5.24
named_entity,499.5,0.245707,14.738
