In [None]:
!pip install qdrant-client

In [3]:
import pandas as pd
from pathlib import Path
from qdrant_client import models, QdrantClient
from nlp_models.multi_task_model.mtl import MTLInference

DATA_FOLDER = Path('../data/0_external/google-quest-challenge')
MODEL_DIR = Path('../models/multi-task-model/')

In [4]:
output_dir = MODEL_DIR / 'multi-task-model-finetuned-classification-layer-20230609'
tokenizer_dir = output_dir / 'tokenizer'
model_file = output_dir / 'mtl.bin'

mtl = MTLInference(tokenizer_dir, model_file, pretrained_model=False)

In [5]:
df_train = pd.read_csv(DATA_FOLDER / 'train.csv')
label_dict = dict([(k,v) for k, v in enumerate(df_train.category.unique())])

In [6]:
qdrant = QdrantClient(':memory:')
qdrant.recreate_collection(
    collection_name='qb',
    vectors_config=models.VectorParams(
        size=mtl.mtl_model.base_model.config.hidden_size,
        distance=models.Distance.DOT
    )
)

True

In [7]:
qdrant.upload_records(
    collection_name='qb',
    records=[
        models.Record(
            id=idx,
            vector=mtl.predict(rec['question_title'])[1].squeeze().tolist(),
            payload=rec
        ) for idx, rec in df_train.iterrows()
    ]
)

In [None]:
query='What helmet to use for biking in winter?'
matches = qdrant.search(
    collection_name='qb',
    query_vector=mtl.predict(query)[1].squeeze().tolist(),
    limit=3
)
matches

In [None]:
query = 'what helmet to use for biking in winter?'
outputs = mtl.predict(query)
matches = qdrant.search(
    collection_name='qb',
    query_vector=outputs[1].squeeze().tolist(),
    query_filter=models.Filter(
        must=[
            models.FieldCondition(
                key='category',
                match=models.MatchValue(
                    value=label_dict[outputs[0].squeeze().gt(0.5).nonzero().item()]
                    )
            )
        ]
    )
)