# Setup

In [3]:
import os
import torch
from groq import Groq
from typing import cast
from ultralytics import YOLO
from pymilvus import MilvusClient
from transformers.utils.import_utils import is_flash_attn_2_available
from retriever_class import MilvusColbertRetriever, MilvusBasicRetriever
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
from colpali_engine.models import ColQwen2_5, ColQwen2_5_Processor
from sentence_transformers import SentenceTransformer

# Define the device
device = "cuda:0"

# Define knowledge base sources and target image directory
document_source_dir = "document_sources"
img_dir = "image_database"
os.makedirs(img_dir, exist_ok=True) # Ensure the directory exists

# YOLO-12L-Doclaynet
yolo_model = YOLO("pretrained_models/yolo-doclaynet/yolov12l-doclaynet.pt")
yolo_model = yolo_model.to(device)

# ColQwen2.5-Colpali
colpali_model = ColQwen2_5.from_pretrained(
        "pretrained_models/colqwen2.5-v0.2",
        torch_dtype=torch.bfloat16,
        device_map=device,  # or "mps" if on Apple Silicon
        attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None,
    ).eval()
colpali_processor = ColQwen2_5_Processor.from_pretrained("pretrained_models/colqwen2.5-v0.2")
processor = cast(
    ColPaliProcessor, 
    colpali_processor)

# Mxbai-embed-large-v1
embed_model = SentenceTransformer("pretrained_models/mxbai-embed-large-v1",device=device)

# Groq API-Llama4
os.environ["GROQ_API_KEY"] = "<your-api-key>"
client_groq = Groq()

# Milvus Client
client = MilvusClient("milvus_file.db")
colbert_retriever = MilvusColbertRetriever(collection_name="colbert", milvus_client=client,img_dir=img_dir)
basic_retriever = MilvusBasicRetriever(collection_name="basic", milvus_client=client)

# Define Entity Colors
ENTITIES_COLORS = {
    "Picture": (255, 72, 88),
    "Table": (128, 0, 128)
}

print("FINISH SETUP...")

Fetching 2 files: 100%|██████████| 2/2 [00:00<00:00, 5641.30it/s]
Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.44s/it]
  from pkg_resources import DistributionNotFound, get_distribution


FINISH SETUP...


In [None]:
import os
from groq import Groq

# Groq API-Llama4
os.environ["GROQ_API_KEY"] = "<your-api-key>"
client_groq = Groq()

# Helper Function

In [4]:
def url_conversion(img_base64):
    return f"data:image/jpeg;base64,{img_base64}"

def llama4_inference(messages, token=1024):
    completion = client_groq.chat.completions.create(
        model="meta-llama/llama-4-maverick-17b-128e-instruct",
        messages=messages,
        temperature=0.1,
        max_completion_tokens=token,
        top_p=1,
        stream=True,
        stop=None,
    )
    inference_result = ""
    for chunk in completion:
        chunk_inference = chunk.choices[0].delta.content or ""
        inference_result += chunk_inference
    text = inference_result
    return text

# Chat Inference

### Define User Query

In [5]:
user_query = "I want to know the payout"

### Retrieve Relevant Context

In [6]:
batch_query = colpali_processor.process_queries([user_query]).to(device)
embeddings_query = torch.unbind(colpali_model(**batch_query).to("cpu"))[0].float().numpy()
colbert_retriever_result = colbert_retriever.search(embeddings_query, topk=3)
basic_retriever_result = basic_retriever.search(embed_model.encode(user_query), topk=3)

### Create System Instruction

In [7]:
system_instruction = """
You are a helpful assistant designed to answer user queries based on document-related content.

You will be provided with two types of context:
1. Text-based context — extracted textual content from documents.
2. Image-based context — visual content (e.g., figures, tables, or screenshots) extracted from documents.

Your tasks are:
- Analyze the user query and determine the appropriate response using the available context.
- Decide whether the answer requires information from the image-based context.

If the image context is necessary to answer the query:
- Set "need_image" to True.
- Set "image_index" to the appropriate index of the image used (e.g., 0 for the first image, 1 for the second, and so on).
- Include a clear explanation or reasoning in the response.

If the image context is **not** needed:
- Set "need_image" to False.
- Set "image_index" to -1.

All responses **must be returned in strict JSON format**:
{"response": <string>, "need_image": <true|false>, "image_index": <int>}

If you are unsure or cannot answer based on the given context, clearly state that you do not know.

Examples:
{"response": "The chart in image 1 shows the revenue trend.", "need_image": true, "image_index": 1}
{"response": "The policy details are outlined in the text section.", "need_image": false, "image_index": -1}
"""

### Create Content Payload

In [10]:
payload_content = [{
                    "type": "text",
                    "text": f"User Query: {user_query}"
                    }]
for i in range(len(colbert_retriever_result)):
    img_payload = {
        "type": "image_url",
        "image_url": {"url":url_conversion(colbert_retriever_result[i]["content"])}
    }
    payload_content.append(img_payload)
for i in range(len(basic_retriever_result)):
    txt_payload = {
        "type": "text",
        "text": f"Text-based Context #{i+1}:\n{basic_retriever_result[i]['content']}"
    }
    payload_content.append(txt_payload)
messages = [
    {
        "role": "system",
        "content": system_instruction
    },
    {
        "role": "user",
        "content": payload_content
    }
]

### Model Inference

In [11]:
import json
import re

chat_result = llama4_inference(messages)
chat_result = re.findall(r'\{[^{}]+\}', chat_result)
chat_result = json.loads(chat_result[-1])
chat_result

{'response': 'The payout varies based on the performance metric. For Relative TSR Percentile Rank, Delta ROCE, and FCF Margin, the payouts are illustrated in the provided graphs. For example, at a Relative TSR Percentile Rank of 60%, the payout is 60%; at a Delta ROCE of 0 bps, the payout is 100%; and at an FCF Margin of 10%, the payout is 100%.',
 'need_image': True,
 'image_index': 0}

### Structurize the output response

In [12]:
if chat_result["need_image"]:
    img_content = colbert_retriever_result[chat_result['image_index']]["content"]
else:
    img_content = ""

In [13]:
output_response = {
        "response":chat_result["response"],
        "need_image":chat_result["need_image"],
        "img_base64":img_content
        }
output_response

{'response': 'The payout varies based on the performance metric. For Relative TSR Percentile Rank, Delta ROCE, and FCF Margin, the payouts are illustrated in the provided graphs. For example, at a Relative TSR Percentile Rank of 60%, the payout is 60%; at a Delta ROCE of 0 bps, the payout is 100%; and at an FCF Margin of 10%, the payout is 100%.',
 'need_image': True,
 'img_base64': '/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAIBAQEBAQIBAQECAgICAgQDAgICAgUEBAMEBgUGBgYFBgYGBwkIBgcJBwYGCAsICQoKCgoKBggLDAsKDAkKCgr/2wBDAQICAgICAgUDAwUKBwYHCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgr/wAARCAOMBVsDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBU