In [1]:
import os
import re
import gc
import string
import logging

import torch
import numpy as np
from pathlib import Path
from haystack import Document
from haystack.nodes import PreProcessor       
from haystack.nodes import EmbeddingRetriever
from haystack.document_stores import FAISSDocumentStore
from haystack.nodes.prompt import PromptNode, PromptTemplate

from hpqa import preproc

for module in ["farm.utils", "farm.infer", "haystack.reader.farm.FARMReader",
              "farm.modeling.prediction_head", "elasticsearch", "haystack.eval",
               "haystack.document_store.base", "haystack.retriever.base", 
              "farm.data_handler.dataset"]:
    module_logger = logging.getLogger(module)
    module_logger.setLevel(logging.ERROR)

In [2]:
max_number_of_words = 100

if not os.path.isfile("index.faiss"):
    document_store = FAISSDocumentStore(faiss_index_factory_str="Flat", return_embedding=True)
    document_store.delete_documents()

    documents = []
    for book_number in range(1, 8):
        chapters = preproc.get_book_chapters(book_number)
        for chapter, text in chapters.items():
            documents.append(
                Document(
                    content=text.strip(),
                    meta=dict(book=book_number, chapter=chapter)
                )
            )
            
    preprocessor = PreProcessor(
        clean_empty_lines=True,
        clean_whitespace=True,
        clean_header_footer=True,
        split_by="word",
        split_length=max_number_of_words,
        split_respect_sentence_boundary=True,
        split_overlap=10
    )

    docs_processed = preprocessor.process(documents)
    document_store.write_documents(docs_processed)
    
    retriever = EmbeddingRetriever(
        document_store=document_store,
        embedding_model="flax-sentence-embeddings/all_datasets_v3_mpnet-base",
        model_format="sentence_transformers"
    )

    document_store.update_embeddings(retriever=retriever)
    document_store.save("index.faiss")
    
else:
    
    document_store = FAISSDocumentStore.load("index.faiss")
    retriever = EmbeddingRetriever(
        document_store=document_store,
        embedding_model="flax-sentence-embeddings/all_datasets_v3_mpnet-base",
        model_format="sentence_transformers"
    )


In [3]:
def empty_cache(*args):
    for arg in args:
        del arg
    gc.collect()
    torch.cuda.empty_cache()
    
def batch_it(li, n):
    for i in range(0, len(li), n):
        yield li[i:i + n]

In [4]:

prompt_node = PromptNode(model_name_or_path="google/flan-t5-xl", default_prompt_template="question-answering")

In [None]:
batch_size = 2

questions = [
    "What is the Dursleys' address?",
    "How many presents does Dudley get on his birthday?",
    "What does the boa constrictor say to Harry?",
    "What time does the Hogwarts Express leave Platform 9 and 3/4?",
    "How are the Dursleys related to Harry?",
    "Why does Hagrid give Dudley a tail?",
    "What is Harry's wand made of?",
    "What famous wizard card does Harry get in his first chocolate frog?",
    "What does Hagrid name his pet dragon?",
    "What did Harry do to get him temporarily expelled from Hogwarts?",
    "How did Harry, Ron, and Hermione communicate to the other members of Dumbledore's Army that they were meeting?", 
    "How was Hermione able to take extra lessons?",
    "Why does Snape kill Dumbledore?",
    "Why does Snape protect Harry?",
    "Who would win a fight between Dumbledore and a grizzly bear?",
    "Who would win a fight between a muggle and a grizzly bear?",
    "What is a way to sneak into Hogwarts without being detected?",
    "Why do students make fun of Hermione?"
]

responses = []

for q in batch_it(questions, batch_size):
    docs = retriever.retrieve_batch(q)
    responses.append(prompt_node.prompt(prompt_template="question-answering", documents=docs, questions=q))
responses = [resp for batch in responses for resp in batch]

In [6]:
for q, r in zip(questions, responses):
    print(f"Question: {q}\nAnswer: {r}\n")

Question: What is the Dursleys' address?
Answer: number four, Privet Drive

Question: How many presents does Dudley get on his birthday?
Answer: 36

Question: What does the boa constrictor say to Harry?
Answer: It has never seen Brazil

Question: What time does the Hogwarts Express leave Platform 9 and 3/4?
Answer: 11 o'clock

Question: How are the Dursleys related to Harry?
Answer: They are his only living relatives

Question: Why does Hagrid give Dudley a tail?
Answer: To curse Dudley

Question: What is Harry's wand made of?
Answer: phoenix-feather core

Question: What famous wizard card does Harry get in his first chocolate frog?
Answer: Albus Dumbledore

Question: What does Hagrid name his pet dragon?
Answer: Norbert

Question: What did Harry do to get him temporarily expelled from Hogwarts?
Answer: performed the Patronus Charm

Question: How did Harry, Ron, and Hermione communicate to the other members of Dumbledore's Army that they were meeting?
Answer: Enchanted coins

Question: