## Load dataset

In [53]:
from datasets import load_dataset

qrels = load_dataset("BeIR/trec-covid-qrels")
corpus = load_dataset("BeIR/trec-covid", "corpus")
queries = load_dataset("BeIR/trec-covid", "queries")

Downloading readme:   0%|          | 0.00/14.0k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/981k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/66336 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/113M [00:00<?, ?B/s]

Generating corpus split:   0%|          | 0/171332 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/4.46k [00:00<?, ?B/s]

Generating queries split:   0%|          | 0/50 [00:00<?, ? examples/s]

## Examine the corpus

In [54]:
print(f"The document dataset structure is: {corpus}")
print(f"The first few IDs are: {corpus['corpus']['_id'][0:5]}")

first_documents = corpus['corpus']['text'][0:5]
for i, document in enumerate(first_documents):
    print(f"Document {i} is: {document}")

The document dataset structure is: DatasetDict({
    corpus: Dataset({
        features: ['_id', 'title', 'text'],
        num_rows: 171332
    })
})
The first few IDs are: ['ug7v899j', '02tnwd4m', 'ejv2xln0', '2b73a28n', '9785vg6d']
Document 0 is: OBJECTIVE: This retrospective chart review describes the epidemiology and clinical features of 40 patients with culture-proven Mycoplasma pneumoniae infections at King Abdulaziz University Hospital, Jeddah, Saudi Arabia. METHODS: Patients with positive M. pneumoniae cultures from respiratory specimens from January 1997 through December 1998 were identified through the Microbiology records. Charts of patients were reviewed. RESULTS: 40 patients were identified, 33 (82.5%) of whom required admission. Most infections (92.5%) were community-acquired. The infection affected all age groups but was most common in infants (32.5%) and pre-school children (22.5%). It occurred year-round but was most common in the fall (35%) and spring (30%). More than

## Examine queries

In [55]:
print(f"The query dataset structure is: {queries}")
print(f"The first few IDs are: {queries['queries']['_id'][0:5]}")

first_queries = queries['queries']['text'][0:5]
for i, query in enumerate(first_queries):
    print(f"Query {i} is: {query}")

The query dataset structure is: DatasetDict({
    queries: Dataset({
        features: ['_id', 'title', 'text'],
        num_rows: 50
    })
})
The first few IDs are: ['1', '2', '3', '4', '5']
Query 0 is: what is the origin of COVID-19
Query 1 is: how does the coronavirus respond to changes in the weather
Query 2 is: will SARS-CoV2 infected people develop immunity? Is cross protection possible?
Query 3 is: what causes death from Covid-19?
Query 4 is: what drugs have been active against SARS-CoV or SARS-CoV-2 in animal studies?


## Examine reference judgements

In [57]:
print(f"The qrels dataset structure is: {qrels}")
print(f"The query-ids look like: {qrels['test']['query-id'][0:5]}")
print(f"The corpus-ids look like: {qrels['test']['corpus-id'][0:5]}")
print(f"The total number of queries is {len(qrels['test']['query-id'])}. The number of unique queries is {len(set(qrels['test']['query-id']))}.") # Use dedup trick of converting list to set

The qrels dataset structure is: DatasetDict({
    test: Dataset({
        features: ['query-id', 'corpus-id', 'score'],
        num_rows: 66336
    })
})
The query-ids look like: [1, 1, 1, 1, 1]
The corpus-ids look like: ['005b2j4b', '00fmeepz', 'g7dhmyyo', '0194oljo', '021q9884']
The total number of queries is 66336. The number of unique queries is 50.


## Create search database
Elastic search installation is easy. You can follow the steps here: https://www.elastic.co/guide/en/elasticsearch/reference/7.17/brew.html. If you get Java errors you may need to cleanup your Java installation, but it is farily easy. Ask a language model for assistance. I needed to run: <br>
`brew link --force --overwrite openjdk@11` <br>
`brew reinstall openjdk@11` <br>
`export JAVA_HOME="/opt/homebrew/opt/openjdk@11/libexec/openjdk.jdk/Contents/Home" >> ~/.zshrc` <br>
`export PATH="$JAVA_HOME/bin:$PATH" >> ~/.zshrc` <br>
`source ~/.zshrc` <br>

Then you can simply use: <br>
Start Elasticsearch: `brew services start elastic/tap/elasticsearch-full` <br>
Check status: `brew services list` <br>
Stop Elasticsearch: `brew services stop elastic/tap/elasticsearch-full` <br>
To stop Elasticsearch: press Ctrl+C in the terminal where Elasticsearch is running. <br>
Or... <br>
Start manually: `/opt/homebrew/opt/elasticsearch-full/bin/elasticsearch` # add & at end to run in background <br>
To stop it from runnning manually: `pkill -f elasticsearch` <br>
To check if it is running: `curl http://localhost:9200` <br>

In [151]:
from elasticsearch import Elasticsearch, helpers
import random

# Connect to your local Elasticsearch instance
es = Elasticsearch(["http://localhost:9200"])

# Check if the connection was successful
if es.ping():
    print("Connected to Elasticsearch")
else:
    print("Could not connect to Elasticsearch")

def get_random_doc_indicies(document_list_length, range_size):
    return random.sample(range(document_list_length), range_size)

def get_document_subset(full_document_list, document_indicies):
    return [full_document_list[index] for index in document_indicies]

def generate_index(documents, index_name):
    for doc in documents:
        doc_id = doc.pop("_id")
        yield {
            "_index": index_name,
            "_id": doc_id,
            "_source": doc
        }

def bulk_index_documents(es_client, documents, index_name):
    try:
        success, _ = helpers.bulk(es_client, generate_index(documents, index_name))
        print(f"Successfully indexed {success} documents")
    except Exception as e:
        print(f"Encountered error: {e}")

# Delete indicies to start fresh
for index in es.indices.get_alias(index="*"):
    es.indices.delete(index=index)

mappings = {
    "properties": {
        "title": {
            "type": "text"
        },
        "text": {
            "type": "text"
        }
    }
}

random.seed(11)
index_size = 100 # start small
document_list_length = len(corpus['corpus'])
doc_indicies = get_random_doc_indicies(document_list_length, index_size)
documents = get_document_subset(corpus['corpus'], doc_indicies)
es.indices.create(index="path_index", mappings=mappings)
bulk_index_documents(es, documents, "path_index")

  if es.ping():
  for index in es.indices.get_alias(index="*"):
  es.indices.create(index="path_index", mappings=mappings)


Connected to Elasticsearch
Successfully indexed 100 documents
There were [] errors


  success, error = helpers.bulk(es_client, generate_index(documents, index_name))


# Run a test query
We can now run a test query. Here, I just made up a simple query since this dataset contains information on COVID-19. Note that the Elasticsearch documnetation recommends creating a pretty-print function as the search results response is a hard-to-read JSON object. See quickstart for an example: https://github.com/elastic/elasticsearch-labs/blob/main/notebooks/search/00-quick-start.ipynb. Elasticsearch has numerous search times, but BM25 is the default so there are no additional fields required.

In [156]:
# The documentation recommends createing a pretty print function
# See quickstart for an example: https://github.com/elastic/elasticsearch-labs/blob/main/notebooks/search/00-quick-start.ipynb
def pretty_print(results):
    if len(results["hits"]["hits"]) == 0:
        print(f"Serach returned no results.")
    else:
        for hit in results["hits"]["hits"]:
            id = hit["_id"]
            title = hit["_source"]["title"]
            text = hit["_source"]["text"]
            score = hit["_score"]
            pretty_output = f"\nID: {id}\n Title:{title}\n Text: {text}\n score: {score}"
            print(pretty_output)
            print("------")

query_text = "What is the best defense against COVID?"

# No need to specify BM25 as that is the default for elastic search
query = {
    "match": {
        "text": query_text
    }
}

num_search_results = 10
results = es.search(index="path_index", query=query, size=num_search_results)
pretty_print(results) # It works!


ID: yybunc6z
 Title:An in-silico evaluation of different Saikosaponins for their potency against SARS-CoV-2 using NSP15 and fusion spike glycoprotein as targets
 Text: The Public Health Emergency of International Concern declared the widespread outbreak of SARS-CoV-2 as a global pandemic emergency, which has resulted in 1,773,086 confirmed cases including 111,652 human deaths, as on 13 April 2020, as reported to World Health Organization. As of now, there are no vaccines or antiviral drugs declared to be officially useful against the infection. Saikosaponin is a group of oleanane derivatives reported in Chinese medicinal plants and are described for their anti-viral, anti-tumor, anti-inflammatory, anticonvulsant, antinephritis and hepatoprotective activities. They have also been known to have anti-coronaviral property by interfering the early stage of viral replication including absorption and penetration of the virus. Thus, the present study was undertaken to screen and evaluate the 

  results = es.search(index="path_index", query=query, size=num_search_results)


# Basic integration with language model
We can now test having a language model look at a document in the corpus, create a synthetic query based on that document, search for documents using that query, and then review the results.

In [169]:
import os
import json
from openai import OpenAI

# Set your API key
client = OpenAI(
    # This is the default and can be omitted
    api_key=os.environ.get("OPENAI_API_KEY"),
)

index_size = 1
document_list_length = len(corpus['corpus'])
doc_index = get_random_doc_indicies(document_list_length, index_size)
document = get_document_subset(corpus['corpus'], doc_index)
document_text = document[0]["text"]
print(f"Document text is: {document_text}")

prompt = f"""
Please generate a query for which the following text would be returned:
{document_text}

Return the associated query in a JSON:
{{
    "query": {{query}}
}}
"""

response = client.chat.completions.create(
    model="gpt-4-turbo",
    response_format={ "type": "json_object" },
    messages=[
        {"role": "user", "content": prompt}
    ]
)

gpt_response = json.loads(response.choices[0].message.content)
print(gpt_response)

query = {
    "match": {
        "text": gpt_response["query"]
    }
}

num_search_results = 10
results = es.search(index="path_index", query=query, size=num_search_results)
pretty_print(results) # It works!

Document text is: BACKGROUND: Ultrasonography allows for non-invasive examination of the liver and spleen and can further our understanding of schistosomiasis morbidity. METHODOLOGY/PRINCIPAL FINDINGS: We followed 578 people in Southwest China for up to five years. Participants were tested for Schistosoma japonicum infection in stool and seven standard measures of the liver and spleen were obtained using ultrasound to evaluate the relationship between schistosomiasis infection and ultrasound-detectable pathology, and the impact of targeted treatment on morbidity. Parenchymal fibrosis, a network pattern of the liver unique to S. japonicum, was associated with infection at the time of ultrasound (OR 1.40, 95% CI: 1.03–1.90) and infection intensity (test for trend, p = 0.002), adjusting for age, sex and year, and more strongly associated with prior infection status and intensity (adjusted OR 1.84, 95% CI: 1.30–2.60; test for trend: p<0.001 respectively), despite prompt treatment of infect

  results = es.search(index="path_index", query=query, size=num_search_results)


# Test with DistilBERT

In [176]:
import torch
from transformers import DistilBertTokenizer, DistilBertModel
from torch import nn

tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

# Examine tokenization on sample document
print(tokenizer.tokenize(document_text)) # Use document text from previous step

# Simple check for any unknown tokens
tokens = tokenizer.tokenize(document_text)
ids = tokenizer.convert_tokens_to_ids(tokens)
reconstructed  = tokenizer.convert_ids_to_tokens(ids)
print(reconstructed)

unknown_token_count = reconstructed.count("UNK")
if unknown_token_count  == 0:
    print("All document tokens in vocabulary")
else:
    print(f"{unknown_token_count} unknown tokens in document")


['background', ':', 'ultra', '##son', '##ography', 'allows', 'for', 'non', '-', 'invasive', 'examination', 'of', 'the', 'liver', 'and', 'sp', '##leen', 'and', 'can', 'further', 'our', 'understanding', 'of', 'sc', '##his', '##tos', '##omi', '##asi', '##s', 'mor', '##bid', '##ity', '.', 'methodology', '/', 'principal', 'findings', ':', 'we', 'followed', '57', '##8', 'people', 'in', 'southwest', 'china', 'for', 'up', 'to', 'five', 'years', '.', 'participants', 'were', 'tested', 'for', 'sc', '##his', '##tos', '##oma', 'ja', '##pon', '##icum', 'infection', 'in', 'stool', 'and', 'seven', 'standard', 'measures', 'of', 'the', 'liver', 'and', 'sp', '##leen', 'were', 'obtained', 'using', 'ultrasound', 'to', 'evaluate', 'the', 'relationship', 'between', 'sc', '##his', '##tos', '##omi', '##asi', '##s', 'infection', 'and', 'ultrasound', '-', 'detect', '##able', 'pathology', ',', 'and', 'the', 'impact', 'of', 'targeted', 'treatment', 'on', 'mor', '##bid', '##ity', '.', 'par', '##en', '##chy', '##mal

In [199]:
class DistilBERTReranker(nn.Module):
    def __init__(self):
        super(DistilBERTReranker, self).__init__()
        self.distilbert = DistilBertModel.from_pretrained('distilbert-base-uncased')
        self.classifer = nn.Linear(self.distilbert.config.hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.distilbert(input_ids = input_ids, attention_mask=attention_mask)
        pooled_output = outputs[0][:, 0]
        return self.classifer(pooled_output)

def prepare_input(query, document):
    return tokenizer.encode_plus(
        query,
        document,
        add_special_tokens = True,
        max_length = 512,
        truncation = True,
        padding = 'max_length',
        return_tensors = 'pt'
    )

# Test class methods
query_text = "What is the best defense against COVID?" # Same as before
prepared_input = prepare_input(query_text, document_text)
input_ids, attention_mask = prepared_input['input_ids'], prepared_input['attention_mask']

reranker = DistilBERTReranker()
reranking_score = reranker.forward(input_ids, attention_mask)
reranking_score.item()


# Test reranking results

In [217]:
results = es.search(index="path_index", query=query, size=num_search_results)

unranked_results = {}

for i, hit in enumerate(results["hits"]["hits"], start=1):
    document_text = hit["_source"]["text"]
    with torch.no_grad():
        prepared_input = prepare_input(query_text, document_text)
        input_ids, attention_mask = prepared_input['input_ids'], prepared_input['attention_mask']
        reranking_score = reranker.forward(input_ids, attention_mask)
        unranked_results[i] = {"score": reranking_score, "document_text": document_text}

print(unranked_results)
reranked_results = sorted(unranked_results.items(), key=lambda item: item[1]["score"].item(), reverse = True)

  results = es.search(index="path_index", query=query, size=num_search_results)
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.


{1: {'score': tensor([[-0.1967]]), 'document_text': 'Seoul virus (SEOV) is a zoonotic orthohantavirus carried by black and brown rats, and can cause hemorrhagic fever with renal syndrome in humans. Human cases of SEOV virus infection have most recently been reported in the USA, United Kingdom, France and the Netherlands and were primarily associated with contact with pet rats and feeder rats. Infection of rats results in an asymptomatic but persistent infection. Little is known about the cell tropism of SEOV in its reservoir and most available data is based on experimental infection studies in which rats were inoculated via a route which does not recapitulate virus transmission in nature. Here we report the histopathological analysis of SEOV cell tropism in key target organs following natural infection of a cohort of feeder rats, comprising 19 adults and 11 juveniles. All adult rats in this study were positive for SEOV specific antibodies and viral RNA in their tissues. One juvenile ra

# Test constructing hard negatives

In [227]:
import os
import json
from openai import OpenAI

# Set your API key
client = OpenAI(
    # This is the default and can be omitted
    api_key=os.environ.get("OPENAI_API_KEY"),
)

index_size = 1
document_list_length = len(corpus['corpus'])
doc_index = get_random_doc_indicies(document_list_length, index_size)
positive_document = get_document_subset(corpus['corpus'], doc_index)
positive_document_text = positive_document[0]["text"]
positive_document_id = positive_document[0]["_id"]
print(f"Document text is: {document_text}")

prompt = f"""
Please generate a query for which the following text would be returned:
{document_text}

Return the associated query in a JSON:
{{
    "query": {{query}}
}}

Your response:

{{"query": }}
"""

response = client.chat.completions.create(
    model="gpt-4-turbo",
    response_format={ "type": "json_object"},
    messages=[
        {"role": "user", "content": prompt}
    ]
)

gpt_response = json.loads(response.choices[0].message.content)
print(f"GPT-4o reponse is: {gpt_response}")
query_text = gpt_response["query"]

query = {
    "match": {
        "text": query_text
    }
}

num_search_results = 20
hard_negative_start = 5
num_hard_negatives = 3
results = es.search(index="path_index", query=query, size=num_search_results)

tuple = {"query": query_text, "positive_document": document_text}
for i, hit in enumerate(results["hits"]["hits"][hard_negative_start:(hard_negative_start + num_hard_negatives)], start=1):
    if hit["_id"] == positive_document_id:
        continue
    else:
        tuple[f"hard_negative_{i}"] = hit["_source"]["text"]

tuple

Document text is: OBJECTIVE: We are presently going through a historic and unprecedented crisis for humanity with SARS-CoV-2 causing immense damage to life and world economics. It has been 3 months, since we had the first cluster in China and we felt the need to look into certain regional patterns of transmission of the virus with respect to some distinctive living conditions, incidence of malaria, the genomics of different strains, and its impact on severity. MATERIAL AND METHODS: Data for 107 countries was compiled and correlation analysis was done between incidence of malaria and number of SARS-CoV-2 cases. Possibility of genetic similarity between SARS-CoV-2 and reported zoonotic RNA viruses found associated previously with some Plasmodium species was explored by utilizing NCBI database. RESULTS: We found a significant inverse correlation between SARS-CoV-2 disease burden and incidence of Malaria. Our analysis also showed that a 12 base pair region encoding a part of surface glycop

  results = es.search(index="path_index", query=query, size=num_search_results)


{'query': 'What are the findings of a study exploring the correlation between malaria incidence and SARS-CoV-2 cases, and the potential genetic similarities between SARS-CoV-2 and zoonotic RNA viruses related to Plasmodium species?',
 'positive_document': 'OBJECTIVE: We are presently going through a historic and unprecedented crisis for humanity with SARS-CoV-2 causing immense damage to life and world economics. It has been 3 months, since we had the first cluster in China and we felt the need to look into certain regional patterns of transmission of the virus with respect to some distinctive living conditions, incidence of malaria, the genomics of different strains, and its impact on severity. MATERIAL AND METHODS: Data for 107 countries was compiled and correlation analysis was done between incidence of malaria and number of SARS-CoV-2 cases. Possibility of genetic similarity between SARS-CoV-2 and reported zoonotic RNA viruses found associated previously with some Plasmodium species