In [1]:
import pandas as pd
import numpy as np

## Semantic Search with PubMed

### Setup

In [None]:
'''
National Library of Medicine (NLM) releases every year MEDLINE, a snapshot of all the currently available 
PubMed papers in its library (www.nlm.nih.gov/databases/download/pubmed_medline.html).

The latest snapshot for 2022 includes 33.4M paper abstracts and metadata. A lot of these are not so useful,
so following the approach of https://www.biorxiv.org/content/10.1101/2023.04.10.536208v1.full , all the papers 
with empty abstracts, with unfinished abstracts, and with non-English abstracts, were removed.

This yields 20.6M paper instances. It can be downloaded from: https://zenodo.org/records/7849020

There is a separate file for abstracts and another one for metadata (authors, journal, date...).

'''

pubmed_data = pd.read_csv("./data/pubmed_landscape_data.csv")
pubmed_data

abstracts = pd.read_csv("./data/pubmed_landscape_abstracts.csv")
abstracts


### Datasets

In [None]:
'''
Code for loading the four datasets.
'''

#HealthFC
healthfc_df = pd.read_csv("healthFC_annotated.csv")
healthfc_claims = healthfc_df.en_claim.tolist()
healthfc_labels = healthfc_df.label.tolist()
healthfc_yesno_Claims = healthfc_df[healthfc_df.label != 1].en_claim.tolist() #include only positive/negative claims
healthfc_yesno_labels = healthfc_df[healthfc_df.label != 1].label.tolist()

#SciFact
scifact_df = pd.read_csv("scifact_no-nei_dataset.csv", index_col=[0])
scifact_claims = scifact_df.claim.tolist()
scifact_labels = scifact_df.label.tolist()

#PubMedQA
pubmedqa_df = pd.read_json("pubmedqa.json")
pubmedqa_claims = pubmedqa_df.transpose().QUESTION.tolist()
pubmedqa_labels = pubmedqa_df.transpose().final_decision.tolist()


#CoVERT
covert_df = pd.read_json("CoVERT_FC_annotations.jsonl", lines=True)
covert_claims = covert_df.claim.tolist()
covert_claims = [c.replace("@username", "").replace("\n", " ") for c in covert_claims]

covert_labels = covert_df.label.tolist()
mapper = {'REFUTES':0, 'SUPPORTS':1, 'NOT ENOUGH INFO':2}
covert_labels = [mapper[l] for l in labels]
covert_yesno_indices = np.where(np.array(covert_labels) != 2)[0]
covert_yesno_claims = np.array(covert_yesno_indices)[np.array(covert_labels) != 2]


### Encode and store all the documents

In [None]:
from sentence_transformers import SentenceTransformer, util

'''
While there are many versions of S-BERT to use for sentence embeddings (https://www.sbert.net/docs/pretrained_models.html),
there are not so many available for the biomedical domain.

Some relevant ones I found were:
1) https://huggingface.co/kamalkraj/BioSimCSE-BioLinkBERT-BASE
2) https://huggingface.co/pritamdeka/S-BioBert-snli-multinli-stsb 

I opted for the first one.

I encoded all the abstracts and saved them on my disk. These experiments were started before vector databases became 
a very popular thing, so probably a better idea nowadays would be to store them in a vector DB, instead of a disk.
Still, this also works.

'''
DOCUMENT_EMBEDDING_MODEL = "kamalkraj/BioSimCSE-BioLinkBERT-BASE"
QUERY_EMBEDDING_MODEL = "kamalkraj/BioSimCSE-BioLinkBERT-BASE" #needs to be the same model so the results make sense

sentence_model = SentenceTransformer(DOCUMENT_EMBEDDING_MODEL, device="cuda:0")

        
# Generate embeddings iteratively for each 100k documents and save as "npy" (numpy format).
# This took many hours and I ran it overnight.
for step in range(1,205):
    abstracts_slice = abstracts[step*100000:(step+1)*100000].AbstractText.tolist()
    encoded_slice = sentence_model.encode(abstracts_slice)

    with open("../PubMed/embeddings/step" + str(step) + ".npy",'wb') as f:
        np.save(f, encoded_slice)
        

# Load all the documents and sort. This potentially requires a lot of RAM.
# If using a vector DB, then they can be persistend on a disk.
import glob
npfiles = glob.glob("/home/ubuntu/PubMed/embeddings/*.npy")
npfiles.sort()
npfiles          

In [None]:
import torch
from sentence_transformers import SentenceTransformer, util
import glob

# Load all the PubMed abstracts into one variable.
npfiles = glob.glob("/home/ubuntu/PubMed/embeddings/*.npy")
npfiles.sort()

all_arrays = list()
for npfile in npfiles:
    print(npfile)
    all_arrays.append(np.load(npfile))
    
stacked = np.vstack(all_arrays)
stacked.shape


# Use the same embedding model to embed the query.
sentence_model = SentenceTransformer(QUERY_EMBEDDING_MODEL, device="cpu")

#The query you wish to embed...
query = "Can oregano oil relieve discomfort or disease?"
query_embedding = sentence_model.encode(query, convert_to_tensor=True)

#Load all numpy arrays with embeddings
npfiles = glob.glob("/mnt/mydrive/PubMed/embeddings/*.npy")
npfiles.sort()
all_np_arrays = [np.load(npfile) for npfile in npfiles] 
all_np_arrays = np.array(all_np_arrays)
document_embeddings = all_np_arrays.reshape(-1, 768)

#Calculate all the cosine similarities 
similarity_values = util.cos_sim(query_embedding, document_embeddings)

#PubMed IDs (PMIDs) of the most similar documents.
most_similar_ids = torch.topk(similarity_values, 10).indices[0].tolist()

#Cosine similarity values between the query and most similar documents.
most_similar_values = torch.topk(similarity_values, 10).values[0].tolist()


*Example output*
    
PMIDs: [11668665,
  1898212,
  6998039,
  9409196,
  6434369,
  10998745,
  18133290,
  7897606,
  15190016,
  7321689]
  
  
Similarities: [0.6098220944404602,
  0.5907589197158813,
  0.5672659277915955,
  0.5657244920730591,
  0.5642852187156677,
  0.5642617344856262,
  0.562914252281189,
  0.5624894499778748,
  0.5568264722824097,
  0.5562151074409485]
 


### Find top 10 most similar docs

In [None]:
'''
Now find the top 10 most similar documents for every query from a claim verification / fact-checking dataset.

Let's use SciFact for example (https://aclanthology.org/2020.emnlp-main.609.pdf).

'''

from sentence_transformers import SentenceTransformer, util
import glob
import torch 

claims = scifact_claims

#Load all numpy arrays with embeddings
npfiles = glob.glob("/mnt/mydrive/PubMed/embeddings/*.npy")
npfiles.sort()
all_np_arrays = list()
for npfile in npfiles:
    all_np_arrays.append(np.load(npfile))    
all_np_arrays = np.array(all_np_arrays)

#This will created an array with shape (20M, 768), with all document embeddings. 
document_embeddings = all_np_arrays.reshape(-1, 768)

#Load query embedding model.
sentence_model = SentenceTransformer(QUERY_EMBEDDING_MODEL, device="cpu")

with open("pubmed_scifact_pmids.txt", "w") as f:
    for query in claims:
        query_embedding = sentence_model.encode(query, convert_to_tensor=True)
        
        sims = util.cos_sim(query_embedding, document_embeddings)
        f.write(query)
        f.write("\n")
        f.write(str(torch.topk(sims, 10).values[0].tolist()))
        f.write("\n")
        f.write(str(torch.topk(sims, 10).indices[0].tolist()))
        f.write("\n")
        f.write("\n")

  

### Select the evidence sentences

In [None]:
'''
Once we selected top 10 documents for each claim, next step in the fact-checking pipeline is to find the most relevant
sentences from these abstracts to use as "evidence sentences". These sentences are then concated with the claim and
saved again. This format ("claim [SEP] evidence1 evidence2 ... evidenceN") will then be fed to an NLI model in the
last step, to get a prediction (entailment, contradiction, neutral => supported, refuted, not enough info).

The sentence embedding model used to select evidence sentences is: https://huggingface.co/copenlu/spiced , which was 
shown to perform well on the task of selecting evidence for scientific claims (https://aclanthology.org/2022.emnlp-main.117.pdf).
'''  

import ast
from nltk.tokenize import sent_tokenize, word_tokenize
import torch
from sentence_transformers import SentenceTransformer, util


# Load all claims and document PMIDs from the file.
all_claims = list()
all_ids = list()
with open("scifact_results.txt", "r") as f:
    lines = [line.rstrip() for line in f]
    
    idx = 0
    for line in lines[:10]:
        if idx%4==0:
            claim = line
            all_claims.append(claim)
        elif idx%4==1:
            scores = ast.literal_eval(line)
        elif idx%4==2:
            ids = ast.literal_eval(line)
            all_ids.append(ids)
        
        idx += 1

#Load all PubMed abstracts.
abstracts = pd.read_csv("/mnt/mydrive/PubMed/pubmed_landscape_abstracts.csv")
abstracts_text = abstracts.AbstractText.tolist()
print("loaded abstracts!")


#Load all sentences of all 10 abstracts for each claim into a big list of lists.
claim_sentences = list()
for ids in all_ids:
    all_sentences = list()
    
    for doc_id in ids:
        abstract = abstracts_text[doc_id]
        sentences = sent_tokenize(abstract)
        all_sentences.extend(sentences)
        all_sentences = [s.lower() for s in all_sentences]   
    claim_sentences.append(all_sentences)
#print("collected all sentences from abstracts!")   

#Load sentence transformer for selecting evidence sentences.
SENTENCE_EMBEDDING_MODEL = 'copenlu/spiced'
model = SentenceTransformer(SENTENCE_EMBEDDING_MODEL)
#print("loaded sentence model!")


#Find top 10 sentences for each claim (top 10 performed well, but top 5 could maybe bring less noise)
top_sentences = list()
for idx in range(len(all_claims)):
    claim = all_claims[idx]
    sents = claim_sentences[idx]
    
    sents_embeddings = model.encode(sents, convert_to_tensor=True)
    claim_embedding = model.encode(claim, convert_to_tensor=True)
    cos_scores = util.cos_sim(claim_embedding, sents_embeddings)[0]
    top_results = torch.topk(cos_scores, k=10)
    
    np_results = top_results[1].detach().cpu().numpy()
    top_sentences.append(np_results)


selected_sentences = list()
for idx in range(len(all_claims)):
    top = top_sentences[idx]
    top = np.sort(top)
    sents = np.array(claim_sentences[idx])[top]    

    selected_sentences.append(sents)
  
 # Create a joint list of concatenated claims and evidence, in form of "claim [SEP] evidence1 evidence2 ... evidenceN"
joint_list = list()
for idx in range(len(all_claims)):
    joint = all_claims[idx] + " [SEP] "
    for s in selected_sentences[idx]:
        joint += s
        joint += " "
    joint_list.append(joint)

    
#Save this in a file before the final step.
with open("scifact_joint_lines.txt", "w") as f:
	for example in joint_list:
		f.write(example)
		f.write("\n")



### Predict the veracity labels

In [None]:
'''
This is the final step, where the claim-evidence pairs get a verdict prediction from an NLI model.
These predictions are compared to the gold labels from manual annotators and we get some F1 metrics.

The NLI model used is DeBERTa-v3, fine-tuned on some popular NLI datasets:
https://huggingface.co/MoritzLaurer/DeBERTa-v3-large-mnli-fever-anli-ling-wanli
This is probably the best encoder-only model for NLI tasks (see the GLUE leaderboard for some others).

'''

from transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification
from torch.utils.data import DataLoader

NLI_MODEL = "MoritzLaurer/DeBERTa-v3-large-mnli-fever-anli-ling-wanli"
   
tokenizer = AutoTokenizer.from_pretrained(NLI_MODEL, model_max_length=1024)
model = AutoModelForSequenceClassification.from_pretrained(NLI_MODEL)

#The dataset class, consiting of encoding of the joint line and its gold label.
class CtDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)


#Creates a NumPy array with results predictions.
def get_result(joint_list, indices, model, tokenizer):
    nli_test = joint_list
    nli_encoded = tokenizer(nli_test, return_tensors='pt',
                             truncation_strategy='only_first', add_special_tokens=True, padding=True)
    nli_dataset = CtDataset(nli_encoded, np.zeros(len(nli_test)))

    test_loader = DataLoader(nli_dataset, batch_size=16,
                             drop_last=False, shuffle=False, num_workers=4)

    model.eval()
    model = model.to("cuda")
    
    result = np.zeros(len(test_loader.dataset))    
    index = 0

    with torch.no_grad():
        for batch_num, instances in enumerate(test_loader):
            input_ids = instances["input_ids"].to("cuda")
            attention_mask = instances["attention_mask"].to("cuda")
            logits = model(input_ids=input_ids,
                                          attention_mask=attention_mask)[0]
            probs = logits.softmax(dim=1)

            #If the entailment score was bigger than contradiction score, predict "SUPPORTED" (positive). 
            #Otherwise, predict "REFUTED" (negative).
            pred = probs[:,0] > probs[:,2]
            pred = np.array(pred.cpu()).astype(int)

            result[index : index + pred.shape[0]] = pred.flatten()
            index += pred.shape[0]

    return result


def print_scores(actual_values, predicted_values):
    # Calculate precision
    precision = precision_score(actual_values, predicted_values, average = "binary")

    # Calculate recall
    recall = recall_score(actual_values, predicted_values, average = "binary")

    # Calculate F1 score
    f1 = f1_score(actual_values, predicted_values, average = "binary")

    # Calculate accuracy
    accuracy = accuracy_score(actual_values, predicted_values)

    # Print the results
    print("Precision:", precision)
    print("Recall:", recall)
    print("F1 Score:", f1)
    print("Accuracy:", accuracy)

#Prediction
prediction_result = get_result(scifact_claims, model, tokenizer)

#Predict scores
print_scores(scifact_labels, prediction_result)

*Results for SciFact*

Precision: 0.7373737373737373

Recall: 0.8004385964912281

F1 Score: 0.7676130389064143

Accuracy: 0.6810966810966811

In [None]:
'''
That's it!

The results would probably be better if there was some more optimization (like selecting different number of top k 
documents or sentences, using a re-ranker, etc.), and more fine-tuning of models on some other fact-checking datasets,
or joint training of evidence selection & verdict prediction (with a shared loss function). This is left for future work.

But for a zero-shot, out-of-the-box solution, this is a nice pipeline and performance close to gold-evidence settings.

'''

## BM25 Search with PubMed

In [None]:
'''
Other than semantic search, the process can also be done with the classic BM25 search.

There is a library called 'retriv' that provides a simple function to build an inverted index out of your document corpus.
It took almost two hours to construct the index, but after that it can be pickled and used easily. 
Batch search of 1000 claims at once takes milliseconds over the index.

'''

from retriv import SparseRetriever

#Create the SparseRetriever object that will be used for BM25 search.
sr = SparseRetriever(
  index_name="pubmed-index",
  model="bm25",
  min_df=10,
  tokenizer="whitespace",bm25
    
  stemmer="english",
  stopwords="english",
  do_lowercasing=True,
  do_ampersand_normalization=True,
  do_special_chars_normalization=True,
  do_acronyms_normalization=True,
  do_punctuation_removal=True,
)

corpus_path = "/mnt/mydrive/PubMed/pubmed_landscape_abstracts.csv"


#Construct the inverted index.
import time
start = time.time()

sr = sr.index_file(
  path=corpus_path,  # File kind is automatically inferred
  show_progress=True,         # Default value
  callback=lambda doc: {      # Callback defaults to None.
    "id": doc["PMID"],
    "text": doc["AbstractText"],          
    }
  )

duration = time.time() - start
print(duration)
#Duration: 5772.615013837814


#Pickle the file.
import pickle
file = open('/mnt/mydrive/pickled_sr', 'wb')
pickle.dump(sr, file)


In [None]:
#Search for a single query
sr.search(
  query="colorectal cancer high fiber diet",    # What to search for        
  return_docs=True,          # Default value, return the text of the documents
  cutoff=10,                # Default value, number of results to return
)


#Batch search for the whole dataset
query_list = list()

idx = 0
for c in claims:
    c = c.lower()
    d = dict()
    d["id"] = str(idx)
    d["text"] = c
    query_list.append(d)
    idx += 1

results = sr.msearch(
  queries=query_list,
  cutoff=10,
)
print(results)

#Print all the results
with open("bm25_scifact_pmids.txt", "w") as f:
    f.write(str(results))

#### Retrieval part changed, the rest of the pipeline (evidence selection and verdict prediction) is the same as before




Top 10 sentences for each claim are selected and then the verdict predicted with the NLI model.

*Results for SciFact with BM25 were*

Precision: 0.7995169082125604
    
Recall: 0.7258771929824561
    
F1 Score: 0.7609195402298851
    
Accuracy: 0.6998556998556998
    
The overall F1 score is a bit lower, but precision is a lot better. On the other hand, semantic search wins in recall.

Kind of intuitively expected, but still a nice result! :)

#### For Wikipedia, the same process of semantic search and BM25 search is followed, but using the enwiki dump from the Wikimedia website.

https://dumps.wikimedia.org/backup-index.html

## Google Search

In [None]:
from googleapiclient.discovery import build
from googleapiclient.errors import HttpError
from typing import Final

import string 
import time 


'''
Google Search done using the Google Custom Search API (https://developers.google.com/custom-search/v1/overview).

The search is queried using each of the claims. Top 10 results are retrieved and their preview snippets taken as "evidence sentences".
The final lines will be in form "claim [SEP] evidence1 evidence2 ... evidenceN", i.e., "claim [SEP] snippet1 snippet2 ... snippetN"

These evidence sentences are then concatenated with the claim and verified with the same fact-checking (NLI) workflow as before.
'''

GOOGLE_API_KEY = "YOUR_API_KEY"
GOOGLE_CSE_ID = "YOUR_CSE_ID"


search_results = {"engine": "Google", "results": []}
    
# Query parameters list: https://developers.google.com/custom-search/v1/reference/rest/v1/cse/list
def search(search_term: str, **kwargs):
    service = build("customsearch", "v1", developerKey=GOOGLE_API_KEY)
    default_options = {
        "c2coff": "1",
        "fields": "items(title,formattedUrl)",
    }
    default_options.update(kwargs)
    try:
        response = (
            service.cse()
            .list(q=search_term, cx=GOOGLE_CSE_ID)
            .execute()
        )
        """    
        pages_list = response["items"]
        if pages_list:
            formated_response = [
                {"title": page["title"], "url": page["formattedUrl"]}
                for page in pages_list
            ]
        search_results = {"results": formated_response, "pages_list":pages_list}
        """ 
        return response
    except HttpError as e:
        print(e)


idx = 0
all_links = list()
all_titles = list()
all_snippets = list()

#Load the claim of your dataset to the 'claims' variable.
for claim in claims:
    time.sleep(0.1)
    google_result_list = search(claim)

    links = list()
    titles = list()
    snippets = list()
    
    if "items" in google_result_list:
        for item in google_result_list["items"]:
            links.append(item["link"])
            titles.append(item["title"])
            snippets.append(item["snippet"])

    all_links.append(links)
    all_titles.append(titles)
    all_snippets.append(snippets)
    idx += 1


with open("scifact_google_results.txt", "a") as f:
    for idx in range(len(all_links)):
        ls = all_links[idx]
        ts = all_titles[idx]
        ss = all_snippets[idx]
    
        if idx != 0:
            f.write("CLAIM")
            f.write("\n")
        f.write(all_claims[idx])
        f.write("\n")
        f.write("EVIDENCE")
        f.write("\n")
        
        if len(ts)==0:
            f.write("No results found!")
            f.write("\n")
            f.write("\n")
            continue
        
        for i in range(len(ts)):
            f.write(ts[i])
            f.write("\n")
            f.write(ls[i])
            f.write("\n")
            f.write(ss[i])
            f.write("\n")
        
        f.write("\n")
            
                