In [1]:
import json
from elasticsearch import Elasticsearch
es = Elasticsearch()

In [2]:
# check es is up
es.info()



{'name': 'es01',
 'cluster_name': 'es-docker-cluster',
 'cluster_uuid': 'esmw9d6xTkay_4OD0qAXAA',
 'version': {'number': '7.14.0',
  'build_flavor': 'default',
  'build_type': 'docker',
  'build_hash': 'dd5a0a2acaa2045ff9624f3729fc8a6f40835aa1',
  'build_date': '2021-07-29T20:49:32.864135063Z',
  'build_snapshot': False,
  'lucene_version': '8.9.0',
  'minimum_wire_compatibility_version': '6.8.0',
  'minimum_index_compatibility_version': '6.0.0-beta1'},
 'tagline': 'You Know, for Search'}

In [3]:
# check data we load in milestone1 are still there
response = es.search(
    index="cdc",
    body= {
        "query": {
            "match_all": {}
        },
    },
)

print("number of documents in index = ", response['hits']['total']['value'])

number of documents in index =  401


In [16]:
# def a function to return 50 match results
def find_documents(query, top=50):
    response = es.search(
        index="cdc",
        body= {
            "query": {
                "match": {
                    "text": query
                }
            },
        },
        size=50
    )
    return [ doc['_source'] for doc in response['hits']['hits'] ]


find_documents("Pandemic")

[{'section_title': 'See also',
  'main_section': 'See also',
  'article_title': 'Pandemic severity index',
  'source_url': 'https://en.wikipedia.org/wiki/Pandemic_severity_index',
  'page_id': 9291245,
  'tags': 'Pandemic severity index,See also',
  'section_number': 5},
 {'section_title': 'Stages',
  'text': 'Pandemic,Assessment,Stages\nThe World Health Organization (WHO) previously applied a six-stage classification to describe the process by which a novel influenza virus moves from the first few infections in humans through to a pandemic. It starts when mostly animals are infected with a virus and a few cases where animals infect people, then moves to the stage where the virus begins to be transmitted directly between people and ends with the stage when infections in humans from the virus have spread worldwide. In February 2020, a WHO spokesperson clarified that "there is no official category [for a pandemic]".\n\nIn a virtual press conference in May 2009 on the influenza pandemic, 

In [32]:
# take a look at questions
with open("../questions.json") as f:
    questions = json.load(f)

# questions are pretty short
questions

['How many people have died during Black Death?',
 'Which diseases can be transmitted by animals?',
 'Connection between climate change and a likelihood of a pandemic',
 'What is an example of a latent virus',
 'Viruses in nanotechnology',
 'Giant viruses classification',
 'What are the notable pandemic prevention organizations?',
 'How many leprosy outbreaks are known to happen?',
 'What are the geographic areas with the highest transmission of malaria?',
 'How to prevent the spread of viral infections?']

In [40]:
# try load the distilbert-base-nli-stsb-mean-tokens sentence transformer
from sentence_transformers import SentenceTransformer, util

sentences = ["This is an example sentence", "Each sentence is converted"]
model = SentenceTransformer('sentence-transformers/distilbert-base-nli-stsb-mean-tokens')
# try a different model
# model = SentenceTransformer('msmarco-distilbert-base-tas-b')
embeddings = model.encode(sentences)
print(embeddings.shape)


(2, 768)


In [43]:
def search(question, k=10):
    documents = find_documents(question)
    passages = [ doc['text'] for doc in documents]
    article_titles = [ doc['article_title'] for doc in documents]
    section_titles = [ doc['section_title'] for doc in documents]
    #print(documents['text'])
    doc_embeddings = model.encode(passages)
    question_embedding = model.encode([question])
    # use default consine similarity
    results = []
    for corpus in util.semantic_search(question_embedding, doc_embeddings, k)[0]:
        id = corpus['corpus_id']
        score = corpus['score']
        results.append({
            'score': score,
            'article_title': article_titles[id],
            'section_title': section_titles[id]
        })
    return results

# test
responses = []
for question in questions:
    results = search(question)
    response = {
        'question': question,
        'results': results
    }
    responses.append(response)
    
    

In [45]:
# take a look at responses format
print(responses)


# write to file
with open("responses.json", "w") as f:
    json.dump(responses, f)


[{'question': 'How many people have died during Black Death?', 'results': [{'score': 0.3272359371185303, 'article_title': 'Pandemic', 'section_title': 'Tuberculosis'}, {'score': 0.26712003350257874, 'article_title': 'HIV/AIDS', 'section_title': 'Stigma'}, {'score': 0.26525598764419556, 'article_title': 'Virus', 'section_title': 'Epidemics and pandemics'}, {'score': 0.23919685184955597, 'article_title': 'HIV/AIDS', 'section_title': 'Epidemiology'}, {'score': 0.2061607837677002, 'article_title': 'Pandemic', 'section_title': 'Notable outbreaks'}, {'score': 0.20557813346385956, 'article_title': 'Pandemic severity index', 'section_title': 'Guidelines'}, {'score': 0.20442703366279602, 'article_title': 'Pandemic', 'section_title': 'Summary'}, {'score': 0.19627705216407776, 'article_title': '1929–1930 psittacosis pandemic', 'section_title': 'Origin and global spread'}, {'score': 0.19061675667762756, 'article_title': 'Epidemiology of HIV/AIDS', 'section_title': 'Caribbean'}, {'score': 0.1871424