In [1]:
import json
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer

In [2]:
with open('data.json', 'r') as file:
    documents = json.load(file)

sentences = [document['text'] for document in documents ]
sentences

['A pandemic (from Greek πᾶν, pan, "all" and δῆμος, demos, "people") is an epidemic of an infectious disease that has spread across a large region, for instance multiple continents or worldwide, affecting a substantial number of people. A widespread endemic disease with a stable number of infected people is not a pandemic. Widespread endemic diseases with a stable number of infected people such as recurrences of seasonal influenza are generally excluded as they occur simultaneously in large regions of the globe rather than being spread worldwide.\nThroughout human history, there have been a number of pandemics of diseases such as smallpox and tuberculosis. The most fatal pandemic in recorded history was the Black Death (also known as The Plague), which killed an estimated 75–200 million people in the 14th century. The term was not used yet but was for later pandemics including the 1918 influenza pandemic (Spanish flu). Current pandemics include COVID-19 (SARS-CoV-2) and HIV/AIDS.',
 'H

In [11]:
model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
embeddings = model.encode(sentences)

vector_size=embeddings.shape[1]
vector_size


384

In [19]:
def encode(text):
    return model.encode(text)

encode('test tset test')

array([-5.07164121e-01,  1.28933340e-01,  2.17811242e-01, -3.82991314e-01,
       -7.57071003e-02, -3.25809211e-01,  8.33333790e-01,  7.28661776e-01,
       -4.27888751e-01, -5.99833310e-01,  2.54839271e-01, -5.28913558e-01,
       -1.66156009e-01, -1.53139159e-01, -3.90681267e-01, -8.19487572e-01,
       -1.33167639e-01, -6.77356541e-01,  5.05305111e-01, -3.60060692e-01,
       -9.37046930e-02, -2.39226401e-01, -1.68991864e-01, -2.23630741e-01,
       -3.15550804e-01, -3.20087940e-01,  5.95741749e-01,  3.26300949e-01,
       -3.09168398e-02,  1.17323458e-01, -9.61494073e-02,  5.28547287e-01,
       -1.03682852e+00, -2.67857403e-01,  1.72391489e-01, -4.38796759e-01,
       -9.62218046e-02, -6.51251137e-01, -6.77947521e-01, -3.99700366e-02,
        2.60363191e-01, -4.79452051e-02, -1.30499855e-01,  5.03033817e-01,
        5.22601008e-01,  3.75526309e-01, -5.17159961e-02, -3.83910947e-02,
       -5.20588219e-01,  3.49465609e-01,  4.33636636e-01, -2.92894244e-01,
        2.56942004e-01, -

In [15]:
# Create a flat Faiss index
index = faiss.IndexIDMap(faiss.IndexFlatIP(vector_size)) # the size of our vector space

# Add the document vectors into the index. They need to be transformed into numpy arrays first
index.add_with_ids(
    embeddings,
    # the IDs will be 0 to len(documents)
    np.array(range(0, len(sentences)))
)

In [25]:
def search(query: str, k=5):
    encoded_query = model.encode(query).reshape((1, -1))
    top_k = index.search(encoded_query, k)
    scores = top_k[0][0]
    results = [documents[_id]['title'] for _id in top_k[1][0]]
    return list(zip(results, scores))

In [26]:
with open('questions.json') as f:
    questions = json.load(f)
    
for question in questions:
    print("Question: ", question)
    answers = search(question, 5)
    for i, answer in enumerate(answers):
        print(f"Answer{i}: ",  answer)

Question:  How many people have died during Black Death?
Answer0:  ('Epidemiology of HIV/AIDS', 13.177227)
Answer1:  ('Spanish flu', 12.015366)
Answer2:  ('Bills of mortality', 10.355619)
Answer3:  ('COVID-19 pandemic', 9.897232)
Answer4:  ('HIV/AIDS in Yunnan', 9.320352)
Question:  Which diseases can be transmitted by animals?
Answer0:  ('Swine influenza', 15.895811)
Answer1:  ('Basic reproduction number', 13.09914)
Answer2:  ('Disease X', 13.017599)
Answer3:  ('Virus', 12.970383)
Answer4:  ('Pandemic', 12.493589)
Question:  Connection between climate change and a likelihood of a pandemic
Answer0:  ('Pandemic', 13.284798)
Answer1:  ('PREDICT (USAID)', 12.460386)
Answer2:  ('COVID-19 pandemic', 11.854155)
Answer3:  ('Pandemic prevention', 10.9414215)
Answer4:  ('Pandemic Severity Assessment Framework', 10.637194)
Question:  What is an example of a latent virus
Answer0:  ('Virus', 18.3499)
Answer1:  ('Viral load', 15.907722)
Answer2:  ('Swine influenza', 13.991492)
Answer3:  ('COVID-19 