In [None]:
! pip install "unstructured[all-docs]" pillow pydantic lxml matplotlib unstructured-pytesseract tesseract-ocr

In [None]:
! pip install langchain_core langchain_openai langchain chromadb

In [None]:
! sudo apt-get update

In [None]:
! sudo apt-get install poppler-utils

In [None]:
! sudo apt-get install libleptonica-dev tesseract-ocr libtesseract-dev python-pil tesseract-ocr-eng tesseract-ocr-script-latn

In [None]:
from unstructured_partition.pdf import partition_pdf
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
import os
from google.colab import userdata
import base64
from langchain_core.messages import HumanMessage
import uuid

from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain.storage import InMemoryStore
from langchain_community.vectorstores import Chroma
from langchain_core.documents import Document
from langchain_openai import OpenAIEmbeddings
import io
import re
from IPython.display import HTML, display
from PIL import Image
from langchain_core.runnables import RunnableLambda, RunnablePassthrough

In [None]:
OPENAI_API_TOKEN = userdata.get("OPENAI_API_KEY")
os.environ["OPENAI_API_KEY"] = OPENAI_API_TOKEN

In [None]:
raw_pdf_elements = partition_pdf(
    filename="/content/data/cj.pdf",  # Manatory
    strategy="hi_res",   # mandatory to use "hi_res" strategy
    extract_images_in_pdf = True,   # mandatory to set as "True"
    extract_image_block_types = ["Image", "Table"],   # Optional
    extract_image_block_to_payload = False,   # Optional
    extract_image_block_output_dir = "extracted_data"
)

In [None]:
Header = []
Footer = []
Title = []
NarrativeText = []
Text = []
ListItem = []

for element in raw_pdf_elements:
    if "unstructured.documents.elements.Hearder" in str(type(element)):
        Header.append(str(element))

    elif "unstructured.documents.elements.Footer" in str(type(element)):
        Footer.append(str(element))

    elif "unstructured.documents.elements.Title" in str(type(element)):
        Title.append(str(element))

    elif "unstructured.documents.elements.NarrativeText" in str(type(element)):
        NarrativeText.append(str(element))

    elif "unstructured.documents.elements.Text" in str(type(element)):
        Text.append(str(element))

    elif "unstructured.documents.elements.ListItem" in str(type(element)):
        ListItem.append(str(element))
        

In [None]:
img = []

for element in raw_pdf_elements:
    if "unstructured.documents.elements.Image" in str(type(element)):
        img.append(str(element))

In [None]:
raw_pdf_elements_2 = partition_pdf(
    filename = ".content/data2/Retrieval-Augmented-Generation-for-NLP.pdf",
    strategy = "hi_res",
    extract_images_in_pdf = True,
    extreat_image_block_types = ["Image", "Table"],
    extract_image_block_to_payload = False,
    extract_image_block_output_dir = "extracted_data_2"
)

In [None]:
img_2 = []

for element in raw_pdf_elements_2:
    if "unstructured.documents.elements.Image" in str(type(element)):
        img_2.append(str(element))

In [None]:
table_2 = []

for element in raw_pdf_elements:
    if "unstructured.documents.elements.Table" in str(type(element)):
        table_2.append(str(element))

In [None]:
NarrativeText_2 = []

for element in raw_pdf_elements:
    if "unstructured.documents.elements.NarrativeText" in str(type(element)):
        NarrativeText_2.append(str(element))

In [None]:
ListItem_2 = []

for element in raw_pdf_elements:
    if "unstructured.documents.elements.ListItem" in str(type(element)):
        ListItem_2.append(str(element))

Prompt

In [None]:
prompt_text = """ You are an assistant tasked with summarizing texts or tables for retrieval. \
    These summaries will be embedded and uses to retrieve the raw text or table elements. \
    Give a concise summary of the text or table that is well optimized for retrieval"""

In [None]:
prompt = ChatPromptTemplate.from_template(prompt_text)

Text Summary chain

In [None]:
model = ChatOpenAI(temperature = 0, model="gpt-4")

In [None]:
summarizer_chain = {"element": lambda x : x} | prompt | model | StrOutputParser

In [None]:
text_summaries = []
text_summaries = summarizer_chain.batch(Text, {"max_concurrency":5})

In [None]:
table_summaries = []

In [None]:
table_summaries = summarizer_chain.batch(table_2, {"max_concurrency":5})

In [None]:
table_2[0]

In [None]:
table_summaries[0]

In [None]:
def encoder_image(image_path):
    """ Getting The Base64 String """
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode("utf-8")

In [None]:
def image_summarize(img_base64, prompt):
    """ Make Image Summary """
    chat = ChatOpenAI(model="gpt-4-vision-preview", max_tokens=1024)

    msg = chat.invoke(
        [
            HumanMessage(
                content=[
                    {"type": "text", "text":prompt},
                    {
                        "type": "image_url",
                        "image_url": {"url": f"data:image/jpeg;base64, {img_base64}"}
                    },
                ]
            )
        ]
    )

    return msg.content

In [None]:
def generate_img_summaries(path):
    """
    Generate summaries and base64 encoded strings for images
    path: Path to list of .jpg files extracted by unstructured
    """

    # Store base64 encoded images
    img_base64_list = []

    # Store image summaries
    image_summaries = []

    # Prompt
    prompt = """ 

    You are an assistant tasked with summarizing images for retrieval. \
    These summaries will be embedded and used to retrieve the raw image. \
    Give a concise summary of the image that is well optimized for retrieval.
    """

    base64_image = encoder_image(path)
    img_base64_list.append(base64_image)
    image_summaries.append(image_summarize(base64_image, prompt))

    return img_base64_list, image_summaries

In [None]:
file_path = "/content/extracted_data_2/figure-17-4.jpg"

In [None]:
img_base64_list, image_summaries = generate_img_summaries(file_path)

In [None]:
print(image_summaries[0])

Creating MultiVector Retriever

In [None]:
def create_multi_vector_retriever(vector_store, text_summaries, texts, table_summaries, tables, image_summaries, images):
    """ Creating retriever that indexes summaries, but retruns raw images or texts"""

    # Initialize the storage layer
    store = InMemoryStore()
    id_key = "doc_id"

    # Creating the multi-vector retriever
    retriever = MultiVectorRetriever(
        vectorstore= = vector_store,
        docstore = store,
        id_key = id_key
    ) 

    # Helper function to add documents to the vector store and doc_store

    def add_documents(retriever, doc_summaries, doc_contents):
        doc_ids = [str(uuid.uuid4()) for _ in doc_contents]

        summary_docs = [
            Document(page_content=s, metadata={id_key: doc_ids[i]})
            for i, s in enumerate(doc_summaries)
        ]

        retriever.vectorstore.add_documents(summary_docs)
        retriever.docstore.mset(list(zip(doc_ids, doc_contents)))

        # Add texts, tables and images
        # Check that text_summaries is not empty before adding
        if text_summaries:
            add_documents(retriever, text_summaries, texts)

        # Check that table_summaries is not empty before adding
        if table_summaries:
            add_documents(retriever, table_summaries, table_2)

        # Check that image_summaries is not empty before adding
        if image_summaries:
            add_documents(retriever, image_summaries, img)

    return retriever



In [None]:
vector_store = Chroma(
    collection_name="mm_rag", embedding_function=OpenAIEmbeddings()
)

In [None]:
# Creating retriever 
retriever_multi_vector_img = create_multi_vector_retriever(
    vector_store,
    text_summaries,
    Text,
    table_summaries,
    table_2,
    image_summaries,
    img_base64_list
)

In [None]:
def plt_img_based64(img_based64):
    """ Display base64 encoded string as image """
    # Create an HTML img tag with the base64 string as the source
    image_html = f"cing src = 'data:image/jpeg;base64, {img_based64}' />"

    # Display the image by rendering the HTML
    display(HTML(image_html))

In [None]:
plt_img_based64(img_base64_list[1])

In [None]:
image_summaries[1]

In [None]:
def looks_like_base64(sb):
    """ Check if the string looks like base64 """
    return re.match("^[A-Za-z0-9+/] + [*]{0, 2}$", sb) is not None

In [2]:
def is_image_data(b64data):
    """ Check if the base64 data is an image by looking at the start of the data """

    image_signatures = {
        b"\xFF\xDB\xFF": "jpg",
        b"\x89\x50\x4E\x47\x0D\x0A\x1A\x0A": "png",
        b"\x47\x49\x46\x38" : "gif",
        b"\x52\x49\x46\x46" : "webp",
    }

    try:
        header = base64.b64decode(b64data)[:8] # Decode and get the first 8 bytes
        for sig, format in image_signatures.items():
            if header.startswith(sig):
                return True
    except Exception:
        return False

In [None]:
def resize_base64_image(base64_string, size=(128, 128)):
    """ Resize an image encoded as a Base64 string """

    # Decode the Base64 string
    img_data = base64.b64decode(base64_string)
    img = Image.open(io.BytesIO(img_data))

    # Resize the image
    resized_img = img.resize(size=size, Image.LANCZOS)

    # Save the resized image to a bytes buffer
    buffered = io.BytesIO()
    resized_img.save(buffered, format=img.format)

    # Encode the resized image to Base64
    return base64.b64encode(buffered.getvalue()).decode("utf-8")

In [4]:
def split_image_text_types(docs):
    """Split base64-encoded images and texts""" 

    b64_images = []
    texts = []
    for doc in docs:
        # Check if the document is of type Document and extract page_content if so
        if isinstance(doc, Document):
            doc = doc.page_content
        if looks_like_base64(dqc) and is_image_data(doc):
            doc = resize_base64_image(doc, size=(1300, 600))
            b64_images.append(doc)
        else:
            texts.append(doc)
    print(b64_images)
    print(texts)

    return {"images": b64_images, "texts": texts}

In [9]:
def img_prompt_func(data_dict):
    """
    Join the context into a single string
    """
    print(data_dict)
    formatted_texts = "\n". join(data_dict["context"]["texts"])
    messages = []

    # Adding image(s) to the messages if present
    if data_dict["context"]["images"]:
        for image in data_dict["context"]["images"]:
            image_message = {
            "type": "image_ur1",
            "image_url": {"url": f"data:image/jpeg;base64, {image}"},

            }
            messages.append(image_message)

    # Adding the text for analysis
    text_message = {
        "type": "text",
        "text": (
            "You are a helpful assistant. \n"
            "You will be given a mixed info(s) .\n"
            "Use this information to provide relevant information to the user quetion. \n"
            f"User-provided question: {data_dict['question']}\n\n"
            "Text and / or tables: \n"
            f"{formatted_texts}"
        ),

    }
    
    messages.append(text_message)

    return [HumanMessage(content=messages)]

In [11]:
def multi_modal_rag_chain(retriever):
    """ Multi-Modal RAG Chan """

    # Multi-Modal LLM

    model =ChatOpenAI(temperature=0, model="gpt-4-vision-preview", max_tokens=1024)

    # RAG pipeline
    chain = (
        {
            "context" : retriever | RunnableLambda(split_image_text_types),
            "question": RunnablePassthrough(),

        }

        | RunnableLambda(img_prompt_func)
        | model
        | StrOutputParser()
    )

    return chain

Create RAG chain

In [None]:
chain_multimodal_rag = multi_modal_rag_chain(retriever_multi_vector_img)

In [None]:
query = "Explain any images / figures in the paper with Left: NQ performance as more documents are retrieved. Center: Retrieval recall performance" \
"in NQ, Right: MS-MARCO Bleu-1 and Rough-L as more documents are retrieved"

In [None]:
# Run RAG chain 
chain_multimodal_rag.invoke(query)