# Set up

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from qdrant_client import QdrantClient
import os
import sys

sys.path.insert(0, "..")

from pydantic import BaseModel
import torch
from dotenv import load_dotenv
from loguru import logger
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, PointStruct, VectorParams

import mlflow

load_dotenv()

True

# Controller

In [3]:
class Args(BaseModel):
    testing: bool = False
    run_name: str = "000-first-attempt"
    notebook_persist_dp: str = None
    random_seed: int = 41
    device: str = None

    top_K: int = 100
    top_k: int = 10

    embedding_dim: int = 128

    mlf_model_name: str = "item2vec"

    batch_recs_fp: str = None

    qdrant_url: str = None
    qdrant_collection_name: str = None

    def init(self):
        self.notebook_persist_dp = os.path.abspath(f"data/{self.run_name}")
        os.makedirs(self.notebook_persist_dp, exist_ok=True)
        self.batch_recs_fp = f"{self.notebook_persist_dp}/batch_recs.jsonl"

        if not (qdrant_host := os.getenv("QDRANT_HOST")):
            raise Exception(f"Environment variable QDRANT_HOST is not set.")

        qdrant_port = os.getenv("QDRANT_PORT")
        self.qdrant_url = f"{qdrant_host}:{qdrant_port}"
        self.qdrant_collection_name = os.getenv("QDRANT_COLLECTION_NAME")

        return self


args = Args().init()

print(args.model_dump_json(indent=2))

{
  "testing": false,
  "run_name": "000-first-attempt",
  "notebook_persist_dp": "/mnt/d/projects/recsys/notebooks/data/000-first-attempt",
  "random_seed": 41,
  "device": null,
  "top_K": 100,
  "top_k": 10,
  "embedding_dim": 128,
  "mlf_model_name": "item2vec",
  "batch_recs_fp": "/mnt/d/projects/recsys/notebooks/data/000-first-attempt/batch_recs.jsonl",
  "qdrant_url": "localhost:6333",
  "qdrant_collection_name": "item2vec"
}


# Load Model

In [4]:
mlf_client = mlflow.MlflowClient()

In [5]:
model = mlflow.pyfunc.load_model(model_uri=f"models:/{args.mlf_model_name}@champion")

Downloading artifacts:   0%|          | 0/9 [00:00<?, ?it/s]



In [6]:
run_id = model.metadata.run_id
run_info = mlf_client.get_run(run_id).info
artifact_uri = run_info.artifact_uri

In [10]:
artifact_uri

'mlflow-artifacts:/1/4b208545a85b433d85b0f61e55e85f16/artifacts'

In [13]:
sample_input = mlflow.artifacts.load_dict(
    f"models:/{args.mlf_model_name}@champion/input_example.json"
)
sample_input

Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

{'item_1_ids': ['0375869026'], 'item_2_ids': ['9625990674']}

In [14]:
prediction = model.predict(sample_input)
prediction

{'item_1_ids': ['0375869026'],
 'item_2_ids': ['9625990674'],
 'scores': [0.4312632381916046]}

# Get Embedding

In [15]:
skipgram_model = model.unwrap_python_model().model
embedding_0 = skipgram_model.embeddings(torch.tensor(0))
embedding_dim = embedding_0.size()[0]
embedding_0

tensor([-0.0020, -0.0184,  0.0391,  0.0642,  0.0376,  0.0486,  0.1148,  0.0446,
        -0.2252,  0.0908, -0.0712,  0.0655,  0.1153,  0.0209,  0.2188, -0.1144,
        -0.3277,  0.0307, -0.3318, -0.0807,  0.1248,  0.0999,  0.0016, -0.1236,
         0.0700,  0.2429, -0.1043,  0.0345, -0.1657,  0.1284, -0.2426, -0.0023,
        -0.0329,  0.0266,  0.1605,  0.1208,  0.0235, -0.0150,  0.0780,  0.0047,
         0.2346,  0.0712,  0.2373,  0.0247, -0.0358,  0.0112, -0.0911,  0.0224,
        -0.0531, -0.1322,  0.1567, -0.1460, -0.2184, -0.0824, -0.1033, -0.3565,
        -0.1095, -0.0798, -0.0802,  0.0030, -0.0210,  0.1618, -0.0508, -0.0799,
         0.0604, -0.0424, -0.2881, -0.1340,  0.1210,  0.0142, -0.0659,  0.3046,
         0.0363, -0.1055, -0.2015, -0.0947,  0.0137,  0.0554, -0.0403, -0.3724,
         0.0957,  0.2741,  0.0576,  0.3932,  0.0308, -0.0980, -0.0256, -0.1675,
         0.0578, -0.0371, -0.2635,  0.1179, -0.2434,  0.1774,  0.1601,  0.1509,
        -0.1095,  0.0459, -0.2084,  0.08

In [18]:
id_mapping = model.unwrap_python_model().id_mapping
all_items = list(id_mapping["id_to_idx"].values())
all_items[:5]

[0, 1, 2, 3, 4]

In [20]:
embeddings = skipgram_model.embeddings(torch.tensor(all_items)).detach().numpy()
embeddings

array([[-0.001965  , -0.01844763,  0.0390514 , ...,  0.0789619 ,
        -0.04411918,  0.08246701],
       [ 0.01574272, -0.1106883 ,  0.04423822, ...,  0.03968391,
         0.00295376,  0.03199833],
       [-0.09473783,  0.11698911,  0.06254044, ..., -0.07346609,
         0.219918  , -0.26206523],
       ...,
       [-0.01189691, -0.16669825, -0.01239961, ..., -0.08236234,
        -0.00691404,  0.04842342],
       [-0.03206192, -0.08211985, -0.0188489 , ..., -0.04013699,
         0.00213255, -0.03188387],
       [ 0.03446096, -0.00518576, -0.04457873, ..., -0.01205611,
        -0.04385417, -0.04071712]], dtype=float32)

In [21]:
embeddings.shape

(4630, 128)

# Qdrant

In [22]:
ann_index = QdrantClient(url=args.qdrant_url)

In [23]:
collection_exists = ann_index.collection_exists(args.qdrant_collection_name)
if collection_exists:
    logger.info(f"Deleting existing Qdrant collection {args.qdrant_collection_name}...")
    ann_index.delete_collection(args.qdrant_collection_name)

create_collection_result = ann_index.create_collection(
    collection_name=args.qdrant_collection_name,
    vectors_config=VectorParams(size=embedding_dim, distance=Distance.COSINE),
)

assert create_collection_result == True

In [24]:
upsert_result = ann_index.upsert(
    collection_name=args.qdrant_collection_name,
    points=[
        PointStruct(id=idx, vector=vector.tolist(), payload={})
        for idx, vector in enumerate(embeddings)
    ],
)
assert str(upsert_result.status) == "completed"
upsert_result

UpdateResult(operation_id=0, status=<UpdateStatus.COMPLETED: 'completed'>)

In [25]:
hits = ann_index.search(
    collection_name=args.qdrant_collection_name,
    query_vector=embeddings[0],
    limit=args.top_K,
)

In [26]:
hits

[ScoredPoint(id=0, version=0, score=0.9999999, payload={}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=4306, version=0, score=0.39393944, payload={}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=3263, version=0, score=0.3846603, payload={}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=3335, version=0, score=0.3804885, payload={}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=2586, version=0, score=0.37781808, payload={}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=2496, version=0, score=0.3719725, payload={}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=4226, version=0, score=0.36930364, payload={}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=2520, version=0, score=0.3658317, payload={}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=2453, version=0, score=0.36191103, payload={}, vector=None, shard_key=None, order_value=None),
 ScoredPo