In [1]:
from __future__ import annotations

import base64
import io
import json
import logging
import os
import uuid
from abc import ABC, abstractmethod

from langchain_anthropic import ChatAnthropic
from langchain_anthropic.output_parsers import ToolsOutputParser
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.retrievers import BaseRetriever
from langchain_core.stores import BaseStore
from milvus_model.base import RerankResult
from milvus_model.hybrid import BGEM3EmbeddingFunction
from milvus_model.reranker import BGERerankFunction
from pymilvus import FieldSchema, DataType, CollectionSchema, Collection, RRFRanker, AnnSearchRequest, \
    SearchResult, Hits
from pymilvus.client.types import LoadState
from ragas.testset import TestsetGenerator, evolutions
from unstructured.chunking.basic import chunk_elements
from unstructured.documents.html import HTMLText, HTMLTable

os.chdir("/app")

from pydantic.v1 import Field as FieldV1

from apps.inners.models.base_model import BaseModelV1

from langchain_community.embeddings.infinity import InfinityEmbeddings

from apps.outers.settings.one_embedding_setting import OneEmbeddingSetting

from fastapi.encoders import jsonable_encoder
from pymilvus.orm import utility
from unstructured.partition.html import partition_html
from unstructured.partition.text import partition_text

from tools import cache_tool, dict_tool

import gc

from apps.inners.exceptions import use_case_exception
from apps.inners.models.dtos.document_category import DocumentCategory

import shutil
from pathlib import Path

from apps.outers.settings.one_llm_setting import OneLlmSetting
from langchain_community.storage.redis import RedisStore
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.documents import Document as LangChainDocument
from litellm import Router

from apps.inners.models.dtos.element_category import ElementCategory
from uuid import UUID

from sqlmodel.ext.asyncio.session import AsyncSession
from starlette.datastructures import State

from apps.inners.models.daos.document import Document
from apps.inners.models.dtos.contracts.responses.managements.documents.file_document_response import \
    FileDocumentResponse
from apps.inners.models.dtos.contracts.responses.managements.documents.text_document_response import \
    TextDocumentResponse
from apps.inners.models.dtos.contracts.responses.managements.documents.web_document_response import WebDocumentResponse
from apps.inners.use_cases.managements.document_management import DocumentManagement
from apps.inners.use_cases.managements.file_document_management import FileDocumentManagement
from apps.inners.use_cases.managements.text_document_management import TextDocumentManagement
from apps.inners.use_cases.managements.web_document_management import WebDocumentManagement
from typing import List, TypedDict, Optional, Union
from typing import Tuple, Dict, Any

import dotenv
from datasets import load_dataset
from dotenv import find_dotenv
from langchain_core.runnables.base import RunnableSerializable
from langgraph.graph import StateGraph
from langgraph.graph.graph import CompiledGraph, END
from ragas import evaluate, metrics
from unstructured.documents.elements import Element, Table, Image, Text, NarrativeText
from unstructured.partition.auto import partition
from unstructured.partition.utils.constants import PartitionStrategy

from apps.inners.use_cases.embeddings.hugging_face_e5_instruct_embedding import HuggingFaceE5InstructEmbeddings
from apps.outers.datastores.four_datastore import FourDatastore
from apps.outers.datastores.one_datastore import OneDatastore
from apps.outers.datastores.three_datastore import ThreeDatastore
from apps.outers.datastores.two_datastore import TwoDatastore
from apps.outers.repositories.file_document_repository import FileDocumentRepository
from apps.outers.repositories.text_document_repository import TextDocumentRepository
from apps.outers.repositories.web_document_repository import WebDocumentRepository
from tests.containers.test_container import TestContainer
from tests.seeders.all_seeder import AllSeeder

2024-04-10 15:04:00.447454: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-04-10 15:04:00.508574: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-10 15:04:00.508623: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-10 15:04:00.511840: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-04-10 15:04:00.570558: I tensorflow/core/platform/cpu_feature_guar

In [2]:
# !pip show flagembedding
# !pip show langchain-anthropic
# !pip show pymilvus
# !pip show opencv-python

In [3]:
import tensorflow

tensorflow.config.list_physical_devices('GPU')

2024-04-10 15:04:07.732891: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:887] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.


[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

2024-04-10 15:04:07.733729: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:887] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2024-04-10 15:04:07.733752: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:887] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.


In [4]:
import torch

torch.cuda.is_available()

True

In [5]:
dotenv.load_dotenv(find_dotenv())


True

In [6]:
test_container = TestContainer()

one_llm_setting: OneLlmSetting = test_container.applications.settings.one_llm()
one_embedding_setting: OneEmbeddingSetting = test_container.applications.settings.one_embedding()

one_datastore: OneDatastore = test_container.applications.datastores.one()
two_datastore: TwoDatastore = test_container.applications.datastores.two()
three_datastore: ThreeDatastore = test_container.applications.datastores.three()
four_datastore: FourDatastore = test_container.applications.datastores.four()
temp_datastore: ThreeDatastore = test_container.applications.datastores.temp()

file_document_repository: FileDocumentRepository = test_container.applications.repositories.file_document()
text_document_repository: TextDocumentRepository = test_container.applications.repositories.text_document()
web_document_repository: WebDocumentRepository = test_container.applications.repositories.web_document()

document_management: DocumentManagement = test_container.applications.use_cases.managements.document()
file_document_management: FileDocumentManagement = test_container.applications.use_cases.managements.file_document()
text_document_management: TextDocumentManagement = test_container.applications.use_cases.managements.text_document()
web_document_management: WebDocumentManagement = test_container.applications.use_cases.managements.web_document()

all_seeder: AllSeeder = test_container.seeders.all()

In [7]:
await all_seeder.up()

In [19]:
await all_seeder.down()

In [5]:
await two_datastore.async_client.set("test", "test", ex=10)

True

In [8]:
class BaseReranker(ABC):
    @abstractmethod
    def rerank(self, query: str, texts: List[str], top_k: int) -> Any:
        pass


class BgeReranker(BaseReranker):
    """
    For normal reranker only (bge-reranker-base/bge-reranker-large/bge-reranker-v2-m3).
    """

    def __init__(
            self,
            model_name: str = "BAAI/bge-reranker-v2-m3",
            use_fp16: bool = True,
            batch_size: int = 32,
            normalize: bool = True,
            device: Optional[str] = None,
    ):
        self.model_name = model_name
        self.use_fp16 = use_fp16
        self.batch_size = batch_size
        self.normalize = normalize
        self.device = device
        self._reranker_model: BGERerankFunction = BGERerankFunction(
            model_name=self.model_name,
            use_fp16=self.use_fp16,
            batch_size=self.batch_size,
            normalize=self.normalize,
            device=self.device
        )

    def rerank(self, query: str, texts: List[str], top_k: int) -> List[Dict[str, Any]]:
        results: List[RerankResult] = self._reranker_model(
            query=query,
            documents=texts,
            top_k=top_k
        )
        result_dicts: List[Dict[str, Any]] = [result.to_dict() for result in results]
        result_dicts.sort(
            key=lambda x: x["score"],
            reverse=True
        )

        return result_dicts

In [9]:
class BaseEmbedding(ABC):
    @abstractmethod
    def encode_documents(self, texts: List[str]) -> Union[List[Any], Dict[str, Any]]:
        pass

    @abstractmethod
    def encode_queries(self, texts: List[str]) -> Union[List[Any], Dict[str, Any]]:
        pass


class BgeM3Embedding(BaseEmbedding):

    def __init__(
            self,
            model_name: str = "BAAI/bge-m3",
            batch_size: int = 16,
            device: Optional[str] = None,
            normalize_embeddings: bool = True,
            use_fp16: bool = True,
            return_dense: bool = True,
            return_sparse: bool = True,
            return_colbert_vecs: bool = True,
    ):
        self.model_name = model_name
        self.batch_size = batch_size
        self.device = device
        self.normalize_embeddings = normalize_embeddings
        self.use_fp16 = use_fp16
        self.return_dense = return_dense
        self.return_sparse = return_sparse
        self.return_colbert_vecs = return_colbert_vecs
        self._embedding_model: BGEM3EmbeddingFunction = BGEM3EmbeddingFunction(
            model_name=self.model_name,
            batch_size=self.batch_size,
            device=self.device,
            normalize_embeddings=self.normalize_embeddings,
            use_fp16=self.use_fp16,
            return_dense=self.return_dense,
            return_sparse=self.return_sparse,
            return_colbert_vecs=self.return_colbert_vecs,
        )

    def encode_documents(self, texts: List[str]) -> Dict[str, Any]:
        return self._embedding_model.encode_documents(texts)

    def encode_queries(self, texts: List[str]) -> Dict[str, Any]:
        return self._embedding_model.encode_queries(texts)

    @property
    def dimensions(self) -> Dict[str, Any]:
        return self._embedding_model.dim

In [10]:
logger = logging.getLogger(__name__)


class BaseMilvusVectorStore(ABC):

    def __init__(
            self,
            collection_name: str,
            vector_field_dimensions: Dict[str, Any],
            alias: str = None,
            consistency_level: str = "Strong",
            collection_properties: Optional[Dict[str, Any]] = None,
            drop_old_collection: bool = False,
            id_field_name: str = "id",
    ):
        self.collection_name = collection_name
        self.vector_field_dimensions = vector_field_dimensions
        self.alias = alias
        self.consistency_level = consistency_level
        self.collection_properties = collection_properties
        self.drop_old_collection = drop_old_collection
        self.id_field_name = id_field_name
        self._default_search_params = {
            "SPARSE_INVERTED_INDEX": {"metric_type": "IP"},
            "IVF_FLAT": {"metric_type": "L2"},
            "IVF_SQ8": {"metric_type": "L2"},
            "IVF_PQ": {"metric_type": "L2"},
            "HNSW": {"metric_type": "L2"},
            "RHNSW_FLAT": {"metric_type": "L2"},
            "RHNSW_SQ": {"metric_type": "L2"},
            "RHNSW_PQ": {"metric_type": "L2"},
            "IVF_HNSW": {"metric_type": "L2"},
            "ANNOY": {"metric_type": "L2"},
            "SCANN": {"metric_type": "L2"},
            "AUTOINDEX": {"metric_type": "L2"},
            "GPU_CAGRA": {"metric_type": "L2"},
            "GPU_IVF_FLAT": {"metric_type": "L2"},
            "GPU_IVF_PQ": {"metric_type": "L2"},
        }
        self.collection: Optional[Collection] = None

    def initialize_collection(self):
        if self.has_collection():
            self.collection = Collection(
                self.collection_name,
                using=self.alias,
            )

            if self.collection_properties is not None:
                self.collection.set_properties(self.collection_properties)

            if self.drop_old_collection:
                self.drop_collection()
        else:
            self.collection = None

        if self.collection is None:
            self.collection = self._create_collection()

        self._create_index()

        if utility.load_state(self.collection_name, using=self.alias) == LoadState.NotLoad:
            self.collection.load()

    def drop_collection(self):
        self.collection.drop()
        self.collection = None

    def has_collection(self) -> bool:
        return utility.has_collection(self.collection_name, using=self.alias)

    @abstractmethod
    def _create_index(self):
        pass

    @abstractmethod
    def _create_collection(self):
        pass

    @abstractmethod
    def embed_texts(
            self,
            texts: List[str],
            ids: List[str],
            batch_size: int = 1000
    ):
        pass

    @abstractmethod
    def search(self, query: str, top_k: int) -> Hits:
        pass


class MilvusBgeM3VectorStore(BaseMilvusVectorStore):

    def __init__(
            self,
            embedding_model: BgeM3Embedding,
            *args: Any,
            sparse_vector_field_name: str = "sparse_vector",
            dense_vector_field_name: str = "dense_vector",
            sparse_vector_index_type: str = "SPARSE_INVERTED_INDEX",
            dense_vector_index_type: str = "GPU_CAGRA",
            **kwargs: Any
    ):
        vector_field_dimensions: Dict[str, Any] = {
            sparse_vector_field_name: embedding_model.dimensions["sparse"],
            dense_vector_field_name: embedding_model.dimensions["dense"]
        }
        kwargs["vector_field_dimensions"] = vector_field_dimensions
        super().__init__(*args, **kwargs)
        self.embedding_model = embedding_model
        self.sparse_vector_field_name = sparse_vector_field_name
        self.dense_vector_field_name = dense_vector_field_name
        self.sparse_vector_index_type = sparse_vector_index_type
        self.dense_vector_index_type = dense_vector_index_type
        self.initialize_collection()

    def _create_index(self):
        sparse_vector_field_index_params: Dict[str, Any] = self._default_search_params[self.sparse_vector_index_type]
        sparse_vector_field_index_params["index_type"] = self.sparse_vector_index_type
        self.collection.create_index(
            field_name=self.sparse_vector_field_name,
            index_params=sparse_vector_field_index_params
        )
        dense_vector_field_index_params: Dict[str, Any] = self._default_search_params[self.dense_vector_index_type]
        dense_vector_field_index_params["index_type"] = self.dense_vector_index_type
        self.collection.create_index(
            field_name=self.dense_vector_field_name,
            index_params=dense_vector_field_index_params
        )

    def _create_collection(self):
        fields: List[FieldSchema] = [
            FieldSchema(
                name=self.id_field_name,
                dtype=DataType.VARCHAR,
                is_primary=True,
                auto_id=False,
                max_length=65535
            ),
            FieldSchema(
                name=self.sparse_vector_field_name,
                dtype=DataType.SPARSE_FLOAT_VECTOR,
            ),
            FieldSchema(
                name=self.dense_vector_field_name,
                dtype=DataType.FLOAT_VECTOR,
                dim=self.vector_field_dimensions[self.dense_vector_field_name]
            )
        ]

        schema = CollectionSchema(
            fields=fields,
        )

        collection: Collection = Collection(
            name=self.collection_name,
            schema=schema,
            using=self.alias,
            consistency_level=self.consistency_level,
        )
        if self.collection_properties is not None:
            self.collection.set_properties(self.collection_properties)

        return collection

    def embed_texts(
            self,
            texts: List[str],
            ids: List[str],
            batch_size: int = 1000
    ):
        embeddings: Dict[str, Any] = self.embedding_model.encode_documents(texts)

        total_count: int = len(ids)
        for start_index in range(0, total_count, batch_size):
            end_index: int = min(start_index + batch_size, total_count)
            data: List[Any] = [
                ids[start_index:end_index],
                embeddings["sparse"][start_index:end_index],
                embeddings["dense"][start_index:end_index],
            ]
            self.collection.insert(data)

    def search(self, query: str, top_k: int) -> Hits:
        embeddings: Dict[str, Any] = self.embedding_model.encode_queries(texts=[query])

        output_fields = [
            self.id_field_name,
        ]
        search_requests: List[AnnSearchRequest] = [
            AnnSearchRequest(
                data=embeddings["sparse"],
                anns_field=self.sparse_vector_field_name,
                limit=top_k,
                param=self._default_search_params[self.sparse_vector_index_type]
            ),
            AnnSearchRequest(
                data=embeddings["dense"],
                anns_field=self.dense_vector_field_name,
                limit=top_k,
                param=self._default_search_params[self.dense_vector_index_type]
            )
        ]
        search_result: SearchResult = self.collection.hybrid_search(
            reqs=search_requests,
            output_fields=output_fields,
            rerank=RRFRanker(),
            limit=top_k,

        )
        outputs: Hits = search_result[0]

        return outputs


class MilvusInfinityVectorStore(BaseMilvusVectorStore):

    def __init__(
            self,
            embedding_model: InfinityEmbeddings,
            embedding_dimension: int,
            *args: Any,
            dense_vector_field_name: str = "dense_vector",
            dense_vector_index_type: str = "GPU_CAGRA",
            **kwargs: Any
    ):
        vector_field_dimensions: Dict[str, Any] = {
            dense_vector_field_name: embedding_dimension
        }
        kwargs["vector_field_dimensions"] = vector_field_dimensions
        super().__init__(*args, **kwargs)
        self.embedding_model = embedding_model
        self.dense_vector_field_name = dense_vector_field_name
        self.dense_vector_index_type = dense_vector_index_type
        self.initialize_collection()

    def _create_index(self):
        dense_vector_field_index_params: Dict[str, Any] = self._default_search_params[self.dense_vector_index_type]
        dense_vector_field_index_params["index_type"] = self.dense_vector_index_type
        self.collection.create_index(
            field_name=self.dense_vector_field_name,
            index_params=dense_vector_field_index_params
        )

    def _create_collection(self):
        fields: List[FieldSchema] = [
            FieldSchema(
                name=self.id_field_name,
                dtype=DataType.VARCHAR,
                is_primary=True,
                auto_id=False,
                max_length=65535
            ),
            FieldSchema(
                name=self.dense_vector_field_name,
                dtype=DataType.FLOAT_VECTOR,
                dim=self.vector_field_dimensions[self.dense_vector_field_name]
            )
        ]

        schema = CollectionSchema(
            fields=fields,
        )

        collection: Collection = Collection(
            name=self.collection_name,
            schema=schema,
            using=self.alias,
            consistency_level=self.consistency_level,
        )
        if self.collection_properties is not None:
            self.collection.set_properties(self.collection_properties)

        return collection

    def embed_texts(
            self,
            texts: List[str],
            ids: List[str],
            batch_size: int = 1000
    ):
        embeddings: List[List[float]] = self.embedding_model.embed_documents(
            texts=texts
        )
        total_count: int = len(ids)
        for start_index in range(0, total_count, batch_size):
            end_index: int = min(start_index + batch_size, total_count)
            data: List[Any] = [
                ids[start_index:end_index],
                embeddings[start_index:end_index],
            ]
            self.collection.insert(data)

    def search(self, query: str, top_k: int) -> Hits:
        embeddings: List[float] = self.embedding_model.embed_query(
            text=query,
        )

        output_fields = [
            self.id_field_name,
        ]
        search_result: SearchResult = self.collection.search(
            data=[embeddings],
            anns_field=self.dense_vector_field_name,
            limit=top_k,
            param=self._default_search_params[self.dense_vector_index_type],
            output_fields=output_fields,
        )
        outputs: Hits = search_result[0]

        return outputs



In [11]:
class MilvusHybridRetriever(BaseRetriever):
    document_store: BaseStore[str, LangChainDocument]
    vector_store: BaseMilvusVectorStore
    search_kwargs: Dict[str, Any]
    id_key: Optional[str] = None

    def __init__(self, **kwargs: Any):
        super().__init__(**kwargs)
        if self.id_key is None:
            self.id_key = self.vector_store.id_field_name

    def _get_relevant_documents(
            self, query: str, *, run_manager: CallbackManagerForRetrieverRun
    ) -> List[LangChainDocument]:
        vector_store_retrieved_documents: Hits = self.vector_store.search(
            query=query,
            **self.search_kwargs
        )
        vector_store_retrieved_document_ids: List[str] = [hits.get(self.id_key) for hits in
                                                          vector_store_retrieved_documents]
        document_store_retrieved_documents: List[Optional[bytes]] = self.document_store.mget(
            keys=vector_store_retrieved_document_ids
        )
        decoded_retrieved_documents: List[LangChainDocument] = []

        for vector_store_retrieved_document, document_store_retrieved_document in zip(
                vector_store_retrieved_documents, document_store_retrieved_documents, strict=True
        ):
            if document_store_retrieved_document is None:
                self.vector_store.collection.delete(
                    expr=f"id in {vector_store_retrieved_document_ids}"
                )
                self.document_store.mdelete(keys=vector_store_retrieved_document_ids)
                raise use_case_exception.DocumentStoreRetrieveError()

            decoded_retrieved_document: LangChainDocument = LangChainDocument(
                **json.loads(document_store_retrieved_document)
            )
            decoded_retrieved_document.metadata["relevancy_score"] = vector_store_retrieved_document.score
            decoded_retrieved_documents.append(decoded_retrieved_document)

        decoded_retrieved_documents.sort(
            key=lambda x: x.metadata["relevancy_score"],
            reverse=True
        )

        return decoded_retrieved_documents


In [12]:
class PartitionDocumentProcessor:
    def __init__(
            self,
            document_management: DocumentManagement,
            file_document_management: FileDocumentManagement,
            text_document_management: TextDocumentManagement,
            web_document_management: WebDocumentManagement,
    ):
        self.document_management = document_management
        self.file_document_management = file_document_management
        self.text_document_management = text_document_management
        self.web_document_management = web_document_management

    async def _partition_file(self, state: State, found_document: Document) -> List[Element]:
        found_file_document: FileDocumentResponse = await self.file_document_management.find_one_by_id_with_authorization(
            state=state,
            id=found_document.id
        )
        file_data: bytes = self.file_document_management.file_document_repository.get_object_data(
            object_name=found_file_document.file_name
        )
        extract_image_path: Path = self.file_document_management.file_document_repository.file_path / found_file_document.file_data_hash
        extract_image_path.mkdir(exist_ok=True)
        shutil.rmtree(extract_image_path)
        elements: List[Element] = partition(
            metadata_filename=found_file_document.file_name,
            file=io.BytesIO(file_data),
            extract_images_in_pdf=True,
            extract_image_block_output_dir=str(extract_image_path),
            strategy=PartitionStrategy.AUTO,
            hi_res_model_name="yolox",
        )

        return elements

    async def _partition_text(self, state: State, found_document: Document) -> List[Element]:
        found_text_document: TextDocumentResponse = await self.text_document_management.find_one_by_id_with_authorization(
            state=state,
            id=found_document.id
        )
        elements: List[Element] = partition_text(
            text=found_text_document.text_content,
        )

        return elements

    async def _partition_web(self, state: State, found_document: Document) -> List[Element]:
        found_web_document: WebDocumentResponse = await self.web_document_management.find_one_by_id_with_authorization(
            state=state,
            id=found_document.id
        )
        elements: List[Element] = partition_html(
            url=found_web_document.web_url,
            ssl_verify=False
        )

        return elements

    async def partition(self, state: State, document_id: UUID) -> List[Element]:
        found_document: Document = await self.document_management.find_one_by_id_with_authorization(
            state=state,
            id=document_id
        )
        if found_document.document_type_id == "file":
            elements: List[Element] = await self._partition_file(
                state=state,
                found_document=found_document
            )
        elif found_document.document_type_id == "text":
            elements: List[Element] = await self._partition_text(
                state=state,
                found_document=found_document
            )
        elif found_document.document_type_id == "web":
            elements: List[Element] = await self._partition_web(
                state=state,
                found_document=found_document
            )
        else:
            raise use_case_exception.DocumentTypeNotSupported()

        return elements


class SummaryDocumentProcessor:
    def __init__(self):
        pass

    async def summarize_tables(self, tables: List[Table], llm_model: BaseChatModel) -> List[str]:
        prompt: PromptTemplate = PromptTemplate(
            template="""Instruction: Give a concise passage summary of the table that is well optimized for retrieval. These summary will be embedded and used to retrieve the table. Ensure the output is only the summary without re-explain the instruction.
            Table: {table}""",
            input_variables=["table"]
        )

        batch_messages: List[List[BaseMessage]] = []
        for table in tables:
            text: str = prompt.format(
                table=table.text
            )
            messages: List[BaseMessage] = [
                HumanMessage(
                    content=[
                        {
                            "type": "text",
                            "text": text
                        }
                    ]
                )
            ]
            batch_messages.append(messages)

        chain: RunnableSerializable = llm_model | StrOutputParser()
        generated_summaries: List[str] = await chain.abatch(
            inputs=batch_messages
        )

        return generated_summaries

    async def summarize_images(self, images: List[Image], llm_model: BaseChatModel) -> List[str]:
        prompt_text = """Instruction: Give a concise passage summary of the image that is well optimized for retrieval. These summary will be embedded and used to retrieve the image. Ensure the output is only the summary without re-explain the instruction.
        Image:"""
        batch_messages: List[List[BaseMessage]] = []
        for image in images:
            messages: List[BaseMessage] = [
                HumanMessage(
                    content=[
                        {
                            "type": "text",
                            "text": prompt_text
                        },
                        {
                            "type": "image_url",
                            "image_url": {
                                "url": f"data:{image.metadata.image_mime_type};base64,{image.metadata.image_base64}",
                            }
                        }
                    ]
                )
            ]
            batch_messages.append(messages)

        chain: RunnableSerializable = llm_model | StrOutputParser()
        generated_summaries: List[str] = await chain.abatch(
            inputs=batch_messages
        )

        return generated_summaries


class CategoryDocumentProcessor:
    def __init__(
            self,
            summary_document_processor: SummaryDocumentProcessor,
    ):
        self.summary_document_processor = summary_document_processor

    async def categorize_elements(self, elements: List[Element]) -> ElementCategory:
        categorized_elements: ElementCategory = ElementCategory(
            texts=[],
            tables=[],
            images=[]
        )

        for element in elements:
            if any(
                    element_type == element.__class__.__name__ for element_type in
                    [Text.__name__, NarrativeText.__name__, HTMLText.__name__]

            ):
                categorized_elements.texts.append(element)
            elif any(
                    element_type == element.__class__.__name__ for element_type in
                    [Table.__name__, HTMLTable.__name__]
            ):
                categorized_elements.tables.append(element)
            elif any(
                    element_type == element.__class__.__name__ for element_type in
                    [Image.__name__]
            ):
                file_io = open(element.metadata.image_path, "rb")
                element.metadata.image_mime_type = "image/jpeg"
                element.metadata.image_base64 = base64.b64encode(file_io.read()).decode()
                file_io.close()
                categorized_elements.images.append(element)
            else:
                print(f"BaseDocumentProcessor.categorize_elements: Ignoring element type {element.__class__.__name__}.")

        return categorized_elements

    async def get_categorized_documents(
            self,
            categorized_elements: ElementCategory,
            summarization_model: BaseChatModel,
            is_include_tables: bool = False,
            is_include_images: bool = False,
            chunk_size: int = 400,
            id_key: str = "id",
            metadata: Dict[str, Any] = {},
    ) -> DocumentCategory:
        document_category: DocumentCategory = DocumentCategory(
            texts=[],
            tables=[],
            images=[],
            id_key=id_key
        )
        chunked_texts: List[Element] = chunk_elements(
            elements=categorized_elements.texts,
            include_orig_elements=True,
            max_characters=chunk_size
        )
        for text in chunked_texts:
            text_metadata: Dict[str, Any] = text.metadata.to_dict()
            filtered_metadata: Dict[str, Any] = dict_tool.filter_by_keys(
                text_metadata,
                [key for key in metadata.keys() if key != "orig_elements"]
            )
            filtered_metadata["orig_metadata"] = [orig_element.metadata.to_dict() for orig_element in
                                                  text.metadata.orig_elements]
            filtered_metadata["category"] = "text"
            document: LangChainDocument = LangChainDocument(
                page_content=text.text,
                metadata={
                    id_key: str(uuid.uuid4()),
                    **filtered_metadata,
                    **metadata,
                }
            )
            document_category.texts.append(document)

        if is_include_tables:
            summarized_tables: List[str] = await self.summary_document_processor.summarize_tables(
                tables=categorized_elements.tables,
                llm_model=summarization_model
            )
            for table in summarized_tables:
                document_category.tables.append(LangChainDocument(
                    page_content=table,
                    metadata={
                        id_key: str(uuid.uuid4()),
                        "category": "table",
                        **table.metadata.to_dict(),
                        **metadata
                    }
                ))

        if is_include_images:
            summarized_images: List[str] = await self.summary_document_processor.summarize_images(
                images=categorized_elements.images,
                llm_model=summarization_model
            )
            for image, summarized_image in zip(categorized_elements.images, summarized_images, strict=True):
                document_category.images.append(LangChainDocument(
                    page_content=summarized_image,
                    metadata={
                        id_key: str(uuid.uuid4()),
                        "category": "image",
                        **image.metadata.to_dict(),
                        **metadata
                    }
                ))

        return document_category


partition_document_processor: PartitionDocumentProcessor = PartitionDocumentProcessor(
    document_management=document_management,
    file_document_management=file_document_management,
    text_document_management=text_document_management,
    web_document_management=web_document_management,
)

summary_document_processor: SummaryDocumentProcessor = SummaryDocumentProcessor()
category_document_processor: CategoryDocumentProcessor = CategoryDocumentProcessor(
    summary_document_processor=summary_document_processor
)

In [13]:
class GraphState(TypedDict):
    data: Dict[str, Any]


class GraphPreparation:
    def __init__(
            self,
            one_llm_setting: OneLlmSetting,
            two_datastore: TwoDatastore,
            category_document_processor: CategoryDocumentProcessor,
    ):
        self.one_llm_setting = one_llm_setting
        self.two_datastore = two_datastore
        self.category_document_processor = category_document_processor

    def node_get_llm_model(self, input_state: GraphState) -> GraphState:
        output_state: GraphState = input_state

        model_list: List[Dict] = [
            {
                "model_name": "claude-3-haiku-20240307",
                "litellm_params": {
                    "model": "claude-3-haiku-20240307",
                    "api_key": self.one_llm_setting.LLM_ONE_ANTHROPIC_API_KEY_ONE,
                    "provider": "anthropic"
                }
            },
            {
                "model_name": "claude-3-opus-20240229",
                "litellm_params": {
                    "model": "claude-3-opus-20240229",
                    "api_key": self.one_llm_setting.LLM_ONE_ANTHROPIC_API_KEY_ONE,
                    "provider": "anthropic"
                }
            }
        ]
        router: Router = Router(model_list=model_list)
        deployment: Dict[str, Any] = router.get_available_deployment(
            model=input_state["data"]["llm"]["model_name"]
        )
        provider: str = deployment["litellm_params"]["provider"]
        if provider == "anthropic":
            llm_model: ChatAnthropic = ChatAnthropic(
                anthropic_api_key=deployment["litellm_params"]["api_key"],
                model=deployment["litellm_params"]["model"],
                max_tokens=input_state["data"]["llm"]["max_token"],
                streaming=True,
                temperature=0
            )
        else:
            raise use_case_exception.LlmProviderNotSupported()

        output_state["data"]["llm"]["model"] = llm_model

        return output_state

    async def node_get_categorized_documents(self, input_state: GraphState) -> GraphState:
        output_state: GraphState = input_state

        document_id: UUID = input_state["data"]["next_document_id"]

        categorized_element_hash: str = self._get_categorized_element_hash(
            document_id=document_id
        )
        categorized_document_hash: str = self._get_categorized_document_hash(
            categorized_element_hash=categorized_element_hash,
            summarization_model_name=input_state["data"]["llm"]["model_name"],
            is_include_tables=input_state["data"]["preprocessor_setting"]["is_include_tables"],
            is_include_images=input_state["data"]["preprocessor_setting"]["is_include_images"],
            chunk_size=input_state["data"]["preprocessor_setting"]["chunk_size"],
        )

        categorized_element_hashes: Optional[Dict[UUID, str]] = input_state["data"].get(
            "categorized_element_hashes",
            None
        )
        if categorized_element_hashes is None:
            output_state["data"]["categorized_element_hashes"] = {}
        output_state["data"]["categorized_element_hashes"][document_id] = categorized_element_hash
        is_categorized_element_exist: bool = cache_tool.is_key_in_cache(
            key=categorized_element_hash
        )
        is_force_refresh_categorized_element: bool = input_state["data"]["preprocessor_setting"][
            "is_force_refresh_categorized_element"]
        if is_categorized_element_exist is False or is_force_refresh_categorized_element is True:
            elements: List[Element] = await partition_document_processor.partition(
                state=input_state["data"]["state"],
                document_id=document_id
            )
            categorized_elements: ElementCategory = await self.category_document_processor.categorize_elements(
                elements=elements
            )
            cache_tool.set_cache(
                key=categorized_element_hash,
                value=categorized_elements
            )
        else:
            categorized_elements: ElementCategory = cache_tool.get_cache(
                key=categorized_element_hash
            )

        categorized_document_hashes: Optional[Dict[UUID, str]] = input_state["data"].get(
            "categorized_document_hashes",
            None
        )
        if categorized_document_hashes is None:
            output_state["data"]["categorized_document_hashes"] = {}
        output_state["data"]["categorized_document_hashes"][document_id] = categorized_document_hash
        existing_categorized_document_hash: int = await self.two_datastore.async_client.exists(
            categorized_document_hash
        )
        if existing_categorized_document_hash == 0:
            is_categorized_document_exist: bool = False
        elif existing_categorized_document_hash == 1:
            is_categorized_document_exist: bool = True
        else:
            raise use_case_exception.ExistingCategorizedDocumentHashInvalid()

        is_force_refresh_categorized_document: bool = input_state["data"]["preprocessor_setting"][
            "is_force_refresh_categorized_document"]
        if is_categorized_document_exist is False or is_force_refresh_categorized_document is True or is_force_refresh_categorized_element is True:
            categorized_document: DocumentCategory = await self.category_document_processor.get_categorized_documents(
                categorized_elements=categorized_elements,
                summarization_model=input_state["data"]["llm"]["model"],
                is_include_tables=input_state["data"]["preprocessor_setting"]["is_include_tables"],
                is_include_images=input_state["data"]["preprocessor_setting"]["is_include_images"],
                chunk_size=input_state["data"]["preprocessor_setting"]["chunk_size"],
                metadata={
                    "document_id": document_id
                }
            )
            await self.two_datastore.async_client.set(
                name=categorized_document_hash,
                value=json.dumps(categorized_document.dict(), default=jsonable_encoder).encode()
            )
        else:
            found_categorized_document_bytes: bytes = await self.two_datastore.async_client.get(
                categorized_document_hash
            )
            categorized_document: DocumentCategory = DocumentCategory(**json.loads(found_categorized_document_bytes))

        output_state["data"]["categorized_documents"][document_id] = categorized_document

        return output_state

    def _get_categorized_element_hash(
            self,
            document_id: UUID,
    ):
        data: Dict[str, Any] = {
            "document_id": document_id,
        }
        hashed_data: str = cache_tool.hash_by_dict(
            data=data
        )
        hashed_data = f"categorized_element-{hashed_data}"

        return hashed_data

    def _get_categorized_document_hash(
            self,
            categorized_element_hash: str,
            summarization_model_name: str,
            is_include_tables: bool,
            is_include_images: bool,
            chunk_size: int,
    ) -> str:
        data: Dict[str, Any] = {
            "categorized_element_hash": categorized_element_hash,
            "summarization_model_name": summarization_model_name,
            "is_include_tables": is_include_tables,
            "is_include_images": is_include_images,
            "chunk_size": chunk_size,
        }
        hashed_data: str = cache_tool.hash_by_dict(
            data=data
        )
        hashed_data = f"categorized_document-{hashed_data}"

        return hashed_data

    async def node_prepare_get_categorized_documents(self, input_state: GraphState):
        output_state: GraphState = input_state

        document_ids: List[UUID] = input_state["data"]["document_ids"]

        categorized_documents: Optional[Dict[UUID, DocumentCategory]] = input_state["data"].get(
            "categorized_documents",
            None
        )
        if categorized_documents is None:
            categorized_documents = {}
            output_state["data"]["categorized_documents"] = categorized_documents

        next_document_ids: List[UUID] = list(set(document_ids) - set(categorized_documents.keys()))
        next_document_id: UUID = next_document_ids.pop()

        output_state["data"]["next_document_id"] = next_document_id

        return output_state

    async def node_decide_get_categorized_documents_or_embed(self, input_state: GraphState) -> str:
        output_state: GraphState = input_state

        document_ids: List[UUID] = input_state["data"]["document_ids"]

        categorized_documents: Dict[UUID, DocumentCategory] = input_state["data"]["categorized_documents"]

        if set(categorized_documents.keys()) == set(document_ids):
            output_state["data"]["next_document_id"] = None
            return "EMBED"

        return "GET_CATEGORIZED_DOCUMENTS"

In [14]:
class GraphLongFormQa(GraphPreparation):
    def __init__(
            self,
            one_embedding_setting: OneEmbeddingSetting,
            four_datastore: FourDatastore,
            *args: Any,
            **kwargs: Any
    ):
        super().__init__(*args, **kwargs)
        self.one_embedding_setting = one_embedding_setting
        self.four_datastore = four_datastore

    def _get_embedding_query(self, embedding_model_name: str, question: str,
                             query_instruction: Optional[str] = None) -> str:
        if embedding_model_name == "BAAI/bge-m3":
            query: str = question
        elif embedding_model_name == "intfloat/multilingual-e5-large-instruct":
            if query_instruction is None:
                raise use_case_exception.QueryInstructionNotProvided()

            query: str = HuggingFaceE5InstructEmbeddings.get_detailed_instruct(
                task_description=query_instruction,
                query=question
            )
        else:
            raise use_case_exception.EmbeddingModelNameNotSupported()

        return query

    async def node_prepare_embed(self, input_state: GraphState):
        output_state: GraphState = input_state

        categorized_documents: Dict[UUID, DocumentCategory] = input_state["data"]["categorized_documents"]
        categorized_document_ids: List[UUID] = list(categorized_documents.keys())
        embedded_document_ids: Optional[List[UUID]] = input_state["data"].get("embedded_document_ids", None)
        if embedded_document_ids is None:
            embedded_document_ids = []
            output_state["data"]["embedded_document_ids"] = embedded_document_ids

        next_document_ids: List[UUID] = list(set(categorized_document_ids) - set(embedded_document_ids))
        next_document_id: UUID = next_document_ids.pop()

        output_state["data"]["next_document_id"] = next_document_id
        output_state["data"]["next_categorized_document"] = categorized_documents[next_document_id]

        return output_state

    @cache_tool.cacher(args_include_keys=[])
    def _get_bge_m3_embedding_model(self) -> BgeM3Embedding:
        embedding_model: BgeM3Embedding = BgeM3Embedding(
            use_fp16=False,
            normalize_embeddings=False,
        )

        return embedding_model

    def _get_vector_store(self, embedding_model_name: str, collection_name: str, alias: str) -> BaseMilvusVectorStore:
        if embedding_model_name == "BAAI/bge-m3":
            embedding_model: BgeM3Embedding = self._get_bge_m3_embedding_model()
            vector_store: MilvusBgeM3VectorStore = MilvusBgeM3VectorStore(
                embedding_model=embedding_model,
                alias=alias,
                collection_name=collection_name,
            )
        elif embedding_model_name == "intfloat/multilingual-e5-large-instruct":
            embedding_model: InfinityEmbeddings = InfinityEmbeddings(
                model=embedding_model_name,
                infinity_api_url=self.one_embedding_setting.URL,
            )
            vector_store: MilvusInfinityVectorStore = MilvusInfinityVectorStore(
                embedding_model=embedding_model,
                embedding_dimension=1024,
                alias=alias,
                collection_name=collection_name,
            )
        else:
            raise use_case_exception.EmbeddingModelNameNotSupported()

        return vector_store

    async def node_decide_embed_or_get_relevant_documents(self, input_state: GraphState) -> str:
        output_state: GraphState = input_state

        categorized_documents: Dict[UUID, DocumentCategory] = input_state["data"]["categorized_documents"]
        categorized_document_ids: List[UUID] = list(categorized_documents.keys())
        embedded_document_ids: List[UUID] = input_state["data"]["embedded_document_ids"]

        if set(categorized_document_ids) == set(embedded_document_ids):
            output_state["data"]["next_document_id"] = None
            output_state["data"]["next_categorized_document"] = None
            return "GET_RELEVANT_DOCUMENTS"

        return "EMBED"

    async def node_embed(self, input_state: GraphState) -> GraphState:
        output_state: GraphState = input_state

        categorized_document: DocumentCategory = input_state["data"]["next_categorized_document"]
        document_contents: List[str] = []
        document_ids: List[str] = []
        document_key_value_pairs: List[Tuple[Any, Any]] = []
        documents: List[
            LangChainDocument
        ] = categorized_document.texts + categorized_document.tables + categorized_document.images
        for document in documents:
            document_contents.append(document.page_content)
            document_ids.append(document.metadata[categorized_document.id_key])
            document_key_value_pairs.append(
                (
                    document.metadata[categorized_document.id_key],
                    bytes(json.dumps(document.dict(), default=jsonable_encoder).encode())
                )
            )

        collection_name: str = self._get_collection_name_hash(
            categorized_document_hashes=input_state["data"]["categorized_document_hashes"],
            embedding_model_name=input_state["data"]["embedder_setting"]["model_name"]
        )
        document_store: RedisStore = RedisStore(
            client=self.two_datastore.sync_client
        )
        vector_store: BaseMilvusVectorStore = self._get_vector_store(
            embedding_model_name=input_state["data"]["embedder_setting"]["model_name"],
            collection_name=collection_name,
            alias=self.four_datastore.alias
        )

        if len(document_ids) > 0:
            is_collection_exist: bool = vector_store.has_collection()
            is_entity_exist: bool = False
            if is_collection_exist is True:
                existing_entity_ids: List[Dict] = vector_store.collection.query(
                    expr=f"id in {document_ids}"
                )
                if len(existing_entity_ids) == len(document_ids):
                    is_entity_exist: bool = True
            else:
                vector_store.initialize_collection()

            is_force_refresh_embedding: bool = input_state["data"]["embedder_setting"]["is_force_refresh_embedding"]
            if is_entity_exist is False or is_force_refresh_embedding is True:
                vector_store.collection.delete(
                    expr=f"id in {document_ids}"
                )
                vector_store.embed_texts(
                    texts=document_contents,
                    ids=document_ids
                )

            existing_document_ids: int = await self.two_datastore.async_client.exists(
                *document_ids
            )
            if existing_document_ids == len(document_ids):
                is_document_exist: bool = True
            else:
                is_document_exist: bool = False

            is_force_refresh_document: bool = input_state["data"]["embedder_setting"]["is_force_refresh_document"]
            if is_document_exist is False or is_force_refresh_document is True:
                await document_store.amdelete(keys=document_ids)
                await document_store.amset(key_value_pairs=document_key_value_pairs)

        output_state["data"]["retriever_setting"]["vector_store"] = vector_store
        output_state["data"]["retriever_setting"]["document_store"] = document_store

        if output_state["data"].get("embedded_document_ids", None) is None:
            output_state["data"]["embedded_document_ids"] = []
        else:
            document_id: str = input_state["data"]["next_document_id"]
            output_state["data"]["embedded_document_ids"].append(document_id)

        return output_state

    def _get_collection_name_hash(self, categorized_document_hashes: Dict[UUID, str], embedding_model_name: str) -> str:
        modified_categorized_document_hashes: Dict[str, str] = {}
        for document_id, categorized_document_hash in categorized_document_hashes.items():
            modified_categorized_document_hashes[str(document_id)] = categorized_document_hash
        data: Dict[str, Any] = {
            "categorized_document_hashes": modified_categorized_document_hashes,
            "embedding_model_name": embedding_model_name,
        }
        hashed_data: str = cache_tool.hash_by_dict(
            data=data
        )
        collection_name: str = f"lfqa_{hashed_data}"

        return collection_name

    async def node_get_relevant_documents(self, input_state: GraphState) -> GraphState:
        output_state: GraphState = input_state

        query: str = self._get_embedding_query(
            embedding_model_name=input_state["data"]["embedder_setting"]["model_name"],
            query_instruction=input_state["data"]["embedder_setting"]["query_instruction"],
            question=input_state["data"]["question"]
        )
        vector_store: BaseMilvusVectorStore = input_state["data"]["retriever_setting"]["vector_store"]
        document_store: RedisStore = input_state["data"]["retriever_setting"]["document_store"]
        retriever: MilvusHybridRetriever = MilvusHybridRetriever(
            vector_store=vector_store,
            document_store=document_store,
            collection_name=vector_store.collection_name,
            search_kwargs={
                "top_k": input_state["data"]["retriever_setting"]["top_k"]
            }
        )
        output_state["data"]["retriever_setting"]["retriever"] = retriever
        relevant_document_hash: str = self._get_relevant_document_hash(
            top_k=input_state["data"]["retriever_setting"]["top_k"],
            collection_name=retriever.vector_store.collection_name,
            query=query,
        )
        existing_relevant_document_hash: int = await self.two_datastore.async_client.exists(relevant_document_hash)
        if existing_relevant_document_hash == 0:
            is_relevant_document_exist: bool = False
        elif existing_relevant_document_hash == 1:
            is_relevant_document_exist: bool = True
        else:
            raise use_case_exception.ExistingRelevantDocumentHashInvalid()

        is_force_refresh_relevant_document: bool = input_state["data"]["retriever_setting"][
            "is_force_refresh_relevant_document"]
        is_force_refresh_embedding: bool = input_state["data"]["embedder_setting"]["is_force_refresh_embedding"]
        if is_relevant_document_exist is False or is_force_refresh_relevant_document is True or is_force_refresh_embedding is True:
            relevant_documents: List[LangChainDocument] = retriever.get_relevant_documents(
                query=query
            )
            await self.two_datastore.async_client.set(
                name=relevant_document_hash,
                value=json.dumps(
                    obj=[document.dict() for document in relevant_documents],
                    default=jsonable_encoder
                ).encode()
            )
        else:
            retrieved_document_bytes: bytes = await self.two_datastore.async_client.get(relevant_document_hash)
            retrieved_document_dicts: List[Dict] = json.loads(retrieved_document_bytes)
            relevant_documents: List[LangChainDocument] = [
                LangChainDocument(**document) for document in retrieved_document_dicts
            ]

        output_state["data"]["relevant_documents"] = relevant_documents
        output_state["data"]["relevant_document_hash"] = relevant_document_hash

        return output_state

    def _get_relevant_document_hash(self, top_k: int, collection_name: str, query: str) -> str:
        data: Dict[str, Any] = {
            "top_k": top_k,
            "collection_name": collection_name,
            "query": query,
        }
        hashed_data: str = cache_tool.hash_by_dict(
            data=data
        )
        hashed_data = f"relevant_document-{hashed_data}"

        return hashed_data

    def _get_reranker_model(self, model_name: str) -> BaseReranker:
        if model_name == "BAAI/bge-reranker-v2-m3":
            model: BgeReranker = BgeReranker(
                model_name=model_name
            )
        else:
            raise use_case_exception.RerankerModelNameNotSupported()

        return model

    async def node_get_re_ranked_documents(self, input_state: GraphState) -> GraphState:
        output_state: GraphState = input_state

        relevant_document_hash: str = input_state["data"]["relevant_document_hash"]
        re_ranked_document_hash: str = self._get_re_ranked_document_hash(
            relevant_document_hash=relevant_document_hash,
            reranker_model_name=input_state["data"]["reranker_setting"]["model_name"],
            top_k=input_state["data"]["reranker_setting"]["top_k"]
        )
        existing_re_ranked_document_hash: int = await self.two_datastore.async_client.exists(re_ranked_document_hash)
        if existing_re_ranked_document_hash == 0:
            is_re_ranked_document_exist: bool = False
        elif existing_re_ranked_document_hash == 1:
            is_re_ranked_document_exist: bool = True
        else:
            raise use_case_exception.ExistingReRankedDocumentHashInvalid()

        is_force_refresh_re_ranked_document: bool = input_state["data"]["reranker_setting"][
            "is_force_refresh_re_ranked_document"]
        if is_re_ranked_document_exist is False or is_force_refresh_re_ranked_document is True:
            relevant_documents: List[LangChainDocument] = input_state["data"]["relevant_documents"]
            reranker_model: BaseReranker = self._get_reranker_model(
                model_name=input_state["data"]["reranker_setting"]["model_name"]
            )
            texts: List[str] = [document.page_content for document in relevant_documents]
            re_ranked_results: List[Dict[str, Any]] = reranker_model.rerank(
                query=input_state["data"]["question"],
                texts=texts,
                top_k=input_state["data"]["reranker_setting"]["top_k"]
            )
            re_ranked_documents: List[LangChainDocument] = []
            for re_ranked_result in re_ranked_results:
                relevant_document: Optional[LangChainDocument] = relevant_documents[re_ranked_result["index"]]
                re_ranked_document: LangChainDocument = LangChainDocument(
                    page_content=re_ranked_result["text"],
                    metadata=dict(
                        re_ranked_score=re_ranked_result["score"],
                        **relevant_document.metadata
                    )
                )
                re_ranked_documents.append(re_ranked_document)
            await self.two_datastore.async_client.set(
                name=re_ranked_document_hash,
                value=json.dumps(
                    obj=[document.dict() for document in re_ranked_documents],
                    default=jsonable_encoder
                ).encode()
            )
        else:
            re_ranked_document_bytes: bytes = await self.two_datastore.async_client.get(re_ranked_document_hash)
            re_ranked_document_dicts: List[Dict] = json.loads(re_ranked_document_bytes)
            re_ranked_documents: List[LangChainDocument] = [
                LangChainDocument(**document) for document in re_ranked_document_dicts
            ]

        output_state["data"]["re_ranked_documents"] = re_ranked_documents

        return output_state

    def _get_re_ranked_document_hash(
            self,
            relevant_document_hash: str,
            reranker_model_name: str,
            top_k: int
    ) -> str:
        data: Dict[str, Any] = {
            "relevant_document_hash": relevant_document_hash,
            "reranker_model_name": reranker_model_name,
            "top_k": top_k
        }
        hashed_data: str = cache_tool.hash_by_dict(
            data=data
        )
        hashed_data = f"re_ranked_document-{hashed_data}"

        return hashed_data

    async def node_generate_answer(self, input_state: GraphState) -> GraphState:
        output_state: GraphState = input_state

        re_ranked_documents: List[LangChainDocument] = input_state["data"]["re_ranked_documents"]
        retriever: MilvusHybridRetriever = input_state["data"]["retriever_setting"]["retriever"]
        re_ranked_document_ids: List[str] = [document.metadata[retriever.id_key] for document in re_ranked_documents]
        generated_answer_hash: str = self._get_generated_answer_hash(
            re_ranked_document_ids=re_ranked_document_ids,
            question=input_state["data"]["question"],
            llm_model_name=input_state["data"]["llm"]["model_name"],
            prompt_text=input_state["data"]["generator_setting"]["prompt_text"],
            max_token=input_state["data"]["llm"]["max_token"],
        )
        existing_generated_answer_hash: int = await self.two_datastore.async_client.exists(generated_answer_hash)
        if existing_generated_answer_hash == 0:
            is_generated_answer_exist: bool = False
        elif existing_generated_answer_hash == 1:
            is_generated_answer_exist: bool = True
        else:
            raise use_case_exception.ExistingGeneratedAnswerHashInvalid()

        is_force_refresh_generated_answer: bool = input_state["data"]["generator_setting"][
            "is_force_refresh_generated_answer"]
        if is_generated_answer_exist is False or is_force_refresh_generated_answer is True:
            prompt: PromptTemplate = PromptTemplate(
                template=input_state["data"]["generator_setting"]["prompt_text"],
                template_format="jinja2",
                input_variables=["passages", "question"]
            )
            text: str = prompt.format(
                passages=re_ranked_documents,
                question=input_state["data"]["question"]
            )
            messages: List[BaseMessage] = [
                HumanMessage(
                    content=[
                        {
                            "type": "text",
                            "text": text
                        }
                    ]
                )
            ]
            llm_model: BaseChatModel = input_state["data"]["llm"]["model"]
            chain: RunnableSerializable = llm_model | StrOutputParser()
            generated_answer: str = chain.invoke(
                input=messages
            )
            await self.two_datastore.async_client.set(
                name=generated_answer_hash,
                value=generated_answer.encode()
            )
        else:
            generated_answer_byte: bytes = await self.two_datastore.async_client.get(generated_answer_hash)
            generated_answer: str = generated_answer_byte.decode()

        output_state["data"]["generated_answer"] = generated_answer
        output_state["data"]["generated_answer_hash"] = generated_answer_hash

        return output_state

    def _get_generated_answer_hash(
            self,
            re_ranked_document_ids: List[str],
            question: str,
            llm_model_name: str,
            prompt_text: str,
            max_token: int,
    ) -> str:
        data: Dict[str, Any] = {
            "re_ranked_document_ids": re_ranked_document_ids,
            "question": question,
            "llm_model_name": llm_model_name,
            "prompt_text": prompt_text,
            "max_token": max_token,
        }
        hashed_data: str = cache_tool.hash_by_dict(
            data=data
        )
        hashed_data = f"generated_answer-{hashed_data}"

        return hashed_data

    async def node_grade_hallucination(self, input_state: GraphState) -> GraphState:
        output_state: GraphState = input_state

        retrieved_documents: List[LangChainDocument] = input_state["data"]["relevant_documents"]

        class GradeTool(BaseModelV1):
            """Binary score for support check."""
            binary_score: bool = FieldV1(
                description="Is supported binary score, either True if supported or False if not supported."
            )

        retriever: MilvusHybridRetriever = input_state["data"]["retriever_setting"]["retriever"]
        generated_hallucination_grade_hash: str = self._get_generated_hallucination_grade_hash(
            retrieved_document_ids=[document.metadata[retriever.id_key] for document in retrieved_documents],
            generated_answer_hash=input_state["data"]["generated_answer_hash"]
        )
        existing_generated_hallucination_grade_hash: int = await self.two_datastore.async_client.exists(
            generated_hallucination_grade_hash)
        if existing_generated_hallucination_grade_hash == 0:
            is_generated_hallucination_grade_hash_exist: bool = False
        elif existing_generated_hallucination_grade_hash == 1:
            is_generated_hallucination_grade_hash_exist: bool = True
        else:
            raise use_case_exception.ExistingGeneratedHallucinationGradeHashInvalid()

        is_force_refresh_generated_hallucination_grade_hash: bool = input_state["data"]["generator_setting"][
            "is_force_refresh_generated_hallucination_grade_hash"]
        if is_generated_hallucination_grade_hash_exist is False or is_force_refresh_generated_hallucination_grade_hash is True:
            prompt: PromptTemplate = PromptTemplate(
                template="""Instruction: Assess whether an Large Language Model generated answer is supported by a set of retrieved passages. Give a binary score of "True" or "False". "True" means that the answer is supported by the set of retrieved passages. "False" means that the answer is not supported by the set of retrieved passages.
                Passages:
                {% for passage in passages %}
                [{{ loop.index }}]={{ passage.page_content }}
                {% endfor %}
                Generated Answer: {{ generated_answer }}
                """,
                template_format="jinja2",
                input_variables=["passages", "generated_answer"]
            )
            text: str = prompt.format(
                passages=retrieved_documents,
                generated_answer=input_state["data"]["generated_answer"]
            )
            messages: List[BaseMessage] = [
                HumanMessage(
                    content=[
                        {
                            "type": "text",
                            "text": text
                        }
                    ]
                )
            ]
            llm_model: BaseChatModel = input_state["data"]["llm"]["model"]
            chain: RunnableSerializable = llm_model.bind_tools(tools=[GradeTool]) | ToolsOutputParser(
                pydantic_schemas=[GradeTool]
            )
            generated_tools: List[GradeTool] = chain.invoke(
                input=messages
            )
            generated_hallucination_grade: str = str(not generated_tools[0].binary_score)
            await self.two_datastore.async_client.set(
                name=generated_hallucination_grade_hash,
                value=generated_hallucination_grade.encode()
            )
        else:
            generated_hallucination_grade_byte: bytes = await self.two_datastore.async_client.get(
                generated_hallucination_grade_hash)
            generated_hallucination_grade: str = generated_hallucination_grade_byte.decode()

        output_state["data"]["generated_hallucination_grade"] = generated_hallucination_grade
        output_state["data"]["generated_hallucination_grade_hash"] = generated_hallucination_grade_hash

        return output_state

    def _get_generated_hallucination_grade_hash(
            self,
            retrieved_document_ids: List[str],
            generated_answer_hash: str,
    ) -> str:
        data: Dict[str, Any] = {
            "retrieved_document_ids": retrieved_document_ids,
            "generated_answer_hash": generated_answer_hash,
        }
        hashed_data: str = cache_tool.hash_by_dict(
            data=data
        )
        hashed_data = f"generated_hallucination_grade-{hashed_data}"

        return hashed_data

    async def node_grade_answer_relevancy(self, input_state: GraphState) -> GraphState:
        output_state: GraphState = input_state

        class GradeTool(BaseModelV1):
            """Binary score for resolution check."""
            binary_score: bool = FieldV1(
                description="Is resolved binary score, either True if resolved or False if not resolved."
            )

        generated_answer_relevancy_grade_hash: str = self._get_generated_answer_relevancy_grade_hash(
            question=input_state["data"]["question"],
            generated_answer_hash=input_state["data"]["generated_answer_hash"]
        )
        existing_generated_hallucination_grade_hash: int = await self.two_datastore.async_client.exists(
            generated_answer_relevancy_grade_hash)
        if existing_generated_hallucination_grade_hash == 0:
            is_generated_hallucination_grade_hash_exist: bool = False
        elif existing_generated_hallucination_grade_hash == 1:
            is_generated_hallucination_grade_hash_exist: bool = True
        else:
            raise use_case_exception.ExistingGeneratedAnswerRelevancyGradeHashInvalid()

        is_force_refresh_generated_answer_relevancy_grade_hash: bool = input_state["data"]["generator_setting"][
            "is_force_refresh_generated_answer_relevancy_grade_hash"]
        if is_generated_hallucination_grade_hash_exist is False or is_force_refresh_generated_answer_relevancy_grade_hash is True:
            prompt: PromptTemplate = PromptTemplate(
                template="""Instruction: Assess whether an Large Language Model generated answer resolves a question. Give a binary score of "True" or "False". "True" means that the answer resolves the question. "False" means that the answer does not resolve the question.
                Generated Answer: {{ generated_answer }}
                Question: {{ question }}
                """,
                input_variables=["generated_answer", "question"]
            )
            text: str = prompt.format(
                generated_answer=input_state["data"]["generated_answer"],
                question=input_state["data"]["question"]
            )
            messages: List[BaseMessage] = [
                HumanMessage(
                    content=[
                        {
                            "type": "text",
                            "text": text
                        }
                    ]
                )
            ]
            llm_model: BaseChatModel = input_state["data"]["llm"]["model"]
            chain: RunnableSerializable = llm_model.bind_tools(tools=[GradeTool]) | ToolsOutputParser(
                pydantic_schemas=[GradeTool]
            )
            generated_tools: List[GradeTool] = chain.invoke(
                input=messages
            )
            generated_answer_relevancy_grade: str = str(generated_tools[0].binary_score)
            await self.two_datastore.async_client.set(
                name=generated_answer_relevancy_grade_hash,
                value=generated_answer_relevancy_grade.encode()
            )
        else:
            generated_answer_relevancy_grade_byte: bytes = await self.two_datastore.async_client.get(
                generated_answer_relevancy_grade_hash
            )
            generated_answer_relevancy_grade: str = generated_answer_relevancy_grade_byte.decode()

        output_state["data"]["generated_answer_relevancy_grade"] = generated_answer_relevancy_grade
        output_state["data"]["generated_answer_relevancy_grade_hash"] = generated_answer_relevancy_grade_hash

        return output_state

    def _get_generated_answer_relevancy_grade_hash(
            self,
            question: str,
            generated_answer_hash: str,
    ) -> str:
        data: Dict[str, Any] = {
            "question": question,
            "generated_answer_hash": generated_answer_hash,
        }
        hashed_data: str = cache_tool.hash_by_dict(
            data=data
        )
        hashed_data = f"generated_answer_relevancy_grade-{hashed_data}"

        return hashed_data

    def node_decide_transform_question_or_grade_answer_relevancy(self, input_state: GraphState) -> str:
        output_state: GraphState = input_state

        generated_hallucination_grade: str = input_state["data"]["generated_hallucination_grade"]
        if generated_hallucination_grade == "False":
            return "GRADE_ANSWER_RELEVANCY"

        transform_question_max_retry: int = input_state["data"]["transform_question_max_retry"]
        input_state["data"].setdefault("transform_question_current_retry", 0)
        transform_question_current_retry: int = input_state["data"]["transform_question_current_retry"]
        if transform_question_current_retry >= transform_question_max_retry:
            return "MAX_RETRY"

        output_state["data"]["transform_question_current_retry"] += 1

        return "TRANSFORM_QUESTION"

    def node_decide_transform_question_or_provide_answer(self, input_state: GraphState) -> str:
        output_state: GraphState = input_state

        generated_answer_relevancy_grade: str = input_state["data"]["generated_answer_relevancy_grade"]
        if generated_answer_relevancy_grade == "True":
            return "PROVIDE_ANSWER"

        transform_question_max_retry: int = input_state["data"]["transform_question_max_retry"]
        input_state["data"].setdefault("transform_question_current_retry", 0)
        transform_question_current_retry: int = input_state["data"]["transform_question_current_retry"]
        if transform_question_current_retry >= transform_question_max_retry:
            return "MAX_RETRY"

        output_state["data"]["transform_question_current_retry"] += 1

        return "TRANSFORM_QUESTION"

    async def node_transform_question(self, input_state: GraphState) -> GraphState:
        output_state: GraphState = input_state

        generated_question_hash: str = self._get_transformed_question_hash(
            question=input_state["data"]["question"]
        )
        existing_generated_question_hash: int = await self.two_datastore.async_client.exists(
            generated_question_hash
        )
        if existing_generated_question_hash == 0:
            is_generated_question_exist: bool = False
        elif existing_generated_question_hash == 1:
            is_generated_question_exist: bool = True
        else:
            raise use_case_exception.ExistingGeneratedQuestionHashInvalid()

        is_force_refresh_generated_question: bool = input_state["data"]["generator_setting"][
            "is_force_refresh_generated_question"]
        if is_generated_question_exist is False or is_force_refresh_generated_question is True:
            prompt: PromptTemplate = PromptTemplate(
                template="""Instruction: Converts the question to a better version that is optimized for vector store retrieval. Observe the question and try to reason about underlying semantics. Ensure the output is only the question without re-explain the instruction.
                Question: {question}""",
                input_variables=["question"]
            )
            text: str = prompt.format(
                question=input_state["data"]["question"]
            )
            messages: List[BaseMessage] = [
                HumanMessage(
                    content=[
                        {
                            "type": "text",
                            "text": text
                        }
                    ]
                )
            ]
            llm_model: BaseChatModel = input_state["data"]["llm"]["model"]
            chain: RunnableSerializable = llm_model | StrOutputParser()
            generated_question: str = chain.invoke(
                input=messages
            )
            await self.two_datastore.async_client.set(
                name=generated_question_hash,
                value=generated_question.encode()
            )
        else:
            generated_question_byte: bytes = await self.two_datastore.async_client.get(
                generated_question_hash
            )
            generated_question: str = generated_question_byte.decode()

        output_state["data"]["question"] = generated_question
        output_state["data"]["question_hash"] = generated_question_hash

        return output_state

    def _get_transformed_question_hash(
            self,
            question: str,
    ) -> str:
        data: Dict[str, Any] = {
            "question": question,
        }
        hashed_data: str = cache_tool.hash_by_dict(
            data=data
        )
        hashed_data = f"transformed_question-{hashed_data}"

        return hashed_data

    def compile(self) -> CompiledGraph:
        graph: StateGraph = StateGraph(GraphState)

        graph.add_node(
            key=self.node_get_llm_model.__name__,
            action=self.node_get_llm_model
        )
        graph.add_node(
            key=self.node_prepare_get_categorized_documents.__name__,
            action=self.node_prepare_get_categorized_documents
        )
        graph.add_node(
            key=self.node_get_categorized_documents.__name__,
            action=self.node_get_categorized_documents
        )
        graph.add_node(
            key=self.node_prepare_embed.__name__,
            action=self.node_prepare_embed
        )
        graph.add_node(
            key=self.node_embed.__name__,
            action=self.node_embed
        )
        graph.add_node(
            key=self.node_get_relevant_documents.__name__,
            action=self.node_get_relevant_documents
        )
        graph.add_node(
            key=self.node_get_re_ranked_documents.__name__,
            action=self.node_get_re_ranked_documents
        )
        graph.add_node(
            key=self.node_generate_answer.__name__,
            action=self.node_generate_answer
        )
        graph.add_node(
            key=self.node_grade_hallucination.__name__,
            action=self.node_grade_hallucination
        )
        graph.add_node(
            key=self.node_grade_answer_relevancy.__name__,
            action=self.node_grade_answer_relevancy
        )
        graph.add_node(
            key=self.node_transform_question.__name__,
            action=self.node_transform_question
        )

        graph.set_entry_point(
            key=self.node_get_llm_model.__name__
        )
        graph.add_edge(
            start_key=self.node_get_llm_model.__name__,
            end_key=self.node_prepare_get_categorized_documents.__name__
        )
        graph.add_edge(
            start_key=self.node_prepare_get_categorized_documents.__name__,
            end_key=self.node_get_categorized_documents.__name__
        )
        graph.add_conditional_edges(
            start_key=self.node_get_categorized_documents.__name__,
            condition=self.node_decide_get_categorized_documents_or_embed,
            conditional_edge_mapping={
                "GET_CATEGORIZED_DOCUMENTS": self.node_prepare_get_categorized_documents.__name__,
                "EMBED": self.node_prepare_embed.__name__
            }
        )
        graph.add_edge(
            start_key=self.node_prepare_embed.__name__,
            end_key=self.node_embed.__name__
        )
        graph.add_conditional_edges(
            start_key=self.node_embed.__name__,
            condition=self.node_decide_embed_or_get_relevant_documents,
            conditional_edge_mapping={
                "EMBED": self.node_prepare_embed.__name__,
                "GET_RELEVANT_DOCUMENTS": self.node_get_relevant_documents.__name__
            }
        )
        graph.add_edge(
            start_key=self.node_get_relevant_documents.__name__,
            end_key=self.node_get_re_ranked_documents.__name__
        )
        graph.add_edge(
            start_key=self.node_get_re_ranked_documents.__name__,
            end_key=self.node_generate_answer.__name__
        )
        graph.add_edge(
            start_key=self.node_generate_answer.__name__,
            end_key=self.node_grade_hallucination.__name__
        )
        graph.add_conditional_edges(
            start_key=self.node_grade_hallucination.__name__,
            condition=self.node_decide_transform_question_or_grade_answer_relevancy,
            conditional_edge_mapping={
                "MAX_RETRY": END,
                "GRADE_ANSWER_RELEVANCY": self.node_grade_answer_relevancy.__name__,
                "TRANSFORM_QUESTION": self.node_transform_question.__name__
            }
        )
        graph.add_conditional_edges(
            start_key=self.node_grade_answer_relevancy.__name__,
            condition=self.node_decide_transform_question_or_provide_answer,
            conditional_edge_mapping={
                "MAX_RETRY": END,
                "PROVIDE_ANSWER": END,
                "TRANSFORM_QUESTION": self.node_transform_question.__name__
            }
        )
        graph.add_edge(
            start_key=self.node_transform_question.__name__,
            end_key=self.node_get_relevant_documents.__name__
        )

        compiled_graph: CompiledGraph = graph.compile()

        return compiled_graph


output_state: GraphState


async def handler(session: AsyncSession):
    global output_state

    state: State = State()
    state.authorized_session = all_seeder.session_seeder.session_fake.data[0]
    state.session = session

    graph_lfqa: GraphLongFormQa = GraphLongFormQa(
        one_embedding_setting=one_embedding_setting,
        one_llm_setting=one_llm_setting,
        two_datastore=two_datastore,
        four_datastore=four_datastore,
        category_document_processor=category_document_processor
    )
    compiled_graph_lfqa: CompiledGraph = graph_lfqa.compile()

    data: Dict[str, Any] = {
        "state": state,
        "document_ids": [all_seeder.document_seeder.document_fake.data[0].id,
                         all_seeder.document_seeder.document_fake.data[1].id,
                         all_seeder.document_seeder.document_fake.data[2].id],
        "llm": {
            "model_name": "claude-3-haiku-20240307",
            "max_token": 500,
        },
        "preprocessor_setting": {
            "is_force_refresh_categorized_element": False,
            "is_force_refresh_categorized_document": False,
            "chunk_size": 500,
            "is_include_tables": False,
            "is_include_images": False,
        },
        "embedder_setting": {
            "is_force_refresh_embedding": False,
            "is_force_refresh_document": False,
            # "model_name": "intfloat/multilingual-e5-large-instruct",
            "model_name": "BAAI/bge-m3",
            "query_instruction": "Given the question, retrieve passage that answer the question.",
        },
        "retriever_setting": {
            "is_force_refresh_relevant_document": False,
            "top_k": 50,
        },
        "reranker_setting": {
            "model_name": "BAAI/bge-reranker-v2-m3",
            "is_force_refresh_re_ranked_document": False,
            "top_k": 5,
        },
        "question": "what is political science?",
        "generator_setting": {
            "is_force_refresh_generated_answer": False,
            "is_force_refresh_generated_question": False,
            "is_force_refresh_generated_hallucination_grade_hash": False,
            "is_force_refresh_generated_answer_relevancy_grade_hash": False,
            "prompt_text": """Instruction: Create a concise and informative answer for a given question based solely on the given passages. You must only use information from the given passages. Use an unbiased and journalistic tone. Do not repeat text. Cite at least one passage in each sentence. Cite the passages using passage number notation like "[number]". If multiple passages contain the answer, cite those passages like "[number, number, etc.]". If the passages do not contain the answer to the question, then say that answering is not possible given the available information with the explanation. Ensure the output is only the answer without re-explain the instruction.
            Passages:
            {% for passage in passages %}
            [{{ loop.index }}]={{ passage.page_content }}
            {% endfor %}
            Question: {{ question }}
            Answer:"""
        },
        "transform_question_max_retry": 0
    }

    print(compiled_graph_lfqa.get_graph().draw_mermaid())

    input_state: GraphState = GraphState(
        data=data
    )
    output_state = await compiled_graph_lfqa.ainvoke(input_state)

In [15]:
# cache_tool.clear_cache()
await one_datastore.retryable(handler)
torch.cuda.empty_cache()
gc.collect()
cache_tool.get_cache()

%%{init: {'flowchart': {'curve': 'linear'}}}%%
graph TD;
	__start__[__start__]:::startclass;
	__end__[__end__]:::endclass;
	node_get_llm_model([node_get_llm_model]):::otherclass;
	node_prepare_get_categorized_documents([node_prepare_get_categorized_documents]):::otherclass;
	node_get_categorized_documents([node_get_categorized_documents]):::otherclass;
	node_prepare_embed([node_prepare_embed]):::otherclass;
	node_embed([node_embed]):::otherclass;
	node_get_relevant_documents([node_get_relevant_documents]):::otherclass;
	node_get_re_ranked_documents([node_get_re_ranked_documents]):::otherclass;
	node_generate_answer([node_generate_answer]):::otherclass;
	node_grade_hallucination([node_grade_hallucination]):::otherclass;
	node_grade_answer_relevancy([node_grade_answer_relevancy]):::otherclass;
	node_transform_question([node_transform_question]):::otherclass;
	node_get_categorized_documents_node_decide_get_categorized_documents_or_embed([node_get_categorized_documents_node_decide_get_cate



BaseDocumentProcessor.categorize_elements: Ignoring element type HTMLTitle.
BaseDocumentProcessor.categorize_elements: Ignoring element type HTMLTitle.
BaseDocumentProcessor.categorize_elements: Ignoring element type HTMLTitle.
BaseDocumentProcessor.categorize_elements: Ignoring element type HTMLTitle.
BaseDocumentProcessor.categorize_elements: Ignoring element type HTMLTitle.
BaseDocumentProcessor.categorize_elements: Ignoring element type HTMLTitle.
BaseDocumentProcessor.categorize_elements: Ignoring element type HTMLTitle.
BaseDocumentProcessor.categorize_elements: Ignoring element type HTMLTitle.
BaseDocumentProcessor.categorize_elements: Ignoring element type HTMLTitle.
BaseDocumentProcessor.categorize_elements: Ignoring element type HTMLTitle.
BaseDocumentProcessor.categorize_elements: Ignoring element type HTMLTitle.
BaseDocumentProcessor.categorize_elements: Ignoring element type HTMLTitle.
BaseDocumentProcessor.categorize_elements: Ignoring element type HTMLTitle.
BaseDocument

Some weights of the model checkpoint at microsoft/table-transformer-structure-recognition were not used when initializing TableTransformerForObjectDetection: ['model.backbone.conv_encoder.model.layer2.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer3.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer4.0.downsample.1.num_batches_tracked']
- This IS expected if you are initializing TableTransformerForObjectDetection from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TableTransformerForObjectDetection from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


BaseDocumentProcessor.categorize_elements: Ignoring element type Title.
BaseDocumentProcessor.categorize_elements: Ignoring element type ListItem.
BaseDocumentProcessor.categorize_elements: Ignoring element type Header.
BaseDocumentProcessor.categorize_elements: Ignoring element type FigureCaption.
BaseDocumentProcessor.categorize_elements: Ignoring element type Header.
BaseDocumentProcessor.categorize_elements: Ignoring element type Header.
BaseDocumentProcessor.categorize_elements: Ignoring element type Header.
BaseDocumentProcessor.categorize_elements: Ignoring element type Title.
BaseDocumentProcessor.categorize_elements: Ignoring element type FigureCaption.
BaseDocumentProcessor.categorize_elements: Ignoring element type Header.
BaseDocumentProcessor.categorize_elements: Ignoring element type Header.
BaseDocumentProcessor.categorize_elements: Ignoring element type Header.
BaseDocumentProcessor.categorize_elements: Ignoring element type Header.
BaseDocumentProcessor.categorize_elem

Fetching 22 files:   0%|          | 0/22 [00:00<?, ?it/s]

  warn_beta(


{'categorized_element-5a3a764cdb39b8347c290d93afe667eb010da0f2fb2b578620f40170ed4d540b': ElementCategory(texts=[], tables=[<unstructured.documents.html.HTMLTable object at 0x7ff6f05fbfd0>], images=[]),
 'categorized_element-b2899a0c3bbe2f36b2ed5fbd3b6858dcf9fe3787f034075040bf5249e483efce': ElementCategory(texts=[<unstructured.documents.elements.NarrativeText object at 0x7ff6f0519f00>], tables=[], images=[]),
 'categorized_element-3a3efcb381b2a4c32f4ee1fcbe00bddb7fb422bf4bc67b4ce2ec2bcc0c5c6646': ElementCategory(texts=[<unstructured.documents.elements.NarrativeText object at 0x7ff67474bd90>, <unstructured.documents.elements.NarrativeText object at 0x7ff67474bd30>, <unstructured.documents.elements.Text object at 0x7ff677ac1c60>, <unstructured.documents.elements.NarrativeText object at 0x7ff67474b220>, <unstructured.documents.elements.NarrativeText object at 0x7ff67474b490>, <unstructured.documents.elements.NarrativeText object at 0x7ff67474afe0>, <unstructured.documents.elements.Narrativ

In [16]:
output_state

{'data': {'state': <starlette.datastructures.State at 0x7ff6f058fca0>,
  'document_ids': [UUID('4b04540d-4706-4870-a806-9e7399510650'),
   UUID('60d29c9a-5178-4630-b2c5-409d683f1a2b'),
   UUID('d27ee1dc-abd2-4f4e-b94b-61b0526250e5')],
  'llm': {'model_name': 'claude-3-haiku-20240307',
   'max_token': 500,
   'model': ChatAnthropic(model='claude-3-haiku-20240307', max_tokens=500, temperature=0.0, anthropic_api_key=SecretStr('**********'), streaming=True, _client=<anthropic.Anthropic object at 0x7ff6f05b8dc0>, _async_client=<anthropic.AsyncAnthropic object at 0x7ff6f05bb250>)},
  'preprocessor_setting': {'is_force_refresh_categorized_element': False,
   'is_force_refresh_categorized_document': False,
   'chunk_size': 500,
   'is_include_tables': False,
   'is_include_images': False},
  'embedder_setting': {'is_force_refresh_embedding': False,
   'is_force_refresh_document': False,
   'model_name': 'BAAI/bge-m3',
   'query_instruction': 'Given the question, retrieve passage that answer th

In [17]:
output_state["data"]["question"]

'what is political science?'

In [18]:
output_state["data"]["relevant_documents"]

[Document(page_content='Political science is the scientific study of politics. It is a social science dealing with systems of governance and power, and the analysis of political activities, political thoughts, political behavior, and political structures.', metadata={'id': '21a82e79-b284-451d-81ba-7f93a1c9b142', 'orig_metadata': [{'languages': ['eng'], 'filetype': 'text/plain'}], 'category': 'text', 'document_id': '60d29c9a-5178-4630-b2c5-409d683f1a2b', 'relevancy_score': 0.032786883413791656}),
 Document(page_content='other emerging technologies have facilitated the transfer of artiﬁcial intel- ligence to machines and other items, such as buildings and robots [11]. Indeed, Chassignol et al. provides a two-faceted deﬁnition and description of AI. They deﬁne AI as a ﬁeld and a theory. As a ﬁeld of study, they deﬁne AI as a study area in computer science whose pursuits are aimed at solving different cognitive problems commonly associated with the human intelligence, such as learning, pro

In [19]:
output_state["data"]["re_ranked_documents"]

[Document(page_content='Political science is the scientific study of politics. It is a social science dealing with systems of governance and power, and the analysis of political activities, political thoughts, political behavior, and political structures.', metadata={'re_ranked_score': 0.9999151889582765, 'id': '21a82e79-b284-451d-81ba-7f93a1c9b142', 'orig_metadata': [{'languages': ['eng'], 'filetype': 'text/plain'}], 'category': 'text', 'document_id': '60d29c9a-5178-4630-b2c5-409d683f1a2b', 'relevancy_score': 0.032786883413791656}),
 Document(page_content='intelligence as the study of intelligence behavior in human beings, animals, and machines and endeavoring to engineer such behavior into an artifact, such as computers and computer-related technolo- gies [5] (p.1). Drawing from these deﬁnitions, it is evident that artiﬁcial intelligence is the culmination of computers, computer-related technologies, machines, and information communication technology innovations and developments, giv

In [20]:
output_state["data"]["generated_answer"]

'Political science is the scientific study of politics. It is a social science dealing with systems of governance and power, and the analysis of political activities, political thoughts, political behavior, and political structures [1].'

In [68]:
document_id = all_seeder.file_document_seeder.file_document_fake.data[0].id
output_state: GraphState = GraphState(
    data={}
)


async def handler(session: AsyncSession):
    global output_state

    state: State = State()
    state.authorized_session = all_seeder.session_seeder.session_fake.data[0]
    state.session = session

    graph_lfqa: GraphLongFormQa = GraphLongFormQa(
        one_embedding_setting=one_embedding_setting,
        one_llm_setting=one_llm_setting,
        two_datastore=two_datastore,
        four_datastore=four_datastore,
        category_document_processor=category_document_processor
    )

    graph_document: StateGraph = StateGraph(GraphState)
    graph_document.add_node(
        key=graph_lfqa.node_get_llm_model.__name__,
        action=graph_lfqa.node_get_llm_model
    )
    graph_document.add_node(
        key=graph_lfqa.node_prepare_get_categorized_documents.__name__,
        action=graph_lfqa.node_prepare_get_categorized_documents
    )
    graph_document.add_node(
        key=graph_lfqa.node_get_categorized_documents.__name__,
        action=graph_lfqa.node_get_categorized_documents
    )

    graph_document.set_entry_point(
        key=graph_lfqa.node_get_llm_model.__name__
    )

    graph_document.add_edge(
        start_key=graph_lfqa.node_get_llm_model.__name__,
        end_key=graph_lfqa.node_prepare_get_categorized_documents.__name__
    )
    graph_document.add_edge(
        start_key=graph_lfqa.node_prepare_get_categorized_documents.__name__,
        end_key=graph_lfqa.node_get_categorized_documents.__name__
    )
    graph_document.add_conditional_edges(
        start_key=graph_lfqa.node_get_categorized_documents.__name__,
        condition=graph_lfqa.node_decide_get_categorized_documents_or_embed,
        conditional_edge_mapping={
            "GET_CATEGORIZED_DOCUMENTS": graph_lfqa.node_prepare_get_categorized_documents.__name__,
            "EMBED": END
        }
    )
    compiled_graph_document = graph_document.compile()

    data: Dict[str, Any] = {
        "state": state,
        "document_ids": [document_id],
        "llm": {
            "model_name": "claude-3-haiku-20240307",
            "max_token": 500,
        },
        "preprocessor_setting": {
            "is_force_refresh_categorized_element": False,
            "is_force_refresh_categorized_document": False,
            "chunk_size": 50,
            "is_include_tables": False,
            "is_include_images": False,
        },
    }

    input_state = GraphState(
        data=data
    )
    output_state = await compiled_graph_document.ainvoke(
        input=input_state
    )


await one_datastore.retryable(handler)
output_state

{'data': {'state': <starlette.datastructures.State at 0x7ff5d6a7ca00>,
  'document_ids': [UUID('94b64c60-2c9d-4bb1-91e4-5d4f0ed7300f')],
  'llm': {'model_name': 'claude-3-haiku-20240307',
   'max_token': 500,
   'model': ChatAnthropic(model='claude-3-haiku-20240307', max_tokens=500, temperature=0.0, anthropic_api_key=SecretStr('**********'), streaming=True, _client=<anthropic.Anthropic object at 0x7ff5d6a7f6d0>, _async_client=<anthropic.AsyncAnthropic object at 0x7ff5d6a7c640>)},
  'preprocessor_setting': {'is_force_refresh_categorized_element': False,
   'is_force_refresh_categorized_document': False,
   'chunk_size': 50,
   'is_include_tables': False,
   'is_include_images': False},
  'next_document_id': None,
  'categorized_element_hashes': {UUID('94b64c60-2c9d-4bb1-91e4-5d4f0ed7300f'): '5abdf19ad8bbcff78dd01362ad21ddaec65e7193651087e3aefa5eff4acb74b9'},
  'categorized_document_hashes': {UUID('94b64c60-2c9d-4bb1-91e4-5d4f0ed7300f'): '999557ea1a0af986a8f8241197134ceaddbaa53347f18b1fe

In [19]:
documents = output_state["data"]["categorized_documents"][document_id].get_all()

generator_llm = ChatAnthropic(
    model="claude-3-haiku-20240307",
    anthropic_api_key=one_llm_setting.LLM_ONE_ANTHROPIC_API_KEY_ONE
)
critic_llm = ChatAnthropic(
    model="claude-3-opus-20240229",
    anthropic_api_key=one_llm_setting.LLM_ONE_ANTHROPIC_API_KEY_ONE
)
embeddings = InfinityEmbeddings(
    model="intfloat/multilingual-e5-large-instruct",
    infinity_api_url=one_embedding_setting.URL
)

generator = TestsetGenerator.from_langchain(
    generator_llm=generator_llm,
    critic_llm=critic_llm,
    embeddings=embeddings
)

test_set = generator.generate_with_langchain_docs(
    documents=documents,
    test_size=1,
    distributions={
        evolutions.simple: 0.5,
        evolutions.reasoning: 0.25,
        evolutions.multi_context: 0.25
    }
)

embedding nodes:   0%|          | 0/708 [00:00<?, ?it/s]

Exception in thread Thread-11:
Traceback (most recent call last):
  File "/usr/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
    self.run()
  File "/usr/local/lib/python3.10/dist-packages/ragas/executor.py", line 96, in run
    results = self.loop.run_until_complete(self._aresults())
  File "/usr/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete
    return future.result()
  File "/usr/local/lib/python3.10/dist-packages/ragas/executor.py", line 84, in _aresults
    raise e
  File "/usr/local/lib/python3.10/dist-packages/ragas/executor.py", line 79, in _aresults
    r = await future
  File "/usr/lib/python3.10/asyncio/tasks.py", line 571, in _wait_for_one
    return f.result()  # May raise f.exception().
  File "/usr/local/lib/python3.10/dist-packages/ragas/executor.py", line 38, in sema_coro
    return await coro
  File "/usr/local/lib/python3.10/dist-packages/ragas/executor.py", line 112, in wrapped_callable_async
    return counter, await callable(

ExceptionInRunner: The runner thread which was running the jobs raised an exeception. Read the traceback above to debug it. You can also pass `raise_exceptions=False` incase you want to show only a warning message instead.

In [27]:
eval_set = test_set.to_dataset()
eval_set.rename_column(
    original_column_name="answer",
    new_column_name="ground_truth"
)

NameError: name 'test_set' is not defined

In [166]:
for index, eval in enumerate(eval_set):
    async def handler(session: AsyncSession):
        global output_state

        state: State = State()
        state.authorized_session = all_seeder.session_seeder.session_fake.data[0]
        state.session = session

        graph_lfqa: GraphLongFormQa = GraphLongFormQa(
            one_embedding_setting=one_embedding_setting,
            one_llm_setting=one_llm_setting,
            two_datastore=two_datastore,
            four_datastore=four_datastore,
            category_document_processor=category_document_processor
        )
        compiled_graph_lfqa: CompiledGraph = graph_lfqa.compile()

        data: Dict[str, Any] = {
            "state": state,
            "document_ids": [document_id],
            "llm": {
                "model_name": "claude-3-haiku-20240307",
                "max_token": 500,
            },
            "preprocessor_setting": {
                "is_force_refresh_categorized_document": False,
                "chunk_size": 50,
                "is_include_tables": False,
                "is_include_images": False,
            },
            "embedder_setting": {
                "is_force_refresh_embedding": False,
                "is_force_refresh_document": False,
                "model_name": "BAAI/bge-m3",
                "query_instruction": "Given the question, retrieve passage that answer the question.",
            },
            "retriever_setting": {
                "is_force_refresh_relevant_document": False,
                "top_k": 50,
            },
            "reranker_setting": {
                "model_name": "BAAI/bge-reranker-v2-m3",
                "is_force_refresh_re_ranked_document": False,
                "top_k": 5,
            },
            "question": "what is political science?",
            "generator_setting": {
                "is_force_refresh_generated_answer": False,
                "is_force_refresh_generated_question": False,
                "is_force_refresh_generated_hallucination_grade_hash": False,
                "is_force_refresh_generated_answer_relevancy_grade_hash": False,
                "prompt_text": """Instruction: Create a concise and informative answer for a given question based solely on the given passages. You must only use information from the given passages. Use an unbiased and journalistic tone. Do not repeat text. Cite at least one passage in each sentence. Cite the passages using passage number notation like "[number]". If multiple passages contain the answer, cite those passages like "[number, number, etc.]". If the passages do not contain the answer to the question, then say that answering is not possible given the available information with the explanation. Ensure the output is only the answer without re-explain the instruction.
                Passages:
                {% for passage in passages %}
                [{{ loop.index }}]={{ passage.page_content }}
                {% endfor %}
                Question: {{ question }}
                Answer:"""
            },
            "transform_question_max_retry": 3
        }

        input_state: GraphState = GraphState(
            data=data
        )
        output_state = await compiled_graph_lfqa.ainvoke(input_state)

        eval_set[index]["contexts"] = [document.page_content for document in
                                       output_state["data"]["categorized_documents"][document_id].get_all()]
        eval_set[index]["answer"] = output_state["data"]["generated_answer"]

NameError: name 'eval_set' is not defined

In [23]:
# loading the V2 dataset
amnesty_qa = load_dataset("explodinggradients/amnesty_qa", "english_v2", trust_remote_code=True)

Repo card metadata block was not found. Setting CardData to empty.


In [24]:
eval_set_2 = amnesty_qa["eval"].select(range(1))
eval_set_2

Dataset({
    features: ['question', 'ground_truth', 'answer', 'contexts'],
    num_rows: 1
})

In [25]:
result = evaluate(
    dataset=eval_set_2,
    llm=critic_llm,
    embeddings=embeddings,
    metrics=[
        metrics.faithfulness,
        metrics.answer_relevancy,
        metrics.context_recall,
        metrics.context_precision,
        #     metrics.answer_correctness,
        #     metrics.context_relevancy,
        #     metrics.context_entity_recall,
    ],
)

Evaluating:   0%|          | 0/4 [00:00<?, ?it/s]

  for attr in assigned:
Task was destroyed but it is pending!
task: <Task pending name='Task-369' coro=<as_completed.<locals>.sema_coro() running at /usr/local/lib/python3.10/dist-packages/ragas/executor.py:37> wait_for=<Future pending cb=[Task.task_wakeup()]> cb=[as_completed.<locals>._on_completion() at /usr/lib/python3.10/asyncio/tasks.py:558]>
Task was destroyed but it is pending!
task: <Task pending name='Task-52' coro=<as_completed.<locals>.sema_coro() running at /usr/local/lib/python3.10/dist-packages/ragas/executor.py:38> wait_for=<Future pending cb=[Task.task_wakeup()]> cb=[as_completed.<locals>._on_completion() at /usr/lib/python3.10/asyncio/tasks.py:558]>
Task was destroyed but it is pending!
task: <Task pending name='Task-55' coro=<as_completed.<locals>.sema_coro() running at /usr/local/lib/python3.10/dist-packages/ragas/executor.py:38> wait_for=<Future pending cb=[Task.task_wakeup()]> cb=[as_completed.<locals>._on_completion() at /usr/lib/python3.10/asyncio/tasks.py:558]>


In [26]:
result

{'faithfulness': 0.5714, 'answer_relevancy': 1.0000, 'context_recall': 1.0000, 'context_precision': 1.0000}