In [None]:
import logging
from typing import Dict, List, Optional, Union
from pprint import pprint

from haystack.nodes import TextConverter, PDFToTextConverter, DocxToTextConverter, PreProcessor, EmbeddingRetriever, DensePassageRetriever
from haystack.utils import convert_files_to_docs, print_answers
from haystack.document_stores import InMemoryDocumentStore, FAISSDocumentStore
from haystack.nodes import FARMReader, TransformersReader, RAGenerator, Seq2SeqGenerator
from haystack.pipelines import GenerativeQAPipeline
from haystack.schema import Document

from transformers import PreTrainedTokenizer, BatchEncoding

%load_ext autoreload
%autoreload 2

In [None]:
logging.basicConfig(format="%(levelname)s - %(name)s -  %(message)s", level=logging.WARNING)
logging.getLogger("haystack").setLevel(logging.INFO)

## get documents

In [None]:
PDFS_PATH="/data/kg_pdfs_test/"

all_docs = convert_files_to_docs(dir_path=PDFS_PATH)

## Preprocessing 

In [None]:
preprocessor = PreProcessor(
    clean_empty_lines=True,
    clean_whitespace=True,
    clean_header_footer=False,
    split_by="word",
    split_length=128,  # smaller splits works better? 
    split_respect_sentence_boundary=True,
)

all_docs_process = preprocessor.process(all_docs)

print(f"n_files_input: {len(all_docs)}\nn_docs_output: {len(all_docs_process)}")

In [None]:
all_docs_process[:3]

## Document Store 

In [None]:
# In-Memory Document Store
# document_store = InMemoryDocumentStore()


# The FAISSDocumentStore uses a SQL(SQLite in-memory be default) database under-the-hood to store the document text and other meta data. 
# The vector embeddings of the text are indexed on a FAISS Index that later is queried for searching answers.
document_store = FAISSDocumentStore(sql_url = "sqlite:///faiss_document_store_2.db", 
                                    faiss_index_factory_str="Flat", similarity="dot_product", return_embedding=True)

In [None]:
document_store.write_documents(all_docs_process)

In [None]:
document_store.get_document_count()

## Retriever


In [None]:
# Initialize DPR Retriever to encode documents, encode question and query documents

dpr_retriever = DensePassageRetriever(
    document_store=document_store,
    query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
    passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
    use_gpu=True,
    embed_title=True,
)

In [None]:
# Add documents embeddings to index

document_store.update_embeddings(retriever=dpr_retriever)

In [None]:
document_store.get_all_documents()[55].embedding.shape

In [None]:
query = "What is streaming data?"
query1 = "How is deep learning used in industry?"
query2 = "What is a data mesh?"

In [None]:
dpr_ls = [(doc.content, doc.meta) for doc in dpr_retriever.retrieve(query2, top_k=5)]

pprint(dpr_ls)

## Generator

#### Retrieval Augmented Generator

In [None]:
# Initialize RAG Generator

fb_generator = RAGenerator(
    model_name_or_path="facebook/rag-token-nq",
    use_gpu=True,
    top_k=1,
    max_length=100,
    min_length=2,
    embed_title=True,
    num_beams=2,
)

In [None]:
ans = fb_generator.predict(query2, 
                           documents=dpr_retriever.retrieve(query2, top_k=5), 
                           top_k=3)

# pprint(ans.get('answers'))
pprint(ans)

#### T5-large

In [None]:
class _T5Converter:
    """
    A sequence-to-sequence model input converter (https://huggingface.co/yjernite/bart_eli5) based on the BART architecture fine-tuned on ELI5 dataset (https://arxiv.org/abs/1907.09190).
    The converter takes documents and a query as input and formats them into a single sequence that a seq2seq model can use it as input for its generation step.
    This includes model-specific prefixes, separation tokens and the actual conversion into tensors. 
    For more details refer to Yacine Jernite's excellent LFQA contributions at https://yjernite.github.io/lfqa.html
    """
    def __call__(self, tokenizer: PreTrainedTokenizer, query: str, documents: List[Document], top_k: Optional[int] = None) -> BatchEncoding:
        conditioned_doc = "<P> " + " <P> ".join([d.content for d in documents])
        # print(conditioned_doc)

        # concatenate question and support document into BART input
        query_and_docs = "question: {} context: {}".format(query, conditioned_doc)
        max_source_length = 512

        # return tokenizer([(query_and_docs, "A")], truncation=True, padding=True, max_length=max_source_length, return_tensors="pt")
        return tokenizer([query_and_docs], truncation=True, padding=True, max_length=max_source_length, return_tensors="pt")

In [None]:
# /data/t5-large; google/t5-large-lm-adapt

t5_generator = Seq2SeqGenerator(
    model_name_or_path="/data/t5-large",
    input_converter=_T5Converter(),
    use_gpu=True,
    top_k=1,
    max_length=100,
    min_length=2,
    num_beams=3,
)

In [None]:
ans = t5_generator.predict(query2, 
                           documents=dpr_retriever.retrieve(query2, top_k=5), 
                           top_k=3)

# pprint(ans.get('answers'))
pprint(ans)

## Pipeline

In [None]:
QUESTIONS = [
    "What is streaming data?",
    "How is deep learning used in industry?",
    "What is a data mesh?",
    "What do data scientists work on?",
    "How can cloud storage costs be reduced?",
    "What are the advantages of multi cloud?"
]

In [None]:
pipe_GQA = GenerativeQAPipeline(generator=t5_generator, retriever=dpr_retriever)

for question in QUESTIONS:
    res = pipe_GQA.run(query=question, 
                       params={"Generator": {"top_k": 1}, "Retriever": {"top_k": 5}})
    
    (print_answers(res, details="all"))