In [53]:
!pip install -U --quiet langchain langchain_community chromadb  langchain-google-vertexai langchain_chroma
!pip install --quiet "unstructured[all-docs]" pypdf pillow pydantic lxml pillow matplotlib chromadb tiktoken

In [54]:
PROJECT_ID = ""
REGION = ""

from google.colab import auth
auth.authenticate_user()

In [55]:
import vertexai
vertexai.init(project = PROJECT_ID , location = REGION)

In [56]:
#Download and prepare data
import logging
import zipfile
import requests

logging.basicConfig(level=logging.INFO)

data_url = "https://storage.googleapis.com/benchmarks-artifacts/langchain-docs-benchmarking/cj.zip"
result = requests.get(data_url)
filename = "cj.zip"
with open(filename, "wb") as file:
   file.write(result.content)

with zipfile.ZipFile(filename, "r") as zip_ref:
   zip_ref.extractall()

In [57]:
#load the downloded "cj.zip"
from langchain_community.document_loaders import PyPDFLoader

loader = PyPDFLoader("./cj/cj.pdf")
docs = loader.load()
tables = []
texts = [d.page_content for d in docs]

In [59]:
#Generate Text summaries
from langchain_google_vertexai import VertexAI , ChatVertexAI , VertexAIEmbeddings
from langchain.prompts import PromptTemplate
from langchain_core.messages import AIMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda

Read about Runnable Lambda here - https://www.pinecone.io/learn/series/langchain/langchain-expression-language/

In [60]:
# Generate summaries of text elements
def generate_text_summaries(texts, tables, summarize_texts=False):
   """
   Summarize text elements
   texts: List of str
   tables: List of str
   summarize_texts: Bool to summarize texts
   """

   # Prompt
   prompt_text = """You are an assistant tasked with summarizing tables and text for retrieval. \
   These summaries will be embedded and used to retrieve the raw text or table elements. \
   Give a concise summary of the table or text that is well optimized for retrieval. Table or text: {element} """
   prompt = PromptTemplate.from_template(prompt_text)
   empty_response = RunnableLambda(
       lambda x: AIMessage(content="Error processing document")
   )
   # Text summary chain
   model = VertexAI(
       temperature=0, model_name="gemini-pro", max_output_tokens=1024
   ).with_fallbacks([empty_response])
   summarize_chain = {"element": lambda x: x} | prompt | model | StrOutputParser()

   # Initialize empty summaries
   text_summaries = []
   table_summaries = []

   # Apply to text if texts are provided and summarization is requested
   if texts and summarize_texts:
       text_summaries = summarize_chain.batch(texts, {"max_concurrency": 1})
   elif texts:
       text_summaries = texts

   # Apply to tables if tables are provided
   if tables:
       table_summaries = summarize_chain.batch(tables, {"max_concurrency": 1})

   return text_summaries, table_summaries


# Get text summaries
text_summaries, table_summaries = generate_text_summaries(
   texts, tables, summarize_texts=True
)

#text_summaries[0]

# Code for multi vector Retrival

In [62]:
import uuid
from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain.storage import InMemoryStore
from langchain_chroma import Chroma
from langchain_core.documents import Document

In [63]:
#This function provides functionality to first setup a retrival funtion,
# then function to add the supplied documents to the DB (both the summary and
# raw text), then return this initialized Retriver function
def create_multi_vector_retriever(
   vectorstore, text_summaries, texts, table_summaries, tables):
   """
   Create retriever that indexes summaries, but returns raw images or texts
   """

   # Initialize the storage layer
   store = InMemoryStore()
   id_key = "doc_id"

   # Create the multi-vector retriever
   retriever = MultiVectorRetriever(
       vectorstore=vectorstore,
       docstore=store,
       id_key=id_key,
   )

   # Helper function to add documents to the vectorstore and docstore
   def add_documents(retriever, doc_summaries, doc_contents):
       doc_ids = [str(uuid.uuid4()) for _ in doc_contents]
       summary_docs = [
           Document(page_content=s, metadata={id_key: doc_ids[i]})
           for i, s in enumerate(doc_summaries)
       ]
       retriever.vectorstore.add_documents(summary_docs) # Summary added to DB
       retriever.docstore.mset(list(zip(doc_ids, doc_contents)))# Raw text added to DB

   # Add texts, tables, and images
   # Check that text_summaries is not empty before adding
   if text_summaries:
       add_documents(retriever, text_summaries, texts)
   # Check that table_summaries is not empty before adding
   if table_summaries:
       add_documents(retriever, table_summaries, tables)

   # Check that image_summaries is not empty before adding
   #if image_summaries:
       #add_documents(retriever, image_summaries, images)

   return retriever


# The vectorstore to use to Vecterize, Embbed and index the summaries
vectorstore = Chroma(
   collection_name="mm_rag_cj_blog",
   embedding_function=VertexAIEmbeddings(model_name="textembedding-gecko@latest"),
)

# Create and initialize retriever using the functions created above
retriever_multi_vector_img = create_multi_vector_retriever(
   vectorstore,
   text_summaries,
   texts,
   table_summaries,
   tables
)


Build the milti modal RAG pipeline

In [64]:
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_core.messages import HumanMessage
import io

In [65]:
#convert_Doc_to_Dict() - THis function coverts the retriver output to a dict
def convert_Doc_to_Dict(docs):
   """
   coverts the retriver output to a dict
   """
   b64_images = []
   texts = []
   for doc in docs:
       # Check if the document is of type Document and extract page_content if so
       if isinstance(doc, Document):
           doc = doc.page_content
           texts.append(doc)
       #if looks_like_base64(doc) and is_image_data(doc):
       #    doc = resize_base64_image(doc, size=(1300, 600))
       #    b64_images.append(doc)
       #else:
       #    texts.append(doc)
   #if len(b64_images) > 0:
   #    return {"images": b64_images[:1], "texts": []}
   return {"images": b64_images, "texts": texts}

In [66]:
#This function creates the prompt for the RAG task
def rag_prompt_func(data_dict):
   """
   Join the context into a single string
   """
   formatted_texts = "\n".join(data_dict["context"]["texts"])
   messages = []

   # Adding the text for analysis
   text_message = {
       "type": "text",
       "text": (
           "You are financial analyst tasking with providing investment advice.\n"
           "You will be given a text, tables.\n"
           "Use this information to provide investment advice related to the user question. \n"
           f"User-provided question: {data_dict['question']}\n\n"
           "Text and / or tables:\n"
           f"{formatted_texts}"
       ),
   }
   messages.append(text_message)
   '''
   # Adding image(s) to the messages if present
   if data_dict["context"]["images"]:
       for image in data_dict["context"]["images"]:
           image_message = {
               "type": "image_url",
               "image_url": {"url": f"data:image/jpeg;base64,{image}"},
           }
           messages.append(image_message)
    '''
   return [HumanMessage(content=messages)]


In [67]:
def multi_modal_rag_chain(retriever):
   """
   Multi-modal RAG chain
   """

   # Multi-modal LLM
   model = ChatVertexAI(
       temperature=1, model_name="gemini-pro-vision", max_output_tokens=1024
       #gemini-1.5-flash, gemma2,
   )

   # RAG pipeline
   chain = (
       {
           "context": retriever | RunnableLambda(convert_Doc_to_Dict),
           "question": RunnablePassthrough(),
       }
       | RunnableLambda(rag_prompt_func)
       | model
       | StrOutputParser()
   )

   return chain


# Create RAG chain
chain_multimodal_rag = multi_modal_rag_chain(retriever_multi_vector_img)

Test Retrival pipeline

In [71]:
query = "What updates do you have about Genrative AI technology space?"
docs = retriever_multi_vector_img.get_relevant_documents(query)#, limit=1)

# We get relevant docs
len(docs)

1

Calling the RAG pipeline

In [None]:
result = chain_multimodal_rag.invoke(query)

from IPython.display import Markdown as md
md(result)