In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json
import os
import sys
import time

import torch
from dotenv import load_dotenv
from loguru import logger
from pydantic import BaseModel
from qdrant_client import QdrantClient
from tqdm.auto import tqdm

import mlflow

load_dotenv()

sys.path.insert(0, "..")

In [3]:
class Args(BaseModel):
    testing: bool = False
    run_name: str = "000-first-attempt"
    notebook_persist_dp: str = None
    random_seed: int = 41
    device: str = None

    top_K: int = 100
    top_k: int = 10

    mlf_model_name: str = "sequence_two_tower"

    batch_recs_fp: str = None

    qdrant_url: str = None
    qdrant_collection_name: str = None

    def init(self):
        self.notebook_persist_dp = os.path.abspath(f"data/{self.run_name}")
        os.makedirs(self.notebook_persist_dp, exist_ok=True)
        self.batch_recs_fp = f"{self.notebook_persist_dp}/batch_recs.jsonl"

        if not (qdrant_host := os.getenv("QDRANT_HOST")):
            raise Exception(f"Environment variable QDRANT_HOST is not set.")

        qdrant_port = os.getenv("QDRANT_PORT")
        self.qdrant_url = f"{qdrant_host}:{qdrant_port}"
        self.qdrant_collection_name = os.getenv("QDRANT_COLLECTION_NAME") + "_item"

        return self


args = Args().init()

print(args.model_dump_json(indent=2))

{
  "testing": false,
  "run_name": "000-first-attempt",
  "notebook_persist_dp": "/home/dinhln/Desktop/real_time_recsys/notebooks/data/000-first-attempt",
  "random_seed": 41,
  "device": null,
  "top_K": 100,
  "top_k": 10,
  "mlf_model_name": "sequence_two_tower",
  "batch_recs_fp": "/home/dinhln/Desktop/real_time_recsys/notebooks/data/000-first-attempt/batch_recs.jsonl",
  "qdrant_url": "138.2.61.6:6333",
  "qdrant_collection_name": "item2vec_item"
}


In [4]:
ann_index = QdrantClient(url=args.qdrant_url)
if not ann_index.collection_exists(args.qdrant_collection_name):
    raise Exception(
        f"Required Qdrant collection {args.qdrant_collection_name} does not exist"
    )

  ann_index = QdrantClient(url=args.qdrant_url)


In [5]:
def get_vector_by_id(id_: int):
    record = ann_index.retrieve(
        collection_name=args.qdrant_collection_name, ids=[id_], with_vectors=True
    )[0]
    return record.vector

In [6]:
len(ann_index.retrieve(
        collection_name=args.qdrant_collection_name, ids=[1,1], with_vectors=True
    ))

1

In [7]:
vector = get_vector_by_id(0)
neighbors = ann_index.search(
    collection_name=args.qdrant_collection_name, query_vector=vector, limit=5
)

  neighbors = ann_index.search(


In [8]:
neighbors

[ScoredPoint(id=0, version=0, score=1.0000001, payload={'average_rating': 4.6, 'parent_asin': '0972683275', 'categories': '["Electronics", "Television & Video", "Accessories", "TV Mounts, Stands & Turntables", "TV Wall & Ceiling Mounts"]', 'main_category': 'All Electronics', 'description': '["The videosecu TV mount is a mounting solution for most 22\\"-47\\" LCD LED Plasma TV and some LED up to 55\\" with VESA 600x400mm (24\\"x16\\"), 400x400mm (16\\"x16\\"),600x300mm(24\\"x12\\"), 400x200mm (16\\"x8\\"),300x300mm (12\\"x12\\"),300x200mm(12\\"x8\\"),200x200mm (8\\"x8\\"),200x100mm (8\\"x4\\") mounting hole pattern .Heavy gauge steel construction provides safety loading up to 66lbs display .It can tilt 15 degree forward or backward and swivel 180 degree. The removable VESA plate can be taken off for easy installation. Post-installation level adjustment allows the TV to perfectly level. The on arm cable management ring system design, guides wires and prevent cable pinching. Standard hard

In [9]:
mlf_client = mlflow.MlflowClient()

In [10]:
model = mlflow.pyfunc.load_model(model_uri=f"models:/{args.mlf_model_name}@champion")

Downloading artifacts:   0%|          | 0/8 [00:00<?, ?it/s]



In [11]:
run_id = model.metadata.run_id
run_info = mlf_client.get_run(run_id).info
artifact_uri = run_info.artifact_uri

In [12]:
sample_input = mlflow.artifacts.load_dict(f"{artifact_uri}/inferrer/input_example.json")
sample_input

Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

{'user_ids': ['AE22236AFRRSMQIKGG7TPTB75QEA'],
 'item_sequences': [['0972683275', '1449410243']],
 'item_ids': ['0972683275']}

In [13]:
prediction = model.predict(sample_input)
prediction

  item_sequences = torch.tensor(item_sequences, device=self.device)


{'user_ids': ['AE22236AFRRSMQIKGG7TPTB75QEA'],
 'item_sequences': [['0972683275', '1449410243']],
 'item_ids': ['0972683275'],
 'scores': [0.5657538175582886]}

# Batch recs

In [14]:
idm = model.unwrap_python_model().idm

all_items = list(idm.item_to_index.values())
all_items[0:5]

[0, 1, 2, 3, 4]

In [15]:
len(all_items)

4817

In [16]:
# papermill_description=batch-precompute
recs = []
records = ann_index.retrieve(
    collection_name=args.qdrant_collection_name, ids=all_items, with_vectors=True
)
vectors = [record.vector for record in records]
model_pred_times = []

for indice, query_embedding in tqdm(zip(all_items, vectors), total=len(all_items), desc="Batch recs"):
    neighbor_records = ann_index.search(
        collection_name=args.qdrant_collection_name,
        query_vector=query_embedding,
        limit=args.top_K + 1,
    )
    neighbors = [neighbor.id for neighbor in neighbor_records]
    scores = [neighbor.score for neighbor in neighbor_records]
    
    
    # Remove self-recommendation
    neighbours_indexes = [i for i, neighbor in enumerate(neighbors) if neighbor != indice]
    neighbors = [neighbors[i] for i in neighbours_indexes]
    scores = [scores[i] for i in neighbours_indexes]
    # # Recalculate prediction scores for all neighbors
    # t0 = time.time()
    # scores = (
    #     skipgram_model(torch.tensor([indice] * len(neighbors)), torch.tensor(neighbors))
    #     .detach()
    #     .numpy()
    #     .astype(float)
    # )
    # t1 = time.time()
    # model_pred_times.append(t1 - t0)
    # # Rerank scores based on model output predictions
    # neighbors, scores = zip(
    #     *sorted(zip(neighbors, scores), key=lambda x: x[1], reverse=True)
    # )
    neighbor_ids = [idm.index_to_item[idx] for idx in neighbors]
    id_ = idm.index_to_item[indice]
    recs.append(
        {"target_item": id_, "rec_item_ids": neighbor_ids, "rec_scores": list(scores)}
    )

Batch recs:   0%|          | 0/4817 [00:00<?, ?it/s]

  neighbor_records = ann_index.search(


In [20]:
recs[3]

{'target_item': 'B00000K2YR',
 'rec_item_ids': ['B0BXP3P132',
  'B06WP2ZT5N',
  'B09P4Q7JK4',
  'B0C5829XY7',
  'B07S9THRC5',
  'B07DHLZ7Z2',
  'B07GPGVYGX',
  'B07DNZZCPX',
  'B00BCA40S0',
  'B09ZLL36JF',
  'B07QH5HD3R',
  'B00XBZY0EI',
  'B0BLM2HK53',
  'B07MN67BCR',
  'B06ZZY14LK',
  'B08Y8FSTMT',
  'B09WBKKRFN',
  'B07GTGHQHB',
  'B0B4NCQ3XF',
  'B07ZWJR9GD',
  'B0B7N8S4T5',
  'B077ZVBJH2',
  'B07GT37484',
  'B0762QT7S6',
  'B077SF8KMG',
  'B098RJMJTW',
  'B0C5MBN688',
  'B077ZT29P2',
  'B09VDNQH8B',
  'B07BTHNW9W',
  'B07RS8J6QP',
  'B07RJZPTLX',
  'B01IQEAEDY',
  'B0BTVN2YTV',
  'B074JKT894',
  'B0BMK6DC5W',
  'B09P4FVYK9',
  'B09PRD4T26',
  'B088RBT8RH',
  'B07PDHSLM6',
  'B07FQDMKFT',
  'B0C2TZSCPT',
  'B071WL63HB',
  'B07MNFH1PX',
  'B07P5JV6HT',
  'B0B35BGQ55',
  'B06XKPQ6YZ',
  'B07W371S8F',
  'B077XGL4PG',
  'B01NBTFNVA',
  'B0BMQN7L6B',
  'B07DC4PZC4',
  'B0BKVV246Q',
  'B07H4VQ4BZ',
  'B00XIXCDLA',
  'B08TJ3JC8Y',
  'B01JW0ASNW',
  'B00I8Y6V9E',
  'B07D4734HR',
  'B08D7JP

In [19]:
logger.info(f"Saving batch recs output to {args.batch_recs_fp}...")
with open(args.batch_recs_fp, "w") as f:
    for rec in recs:
        f.write(json.dumps(rec) + "\n")

[32m2025-06-28 22:05:34.565[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mSaving batch recs output to /home/dinhln/Desktop/real_time_recsys/notebooks/data/000-first-attempt/batch_recs.jsonl...[0m
