In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys

import torch
from dotenv import load_dotenv
from loguru import logger
from pydantic import BaseModel
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, PointStruct, VectorParams

from datasets import load_dataset

import mlflow
import numpy as np

sys.path.insert(0, "..")
load_dotenv()

True

In [3]:
class Args(BaseModel):
    testing: bool = False
    run_name: str = "000-first-attempt"
    notebook_persist_dp: str = None
    random_seed: int = 41
    device: str = None

    top_K: int = 100
    top_k: int = 10

    embedding_dim: int = 128

    mlf_model_name: str = "sequence_two_tower_retrieval"

    batch_recs_fp: str = None

    qdrant_url: str = None
    qdrant_collection_name: str = None
    

    def init(self):

        if not (qdrant_host := os.getenv("QDRANT_HOST")):
            raise Exception(f"Environment variable QDRANT_HOST is not set.")

        qdrant_port = os.getenv("QDRANT_PORT")
        self.qdrant_url = f"{qdrant_host}:{qdrant_port}"
        self.qdrant_collection_name = os.getenv("QDRANT_COLLECTION_NAME")

        return self


args = Args().init()

print(args.model_dump_json(indent=2))

{
  "testing": false,
  "run_name": "000-first-attempt",
  "notebook_persist_dp": null,
  "random_seed": 41,
  "device": null,
  "top_K": 100,
  "top_k": 10,
  "embedding_dim": 128,
  "mlf_model_name": "sequence_two_tower_retrieval",
  "batch_recs_fp": null,
  "qdrant_url": "138.2.61.6:6333",
  "qdrant_collection_name": "item2vec"
}


In [4]:
mlf_client = mlflow.MlflowClient()

In [5]:
model = mlflow.pyfunc.load_model(
    model_uri=f"models:/{args.mlf_model_name}@champion",
    model_config = {
        "device": "cpu"
    }
    )

Downloading artifacts:   0%|          | 0/8 [00:00<?, ?it/s]



In [6]:
run_id = model.metadata.run_id
run_info = mlf_client.get_run(run_id).info
artifact_uri = run_info.artifact_uri

In [7]:
sample_input = mlflow.artifacts.load_dict(f"{artifact_uri}/inferrer/input_example.json")
sample_input

Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

{'item_sequences': [['0972683275', '1449410243']], 'item_ids': ['0972683275']}

In [8]:
prediction = model.predict(sample_input)
prediction

  item_sequences = torch.tensor(item_sequences, device=self.device)


{'item_sequences': [['0972683275', '1449410243']],
 'item_ids': ['0972683275'],
 'scores': [0.4714652895927429]}

In [9]:
next(model.unwrap_python_model().model.parameters()).device

device(type='cpu')

## Get item embeddings

In [10]:
two_tower_model = model.unwrap_python_model().model
two_tower_model

SequenceRatingPrediction(
  (item_embedding): Embedding(4818, 128, padding_idx=4817)
  (encoder_layer): TransformerEncoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
    )
    (linear1): Linear(in_features=128, out_features=128, bias=True)
    (dropout): Dropout(p=0.3, inplace=False)
    (linear2): Linear(in_features=128, out_features=128, bias=True)
    (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.3, inplace=False)
    (dropout2): Dropout(p=0.3, inplace=False)
    (activation): PReLU(num_parameters=1)
  )
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_feat

In [11]:
two_tower_model = model.unwrap_python_model().model
item_embedding_0 = two_tower_model.item_embedding(torch.tensor(0))
item_embedding_dim = item_embedding_0.size()[0]
item_embedding_dim

128

In [12]:
item_embedding = two_tower_model.item_embedding.weight.detach().numpy()

logger.info(f"item_embedding.shape: {item_embedding.shape}")

[32m2025-07-02 01:32:55.360[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m3[0m - [1mitem_embedding.shape: (4818, 128)[0m


In [13]:
ann_index = QdrantClient(
    url=args.qdrant_url,
    prefer_grpc=True,

)

  ann_index = QdrantClient(


In [15]:

collection_name = "two_tower_sequence_item_embedding"
embedding = item_embedding
collection_exists = ann_index.collection_exists(collection_name)
if collection_exists:
    logger.info(f"Deleting existing Qdrant collection {collection_name}...")
    ann_index.delete_collection(collection_name)

create_collection_result = ann_index.create_collection(
    collection_name=collection_name,
    vectors_config=VectorParams(size=embedding.shape[1], distance=Distance.COSINE),
)

assert create_collection_result

[32m2025-07-02 01:36:54.574[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m5[0m - [1mDeleting existing Qdrant collection two_tower_sequence_item_embedding...[0m


In [54]:
# for embeddings, name in zip([item_embedding, user_embedding], ["item", "user"]):
#     collection_name = f"{args.qdrant_collection_name}_{name}"
#     upsert_result = ann_index.upsert(
#         collection_name=collection_name,
#         points=[
#             PointStruct(id=idx, vector=vector.tolist(), payload={})
#             for idx, vector in enumerate(embeddings)
#         ],
#     )
#     assert str(upsert_result.status) == "completed"
#     upsert_result

In [16]:
cols = [
    "parent_asin",
    "title",
    "average_rating",
    "description",
    "main_category",
    "categories",
]

In [17]:
# Load metadata

metadata_raw = load_dataset(
    "McAuley-Lab/Amazon-Reviews-2023", "raw_meta_Electronics", trust_remote_code=True
)
metadata_raw_df = metadata_raw["full"].to_pandas()
metadata_raw_df.head(3)

Unnamed: 0,main_category,title,average_rating,rating_number,features,description,price,images,videos,store,categories,details,parent_asin,bought_together,subtitle,author
0,All Electronics,FS-1051 FATSHARK TELEPORTER V3 HEADSET,3.5,6,[],[Teleporter V3 The “Teleporter V3” kit sets a ...,,"{'hi_res': [None], 'large': ['https://m.media-...","{'title': [], 'url': [], 'user_id': []}",Fat Shark,"[Electronics, Television & Video, Video Glasses]","{""Date First Available"": ""August 2, 2014"", ""Ma...",B00MCW7G9M,,,
1,All Electronics,Ce-H22B12-S1 4Kx2K Hdmi 4Port,5.0,1,"[UPC: 662774021904, Weight: 0.600 lbs]",[HDMI In - HDMI Out],,{'hi_res': ['https://m.media-amazon.com/images...,"{'title': [], 'url': [], 'user_id': []}",SIIG,"[Electronics, Television & Video, Accessories,...","{""Product Dimensions"": ""0.83 x 4.17 x 2.05 inc...",B00YT6XQSE,,,
2,Computers,Digi-Tatoo Decal Skin Compatible With MacBook ...,4.5,246,[WARNING: Please IDENTIFY MODEL NUMBER on the ...,[],19.99,{'hi_res': ['https://m.media-amazon.com/images...,"{'title': ['AL 2Sides Video', 'MacBook Protect...",Digi-Tatoo,"[Electronics, Computers & Accessories, Laptop ...","{""Brand"": ""Digi-Tatoo"", ""Color"": ""Fresh Marble...",B07SM135LS,,,


In [18]:
metadata_raw_df = metadata_raw_df[cols]
metadata_raw_df.head(3)

Unnamed: 0,parent_asin,title,average_rating,description,main_category,categories
0,B00MCW7G9M,FS-1051 FATSHARK TELEPORTER V3 HEADSET,3.5,[Teleporter V3 The “Teleporter V3” kit sets a ...,All Electronics,"[Electronics, Television & Video, Video Glasses]"
1,B00YT6XQSE,Ce-H22B12-S1 4Kx2K Hdmi 4Port,5.0,[HDMI In - HDMI Out],All Electronics,"[Electronics, Television & Video, Accessories,..."
2,B07SM135LS,Digi-Tatoo Decal Skin Compatible With MacBook ...,4.5,[],Computers,"[Electronics, Computers & Accessories, Laptop ..."


In [19]:
idm = model.unwrap_python_model().idm
all_item_indices = np.arange(
    item_embedding.shape[0]
).tolist()

all_item_ids =  [idm.get_item_id(indice) for indice in all_item_indices]

In [20]:
import json
import numpy as np

# Convert ndarray to list first, then to JSON string
def safe_serialize(x):
    if isinstance(x, np.ndarray):
        x = x.tolist()
    return json.dumps(x)

list_columns = ["description", "categories"]
metadata_raw_df[list_columns] = metadata_raw_df[list_columns].applymap(safe_serialize)

  metadata_raw_df[list_columns] = metadata_raw_df[list_columns].applymap(safe_serialize)


In [21]:
metadata_raw_df

Unnamed: 0,parent_asin,title,average_rating,description,main_category,categories
0,B00MCW7G9M,FS-1051 FATSHARK TELEPORTER V3 HEADSET,3.5,"[""Teleporter V3 The \u201cTeleporter V3\u201d ...",All Electronics,"[""Electronics"", ""Television & Video"", ""Video G..."
1,B00YT6XQSE,Ce-H22B12-S1 4Kx2K Hdmi 4Port,5.0,"[""HDMI In - HDMI Out""]",All Electronics,"[""Electronics"", ""Television & Video"", ""Accesso..."
2,B07SM135LS,Digi-Tatoo Decal Skin Compatible With MacBook ...,4.5,[],Computers,"[""Electronics"", ""Computers & Accessories"", ""La..."
3,B089CNGZCW,NotoCity Compatible with Vivoactive 4 band 22m...,4.5,[],AMAZON FASHION,"[""Electronics"", ""Wearable Technology"", ""Clips,..."
4,B004E2Z88O,Motorola Droid X Essentials Combo Pack,3.8,"[""all Genuine High Quality Motorola Made Acces...",Cell Phones & Accessories,"[""Electronics"", ""Computers & Accessories"", ""Co..."
...,...,...,...,...,...,...
1610007,B003NUIU9M,"Wintec FileMate Pro USB Flash Drive, 3FMUSB32G...",5.0,"[""--New in retail packaging --Fast USB 2.0 dat...",Computers,"[""Electronics"", ""Computers & Accessories"", ""Da..."
1610008,B0BHVY33TL,Tsugar Noise Reduction Wireless Headphones Blu...,1.0,"[""Description: 100% brand new high quality 1.H...",,"[""Electronics"", ""Headphones, Earbuds & Accesso..."
1610009,B09SQGRFFH,"Hardshell Case for MacBook Pro (16-inch, 2021)...",4.6,[],,"[""Electronics"", ""Computers & Accessories"", ""La..."
1610010,B091JWCSG5,"FYY 12-13.3"" Laptop Sleeve Case Bag, PU Leathe...",4.0,[],Computers,"[""Electronics"", ""Computers & Accessories"", ""La..."


In [22]:
payload = (
    metadata_raw_df[metadata_raw_df["parent_asin"].isin(all_item_ids)]
    .assign(item_index=lambda df: df["parent_asin"].map(idm.get_item_index))
    .set_index("item_index")
    .to_dict(orient="index")
)

In [23]:
collection_name = "two_tower_sequence_item_embedding"
upsert_result = ann_index.upsert(
    collection_name=collection_name,
    points=[
        PointStruct(id=idx, vector=vector.tolist(), payload=payload.get(idx, {}))
        for idx, vector in enumerate(item_embedding[:-1])
    ],
)
assert str(upsert_result.status) == "completed"
upsert_result

UpdateResult(operation_id=0, status=<UpdateStatus.COMPLETED: 'completed'>)