In [1]:
import os
import sys

import dotenv
import nltk
import pandas as pd
import torch
from elasticsearch import Elasticsearch
from elasticsearch.helpers import bulk
from PIL import Image
from sentence_transformers import SentenceTransformer
from tqdm.auto import tqdm

sys.path.append("..")
dotenv.load_dotenv()

# Ensure nltk punkt is available
nltk.download("punkt")

  from .autonotebook import tqdm as notebook_tqdm
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [None]:
IS_CREATING_INDEX = False
IS_INDEXING_DATA = False

In [2]:
# Create an Elasticsearch client instance
es = Elasticsearch(
    [os.getenv("ELASTICSEARCH_URL")],  # Use the service name as the host
    http_auth=(
        os.getenv("ELASTICSEARCH_USERNAME"),
        os.getenv("ELASTICSEARCH_PASSWORD"),
    ),  # Username and password
    verify_certs=False,  # Disable SSL certificate verification if not using HTTPS
)

# Check if the connection is successful
if es.ping():
    print("Connected to Elasticsearch")
else:
    print("Could not connect to Elasticsearch")

Connected to Elasticsearch


  es = Elasticsearch(


In [10]:
# Define index name
index_name = "hm_ecommerce_search"

In [11]:
# Index settings and mappings
def create_index(index_name):
    if es.indices.exists(index=index_name):
        print(f"Index {index_name} already exists")
        # delete the index
        es.indices.delete(index=index_name)
        print(f"Index {index_name} deleted")

    index_body = {
        "settings": {
            "number_of_shards": 1,
            "number_of_replicas": 0,
        },
        "mappings": {
            "properties": {
                "productDisplayName": {"type": "text", "analyzer": "standard"},
                "image_vector": {
                    "type": "dense_vector",
                    "dims": 512,
                    "index": True,
                    "similarity": "cosine",
                },
                "text_vector": {
                    "type": "dense_vector",
                    "dims": 512,
                    "index": True,
                    "similarity": "cosine",
                },  # Adjust dimensions if needed
            }
        },
    }
    es.indices.create(index=index_name, body=index_body)
    print(f"Index {index_name} created successfully")


if IS_CREATING_INDEX:
    create_index(index_name)

Index hm_ecommerce_search already exists
Index hm_ecommerce_search deleted
Index hm_ecommerce_search created successfully


In [12]:
# Load dataset

DATASET_PATH = "../datasets/kaggle_hm/articles.csv"
IMAGES_PATH = "../datasets/kaggle_hm/images/"

metadata = pd.read_csv(DATASET_PATH, dtype={"article_id": str})

# Load CLIP model
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
model = SentenceTransformer("sentence-transformers/clip-ViT-B-32", device=device)

cuda


In [13]:
metadata.head()

Unnamed: 0,article_id,product_code,prod_name,product_type_no,product_type_name,product_group_name,graphical_appearance_no,graphical_appearance_name,colour_group_code,colour_group_name,...,department_name,index_code,index_name,index_group_no,index_group_name,section_no,section_name,garment_group_no,garment_group_name,detail_desc
0,108775015,108775,Strap top,253,Vest top,Garment Upper body,1010016,Solid,9,Black,...,Jersey Basic,A,Ladieswear,1,Ladieswear,16,Womens Everyday Basics,1002,Jersey Basic,Jersey top with narrow shoulder straps.
1,108775044,108775,Strap top,253,Vest top,Garment Upper body,1010016,Solid,10,White,...,Jersey Basic,A,Ladieswear,1,Ladieswear,16,Womens Everyday Basics,1002,Jersey Basic,Jersey top with narrow shoulder straps.
2,108775051,108775,Strap top (1),253,Vest top,Garment Upper body,1010017,Stripe,11,Off White,...,Jersey Basic,A,Ladieswear,1,Ladieswear,16,Womens Everyday Basics,1002,Jersey Basic,Jersey top with narrow shoulder straps.
3,110065001,110065,OP T-shirt (Idro),306,Bra,Underwear,1010016,Solid,9,Black,...,Clean Lingerie,B,Lingeries/Tights,1,Ladieswear,61,Womens Lingerie,1017,"Under-, Nightwear","Microfibre T-shirt bra with underwired, moulde..."
4,110065002,110065,OP T-shirt (Idro),306,Bra,Underwear,1010016,Solid,10,White,...,Clean Lingerie,B,Lingeries/Tights,1,Ladieswear,61,Womens Lingerie,1017,"Under-, Nightwear","Microfibre T-shirt bra with underwired, moulde..."


In [None]:
def index_data_images():
    actions = []
    batch_size = 200

    for i in tqdm(range(0, len(metadata), batch_size)):
        batch = metadata.iloc[i : i + batch_size]
        img_paths = [
            f"{IMAGES_PATH}{str(row.article_id)[:3]}/{row.article_id}.jpg"
            for _, row in batch.iterrows()
        ]
        img_batch = []
        for path in img_paths:
            try:
                img = Image.open(path)
                img_batch.append(img)
            except Exception as e:
                print(f"Error loading image {path}: {str(e)}")
                img_batch.append(None)

        try:
            # Filter out None values from failed image loads
            valid_indices = [i for i, img in enumerate(img_batch) if img is not None]
            valid_img_batch = [img_batch[i] for i in valid_indices]
            valid_batch = batch.iloc[valid_indices]

            if not valid_img_batch:
                continue

            dense_embeds_image = model.encode(valid_img_batch).tolist()
            dense_embeds_text = model.encode(valid_batch["prod_name"].tolist()).tolist()

            meta_dict = valid_batch.to_dict(orient="records")
            for j, (dense_image, dense_text, row) in enumerate(
                zip(dense_embeds_image, dense_embeds_text, meta_dict)
            ):
                doc_id = f"{i + valid_indices[j]}"  # Unique ID
                actions.append(
                    {
                        "_index": index_name,
                        "_id": doc_id,
                        "_source": {
                            "productDisplayName": row["prod_name"],
                            "image_vector": dense_image,
                            "text_vector": dense_text,
                            "metadata": row,
                        },
                    }
                )

            if len(actions) >= batch_size:
                try:
                    bulk(
                        es, actions, raise_on_error=False
                    )  # Don't raise error on failure
                except Exception as e:
                    print(f"Error in bulk indexing: {str(e)}")
                actions = []

        except Exception as e:
            print(f"Error processing batch starting at index {i}: {str(e)}")
            continue

    if actions:
        try:
            bulk(es, actions, raise_on_error=False)
        except Exception as e:
            print(f"Error in final bulk indexing: {str(e)}")

    print("Indexing completed")


if IS_INDEXING_DATA:
    index_data_images()

In [15]:
def create_id_mapping():
    id_to_image = {}
    batch_size = 200

    for i in tqdm(range(0, len(metadata), batch_size)):
        batch = metadata.iloc[i : i + batch_size]
        img_paths = [
            f"{IMAGES_PATH}{str(row.article_id)[:3]}/{row.article_id}.jpg"
            for _, row in batch.iterrows()
        ]

        for j, row in enumerate(batch.itertuples(index=False)):
            # Use the article_id directly as it's already a string
            id_to_image[row.article_id] = img_paths[j]

    return id_to_image


id_to_image = create_id_mapping()

100%|██████████| 528/528 [00:02<00:00, 240.29it/s]


## hybrid search

In [51]:
# Hybrid Search Function
def search_elasticsearch(query, weights=[0.1, 0.3, 0.6], top_k=10):
    """Hybrid Search with BM25 + Vector Search"""

    dense_vector = model.encode([query])[0].tolist()
    # bm25_query = {"match": {"productDisplayName": query}}  # Elasticsearch's built-in BM25

    # bm25_query = {
    #     "script_score": {
    #         "query": {"match": {"productDisplayName": query}},
    #         "script": {
    #             "source": "( _score - params.min_score ) / ( params.max_score - params.min_score )",
    #             "params": {
    #                 "min_score": 0,   # Set an estimated BM25 min score
    #                 "max_score": 10   # Set an estimated BM25 max score
    #             }
    #         }
    #     }
    # }

    # bm25_query = {
    #     "script_score": {
    #         "query": {"match": {"productDisplayName": query}},
    #         "script": {
    #             "source": "Math.log(_score + 1)"  # Apply log scaling
    #         }
    #     }
    # }

    bm25_query = {
        "size": 1,  # Get only the top-scoring document
        "query": {"match": {"productDisplayName": query}},
    }
    bm25_response = es.search(index=index_name, body=bm25_query)

    if bm25_response["hits"]["hits"]:
        max_bm25_score = bm25_response["hits"]["hits"][0]["_score"]
    else:
        max_bm25_score = 1  # Avoid division by zero

    # Step 2: Hybrid query with normalized BM25 scores
    bm25_query = {
        "script_score": {
            "query": {"match": {"productDisplayName": query}},
            "script": {
                "source": "( _score / params.max_score )",  # Normalize BM25
                "params": {
                    "max_score": max(max_bm25_score, 1)
                },  # Avoid division by zero
            },
        }
    }

    image_vector_query = {
        "script_score": {
            "query": {"match_all": {}},
            "script": {
                "source": "(cosineSimilarity(params.query_vector, 'image_vector') + 1.0) / 2.0",
                "params": {"query_vector": dense_vector},
            },
        }
    }

    text_vector_query = {
        "script_score": {
            "query": {"match_all": {}},
            "script": {
                "source": "(cosineSimilarity(params.query_vector, 'text_vector') + 1.0) / 2.0",
                "params": {"query_vector": dense_vector},
            },
        }
    }

    hybrid_query = {
        "size": top_k,
        "query": {
            "bool": {
                "should": [
                    {"function_score": {"query": bm25_query, "boost": weights[0]}},
                    {
                        "function_score": {
                            "query": image_vector_query,
                            "boost": weights[1],
                        }
                    },
                    {
                        "function_score": {
                            "query": text_vector_query,
                            "boost": weights[2],
                        }
                    },
                ]
            }
        },
    }

    response = es.search(index=index_name, body=hybrid_query)
    results = [
        (
            hit["_source"]["metadata"]["article_id"],
            hit["_source"]["productDisplayName"],
            hit["_score"],
        )
        for hit in response["hits"]["hits"]
    ]
    return results

In [56]:
from IPython.core.display import HTML


def print_results(results):
    for id, product, score in results:
        print(f"{id}: {product}: {score}")


# function to display product images

DATASET_PATH = "../datasets/kaggle_hm/articles.csv"
IMAGES_PATH = "../datasets/kaggle_hm/images/"


def display_result(image_batch):
    html_content = []
    for img in image_batch:
        # Handle both PIL Image and numpy array inputs
        if isinstance(img, (int, str)):
            try:
                # Use the IMAGES_PATH constant for image loading
                if os.path.exists(img):
                    html_content.append(
                        f'<img src="{img}" style="margin:5px;height:200px">'
                    )
                else:
                    print(f"Image file not found: {img}")
                    continue
            except Exception as e:
                print(f"Error loading image file: {e}")
                continue
        else:
            print("Only file paths supported for now")
            continue

    if not html_content:
        return HTML(data="<div>No images to display</div>")

    return HTML(
        data=f"<div style='display:flex;flex-wrap:wrap'>{''.join(html_content)}</div>"
    )

In [94]:
# Example query
query = "dark blue french connection jeans for men"
results = search_elasticsearch(query, weights=[0, 0.8, 0.2], top_k=20)

print_results(results)
# display the images
display_result([id_to_image[id] for id, product, score in results])

0502315001: MY lined denim trouser: 0.70847577
0256151014: Superstretch Fancy denim: 0.7081977
0554126003: Xihibt Original Jeans: 0.706441
0664647011: Bruce Skinny Denim: 0.7063439
0843596002: Straight fit denim: 0.70609945
0653145008: &DENIM Skinny RW Chic: 0.70607054
0890381003: Leo denim trouser: 0.70588964
0863564004: TVP Slim waist denim: 0.7057018
0664647001: Bruce Skinny Denim Trs: 0.7056441
0644024001: KIM denim trouser: 0.70553154
0664813001: KIM denim trouser: 0.70545495
0746314003: TVP Slim denim: 0.7053627
0843596001: Straight fit denim: 0.705232
0664647007: Bruce Skinny Denim: 0.70511323
0808318003: TVP Slim denim: 0.7049047
0256151015: Superstretch Fancy denim: 0.7048614
0664647010: Bruce Skinny Denim: 0.70485455
0664122002: &DENIM Shaping bootcut: 0.70485175
0664072001: Cool Cropped denim straight HW: 0.70475763
0569498002: Alala Denim: 0.70459175


In [95]:
# Example query
query = "light blue french connection jeans for women"
results = search_elasticsearch(query, weights=[0, 0.8, 0.2], top_k=20)

print_results(results)
# display the images
display_result([id_to_image[id] for id, product, score in results])

0300024017: Super skinny denim: 0.7096269
0554126003: Xihibt Original Jeans: 0.70901024
0844463001: Presley Denim Trousers: 0.7085048
0640373001: DIV Farah flared denim: 0.7081666
0653145002: &DENIM Skinny RW Chic: 0.70773643
0708679002: LNY JADE HIGH DENIM: 0.70770156
0252298015: Danae jeans: 0.70765245
0549262007: Boyfriend LW denim: 0.707644
0754852002: PE ANNA WIDE DENIM: 0.70755684
0881244003: &DENIM+ Curvy jegging HW: 0.7075362
0300024056: Super skinny denim: 0.70727074
0617193001: Ace Skinny 5-Pocket Denim: 0.7069459
0469562055: Skinny denim (1): 0.7067404
0669700001: Capri denim: 0.70671844
0853735001: Cropped denims: 0.7067069
0614644001: Loretta denim slim HW: 0.7066469
0849490001: Ava HW wide denim trousers: 0.7066099
0689390003: Cashew denim: 0.70635134
0777040001: Mother NW wide denim: 0.706326
0469562002: Skinny denim (1): 0.7061925


In [96]:
# Example query
query = " yellow t-shirt with stripes"
results = search_elasticsearch(query, weights=[0, 0.8, 0.2], top_k=20)

print_results(results)
# display the images
display_result([id_to_image[id] for id, product, score in results])

0691761001: RESORT CORN TEE: 0.71413386
0695322011: Lance striped shirt: 0.71163714
0679232002: UNI SUN ss tee stripe: 0.7115141
0614460012: CORN CLASSIC TEE: 0.71061134
0585806003: Simon Retro Stripe Tee: 0.7095337
0614460006: CORN CLASSIC TEE: 0.70771974
0792301009: BABY rib tee: 0.7062561
0643217002: CRISPY STRIPE SWEATER: 0.7059886
0727846003: Crash short sleeve: 0.70376796
0256096023: Basic SS henley t-shirt: 0.7034553
0717490011: Cat Tee.: 0.70325667
0878045001: Tobias Henley tee: 0.70275867
0599773015: Sally Stripe Tee: 0.7024925
0599773012: Sally Stripe Tee: 0.7024099
0906114001: Rio Striped tee: 0.70203245
0714340001: Basic LS Roll up t-shirt: 0.70096606
0708753001: EQ CAMDEN TEE: 0.7007703
0669386005: RC BORIS SHIRT.: 0.69960064
0624486050: Brit Baby Tee: 0.69917303
0783543020: R-Neck SS Slim ONLINE: 0.69874763


In [98]:
# Example query
query = "birkin hermes"
results = search_elasticsearch(query, weights=[0, 0.7, 0.3], top_k=20)

print_results(results)
# display the images
display_result([id_to_image[id] for id, product, score in results])

0647982002: Dag shoulder tote: 0.70094013
0475762003: Marais tote bag: 0.7007505
0639448009: Case tote: 0.70061195
0731344001: Amico Shopper: 0.69995254
0739587001: Amber shopper: 0.69969594
0510618001: Christie tote bag: 0.6992969
0510618004: Christie tote bag: 0.6977566
0723691001: Bag Harald Satchel: 0.6974496
0858078001: Dominique shopper: 0.69713706
0817053006: HANNA SMALL BAG: 0.6967028
0690814001: Lander perforated shopper: 0.6966033
0556669002: Meja tote bag: 0.69644773
0562777001: Katherine tote: 0.6963093
0639448003: Case tote: 0.6960014
0647462002: Fedex tote: 0.6959723
0451704003: Gary suede micro tote: 0.69572407
0639448001: Day tote: 0.69513965
0806931002: Seoul tote: 0.6951164
0690734002: Ofelia bag: 0.6950636
0603367001: Manhattan tote: 0.6950156


## image search

In [99]:
def image_search(image, top_k=10):
    image_vector = model.encode(image).tolist()
    image_vector_query = {
        "size": top_k,
        "query": {
            "script_score": {
                "query": {"match_all": {}},
                "script": {
                    "source": "(cosineSimilarity(params.query_vector, 'image_vector') + 1.0) / 2.0",
                    "params": {"query_vector": image_vector},
                },
            }
        },
    }
    response = es.search(index=index_name, body=image_vector_query)
    results = [
        (
            hit["_source"]["metadata"]["article_id"],
            hit["_source"]["productDisplayName"],
            hit["_score"],
        )
        for hit in response["hits"]["hits"]
    ]
    return results


def display_image(image):
    return display_result([image])


article_id = metadata.iloc[300].article_id
image_path = f"{IMAGES_PATH}{article_id[:3]}/{article_id}.jpg"
image = Image.open(image_path)
display_image(image_path)

In [100]:
results = image_search(image, top_k=10)
display_result([id_to_image[id] for id, product, score in results])

## image and query search

In [101]:
def image_and_query_search(image, query, top_k=10, color_filter=None):
    # filter by query and color first and sort by score
    image_vector = model.encode(image).tolist()
    text_vector = model.encode(query).tolist()
    query_filter = {
        "size": top_k,
        "min_score": 0.5,
        "query": {
            "function_score": {
                "query": {
                    "bool": {
                        "must": [
                            {
                                "match": {
                                    "productDisplayName": {
                                        "query": query,
                                        "minimum_should_match": "30%",
                                    }
                                }
                            }
                        ],
                        "filter": (
                            [
                                {
                                    "term": {
                                        "metadata.perceived_colour_master_name": color_filter
                                    }
                                },
                            ]
                            if color_filter is not None
                            else []
                        ),
                    }
                },
                "functions": [
                    {
                        # Text similarity filter (only apply if similarity >= 0.5)
                        "script_score": {
                            "script": {
                                "source": """
                                    double similarity = (cosineSimilarity(params.query_vector, 'text_vector') + 1.0) / 2.0;
                                    return similarity >= 0.87 ? similarity * 0.2 : 0;
                                """,
                                "params": {"query_vector": text_vector},
                            }
                        }
                    },
                    {
                        # Image similarity scoring
                        "script_score": {
                            "script": {
                                "source": """
                                    double similarity = (cosineSimilarity(params.query_vector, 'image_vector') + 1.0) / 2.0;
                                    return similarity >= 0.88 ? similarity * 0.8 : 0;
                                """,
                                "params": {"query_vector": image_vector},
                            }
                        }
                    },
                ],
                "score_mode": "sum",  # Ensures text similarity acts as a filter, not a scoring factor
                "boost_mode": "replace",
            }
        },
    }

    response = es.search(index=index_name, body=query_filter)
    results = [
        (
            hit["_source"]["metadata"]["article_id"],
            hit["_source"]["productDisplayName"],
            hit["_score"],
        )
        for hit in response["hits"]["hits"]
    ]
    return results

In [102]:
article_id = metadata.iloc[300].article_id
image_path = f"{IMAGES_PATH}{article_id[:3]}/{article_id}.jpg"
image = Image.open(image_path)
display_image(image_path)

In [128]:
results = image_and_query_search(
    image, "soft top wear", top_k=20, color_filter="purple"
)
print_results(results)
display_result([id_to_image[id] for id, product, score in results])

0855944002: Fiffi top: 0.8937304
0892070001: PQ CUBA CASH SL TOP: 0.7374304
0877222001: Lemon top: 0.7348361
0555353008: Nancy rib top: 0.7295621
0568065004: BAMBAM loose top (1): 0.7294848
0485973034: Nilsson Top: 0.7289509
0906304002: Cadillac l/s top: 0.72777236
0621336002: PE Embla top: 0.7272623
0660327001: Artur shell top: 0.7271081
0516656017: TLC Top: 0.72229075
0621363002: PE Bella bikini top 2: 0.7218301
0665399002: J BORGO TOP EQ: 0.71979237
0634009002: CERISE top: 0.7158413
0887659005: Carla top: 0.7138442
0616697002: CINNAMON strap top: 0.7128319
0832362005: Paulina top: 0.7121215
0574605002: Bolt top: 0.7117668
0897732001: DIV+ Mando top: 0.71130496
0748574002: Khloe Seamless padded soft bra: 0.70932585
0907696004: Jinny smock top: 0.7091646


In [106]:
article_id = metadata.iloc[20000].article_id
image_path = f"{IMAGES_PATH}{article_id[:3]}/{article_id}.jpg"
image = Image.open(image_path)
display_image(image_path)

In [118]:
results = image_and_query_search(
    image, "pink color bottomwear", top_k=20, color_filter="pink"
)
print_results(results)
display_result([id_to_image[id] for id, product, score in results])

0615289001: SKINNY SUNBEAM PINK: 0.9438619
0576910002: SKINNY SUNBEAM HEART PINK: 0.93690586
0576910001: Skinny Sunbeam Pink 79: 0.933331
0548738001: SKINNY SUNBEAM PINK: 0.9283832
0767799016: JUST PINK DRESS(1): 0.9021936
0771235004: LOGG Just pink dress: 0.8986049
0767799017: JUST PINK DRESS(1): 0.89476854
0740370003: PINK beanie: 0.7100506
