In [13]:
import copy
import io
import os
import shutil
from pathlib import Path

# from langchain import retrievers
from langchain_community.storage.redis import RedisStore
from langchain_core.embeddings import Embeddings
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import HumanMessage, BaseMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_text_splitters import RecursiveCharacterTextSplitter
from litellm import Router

from apps.inners.models.dtos.element_category import ElementCategory
from apps.outers.exceptions import use_case_exception

os.chdir("/app")
from uuid import UUID

from IPython.lib.display import IFrame
from sqlmodel.ext.asyncio.session import AsyncSession
from starlette.datastructures import State

from apps.inners.models.daos.document import Document
from apps.inners.models.dtos.contracts.responses.managements.documents.file_document_response import \
    FileDocumentResponse
from apps.inners.models.dtos.contracts.responses.managements.documents.text_document_response import \
    TextDocumentResponse
from apps.inners.models.dtos.contracts.responses.managements.documents.web_document_response import WebDocumentResponse
from apps.inners.use_cases.managements.document_management import DocumentManagement
from apps.inners.use_cases.managements.file_document_management import FileDocumentManagement
from apps.inners.use_cases.managements.text_document_management import TextDocumentManagement
from apps.inners.use_cases.managements.web_document_management import WebDocumentManagement
from typing import List
from typing import TypedDict, Tuple, Dict, Any

import dotenv
import litellm
from datasets import load_dataset
from dotenv import find_dotenv
from langchain_community.chat_models import ChatLiteLLMRouter
from langchain_community.vectorstores.milvus import Milvus
from langchain_core.runnables.base import RunnableLike, RunnableSerializable
from langgraph.graph import StateGraph
from langgraph.graph.graph import CompiledGraph
from ragas import evaluate
from unstructured.documents.elements import Element, Table, Image, Text
from unstructured.partition.auto import partition
from unstructured.partition.utils.constants import PartitionStrategy

from apps.inners.use_cases.embeddings.hugging_face_e5_instruct_embedding import HuggingFaceE5InstructEmbeddings
from apps.outers.datastores.four_datastore import FourDatastore
from apps.outers.datastores.one_datastore import OneDatastore
from apps.outers.datastores.three_datastore import ThreeDatastore
from apps.outers.datastores.two_datastore import TwoDatastore
from apps.outers.repositories.file_document_repository import FileDocumentRepository
from apps.outers.repositories.text_document_repository import TextDocumentRepository
from apps.outers.repositories.web_document_repository import WebDocumentRepository
from tests.containers.test_container import TestContainer
from tests.seeders.all_seeder import AllSeeder

!pip show sqlalchemy


Name: SQLAlchemy
Version: 1.4.52
Summary: Database Abstraction Library
Home-page: https://www.sqlalchemy.org
Author: Mike Bayer
Author-email: mike_mp@zzzcomputing.com
License: MIT
Location: /usr/local/lib/python3.10/dist-packages
Requires: greenlet
Required-by: fastapi-utils, langchain, langchain-community, sqlalchemy-cockroachdb, sqlmodel


In [None]:
# import tensorflow
# 
# tensorflow.config.list_physical_devices('GPU')

In [2]:
# import torch
# 
# torch.cuda.is_available()

In [3]:
dotenv.load_dotenv(find_dotenv())


def Frame(src):
    return IFrame(src, width=700, height=500)


NameError: name 'dotenv' is not defined

In [64]:
test_container = TestContainer()

one_datastore: OneDatastore = test_container.applications.datastores.one()
two_datastore: TwoDatastore = test_container.applications.datastores.two()
three_datastore: ThreeDatastore = test_container.applications.datastores.three()
four_datastore: FourDatastore = test_container.applications.datastores.four()
temp_datastore: ThreeDatastore = test_container.applications.datastores.temp()

file_document_repository: FileDocumentRepository = test_container.applications.repositories.file_document()
text_document_repository: TextDocumentRepository = test_container.applications.repositories.text_document()
web_document_repository: WebDocumentRepository = test_container.applications.repositories.web_document()

document_management: DocumentManagement = test_container.applications.use_cases.managements.document()
file_document_management: FileDocumentManagement = test_container.applications.use_cases.managements.file_document()
text_document_management: TextDocumentManagement = test_container.applications.use_cases.managements.text_document()
web_document_management: WebDocumentManagement = test_container.applications.use_cases.managements.web_document()

all_seeder: AllSeeder = test_container.seeders.all()

In [65]:
await all_seeder.up()

In [8]:
await all_seeder.down()

In [5]:
await two_datastore.client.set("test", "test", ex=10)

True

In [118]:
# loading the V2 dataset
amnesty_qa = load_dataset("explodinggradients/amnesty_qa", "english_v2", trust_remote_code=True)

In [119]:
amnesty_qa

DatasetDict({
    eval: Dataset({
        features: ['question', 'ground_truth', 'answer', 'contexts'],
        num_rows: 20
    })
})

In [71]:
class PartitionDocumentProcessor:
    def __init__(
            self,
            document_management: DocumentManagement,
            file_document_management: FileDocumentManagement,
            text_document_management: TextDocumentManagement,
            web_document_management: WebDocumentManagement,
    ):
        self.document_management = document_management
        self.file_document_management = file_document_management
        self.text_document_management = text_document_management
        self.web_document_management = web_document_management

    async def _partition_file(self, state: State, found_document: Document) -> List[Element]:
        found_file_document: FileDocumentResponse = await self.file_document_management.find_one_by_id(
            state=state,
            id=found_document.id
        )
        file_data: bytes = self.file_document_management.file_document_repository.get_object_data(
            object_name=found_file_document.file_name
        )
        extract_image_path: Path = self.file_document_management.file_document_repository.file_path / found_file_document.file_data_hash
        extract_image_path.mkdir(exist_ok=True)
        shutil.rmtree(extract_image_path)
        elements = partition(
            metadata_filename=found_file_document.file_name,
            file=io.BytesIO(file_data),
            extract_images_in_pdf=True,
            extract_image_block_output_dir=str(extract_image_path),
            strategy=PartitionStrategy.HI_RES,
            hi_res_model_name="yolox"
        )

        return elements

    async def _partition_text(self, state: State, found_document: Document) -> List[Element]:
        found_text_document: TextDocumentResponse = await self.text_document_management.find_one_by_id(
            state=state,
            id=found_document.id
        )
        elements = partition(
            text=found_text_document.text_content
        )

        return elements

    async def _partition_web(self, state: State, found_document: Document) -> List[Element]:
        found_web_document: WebDocumentResponse = await self.web_document_management.find_one_by_id(
            state=state,
            id=found_document.id
        )
        elements = partition(
            url=found_web_document.web_url
        )

        return elements

    async def partition(self, state: State, document_id: UUID) -> List[Element]:
        found_document: Document = await self.document_management.find_one_by_id(
            state=state,
            id=document_id
        )
        if found_document.document_type_id == "file":
            elements: List[Element] = await self._partition_file(
                state=state,
                found_document=found_document
            )
        elif found_document.document_type_id == "text":
            elements: List[Element] = await self._partition_text(
                state=state,
                found_document=found_document
            )
        elif found_document.document_type_id == "web":
            elements: List[Element] = await self._partition_web(
                state=state,
                found_document=found_document
            )
        else:
            raise use_case_exception.DocumentTypeNotSupported()

        return elements


class CategoryDocumentProcessor:
    def __init__(self):
        pass

    async def categorize_elements(self, elements: List[Element]) -> ElementCategory:
        element_category: ElementCategory = ElementCategory(
            texts=[],
            tables=[],
            images=[]
        )

        for element in elements:
            if any(
                    element_type in str(type(element)) for element_type in
                    ["unstructured.documents.elements.Text", "unstructured.documents.elements.NarrativeText"]
            ):
                element_category.texts.append(element)
            elif any(
                    element_type in str(type(element)) for element_type in
                    ["unstructured.documents.elements.Table"]
            ):
                element_category.tables.append(element)
            elif any(
                    element_type in str(type(element)) for element_type in
                    ["unstructured.documents.elements.Image"]
            ):
                element_category.images.append(element)
            else:
                print(f"BaseDocumentProcessor.categorize_elements: Ignoring element type {type(element)}.")

        return element_category


class SplitDocumentProcessor:
    def __init__(self):
        pass

    def split_texts(self, texts: List[Text], chunk_size: int = 4000, chunk_overlap: int = 0) -> List[str]:
        text_splitter: RecursiveCharacterTextSplitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap
        )
        text: str = " ".join([text.text for text in texts])
        splitted_text: List[str] = text_splitter.split_text(
            text=text
        )

        return splitted_text


class SummaryDocumentProcessor:
    def __init__(self):
        pass

    def summarize_tables(self, tables: List[Table], model: BaseChatModel) -> List[str]:
        prompt_text = """You are an assistant tasked with summarizing tables for retrieval. \
        These summaries will be embedded and used to retrieve the tables. \
        Give a concise summary of the tables that is well optimized for retrieval. Table : {table} """
        prompt: ChatPromptTemplate = ChatPromptTemplate.from_template(prompt_text)
        chain: RunnableSerializable = {"table": lambda table: table.text} | prompt | model | StrOutputParser()
        summaries: List[str] = chain.batch(tables)

        return summaries

    def _get_message_from_image(self, model: BaseChatModel, prompt_text: str, image: Image) -> BaseMessage:
        message = model.invoke([
            HumanMessage(
                content=[
                    {
                        "type": "text",
                        "text": prompt_text
                    },
                    {
                        "type": "image",
                        "source": {
                            "type": "base64",
                            "media_type": image.metadata.image_mime_type,
                            "data": image.metadata.image_base64
                        }
                    }
                ]
            )
        ])

        return message

    def summarize_images(self, images: List[Image], model: BaseChatModel) -> List[str]:
        prompt_text = """You are an assistant tasked with summarizing images for retrieval. \
        These summaries will be embedded and used to retrieve the images. \
        Give a concise summary of the images that is well optimized for retrieval."""
        summaries: List[str] = []
        for image in images:
            message: BaseMessage = self._get_message_from_image(
                model=model,
                prompt_text=prompt_text,
                image=image
            )
            summaries.append(message.content)

        return summaries


partition_document_processor: PartitionDocumentProcessor = PartitionDocumentProcessor(
    document_management=document_management,
    file_document_management=file_document_management,
    text_document_management=text_document_management,
    web_document_management=web_document_management,
)

category_document_processor: CategoryDocumentProcessor = CategoryDocumentProcessor()
split_document_processor: SplitDocumentProcessor = SplitDocumentProcessor()
summary_document_processor: SummaryDocumentProcessor = SummaryDocumentProcessor()


In [None]:
elements: List[Element] = None


async def handler(session: AsyncSession):
    global elements
    state: State = State()
    state.session = session
    state.authorized_session = all_seeder.session_seeder.session_mock.data[0]
    elements = await partition_document_processor.partition(
        state=state,
        document_id=all_seeder.file_document_seeder.file_document_mock.data[0].id
    )


await one_datastore.retryable(handler)

elements

In [73]:
element_category: ElementCategory = await category_document_processor.categorize_elements(
    elements=elements
)

BaseDocumentProcessor.categorize_elements: Ignoring element type <class 'unstructured.documents.elements.Title'>.
BaseDocumentProcessor.categorize_elements: Ignoring element type <class 'unstructured.documents.elements.ListItem'>.
BaseDocumentProcessor.categorize_elements: Ignoring element type <class 'unstructured.documents.elements.Header'>.
BaseDocumentProcessor.categorize_elements: Ignoring element type <class 'unstructured.documents.elements.FigureCaption'>.
BaseDocumentProcessor.categorize_elements: Ignoring element type <class 'unstructured.documents.elements.Header'>.
BaseDocumentProcessor.categorize_elements: Ignoring element type <class 'unstructured.documents.elements.Header'>.
BaseDocumentProcessor.categorize_elements: Ignoring element type <class 'unstructured.documents.elements.Header'>.
BaseDocumentProcessor.categorize_elements: Ignoring element type <class 'unstructured.documents.elements.Title'>.
BaseDocumentProcessor.categorize_elements: Ignoring element type <class '

In [76]:
litellm.set_verbose = False
model_list: List[Dict] = [
    {
        "model_name": "claude-3-haiku",
        "litellm_params": {
            "model": "claude-3-haiku-20240307",
            "api_key": os.getenv("ANTHROPIC_API_KEY"),
        }
    }
]
router: Router = Router(model_list=model_list)
model: ChatLiteLLMRouter = ChatLiteLLMRouter(
    router=router,
    model_name="claude-3-haiku",
    streaming=True,
    temperature=0,
)

embeddings: HuggingFaceE5InstructEmbeddings = HuggingFaceE5InstructEmbeddings(
    model_name="/mnt/c/Data/Apps/research-assistant-infrastructure/data/models/infloat/multilingual-e5-large-instruct",
    model_kwargs={'device': 'cuda'},
    encode_kwargs={'normalize_embeddings': True},
    query_instruction="Given the question, retrieve the answer from the context."
)

ImportError: Could not import sentence_transformers python package. Please install it with `pip install sentence-transformers`.

ImportError: Could not import sentence_transformers python package. Please install it with `pip install sentence-transformers`.

In [ ]:
summarized_tables: List[str] = summary_document_processor.summarize_tables(
    tables=element_categories.tables,
    model=model
)

In [ ]:
summarized_images: List[str] = summary_document_processor.summarize_images(
    images=element_categories.images,
    model=model
)

In [0]:
splitted_texts: List[str] = split_document_processor.split_texts(
    chunk_size=10,
    chunk_overlap=0,
    texts=element_categories.texts
)
len(splitted_texts)
splitted_texts[0], splitted_texts[1]

In [ ]:
class MultiVectorRetriever:
    def __init__(
            self,
            two_datastore: TwoDatastore,
            four_datastore: FourDatastore,
    ):
        self.two_datastore = two_datastore
        self.four_datastore = four_datastore

    def get_retriever(
            self,
            embedding_function: Embeddings,
            collection_name: str,
            **kwargs: Any
    ) -> retrievers.MultiVectorRetriever:
        document_store: RedisStore = RedisStore(
            client=self.two_datastore.client,
        )
        vector_store: Milvus = self.four_datastore.get_client(
            embedding_function=embedding_function,
            collection_name=collection_name,
        )
        return retrievers.MultiVectorRetriever(
            docstore=document_store,
            vectorstore=vector_store,
            **kwargs
        )


In [0]:
class GraphState(TypedDict):
    data: Dict[str, Any]


class GraphLongFormQa:
    def __init__(
            self,
            four_datastore: FourDatastore,
    ):
        self.four_datastore = four_datastore

    def _retrieve(self, input_state: GraphState) -> GraphState:
        output_state: GraphState = copy.deepcopy(input_state)
        retriever: Milvus = self.four_datastore.get_client(
            embedding_function=input_state["retriever_setting"]["embedding_function"],
            collection_name=input_state["retriever_setting"]["collection_name"],
        )
        return output_state

    def _get_node(self, node: RunnableLike) -> Tuple[str, RunnableLike]:
        return node.__name__, node

    def compile(self) -> CompiledGraph:
        graph: StateGraph = StateGraph(GraphState)
        graph.add_node(*self._get_node(self._retrieve))
        graph.set_entry_point(self._retrieve.__name__)
        graph.set_finish_point(self._retrieve.__name__)
        compiled_graph: CompiledGraph = graph.compile()
        return compiled_graph


graph_lfqa = GraphLongFormQa(
    four_datastore=four_datastore,
)
compiled_graph_lfqa = graph_lfqa.compile()
input_state: GraphState = GraphState(
    data={
        "retriever_setting": {
            "embedding_function": HuggingFaceE5InstructEmbeddings,
            "collection_name": all_seeder.account_seeder.account_mock.data[0].id,
        },
        "question": "what is artificial intelligence?"
    }
)
compiled_graph_lfqa.invoke(input_state)


In [30]:
eval_data = amnesty_qa["eval"].select(range(1))
eval_data

Dataset({
    features: ['question', 'ground_truth', 'answer', 'contexts'],
    num_rows: 1
})

In [21]:
result = evaluate(
    eval_data,
    llm=llm,
    embeddings=embeddings,
    # metrics=[
    #     metrics.faithfulness,
    #     metrics.answer_relevancy, 
    #     metrics.context_recall,
    #     metrics.context_precision,
    #     metrics.answer_correctness,
    #     metrics.context_relevancy,
    #     metrics.context_entity_recall,
    # ],
)

Evaluating:   0%|          | 0/4 [00:00<?, ?it/s]


[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new[0m
LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True'.


Invalid JSON response. Expected dictionary with key 'Attributed'
  value = np.nanmean(self.scores[cn])


In [22]:
result

{'answer_relevancy': 0.9599, 'context_precision': 1.0000, 'faithfulness': 0.5000, 'context_recall': nan}