In [172]:
import os
import json
import uuid
import requests
import base64
import io
import tqdm

from PIL import Image
from dotenv import load_dotenv

from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain.storage import InMemoryStore
from langchain_chroma import Chroma
from langchain.schema import Document, HumanMessage, AIMessage
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema.messages import HumanMessage, AIMessage
from langchain.schema.runnable import RunnablePassthrough
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain_core.runnables import RunnableLambda

import gradio as gr

# Load environment variables
load_dotenv()

## Reading data from file

In [2]:
with open('raw_data.json', 'r') as file:
    data = json.load(file)

In [3]:
def encode_image(image_url: str) -> str:
    response = requests.get(image_url)
    response.raise_for_status() 
    return base64.b64encode(response.content).decode('utf-8')

In [4]:
text_elements = []
image_elements = []
image_links = []
link_elements = []
title_elements = []

In [5]:
for tag in data:
    for article in data[tag]:
        text_elements.append(article['text'])
        image_links.append(article['feature_image'])
        encoded_image = encode_image(article['feature_image'])
        image_elements.append(encoded_image)
        link_elements.append(article['link'])
        title_elements.append(article['title'])
        

In [136]:
# print("The length of title elements are :", len(title_elements))
print("The length of text elements are :", len(text_elements))
print("The length of image elements are :",len(image_elements))
print("The length of link elements are :", len(link_elements))
print("The length of title elements are :", len(title_elements))

The length of text elements are : 1259
The length of image elements are : 1259
The length of link elements are : 1259
The length of title elements are : 1259


## Preparing data for vector db

In [6]:
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1150, chunk_overlap=150)

In [7]:
all_chunks = []
original_data_chunks = []

for title, article, link in zip(title_elements, text_elements, link_elements):
    chunks = text_splitter.split_text(article)
    for chunk in chunks:
        all_chunks.append(chunk)
        original_data_chunks.append([title, link, chunk])

In [251]:
chain_gpt= ChatOpenAI(model="gpt-4o-mini", max_tokens=1024)

In [252]:
# Function for image summaries
def summarize_image(encoded_image):
    prompt = [
        AIMessage(content="You are a bot that is good at analyzing images."),
        HumanMessage(content=[
            {"type": "text", "text": "Describe the contents of this image."},
            {
                "type": "image_url",
                "image_url": {
                    "url": f"data:image/jpeg;base64,{encoded_image}"
                },
            },
        ])
    ]
    response = chain_gpt.invoke(prompt)
    return response.content

In [12]:
# Path for saving progress
save_path_summaries = 'image_summaries.json'

if os.path.exists(save_path_summaries):
    with open(save_path_summaries, 'r') as f:
        image_summaries = json.load(f)
else:
    image_summaries = []

In [13]:
print(len(image_summaries))

1259


In [None]:
# This part is commented because we already generated image_summaries
# def save_progress(summaries):
#     with open(save_path_summaries, 'w') as f:
#         json.dump(summaries, f)
#     
# 
# save_interval = 10
# 
# # Processing images with progress tracking
# for i, image in tqdm(enumerate(image_elements, start=len(image_summaries)), 
#                      total=len(image_elements), desc="Processing images"):
#     try:
#         image_summary = summarize_image(image)
#         image_summaries.append(image_summary)
#         
#         # Auto-save after a specified interval
#         if (i + 1) % save_interval == 0:
#             save_progress(image_summaries)
#             print(f"Progress saved at image {i + 1}")
# 
#     except Exception as e:
#         print(f"Error processing image {i + 1}: {e}")
# 
# # Final save after the loop ends
# save_progress(image_summaries)
# print("Final progress saved.")


## Creating vector database


In [8]:
# Initialize the vector store and storage layer
vectorstore = Chroma(collection_name="sources", embedding_function=OpenAIEmbeddings())
store = InMemoryStore()
id_key = "article_id"

In [9]:
retriever = MultiVectorRetriever(vectorstore=vectorstore, docstore=store, id_key=id_key)

In [10]:
# Add text chunks
text_ids = [str(uuid.uuid4()) for _ in all_chunks]
text_docs = [
    Document(
        page_content=text, 
        metadata={
            id_key: text_ids[i]
        }
    )
    for i, text in enumerate(all_chunks)
]

original_text_docs = [Document(
    page_content = original_data_chunks[i][2],
    metadata={
       "type": 'text',
       "title": original_data_chunks[i][0],  
        "link": original_data_chunks[i][1]    
        }
    )
    for i in range(len(original_data_chunks))
]
# Add the documents to the vectorstore
retriever.vectorstore.add_documents(text_docs)
# Store the original chunks (document content) in the docstore
retriever.docstore.mset(
    [(text_ids[i], original_text_docs[i]) for i in range(len(all_chunks))]
)

In [14]:
#Add images
image_ids = [str(uuid.uuid4()) for _ in image_summaries]
image_docs = [
    Document(
        page_content=summary,  # Image summary or description
        metadata={
            id_key: image_ids[i] 
        }
    )
    for i, summary in enumerate(image_summaries)
]

original_image_docs = [Document(
    page_content=image, 
    metadata={
        "type": 'image',
        'link': image_links[i]
    }
    )
    for i, image in enumerate(image_elements)
]

# Add the documents to the vectorstore
retriever.vectorstore.add_documents(image_docs)
# Store the original images in the docstore
retriever.docstore.mset(
    [(image_ids[i], original_image_docs[i]) for i in range(len(image_summaries))]
)

## Implementing RAG

In [97]:
def prepare_context(retrieved_docs : str):
    context = ''
    for doc in retrieved_docs:
        if doc.metadata['type'] == 'text':
            context+= doc.page_content
            context += '\n\n'

    if len(context) > 128000:
        context = context[:12800]
    return context


In [159]:
template = """ You are part of a system that answer questions and search relevant articles and images. 

Your task: Answer the question based only on the following context, which can include only text.\
           If you don't know than answer: "It is better to look at retrieved artickles or images."\n\nContext: 
{context}
Question: {question}
"""
prompt = ChatPromptTemplate.from_template(template)

model = ChatOpenAI(temperature=0, model="gpt-4o-mini")

chain = (
    {"context": retriever | RunnableLambda(prepare_context) , "question": RunnablePassthrough()}
    | prompt
    | model
    | StrOutputParser()
)


In [160]:
# Wrapper function for Gradio UI
def wrapper_func(query):
    
    retrieved_docs = retriever.get_relevant_documents(query)
    response = chain.invoke(query)
    
    text_results = []
    images = []
    seen_images = set()
    
    for doc in retrieved_docs:
        if doc.metadata['type'] == 'text':
            title = doc.metadata.get('title', 'Untitled')
            link = doc.metadata.get('link', 'No link available')
            content = doc.page_content
            text_results.append(f"**Title:** {title}\n\n**Link:** {str(link)}\n\n**Citation:** {content}")
        elif doc.metadata['type'] == 'image':
            image_data = base64.b64decode(doc.page_content)
            if image_data not in seen_images:
                seen_images.add(image_data)
                images.append(Image.open(io.BytesIO(image_data)))
            
    text_output = "\n\n---\n\n".join(text_results)
    return response, text_output, images

In [170]:
# Building UI with Gradio
with gr.Blocks(theme=gr.themes.Ocean()) as ui:
    gr.Markdown("# Multimodal RAG System Demo 🤖")
    
    query_input = gr.Textbox(label="Enter your question", lines=2, placeholder="Ask me something...")
    submit_btn = gr.Button("Submit")

    answer_output = gr.Textbox(label="Answer", lines=5, interactive=False)
    text_output = gr.Markdown(label="Retrieved Articles")
    image_output = gr.Gallery(label="Retrieved Images", elem_id="gallery")

    submit_btn.click(
        fn=wrapper_func,
        inputs=query_input,
        outputs=[answer_output, text_output, image_output]
    )
    
ui.launch()

* Running on local URL:  http://127.0.0.1:7921

To create a public link, set `share=True` in `launch()`.




In [127]:
query = 'What is O1-engineer ?'
print(chain.invoke(query))

O1-engineer is a command-line tool that uses OpenAI’s API to assist developers with code generation, file management, project planning, and code review. It features an interactive console, conversation history management, and enhanced file and folder operations to help streamline development workflows. Additionally, O1-engineer can automate routine tasks and provide intelligent support throughout the development process.


In [93]:

retrieved_docs = retriever.get_relevant_documents(query)
print('Type: ', type(retrieved_docs))
print('len: ', len(retrieved_docs))