In [11]:
from elasticsearch import Elasticsearch

username = "elastic"
password = "FjZD_LI-=AJOtsfpq9U*"

url = f"https://elastic:{password}@rush-compute-01.tech.cornell.edu:9200"

es = Elasticsearch(
    url,
    use_ssl = True,
    # ca_certs=False,
    verify_certs=False
)

import urllib3
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

In [177]:
# List existing indices
[idx for idx in es.indices.get_alias().keys() if not idx.startswith('.')]

['train_100_profile_str',
 'val_10_profile_str',
 'test_100_profile_str',
 'val_100_profile_str',
 'val_5_profile_str',
 'val_50_profile_str',
 'val_1_profile_str',
 'val_n1000_profile_str']

In [178]:
# Delete an existing index
# es.indices.delete(index='val_5_profile_str', ignore=[400, 404])
for idx in [idx for idx in es.indices.get_alias().keys() if not idx.startswith('.')]:
    print('deleting', idx)
    es.indices.delete(index=idx, ignore=[400, 404])

deleting test_100_profile_str
deleting val_1_profile_str
deleting val_n1000_profile_str
deleting train_100_profile_str
deleting val_5_profile_str
deleting val_50_profile_str
deleting val_100_profile_str
deleting val_10_profile_str


In [179]:
import datasets

from elasticsearch import helpers
from elasticsearch_dsl import Index


def create_index_from_profiles(index_name: str, dataset_split: str):
    index = Index(index_name, es)
    index.settings(
        number_of_shards=1, # need one shard since scores are calculated with a single shard!
        number_of_replicas=2,
        index={
            'mapping': {
                'ignore_malformed': True,
                'total_fields.limit': 20_000
            },
            "similarity" : {
              "default" : {
                "type" : "BM25",
                "b": 0.5,
                "k1": 0
              }
            }
        }
    )
    index.create()
    

    dataset = datasets.load_dataset('wiki_bio', split=dataset_split, version='1.2.0')

    def make_prof_table(prof):
        table = prof['input_text']['table']
        prof_dict = dict(zip(table['column_header'], table['content']))
        prof_dict = { k.strip().strip('.|<>'): v.strip().strip('.|<>') for k,v in prof_dict.items() }
        if 'no.of.children' in prof_dict:
            # fix for one weird error
            prof_dict['no of children'] = prof_dict['no.of.children']
            del prof_dict['no.of.children']
        prof_dict = {k: v for k,v in prof_dict.items() if (len(k) and len(v))}
        prof_str = ''
        for k,v in prof_dict.items():
            prof_str += f'{k} : {v}'
            prof_str += '\n'
        return prof_str

    prof_data = [make_prof_table(prof) for prof in dataset]

    print('inserting', len(prof_data), 'profiles')

    prof_data_json = [{'_id': idx, 'body': { 'profile': profile_str, 'id': idx }} for idx, profile_str in enumerate(prof_data)]
    return helpers.bulk(es, prof_data_json, index=index_name)

In [180]:
create_index_from_profiles('val_n1000_profile_str', 'val[:1000]')

Using custom data configuration default
Reusing dataset wiki_bio (/home/jxm3/.cache/huggingface/datasets/wiki_bio/default/1.2.0/c05ce066e9026831cd7535968a311fc80f074b58868cfdffccbc811dff2ab6da)


inserting 1000 profiles


(1000, [])

In [181]:
create_index_from_profiles('val_1_profile_str', 'val[:1%]')

Using custom data configuration default
Reusing dataset wiki_bio (/home/jxm3/.cache/huggingface/datasets/wiki_bio/default/1.2.0/c05ce066e9026831cd7535968a311fc80f074b58868cfdffccbc811dff2ab6da)


inserting 728 profiles


(728, [])

In [182]:
create_index_from_profiles('val_5_profile_str', 'val[:5%]')

Using custom data configuration default
Reusing dataset wiki_bio (/home/jxm3/.cache/huggingface/datasets/wiki_bio/default/1.2.0/c05ce066e9026831cd7535968a311fc80f074b58868cfdffccbc811dff2ab6da)


inserting 3642 profiles


(3642, [])

In [183]:
create_index_from_profiles('val_10_profile_str', 'val[:10%]')

Using custom data configuration default
Reusing dataset wiki_bio (/home/jxm3/.cache/huggingface/datasets/wiki_bio/default/1.2.0/c05ce066e9026831cd7535968a311fc80f074b58868cfdffccbc811dff2ab6da)


inserting 7283 profiles


(7283, [])

In [184]:
create_index_from_profiles('val_50_profile_str', 'val[:50%]')

Using custom data configuration default
Reusing dataset wiki_bio (/home/jxm3/.cache/huggingface/datasets/wiki_bio/default/1.2.0/c05ce066e9026831cd7535968a311fc80f074b58868cfdffccbc811dff2ab6da)


inserting 36416 profiles


(36416, [])

In [185]:
create_index_from_profiles('val_100_profile_str', 'val[:100%]')

Using custom data configuration default
Reusing dataset wiki_bio (/home/jxm3/.cache/huggingface/datasets/wiki_bio/default/1.2.0/c05ce066e9026831cd7535968a311fc80f074b58868cfdffccbc811dff2ab6da)


inserting 72831 profiles


(72831, [])

In [186]:
create_index_from_profiles('test_100_profile_str', 'test[:100%]')

Using custom data configuration default
Reusing dataset wiki_bio (/home/jxm3/.cache/huggingface/datasets/wiki_bio/default/1.2.0/c05ce066e9026831cd7535968a311fc80f074b58868cfdffccbc811dff2ab6da)


inserting 72831 profiles


(72831, [])

In [187]:
create_index_from_profiles('train_100_profile_str', 'train[:100%]')

Using custom data configuration default
Reusing dataset wiki_bio (/home/jxm3/.cache/huggingface/datasets/wiki_bio/default/1.2.0/c05ce066e9026831cd7535968a311fc80f074b58868cfdffccbc811dff2ab6da)


inserting 582659 profiles


(582659, [])

In [188]:
def search_results_for_query_by_index(query: str, max_hits: int = 10, index: str="val_100_profile_str"):
    search_results = es.search(index=index, q=query, size=max_hits)
    num_hits = search_results["hits"]["total"]["value"]
    return num_hits, search_results["hits"]["hits"]

search_results_for_query_by_index("gianluca farina", max_hits=3)

(11,
 [{'_index': 'val_100_profile_str',
   '_id': '50001',
   '_score': 18.672169,
   '_source': {'body': {'profile': 'birth_date : 15 december 1962\narticle_title : gianluca farina\nnationality : italy\nname : gianluca farina\n',
     'id': 50001}}},
  {'_index': 'val_100_profile_str',
   '_id': '19619',
   '_score': 9.491162,
   '_ignored': ['body.profile.keyword'],
   '_source': {'body': {'profile': 'plays : right-handed -lrb- one-handed backhand -rrb-\ndoublesrecord : 269 -- 255\nretired : 24 october 2005\nsinglesrecord : 469 -- 370\nhighestdoublesranking : no. 24 -lrb- 21 june 1999 -rrb-\nusopenresult : 4r -lrb- 2002 -rrb-\nheight : 1.72\nwimbledonresult : qf -lrb- 2003 -rrb-\nfrenchopenresult : 4r -lrb- 2001 , 2002 -rrb-\nbirth_date : 27 april 1972\narticle_title : silvia farina elia\nresidence : rome , italy\nturnedpro : 1988\ndoublestitles : 8\nhighestsinglesranking : no. 11 -lrb- 20 may 2002 -rrb-\nname : silvia farina elia\nsinglestitles : 3\naustralianopenresult : 4r -lrb- 

In [189]:
val_documents = datasets.load_dataset('wiki_bio', split='val[:100%]', version='1.2.0')['target_text']

Using custom data configuration default
Reusing dataset wiki_bio (/home/jxm3/.cache/huggingface/datasets/wiki_bio/default/1.2.0/c05ce066e9026831cd7535968a311fc80f074b58868cfdffccbc811dff2ab6da)


In [190]:
import re
gianluca_farina = val_documents[50_001]
gianluca_farina = gianluca_farina.replace('-lrb-', '(').replace('-rrb-', ')')
gianluca_farina = re.sub(r'[^\w|\s]', ' ',gianluca_farina)
print(gianluca_farina)

search_results_for_query_by_index(gianluca_farina, max_hits=3)

gianluca farina   born 15 december 1962   is an italian competition rower and olympic champion  
he received a gold medal in    quadruple sculls    at the 1988 summer olympics in seoul   together with agostino abbagnale   davide tizzano   and piero poli  
he received a bronze medal in    quadruple sculls    at the 1992 summer olympics in barcelona  



(10000,
 [{'_index': 'val_100_profile_str',
   '_id': '32538',
   '_score': 69.55131,
   '_ignored': ['body.profile.keyword'],
   '_source': {'body': {'profile': "caption : waff on the intrepid sea on december 7 , 2012 . addresses guests during the pearl harbor day ceremony\nalt : army general of 99th regional support command , u.s. army reserve maj. , addresses guests , during the pearl harbor day ceremony gen. on the intrepid sea , air and space 121207-a-bw524-117 william d. razz waff , commanding\nname : william d. razz waff\nhonorific_prefix : major general\nallegiance : american\nimage : army general of 99th regional support command , u.s. army reserve maj. , addresses guests , during the pearl harbor day ceremony gen. on the intrepid sea , air and space 121207-a-bw524-117 . william jpg d. razz waff , commanding\nrank : major general\nawards : distinguished oak additionally he has received the horatio gates gold medal and leaf the theodore roosevelt medal from the adjutant general

In [191]:
import tqdm

def count_correct_val_predictions(num_total: int, index_name: str):
    print(num_total, index_name)
    def preprocess_doc(doc: str) -> str:
        doc = doc.replace('-lrb-', '(').replace('-rrb-', ')')
        return re.sub(r'[^\w|\s]', ' ',doc)

    num_correct = 0
    for idx, raw_doc in enumerate(tqdm.tqdm(val_documents[:num_total])):
        doc = preprocess_doc(raw_doc)
        _, results = search_results_for_query_by_index(doc, max_hits=3, index=index_name)
        top_result = results[0]
        top_result_id = int(top_result['_id'])
        if top_result_id == idx: num_correct += 1
        # else: print(doc, top_result)

    
    total_num_documents = int(es.cat.count(index_name, params={'format': 'json'})[0]['count'])
    print(f'Correct: {num_correct} / {num_total} \t {total_num_documents}')
    
    return num_correct, total_num_documents

In [192]:
# List existing indices
es.indices.get_alias().keys()

dict_keys(['val_100_profile_str', 'test_100_profile_str', 'val_50_profile_str', 'val_10_profile_str', 'train_100_profile_str', 'val_n1000_profile_str', 'val_5_profile_str', 'val_1_profile_str', '.security-7'])

In [None]:
results = []
for index_name in ['val_1_profile_str', 'train_100_profile_str', 'test_100_profile_str', 'val_100_profile_str,train_100_profile_str,test_100_profile_str', 'val_100_profile_str,train_100_profile_str', 'val_100_profile_str', 'val_50_profile_str', 'val_10_profile_str', 'val_n1000_profile_str', 'val_5_profile_str']:
    num_correct, total_num_documents = count_correct_val_predictions(200, index_name)
    results.append(
        (index_name, num_correct, total_num_documents)
    )


200 val_1_profile_str


100%|██████████| 200/200 [00:01<00:00, 160.55it/s]
  total_num_documents = int(es.cat.count(index_name, params={'format': 'json'})[0]['count'])


Correct: 172 / 200 	 728
200 train_100_profile_str


100%|██████████| 200/200 [00:21<00:00,  9.49it/s]


Correct: 0 / 200 	 582659
200 test_100_profile_str


100%|██████████| 200/200 [00:05<00:00, 39.01it/s]


Correct: 0 / 200 	 72831
200 val_100_profile_str,train_100_profile_str,test_100_profile_str


 94%|█████████▍| 189/200 [00:12<00:00, 15.20it/s]

In [None]:
import pandas as pd
df = pd.DataFrame([(i, n, t) for i,n,t in results if 'val_' in i], columns=['index', 'correct', 'total'])
df.head()

In [None]:
df['tests'] = 200
df['percent_correct'] = df['correct'] / df['tests'] * 100
df.sort_values(by='total').plot(x='total', y='percent_correct')

import matplotlib.pyplot as plt
plt.ylabel('top-1 accuracy')
plt.xlabel('total number of profiles')
plt.title('BM-25 accuracy vs dataset size')