[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://github.com/milvus-io/bootcamp/blob/master/bootcamp/tutorials/quickstart/multimodal_rag_with_milvus.ipynb)

# Multimodal RAG with Milvus 🖼️

This notebook showcases the integration of Milvus with [MagicLens](https://open-vision-language.github.io/MagicLens/) for advanced image searching based on user instructions. Users can upload an image and edit instructions, which are processed by MagicLens's composed retrieval model to search for candidate images. GPT-4o then acts as a reranker, selecting the most suitable image and providing the rationale behind the choice. This powerful combination enables a seamless and intuitive image search experience, leveraging Milvus for efficient retrieval and MagicLens for precise image processing and matching.

### Install Dependencies

In [None]:
!pip install --upgrade pymilvus Pillow numpy opencv-python datasets openai

> If you are using Google Colab, to enable dependencies just installed, you may need to **restart the runtime** (click on the "Runtime" menu at the top of the screen, and select "Restart session" from the dropdown menu).

### Set up Environment Variables

We will use OpenAI as the LLM in this example. You should prepare the [api key](https://platform.openai.com/docs/quickstart) `OPENAI_API_KEY` as an environment variable.

In [None]:
import os

os.environ["OPENAI_API_KEY"] = "sk-***********"

### Prepare MagicLens Model

More detailed information can be found at <https://github.com/google-deepmind/magiclens>.

- Setup

1. Create the Conda environment:

In [None]:
!conda create --name magic_lens python=3.9 -y

2. Activate the environment and run the subsequent commands manually in a terminal:

In [None]:
!conda activate magic_lens

3. Clone the repository and navigate to the directory:

In [None]:
!git clone https://github.com/google-research/scenic.git
!cd scenic

4. Install the dependencies:

In [None]:
!pip install .
!pip install -r scenic/projects/baselines/clip/requirements.txt

You may need to install corresponding GPU version of jax following <https://jax.readthedocs.io/en/latest/installation.html>.

In [None]:
# CUDA 12 installation
# Note: wheels only available on linux.
!pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

In [None]:
# CUDA 11 installation
!pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

- Model Download

In [None]:
!cd .. # in main folder of demo.
# you may need to use `gcloud auth login` for access, any gmail account should work.
!gsutil cp -R gs://gresearch/magiclens/models ./

Download the [categories.txt](https://github.com/milvus-io/bootcamp/blob/master/bootcamp/tutorials/quickstart/apps/cir_with_milvus/categories.txt) file here and keep it in the same directory: 

### Build Retriever

In [None]:
from typing import Dict
from magiclens.model import MagicLens
import jax
import jax.numpy as jnp
import pickle
from flax import serialization
from scenic.projects.baselines.clip import tokenizer as clip_tokenizer
from PIL import Image
import numpy as np

jax.config.update("jax_platform_name", "gpu")


def load_model(model_size: str, model_path: str) -> Dict:
    """Load and initialize the model."""
    model = MagicLens(model_size)
    rng = jax.random.PRNGKey(0)
    dummy_input = {
        "ids": jnp.ones((1, 1, 77), dtype=jnp.int32),
        "image": jnp.ones((1, 224, 224, 3), dtype=jnp.float32),
    }
    params = model.init(rng, dummy_input)
    print("Model initialized")

    with open(model_path, "rb") as f:
        model_bytes = pickle.load(f)
    params = serialization.from_bytes(params, model_bytes)
    print("Model loaded")
    return model, params


model, model_params = load_model(
    "large", "/home/data3/david/magiclens/models/magic_lens_clip_large.pkl"
)


@jax.jit
def apply_model(params, image, ids):
    return model.apply(params, {"ids": ids, "image": image})


def process_img(image_path: str, size: int) -> np.ndarray:
    """Process a single image to the desired size and normalize."""
    img = Image.open(image_path).convert("RGB")
    img = img.resize((size, size), Image.BILINEAR)
    img = np.array(img) / 255.0  # Normalize to [0, 1]
    img = img[np.newaxis, ...]  # Add batch dimension
    return img


class Retriever:
    def __init__(self):
        self.model = model
        self.tokenizer = clip_tokenizer.build_tokenizer()
        self.model_params = model_params

    def encode_query(self, img_path, text):
        img = process_img(img_path, 224)
        tokens = self.tokenizer(text)
        res = apply_model(self.model_params, img, tokens)
        return np.array(res["multimodal_embed"])


retriever = Retriever()

### Prepare Data and Create Collection

We are using a subset of <https://github.com/hyp1231/AmazonReviews2023> which includes approximately 5000 images in 33 different categories, such as applicances, beauty and personal care, clothing, sports and outdoors, etc. <br>
Create a collection and load image data from the dataset to get the knowledge ready.

In [None]:
from datasets import load_dataset
from pymilvus import MilvusClient
import os
import json

# Define the image folder
image_folder = "/home/data3/david/magiclens/data/images/{}"

# Initialize the encoder
encoder = Retriever()

# Initialize Milvus client
client = MilvusClient("./milvus_demo.db")
client.create_collection(
    collection_name="cir_demo_large",
    overwrite=True,
    auto_id=True,
    dimension=768,
    enable_dynamic_field=True,
)

# Read the categories from the file
with open("categories.txt") as fw:
    lines = fw.readlines()

# Loop through each category, download images, and insert data into Milvus
for line in lines:
    category = line.strip()
    meta_dataset = load_dataset(
        "McAuley-Lab/Amazon-Reviews-2023", f"raw_meta_{category}", split="full"
    )

    for i in range(100):
        if len(meta_dataset[i]["images"]["large"]) > 0:
            img_url = meta_dataset[i]["images"]["large"][0]
            img_name = os.path.basename(img_url)
            img_path = image_folder.format(img_name)

            # Download the image
            os.system(
                f"wget {img_url} -P {os.path.dirname(img_path)} --no-check-certificate"
            )

            if os.path.exists(img_path):
                # Encode the image
                feat = encoder.encode_query(img_path, "")
                # Create the metadata spec
                spec = json.dumps(meta_dataset[i])
                # Insert the data into Milvus
                res = client.insert(
                    collection_name="cir_demo_large",
                    data={
                        "vector": np.array(feat.flatten()),
                        "spec": spec,
                        "name": f"{category}_{i}",
                    },
                )

### Upload Query Image

In [3]:
import io
from PIL import Image
from IPython.display import display
import ipywidgets as widgets

# Create an upload widget
upload_widget = widgets.FileUpload(
    accept="image/*",  # Accept only image files
    multiple=False,  # Accept only a single file
)
display(upload_widget)

# Create a display output area
display_output = widgets.Output()
display(display_output)

uploaded_image = None


def on_upload_change(change):
    # Clear previous output
    display_output.clear_output()

    with display_output:
        for uploaded_file in upload_widget.value:
            content = uploaded_file["content"]
            uploaded_image = Image.open(io.BytesIO(content))
            # Display the image
            display(uploaded_image)


# Attach the handler to the upload widget
upload_widget.observe(on_upload_change, names="value")

FileUpload(value=(), accept='image/*', description='Upload')

Output()

### Enter Text Instruction

In [9]:
text = input("Enter your instruction: ")
print(text)

an earphone with the theme of the image


### Encode Query and Run the Search

In [None]:
emb = retriever.encode_query("temp.jpg", text)

search_results = client.search(
    collection_name="cir_demo_large",
    data=[emb.flatten()],
    output_fields=["spec"],
    limit=100,  # Max number of search results to return
    search_params={"metric_type": "COSINE", "params": {}},  # Search parameters
)

### Display Search Results
Top 25 retrieved images will be displayed in the form of a 5x5 grid.

In [None]:
images = []
path = "/home/data3/david/magiclens/data/images"
for result in search_results:
    for hit in result[:25]:
        filename = hit["entity"]["filename"]
        img = Image.open(filename)
        img = img.resize((150, 150))
        images.append(img)

width = 150 * 5
height = 150 * 5
concatenated_image = Image.new("RGB", (width, height))

for idx, img in enumerate(images):
    x = idx % 5
    y = idx // 5
    concatenated_image.paste(img, (x * 150, y * 150))
display("results")
display(concatenated_image)

### Build Generative Reranker

This reranker helps find the most suitable image based on image and text queries by creating a panoramic view image of the top 25 retrieved images, arranging them into a 5x5 grid, and utilizing an LLM to rank the images. 

In [None]:
import json
import os
import cv2
import numpy as np
import base64
import requests
from openai import AzureOpenAI


class GenerativeReranker:

    def __init__(
        self,
        rowCount: int = 5,
        dim: tuple = (300, 300),
        cache_file: str = "cache.json",
    ) -> None:
        self.rowCount = rowCount
        self.dim = dim
        self.combined_image = None
        self.cache_file = cache_file
        self.cache = self.load_cache()
        self.api_key = os.getenv("OPENAI_API_KEY")

    def __call__(self, query, caption, images, infos) -> list[int]:
        cache_key = self.generate_cache_key(caption, infos)
        if cache_key in self.cache:
            ranked_indices, _ = self.cache[cache_key]
        else:
            self.combined_image = self.create_panoramic_view([query] + images)
            cv2.imwrite("combined_image.jpg", self.combined_image)
            if self.api_type == "openai":
                ranked_indices, explanation = self.generate_ranking_explanation(
                    caption, infos
                )
            elif self.api_type == "azure_openai":
                ranked_indices, explanation = (
                    self.generate_ranking_explanation_azure_openai(caption, infos)
                )
            else:
                pass

            self.cache[cache_key] = (ranked_indices, explanation)
            self.save_cache()
        return [idx for idx in ranked_indices]

    def get_best_item(self, query, caption, images, infos) -> int:
        """
        returns:
            index of the most matched image
        """
        ranked_indices = self.__call__(query, caption, images, infos)
        return ranked_indices[0]

    def explain(self, query, caption, images, infos) -> str:
        """
        provides an explanation of why the best item is chosen based on the query, caption, images, and infos
        """
        cache_key = self.generate_cache_key(caption, infos)
        if cache_key in self.cache:
            _, explanation = self.cache[cache_key]
        else:
            if self.combined_image is None:
                self.combined_image = self.create_panoramic_view([query] + images)
            ranked_indices, explanation = self.generate_ranking_explanation(
                caption, infos
            )
            self.cache[cache_key] = (ranked_indices, explanation)
            self.save_cache()

        return explanation

    def create_panoramic_view(self, images: list) -> np.ndarray:
        """
        creates a 5x5 panoramic view image from a list of images
        """
        img_width, img_height = self.dim
        panoramic_width = img_width * self.rowCount
        panoramic_height = img_height * self.rowCount
        panoramic_image = np.full(
            (panoramic_height, panoramic_width, 3), 255, dtype=np.uint8
        )

        # create and resize the query image with a blue border
        query_image = np.full((panoramic_height, img_width, 3), 255, dtype=np.uint8)
        resized_image = cv2.resize(images[0], (img_width, img_height))

        border_size = 10
        blue = (255, 0, 0)  # blue color in BGR
        bordered_query_image = cv2.copyMakeBorder(
            resized_image,
            border_size,
            border_size,
            border_size,
            border_size,
            cv2.BORDER_CONSTANT,
            value=blue,
        )

        query_image[img_height * 2 : img_height * 3, 0:img_width] = cv2.resize(
            bordered_query_image, (img_width, img_height)
        )

        # add text "query" below the query image
        text = "query"
        font_scale = 1
        font_thickness = 2
        text_org = (10, img_height * 3 + 30)
        cv2.putText(
            query_image,
            text,
            text_org,
            cv2.FONT_HERSHEY_SIMPLEX,
            font_scale,
            blue,
            font_thickness,
            cv2.LINE_AA,
        )

        # combine the rest of the images into the panoramic view
        for i, image in enumerate(images[1:]):
            image = cv2.resize(image, (img_width - 4, img_height - 4))
            row = i // self.rowCount
            col = i % self.rowCount
            start_row = row * img_height
            start_col = col * img_width

            border_size = 2
            bordered_image = cv2.copyMakeBorder(
                image,
                border_size,
                border_size,
                border_size,
                border_size,
                cv2.BORDER_CONSTANT,
                value=(0, 0, 0),
            )
            panoramic_image[
                start_row : start_row + img_height, start_col : start_col + img_width
            ] = bordered_image

            # add red index numbers to each image
            text = str(i)
            org = (start_col + 50, start_row + 30)
            (font_width, font_height), baseline = cv2.getTextSize(
                text, cv2.FONT_HERSHEY_SIMPLEX, 1, 2
            )

            top_left = (org[0] - 48, start_row + 2)
            bottom_right = (org[0] - 48 + font_width + 5, org[1] + baseline + 5)

            cv2.rectangle(
                panoramic_image, top_left, bottom_right, (255, 255, 255), cv2.FILLED
            )
            cv2.putText(
                panoramic_image,
                text,
                (start_col + 10, start_row + 30),
                cv2.FONT_HERSHEY_SIMPLEX,
                1,
                (0, 0, 255),
                2,
                cv2.LINE_AA,
            )

        # combine the query image with the panoramic view
        panoramic_image = np.hstack([query_image, panoramic_image])
        return panoramic_image

    def generate_ranking_explanation(
        self, caption: str, infos: dict
    ) -> tuple[list[int], str]:
        """
        uses an LLM to rank images and generate an explanation based on the combined panoramic view image, caption, and infos
        """
        base64_image = self.encode_image("combined_image.jpg")
        headers = {
            "Content-Type": "application/json",
            "Authorization": f"Bearer {self.api_key}",
        }
        information = (
            "You are responsible for ranking results for a Composed Image Retrieval. "
            "The user retrieves an image with an 'instruction' indicating their retrieval intent. "
            "For example, if the user queries a red car with the instruction 'change this car to blue,' a similar type of car in blue would be ranked higher in the results. "
            "Now you would receive instruction and query image with blue border. Every item has its red index number in its top left. Do not misunderstand it. "
            f"User instruction: {caption} \n\n"
        )

        # add additional information for each image
        for i, info in enumerate(infos["product"]):
            information += f"{i}. {info}\n"

        information += (
            "Provide a new ranked list of indices from most suitable to least suitable, followed by an explanation for the top 1 most suitable item only. "
            "The format of the response has to be 'Ranked list: []' with the indices in brackets as integers, followed by 'Reasons:' plus the explanation why this most fit user's query intent."
        )

        payload = {
            "model": "gpt-4o",
            "messages": [
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": information},
                        {
                            "type": "image_url",
                            "image_url": {
                                "url": f"data:image/jpeg;base64,{base64_image}"
                            },
                        },
                    ],
                }
            ],
            "max_tokens": 300,
        }

        response = requests.post(
            "https://api.openai.com/v1/chat/completions", headers=headers, json=payload
        )
        result = response.json()["choices"][0]["message"]["content"]

        # parse the ranked indices from the response
        start_idx = result.find("[")
        end_idx = result.find("]")
        ranked_indices_str = result[start_idx + 1 : end_idx].split(",")
        ranked_indices = [int(index.strip()) for index in ranked_indices_str]

        # extract explanation
        explanation = result[end_idx + 1 :].strip()

        return ranked_indices, explanation

    def encode_image(self, image_path):
        with open(image_path, "rb") as image_file:
            return base64.b64encode(image_file.read()).decode("utf-8")

    def load_cache(self):
        if os.path.exists(self.cache_file):
            with open(self.cache_file, "r") as file:
                return json.load(file)
        return {}

    def save_cache(self):
        with open(self.cache_file, "w") as file:
            json.dump(self.cache, file)

    def generate_cache_key(self, caption, infos):
        return f"{caption}_{json.dumps(infos, sort_keys=True)}"

### Display the Combined Image of Top 25 Results

In [None]:
reranker = GenerativeReranker()
top_images = search_results[0][:25]

combined_image = reranker.create_panoramic_view(top_images)
display(combined_image)

### Show Best Item Index and Explanation

In [None]:
top_infos = {"product": [], "instruction": ""}

for i, info in enumerate(search_results):
    img_info = info["entity"]
    top_infos["product"].append(img_info)

top_infos["instruction"] = text
query_array = np.array(uploaded_image)[:, :, ::-1]

best_index = reranker.get_best_item(query_array, text, top_images, top_infos)

explanation = reranker.explain(uploaded_image, text, top_images, top_infos)
print("Best index:", best_index)

# Display the best item and explanation
best_item_image = top_images[best_index][:, :, ::-1]
Image.fromarray(best_item_image).show()
print("Explanation:", explanation)

### Quick Deploy

To learn about how to start an online demo with this tutorial, please refer to [the example application](https://github.com/milvus-io/bootcamp/tree/master/bootcamp/tutorials/quickstart/apps/multimodal_rag_with_milvus).

<img src="https://raw.githubusercontent.com/milvus-io/bootcamp/master/bootcamp/tutorials/quickstart/apps/multimodal_rag_with_milvus/pics/cir_demo.jpg
"/>