In [1]:
import base64
import copy
import hashlib
import io
import os
import uuid

os.chdir("/app")

from ipywidgets import HTML

from apps.inners.exceptions import use_case_exception
from apps.inners.models.dtos.document_category import DocumentCategory

import shutil
from pathlib import Path

from langchain import retrievers
from langchain_community.storage.redis import RedisStore
from langchain_core.embeddings import Embeddings
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import BaseMessage, ChatMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.documents import Document as LangchainDocument
from litellm import Router

from apps.inners.models.dtos.element_category import ElementCategory

from uuid import UUID

from sqlmodel.ext.asyncio.session import AsyncSession
from starlette.datastructures import State

from apps.inners.models.daos.document import Document
from apps.inners.models.dtos.contracts.responses.managements.documents.file_document_response import \
    FileDocumentResponse
from apps.inners.models.dtos.contracts.responses.managements.documents.text_document_response import \
    TextDocumentResponse
from apps.inners.models.dtos.contracts.responses.managements.documents.web_document_response import WebDocumentResponse
from apps.inners.use_cases.managements.document_management import DocumentManagement
from apps.inners.use_cases.managements.file_document_management import FileDocumentManagement
from apps.inners.use_cases.managements.text_document_management import TextDocumentManagement
from apps.inners.use_cases.managements.web_document_management import WebDocumentManagement
from typing import List
from typing import TypedDict, Tuple, Dict, Any

import dotenv
import litellm
from datasets import load_dataset
from dotenv import find_dotenv
from langchain_community.chat_models import ChatLiteLLMRouter
from langchain_community.vectorstores.milvus import Milvus
from langchain_core.runnables.base import RunnableLike, RunnableSerializable
from langgraph.graph import StateGraph
from langgraph.graph.graph import CompiledGraph
from ragas import evaluate
from unstructured.documents.elements import Element, Table, Image, Text
from unstructured.partition.auto import partition
from unstructured.partition.utils.constants import PartitionStrategy

from apps.inners.use_cases.embeddings.hugging_face_e5_instruct_embedding import HuggingFaceE5InstructEmbeddings
from apps.outers.datastores.four_datastore import FourDatastore
from apps.outers.datastores.one_datastore import OneDatastore
from apps.outers.datastores.three_datastore import ThreeDatastore
from apps.outers.datastores.two_datastore import TwoDatastore
from apps.outers.repositories.file_document_repository import FileDocumentRepository
from apps.outers.repositories.text_document_repository import TextDocumentRepository
from apps.outers.repositories.web_document_repository import WebDocumentRepository
from tests.containers.test_container import TestContainer
from tests.seeders.all_seeder import AllSeeder


In [2]:
# import tensorflow
# 
# tensorflow.config.list_physical_devices('GPU')

In [3]:
# import torch
# 
# torch.cuda.is_available()

In [4]:
dotenv.load_dotenv(find_dotenv())


True

In [5]:
test_container = TestContainer()

one_datastore: OneDatastore = test_container.applications.datastores.one()
two_datastore: TwoDatastore = test_container.applications.datastores.two()
three_datastore: ThreeDatastore = test_container.applications.datastores.three()
four_datastore: FourDatastore = test_container.applications.datastores.four()
temp_datastore: ThreeDatastore = test_container.applications.datastores.temp()

file_document_repository: FileDocumentRepository = test_container.applications.repositories.file_document()
text_document_repository: TextDocumentRepository = test_container.applications.repositories.text_document()
web_document_repository: WebDocumentRepository = test_container.applications.repositories.web_document()

document_management: DocumentManagement = test_container.applications.use_cases.managements.document()
file_document_management: FileDocumentManagement = test_container.applications.use_cases.managements.file_document()
text_document_management: TextDocumentManagement = test_container.applications.use_cases.managements.text_document()
web_document_management: WebDocumentManagement = test_container.applications.use_cases.managements.web_document()

all_seeder: AllSeeder = test_container.seeders.all()

In [6]:
await all_seeder.up()

In [None]:
await all_seeder.down()

In [5]:
await two_datastore.client.set("test", "test", ex=10)

True

In [7]:
# loading the V2 dataset
amnesty_qa = load_dataset("explodinggradients/amnesty_qa", "english_v2", trust_remote_code=True)



In [8]:
amnesty_qa

DatasetDict({
    eval: Dataset({
        features: ['question', 'ground_truth', 'answer', 'contexts'],
        num_rows: 20
    })
})

In [9]:
litellm.set_verbose = False
model_list: List[Dict] = [
    {
        "model_name": "claude-3-haiku",
        "litellm_params": {
            "model": "claude-3-haiku-20240307",
            "api_key": os.getenv("ANTHROPIC_API_KEY"),
        }
    },
    {
        "model_name": "claude-3-opus",
        "litellm_params": {
            "model": "claude-3-opus-20240229",
            "api_key": os.getenv("ANTHROPIC_API_KEY"),
        }
    }
]
router: Router = Router(model_list=model_list)
model: ChatLiteLLMRouter = ChatLiteLLMRouter(
    router=router,
    model_name="claude-3-haiku",
    streaming=True,
    temperature=0,
)

embeddings: HuggingFaceE5InstructEmbeddings = HuggingFaceE5InstructEmbeddings(
    model_name="intfloat/multilingual-e5-large-instruct",
    model_kwargs={'device': 'cuda'},
    encode_kwargs={'normalize_embeddings': True},
    query_instruction="Given the question, retrieve the answer from the context."
)

In [10]:
class PartitionDocumentProcessor:
    def __init__(
            self,
            document_management: DocumentManagement,
            file_document_management: FileDocumentManagement,
            text_document_management: TextDocumentManagement,
            web_document_management: WebDocumentManagement,
    ):
        self.document_management = document_management
        self.file_document_management = file_document_management
        self.text_document_management = text_document_management
        self.web_document_management = web_document_management

    async def _partition_file(self, state: State, found_document: Document) -> List[Element]:
        found_file_document: FileDocumentResponse = await self.file_document_management.find_one_by_id_with_authorization(
            state=state,
            id=found_document.id
        )
        file_data: bytes = self.file_document_management.file_document_repository.get_object_data(
            object_name=found_file_document.file_name
        )
        extract_image_path: Path = self.file_document_management.file_document_repository.file_path / found_file_document.file_data_hash
        extract_image_path.mkdir(exist_ok=True)
        shutil.rmtree(extract_image_path)
        elements: List[Element] = partition(
            metadata_filename=found_file_document.file_name,
            file=io.BytesIO(file_data),
            extract_images_in_pdf=True,
            extract_image_block_output_dir=str(extract_image_path),
            strategy=PartitionStrategy.AUTO,
            hi_res_model_name="yolox"
        )

        return elements

    async def _partition_text(self, state: State, found_document: Document) -> List[Element]:
        found_text_document: TextDocumentResponse = await self.text_document_management.find_one_by_id_with_authorization(
            state=state,
            id=found_document.id
        )
        elements: List[Element] = partition(
            text=found_text_document.text_content
        )

        return elements

    async def _partition_web(self, state: State, found_document: Document) -> List[Element]:
        found_web_document: WebDocumentResponse = await self.web_document_management.find_one_by_id_with_authorization(
            state=state,
            id=found_document.id
        )
        elements: List[Element] = partition(
            url=found_web_document.web_url
        )

        return elements

    async def partition(self, state: State, document_id: UUID) -> List[Element]:
        found_document: Document = await self.document_management.find_one_by_id_with_authorization(
            state=state,
            id=document_id
        )
        if found_document.document_type_id == "file":
            elements: List[Element] = await self._partition_file(
                state=state,
                found_document=found_document
            )
        elif found_document.document_type_id == "text":
            elements: List[Element] = await self._partition_text(
                state=state,
                found_document=found_document
            )
        elif found_document.document_type_id == "web":
            elements: List[Element] = await self._partition_web(
                state=state,
                found_document=found_document
            )
        else:
            raise use_case_exception.DocumentTypeNotSupported()

        return elements


class CategoryDocumentProcessor:
    def __init__(self):
        pass

    async def categorize_elements(self, elements: List[Element]) -> ElementCategory:
        element_category: ElementCategory = ElementCategory(
            texts=[],
            tables=[],
            images=[]
        )

        for element in elements:
            if any(
                    element_type in str(type(element)) for element_type in
                    ["unstructured.documents.elements.Text", "unstructured.documents.elements.NarrativeText"]
            ):
                element_category.texts.append(element)
            elif any(
                    element_type in str(type(element)) for element_type in
                    ["unstructured.documents.elements.Table"]
            ):
                element_category.tables.append(element)
            elif any(
                    element_type in str(type(element)) for element_type in
                    ["unstructured.documents.elements.Image"]
            ):
                file_io = open(element.metadata.image_path, "rb")
                element.metadata.image_mime_type = "image/jpeg"
                element.metadata.image_base64 = base64.b64encode(file_io.read()).decode("utf-8")
                file_io.close()
                element_category.images.append(element)
            else:
                print(f"BaseDocumentProcessor.categorize_elements: Ignoring element type {type(element)}.")

        return element_category


class SummaryDocumentProcessor:
    def __init__(self):
        pass

    def summarize_tables(self, tables: List[Table], model: BaseChatModel) -> List[str]:
        prompt_text = """You are an assistant tasked with summarizing tables for retrieval. \
        These summaries will be embedded and used to retrieve the table. \
        Give a concise passage summary of the table that is well optimized for retrieval. \
        Make sure the output only the summary without re-explaining. \
        Table : {table} """
        prompt: ChatPromptTemplate = ChatPromptTemplate.from_template(prompt_text)
        chain: RunnableSerializable = {"table": lambda table: table.text} | prompt | model | StrOutputParser()
        summaries: List[str] = chain.batch(tables)

        return summaries

    def _get_message_from_image(self, model: BaseChatModel, prompt_text: str, image: Image) -> BaseMessage:
        message: BaseMessage = model.invoke([
            ChatMessage(
                role="user",
                content=[
                    {
                        "type": "text",
                        "text": prompt_text
                    },
                    {
                        "type": "image",
                        "source": {
                            "type": "base64",
                            "media_type": image.metadata.image_mime_type,
                            "data": image.metadata.image_base64
                        }
                    }
                ]
            )
        ])

        return message

    def summarize_images(self, images: List[Image], model: BaseChatModel) -> List[str]:
        prompt_text = """You are an assistant tasked with summarizing images for retrieval. \
        These summaries will be embedded and used to retrieve the image. \
        Give a concise passage summary of the image that is well optimized for retrieval. \
        Make sure the output only the summary without re-explaining. \
        """
        summaries: List[str] = []
        for image in images:
            message: BaseMessage = self._get_message_from_image(
                model=model,
                prompt_text=prompt_text,
                image=image
            )
            summaries.append(message.content)

        return summaries


class MainDocumentProcessor:
    def __init__(
            self,
            summary_document_processor: "SummaryDocumentProcessor",
    ):
        self.summary_document_processor = summary_document_processor

    def split_texts(self, texts: List[Text], chunk_size: int, chunk_overlap: int) -> List[str]:
        text_splitter: RecursiveCharacterTextSplitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap
        )
        text: str = " ".join([text.text for text in texts])
        splitted_text: List[str] = text_splitter.split_text(
            text=text
        )

        return splitted_text

    def get_documents(
            self,
            element_category: ElementCategory,
            summarization_model: BaseChatModel,
            is_summarize_tables: bool = False,
            is_summarize_images: bool = False,
            chunk_size: int = 400,
            chunk_overlap: int = int(400 * 0.1),
            id_key: str = "id"
    ) -> DocumentCategory:
        document_category: DocumentCategory = DocumentCategory(
            texts=[],
            tables=[],
            images=[],
            id_key=id_key
        )
        splitted_texts: List[str] = self.split_texts(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap,
            texts=element_category.texts
        )
        for text in splitted_texts:
            document_category.texts.append(LangchainDocument(
                page_content=text,
                metadata={
                    id_key: str(uuid.uuid4())
                }
            ))

        if is_summarize_tables:
            summarized_tables: List[str] = self.summary_document_processor.summarize_tables(
                tables=element_category.tables,
                model=summarization_model
            )
            for table in summarized_tables:
                document_category.tables.append(LangchainDocument(
                    page_content=table,
                    metadata={
                        id_key: str(uuid.uuid4())
                    }
                ))

        if is_summarize_images:
            summarized_images: List[str] = self.summary_document_processor.summarize_images(
                images=element_category.images,
                model=summarization_model
            )
            for image, summarized_image in zip(element_category.images, summarized_images):
                document_category.images.append(LangchainDocument(
                    page_content=summarized_image,
                    metadata={
                        id_key: str(uuid.uuid4()),
                        "image": {
                            "mime_type": image.metadata.image_mime_type,
                            "base64": image.metadata.image_base64
                        }
                    }
                ))

        return document_category


partition_document_processor: PartitionDocumentProcessor = PartitionDocumentProcessor(
    document_management=document_management,
    file_document_management=file_document_management,
    text_document_management=text_document_management,
    web_document_management=web_document_management,
)

category_document_processor: CategoryDocumentProcessor = CategoryDocumentProcessor()
summary_document_processor: SummaryDocumentProcessor = SummaryDocumentProcessor()
main_document_processor: MainDocumentProcessor = MainDocumentProcessor(
    summary_document_processor=summary_document_processor
)


In [11]:
elements: List[Element] = None


async def handler(session: AsyncSession):
    global elements
    state: State = State()
    state.session = session
    state.authorized_session = all_seeder.session_seeder.session_mock.data[0]
    elements = await partition_document_processor.partition(
        state=state,
        document_id=all_seeder.file_document_seeder.file_document_mock.data[0].id
    )


await one_datastore.retryable(handler)

elements

[<unstructured.documents.elements.Image at 0x7f7b986b87c0>,
 <unstructured.documents.elements.NarrativeText at 0x7f7b986b8580>,
 <unstructured.documents.elements.NarrativeText at 0x7f7b986b8a90>,
 <unstructured.documents.elements.Title at 0x7f7b986b8910>,
 <unstructured.documents.elements.Text at 0x7f7b9af4d240>,
 <unstructured.documents.elements.NarrativeText at 0x7f7b986b8700>,
 <unstructured.documents.elements.NarrativeText at 0x7f7b986b92d0>,
 <unstructured.documents.elements.ListItem at 0x7f7b986b9300>,
 <unstructured.documents.elements.NarrativeText at 0x7f7b986b9240>,
 <unstructured.documents.elements.NarrativeText at 0x7f7b986b84c0>,
 <unstructured.documents.elements.NarrativeText at 0x7f7b986ba110>,
 <unstructured.documents.elements.NarrativeText at 0x7f7b986ba380>,
 <unstructured.documents.elements.NarrativeText at 0x7f7b986b9cc0>,
 <unstructured.documents.elements.Text at 0x7f7b9af4f340>,
 <unstructured.documents.elements.NarrativeText at 0x7f7b986b9bd0>,
 <unstructured.docu

In [12]:
element_category: ElementCategory = await category_document_processor.categorize_elements(
    elements=elements
)
element_category.tables = element_category.tables[:1]
element_category.images = element_category.images[:1]
[x.text for x in element_category.texts]

BaseDocumentProcessor.categorize_elements: Ignoring element type <class 'unstructured.documents.elements.Title'>.
BaseDocumentProcessor.categorize_elements: Ignoring element type <class 'unstructured.documents.elements.ListItem'>.
BaseDocumentProcessor.categorize_elements: Ignoring element type <class 'unstructured.documents.elements.Header'>.
BaseDocumentProcessor.categorize_elements: Ignoring element type <class 'unstructured.documents.elements.FigureCaption'>.
BaseDocumentProcessor.categorize_elements: Ignoring element type <class 'unstructured.documents.elements.Header'>.
BaseDocumentProcessor.categorize_elements: Ignoring element type <class 'unstructured.documents.elements.Header'>.
BaseDocumentProcessor.categorize_elements: Ignoring element type <class 'unstructured.documents.elements.Header'>.
BaseDocumentProcessor.categorize_elements: Ignoring element type <class 'unstructured.documents.elements.Title'>.
BaseDocumentProcessor.categorize_elements: Ignoring element type <class '

['Received April 5, 2020, accepted April 14, 2020, date of publication April 17, 2020, date of current version May 5, 2020.',
 'Digital Object Identifier 10.1109/ACCESS.2020.2988510',
 'LIJIA CHEN1, PINGPING CHEN 2,4, (Member, IEEE), AND ZHIJIAN LIN 3, (Member, IEEE) 1School of Design, Yango University, Fuzhou 350015, China 2School of Advanced Manufacturing, Science Park of Fuzhou University, Jinjiang 362251, China 3School of Information, Fuzhou University, Fuzhou 35008, China',
 'Corresponding author: Pingping Chen (ppchen.xm@gmail.com)',
 'This work was supported in part by the Humanities and Social Science Planning Funds of Fujian Province under Grant 275 JAS19453, and in part by the Distinguished Scholar Grant of Educational Commission of Fujian Province.',
 'INDEX TERMS Education, artiﬁcial intelligence, leaner.',
 'I. INTRODUCTION As illustrated by Henry Ford in the analogy, innovation does not mean working that the society should work only with what has been the norm, such as ﬁn

In [13]:
chunk_size: int = 400
chunk_overlap: int = int(chunk_size * 0.1)
document_category: DocumentCategory = main_document_processor.get_documents(
    element_category=element_category,
    chunk_size=chunk_size,
    chunk_overlap=chunk_overlap,
    summarization_model=model,
    is_summarize_tables=True,
    is_summarize_images=True
)
document_category



In [14]:
image_html = f'<img src="data:{document_category.images[0].metadata["image"]["mime_type"]};base64,{document_category.images[0].metadata["image"]["base64"]}" />'
display(HTML(image_html))
# image_html

HTML(value='<img src="data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSE…

In [15]:
class MultiVectorRetriever:
    def __init__(
            self,
            two_datastore: TwoDatastore,
            four_datastore: FourDatastore,
    ):
        self.two_datastore = two_datastore
        self.four_datastore = four_datastore

    def get_retriever(
            self,
            embedding_function: Embeddings,
            collection_name: str,
            **kwargs: Any
    ) -> retrievers.MultiVectorRetriever:
        document_store: RedisStore = RedisStore(
            redis_url=self.two_datastore.two_datastore_setting.URL,
        )
        vector_store: Milvus = self.four_datastore.get_client(
            embedding_function=embedding_function,
            collection_name=collection_name,
        )
        return retrievers.MultiVectorRetriever(
            docstore=document_store,
            vectorstore=vector_store,
            **kwargs
        )


multi_vector_retriever: MultiVectorRetriever = MultiVectorRetriever(
    two_datastore=two_datastore,
    four_datastore=four_datastore,
)

In [28]:
class GraphState(TypedDict):
    data: Dict[str, Any]


class GraphLongFormQa:
    def __init__(
            self,
            multi_vector_retriever: MultiVectorRetriever,
    ):
        self.multi_vector_retriever = multi_vector_retriever

    def _retrieve(self, input_state: GraphState) -> GraphState:
        output_state: GraphState = copy.deepcopy(input_state)
        retriever: retrievers.MultiVectorRetriever = self.multi_vector_retriever.get_retriever(
            embedding_function=input_state["data"]["retriever_setting"]["embedding_function"],
            collection_name=input_state["data"]["retriever_setting"]["collection_name"],
            id_key=input_state["data"]["document_category"].id_key,
            search_kwargs={
                "k": input_state["data"]["retriever_setting"]["top_k"]
            }
        )
        documents: List[LangchainDocument] = (
                input_state["data"]["document_category"].texts +
                input_state["data"]["document_category"].tables +
                input_state["data"]["document_category"].images
        )
        document_contents: List[str] = []
        document_metadatas: List[Dict[str, Any]] = []
        document_ids: List[str] = []
        for document in documents:
            document_contents.append(document.page_content)
            document_metadatas.append(document.metadata)
            document_ids.append(document.metadata["id"])
            
        retriever.vectorstore.add_texts(
            texts=document_contents,
            metadatas=document_metadatas,
            ids=document_ids
        )
        retriever.docstore.mset(key_value_pairs=list(zip(document_ids, documents)))

        output_state["data"]["retrieved_documents"] = retriever.get_relevant_documents(
            query=input_state["data"]["question"]
        )
        
        return output_state

    def _get_node(self, node: RunnableLike) -> Tuple[str, RunnableLike]:
        return node.__name__, node

    def compile(self) -> CompiledGraph:
        graph: StateGraph = StateGraph(GraphState)
        graph.add_node(*self._get_node(self._retrieve))
        graph.set_entry_point(self._retrieve.__name__)
        graph.set_finish_point(self._retrieve.__name__)
        compiled_graph: CompiledGraph = graph.compile()
        return compiled_graph


graph_lfqa = GraphLongFormQa(
    multi_vector_retriever=multi_vector_retriever
)
compiled_graph_lfqa = graph_lfqa.compile()

data: Dict[str, Any] = {
    "retriever_setting": {
        "embedding_function": embeddings,
        "top_k": 5,
    },
    "question": "what is artificial intelligence?",
    "document_category": document_category,
}
data["retriever_setting"]["collection_name"] = f"lfqa_{hashlib.sha256(str(data).encode()).hexdigest()}"

input_state: GraphState = GraphState(
    data=data
)
output_state: GraphState = compiled_graph_lfqa.invoke(input_state)


DataError: Invalid input of type: 'Document'. Convert to a bytes, string, int or float first.

In [25]:
output_state["data"]["retrieved_documents"]

[b'as being arti\xef\xac\x81cial intelli- gence. Embedded computers, sensors, and other emerging technologies have facilitated the transfer of arti\xef\xac\x81cial intel- ligence to machines and other items, such as buildings and robots [11]. Indeed, Chassignol et al. provides a two-faceted de\xef\xac\x81nition and description of AI. They de\xef\xac\x81ne AI as a \xef\xac\x81eld and a theory. As a \xef\xac\x81eld of study, they de\xef\xac\x81ne AI as a study area in computer science whose pursuits are aimed at solving different cognitive problems commonly associated with the human intelligence, such as learning, problem solving, and pattern recognition, and subsequently adapting [11]. As a the- ory, Chassignol et al. de\xef\xac\x81ned AI as a theoretical framework guiding the development and use of computer systems with the capabilities of human beings, more particularly, intelli- gence and the ability to perform tasks that require human intelligence, including visual perception, speec

In [30]:
eval_data = amnesty_qa["eval"].select(range(1))
eval_data

Dataset({
    features: ['question', 'ground_truth', 'answer', 'contexts'],
    num_rows: 1
})

In [21]:
result = evaluate(
    eval_data,
    llm=llm,
    embeddings=embeddings,
    # metrics=[
    #     metrics.faithfulness,
    #     metrics.answer_relevancy, 
    #     metrics.context_recall,
    #     metrics.context_precision,
    #     metrics.answer_correctness,
    #     metrics.context_relevancy,
    #     metrics.context_entity_recall,
    # ],
)

Evaluating:   0%|          | 0/4 [00:00<?, ?it/s]


[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new[0m
LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True'.


Invalid JSON response. Expected dictionary with key 'Attributed'
  value = np.nanmean(self.scores[cn])


In [22]:
result

{'answer_relevancy': 0.9599, 'context_precision': 1.0000, 'faithfulness': 0.5000, 'context_recall': nan}