In [None]:
!pip install faiss-cpu
!pip install sentence_transformers

In [130]:
import json
from pprint import pprint
import faiss
import torch
import pickle
import numpy as np
from sentence_transformers import SentenceTransformer

In [3]:
with open('data.json', 'r') as f:
  dataset = json.load(f)

In [108]:
sentences = [d['text'] for d in dataset]

In [9]:
model = SentenceTransformer('distilbert-base-nli-stsb-mean-tokens')

Downloading:   0%|          | 0.00/345 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/4.01k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/555 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/122 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/265M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/505 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/229 [00:00<?, ?B/s]

In [110]:
sentence_embeddings = model.encode(sentences)

In [112]:
vector_space_size = sentence_embeddings[0].shape[0]

index = faiss.IndexIDMap(faiss.IndexFlatIP(vector_space_size))

In [113]:
index.add_with_ids(
    sentence_embeddings,
    np.array(range(0, len(dataset)))
)

In [114]:
def search(query: str, k=1):
  encoded_query = np.array([model.encode(query)])
  top_k = index.search(encoded_query, k)
  scores = top_k[0][0]
  results = [dataset[_id] for _id in top_k[1][0]]
  results = [{k:v for k,v in r.items() if k in ('title', 'url')} 
             for r in results]
  return list(zip(results, scores))

In [115]:
questions = ["How many people have died during Black Death?", 
  "Which diseases can be transmitted by animals?", 
  "Connection between climate change and a likelihood of a pandemic", 
  "What is an example of a latent virus", "Viruses in nanotechnology", 
  "Giant viruses classification", "What are the notable pandemic prevention organizations?", 
  "How many leprosy outbreaks are known to happen?", 
  "What are the geographic areas with the highest transmission of malaria?", 
  "How to prevent the spread of viral infections?"]

In [128]:
for q in questions:
  print('query: %s\n' % q)
  result = search(q, k=3)
  for r in result:
    print('\t title: %s, score: %s, url: %s' % 
          (r[0]['title'], r[1], r[0]['url']))
  print('\n\n')

query: How many people have died during Black Death?

	 title: Pandemic, score: 44.64847, url: https://en.wikipedia.org/wiki/Pandemic
	 title: Spanish flu, score: 42.816513, url: https://en.wikipedia.org/wiki/Spanish_flu
	 title: Pandemic prevention, score: 40.31983, url: https://en.wikipedia.org/wiki/Pandemic_prevention



query: Which diseases can be transmitted by animals?

	 title: Swine influenza, score: 89.7691, url: https://en.wikipedia.org/wiki/Swine_influenza
	 title: Targeted immunization strategies, score: 78.84324, url: https://en.wikipedia.org/wiki/Targeted_immunization_strategies
	 title: Pandemic prevention, score: 71.77265, url: https://en.wikipedia.org/wiki/Pandemic_prevention



query: Connection between climate change and a likelihood of a pandemic

	 title: Pandemic prevention, score: 117.75107, url: https://en.wikipedia.org/wiki/Pandemic_prevention
	 title: PREDICT (USAID), score: 108.271996, url: https://en.wikipedia.org/wiki/PREDICT_(USAID)
	 title: Pandemic seve

In [134]:
with open('index.pickle', 'wb') as f:
  pickle.dump(index, f)