In [1]:
from qdrant_client import QdrantClient

from llama_index.vector_stores.qdrant import QdrantVectorStore
from llama_index.core import VectorStoreIndex, StorageContext
from llama_index.core.embeddings import BaseEmbedding
from pydantic import PrivateAttr
from typing import List
from open_clip import create_model_from_pretrained, get_tokenizer
from llama_index.core.schema import TextNode
import h5py
import torch

In [2]:
class CLIPEmbedding(BaseEmbedding):
    _model = PrivateAttr()
    _preprocess = PrivateAttr()
    _tokenizer = PrivateAttr()
    _device = PrivateAttr()

    def __init__(self, model_name: str = "hf-hub:apple/DFN2B-CLIP-ViT-B-16", device: str = "cpu"):
        super().__init__()
        self._device = device
        self._model, self._preprocess = create_model_from_pretrained(model_name)
        self._tokenizer = get_tokenizer("ViT-B-16")
        self._model = self._model.to(self._device).eval()

    def _encode_text(self, text: str) -> List[float]:
        tokens = self._tokenizer([text]).to(self._device)
        with torch.no_grad():
            emb = self._model.encode_text(tokens) 
        return emb[0].cpu().numpy().tolist()

    def _get_query_embedding(self, query: str) -> List[float]:
        return self._encode_text(query)

    def _get_text_embedding(self, text: str) -> List[float]:
        return self._encode_text(text)

    async def _aget_query_embedding(self, query: str) -> List[float]:
        return self._get_query_embedding(query)

    async def _aget_text_embedding(self, text: str) -> List[float]:
        return self._get_text_embedding(text)

device = "cuda" if torch.cuda.is_available() else "cpu"
embed_model = CLIPEmbedding(device=device)

In [3]:
qdrant_client = QdrantClient(
    url="https://09a6d049-00c4-4b77-8e95-1dcc9ea5df34.eu-west-1-0.aws.cloud.qdrant.io:6333",
    api_key="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.-ZPZib9FxehqbTuqxsk7QdVjBQd0LlQEq7dpjF1b4PI",
)

In [None]:
nodes = []
with h5py.File("features/frame_features.hdf5", "r") as fs:
    for key in fs.keys():
        embeddings = fs[key][()]
        median_id = len(embeddings) // 2

        node = TextNode(
            text=f"Frame {key}",
            metadata={"id": key},
            embedding=embeddings[median_id].tolist()
        )
        nodes.append(node)

In [5]:
collection_name = "image"
vector_store = QdrantVectorStore(client=qdrant_client, collection_name=collection_name)
storage_context = StorageContext.from_defaults(vector_store=vector_store)

index = VectorStoreIndex(nodes, storage_context=storage_context, embed_model=embed_model, show_progress=True)

Generating embeddings: 0it [00:00, ?it/s]

In [6]:
vector_store = QdrantVectorStore(client=qdrant_client, collection_name=collection_name)
index = VectorStoreIndex.from_vector_store(vector_store=vector_store, embed_model=embed_model)
retriever = index.as_retriever(similarity_top_k=10)
query = "a plane flying in the sky"
nodes = retriever.retrieve(query)

for i, node in enumerate(nodes):
    print(f"ðŸ”Ž Káº¿t quáº£ {i+1}:")
    print(f"Score: {node.score:.4f}")
    print("Metadata:", node.metadata)
    print(f"Ná»™i dung: {node.get_content()}\n")

ðŸ”Ž Káº¿t quáº£ 1:
Score: 0.2782
Metadata: {'id': '0bSz70pYAP0_5_15'}
Ná»™i dung: Frame 0bSz70pYAP0_5_15

ðŸ”Ž Káº¿t quáº£ 2:
Score: 0.2406
Metadata: {'id': 'VxM96IYzw0Q_2_15'}
Ná»™i dung: Frame VxM96IYzw0Q_2_15

ðŸ”Ž Káº¿t quáº£ 3:
Score: 0.2397
Metadata: {'id': 'ZbzDGXEwtGc_6_15'}
Ná»™i dung: Frame ZbzDGXEwtGc_6_15

ðŸ”Ž Káº¿t quáº£ 4:
Score: 0.2222
Metadata: {'id': 'Eamd2wMKixs_48_72'}
Ná»™i dung: Frame Eamd2wMKixs_48_72

ðŸ”Ž Káº¿t quáº£ 5:
Score: 0.2221
Metadata: {'id': 'DN7jwyL1Xgg_1_19'}
Ná»™i dung: Frame DN7jwyL1Xgg_1_19

ðŸ”Ž Káº¿t quáº£ 6:
Score: 0.2179
Metadata: {'id': '4MjTb5A68VA_111_118'}
Ná»™i dung: Frame 4MjTb5A68VA_111_118

ðŸ”Ž Káº¿t quáº£ 7:
Score: 0.2141
Metadata: {'id': 'vz71JKcpeUU_0_10'}
Ná»™i dung: Frame vz71JKcpeUU_0_10

ðŸ”Ž Káº¿t quáº£ 8:
Score: 0.2116
Metadata: {'id': '3chNlP5TeO8_0_10'}
Ná»™i dung: Frame 3chNlP5TeO8_0_10

ðŸ”Ž Káº¿t quáº£ 9:
Score: 0.2099
Metadata: {'id': 'Gn4Iv5ARIXc_37_40'}
Ná»™i dung: Frame Gn4Iv5ARIXc_37_40

ðŸ”Ž Káº¿t quáº£ 10:
Score: