In [1]:
import sys
import os
import pandas as pd
from datasets import load_dataset
from tqdm.auto import tqdm
import numpy as np
from sentence_transformers import SentenceTransformer
import torch
from dotenv import load_dotenv
from loguru import logger
from pydantic import BaseModel
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, PointStruct, VectorParams

sys.path.insert(0, '..')

from src.id_mapper import IDMapper

load_dotenv()

True

# Controller

In [2]:
class Args(BaseModel):
    testing: bool = False
    notebook_persist_dp: str = None
    random_seed: int = 41

    user_col: str = "user_id"
    item_col: str = "parent_asin"
    rating_col: str = "rating"
    timestamp_col: str = "timestamp"

    qdrant_url: str = None
    qdrant_collection_name: str = "item_desc_sbert"

    top_k: int = 10

    def init(self):
        if not (qdrant_host := os.getenv("QDRANT_HOST")):
            raise Exception(f"Environment variable QDRANT_HOST is not set.")

        qdrant_port = os.getenv("QDRANT_PORT")
        self.qdrant_url = f"{qdrant_host}:{qdrant_port}"

        return self

args = Args().init()

print(args.model_dump_json(indent=2))

{
  "testing": false,
  "notebook_persist_dp": null,
  "random_seed": 41,
  "user_col": "user_id",
  "item_col": "parent_asin",
  "rating_col": "rating",
  "timestamp_col": "timestamp",
  "qdrant_url": "localhost:6333",
  "qdrant_collection_name": "item_desc_sbert",
  "top_k": 10
}


# Load model

In [4]:
model = SentenceTransformer(
    "all-mpnet-base-v2",
    prompts={
        "classification": "Classify the following entertainment product description into different characteristics: ",
        "retrieval": "Retrieve semantically similar text: ",
        "clustering": "Identify the topic or theme based on the text: ",
    },
    default_prompt_name="clustering",
)

# The sentences to encode
sentences = [
    "Super Mario",
    "Pokemon",
    "Final Fantasy",
    "Diablo 3",
    "World of Warcraft",
    "Dota",
    "League of Legends"
]

# 2. Calculate embeddings by calling model.encode()
embeddings = model.encode(sentences)
print(embeddings.shape)

# 3. Calculate the embedding similarities
similarities = model.similarity(embeddings, embeddings)
print(similarities)

Default prompt name is set to 'clustering'. This prompt will be applied to all `encode()` calls, except if `encode()` is called with `prompt` or `prompt_name` parameters.


(7, 768)
tensor([[1.0000, 0.7523, 0.7229, 0.6226, 0.6043, 0.6279, 0.6148],
        [0.7523, 1.0000, 0.6922, 0.6775, 0.6057, 0.6701, 0.6665],
        [0.7229, 0.6922, 1.0000, 0.7260, 0.6649, 0.6406, 0.6467],
        [0.6226, 0.6775, 0.7260, 1.0000, 0.7900, 0.7606, 0.7804],
        [0.6043, 0.6057, 0.6649, 0.7900, 1.0000, 0.7972, 0.8204],
        [0.6279, 0.6701, 0.6406, 0.7606, 0.7972, 1.0000, 0.9076],
        [0.6148, 0.6665, 0.6467, 0.7804, 0.8204, 0.9076, 1.0000]])


# Load data

In [5]:
idm_fp = "../data/idm.json"
idm = IDMapper().load(idm_fp)

In [6]:
metadata_raw = load_dataset(
    "McAuley-Lab/Amazon-Reviews-2023", "raw_meta_Video_Games", trust_remote_code=True
)
metadata_df = (
    metadata_raw["full"]
    .to_pandas()
    .loc[lambda df: df[args.item_col].isin(list(idm.item_to_index.keys()))]
    .assign(
        item_indice=lambda df: df[args.item_col].map(lambda s: idm.get_item_index(s))
    )
)
assert metadata_df[args.item_col].duplicated().sum() == 0
metadata_df

Unnamed: 0,main_category,title,average_rating,rating_number,features,description,price,images,videos,store,categories,details,parent_asin,bought_together,subtitle,author,item_indice
2,Video Games,NBA 2K17 - Early Tip Off Edition - PlayStation 4,4.3,223,[The #1 rated NBA video game simulation series...,[Following the record-breaking launch of NBA 2...,58.0,{'hi_res': ['https://m.media-amazon.com/images...,{'title': ['NBA 2K17 - Kobe: Haters vs Players...,2K,"[Video Games, PlayStation 4, Games]","{""Release date"": ""September 16, 2016"", ""Best S...",B00Z9TLVK0,,,,3057
15,Video Games,"Warhammer 40,000 Dawn of War Game of the Year ...",4.0,68,[Real-time strategy game based on the popular ...,"[From the Manufacturer, This Game of The Year ...",29.95,"{'hi_res': [None, 'https://m.media-amazon.com/...","{'title': [], 'url': [], 'user_id': []}",THQ,"[Video Games, PC, Games]","{""Release date"": ""September 20, 2005"", ""Best S...",B001EYUX4Y,,,,1020
46,Video Games,Polk Audio Striker Zx Xbox One Gaming Headset ...,3.9,169,[Powered by 40 years of audio heritage and tun...,[Our ProFit Comfort system creates a lightweig...,,{'hi_res': ['https://m.media-amazon.com/images...,"{'title': [], 'url': [], 'user_id': []}",Polk Audio,"[Video Games, Xbox One, Accessories, Headsets]","{""Release date"": ""October 5, 2014"", ""Best Sell...",B00OLOQGAY,,,,2876
63,Video Games,The Legend of Heroes: Trails in the Sky - Sony...,4.4,91,[After a brief hiatus since its last release i...,"[Product Description, In the peaceful town of ...",185.0,{'hi_res': ['https://m.media-amazon.com/images...,{'title': ['The Legend of Heroes: Trails in th...,Xseed Games,"[Video Games, Legacy Systems, PlayStation Syst...","{""Release date"": ""March 29, 2011"", ""Best Selle...",B004BV5O0U,,,,1663
65,Video Games,Harry Potter: Goblet of Fire - Sony PSP,3.6,38,[All the Magic of the Movie - Characters model...,"[From the Manufacturer, Be Harry Potter in a n...",19.43,"{'hi_res': [None, 'https://m.media-amazon.com/...","{'title': [], 'url': [], 'user_id': []}",Electronic Arts,"[Purchase Circles, Geography, United States, M...","{""Release date"": ""September 8, 2006"", ""Best Se...",B001ELJEA6,,,,785
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
118437,Video Games,Resident Evil 4,4.6,1133,[],[In Resident Evil 4 you'll know a new type of ...,44.98,{'hi_res': ['https://m.media-amazon.com/images...,"{'title': [], 'url': [], 'user_id': []}",Capcom,"[Video Games, Legacy Systems, Nintendo Systems...","{""Release date"": ""June 19, 2007"", ""Best Seller...",B000P46NKC,,,,551
118457,All Electronics,ivoler Carrying Storage Case for Nintendo Swit...,4.8,13702,[Deluxe Travel Carrying Case. Specially design...,[],26.39,{'hi_res': ['https://m.media-amazon.com/images...,"{'title': [], 'url': [], 'user_id': []}",ivoler,"[Video Games, Legacy Systems, Nintendo Systems...","{""Product Dimensions"": ""11.2 x 4 x 9 inches"", ...",B076GYVWRY,,,,3799
118459,Video Games,EastVita New Charger Dock + 4 x Battery for Ni...,3.9,676,[Include: 1x Remote Controller Charger 4 x 280...,[Charger Dock + 4 x Battery for Wii Remote],,{'hi_res': ['https://m.media-amazon.com/images...,"{'title': [], 'url': [], 'user_id': []}",EastVita,"[Video Games, Legacy Systems, Nintendo Systems...","{""Pricing"": ""The strikethrough price is the Li...",B004Y2VAVS,,,,1795
118510,Video Games,Mario & Luigi: Partners In Time,4.6,607,"[Players use the top screen to study the land,...",[Mario and Luigi: Partners In Time an insane s...,99.99,{'hi_res': ['https://m.media-amazon.com/images...,"{'title': [], 'url': [], 'user_id': []}",Nintendo,"[Video Games, Legacy Systems, Nintendo Systems...","{""Release date"": ""November 28, 2005"", ""Best Se...",B000B8J7K0,,,,415


In [7]:
titles = metadata_df['title'].values
titles[:5]

array(['NBA 2K17 - Early Tip Off Edition - PlayStation 4',
       'Warhammer 40,000 Dawn of War Game of the Year - PC',
       'Polk Audio Striker Zx Xbox One Gaming Headset - Black',
       'The Legend of Heroes: Trails in the Sky - Sony PSP',
       'Harry Potter: Goblet of Fire - Sony PSP'], dtype=object)

In [8]:
%%time
title_embeddings = model.encode(titles)
embedding_dim = title_embeddings.shape[1]

CPU times: user 14 s, sys: 2.39 s, total: 16.4 s
Wall time: 13 s


In [9]:
title_embeddings

array([[ 0.05771705, -0.04336544, -0.02406197, ...,  0.00534384,
        -0.04260172, -0.00983627],
       [ 0.01018465,  0.03246344, -0.00325039, ..., -0.00215585,
        -0.05831467, -0.01222209],
       [ 0.03833957, -0.06638431, -0.00046459, ...,  0.00164709,
        -0.06160014,  0.00236448],
       ...,
       [ 0.07847869, -0.05095147,  0.0042989 , ..., -0.03248303,
        -0.05708611, -0.01929816],
       [ 0.05220518,  0.00836591, -0.02224629, ..., -0.01255159,
         0.01264354, -0.00346215],
       [ 0.03613825, -0.02936018, -0.02450047, ...,  0.006214  ,
        -0.06865519, -0.00285492]], dtype=float32)

In [10]:
query = 'harry potter'
query_embedding = model.encode(query)
similarities = model.similarity(query_embedding, title_embeddings)

In [11]:
similarities

tensor([[0.5050, 0.5411, 0.4574,  ..., 0.4853, 0.5794, 0.5575]])

In [12]:
similarities = similarities.squeeze().detach().numpy()
nn_indices = np.argpartition(similarities, -10)[-10:]

In [13]:
pd.DataFrame({
    'target': [query] * 10,
    'score': similarities[nn_indices],
    'titles': titles[nn_indices]
})

Unnamed: 0,target,score,titles
0,harry potter,0.786267,LEGO Harry Potter: Years 5-7 - Playstation 3
1,harry potter,0.792318,Lego Harry Potter: Years 5-7 - PlayStation Vita
2,harry potter,0.852925,Harry Potter and the Deathly Hallows Part 1 - ...
3,harry potter,0.85498,Harry Potter: Prisoner of Azkaban - Xbox
4,harry potter,0.872698,Harry Potter: Goblet of Fire - Sony PSP
5,harry potter,0.877665,Harry Potter and the Sorcerer's Stone - PC
6,harry potter,0.878635,Harry Potter and the Order of the Phoenix - Ni...
7,harry potter,0.894921,Harry Potter and the Half Blood Prince - Xbox 360
8,harry potter,0.897954,Harry Potter and the Chamber of Secrets
9,harry potter,0.903514,"Harry Potter and the Deathly Hallows, Part 2"


In [14]:
titles[nn_indices]

array(['LEGO Harry Potter: Years 5-7 - Playstation 3',
       'Lego Harry Potter: Years 5-7 - PlayStation Vita',
       'Harry Potter and the Deathly Hallows Part 1 - Playstation 3',
       'Harry Potter: Prisoner of Azkaban - Xbox',
       'Harry Potter: Goblet of Fire - Sony PSP',
       "Harry Potter and the Sorcerer's Stone - PC",
       'Harry Potter and the Order of the Phoenix - Nintendo Wii',
       'Harry Potter and the Half Blood Prince - Xbox 360',
       'Harry Potter and the Chamber of Secrets',
       'Harry Potter and the Deathly Hallows, Part 2'], dtype=object)

# Push to Qdrant

In [15]:
ann_index = QdrantClient(url=args.qdrant_url)

In [16]:
collection_exists = ann_index.collection_exists(args.qdrant_collection_name)
if collection_exists:
    logger.info(f"Deleting existing Qdrant collection {args.qdrant_collection_name}...")
    ann_index.delete_collection(args.qdrant_collection_name)

create_collection_result = ann_index.create_collection(
    collection_name=args.qdrant_collection_name,
    vectors_config=VectorParams(size=embedding_dim, distance=Distance.COSINE),
)

assert create_collection_result == True

[32m2024-10-26 23:10:24.081[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m3[0m - [1mDeleting existing Qdrant collection item_desc_sbert...[0m


In [18]:
points = []

for i, row in metadata_df.reset_index(drop=True).iterrows():
    item_indice = int(row['item_indice'])
    vector = title_embeddings[i].tolist()
    payload = {
        args.item_col: row[args.item_col],
        'title': row['title'],
    }
    point = PointStruct(id=item_indice, vector=vector, payload=payload)
    points.append(point)

In [19]:
batch_size = 32

for i in tqdm(range(0, len(points), batch_size)):
    batch_points = points[i:i+batch_size]
    upsert_result = ann_index.upsert(
        collection_name=args.qdrant_collection_name,
        points=batch_points,
    )
    assert str(upsert_result.status) == "completed"
upsert_result

  0%|          | 0/145 [00:00<?, ?it/s]

UpdateResult(operation_id=144, status=<UpdateStatus.COMPLETED: 'completed'>)

In [20]:
query_vector = title_embeddings[0]

hits = ann_index.search(
    collection_name=args.qdrant_collection_name,
    query_vector=query_vector,
    limit=args.top_k,
)

In [21]:
hits

[ScoredPoint(id=3057, version=0, score=0.9999999, payload={'parent_asin': 'B00Z9TLVK0', 'title': 'NBA 2K17 - Early Tip Off Edition - PlayStation 4'}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=3490, version=1, score=0.94643676, payload={'parent_asin': 'B01MG6DORB', 'title': 'NBA 2K17 Standard Edition - PlayStation 4'}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=968, version=40, score=0.9246316, payload={'parent_asin': 'B001EYUTL6', 'title': 'NBA 2K7 - PlayStation 2'}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=3019, version=8, score=0.92061067, payload={'parent_asin': 'B00XZQ58AI', 'title': 'NBA 2K16 - PlayStation 3'}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=967, version=87, score=0.90808296, payload={'parent_asin': 'B001EYUTKW', 'title': 'NBA 2K6 - PlayStation 2'}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=969, version=27, score=0.90576637, payload={'parent_asin': 'B001EYUTMA',