# Multimodal RAG

Reference: [https://github.com/FlagOpen/FlagEmbedding/tree/master](https://github.com/FlagOpen/FlagEmbedding/tree/master)

## Data Preparation

---

In the project root (milvus-examples folder path), run the following:

```
$ just multi-modal-dataset
```

## Load Embedding Model & Build Encoder

---

In [1]:
import torch
from transformers import AutoModel

In [27]:
class Encoder:
    def __init__(self, model_name: str):
        self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True) # You must set trust_remote_code=True
        self.model.set_processor(model_name)
        self.model.eval()

    def encode_query(self, text: str) -> list[float]:
        with torch.no_grad():
            query_emb = self.model.encode(text=text)
        return query_emb.tolist()[0]

    def encode_image(self, image_path: str) -> list[float]:
        with torch.no_grad():
            query_emb = self.model.encode(images = [image_path])
        return query_emb.tolist()[0]

In [28]:
model_name = "BAAI/BGE-VL-base"
encoder = Encoder(model_name)

## Load Data

---

### Generate embeddings

In [4]:
import os
from tqdm import tqdm
from glob import glob

In [16]:
# Generate embeddings for the image dataset
data_dir = (
    "../../amazon_reviews_2023_subset/images_folder"  # Change to your own value if using a different data directory
)
image_list = glob(
    os.path.join(data_dir, "images", "*.jpg")
)  # We will only use images ending with ".jpg"
image_dict = {}

for image_path in tqdm(image_list, desc="Generating image embeddings: "):
    try:
        image_dict[image_path] = encoder.encode_image(image_path)
    except Exception as e:
        print(f"Failed to generate embedding for {image_path}. Skipped.")
        continue

print("Number of encoded images:", len(image_dict))

Generating image embeddings: 100%|██████████| 900/900 [00:45<00:00, 19.82it/s]

Number of encoded images: 900





### Insert into Milvus

In [17]:
from pymilvus import MilvusClient

In [18]:
uri = "http://localhost:19530"
milvus_client = MilvusClient(uri=uri)

dim = len(list(image_dict.values())[0])
collection_name = "multimodal_rag_demo"

# Create Milvus Collection
# By default, the vector field name is "vector"
milvus_client.create_collection(
    collection_name=collection_name,
    auto_id=True,
    dimension=dim,
    enable_dynamic_field=True,
)

# Insert data into collection
milvus_client.insert(
    collection_name=collection_name,
    data=[{"image_path": k, "vector": v} for k, v in image_dict.items()],
)

{'insert_count': 900, 'ids': [459664593296574549, 459664593296574550, 459664593296574551, 459664593296574552, 459664593296574553, 459664593296574554, 459664593296574555, 459664593296574556, 459664593296574557, 459664593296574558, 459664593296574559, 459664593296574560, 459664593296574561, 459664593296574562, 459664593296574563, 459664593296574564, 459664593296574565, 459664593296574566, 459664593296574567, 459664593296574568, 459664593296574569, 459664593296574570, 459664593296574571, 459664593296574572, 459664593296574573, 459664593296574574, 459664593296574575, 459664593296574576, 459664593296574577, 459664593296574578, 459664593296574579, 459664593296574580, 459664593296574581, 459664593296574582, 459664593296574583, 459664593296574584, 459664593296574585, 459664593296574586, 459664593296574587, 459664593296574588, 459664593296574589, 459664593296574590, 459664593296574591, 459664593296574592, 459664593296574593, 459664593296574594, 459664593296574595, 459664593296574596, 4596645932

## Multimodal Search with Generative Reranker

---

### Run search

In [30]:
query_image = os.path.join(
    "../../amazon_reviews_2023_subset/images_folder", "leopard.jpg"
)  # Change to your own query image path
query_text = "phone case with this image theme"

query_vec = encoder.encode_query(text=query_text)
print(query_vec)

[0.05479687452316284, 0.023166052997112274, 0.0167881827801466, -0.04697072133421898, 0.011066324077546597, -0.02828969806432724, 0.06310860812664032, -0.018622776493430138, -0.006653296295553446, 0.04232127591967583, -0.010570806451141834, -0.0627618357539177, 0.00604519248008728, -0.03184321150183678, 0.019238580018281937, 0.025734692811965942, -0.008408325724303722, -0.017258351668715477, -0.04112744331359863, 0.022215425968170166, 0.024058213457465172, -0.09258005023002625, -0.02045830339193344, -0.02199678122997284, 0.03615334630012512, 0.02627875842154026, 0.0031771014910191298, 0.007986736483871937, -0.03374848887324333, 0.024627046659588814, 0.03626956790685654, 0.038993291556835175, -0.013830902986228466, 0.016696704551577568, -0.02896762639284134, 0.0034529822878539562, 0.01837325096130371, -0.001882597105577588, -0.07863987237215042, 0.08217803388834, -0.03704100474715233, 0.02073206752538681, 0.0010560451773926616, 0.0017355632735416293, 0.01463480107486248, -0.025465548038

In [31]:
query_image = os.path.join(
    "../../amazon_reviews_2023_subset/images_folder", "leopard.jpg"
)  # Change to your own query image path
query_text = "phone case with this image theme"

query_vec = encoder.encode_query(text=query_text)

search_results = milvus_client.search(
    collection_name=collection_name,
    data=[query_vec],
    output_fields=["image_path"],
    limit=9,  # Max number of search results to return
    search_params={"metric_type": "COSINE", "params": {}},  # Search parameters
)[0]

retrieved_images = [hit.get("entity").get("image_path") for hit in search_results]
print(retrieved_images)

['../../amazon_reviews_2023_subset/images_folder/images/51D1Zk43qzL._AC_.jpg', '../../amazon_reviews_2023_subset/images_folder/images/419wqNrteJL._AC_.jpg', '../../amazon_reviews_2023_subset/images_folder/images/41kb5Sk2PVL._AC_.jpg', '../../amazon_reviews_2023_subset/images_folder/images/51cnOn0jyZL._AC_.jpg', '../../amazon_reviews_2023_subset/images_folder/images/51Wqge9HySL._AC_.jpg', '../../amazon_reviews_2023_subset/images_folder/images/51QZyaiYNPL._AC_.jpg', '../../amazon_reviews_2023_subset/images_folder/images/51DHv2g4g9L._AC_.jpg', '../../amazon_reviews_2023_subset/images_folder/images/51RvUc3SJsL._AC_.jpg', '../../amazon_reviews_2023_subset/images_folder/images/41VP6KIgeJL._AC_.jpg']


### Rerank with GPT-4o

#### 1. Create a panoramic view

In [34]:
import numpy as np
import cv2
from PIL import Image

In [35]:
img_height = 300
img_width = 300
row_count = 3

In [36]:
def create_panoramic_view(query_image_path: str, retrieved_images: list) -> np.ndarray:
    """
    creates a 5x5 panoramic view image from a list of images

    args:
        images: list of images to be combined

    returns:
        np.ndarray: the panoramic view image
    """
    panoramic_width = img_width * row_count
    panoramic_height = img_height * row_count
    panoramic_image = np.full(
        (panoramic_height, panoramic_width, 3), 255, dtype=np.uint8
    )

    # create and resize the query image with a blue border
    query_image_null = np.full((panoramic_height, img_width, 3), 255, dtype=np.uint8)
    query_image = Image.open(query_image_path).convert("RGB")
    query_array = np.array(query_image)[:, :, ::-1]
    resized_image = cv2.resize(query_array, (img_width, img_height))

    border_size = 10
    blue = (255, 0, 0)  # blue color in BGR
    bordered_query_image = cv2.copyMakeBorder(
        resized_image,
        border_size,
        border_size,
        border_size,
        border_size,
        cv2.BORDER_CONSTANT,
        value=blue,
    )

    query_image_null[img_height * 2 : img_height * 3, 0:img_width] = cv2.resize(
        bordered_query_image, (img_width, img_height)
    )

    # add text "query" below the query image
    text = "query"
    font_scale = 1
    font_thickness = 2
    text_org = (10, img_height * 3 + 30)
    cv2.putText(
        query_image_null,
        text,
        text_org,
        cv2.FONT_HERSHEY_SIMPLEX,
        font_scale,
        blue,
        font_thickness,
        cv2.LINE_AA,
    )

    # combine the rest of the images into the panoramic view
    retrieved_imgs = [
        np.array(Image.open(img).convert("RGB"))[:, :, ::-1] for img in retrieved_images
    ]
    for i, image in enumerate(retrieved_imgs):
        image = cv2.resize(image, (img_width - 4, img_height - 4))
        row = i // row_count
        col = i % row_count
        start_row = row * img_height
        start_col = col * img_width

        border_size = 2
        bordered_image = cv2.copyMakeBorder(
            image,
            border_size,
            border_size,
            border_size,
            border_size,
            cv2.BORDER_CONSTANT,
            value=(0, 0, 0),
        )
        panoramic_image[
        start_row : start_row + img_height, start_col : start_col + img_width
        ] = bordered_image

        # add red index numbers to each image
        text = str(i)
        org = (start_col + 50, start_row + 30)
        (font_width, font_height), baseline = cv2.getTextSize(
            text, cv2.FONT_HERSHEY_SIMPLEX, 1, 2
        )

        top_left = (org[0] - 48, start_row + 2)
        bottom_right = (org[0] - 48 + font_width + 5, org[1] + baseline + 5)

        cv2.rectangle(
            panoramic_image, top_left, bottom_right, (255, 255, 255), cv2.FILLED
        )
        cv2.putText(
            panoramic_image,
            text,
            (start_col + 10, start_row + 30),
            cv2.FONT_HERSHEY_SIMPLEX,
            1,
            (0, 0, 255),
            2,
            cv2.LINE_AA,
        )

    # combine the query image with the panoramic view
    panoramic_image = np.hstack([query_image_null, panoramic_image])
    return panoramic_image

In [37]:
combined_image_path = os.path.join("../../amazon_reviews_2023_subset/images_folder", "combined_image.jpg")
panoramic_image = create_panoramic_view(query_image, retrieved_images)
cv2.imwrite(combined_image_path, panoramic_image)

combined_image = Image.open(combined_image_path)
show_combined_image = combined_image.resize((300, 300))
show_combined_image.show()

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


#### 2. Rerank and explain

In [38]:
import requests
import base64

In [39]:
#openai_api_key = "sk-***"  # Change to your OpenAI API Key
openai_api_key = "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"

In [40]:
def generate_ranking_explanation(
    combined_image_path: str,
    caption: str,
    infos: dict = None
) -> tuple[list[int], str]:
    with open(combined_image_path, "rb") as image_file:
        base64_image = base64.b64encode(image_file.read()).decode("utf-8")

    information = (
        "You are responsible for ranking results for a Composed Image Retrieval. "
        "The user retrieves an image with an 'instruction' indicating their retrieval intent. "
        "For example, if the user queries a red car with the instruction 'change this car to blue,' a similar type of car in blue would be ranked higher in the results. "
        "Now you would receive instruction and query image with blue border. Every item has its red index number in its top left. Do not misunderstand it. "
        f"User instruction: {caption} \n\n"
    )

    # add additional information for each image
    if infos:
        for i, info in enumerate(infos["product"]):
            information += f"{i}. {info}\n"

    information += (
        "Provide a new ranked list of indices from most suitable to least suitable, followed by an explanation for the top 1 most suitable item only. "
        "The format of the response has to be 'Ranked list: []' with the indices in brackets as integers, followed by 'Reasons:' plus the explanation why this most fit user's query intent."
    )

    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {openai_api_key}",
    }

    payload = {
        "model": "gpt-4o",
        "messages": [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": information},
                    {
                        "type": "image_url",
                        "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"},
                    },
                ],
            }
        ],
        "max_tokens": 300,
    }

    response = requests.post(
        "https://api.openai.com/v1/chat/completions",
        headers=headers,
        json=payload
    )
    result = response.json()["choices"][0]["message"]["content"]

    # parse the ranked indices from the response
    start_idx = result.find("[")
    end_idx = result.find("]")
    ranked_indices_str = result[start_idx + 1 : end_idx].split(",")
    ranked_indices = [int(index.strip()) for index in ranked_indices_str]

    # extract explanation
    explanation = result[end_idx + 1 :].strip()

    return ranked_indices, explanation

In [41]:
ranked_indices, explanation = generate_ranking_explanation(combined_image_path, query_text)

#### 3. Display the best result with explanation

In [43]:
print(explanation)

best_index = ranked_indices[0]
best_img = Image.open(retrieved_images[best_index])
best_img = best_img.resize((150, 150))
best_img.show()

Reasons: 
Item 0 is most suitable because the image theme of the leopard closely matches the query image given. A phone case with a similar theme of leopards fits the user's instruction effectively.


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
