In [None]:
import json
import os
from dataclasses import dataclass
import os
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_text_splitters import CharacterTextSplitter
import json
from dataclasses import dataclass
from langchain_text_splitters import CharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langgraph.graph import MessagesState, StateGraph
from langchain_core.tools import tool
from langchain_core.messages import SystemMessage
from langgraph.prebuilt import ToolNode
from langgraph.graph import END
from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.checkpoint.memory import MemorySaver


file_path = "listings/metadata/listings_0.json"

# with open(file_path, "r", encoding="utf-8") as f:
#     data = [json.loads(line) for line in f]  # Read each line as a separate JSON object

# n_data_samples = len(data)
# print(n_data_samples)  # Number of JSON objects in the file

In [None]:
api_key = os.getenv('GEMINI_API_KEY')
# client = genai.Client(api_key=api_key)

# For information on the available models, see:
# https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models

In [None]:
@dataclass
class FileAndMeatadata:
    file_string: str
    item_id: str
    main_image_id: str
    other_image_id: str

class VectorStore:
    def __init__(self, file_path,
                 embedding_model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2", 
                 chunk_size=1000, 
                 chunk_overlap=200):
        self._embedding_model_name = embedding_model_name
        self._file_path = file_path
        self._chunk_size = chunk_size
        self._chunk_overlap = chunk_overlap
        self._embedding_model = HuggingFaceEmbeddings(model_name=embedding_model_name)
        # self._tokenizer = AutoTokenizer.from_pretrained(embedding_model_name)
        self._file_and_metadata_list = self._load_and_process_json()
        self.vectorstore = self._chunk_and_embed()

    def perform_search(self, query: str, top_k=1):
        """Performs a similarity search and returns results."""
        # TODO: pre-process query if it is too long!
        # For this you will need to use self._tokenizer
        docs = self.vectorstore.similarity_search(query, k=top_k)
        return docs

    def _load_and_process_json(self) -> list[FileAndMeatadata]:
        """Loads JSON data, processes it, and returns a list of FileAndMeatadata objects."""
        with open(file_path, 'r', encoding='utf-8') as f:
            return [self._json_to_str(json.loads(line)) for line in f] 
        
    def _json_to_str(self, file: dict) -> FileAndMeatadata:
        item_id = file.pop('item_id', None)
        main_image_id = file.pop('main_image_id', None)
        other_image_id = file.pop('other_image_id', None)

        file.pop('model_number', None)
        file.pop('marketplace', None)
        file.pop("domain_name", None)

        remaining_data_str = self._flatten_json(file)
        file_and_metadata = FileAndMeatadata(file_string=remaining_data_str, 
                                             item_id=item_id, 
                                             main_image_id=main_image_id, 
                                             other_image_id=other_image_id)
        return file_and_metadata
    
    def _flatten_json(self, y):
        """Flatten nested JSON into a plain text format."""
        out = []

        def flatten(x):
            if isinstance(x, dict):
                for v in x.values():
                    flatten(v)
            elif isinstance(x, list):
                for item in x:
                    flatten(item)
            elif isinstance(x, str):
                out.append(x.lower())

        flatten(y)
        return " ".join(out)

    def _chunk_and_embed(self) -> FAISS:
        """Chunks file_string, embeds, and creates a FAISS vector store with metadata."""
        text_splitter = CharacterTextSplitter(chunk_size=self._chunk_size, chunk_overlap=self._chunk_overlap)
        chunks_with_metadata = []

        # TODO: smarter chunking based on the number of tokens insted of the number of characters.
        for item_idx, item in enumerate(self._file_and_metadata_list):
            if len(item.file_string) > self._chunk_size:
                text_chunks = text_splitter.split_text(item.file_string)
                for chunk in text_chunks:
                    chunks_with_metadata.append((chunk, {
                        "item_idx": item_idx,
                        "item_id": item.item_id,
                        "main_image_id": item.main_image_id,
                        "other_image_id": item.other_image_id
                    }))
            else:
                chunks_with_metadata.append((item.file_string, {
                    "item_idx": item_idx,
                        "item_id": item.item_id,
                        "main_image_id": item.main_image_id,
                        "other_image_id": item.other_image_id
                }))

        texts = [chunk[0] for chunk in chunks_with_metadata]
        metadatas = [chunk[1] for chunk in chunks_with_metadata]

        vectorstore = FAISS.from_texts(texts, self._embedding_model, metadatas=metadatas)
        return vectorstore

In [None]:
# Example usage:
vector_store = VectorStore(file_path)

In [None]:
query = "iphone cover"
results = vector_store.perform_search(query)
results[0].page_content

In [None]:
results[0].metadata

In [None]:
# class RAGShoppingAssistant:
#     def __init__(self, vectorstore):
#         self.vectorstore = vectorstore
#         self._llm = ChatGoogleGenerativeAI(model="gemini-1.5-pro", api_key=os.getenv('GEMINI_API_KEY'))
#         self._memory = ConversationSummaryBufferMemory.from_llm(llm=self._llm, memory_key="chat_history", return_messages=True)

#         self._rag_chain = self._setup_rag_chain()

#     def _setup_rag_chain(self):
#         """Sets up the RAG chain with conversation memory."""
#         template = """Use the following pieces of context to answer the question at the end.
#         If you don't know the answer, just say you don't know, don't try to make up an answer.
#         Also, keep the conversation going, and remember the previous questions and answers.
#         {context}
#         Question: {question}
#         {chat_history}
#         Answer:"""
#         prompt = ChatPromptTemplate.from_template(template)
#         retriever = self.vectorstore.as_retriever()
#         rag_chain = ConversationalRetrievalChain.from_llm(
#             self._llm,
#             retriever,
#             memory=self._memory,
#             return_source_documents=True,
#         )
#         return rag_chain

#     def chat_with_assistant(self, question):
#         result = self._rag_chain({"question": question})
#         print(f"Answer: {result['answer']}")
#         for doc in result['source_documents']:
#             print(f"  - {doc.page_content} (Metadata: {doc.metadata})")
#         return result['question']


# assistant = RAGShoppingAssistant(vector_store)

In [None]:
# Check out this for RAG
# Just review this one https://python.langchain.com/docs/tutorials/rag/
# Use this one https://python.langchain.com/docs/tutorials/qa_chat_history/

In [None]:
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash-001", api_key=os.getenv('GEMINI_API_KEY'))

In [None]:
graph_builder = StateGraph(MessagesState)

@tool(response_format="content_and_artifact")
def retrieve(query: str):
    """Retrieve information related to a query."""
    retrieved_docs = vector_store.perform_search(query, top_k=2)
    serialized = "\n\n".join(
        (f"Content: {doc.page_content}")
        for doc in retrieved_docs
    )
    return serialized, retrieved_docs

# Step 1: Generate an AIMessage that may include a tool-call to be sent.
def query_or_respond(state: MessagesState):
    """Generate tool call for retrieval or respond."""
    llm_with_tools = llm.bind_tools([retrieve])
    response = llm_with_tools.invoke(state["messages"])
    # MessagesState appends messages to state instead of overwriting
    return {"messages": [response]}


# Step 2: Execute the retrieval.
tools = ToolNode([retrieve])


# Step 3: Generate a response using the retrieved content.
def generate(state: MessagesState):
    """Generate answer."""
    # Get generated ToolMessages
    recent_tool_messages = []
    for message in reversed(state["messages"]):
        if message.type == "tool":
            recent_tool_messages.append(message)
        else:
            break
    tool_messages = recent_tool_messages[::-1]

    # Format into prompt
    docs_content = "\n\n".join(doc.content for doc in tool_messages)
    system_message_content = (
        "You are a shopping assistant for question-answering tasks. "
        "Use the following pieces of retrieved context to answer "
        "the question. If you don't know the answer, say that you "
        "don't know. Use three sentences maximum and keep the "
        "answer concise."
        "\n\n"
        f"{docs_content}"
    )
    conversation_messages = [
        message
        for message in state["messages"]
        if message.type in ("human", "system")
        or (message.type == "ai" and not message.tool_calls)
    ]
    prompt = [SystemMessage(system_message_content)] + conversation_messages

    # Run
    response = llm.invoke(prompt)
    return {"messages": [response]}

In [None]:
graph_builder.add_node(query_or_respond)
graph_builder.add_node(tools)
graph_builder.add_node(generate)

graph_builder.set_entry_point("query_or_respond")
graph_builder.add_conditional_edges(
    "query_or_respond",
    tools_condition,
    {END: END, "tools": "tools"},
)
graph_builder.add_edge("tools", "generate")
graph_builder.add_edge("generate", END)

memory = MemorySaver()
graph = graph_builder.compile(checkpointer=memory)

config = {"configurable": {"thread_id": "abc123"}}

In [None]:
# from IPython.display import Image, display

# display(Image(graph.get_graph().draw_mermaid_png()))

In [None]:
input_message = "Hello"

for step in graph.stream(
    {"messages": [{"role": "user", "content": input_message}]},
    stream_mode="values",
    config=config
):
    step["messages"][-1].pretty_print()

In [None]:
input_message = "Can you help me to find a phone cover for an iphone?"

for step in graph.stream(
    {"messages": [{"role": "user", "content": input_message}]},
    stream_mode="values",
    config=config
):
    step["messages"][-1].pretty_print()

In [None]:
input_message = "Is it available in red?"

for step in graph.stream(
    {"messages": [{"role": "user", "content": input_message}]},
    stream_mode="values",
    config=config
):
    step["messages"][-1].pretty_print()

In [None]:
# TODO: return not only text, but also corresponding images.

In [None]:
# RAG with text & image embedding.
# TODO: just use an embedding model accepting both text and image inputs.