# Index embeddings into embedding store

# Set up

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys

import torch
from dotenv import load_dotenv
from loguru import logger
from pydantic import BaseModel
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, PointStruct, VectorParams

import mlflow

sys.path.insert(0, "..")
load_dotenv()

True

# Controller

In [3]:
class Args(BaseModel):
    testing: bool = False
    run_name: str = "000-first-attempt"
    notebook_persist_dp: str = None
    random_seed: int = 41
    device: str = None

    top_K: int = 100
    top_k: int = 10

    embedding_dim: int = 128

    mlf_model_name: str = "item2vec"

    batch_recs_fp: str = None

    qdrant_url: str = None
    qdrant_collection_name: str = None

    def init(self):
        self.notebook_persist_dp = os.path.abspath(f"data/{self.run_name}")
        os.makedirs(self.notebook_persist_dp, exist_ok=True)
        self.batch_recs_fp = f"{self.notebook_persist_dp}/batch_recs.jsonl"

        if not (qdrant_host := os.getenv("QDRANT_HOST")):
            raise Exception(f"Environment variable QDRANT_HOST is not set.")

        qdrant_port = os.getenv("QDRANT_PORT")
        self.qdrant_url = f"{qdrant_host}:{qdrant_port}"
        self.qdrant_collection_name = os.getenv("QDRANT_COLLECTION_NAME")

        return self


args = Args().init()

print(args.model_dump_json(indent=2))

{
  "testing": false,
  "run_name": "000-first-attempt",
  "notebook_persist_dp": "/Users/dvq/frostmourne/recsys-mvp/notebooks/data/000-first-attempt",
  "random_seed": 41,
  "device": null,
  "top_K": 100,
  "top_k": 10,
  "embedding_dim": 128,
  "mlf_model_name": "item2vec",
  "batch_recs_fp": "/Users/dvq/frostmourne/recsys-mvp/notebooks/data/000-first-attempt/batch_recs.jsonl",
  "qdrant_url": "localhost:6333",
  "qdrant_collection_name": "item2vec"
}


# Load model

In [4]:
mlf_client = mlflow.MlflowClient()

In [5]:
model = mlflow.pyfunc.load_model(model_uri=f"models:/{args.mlf_model_name}@champion")

Downloading artifacts:   0%|          | 0/8 [00:00<?, ?it/s]

 - cloudpickle (current: 3.1.0, required: cloudpickle==3.0.0)
To fix the mismatches, call `mlflow.pyfunc.get_model_dependencies(model_uri)` to fetch the model's environment and install dependencies using the resulting environment file.


In [6]:
run_id = model.metadata.run_id
run_info = mlf_client.get_run(run_id).info
artifact_uri = run_info.artifact_uri

In [7]:
sample_input = mlflow.artifacts.load_dict(f"{artifact_uri}/inferrer/input_example.json")
sample_input

Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

{'item_1_ids': ['B0015ACX3Q'], 'item_2_ids': ['B009VUHWBA']}

In [8]:
prediction = model.predict(sample_input)
prediction

{'item_1_ids': ['B0015ACX3Q'],
 'item_2_ids': ['B009VUHWBA'],
 'scores': [0.5049286484718323]}

# Get embeddings

In [9]:
skipgram_model = model.unwrap_python_model().model
embedding_0 = skipgram_model.embeddings(torch.tensor(0))
embedding_dim = embedding_0.size()[0]
embedding_0

tensor([ 0.1503,  0.1055, -0.2476, -0.0465, -0.0883,  0.0994,  0.2884, -0.2099,
         0.0748,  0.0424, -0.1090, -0.1652,  0.1983, -0.0623,  0.0731,  0.1268,
        -0.0486,  0.4281, -0.3518,  0.3000, -0.3354, -0.1979, -0.0372,  0.3658,
        -0.0209,  0.1510, -0.1517, -0.0391,  0.0733,  0.0255, -0.2660,  0.1875,
         0.1510,  0.1938,  0.1626, -0.3465,  0.1775, -0.0117,  0.2518,  0.0909,
        -0.4770, -0.0490, -0.2498, -0.3234, -0.0900,  0.2820, -0.0304, -0.1027,
        -0.2447,  0.3076, -0.1769, -0.0833, -0.0449,  0.0054,  0.0577, -0.2451,
         0.1902, -0.0901,  0.4717,  0.1035,  0.2562,  0.1041,  0.0230, -0.0305,
         0.2279, -0.4729,  0.0442, -0.0101, -0.0440, -0.4081,  0.2048, -0.1479,
         0.4262, -0.1546, -0.0713,  0.0224, -0.1713, -0.1386,  0.2682, -0.2696,
        -0.3055,  0.2965, -0.3019, -0.0385,  0.1030,  0.3998,  0.1618,  0.0049,
        -0.1015,  0.5003,  0.0309,  0.2191,  0.0314, -0.0273, -0.0231, -0.1765,
         0.1304, -0.0575, -0.1425, -0.04

In [10]:
id_mapping = model.unwrap_python_model().id_mapping
all_items = list(id_mapping["id_to_idx"].values())
all_items[:5]

[0, 1, 2, 3, 4]

In [11]:
embeddings = skipgram_model.embeddings(torch.tensor(all_items)).detach().numpy()
embeddings

array([[ 0.15028489,  0.10554074, -0.24757317, ..., -0.43071437,
         0.19808286,  0.08967826],
       [-0.17599209, -0.09977298, -0.2695489 , ..., -0.18573076,
         0.09876981,  0.23391782],
       [-0.15676837, -0.12917697,  0.12649146, ...,  0.2507824 ,
        -0.32547793, -0.04743848],
       ...,
       [-0.09513666, -0.00402308,  0.00828476, ...,  0.05977607,
        -0.01325929,  0.48805568],
       [-0.11036596, -0.03687766, -0.18476743, ..., -0.23155648,
        -0.08153995,  0.06819829],
       [-0.09447317,  0.16883737,  0.26790723, ...,  0.2716018 ,
         0.02243392,  0.0830389 ]], dtype=float32)

In [12]:
embeddings.shape

(4630, 128)

# Embedding store

In [14]:
ann_index = QdrantClient(url=args.qdrant_url)

In [15]:
collection_exists = ann_index.collection_exists(args.qdrant_collection_name)
if collection_exists:
    logger.info(f"Deleting existing Qdrant collection {args.qdrant_collection_name}...")
    ann_index.delete_collection(args.qdrant_collection_name)

create_collection_result = ann_index.create_collection(
    collection_name=args.qdrant_collection_name,
    vectors_config=VectorParams(size=embedding_dim, distance=Distance.COSINE),
)

assert create_collection_result == True

[32m2024-10-12 09:41:32.747[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m3[0m - [1mDeleting existing Qdrant collection item2vec...[0m


In [16]:
upsert_result = ann_index.upsert(
    collection_name=args.qdrant_collection_name,
    points=[
        PointStruct(id=idx, vector=vector.tolist(), payload={})
        for idx, vector in enumerate(embeddings)
    ],
)
assert str(upsert_result.status) == "completed"
upsert_result

UpdateResult(operation_id=0, status=<UpdateStatus.COMPLETED: 'completed'>)

In [17]:
hits = ann_index.search(
    collection_name=args.qdrant_collection_name,
    query_vector=embeddings[0],
    limit=args.top_K,
)

In [18]:
hits

[ScoredPoint(id=0, version=0, score=1.0000001, payload={}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=322, version=0, score=0.43110707, payload={}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=2194, version=0, score=0.41356394, payload={}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=2765, version=0, score=0.40142947, payload={}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=617, version=0, score=0.38005733, payload={}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=2237, version=0, score=0.37947062, payload={}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=2067, version=0, score=0.3756628, payload={}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=4088, version=0, score=0.37306264, payload={}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=1531, version=0, score=0.36884758, payload={}, vector=None, shard_key=None, order_value=None),
 ScoredP