In [1]:
import base64
import hashlib
import io
import json
import os
import uuid

import numpy
from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings

os.chdir("/app")

from fastapi.encoders import jsonable_encoder
from langchain.retrievers import MultiVectorRetriever
from pymilvus.orm import utility
from unstructured.partition.html import partition_html
from unstructured.partition.text import partition_text

from tools import dict_tool

import gc

from tools.cache_tool import cacher
from apps.inners.exceptions import use_case_exception
from apps.inners.models.dtos.document_category import DocumentCategory

import shutil
from pathlib import Path

from apps.outers.settings.one_llm_setting import OneLlmSetting
from langchain_community.storage.redis import RedisStore
from langchain_core.embeddings import Embeddings
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import BaseMessage, ChatMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.documents import Document as LangChainDocument
from litellm import Router

from apps.inners.models.dtos.element_category import ElementCategory
from uuid import UUID

from sqlmodel.ext.asyncio.session import AsyncSession
from starlette.datastructures import State

from apps.inners.models.daos.document import Document
from apps.inners.models.dtos.contracts.responses.managements.documents.file_document_response import \
    FileDocumentResponse
from apps.inners.models.dtos.contracts.responses.managements.documents.text_document_response import \
    TextDocumentResponse
from apps.inners.models.dtos.contracts.responses.managements.documents.web_document_response import WebDocumentResponse
from apps.inners.use_cases.managements.document_management import DocumentManagement
from apps.inners.use_cases.managements.file_document_management import FileDocumentManagement
from apps.inners.use_cases.managements.text_document_management import TextDocumentManagement
from apps.inners.use_cases.managements.web_document_management import WebDocumentManagement
from typing import List, TypedDict
from typing import Tuple, Dict, Any

import dotenv
from datasets import load_dataset
from dotenv import find_dotenv
from langchain_community.chat_models import ChatLiteLLMRouter
from langchain_community.vectorstores.milvus import Milvus
from langchain_core.runnables.base import RunnableSerializable
from langgraph.graph import StateGraph
from langgraph.graph.graph import CompiledGraph
from ragas import evaluate
from unstructured.documents.elements import Element, Table, Image, Text
from unstructured.partition.auto import partition
from unstructured.partition.utils.constants import PartitionStrategy

from apps.inners.use_cases.embeddings.hugging_face_e5_instruct_embedding import HuggingFaceE5InstructEmbeddings
from apps.outers.datastores.four_datastore import FourDatastore
from apps.outers.datastores.one_datastore import OneDatastore
from apps.outers.datastores.three_datastore import ThreeDatastore
from apps.outers.datastores.two_datastore import TwoDatastore
from apps.outers.repositories.file_document_repository import FileDocumentRepository
from apps.outers.repositories.text_document_repository import TextDocumentRepository
from apps.outers.repositories.web_document_repository import WebDocumentRepository
from tests.containers.test_container import TestContainer
from tests.seeders.all_seeder import AllSeeder


In [2]:
keys = {"data": {"embedding": {"model_name", "query_instruction"}}}
# keys = {"data": {"x"}}

d = {
    "data": {
        "embedding": {
            "model_name": "intfloat/multilingual-e5-large-instruct",
            "query_instruction": "Given the question, retrieve the answer from the context."
        },
        "x": {
            "y": 1
        },
    }
}

kwargs = {
    "d": d
}
_kwargs_include_keys = ["d"]
# _kwargs_include_keys: Set[Any] = set([])
# _kwargs_include_keys = _kwargs_include_keys.union(set(kwargs.keys()))

args = (1, 2, 3, d)
dict_args: Dict[Any, Any] = {}
for key, arg in enumerate(args):
    dict_args[key] = arg

_args_include_keys = [0, {3: {"data": {"embedding": ["model_name"]}}}]

dict_tool.filter_by_keys(dict_args, _args_include_keys)

x = dict_tool.filter_by_keys(kwargs, _kwargs_include_keys)
dict_tool.replace_end_value_to_string(x)
# kwargs
# 

# cache_tool.clear_cache()
# 
# 
# # @cacher(kwargs_include_keys=keys)
# @cacher()
# def testx(x=None):
#     return x
# 
# 
# class testc:
#     def __init__(self):
#         pass
# 
#     @cacher(args_include_keys=[0, {1: {"data": ["embedding"]}}])
#     def testx(self, x=None):
#         return x
# 
# 
# testc().testx(d)
# 
# testx(x=d)
# 
# cache_tool.get_cache()

{'d': {'data': {'embedding': {'model_name': 'intfloat/multilingual-e5-large-instruct',
    'query_instruction': 'Given the question, retrieve the answer from the context.'},
   'x': {'y': '1'}}}}

In [4]:
# import tensorflow
# 
# tensorflow.config.list_physical_devices('GPU')

In [5]:
import torch

torch.cuda.is_available()

True

In [6]:
dotenv.load_dotenv(find_dotenv())


True

In [7]:
test_container = TestContainer()

one_llm_setting: OneLlmSetting = test_container.applications.settings.one_llm()

one_datastore: OneDatastore = test_container.applications.datastores.one()
two_datastore: TwoDatastore = test_container.applications.datastores.two()
three_datastore: ThreeDatastore = test_container.applications.datastores.three()
four_datastore: FourDatastore = test_container.applications.datastores.four()
temp_datastore: ThreeDatastore = test_container.applications.datastores.temp()

file_document_repository: FileDocumentRepository = test_container.applications.repositories.file_document()
text_document_repository: TextDocumentRepository = test_container.applications.repositories.text_document()
web_document_repository: WebDocumentRepository = test_container.applications.repositories.web_document()

document_management: DocumentManagement = test_container.applications.use_cases.managements.document()
file_document_management: FileDocumentManagement = test_container.applications.use_cases.managements.file_document()
text_document_management: TextDocumentManagement = test_container.applications.use_cases.managements.text_document()
web_document_management: WebDocumentManagement = test_container.applications.use_cases.managements.web_document()

all_seeder: AllSeeder = test_container.seeders.all()

In [8]:
await all_seeder.up()

In [None]:
await all_seeder.down()

In [5]:
await two_datastore.client.set("test", "test", ex=10)

True

In [7]:
# loading the V2 dataset
amnesty_qa = load_dataset("explodinggradients/amnesty_qa", "english_v2", trust_remote_code=True)



In [8]:
amnesty_qa

DatasetDict({
    eval: Dataset({
        features: ['question', 'ground_truth', 'answer', 'contexts'],
        num_rows: 20
    })
})

In [9]:
class MainDocumentProcessor:
    def __init__(self):
        pass

    def split_texts(self, texts: List[Text], chunk_size: int, chunk_overlap: int) -> List[str]:
        text_splitter: RecursiveCharacterTextSplitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap
        )
        text: str = " ".join([text.text for text in texts])
        splitted_text: List[str] = text_splitter.split_text(
            text=text
        )

        return splitted_text


class PartitionDocumentProcessor:
    def __init__(
            self,
            document_management: DocumentManagement,
            file_document_management: FileDocumentManagement,
            text_document_management: TextDocumentManagement,
            web_document_management: WebDocumentManagement,
    ):
        self.document_management = document_management
        self.file_document_management = file_document_management
        self.text_document_management = text_document_management
        self.web_document_management = web_document_management

    async def _partition_file(self, state: State, found_document: Document) -> List[Element]:
        found_file_document: FileDocumentResponse = await self.file_document_management.find_one_by_id_with_authorization(
            state=state,
            id=found_document.id
        )
        file_data: bytes = self.file_document_management.file_document_repository.get_object_data(
            object_name=found_file_document.file_name
        )
        extract_image_path: Path = self.file_document_management.file_document_repository.file_path / found_file_document.file_data_hash
        extract_image_path.mkdir(exist_ok=True)
        shutil.rmtree(extract_image_path)
        elements: List[Element] = partition(
            metadata_filename=found_file_document.file_name,
            file=io.BytesIO(file_data),
            extract_images_in_pdf=True,
            extract_image_block_output_dir=str(extract_image_path),
            strategy=PartitionStrategy.AUTO,
            hi_res_model_name="yolox"
        )

        return elements

    async def _partition_text(self, state: State, found_document: Document) -> List[Element]:
        found_text_document: TextDocumentResponse = await self.text_document_management.find_one_by_id_with_authorization(
            state=state,
            id=found_document.id
        )
        elements: List[Element] = partition_text(
            text=found_text_document.text_content
        )

        return elements

    async def _partition_web(self, state: State, found_document: Document) -> List[Element]:
        found_web_document: WebDocumentResponse = await self.web_document_management.find_one_by_id_with_authorization(
            state=state,
            id=found_document.id
        )
        elements: List[Element] = partition_html(
            url=found_web_document.web_url,
            ssl_verify=False
        )

        return elements

    async def partition(self, state: State, document_id: UUID) -> List[Element]:
        found_document: Document = await self.document_management.find_one_by_id_with_authorization(
            state=state,
            id=document_id
        )
        if found_document.document_type_id == "file":
            elements: List[Element] = await self._partition_file(
                state=state,
                found_document=found_document
            )
        elif found_document.document_type_id == "text":
            elements: List[Element] = await self._partition_text(
                state=state,
                found_document=found_document
            )
        elif found_document.document_type_id == "web":
            elements: List[Element] = await self._partition_web(
                state=state,
                found_document=found_document
            )
        else:
            raise use_case_exception.DocumentTypeNotSupported()

        return elements


class SummaryDocumentProcessor:
    def __init__(self):
        pass

    def summarize_tables(self, tables: List[Table], model: BaseChatModel) -> List[str]:
        prompt_text = """You are an assistant tasked with summarizing tables for retrieval. \
        These summaries will be embedded and used to retrieve the table. \
        Give a concise passage summary of the table that is well optimized for retrieval. \
        Make sure the output only the summary without re-explaining. \
        Table : {table} """
        prompt: ChatPromptTemplate = ChatPromptTemplate.from_template(prompt_text)
        chain: RunnableSerializable = {"table": lambda table: table.text} | prompt | model | StrOutputParser()
        summaries: List[str] = chain.batch(tables)

        return summaries

    def _get_message_from_image(self, model: BaseChatModel, prompt_text: str, image: Image) -> BaseMessage:
        message: BaseMessage = model.invoke([
            ChatMessage(
                role="user",
                content=[
                    {
                        "type": "text",
                        "text": prompt_text
                    },
                    {
                        "type": "image",
                        "source": {
                            "type": "base64",
                            "media_type": image.metadata.image_mime_type,
                            "data": image.metadata.image_base64
                        }
                    }
                ]
            )
        ])

        return message

    def summarize_images(self, images: List[Image], model: BaseChatModel) -> List[str]:
        prompt_text = """You are an assistant tasked with summarizing images for retrieval. \
        These summaries will be embedded and used to retrieve the image. \
        Give a concise passage summary of the image that is well optimized for retrieval. \
        Make sure the output only the summary without re-explaining. \
        """
        summaries: List[str] = []
        for image in images:
            message: BaseMessage = self._get_message_from_image(
                model=model,
                prompt_text=prompt_text,
                image=image
            )
            summaries.append(message.content)

        return summaries


class CategoryDocumentProcessor:
    def __init__(
            self,
            main_document_processor: MainDocumentProcessor,
            summary_document_processor: SummaryDocumentProcessor,
    ):
        self.main_document_processor = main_document_processor
        self.summary_document_processor = summary_document_processor

    async def categorize_elements(self, elements: List[Element]) -> ElementCategory:
        categorized_elements: ElementCategory = ElementCategory(
            texts=[],
            tables=[],
            images=[]
        )

        for element in elements:
            if any(
                    element_type in str(type(element)) for element_type in
                    ["unstructured.documents.elements.Text", "unstructured.documents.elements.NarrativeText"]
            ):
                categorized_elements.texts.append(element)
            elif any(
                    element_type in str(type(element)) for element_type in
                    ["unstructured.documents.elements.Table"]
            ):
                categorized_elements.tables.append(element)
            elif any(
                    element_type in str(type(element)) for element_type in
                    ["unstructured.documents.elements.Image"]
            ):
                file_io = open(element.metadata.image_path, "rb")
                element.metadata.image_mime_type = "image/jpeg"
                element.metadata.image_base64 = base64.b64encode(file_io.read()).decode("utf-8")
                file_io.close()
                categorized_elements.images.append(element)
            else:
                print(f"BaseDocumentProcessor.categorize_elements: Ignoring element type {type(element)}.")

        return categorized_elements

    def get_categorized_documents(
            self,
            categorized_elements: ElementCategory,
            summarization_model: BaseChatModel,
            is_include_tables: bool = False,
            is_include_images: bool = False,
            chunk_size: int = 400,
            chunk_overlap: int = int(400 * 0.1),
            id_key: str = "id"
    ) -> DocumentCategory:
        document_category: DocumentCategory = DocumentCategory(
            texts=[],
            tables=[],
            images=[],
            id_key=id_key
        )
        splitted_texts: List[str] = self.main_document_processor.split_texts(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap,
            texts=categorized_elements.texts
        )
        for text in splitted_texts:
            document_category.texts.append(LangChainDocument(
                page_content=text,
                metadata={
                    id_key: str(uuid.uuid4())
                }
            ))

        if is_include_tables:
            summarized_tables: List[str] = self.summary_document_processor.summarize_tables(
                tables=categorized_elements.tables,
                model=summarization_model
            )
            for table in summarized_tables:
                document_category.tables.append(LangChainDocument(
                    page_content=table,
                    metadata={
                        id_key: str(uuid.uuid4())
                    }
                ))

        if is_include_images:
            summarized_images: List[str] = self.summary_document_processor.summarize_images(
                images=categorized_elements.images,
                model=summarization_model
            )
            for image, summarized_image in zip(categorized_elements.images, summarized_images):
                document_category.images.append(LangChainDocument(
                    page_content=summarized_image,
                    metadata={
                        id_key: str(uuid.uuid4()),
                        "image": {
                            "mime_type": image.metadata.image_mime_type,
                            "base64": image.metadata.image_base64
                        }
                    }
                ))

        return document_category


partition_document_processor: PartitionDocumentProcessor = PartitionDocumentProcessor(
    document_management=document_management,
    file_document_management=file_document_management,
    text_document_management=text_document_management,
    web_document_management=web_document_management,
)

main_document_processor: MainDocumentProcessor = MainDocumentProcessor()
summary_document_processor: SummaryDocumentProcessor = SummaryDocumentProcessor()
category_document_processor: CategoryDocumentProcessor = CategoryDocumentProcessor(
    main_document_processor=main_document_processor,
    summary_document_processor=summary_document_processor
)


In [None]:
class GraphState(TypedDict):
    data: Dict[str, Any]


class GraphLongFormQa:
    def __init__(
            self,
            one_llm_setting: OneLlmSetting,
            two_datastore: TwoDatastore,
            four_datastore: FourDatastore,
            category_document_processor: CategoryDocumentProcessor,
    ):
        self.one_llm_setting = one_llm_setting
        self.two_datastore = two_datastore
        self.four_datastore = four_datastore
        self.category_document_processor = category_document_processor

    def node_get_model(self, input_state: GraphState) -> GraphState:
        output_state: GraphState = input_state
        model_list: List[Dict] = [
            {
                "model_name": "claude-3-haiku",
                "litellm_params": {
                    "model": "claude-3-haiku-20240307",
                    "api_key": self.one_llm_setting.LLM_ONE_ANTHROPIC_API_KEY_ONE,
                }
            },
            {
                "model_name": "claude-3-opus",
                "litellm_params": {
                    "model": "claude-3-opus-20240229",
                    "api_key": self.one_llm_setting.LLM_ONE_ANTHROPIC_API_KEY_ONE,
                }
            }
        ]
        router: Router = Router(model_list=model_list)
        model: ChatLiteLLMRouter = ChatLiteLLMRouter(
            router=router,
            model_name=input_state["data"]["llm"]["model_name"],
            streaming=True,
            temperature=0,
        )
        output_state["data"]["llm"]["model"] = model

        return output_state

    @cacher(args_include_keys=[], kwargs_include_keys=["model_name"])
    def _get_embedding_model(self, model_name: str) -> Embeddings:
        if model_name == "intfloat/multilingual-e5-large-instruct":
            model: HuggingFaceEmbeddings = HuggingFaceEmbeddings(
                model_name=model_name,
                model_kwargs={'device': 'cuda'},
                encode_kwargs={'normalize_embeddings': True},
            )
        else:
            raise use_case_exception.EmbeddingModelNameNotSupported()

        return model

    def node_get_embeddings(self, input_state: GraphState) -> GraphState:
        output_state: GraphState = input_state
        model: Embeddings = self._get_embedding_model(
            model_name=input_state["data"]["embedding"]["model_name"]
        )

        output_state["data"]["embedding"]["model"] = model

        return output_state

    async def node_get_categorized_documents(self, input_state: GraphState) -> GraphState:
        output_state: GraphState = input_state
        document_id: UUID = input_state["data"]["document_id"]

        categorized_document_hash: str = self._get_categorized_document_hash(
            document_id=document_id,
            preprocessor_setting=input_state["data"]["preprocessor_setting"]
        )
        input_state["data"]["categorized_document_hash"] = categorized_document_hash
        existing_categorized_document_hash: int = await self.two_datastore.client.exists(categorized_document_hash)
        if existing_categorized_document_hash == 0:
            is_categorized_document_exist: bool = False
        elif existing_categorized_document_hash == 1:
            is_categorized_document_exist: bool = True
        else:
            raise use_case_exception.ExistingCategorizedDocumentHashInvalid

        if is_categorized_document_exist is False:
            elements: List[Element] = await partition_document_processor.partition(
                state=input_state["data"]["state"],
                document_id=document_id
            )
            categorized_elements: ElementCategory = await self.category_document_processor.categorize_elements(
                elements=elements
            )
            categorized_documents: DocumentCategory = self.category_document_processor.get_categorized_documents(
                categorized_elements=categorized_elements,
                summarization_model=input_state["data"]["llm"]["model"],
                is_include_tables=input_state["data"]["preprocessor_setting"]["is_include_tables"],
                is_include_images=input_state["data"]["preprocessor_setting"]["is_include_images"],
                chunk_size=input_state["data"]["preprocessor_setting"]["chunk_size"],
                chunk_overlap=input_state["data"]["preprocessor_setting"]["chunk_overlap"],
            )
            await self.two_datastore.client.set(
                name=categorized_document_hash,
                value=json.dumps(categorized_documents.dict(), default=jsonable_encoder)
            )
        else:
            found_categorized_document_bytes: bytes = await self.two_datastore.client.get(categorized_document_hash)
            categorized_documents: DocumentCategory = DocumentCategory(**json.loads(found_categorized_document_bytes))

        output_state["data"]["categorized_documents"] = categorized_documents

        return output_state

    def _get_categorized_document_hash(self, document_id: UUID, preprocessor_setting: Dict[str, Any]) -> str:
        data: Dict[str, Any] = {
            "document_id": document_id,
            "preprocessor_setting": preprocessor_setting,
        }
        hashed_data: str = hashlib.sha256(
            string=json.dumps(data, sort_keys=True, default=jsonable_encoder).encode()
        ).hexdigest()

        return hashed_data

    def _get_collection_name_hash(self, categorized_document_hash: str, embedding_model_name: str,
                                  prefix: str = "lfqa") -> str:
        data: Dict[str, Any] = {
            "categorized_document_hash": categorized_document_hash,
            "embedding_model_name": embedding_model_name,
        }
        hashed_data: str = hashlib.sha256(
            string=json.dumps(data, sort_keys=True, default=jsonable_encoder).encode()
        ).hexdigest()
        collection_name: str = f"{prefix}_{hashed_data}"

        return collection_name

    async def node_retrieve(self, input_state: GraphState) -> GraphState:
        output_state: GraphState = input_state
        categorized_documents: DocumentCategory = input_state["data"]["categorized_documents"]
        documents: List[LangChainDocument] = (
                categorized_documents.texts +
                categorized_documents.tables +
                categorized_documents.images
        )
        document_contents: List[str] = []
        document_meta_datas: List[Dict[str, Any]] = []
        document_ids: List[str] = []
        document_key_value_pairs: List[Tuple[Any, Any]] = []
        for document in documents:
            document_contents.append(document.page_content)
            document_meta_datas.append(document.metadata)
            document_ids.append(document.metadata[categorized_documents.id_key])
            document_key_value_pairs.append(
                (document.metadata[categorized_documents.id_key],
                 bytes(json.dumps(document.dict(), default=jsonable_encoder).encode()))
            )

        embedding_model: Embeddings = input_state["data"]["embedding"]["model"]
        collection_name: str = self._get_collection_name_hash(
            categorized_document_hash=input_state["data"]["categorized_document_hash"],
            embedding_model_name=input_state["data"]["embedding"]["model_name"]
        )
        document_store: RedisStore = RedisStore(
            redis_url=self.two_datastore.two_datastore_setting.URL,
        )
        vector_store: Milvus = self.four_datastore.get_client(
            embedding_function=embedding_model,
            collection_name=collection_name,
        )
        retriever: MultiVectorRetriever = MultiVectorRetriever(
            vectorstore=vector_store,
            docstore=document_store,
            collection_name=collection_name,
            id_key=input_state["data"]["categorized_documents"].id_key,
            search_kwargs={
                "k": input_state["data"]["retriever_setting"]["top_k"]
            }
        )

        is_collection_exists: bool = utility.has_collection(collection_name, using=vector_store.alias)
        is_force_refresh_embedding: bool = input_state["data"]["retriever_setting"]["is_force_refresh_embedding"]
        if is_collection_exists is False or is_force_refresh_embedding is True:
            utility.drop_collection(collection_name, using=vector_store.alias)
            await retriever.vectorstore.aadd_texts(
                texts=document_contents,
                metadatas=document_meta_datas,
                ids=document_ids
            )
            await retriever.docstore.amset(key_value_pairs=document_key_value_pairs)

        query: str = HuggingFaceE5InstructEmbeddings.get_detailed_instruct(
            task_description=input_state["data"]["embedding"]["query_instruction"],
            query=input_state["data"]["question"]
        )
        vector_store_retrieved_documents: List[
            Tuple[LangChainDocument, float]
        ] = await retriever.vectorstore.asimilarity_search_with_score(
            query=query,
            **retriever.search_kwargs
        )

        vector_store_retrieved_document_ids: List[str] = []
        for vector_store_retrieved_document in vector_store_retrieved_documents:
            vector_store_retrieved_document_ids.append(
                vector_store_retrieved_document[0].metadata[categorized_documents.id_key])

        doc_store_retrieved_documents: List[LangChainDocument | None] = await retriever.docstore.amget(
            keys=vector_store_retrieved_document_ids
        )

        decoded_retrieved_documents: List[LangChainDocument] = []
        for vector_store_retrieved_document, doc_store_retrieved_documents in zip(
                vector_store_retrieved_documents, doc_store_retrieved_documents
        ):
            decoded_retrieved_document: LangChainDocument = LangChainDocument(
                **json.loads(doc_store_retrieved_documents.decode())
            )
            decoded_retrieved_document.metadata["score"] = vector_store_retrieved_document[1]
            decoded_retrieved_documents.append(decoded_retrieved_document)

        decoded_retrieved_documents.sort(
            key=lambda x: x.metadata["score"],
            reverse=True
        )
        output_state["data"]["retrieved_documents"] = decoded_retrieved_documents

        return output_state

    def compile(self) -> CompiledGraph:
        graph: StateGraph = StateGraph(GraphState)

        graph.add_node(self.node_get_model.__name__, self.node_get_model)
        graph.add_node(self.node_get_embeddings.__name__, self.node_get_embeddings)
        graph.add_node(self.node_get_categorized_documents.__name__, self.node_get_categorized_documents)
        graph.add_node(self.node_retrieve.__name__, self.node_retrieve)

        graph.set_entry_point(self.node_get_model.__name__)

        graph.add_edge(self.node_get_model.__name__, self.node_get_embeddings.__name__)
        graph.add_edge(self.node_get_embeddings.__name__, self.node_get_categorized_documents.__name__)
        graph.add_edge(self.node_get_categorized_documents.__name__, self.node_retrieve.__name__)

        graph.set_finish_point(self.node_retrieve.__name__)

        compiled_graph: CompiledGraph = graph.compile()

        return compiled_graph


output_state: GraphState


async def handler(session: AsyncSession):
    global output_state

    state: State = State()
    state.authorized_session = all_seeder.session_seeder.session_mock.data[0]
    state.session = session

    graph_lfqa = GraphLongFormQa(
        one_llm_setting=one_llm_setting,
        two_datastore=two_datastore,
        four_datastore=four_datastore,
        category_document_processor=category_document_processor
    )
    compiled_graph_lfqa = graph_lfqa.compile()

    data: Dict[str, Any] = {
        "state": state,
        "document_id": all_seeder.file_document_seeder.file_document_mock.data[0].id,
        "llm": {
            "model_name": "claude-3-haiku"
        },
        "embedding": {
            "model_name": "intfloat/multilingual-e5-large-instruct",
            "query_instruction": "Given the question, retrieve the answer from the context.",
        },
        "preprocessor_setting": {
            "chunk_size": 50,
            "chunk_overlap": numpy.floor(50 * 0.1),
            "is_include_tables": False,
            "is_include_images": False,
        },
        "retriever_setting": {
            "top_k": 3,
            "is_force_refresh_embedding": False,
        },
        "question": "what is lorem ipsum?",
    }

    input_state: GraphState = GraphState(
        data=data
    )
    output_state = await compiled_graph_lfqa.ainvoke(input_state)


await one_datastore.retryable(handler)

torch.cuda.empty_cache()
gc.collect()

In [15]:
# len(output_state["data"]["categorized_documents"].texts[0].page_content)
output_state

{'data': {'state': <starlette.datastructures.State at 0x7f181578aad0>,
  'document_id': UUID('0c92d4aa-50c5-44a3-9959-fac3bca6d15f'),
  'llm': {'model_name': 'claude-3-haiku',
   'model': ChatLiteLLMRouter(client=<module 'litellm' from '/usr/local/lib/python3.10/dist-packages/litellm/__init__.py'>, model_name='claude-3-haiku', openai_api_key='', azure_api_key='', anthropic_api_key='', replicate_api_key='', cohere_api_key='', openrouter_api_key='', streaming=True, temperature=0.0, router=<litellm.router.Router object at 0x7f17aefdff10>, huggingface_api_key='', together_ai_api_key='')},
  'embedding': {'model_name': 'intfloat/multilingual-e5-large-instruct',
   'query_instruction': 'Given the question, retrieve the answer from the context.',
   'model': HuggingFaceEmbeddings(client=SentenceTransformer(
     (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: XLMRobertaModel 
     (1): Pooling({'word_embedding_dimension': 1024, 'pooling_mode_cls_token

In [13]:
output_state["data"]["retrieved_documents"]

[Document(page_content='sheets. [1] Lorem ipsum was introduced to the digital world in the mid-1980s, when Aldus employed it in graphic and word-processing templates for its desktop publishing program PageMaker. Other popular word processors, including Pages', metadata={'id': 'fd6c0392-0ddb-45a4-8ece-42eb671060c2', 'score': 0.19670191407203674}),
 Document(page_content='the design. Lorem ipsum is typically a corrupted version of De finibus bonorum et malorum, a 1st-century BC text by the Roman statesman and philosopher Cicero, with words altered, added, and removed to make', metadata={'id': '00501abf-09ab-404e-aa6f-4fcda0f7f041', 'score': 0.1857132613658905}),
 Document(page_content='In publishing and graphic design, Lorem ipsum is a placeholder text commonly used to demonstrate the visual form of a documents or a typeface without relying on meaningful content. Lorem ipsum may be used as a placeholder before final copy is', metadata={'id': 'd8a5b28c-598e-4e37-a00a-60378311ebd6', 'score

In [24]:
eval_data = amnesty_qa["eval"].select(range(1))
eval_data

NameError: name 'amnesty_qa' is not defined

In [21]:
result = evaluate(
    eval_data,
    llm=llm,
    embeddings=embeddings,
    # metrics=[
    #     metrics.faithfulness,
    #     metrics.answer_relevancy, 
    #     metrics.context_recall,
    #     metrics.context_precision,
    #     metrics.answer_correctness,
    #     metrics.context_relevancy,
    #     metrics.context_entity_recall,
    # ],
)

Evaluating:   0%|          | 0/4 [00:00<?, ?it/s]


[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new[0m
LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True'.


Invalid JSON response. Expected dictionary with key 'Attributed'
  value = np.nanmean(self.scores[cn])


In [22]:
result

{'answer_relevancy': 0.9599, 'context_precision': 1.0000, 'faithfulness': 0.5000, 'context_recall': nan}