In [10]:
from pymilvus import (
    connections,
    FieldSchema,
    DataType,
    CollectionSchema,
    Collection,
    utility,
)
import random

connections.connect("default", host="localhost", port="19530")

VEC_DIM = 768
MAX_TITLE_LEN = 200
TEST_N = 3000
COLLECTION_NAME = "playlists"

fields = [
    FieldSchema(name="pid", dtype=DataType.INT64, is_primary=True, auto_id=False),
    FieldSchema(name="title", dtype=DataType.VARCHAR, max_length=MAX_TITLE_LEN),
    FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=VEC_DIM)
]


def build_collection():
    schema = CollectionSchema(fields, "Playlist embeddings")
    collection = Collection(COLLECTION_NAME, schema)

    entities = [
        [i for i in range(TEST_N)],
        [f"playlist {i}" for i in range(TEST_N)],
        [[random.random() for _ in range(VEC_DIM)] for _ in range(TEST_N)],
    ]

    collection.insert(entities)
    collection.flush()

    index = {
        "index_type": "IVF_FLAT",
        "metric_type": "L2",
        "params": {"nlist": 128},
    }

    collection.create_index("embedding", index)


if not utility.has_collection(COLLECTION_NAME):
    print("Building new collection...")
    build_collection()
else:
    print(f"Collection '{COLLECTION_NAME}' already exists")

Building new collection...


In [8]:
collection = Collection(COLLECTION_NAME)
collection.load()
vectors_to_search = entities[-1][-2:]
search_params = {
    "metric_type": "L2",
    "params": {"nprobe": 10}
}
result = collection.search(vectors_to_search, "embedding", search_params, limit=3, output_fields=["pid", "title"])
result.

<pymilvus.orm.search.SearchResult at 0x7f3b54254040>