In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys

import torch
from dotenv import load_dotenv
from loguru import logger
from pydantic import BaseModel
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, PointStruct, VectorParams

import mlflow

sys.path.insert(0, "..")
load_dotenv()

True

In [3]:
class Args(BaseModel):
    testing: bool = False
    run_name: str = "000-first-attempt"
    notebook_persist_dp: str = None
    random_seed: int = 41
    device: str = None

    top_K: int = 100
    top_k: int = 10

    embedding_dim: int = 128

    mlf_model_name: str = "two-tower"

    batch_recs_fp: str = None

    qdrant_url: str = None
    qdrant_collection_name: str = "two_tower"
    

    def init(self):
        self.notebook_persist_dp = os.path.abspath(f"data/{self.run_name}")
        os.makedirs(self.notebook_persist_dp, exist_ok=True)
        self.batch_recs_fp = f"{self.notebook_persist_dp}/batch_recs.jsonl"

        if not (qdrant_host := os.getenv("QDRANT_HOST")):
            raise Exception(f"Environment variable QDRANT_HOST is not set.")

        qdrant_port = os.getenv("QDRANT_PORT")
        self.qdrant_url = f"{qdrant_host}:{qdrant_port}"

        return self


args = Args().init()

print(args.model_dump_json(indent=2))

{
  "testing": false,
  "run_name": "000-first-attempt",
  "notebook_persist_dp": "/home/dinhln/Desktop/real_time_recsys/notebooks/data/000-first-attempt",
  "random_seed": 41,
  "device": null,
  "top_K": 100,
  "top_k": 10,
  "embedding_dim": 128,
  "mlf_model_name": "two-tower",
  "batch_recs_fp": "/home/dinhln/Desktop/real_time_recsys/notebooks/data/000-first-attempt/batch_recs.jsonl",
  "qdrant_url": "138.2.61.6:6333",
  "qdrant_collection_name": "two_tower"
}


In [4]:
mlf_client = mlflow.MlflowClient()

In [5]:
model = mlflow.pyfunc.load_model(model_uri=f"models:/{args.mlf_model_name}@champion")

Downloading artifacts:   0%|          | 0/8 [00:00<?, ?it/s]



In [6]:
run_id = model.metadata.run_id
run_info = mlf_client.get_run(run_id).info
artifact_uri = run_info.artifact_uri

In [7]:
sample_input = mlflow.artifacts.load_dict(f"{artifact_uri}/inferrer/input_example.json")
sample_input

Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

{'user_id': ['AGQ5ERLI2VUZVYLQV5WYJ5TLGVYA'], 'item_id': ['B0C2P7CNWG']}

In [8]:
prediction = model.predict(sample_input)
prediction

{'user_id': ['AGQ5ERLI2VUZVYLQV5WYJ5TLGVYA'],
 'item_id': ['B0C2P7CNWG'],
 'scores': [0.466325581073761]}

In [9]:
two_tower_model = model.unwrap_python_model().model
item_embedding_0 = two_tower_model.item_embedding(torch.tensor(0))
item_embedding_dim = item_embedding_0.size()[0]
item_embedding_0

tensor([ 4.0572e-02,  4.2762e-03, -7.4350e-02, -8.3112e-02,  8.6084e-02,
         2.9183e-02,  4.6144e-02,  6.2827e-02, -6.3556e-02, -3.9409e-03,
         2.0402e-02, -5.7586e-02, -2.0329e-02, -8.8266e-02, -2.2466e-03,
         3.2479e-03, -4.8516e-02, -1.7956e-01, -2.6697e-02, -2.7128e-03,
         1.9480e-02,  2.3961e-02,  4.7660e-02, -5.0532e-02, -4.7238e-02,
        -8.5903e-02,  3.3863e-02,  4.5857e-02, -3.1144e-02, -2.1075e-02,
         2.1401e-02,  1.4887e-01, -7.2768e-02,  6.0694e-02,  1.1149e-01,
        -1.6692e-01,  1.7919e-01,  8.9292e-03, -8.9551e-02, -8.8214e-03,
        -7.3826e-02,  2.0690e-02,  6.7625e-02, -1.0031e-01,  7.8096e-02,
        -8.1078e-02, -9.6833e-02,  1.2702e-03, -3.4069e-02, -8.6842e-03,
        -3.5066e-02, -1.4946e-02, -2.8548e-03,  2.4422e-03,  1.6283e-01,
         1.0191e-01, -1.3933e-01,  5.3638e-02,  2.6835e-03,  4.7542e-02,
        -6.4982e-02, -8.3644e-02, -1.8135e-02, -3.3626e-02, -5.9312e-02,
        -2.8482e-02,  4.9219e-02,  1.1423e-01, -3.9

In [10]:
item_embedding = two_tower_model.item_embedding.weight.detach().numpy()

user_embedding = two_tower_model.user_embedding.weight.detach().numpy()

logger.info(f"item_embedding.shape: {item_embedding.shape}")
logger.info(f"user_embedding.shape: {user_embedding.shape}")

[32m2025-05-13 18:44:47.999[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m5[0m - [1mitem_embedding.shape: (4817, 128)[0m
[32m2025-05-13 18:44:48.000[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1muser_embedding.shape: (16407, 128)[0m


In [11]:
ann_index = QdrantClient(
    url=args.qdrant_url,
    prefer_grpc=True,

)

In [12]:
embedding_type = ["item", "user"]

for type in embedding_type:
    collection_name = f"{args.qdrant_collection_name}_{type}"
    embedding = item_embedding if type == "item" else user_embedding
    collection_exists = ann_index.collection_exists(collection_name)
    if collection_exists:
        logger.info(f"Deleting existing Qdrant collection {collection_name}...")
        ann_index.delete_collection(collection_name)

    create_collection_result = ann_index.create_collection(
        collection_name=collection_name,
        vectors_config=VectorParams(size=embedding.shape[1], distance=Distance.COSINE),
    )

    assert create_collection_result == True

[32m2025-05-13 18:44:48.584[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m8[0m - [1mDeleting existing Qdrant collection two_tower_item...[0m
[32m2025-05-13 18:44:49.005[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m8[0m - [1mDeleting existing Qdrant collection two_tower_user...[0m


In [13]:
for embeddings, name in zip([item_embedding, user_embedding], ["item", "user"]):
    collection_name = f"{args.qdrant_collection_name}_{name}"
    upsert_result = ann_index.upsert(
        collection_name=collection_name,
        points=[
            PointStruct(id=idx, vector=vector.tolist(), payload={})
            for idx, vector in enumerate(embeddings)
        ],
    )
    assert str(upsert_result.status) == "completed"
    upsert_result

In [14]:
upsert_result

UpdateResult(operation_id=0, status=<UpdateStatus.COMPLETED: 'completed'>)