In [10]:
%pip install -Uq datasets==2.12.0 qdrant-client==1.2.0 sentence-transformers==2.2.2 torch==2.0.1

Note: you may need to restart the kernel to use updated packages.


In [1]:
import torch
from datasets import load_dataset
from transformers import pipeline
from sentence_transformers import SentenceTransformer
from qdrant_client import QdrantClient
from qdrant_client.http import models
from tqdm.auto import tqdm
from typing import List

In [2]:
# load the duorc dataset into a pandas dataframe
df = load_dataset("duorc", "ParaphraseRC", split="train").to_pandas()
df = df[["title", "plot"]]  # select only title and plot column
print(f"Before removing duplicates: {len(df)}")

df = df.drop_duplicates(subset="plot")  # drop rows containing duplicate plot passages, if any
print(f"Unique Plots: {len(df)}")
df.head()

Downloading builder script:   0%|          | 0.00/5.21k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/4.77k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/9.13k [00:00<?, ?B/s]

Downloading and preparing dataset duorc/ParaphraseRC to /Users/balamurali/.cache/huggingface/datasets/duorc/ParaphraseRC/1.0.0/7a96356b7615d573abcd03a9328292c38348547971989538a771c32089bff199...


Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/16.3M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.49M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.84M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/69524 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/15591 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/15857 [00:00<?, ? examples/s]

Dataset duorc downloaded and prepared to /Users/balamurali/.cache/huggingface/datasets/duorc/ParaphraseRC/1.0.0/7a96356b7615d573abcd03a9328292c38348547971989538a771c32089bff199. Subsequent calls will reuse this data.
Before removing duplicates: 69524
Unique Plots: 5133


Unnamed: 0,title,plot
0,Ghosts of Mars,"Set in the second half of the 22nd century, Ma..."
15,Noriko's Dinner Table,"The film starts on December 12th, 2001 with a ..."
34,Gutterballs,A brutally sadistic rape leads to a series of ...
83,An Innocent Man,Jimmie Rainwood (Tom Selleck) is a respected m...
105,The Sorcerer's Apprentice,"Every hundred years, the evil Morgana (Kelly L..."


In [3]:
client = QdrantClient(":memory:")

Now we create a new collection called extractive-question-answering — we can name the collection anything we want.

We specify the metric type as "cosine" and dimension or size as 384 because the retriever we use to generate context embeddings is optimized for cosine similarity and outputs 384-dimension vectors.

In [4]:
collection_name = "extractive-question-answering"

collections = client.get_collections()
print(collections)

# only create collection if it doesn't exist
if collection_name not in [c.name for c in collections.collections]:
    client.recreate_collection(
        collection_name=collection_name,
        vectors_config=models.VectorParams(
            size=384,
            distance=models.Distance.COSINE,
        ),
    )
collections = client.get_collections()
print(collections)

collections=[]
collections=[CollectionDescription(name='extractive-question-answering')]


In [5]:
# set device to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"

# load the retriever model from huggingface model hub
retriever = SentenceTransformer("multi-qa-MiniLM-L6-cos-v1", device=device)
retriever

Downloading (…)5fedf/.gitattributes:   0%|          | 0.00/737 [00:00<?, ?B/s]

Downloading (…)_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading (…)2cb455fedf/README.md:   0%|          | 0.00/11.5k [00:00<?, ?B/s]

Downloading (…)b455fedf/config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

Downloading (…)ce_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

Downloading (…)edf/data_config.json:   0%|          | 0.00/25.5k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

Downloading (…)nce_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading (…)5fedf/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/383 [00:00<?, ?B/s]

Downloading (…)fedf/train_script.py:   0%|          | 0.00/13.8k [00:00<?, ?B/s]

Downloading (…)2cb455fedf/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)455fedf/modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

SentenceTransformer(
  (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
  (2): Normalize()
)

## Generate Embeddings -> Store in Qdrant

Next, we need to generate embeddings for the context passages. We will use the `retriever.encode` for that. 

When passing the documents to Qdrant, we need an:
1. id (a unique integer value), 
2. context embedding, and 
3. payload for each document representing context passages in the dataset. The payload is a dictionary containing data relevant to our embeddings, such as the title, plot etc.

In [6]:
%%time

batch_size = 512  # specify batch size according to your RAM and compute, higher batch size = more RAM usage

for index in tqdm(range(0, len(df), batch_size)):
    i_end = min(index + batch_size, len(df))  # find end of batch
    batch = df.iloc[index:i_end]  # extract batch
    emb = retriever.encode(batch["plot"].tolist()).tolist()  # generate embeddings for batch
    meta = batch.to_dict(orient="records")  # get metadata
    ids = list(range(index, i_end))  # create unique IDs

    # upsert to qdrant
    client.upsert(
        collection_name=collection_name,
        points=models.Batch(ids=ids, vectors=emb, payloads=meta),
    )

collection_vector_count = client.get_collection(collection_name=collection_name).vectors_count
print(f"Vector count in collection: {collection_vector_count}")
assert collection_vector_count == len(df)

  0%|          | 0/11 [00:00<?, ?it/s]

Vector count in collection: 5133
CPU times: user 13min 33s, sys: 3min 24s, total: 16min 57s
Wall time: 3min 41s


## Initialize Reader

We use the `bert-large-uncased-whole-word-masking-finetuned-squad` model from the HuggingFace model hub as our reader model. This is finetuned on the [SQuAD dataset](https://rajpurkar.github.io/SQuAD-explorer/). It is trained to extract an answer from a given context. This special mechanism is why we can use this model to extract answers from our context passages. 

This is our (encoder) component which uses the contexts to extract an answer.

In [7]:
model_name = "bert-large-uncased-whole-word-masking-finetuned-squad"

# load the reader model into a question-answering pipeline
reader = pipeline("question-answering", model=model_name, tokenizer=model_name)
print(reader.model, reader)

Downloading (…)lve/main/config.json:   0%|          | 0.00/443 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/1.34G [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-large-uncased-whole-word-masking-finetuned-squad were not used when initializing BertForQuestionAnswering: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

BertForQuestionAnswering(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 1024, padding_idx=0)
      (position_embeddings): Embedding(512, 1024)
      (token_type_embeddings): Embedding(2, 1024)
      (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-23): 24 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): LayerNorm((1024,), ep

In [8]:
def get_relevant_plot(question: str, top_k: int) -> List[str]:
    """
    Get the relevant plot for a given question

    Args:
        question (str): What do we want to know?
        top_k (int): Top K results to return

    Returns:
        context (List[str]):
    """
    try:
        encoded_query = retriever.encode(question).tolist()  # generate embeddings for the question

        result = client.search(
            collection_name=collection_name,
            query_vector=encoded_query,
            limit=top_k,
        )  # search qdrant collection for context passage with the answer

        context = [
            [x.payload["title"], x.payload["plot"]] for x in result
        ]  # extract title and payload from result
        return context

    except Exception as e:
        print({e})

In [9]:
def extract_answer(question: str, context: List[str]):
    """
    Extract the answer from the context for a given question

    Args:
        question (str): _description_
        context (list[str]): _description_
    """
    results = []
    for c in context:
        # feed the reader the question and contexts to extract answers
        answer = reader(question=question, context=c[1])

        # add the context to answer dict for printing both together, we print only first 500 characters of plot
        answer["title"] = c[0]
        results.append(answer)

    # sort the result based on the score from reader model
    sorted_result = sorted(results, key=lambda x: x["score"], reverse=True)
    for i in range(len(sorted_result)):
        print(f"{i+1}", end=" ")
        print(
            "Answer: ",
            sorted_result[i]["answer"],
            "\n  Title: ",
            sorted_result[i]["title"],
            "\n  score: ",
            sorted_result[i]["score"],
        )


question = "In the movie 3 Idiots, what is the name of the college where the main characters Rancho, Farhan, and Raju study"
context = get_relevant_plot(question, top_k=1)
context

[['Three Idiots',

In [10]:
extract_answer(question, context)

1 Answer:  Imperial College of Engineering 
  Title:  Three Idiots 
  score:  0.9049272537231445


In [23]:
question = "Explain movie Jeepers Creepers II"

In [27]:
context = get_relevant_plot(question, top_k=2)
context

[['Jeepers Creepers II',
  'Three days after the events of the first film, a young boy named Billy Taggart helps his father, Jack Taggart Sr., erect scarecrows in a cornfield. As Billy makes his way through the field, the Creeper, disguised as one of the scarecrows, abducts him in front of Taggart and Billy\'s older brother, Jack Jr. The following day, a school bus carrying a high school basketball team and cheerleaders suffers a blowout. The chaperones inspect the tire and find it torn apart by a hand-crafted shuriken seemingly constructed from fragments of bone. Back on the Taggart farm, Jack finds a dagger dropped by the Creeper. Upon showing it to his father, the weapon inexplicably flies out of his hand on its own accord.\nOn the bus, cheerleader Minxie has a vision of Billy and Darry Jenner, the Creeper\'s victim from the first film, who both attempt to warn her about the Creeper. The Creeper then blows out another tire, disabling the bus. With the party stranded, the Creeper att

In [25]:
extract_answer(question, context)

1 Answer:  Jenny her insane patient. Later, they make their way to the crypt 
  Title:  Vampires Vs. Zombies 
  score:  0.0009691474842838943
2 Answer:  first film 
  Title:  Jeepers Creepers II 
  score:  1.1143445590278134e-05
