In [2]:
import os
import sys

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), os.pardir)))

In [3]:
from dotenv import load_dotenv

load_dotenv(override=True)


True

In [4]:
from typing import List, Optional

from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders.sql_database import SQLDatabaseLoader
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_openai import OpenAIEmbeddings
from langchain_qdrant import QdrantVectorStore
from qdrant_client import QdrantClient
from sqlalchemy import Select

from configs.common_config import settings
from core.databases.db import engine
from core.databases.models import Restaurant
from core.enumerate import DatabaseSchema, EmbeddingModel



In [35]:


class ShopeeKnowledgeBase:
    def __init__(self) -> None:
        # SQL loader
        self.sql_database = SQLDatabase(
            engine=engine,
            schema=DatabaseSchema.SHOPEE.value,
        )
        # embeddings function
        self.embeddings = OpenAIEmbeddings(
            model=EmbeddingModel.EMBEDDING_ADA_V2.value
        )
        # qdrant setting
        self.collection_name = "shopee"
        self.qdrant_url = settings.QDRANT_URL
        self.qdrant_port = settings.QDRANT_PORT
        self.qdrant_grpc_port = settings.QDRANT_GRPC_PORT
        self.qdrant_api_key = settings.QDRANT_API_KEY
        self.qdrant_client = QdrantClient(
            url=self.qdrant_url,
            api_key=self.qdrant_api_key,
            port=self.qdrant_port,
            grpc_port=self.qdrant_grpc_port,
            # prefer_grpc=True,
            # https=None
        )
        self.retriever_search_kwargs = {
            "k": 5,
        }
        self._vector_db = self.init_vector_db()

    def init_vector_db(self):
        if self.qdrant_client.collection_exists(self.collection_name):
            return QdrantVectorStore.from_existing_collection(
                collection_name=self.collection_name,
                embedding=self.embeddings,
                url=self.qdrant_url,
                port=self.qdrant_port,
                api_key=self.qdrant_api_key,
                # prefer_grpc=True,
            )
        else:
            # create new collection with empty documents
            return QdrantVectorStore.from_documents(
                documents=[],
                embedding=self.embeddings,
                url=self.qdrant_url,
                port=self.qdrant_port,
                api_key=self.qdrant_api_key,
                collection_name=self.collection_name,
                # prefer_grpc=True,
            )

    def load_data(self) -> None:
        loader = SQLDatabaseLoader(
            query=Select(Restaurant.name, Restaurant.address, Restaurant.url),
            db=self.sql_database,
        )
        loaded_documents = loader.load()
        return loaded_documents
        

    def split_documents(
        self,
        loaded_docs,
        chunk_size: Optional[int] = 500,
        chunk_overlap: Optional[int] = 20,
    ):
        splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap,
        )
        chunked_docs = splitter.split_documents(loaded_docs)
        return chunked_docs

    def initiate_document_injetion_pipeline(self):
        loaded_docs = self.load_data()
        chunked_docs = self.split_documents(loaded_docs)
        self._vector_db = self._vector_db.from_documents(
            documents=chunked_docs,
            embedding=self.embeddings,
            url=self.qdrant_url,
            port=self.qdrant_port,
            api_key=self.qdrant_api_key,
            collection_name=self.collection_name,
            # prefer_grpc=True,
        )
        return self._vector_db
    
    @property
    def vector_store(self):
        return self._vector_db
    
    @property
    def retriever(self):
        return self._vector_db.as_retriever(
            search_kwargs=self.retriever_search_kwargs
        )


In [36]:
shopee_kb = ShopeeKnowledgeBase()


In [37]:
retriver = shopee_kb.retriever

In [38]:
data = retriver.invoke("Trà Sữa Bobapop")

In [39]:
for i in data:
    print(i.page_content)
    print("\n")

In [34]:
print("\n\n".join([document.page_content for document in data]))

name: Trà Sữa Mambo Tea - 94 Phan Thanh
address: 94 Phan Thanh, P. Thạc Gián, Quận Thanh Khê, Đà Nẵng
url: https://shopeefood.vn/da-nang/tra-sua-mambo-tea-94-phan-thanh

name: Trà Sữa QT - Âu Cơ
address: 77 Âu Cơ, P. Hoà Khánh Bắc, Quận Liên Chiểu, Đà Nẵng
url: https://shopeefood.vn/da-nang/tra-sua-qt-au-co

name: Trà Sữa Maycha - 302 Ông Ích Khiêm
address: 302 Ông Ích Khiêm, P. Tân Chính, Quận Thanh Khê, Đà Nẵng
url: https://shopeefood.vn/da-nang/tra-sua-maycha-302-ong-ich-khiem

name: Tiệm Trà 15K - Núi Thành
address: 616 Núi Thành, P. Hòa Cường Nam, Quận Hải Châu, Đà Nẵng
url: https://shopeefood.vn/da-nang/tiem-tra-15k-nui-thanh


In [42]:
from typing import Optional

from langchain_core.tools import BaseTool
from langchain_core.callbacks import (
    AsyncCallbackManagerForToolRun,
    CallbackManagerForToolRun
)
from pydantic import BaseModel, Field
from typing import Type
from core.ai.knowledge_base.shopee_kb import shopee_kb


class ShopeeSearchInput(BaseModel):
    """Input for the Shopee tool."""
    query: str = Field(description="search query to look up")


class ShopeeSearch(BaseTool):
    name: str = "shopee_search"
    description: str = "Useful when you need to find a resturant, food and drink on Shopee"
    args_schema: Type[BaseModel] = ShopeeSearchInput
    verbose: bool = True
    # shopee_kb: ShopeeKnowledgeBase = shopee_kb

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.shopee_kb = shopee_kb

    def _run(self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None) -> str:
        """Run Shopee search and get restaurant information, food and drink."""
        retriever = shopee_kb.retriever
        documents = retriever.invoke(input=query)
        print("documents: ", documents)
        return "\n\n".join([document.page_content for document in documents])

    async def _arun(
        self,
        query: str,
        run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
    ) -> str:
        """Use the tool asynchronously."""
        return self._run(query=query, run_manager=run_manager.get_sync())


In [43]:
shopee_search = ShopeeSearch()

ValueError: "ShopeeSearch" object has no field "shopee_kb"