# Batch pre-computed recommendations based on ANN search

# Set up

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json
import os
import sys
import time

import torch
from dotenv import load_dotenv
from loguru import logger
from pydantic import BaseModel
from qdrant_client import QdrantClient
from tqdm.auto import tqdm

import mlflow

load_dotenv()

sys.path.insert(0, "..")

# Controller

In [3]:
class Args(BaseModel):
    testing: bool = False
    run_name: str = "000-first-attempt"
    notebook_persist_dp: str = None
    random_seed: int = 41
    device: str = None

    top_K: int = 100
    top_k: int = 10

    mlf_model_name: str = "item2vec"

    batch_recs_fp: str = None

    qdrant_url: str = None
    qdrant_collection_name: str = None

    def init(self):
        self.notebook_persist_dp = os.path.abspath(f"data/{self.run_name}")
        os.makedirs(self.notebook_persist_dp, exist_ok=True)
        self.batch_recs_fp = f"{self.notebook_persist_dp}/batch_recs.jsonl"

        if not (qdrant_host := os.getenv("QDRANT_HOST")):
            raise Exception(f"Environment variable QDRANT_HOST is not set.")

        qdrant_port = os.getenv("QDRANT_PORT")
        self.qdrant_url = f"{qdrant_host}:{qdrant_port}"
        self.qdrant_collection_name = os.getenv("QDRANT_COLLECTION_NAME")

        return self


args = Args().init()

print(args.model_dump_json(indent=2))

{
  "testing": false,
  "run_name": "000-first-attempt",
  "notebook_persist_dp": "/Users/dvq/frostmourne/recsys-mvp/notebooks/data/000-first-attempt",
  "random_seed": 41,
  "device": null,
  "top_K": 100,
  "top_k": 10,
  "mlf_model_name": "item2vec",
  "batch_recs_fp": "/Users/dvq/frostmourne/recsys-mvp/notebooks/data/000-first-attempt/batch_recs.jsonl",
  "qdrant_url": "localhost:6333",
  "qdrant_collection_name": "item2vec"
}


# Load ANN Index

In [4]:
ann_index = QdrantClient(url=args.qdrant_url)
if not ann_index.collection_exists(args.qdrant_collection_name):
    raise Exception(
        f"Required Qdrant collection {args.qdrant_collection_name} does not exist"
    )

In [5]:
def get_vector_by_id(id_: int):
    record = ann_index.retrieve(
        collection_name=args.qdrant_collection_name, ids=[id_], with_vectors=True
    )[0]
    return record.vector

In [6]:
vector = get_vector_by_id(0)
neighbors = ann_index.search(
    collection_name=args.qdrant_collection_name, query_vector=vector, limit=5
)

In [7]:
neighbors

[ScoredPoint(id=0, version=0, score=1.0, payload={}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=3668, version=0, score=0.57748246, payload={}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=2115, version=0, score=0.56080717, payload={}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=1005, version=0, score=0.5573507, payload={}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=691, version=0, score=0.54987663, payload={}, vector=None, shard_key=None, order_value=None)]

# Load model

In [8]:
mlf_client = mlflow.MlflowClient()

In [9]:
model = mlflow.pyfunc.load_model(model_uri=f"models:/{args.mlf_model_name}@champion")

Downloading artifacts:   0%|          | 0/8 [00:00<?, ?it/s]

In [10]:
run_id = model.metadata.run_id
run_info = mlf_client.get_run(run_id).info
artifact_uri = run_info.artifact_uri

In [11]:
sample_input = mlflow.artifacts.load_dict(f"{artifact_uri}/inferrer/input_example.json")
sample_input

Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

{'item_1_ids': ['B00CMQTUSS'], 'item_2_ids': ['B001EYURK4']}

In [12]:
prediction = model.predict(sample_input)
prediction

{'item_1_ids': ['B00CMQTUSS'],
 'item_2_ids': ['B001EYURK4'],
 'scores': [0.42891860008239746]}

## Batch recommend for all items

In [13]:
skipgram_model = model.unwrap_python_model().model
id_mapping = model.unwrap_python_model().id_mapping
all_items = list(id_mapping["id_to_idx"].values())
all_items[:5]

[0, 1, 2, 3, 4]

In [15]:
# papermill_description=batch-precompute
recs = []
records = ann_index.retrieve(
    collection_name=args.qdrant_collection_name, ids=all_items, with_vectors=True
)
vectors = [record.vector for record in records]
model_pred_times = []
# TODO: Optimize this loop being run inside Docker. Running on MacOS host it can runs with 350it/s but inside Docker it runs with only 20it/s.
# I have identify that the skipgram_model call runs 10 times longer in Docker (MacOS with Apple Silicon).
# This problem applies in both MacOS and Ubuntu base machine 🤔
for indice, query_embedding in tqdm(zip(all_items, vectors)):
    neighbor_records = ann_index.search(
        collection_name=args.qdrant_collection_name,
        query_vector=query_embedding,
        limit=args.top_K + 1,
    )
    neighbors = [neighbor.id for neighbor in neighbor_records]
    # Remove self-recommendation
    neighbors = [neighbor for neighbor in neighbors if neighbor != indice]
    # Recalculate prediction scores for all neighbors
    t0 = time.time()
    scores = (
        skipgram_model(torch.tensor([indice] * len(neighbors)), torch.tensor(neighbors))
        .detach()
        .numpy()
        .astype(float)
    )
    t1 = time.time()
    model_pred_times.append(t1 - t0)
    # Rerank scores based on model output predictions
    neighbors, scores = zip(
        *sorted(zip(neighbors, scores), key=lambda x: x[1], reverse=True)
    )
    neighbor_ids = [id_mapping["idx_to_id"][str(idx)] for idx in neighbors]
    id_ = id_mapping["idx_to_id"][str(indice)]
    recs.append(
        {"target_item": id_, "rec_item_ids": neighbor_ids, "rec_scores": list(scores)}
    )

0it [00:00, ?it/s]

In [16]:
avg_model_inference_seconds = sum(model_pred_times) / len(model_pred_times)
logger.info(
    f"Average model inference time: {avg_model_inference_seconds * 1000} milliseconds"
)

0.00010992105784735464

In [17]:
recs[0]

{'target_item': 'B00CMQTUSS',
 'rec_item_ids': ['B00W9DHUBS',
  'B00BN5T30E',
  'B00KSQHX1K',
  'B00DTY9B0O',
  'B09V5R5LSZ',
  'B00BHRD4BM',
  'B00CMQTVK0',
  'B00C1TTF86',
  'B00BGA9X9W',
  'B00DJRLDMU',
  'B005GISQX4',
  'B00CJ9OTNE',
  'B00D5SZ04K',
  'B00HM1XPN4',
  'B00DJRLAZ0',
  'B00CMQTVUA',
  'B00GN67PJ4',
  'B00IAVDOS6',
  'B007CM0K86',
  'B00K5HTPR2',
  'B00BQVXUOA',
  'B00BZS9JV2',
  'B00DTWEOZ8',
  'B07WZS4CTC',
  'B07BLRF329',
  'B00DBDPOZ4',
  'B00J48C36S',
  'B00KIFM28A',
  'B07KXFB1P8',
  'B003NSLGW2',
  'B00EN9Q8G4',
  'B00CEGCN76',
  'B0088TN5FM',
  'B00IAVDPSA',
  'B00KSRV19E',
  'B00YJJ0OQS',
  'B00DBLBMBQ',
  'B07YBX4VQN',
  'B00FM5IY0Q',
  'B00TY9KYKE',
  'B00CMQTUY2',
  'B00WNO6YKG',
  'B00CISMP8M',
  'B00CES8EFY',
  'B00YM7AKLG',
  'B0088MVPFQ',
  'B00G2EVF3E',
  'B00IIHU44E',
  'B00EQNP8F4',
  'B00KR2C0RC',
  'B00LV416KC',
  'B07WPFV4ZW',
  'B00IAVDQCK',
  'B00DB2BI00',
  'B00KPZKS8E',
  'B00BAQXJMO',
  'B00DYJGDYQ',
  'B017QU5G1O',
  'B00HRH79H6',
  'B00BIXY

# Persist

In [18]:
logger.info(f"Saving batch recs output to {args.batch_recs_fp}...")
with open(args.batch_recs_fp, "w") as f:
    for rec in recs:
        f.write(json.dumps(rec) + "\n")

[32m2024-10-12 11:16:34.489[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mSaving batch recs output to /Users/dvq/frostmourne/recsys-mvp/notebooks/data/000-first-attempt/batch_recs.jsonl...[0m
