# Result diversification with Elasticsearch
This notebook demonstrates:
1. Loading fashion dataset
2. Index in Elasticsearch using image search
3. Search items with a broad search term
4. Apply result diversification with the MMR algorithm to the results.

Check out our blog post on this topic to learn more about 

## 1. Setup and Dependencies

In [1]:
!pip install -r requirements.txt


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [2]:
import os
import json
import requests
import numpy as np
import kagglehub
from itertools import repeat
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
import time
from elasticsearch import Elasticsearch
from IPython.display import HTML, display
from typing import List, Dict, Tuple

## 2. Load Configuration

Create a configuration file `elastic_config.env` in this format to authenticate with JINA and the Elastic Cluster. 
```
ELASTIC_API_KEY=<ELASTIC_KEY>
ELASTIC_HOST=<HOST_URL>
JINA_API_KEY=<JINA_KEY>
```

In [3]:
def load_config(file_path="elastic_config.env"):
    """Load configuration from environment file"""
    config = {}
    try:
        with open(file_path, "r") as file:
            for line in file:
                if "=" in line:
                    key, value = line.strip().split("=", 1)
                    config[key] = value
    except FileNotFoundError:
        print(f"Configuration file not found: {file_path}")
    return config


config = load_config()
elastic_host = config.get("ELASTIC_HOST")
elastic_api_key = config.get("ELASTIC_API_KEY")
jina_api_key = config.get("JINA_API_KEY")

print("Configuration loaded successfully")

Configuration loaded successfully


## 3. Load Dataset and Extract ID & Image URLs

In [4]:
dataset_path = kagglehub.dataset_download(
    "paramaggarwal/fashion-product-images-dataset"
)
print("Path to dataset files:", dataset_path)

styles_folder = os.path.join(dataset_path, "fashion-dataset/styles")


def load_dataset(folder_path):
    """Load all JSON files from the dataset folder"""
    products = []

    for filename in os.listdir(folder_path):
        if filename.endswith(".json"):
            file_path = os.path.join(folder_path, filename)
            try:
                with open(file_path, "r") as f:
                    data = json.load(f)
                    if "data" in data:
                        products.append(data["data"])
            except Exception as e:
                print(f"Error reading {filename}: {e}")

    return products


products = load_dataset(styles_folder)
print(f"Loaded {len(products)} total products")

# Filter for bottomwear only to limit data for this demo
bottomwear_products = []
for product in products:
    sub_category = product.get("subCategory", {})
    if sub_category.get("typeName", "").lower() == "bottomwear":
        bottomwear_products.append(product)

print(f"\nFiltered to {len(bottomwear_products)} bottomwear products")

products = bottomwear_products

Path to dataset files: /Users/peter/.cache/kagglehub/datasets/paramaggarwal/fashion-product-images-dataset/versions/1
Loaded 44446 total products

Filtered to 2694 bottomwear products


In [5]:
def extract_id_and_image_url(products):
    """Extract ID and image URL from products"""
    image_data = []

    for product in products:
        product_id = product.get("id")

        style_images = product.get("styleImages", {})
        default_image = style_images.get("default", {})

        image_url = default_image.get("resolutions", {}).get("360X480", "")
        if not image_url:
            image_url = default_image.get("imageURL", "")

        if product_id and image_url:
            image_data.append(
                {
                    "id": product_id,
                    "image_url": image_url,
                    "product_name": product.get("productDisplayName", ""),
                    "brand": product.get("brandName", ""),
                    "color": product.get("baseColour", ""),
                    "article_type": product.get("articleType", {}).get("typeName", ""),
                }
            )

    return image_data


image_data = extract_id_and_image_url(products)
print(f"Extracted {len(image_data)} products with valid IDs and image URLs")

# Only use 1000 products to not make the demo too heavy
demo_image_data = image_data[:1000]
print(f"\nLimited to {len(demo_image_data)} items for demo")
print(f"\nSample items (alphabetically sorted):")
for i in range(min(5, len(demo_image_data))):
    item = demo_image_data[i]
    print(f"  - {item['product_name']} ({item['article_type']}, {item['color']})")

Extracted 2693 products with valid IDs and image URLs

Limited to 1000 items for demo

Sample items (alphabetically sorted):
  - Femella Women Off White Shorts (Shorts, Off White)
  - Nike Women Strong Poly Black Capri (Capris, Black)
  - Flying Machine Men Blue Jeans (Jeans, Blue)
  - Urban Yoga Men Black Shorts (Shorts, Black)
  - Doodle Girls Lace Bow LT.Pink Leggings (Leggings, Pink)


## 4. Create Image Embeddings with JINA API

In [6]:
def get_single_image_embedding(item, jina_api_key):
    """Get embedding for a single image"""
    url = "https://api.jina.ai/v1/embeddings"
    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {jina_api_key}",
    }

    product_data = {
        "product_name": item["product_name"],
        "brand": item["brand"],
        "color": item["color"],
        "article_type": item["article_type"],
    }

    data = {
        "model": "jina-embeddings-v4",
        "dimensions": 1024,
        "normalized": True,
        "task": "retrieval.passage",
        "embedding_type": "float",
        "input": [{"text": f"{product_data}"}, {"image": item["image_url"]}],
    }

    try:
        response = requests.post(url, headers=headers, json=data, timeout=200)
        response.raise_for_status()

        result = response.json()
        if "data" in result and len(result["data"]) > 0:
            return {
                "id": item["id"],
                "image_url": item["image_url"],
                "product_name": item["product_name"],
                "brand": item["brand"],
                "color": item["color"],
                "article_type": item["article_type"],
                "image_vector": to_avg_vector(
                    [result["data"][0]["embedding"], result["data"][1]["embedding"]]
                ),
            }
        return None
    except Exception as e:
        print(f"Error processing {item}: {e}")
        return None


# encode image and product information in one vector
def to_avg_vector(vectors):
    vectors_array = np.array(vectors)

    avg_vector = np.mean(vectors_array, axis=0)

    norm = np.linalg.norm(avg_vector)
    if norm > 0:
        normalized_avg_vector = avg_vector / norm
    else:
        normalized_avg_vector = avg_vector

    return normalized_avg_vector.tolist()


print("Getting embeddings...")

with ThreadPoolExecutor(max_workers=10) as executor:
    products_with_vectors = list(
        tqdm(
            executor.map(
                get_single_image_embedding, demo_image_data, repeat(jina_api_key)
            ),
            total=len(demo_image_data),
            desc="Getting embeddings",
        )
    )

print(f"Retrieved {len(products_with_vectors)} embeddings")

Getting embeddings...


Getting embeddings: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [05:25<00:00,  3.08it/s]

Retrieved 1000 embeddings





In [7]:
def _cosine_similarity(X, Y):
    """Compute cosine similarity between two sets of vectors."""
    X = np.array(X)
    Y = np.array(Y)

    if X.ndim == 1:
        X = X.reshape(1, -1)
    if Y.ndim == 1:
        Y = Y.reshape(1, -1)

    # Normalize the vectors
    X_norm = X / np.linalg.norm(X, axis=1, keepdims=True)
    Y_norm = Y / np.linalg.norm(Y, axis=1, keepdims=True)

    return np.dot(X_norm, Y_norm.T)


def filter_out_similar_items(items, threshold=0.98):
    """Filter out items that have very high similarity to previously seen items"""
    filtered_items = []

    for i, item1 in enumerate(items):
        is_similar_to_existing = False

        for existing_item in filtered_items:
            similarity = _cosine_similarity(
                [item1["image_vector"]], [existing_item["image_vector"]]
            )[0][0]

            if similarity >= threshold:
                is_similar_to_existing = True
                break

        if not is_similar_to_existing:
            filtered_items.append(item1)

    return filtered_items


# Filter out items with similarity >= 0.98
filtered_products = filter_out_similar_items(products_with_vectors, threshold=0.98)

print(f"Original products: {len(products_with_vectors)}")
print(f"After filtering similar items: {len(filtered_products)}")
print(f"Removed {len(products_with_vectors) - len(filtered_products)} similar items")

Original products: 1000
After filtering similar items: 758
Removed 242 similar items


## 5. Setup Elasticsearch Index

In [8]:
# Initialize Elasticsearch client
es = Elasticsearch(elastic_host, api_key=elastic_api_key)

# Define index name
index_name = "fashion_images"

# Define index mapping
mapping = {
    "mappings": {
        "properties": {
            "id": {"type": "keyword"},
            "image_url": {"type": "keyword"},
            "product_name": {"type": "keyword"},
            "brand": {"type": "keyword"},
            "color": {"type": "keyword"},
            "article_type": {"type": "keyword"},
            "image_vector": {
                "type": "dense_vector",
                "dims": 1024,
                "index": True,
                "similarity": "cosine",
                "index_options": {"type": "flat"},
            },
        }
    }
}

if es.indices.exists(index=index_name):
    es.indices.delete(index=index_name)
    print(f"Deleted existing index '{index_name}'")

es.indices.create(index=index_name, body=mapping)
print(f"Created index '{index_name}'")

Deleted existing index 'fashion_images'
Created index 'fashion_images'


## 6. Index Documents with Image Vectors

In [9]:
def index_single_image(item):
    try:
        es.index(index=index_name, id=item["id"], document=item)
        return 1
    except Exception as e:
        print(f"Error indexing document {item['id']}: {e}")
        return 0


print("start")

# Index the documents in parallel
with ThreadPoolExecutor(max_workers=10) as executor:
    results = list(
        tqdm(
            executor.map(index_single_image, filtered_products),
            total=len(filtered_products),
            desc="Indexing images",
        )
    )

indexed_count = sum(results)
print(f"Successfully indexed {indexed_count} documents")

start


Indexing images: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 758/758 [00:26<00:00, 28.72it/s]

Successfully indexed 758 documents





## 7. Query Images with Text Search

In [13]:
SEARCH_QUERY = "pants"


def get_text_embedding(text, jina_api_key):
    """Get text embedding from JINA API"""
    url = "https://api.jina.ai/v1/embeddings"
    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {jina_api_key}",
    }

    data = {
        "model": "jina-embeddings-v4",
        "dimensions": 1024,
        "normalized": True,
        "embedding_type": "float",
        "task": "retrieval.query",
        "input": [{"text": text}],
    }

    try:
        response = requests.post(url, headers=headers, json=data, timeout=30)
        response.raise_for_status()
        result = response.json()

        if "data" in result and len(result["data"]) > 0:
            return result["data"][0]["embedding"]
    except Exception as e:
        print(f"Error getting text embedding: {e}")

    return None


def search_similar_images(es, index_name, query_vector, k=20):
    """Search for similar images using vector similarity"""
    query = {
        "knn": {
            "field": "image_vector",
            "query_vector": query_vector,
            "k": k,
        },
        "size": k,
    }

    response = es.search(index=index_name, body=query)

    results = []
    for hit in response["hits"]["hits"]:
        # Find the original product data to get additional info
        product_id = hit["_source"]["id"]

        results.append(
            {
                "id": product_id,
                "image_url": hit["_source"]["image_url"],
                "image_vector": hit["_source"]["image_vector"],
                "score": hit["_score"],
                "product_name": hit["_source"]["product_name"],
                "brand": hit["_source"]["brand"],
                "color": hit["_source"]["color"],
                "article_type": hit["_source"]["article_type"],
            }
        )

    return results


print(f"Creating text embedding for: '{SEARCH_QUERY}'")
query_vector = get_text_embedding(SEARCH_QUERY, jina_api_key)

if query_vector:
    print(f"\nSearching for items similar to: '{SEARCH_QUERY}'")
    search_results = search_similar_images(es, index_name, query_vector, k=150)
    print(f"Found {len(search_results)} similar images")
else:
    print("Failed to get text embedding")

Creating text embedding for: 'pants'

Searching for items similar to: 'pants'
Found 150 similar images


## 8. Display Search Results\n\nShowing results for text search: **"pants"**

In [15]:
def display_images(images, title="Images", max_per_row=5):
    """Display images in a grid layout"""
    html = f"<h2>{title}</h2>"
    html += '<div style="display: flex; flex-wrap: wrap; gap: 10px;">'
    images = images[:10]

    for i, img in enumerate(images):
        score = img.get("score", "N/A")
        if isinstance(score, (int, float)):
            score_str = f"{score:.3f}"
        else:
            score_str = "N/A"

        product_name = img.get("product_name", "N/A")
        if product_name != "N/A" and len(product_name) > 25:
            product_name = product_name[:25] + "..."

        html += f"""
       <div style="text-align: center; margin-bottom: 20px;">
           <img src="{img['image_url']}" style="width: 150px; height: 200px; object-fit: cover; border: 1px solid #ddd;">
           <p style="margin: 5px 0; font-size: 12px; font-weight: bold;">ID: {img['id']}</p>
           <p style="margin: 5px 0; font-size: 11px;" alit="{img.get('product_name', 'N/A')}">{product_name}</p>
           <p style="margin: 5px 0; font-size: 11px; color: #666;">{img.get('article_type', '')} - {img.get('color', '')}</p>
           <p style="margin: 5px 0; font-size: 12px; color: #007bff;">Score: {score_str}</p>
       </div>
       """

        if (i + 1) % max_per_row == 0:
            html += '</div><div style="display: flex; flex-wrap: wrap; gap: 10px;">'

    html += "</div>"
    display(HTML(html))


display_images(search_results, "Original Search Results")

## 9. Reranking with Maximum Marginal Relevance (MMR)\n\nMMR is a diversity-promoting algorithm that balances:\n- **Relevance**: How well items match the query\n- **Diversity**: How different items are from each other\n\nThe algorithm iteratively selects items that are relevant to the query but different from already selected items.

In [16]:
# taken from: https://github.com/elastic/elasticsearch-py/blob/main/elasticsearch/helpers/vectorstore/_utils.py#L39
def maximal_marginal_relevance(
    query_embedding: List[float],
    embedding_list: List[List[float]],
    lambda_mult: float = 0.5,
    k: int = 4,
) -> List[int]:
    query_embedding_arr = np.array(query_embedding)

    if min(k, len(embedding_list)) <= 0:
        return []
    if query_embedding_arr.ndim == 1:
        query_embedding_arr = np.expand_dims(query_embedding_arr, axis=0)
    similarity_to_query = _cosine_similarity(query_embedding_arr, embedding_list)[0]
    most_similar = int(np.argmax(similarity_to_query))
    idxs = [most_similar]
    selected = np.array([embedding_list[most_similar]])
    while len(idxs) < min(k, len(embedding_list)):
        best_score = -np.inf
        idx_to_add = -1
        similarity_to_selected = _cosine_similarity(embedding_list, selected)
        for i, query_score in enumerate(similarity_to_query):
            if i in idxs:
                continue
            redundant_score = max(similarity_to_selected[i])
            equation_score = (
                lambda_mult * query_score - (1 - lambda_mult) * redundant_score
            )
            if equation_score > best_score:
                best_score = equation_score
                idx_to_add = i
        idxs.append(idx_to_add)
        selected = np.append(selected, [embedding_list[idx_to_add]], axis=0)
    return idxs


mmr_indices = maximal_marginal_relevance(
    query_embedding=query_vector,
    embedding_list=[result["image_vector"] for result in search_results],
    lambda_mult=0.5,
    k=100,
)

reranked_results = [search_results[i] for i in mmr_indices]
display_images(reranked_results, "Reranked Results (MMR)")