In [None]:
!pip install faiss-cpu
!pip install transformers
!pip install sentence_transformers

In [2]:
import faiss
import torch
import pickle
import json
import numpy as np
from transformers import AutoModelForQuestionAnswering, AutoTokenizer
from sentence_transformers import SentenceTransformer

In [4]:
with open('index.pickle', 'rb') as f:
  index = pickle.load(f)

with open('data.json', 'r') as f:
  dataset = json.load(f)

In [5]:
retriever_model = SentenceTransformer('distilbert-base-nli-stsb-mean-tokens')

Downloading:   0%|          | 0.00/345 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/4.01k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/555 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/122 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/265M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/505 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/229 [00:00<?, ?B/s]

In [6]:
model_name = 'deepset/minilm-uncased-squad2'

reader_model = AutoModelForQuestionAnswering.from_pretrained(model_name)
reader_tokenizer = AutoTokenizer.from_pretrained(model_name)

Downloading:   0%|          | 0.00/477 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/127M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/107 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

In [8]:
def retriever(query: str, k=1):
  encoded_query = np.array([retriever_model.encode(query)])
  top_k = index.search(encoded_query, k)
  scores = top_k[0][0]
  results = [dataset[_id] for _id in top_k[1][0]]
  return list(zip(results, scores))

In [9]:
q = "What are the notable pandemic prevention organizations?"
r = retriever(q, k=3)[0][0]['text']

In [11]:
from transformers import pipeline
 
reader = pipeline('question-answering', model=reader_model, tokenizer=reader_tokenizer)

In [13]:
reader({'question': q, 'context': r})

{'answer': 'pandemic preparedness or pandemic mitigation',
 'end': 310,
 'score': 0.6223045587539673,
 'start': 266}

In [37]:
def retriver_reader(query: str, k=1):
  for a in retriever(query, k):
    context = a[0]['text']
    answer = reader({'question': q, 'context': context})['answer']
    a[0].update({'answer': answer})
    yield a

list(retriver_reader('What are the notable pandemic prevention organizations?', 5))

[({'answer': 'pandemic preparedness or pandemic mitigation',
   'text': 'Pandemic prevention is the organization and management of preventive measures against pandemics. Those include measures to reduce causes of new infectious diseases and measures to prevent outbreaks and epidemics from becoming pandemics.\nIt is not to be mistaken for pandemic preparedness or pandemic mitigation which largely seek to mitigate the magnitude of negative effects of pandemics and may overlap with pandemic prevention in some respects.',
   'title': 'Pandemic prevention',
   'url': 'https://en.wikipedia.org/wiki/Pandemic_prevention'},
  143.19728),
 ({'answer': 'healthcare practices and the administration of vaccines',
   'text': 'Targeted immunization strategies are approaches designed to increase the immunization level of populations and decrease the chances of epidemic outbreaks. Though often in regards to use in healthcare practices and the administration of vaccines to prevent biological epidemic out