# Gradio Demo: Мультимодальный поиск по WikiArt

Демо для трех типов поиска:
1. Image→Image: Поиск похожих изображений
2. Caption→Image: Поиск по текстовому описанию
3. Omni-search: Комбинированный поиск по изображению и тексту


In [1]:
!pip install -q gradio datasets torch torchvision transformers faiss-cpu pillow numpy pandas matplotlib tqdm huggingface-hub sentence-transformers


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/31.4 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.4/31.4 MB[0m [31m252.7 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━[0m [32m17.7/31.4 MB[0m [31m273.6 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━[0m [32m24.0/31.4 MB[0m [31m213.9 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━[0m [32m30.4/31.4 MB[0m [31m263.8 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m31.4/31.4 MB[0m [31m153.2 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m31.4/31.4 MB[0m [31m153.2 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m31.4/31.

In [10]:
import gradio as gr
import numpy as np
import torch
from PIL import Image
from transformers import AutoProcessor, AutoModel
from datasets import load_dataset
import faiss
from sentence_transformers import SentenceTransformer
import warnings
import requests
import io
from matplotlib.gridspec import GridSpec
from tqdm import tqdm

warnings.filterwarnings('ignore')

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

Device: cuda


In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## Загрузка моделей и индексов


In [7]:
# датасет
dataset = load_dataset("lyubachuba/wikiart_5k", split='train')
captions = load_dataset("lyubachuba/wikiart_5k_captions", split='train')
colors = load_dataset("lyubachuba/wikiart_5k_colors", split='train')

artists = dataset.features['artist'].names
styles = dataset.features['style'].names
genres = dataset.features['genre'].names

extra_info_captions = []

for i in tqdm(range(len(dataset))):
    base_caption = captions[i]['caption']

    artist = " ".join([elem.title() for elem in artists[dataset[i]['artist']].split("-")])
    style = " ".join([elem.lower() for elem in styles[dataset[i]['style']].split("_")])
    genre = " ".join([elem.lower() for elem in styles[dataset[i]['genre']].split("_")])

    color_palette = colors[i]["color_palette"].replace("The dominant color palette is ", "").lower().replace("question: what is the dominant color palette? answer:", "")

    extra_info_captions.append(f"{base_caption} by {artist} in style of {style} in {genre} genre. Color palette: {color_palette}".replace("  ", " "))

dataset = dataset.add_column('enriched_caption', extra_info_captions)

100%|██████████| 5000/5000 [03:29<00:00, 23.87it/s]


In [8]:
# индексы

# можно заменить индексы на свои
IMAGE_INDEX_PATH = "/content/drive/MyDrive/MM_HW1/image_index.bin"
TEXT_INDEX_PATH = "/content/drive/MyDrive/MM_HW1/enriched_caption_index.bin"

image_index = faiss.read_index(IMAGE_INDEX_PATH)
text_index = faiss.read_index(TEXT_INDEX_PATH)

In [9]:
# модели

model_name = "google/siglip2-base-patch16-224"

siglip_processor = AutoProcessor.from_pretrained(model_name)
siglip_model = AutoModel.from_pretrained(model_name, dtype=torch.float16).to(device)
siglip_model.eval()

sentence_encoder = SentenceTransformer('all-mpnet-base-v2', device=device)

preprocessor_config.json:   0%|          | 0.00/394 [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/34.4M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/253 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.50G [00:00<?, ?B/s]

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

## Функции поиска


In [22]:
# ------- КАРТИНКИ -------

def download_image(url):
  try:
    response = requests.get(url, stream=True, timeout=10)
    response.raise_for_status()

    image_data = response.content
    pil_image = Image.open(io.BytesIO(image_data))

    if pil_image.mode != 'RGB':
        pil_image = pil_image.convert('RGB')

    return pil_image

  except Exception as e:
    return None

def get_image_embedding(image, model, processor, device):
    inputs = processor(images=image, return_tensors="pt").to(device, model.dtype)

    with torch.no_grad():
        embedding = model.get_image_features(**inputs)
        embedding /= embedding.norm(dim=-1, keepdim=True)

    return embedding.cpu().numpy()

def search_similar_images(query_image, index, model,  processor, device, top_k=5):

    query_embedding = get_image_embedding(query_image, model, processor, device)
    distances, indices = index.search(query_embedding.astype('float32'), top_k)

    return distances[0], indices[0]

# ------- ТЕКСТЫ -------

def search_similar_captions(query, sentence_processor, index, top_k=10):

    query_embedding = sentence_processor.encode([query], convert_to_numpy=True)
    query_embedding = query_embedding / np.linalg.norm(query_embedding, axis=1, keepdims=True)

    distances, indices = index.search(query_embedding.astype('float32'), top_k)

    return distances[0], indices[0]

In [30]:
def image_search(query_url, top_k=10):
    query_image = download_image(query_url)

    if query_image is None:
        print("Невозможно скачать изображение.")
        return

    distances, indices = search_similar_images(query_image, image_index, siglip_model, siglip_processor, device, top_k)

    results = []
    for i, (idx, dist) in enumerate(zip(indices, distances)):
        img = dataset[int(idx)]['image']
        caption = dataset[int(idx)]['enriched_caption']
        results.append((img, f"#{i+1} | {caption}"))

    return results

def text_search(text_query, top_k=10):

  distances, indices = search_similar_captions(text_query, sentence_encoder, text_index, top_k)

  results = []
  for i, (idx, dist) in enumerate(zip(indices, distances)):
      img = dataset[int(idx)]['image']
      caption = dataset[int(idx)]['enriched_caption']
      results.append((img, f"#{i+1} | {caption}"))

  return results


def combined_search(
    query_image_url = None,
    query_text = None,
    top_k  = 10,
    text_weight = 0.5, # переписываю под gradio
    image_weight = 0.5
):

    len_dataset = len(dataset)
    scores = np.zeros(len(dataset))

    # поиск по картинке
    if query_image_url is not None and image_weight > 0:
        query_image = download_image(query_image_url)
        distances, indices = search_similar_images(query_image, image_index, siglip_model, siglip_processor, device, len_dataset)
        for i, (idx, dist) in enumerate(zip(indices, distances)):
            scores[int(idx)] += image_weight * dist

    # Поиск по тексту
    if query_text is not None and text_weight > 0:
        distances, indices = search_similar_captions(query_text, sentence_encoder, text_index, len_dataset)
        for i, (idx, dist) in enumerate(zip(indices, distances)):
            scores[int(idx)] += text_weight * dist

    sorted_indices = np.argsort(scores)[::-1]
    filtered_indices = [idx for idx in sorted_indices if scores[idx] > 0][:top_k]

    results = []

    for i, idx in enumerate(filtered_indices):
        img = dataset[int(idx)]['image']
        caption = dataset[int(idx)]['enriched_caption']
        results.append((img, f"#{i+1} | {caption}"))

    return results

In [36]:
with gr.Blocks(title="WikiArt Search") as demo:
    gr.Markdown("# Мультимодальный поиск произведений искусства")

    with gr.Tabs():
        with gr.Tab("Image→Image"):
            gr.Markdown("### Поиск похожих изображений")

            with gr.Row():
                with gr.Column(scale=1):

                    img_url = gr.Textbox(label="Url картинки")
                    img_top_k = gr.Slider(minimum=1, maximum=20, value=10, step=1, label="Количество результатов")
                    img_search_btn = gr.Button("Поиск", variant="primary")

                with gr.Column(scale=2):
                    img_output = gr.Gallery(label="Результаты", columns=5, height="auto")

            img_search_btn.click(
                fn=image_search,
                inputs=[img_url, img_top_k],
                outputs=img_output
            )


        with gr.Tab("Caption→Image"):
            gr.Markdown("### Поиск изобюражений по описанию")

            with gr.Row():
                with gr.Column(scale=1):
                    text_query = gr.Textbox(label="Описание", placeholder="traditional japanese painting of waves")
                    text_top_k = gr.Slider(minimum=1, maximum=20, value=10, step=1, label="Количество результатов")
                    text_search_btn = gr.Button("Поиск", variant="primary")

                with gr.Column(scale=2):
                    text_output = gr.Gallery(label="Результаты", columns=5, height="auto")

            text_search_btn.click(
                fn=text_search,
                inputs=[text_query, text_top_k],
                outputs=text_output
            )

            # Примеры запросов
            gr.Examples(
                examples=[
                  "sunflowers in vase",
                  "a landscape in style of Claude Monet",
                  "russian icon painting",
                  "The Renaissance Madonna",
                  "charcoal sketch of a nude figure",
                  "red flowers"
                ],
                inputs=text_query
            )

        with gr.Tab("Omni→Image Retrieval"):
            gr.Markdown("### Комбинированный поиск")

            with gr.Row():
                with gr.Column(scale=1):
                    omni_img_url = gr.Textbox(label="Url картинки")
                    omni_text_input = gr.Textbox(label="Текстовый запрос")

                    with gr.Row():
                        omni_img_weight = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="Вес изображения")
                        omni_text_weight = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="Вес текста")

                    omni_top_k = gr.Slider(minimum=1, maximum=20, value=10, step=1, label="Количество результатов")
                    omni_search_btn = gr.Button("Поиск", variant="primary")

                with gr.Column(scale=2):
                    omni_output = gr.Gallery(label="Результаты", columns=5, height="auto")
            omni_search_btn.click(
                fn=combined_search,
                inputs=[omni_img_url,
                        omni_text_input,
                        omni_top_k,
                        omni_text_weight,
                        omni_img_weight],
                outputs=omni_output
            )

demo.launch(share=True, debug=True)


Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://b3059f00ab0f098437.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/gradio/queueing.py", line 759, in process_events
    response = await route_utils.call_process_api(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/gradio/route_utils.py", line 354, in call_process_api
    output = await app.get_blocks().process_api(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/gradio/blocks.py", line 2116, in process_api
    result = await self.call_function(
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/gradio/blocks.py", line 1623, in call_function
    prediction = await anyio.to_thread.run_sync(  # type: ignore
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/anyio/to_thread.py", line 56, in run_sync
    return await get_async_backend().run_sync_in_worker_thread(
           ^^^^^

Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://b3059f00ab0f098437.gradio.live


