<a href="https://colab.research.google.com/github/run-llama/llama_index/blob/main/docs/examples/multi_modal/multi_modal_retrieval.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Multi-Modal Retrieval using GPT text embedding and CLIP image embedding for Wikipedia Articles

In this notebook, we show how to build a Multi-Modal retrieval system using LlamaIndex.

Wikipedia Text embedding index: Generate GPT text embeddings from OpenAI for texts

Wikipedia Images embedding index: [CLIP](https://github.com/openai/CLIP) embeddings from OpenAI for images


Query encoder:
* Encoder query text for text index using GPT embedding
* Encoder query text for image index using CLIP embedding

Framework: [LlamaIndex](https://github.com/run-llama/llama_index)

Steps:
1. Download texts and images raw files for Wikipedia articles
2. Build text index for vector store using GPT embeddings
3. Build image index for vector store using CLIP embeddings
4. Retrieve relevant text and image simultaneously using different query encoding embeddings and vector stores

In [None]:
%pip install llama_index ftfy regex tqdm
%pip install git+https://github.com/openai/CLIP.git
%pip install torch torchvision
%pip install matplotlib scikit-image
%pip install -U qdrant_client

## Load and Download Multi-Modal datasets including texts and images from Wikipedia
Parse wikipedia articles and save into local folder

In [None]:
from pathlib import Path
import requests

wiki_titles = [
    "batman",
    "Vincent van Gogh",
    "San Francisco",
    "iPhone",
    "Tesla Model S",
    "BTS",
]


data_path = Path("data_wiki")

for title in wiki_titles:
    response = requests.get(
        "https://en.wikipedia.org/w/api.php",
        params={
            "action": "query",
            "format": "json",
            "titles": title,
            "prop": "extracts",
            "explaintext": True,
        },
    ).json()
    page = next(iter(response["query"]["pages"].values()))
    wiki_text = page["extract"]

    if not data_path.exists():
        Path.mkdir(data_path)

    with open(data_path / f"{title}.txt", "w") as fp:
        fp.write(wiki_text)

## Parse Wikipedia Images and texts. Load into local folder

In [None]:
import wikipedia
import urllib.request

image_path = Path("data_wiki")
image_uuid = 0
# image_metadata_dict stores images metadata including image uuid, filename and path
image_metadata_dict = {}
MAX_IMAGES_PER_WIKI = 10

wiki_titles = [
    "San Francisco",
    "Batman",
    "Vincent van Gogh",
    "iPhone",
    "Tesla Model S",
    "BTS band",
]

# create folder for images only
if not image_path.exists():
    Path.mkdir(image_path)


# Download images for wiki pages
# Assing UUID for each image
for title in wiki_titles:
    images_per_wiki = 0
    #print(title)
    try:
        page_py = wikipedia.page(title)
        list_img_urls = page_py.images
        #print(list_img_urls)
        for url in list_img_urls:
            if url.endswith(".jpg") or url.endswith(".png"):
                image_uuid += 1
                image_file_name = title + "_" + url.split("/")[-1]
                #print(image_file_name)
                # img_path could be s3 path pointing to the raw image file in the future
                image_metadata_dict[image_uuid] = {
                    "filename": image_file_name,
                    "img_path": "./" + str(image_path / f"{image_uuid}.jpg"),
                }
                urllib.request.urlretrieve(
                    url, image_path / f"{image_uuid}.jpg"
                )
                #print(image_metadata_dict[image_uuid])
                images_per_wiki += 1
                # Limit the number of images downloaded per wiki page to 15
                if images_per_wiki >= MAX_IMAGES_PER_WIKI:
                    break
    except:
        print(str(Exception("No images found for Wikipedia page: ")) + title)
        continue

In [None]:
import qdrant_client

#from qdrant_client.http import models
#from qdrant_client.http.models import Distance, VectorParams

from llama_index import (
    ServiceContext,
    SimpleDirectoryReader,
)
from llama_index.vector_stores.qdrant import QdrantVectorStore
from llama_index import VectorStoreIndex, StorageContext
from llama_index.indices.multi_modal.base import MultiModalVectorStoreIndex

# Create a local Qdrant vector store
#client = qdrant_client.QdrantClient(path="qdrant_db")
client = qdrant_client.QdrantClient(url="http://localhost:6333")

text_store = QdrantVectorStore(
    client=client, collection_name="text_collection"
)
image_store = QdrantVectorStore(
    client=client, collection_name="image_collection"
)
storage_context = StorageContext.from_defaults(
    vector_store=text_store, image_store=image_store
)

# Create the MultiModal index
documents = SimpleDirectoryReader("./data_wiki/").load_data()
index = MultiModalVectorStoreIndex.from_documents(
    documents,
    storage_context=storage_context,
)

In [None]:
import qdrant_client

from llama_index import (
    ServiceContext,
    SimpleDirectoryReader,
)
from llama_index.vector_stores.qdrant import QdrantVectorStore
from llama_index import VectorStoreIndex, StorageContext
from llama_index.indices.multi_modal.base import MultiModalVectorStoreIndex

# Create a HTTP Qdrant vector store
client = qdrant_client.QdrantClient(url="http://localhost:6333")

text_store = QdrantVectorStore(
    client=client, collection_name="text_collection"
)
image_store = QdrantVectorStore(
    client=client, collection_name="image_collection"
)
storage_context = StorageContext.from_defaults(
    vector_store=text_store, image_store=image_store
)

# Create the MultiModal index
documents = SimpleDirectoryReader("./data_wiki/").load_data()
index = MultiModalVectorStoreIndex.from_documents(
    documents,
    storage_context=storage_context,
)

# View the collections
collection_info = client.get_collection(collection_name="text_collection")
print("text_collection (vectors,indexed): (", collection_info.vectors_count,",",collection_info.indexed_vectors_count,")")
collection_info = client.get_collection(collection_name="image_collection")
print("image_collection (vectors,indexed): (", collection_info.vectors_count,",",collection_info.indexed_vectors_count,")")

### Plot downloaded Images from Wikipedia

In [None]:
from PIL import Image
import matplotlib.pyplot as plt
import os


def plot_images(image_metadata_dict):
    original_images_urls = []
    images_shown = 0
    for image_id in image_metadata_dict:
        img_path = image_metadata_dict[image_id]["img_path"]
        if os.path.isfile(img_path):
            filename = image_metadata_dict[image_id]["filename"]
            image = Image.open(img_path).convert("RGB")

            plt.subplot(8, 8, len(original_images_urls) + 1)
            plt.imshow(image)
            plt.xticks([])
            plt.yticks([])

            original_images_urls.append(filename)
            images_shown += 1
            if images_shown >= 64:
                break

    plt.tight_layout()


plot_images(image_metadata_dict)

### Build a separate CLIP image embedding index under a differnt collection `wikipedia_img`

In [None]:
def plot_images(image_paths):
    images_shown = 0
    plt.figure(figsize=(16, 9))
    for img_path in image_paths:
        if os.path.isfile(img_path):
            image = Image.open(img_path)

            plt.subplot(2, 3, images_shown + 1)
            plt.imshow(image)
            plt.xticks([])
            plt.yticks([])

            images_shown += 1
            if images_shown >= 9:
                break

In [None]:
import torch
import clip
import numpy as np

# load the CLIP model with the name ViT-B/32
model, preprocess = clip.load("ViT-B/32")

# the resolution of the input images expected by the model
input_resolution = model.visual.input_resolution

# the maximum length of the input text
context_length = model.context_length

# the size of the vocabulary used by the model
vocab_size = model.vocab_size

# print the information about the model to the console
print(
   "Model parameters:",
   f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}",
)
print("Input resolution:", input_resolution)
print("Context length:", context_length)
print("Vocab size:", vocab_size)


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)


# img_emb_dict stores image embeddings for each image
img_emb_dict = {}
with torch.no_grad():


   # extracts image embeddings for each image from the metadata dictionary
   for image_filename in image_metadata_dict:
       img_file_path = image_metadata_dict[image_filename]["img_path"]
       if os.path.isfile(img_file_path):
           image = (
               # preprocess the image using the CLIP model's preprocess function
               # unsqueeze the image tensor to add a batch dimension
               # move the image tensor to the device specified in line 1
               preprocess(Image.open(img_file_path)).unsqueeze(0).to(device)
           )


           # extract image features using the CLIP model's encode_image function
           image_features = model.encode_image(image)


           # store the image features in the image embedding dictionary
           img_emb_dict[image_filename] = image_features

print(img_emb_dict)

In [None]:
from llama_index.schema import ImageDocument


# create a ImageDocument list object for each image in the dataset 
img_documents = []
for image_filename in image_metadata_dict:
   # the img_emb_dict dictionary contains the image embeddings
   if image_filename in img_emb_dict:
       filename = image_metadata_dict[image_filename]["filename"]
       filepath = image_metadata_dict[image_filename]["img_path"]
       print(filepath)


       # create an ImageDocument for each image
       newImgDoc = ImageDocument(
           text=filename, metadata={"filepath": filepath}
       )


       # set image embedding on the ImageDocument
       newImgDoc.embedding = img_emb_dict[image_filename].tolist()[0]
       img_documents.append(newImgDoc)


# create QdrantVectorStore, with collection name "CLIP_image collection"
wikipedia_store = QdrantVectorStore(
   client=client, collection_name="CLIP image collection"
)


# define storage context
storage_context = StorageContext.from_defaults(vector_store=wikipedia_store)


# define image index
image_index = VectorStoreIndex.from_documents(
   img_documents,
   storage_context=storage_context
)


collection_info = client.get_collection(collection_name="CLIP image collection")
print("CLIP image collection (vectors,indexed): (", collection_info.vectors_count,",",collection_info.indexed_vectors_count,")")

In [None]:
from llama_index.vector_stores import VectorStoreQuery


# return the most similar image to a test query
def retrieve_results_from_image_index(query):


   # first tokenize the text query and convert it to a tensor
   text = clip.tokenize(query).to(device)


   # encode the text tensor using the CLIP model to produce a query embedding
   query_embedding = model.encode_text(text).tolist()[0]


   # create a VectorStoreQuery
   image_vector_store_query = VectorStoreQuery(
       query_embedding=query_embedding,
       similarity_top_k=2, # only return 1 image
       mode="default",
   )


   # query the image vector store
   image_retrieval_results = wikipedia_store.query(
       image_vector_store_query
   )
   return image_retrieval_results




# create a new 16 x 5 inch figure from retrieval results
def plot_image_retrieve_results(image_retrieval_results):
   plt.figure(figsize=(16, 5))


   img_cnt = 0
   # subplot the image and score for each retrieval result
   for returned_image, score in zip(
       image_retrieval_results.nodes, image_retrieval_results.similarities
   ):
       img_name = returned_image.text
       img_path = returned_image.metadata["filepath"]
       image = Image.open(img_path).convert("RGB")


       plt.subplot(2, 3, img_cnt + 1)
       plt.title("{:.4f}".format(score))


       plt.imshow(image)
       plt.xticks([])
       plt.yticks([])
       img_cnt += 1




# define image_query function
def image_query(query):
   image_retrieval_results = retrieve_results_from_image_index(query)
   plot_image_retrieve_results(image_retrieval_results)


## Get Multi-Modal retrieval results for some example queries

In [None]:
# define text query engine
text_query_engine = index.as_query_engine()

query = "Who is the main character of Batman?"
# generate Image retrieval results
image_query(query)

# generate Text retrieval results
text_retrieval_results = text_query_engine.query(query)
print(str(text_retrieval_results))

In [None]:
# define text query engine
text_query_engine = index.as_query_engine()

query = "Which are Van Gogh's most famous paintings?"
# generate Image retrieval results
image_query(query)

# generate Text retrieval results
text_retrieval_results = text_query_engine.query(query)
print(str(text_retrieval_results))

In [None]:
# define text query engine
text_query_engine = index.as_query_engine()

query = "What are the most popular tourist attractions in San Francisco"
# generate Image retrieval results
image_query(query)

# generate Text retrieval results
text_retrieval_results = text_query_engine.query(query)
print(str(text_retrieval_results))

In [None]:
from llama_index.response.notebook_utils import display_source_node
from llama_index.schema import ImageNode

retrieved_image = []
for res_node in retrieval_results:
    if isinstance(res_node.node, ImageNode):
        retrieved_image.append(res_node.node.metadata["file_path"])
    else:
        display_source_node(res_node, source_length=200)

plot_images(retrieved_image)