In [1]:
%pip install -Uq datasets==2.12.0 qdrant-client==1.2.0 sentence-transformers==2.2.2 torch==2.0.1

Note: you may need to restart the kernel to use updated packages.


In [2]:
import torch
from datasets import load_dataset
from transformers import pipeline
from sentence_transformers import SentenceTransformer
from qdrant_client import QdrantClient
from qdrant_client.http import models
from tqdm.auto import tqdm
from typing import List
from qdrant_client.http.models import Distance, VectorParams

In [3]:
client = QdrantClient("localhost", port=6333)

In [4]:
# load the duorc dataset into a pandas dataframe
df = load_dataset("duorc", "ParaphraseRC", split="train").to_pandas()
df = df[["title", "plot"]]  # select only title and plot column
print(f"Before removing duplicates: {len(df)}")

df = df.drop_duplicates(subset="plot")  # drop rows containing duplicate plot passages, if any
print(f"Unique Plots: {len(df)}")
df.head()

Found cached dataset duorc (/Users/balamurali/.cache/huggingface/datasets/duorc/ParaphraseRC/1.0.0/7a96356b7615d573abcd03a9328292c38348547971989538a771c32089bff199)


Before removing duplicates: 69524
Unique Plots: 5133


Unnamed: 0,title,plot
0,Ghosts of Mars,"Set in the second half of the 22nd century, Ma..."
15,Noriko's Dinner Table,"The film starts on December 12th, 2001 with a ..."
34,Gutterballs,A brutally sadistic rape leads to a series of ...
83,An Innocent Man,Jimmie Rainwood (Tom Selleck) is a respected m...
105,The Sorcerer's Apprentice,"Every hundred years, the evil Morgana (Kelly L..."


In [5]:
collection_name = "extractive-question-answering"

collections = client.get_collections()
print(collections)

# only create collection if it doesn't exist
if collection_name not in [c.name for c in collections.collections]:
    client.recreate_collection(
        collection_name=collection_name,
        vectors_config=models.VectorParams(
            size=384,
            distance=models.Distance.COSINE,
        ),
    )
collections = client.get_collections()
print(collections)

collections=[CollectionDescription(name='extractive-question-answering')]
collections=[CollectionDescription(name='extractive-question-answering')]


In [6]:
# set device to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"

# load the retriever model from huggingface model hub
retriever = SentenceTransformer("multi-qa-MiniLM-L6-cos-v1", device=device)
retriever

SentenceTransformer(
  (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
  (2): Normalize()
)

In [8]:
%%time

batch_size = 512  # specify batch size according to your RAM and compute, higher batch size = more RAM usage

for index in tqdm(range(0, len(df), batch_size)):
    i_end = min(index + batch_size, len(df))  # find end of batch
    batch = df.iloc[index:i_end]  # extract batch
    
    emb = retriever.encode((batch["title"]+batch["plot"]).tolist()).tolist()  # generate embeddings for batch
    meta = batch.to_dict(orient="records")  # get metadata
    ids = list(range(index, i_end))  # create unique IDs

    # upsert to qdrant
    client.upsert(
        collection_name=collection_name,
        points=models.Batch(ids=ids, vectors=emb, payloads=meta),
    )

collection_vector_count = client.get_collection(collection_name=collection_name).vectors_count
print(f"Vector count in collection: {collection_vector_count}")
assert collection_vector_count == len(df)

  0%|          | 0/11 [00:00<?, ?it/s]

Vector count in collection: 5133
CPU times: user 13min 47s, sys: 3min 21s, total: 17min 8s
Wall time: 3min 43s


In [9]:
model_name = "bert-large-uncased-whole-word-masking-finetuned-squad"

# load the reader model into a question-answering pipeline
reader = pipeline("question-answering", model=model_name, tokenizer=model_name)
print(reader.model, reader)

Some weights of the model checkpoint at bert-large-uncased-whole-word-masking-finetuned-squad were not used when initializing BertForQuestionAnswering: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


BertForQuestionAnswering(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 1024, padding_idx=0)
      (position_embeddings): Embedding(512, 1024)
      (token_type_embeddings): Embedding(2, 1024)
      (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-23): 24 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): LayerNorm((1024,), ep

In [10]:
def get_relevant_plot(question: str, top_k: int) -> List[str]:
    """
    Get the relevant plot for a given question

    Args:
        question (str): What do we want to know?
        top_k (int): Top K results to return

    Returns:
        context (List[str]):
    """
    try:
        encoded_query = retriever.encode(question).tolist()  # generate embeddings for the question

        result = client.search(
            collection_name=collection_name,
            query_vector=encoded_query,
            limit=top_k,
        )  # search qdrant collection for context passage with the answer

        context = [
            [x.payload["title"], x.payload["plot"]] for x in result
        ]  # extract title and payload from result
        return context

    except Exception as e:
        print({e})

In [11]:
def extract_answer(question: str, context: List[str]):
    """
    Extract the answer from the context for a given question

    Args:
        question (str): _description_
        context (list[str]): _description_
    """
    results = []
    for c in context:
        # feed the reader the question and contexts to extract answers
        answer = reader(question=question, context=c[1])

        # add the context to answer dict for printing both together, we print only first 500 characters of plot
        answer["title"] = c[0]
        results.append(answer)

    # sort the result based on the score from reader model
    sorted_result = sorted(results, key=lambda x: x["score"], reverse=True)
    for i in range(len(sorted_result)):
        print(f"{i+1}", end=" ")
        print(
            "Answer: ",
            sorted_result[i]["answer"],
            "\n  Title: ",
            sorted_result[i]["title"],
            "\n  score: ",
            sorted_result[i]["score"],
        )


question = "In the movie 3 Idiots, what is the name of the college where the main characters Rancho, Farhan, and Raju study"
context = get_relevant_plot(question, top_k=1)
context

[['Three Idiots',

In [12]:
extract_answer(question, context)

1 Answer:  Imperial College of Engineering 
  Title:  Three Idiots 
  score:  0.9049272537231445


In [14]:
question="In the movie Three Idiots, whos is virus?"

In [15]:
context = get_relevant_plot(question, top_k=1)
context

[['Three Idiots',

In [16]:
extract_answer(question, context)

1 Answer:  Professor Viru Sahastrabudhhe 
  Title:  Three Idiots 
  score:  0.9276266694068909
