[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://github.com/milvus-io/bootcamp/blob/master/bootcamp/tutorials/quickstart/cir_with_milvus.ipynb)

# Composed Image Retrieval with Milvus 🖼️

This notebook showcases the integration of Milvus with [MagicLens](https://open-vision-language.github.io/MagicLens/) for advanced image searching based on user instructions. Users can upload an image and provide editing instructions, which are processed by MagicLens's composed retrieval model to search for candidate images. This powerful combination enables a seamless and intuitive image search experience, leveraging Milvus for efficient retrieval and MagicLens for precise image processing and matching.

### Dependencies and Environment

In [None]:
!pip install --upgrade pymilvus numpy Pillow opencv-python datasets

### Prepare MagicLens Model

More detailed information can be found at <https://github.com/google-deepmind/magiclens>.

- Setup

1. Create the Conda environment:

In [None]:
!conda create --name magic_lens python=3.9 -y

2. Activate the environment and run the subsequent commands manually in a terminal:

In [None]:
!conda activate magic_lens

3. Clone the repository and navigate to the directory:

In [None]:
!git clone https://github.com/google-research/scenic.git
!cd scenic

4. Install the dependencies:

In [None]:
!pip install .
!pip install -r scenic/projects/baselines/clip/requirements.txt

You may need to install corresponding GPU version of jax following <https://jax.readthedocs.io/en/latest/installation.html>.

In [None]:
# CUDA 12 installation
# Note: wheels only available on linux.
!pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

In [None]:
# CUDA 11 installation
!pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

- Model Download

In [None]:
!cd .. # in main folder of demo.
# you may need to use `gcloud auth login` for access, any gmail account should work.
!gsutil cp -R gs://gresearch/magiclens/models ./

Download the [categories.txt](https://github.com/milvus-io/bootcamp/blob/master/bootcamp/tutorials/quickstart/apps/cir_with_milvus/categories.txt) file here and keep it in the same directory: 

### Build Retriever

In [None]:
from typing import Dict
from magiclens.model import MagicLens
import jax
import jax.numpy as jnp
import pickle
from flax import serialization
from scenic.projects.baselines.clip import tokenizer as clip_tokenizer
from PIL import Image
import numpy as np

jax.config.update("jax_platform_name", "gpu")


def load_model(model_size: str, model_path: str) -> Dict:
    """Load and initialize the model."""
    model = MagicLens(model_size)
    rng = jax.random.PRNGKey(0)
    dummy_input = {
        "ids": jnp.ones((1, 1, 77), dtype=jnp.int32),
        "image": jnp.ones((1, 224, 224, 3), dtype=jnp.float32),
    }
    params = model.init(rng, dummy_input)
    print("Model initialized")

    with open(model_path, "rb") as f:
        model_bytes = pickle.load(f)
    params = serialization.from_bytes(params, model_bytes)
    print("Model loaded")
    return model, params


model, model_params = load_model(
    "large", "/home/data3/david/magiclens/models/magic_lens_clip_large.pkl"
)


@jax.jit
def apply_model(params, image, ids):
    return model.apply(params, {"ids": ids, "image": image})


def process_img(image_path: str, size: int) -> np.ndarray:
    """Process a single image to the desired size and normalize."""
    img = Image.open(image_path).convert("RGB")
    img = img.resize((size, size), Image.BILINEAR)
    img = np.array(img) / 255.0  # Normalize to [0, 1]
    img = img[np.newaxis, ...]  # Add batch dimension
    return img


class Retriever:
    def __init__(self):
        self.model = model
        self.tokenizer = clip_tokenizer.build_tokenizer()
        self.model_params = model_params

    def encode_query(self, img_path, text):
        img = process_img(img_path, 224)
        tokens = self.tokenizer(text)
        res = apply_model(self.model_params, img, tokens)
        return np.array(res["multimodal_embed"])


retriever = Retriever()

### Prepare Data and Create Collection

We are using a subset of <https://github.com/hyp1231/AmazonReviews2023> which includes approximately 5000 images in 33 different categories, such as applicances, beauty and personal care, clothing, sports and outdoors, etc. <br>
Create a collection and load image data from the dataset to get the knowledge ready.

In [None]:
from datasets import load_dataset
from pymilvus import MilvusClient
import os
import json

# Define the image folder
image_folder = "/home/data3/david/magiclens/data/images/{}"

# Initialize the encoder
encoder = Retriever()

# Initialize Milvus client
client = MilvusClient("./milvus_demo.db")
client.create_collection(
    collection_name="cir_demo_large",
    overwrite=True,
    auto_id=True,
    dimension=768,
    enable_dynamic_field=True,
)

# Read the categories from the file
with open("categories.txt") as fw:
    lines = fw.readlines()

# Loop through each category, download images, and insert data into Milvus
for line in lines:
    category = line.strip()
    meta_dataset = load_dataset(
        "McAuley-Lab/Amazon-Reviews-2023", f"raw_meta_{category}", split="full"
    )

    for i in range(100):
        if len(meta_dataset[i]["images"]["large"]) > 0:
            img_url = meta_dataset[i]["images"]["large"][0]
            img_name = os.path.basename(img_url)
            img_path = image_folder.format(img_name)

            # Download the image
            os.system(
                f"wget {img_url} -P {os.path.dirname(img_path)} --no-check-certificate"
            )

            if os.path.exists(img_path):
                # Encode the image
                feat = encoder.encode_query(img_path, "")
                # Create the metadata spec
                spec = json.dumps(meta_dataset[i])
                # Insert the data into Milvus
                res = client.insert(
                    collection_name="cir_demo_large",
                    data={
                        "vector": np.array(feat.flatten()),
                        "spec": spec,
                        "name": f"{category}_{i}",
                    },
                )

### Upload Query Image

In [3]:
import io
from PIL import Image
from IPython.display import display
import ipywidgets as widgets

# Create an upload widget
upload_widget = widgets.FileUpload(
    accept="image/*",  # Accept only image files
    multiple=False,  # Accept only a single file
)
display(upload_widget)

# Create a display output area
display_output = widgets.Output()
display(display_output)

uploaded_image = None


def on_upload_change(change):
    # Clear previous output
    display_output.clear_output()

    with display_output:
        for uploaded_file in upload_widget.value:
            content = uploaded_file["content"]
            uploaded_image = Image.open(io.BytesIO(content))
            # Display the image
            display(uploaded_image)


# Attach the handler to the upload widget
upload_widget.observe(on_upload_change, names="value")

FileUpload(value=(), accept='image/*', description='Upload')

Output()

### Enter Text Instruction

In [9]:
text = input("Enter your instruction: ")
print(text)

an earphone with the theme of the image


### Encode Query and Run the Search

In [None]:
emb = retriever.encode_query("temp.jpg", text)

search_results = client.search(
    collection_name="cir_demo_large",
    data=[emb.flatten()],
    output_fields=["spec"],
    limit=100,  # Max number of search results to return
    search_params={"metric_type": "COSINE", "params": {}},  # Search parameters
)

### Display Search Results
Top 25 retrieved images will be displayed in the form of a 5x5 grid.

In [None]:
images = []
path = "/home/data3/david/magiclens/data/images"
for result in search_results:
    for hit in result[:25]:
        filename = hit["entity"]["filename"]
        img = Image.open(filename)
        img = img.resize((150, 150))
        images.append(img)

width = 150 * 5
height = 150 * 5
concatenated_image = Image.new("RGB", (width, height))

for idx, img in enumerate(images):
    x = idx % 5
    y = idx // 5
    concatenated_image.paste(img, (x * 150, y * 150))
display("results")
display(concatenated_image)

### Quick Deploy

To learn about how to start an online demo with this tutorial, please refer to [the example application](https://github.com/milvus-io/bootcamp/tree/master/bootcamp/tutorials/quickstart/apps/cir_with_milvus).

<img src="https://raw.githubusercontent.com/milvus-io/bootcamp/master/bootcamp/tutorials/quickstart/apps/cir_with_milvus/pics/cir_demo.jpg
"/>