# Set up

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json
import os
import sys
import time

import torch
from dotenv import load_dotenv
from loguru import logger
from pydantic import BaseModel
from qdrant_client import QdrantClient
from tqdm.auto import tqdm

import mlflow

load_dotenv()

sys.path.insert(0, "..")

# Controller

In [3]:
class Args(BaseModel):
    testing: bool = False
    run_name: str = "000-first-attempt"
    notebook_persist_dp: str = None
    random_seed: int = 41
    device: str = None

    top_K: int = 100
    top_k: int = 10

    mlf_model_name: str = "item2vec"

    batch_recs_fp: str = None

    qdrant_url: str = None
    qdrant_collection_name: str = None

    def init(self):
        self.notebook_persist_dp = os.path.abspath(f"data/{self.run_name}")
        os.makedirs(self.notebook_persist_dp, exist_ok=True)
        self.batch_recs_fp = f"{self.notebook_persist_dp}/batch_recs.jsonl"

        if not (qdrant_host := os.getenv("QDRANT_HOST")):
            raise Exception(f"Environment variable QDRANT_HOST is not set.")

        qdrant_port = os.getenv("QDRANT_PORT")
        self.qdrant_url = f"{qdrant_host}:{qdrant_port}"
        self.qdrant_collection_name = os.getenv("QDRANT_COLLECTION_NAME")

        return self


args = Args().init()

print(args.model_dump_json(indent=2))

{
  "testing": false,
  "run_name": "000-first-attempt",
  "notebook_persist_dp": "/mnt/d/projects/recsys/notebooks/data/000-first-attempt",
  "random_seed": 41,
  "device": null,
  "top_K": 100,
  "top_k": 10,
  "mlf_model_name": "item2vec",
  "batch_recs_fp": "/mnt/d/projects/recsys/notebooks/data/000-first-attempt/batch_recs.jsonl",
  "qdrant_url": "localhost:6333",
  "qdrant_collection_name": "item2vec"
}


# Load ANN Index

In [4]:
ann_index = QdrantClient(url=args.qdrant_url)
if not ann_index.collection_exists(args.qdrant_collection_name):
    raise Exception(
        f"Required Qdrant collection {args.qdrant_collection_name} does not exist"
    )

In [5]:
def get_vector_by_id(id_: int):
    record = ann_index.retrieve(
        collection_name=args.qdrant_collection_name, ids=[id_], with_vectors=True
    )[0]
    return record.vector

In [6]:
vector = get_vector_by_id(0)
neighbors = ann_index.search(
    collection_name=args.qdrant_collection_name, query_vector=vector, limit=5
)

In [7]:
neighbors

[ScoredPoint(id=0, version=0, score=0.9999999, payload={}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=4306, version=0, score=0.39393944, payload={}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=3263, version=0, score=0.3846603, payload={}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=3335, version=0, score=0.3804885, payload={}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=2586, version=0, score=0.37781808, payload={}, vector=None, shard_key=None, order_value=None)]

# Load model

In [8]:
mlf_client = mlflow.MlflowClient()

In [9]:
model = mlflow.pyfunc.load_model(model_uri=f"models:/{args.mlf_model_name}@champion")

Downloading artifacts:   0%|          | 0/9 [00:00<?, ?it/s]



In [10]:
sample_input = mlflow.artifacts.load_dict(
    f"models:/{args.mlf_model_name}@champion/input_example.json"
)
sample_input

Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

{'item_1_ids': ['0375869026'], 'item_2_ids': ['9625990674']}

In [11]:
prediction = model.predict(sample_input)
prediction

{'item_1_ids': ['0375869026'],
 'item_2_ids': ['9625990674'],
 'scores': [0.4312632381916046]}

# Batch recommend for all items

In [13]:
skipgram_model = model.unwrap_python_model().model
id_mapping = model.unwrap_python_model().id_mapping
all_items = list(id_mapping["id_to_idx"].values())
all_items[:5]

[0, 1, 2, 3, 4]

In [14]:
# papermill_description=batch-precompute
recs = []
records = ann_index.retrieve(
    collection_name=args.qdrant_collection_name, ids=all_items, with_vectors=True
)
vectors = [record.vector for record in records]
model_pred_times = []

for indice, query_embedding in tqdm(zip(all_items, vectors)):
    neighbor_records = ann_index.search(
        collection_name=args.qdrant_collection_name,
        query_vector=query_embedding,
        limit=args.top_K + 1,
    )
    neighbors = [neighbor.id for neighbor in neighbor_records]
    # Remove self-recommendation
    neighbors = [neighbor for neighbor in neighbors if neighbor != indice]
    # Recalculate prediction scores for all neighbors
    t0 = time.time()
    scores = (
        skipgram_model(torch.tensor([indice] * len(neighbors)), torch.tensor(neighbors))
        .detach()
        .numpy()
        .astype(float)
    )
    t1 = time.time()
    model_pred_times.append(t1 - t0)
    # Rerank scores based on model output predictions
    neighbors, scores = zip(
        *sorted(zip(neighbors, scores), key=lambda x: x[1], reverse=True)
    )
    neighbor_ids = [id_mapping["idx_to_id"][str(idx)] for idx in neighbors]
    id_ = id_mapping["idx_to_id"][str(indice)]
    recs.append(
        {"target_item": id_, "rec_item_ids": neighbor_ids, "rec_scores": list(scores)}
    )

0it [00:00, ?it/s]

In [15]:
avg_model_inference_seconds = sum(model_pred_times) / len(model_pred_times)
logger.info(
    f"Average model inference time: {avg_model_inference_seconds * 1000} milliseconds"
)

[32m2025-11-13 01:50:57.080[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mAverage model inference time: 0.4826636345288666 milliseconds[0m


In [16]:
recs[0]

{'target_item': '0375869026',
 'rec_item_ids': ['B0053BCML6',
  'B00SZ1DQFM',
  'B00Z9TKLFQ',
  'B00NETKFJU',
  'B01G309132',
  'B00ZHQ39F0',
  'B07NQDVZP7',
  'B00X8Y18U6',
  'B011JGTTJQ',
  'B00CXTX2YW',
  'B019TYEXC8',
  'B00DOWCQ0I',
  'B0719KC7HN',
  'B07WTBZNTN',
  'B00L4VYOIS',
  'B087SLTR2B',
  'B00JS9XVNM',
  'B00S7O6R4Y',
  'B07CRC2X77',
  'B00KU1GVGG',
  'B00EO6GSY8',
  'B002BSC5HA',
  'B0199OXR0W',
  'B07DKT3WJ2',
  'B008A27UMG',
  'B01K1OO5PU',
  'B00KM66UFQ',
  'B00U1WN17G',
  'B00HWMP0OU',
  'B010R2RHGU',
  'B00PU12IDG',
  'B007NUQICE',
  'B00E4QOE30',
  'B0845KDQFZ',
  'B000OFSBL6',
  'B06Y2LGTW3',
  'B005VBVRGY',
  'B081FKL957',
  'B07BMRGKX2',
  'B07YBXCXYX',
  'B00ZM5ON88',
  'B07RBZMPZ2',
  'B0171RL3P0',
  'B07GHWHFR5',
  'B00UBL0ZQC',
  'B015TL6PGM',
  'B00F6YISHM',
  'B00VMB5VFK',
  'B073ZR63P8',
  'B08MBG5254',
  'B07V3G6C1F',
  'B00KWEH61U',
  'B00GY4OB8S',
  'B09WL5QLJF',
  'B012E58DFC',
  'B074FF955H',
  'B08S5YCZ6Q',
  'B00FA1CKKW',
  'B01FSO3VGC',
  'B07HC4Z

# Persists

In [18]:
logger.info(f"Saving batch recs output to {args.batch_recs_fp}...")
with open(args.batch_recs_fp, "w") as f:
    for rec in recs:
        f.write(json.dumps(rec) + "\n")

[32m2025-11-13 01:57:55.783[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mSaving batch recs output to /mnt/d/projects/recsys/notebooks/data/000-first-attempt/batch_recs.jsonl...[0m
