# Milvus in Python

1. Define the connection and instantiate a client
1. To create a table (collection in Milvus), define a schema


In [1]:
from pymilvus import connections, MilvusClient, DataType, Collection

connections.connect("default", host="localhost", port="19530")
client = MilvusClient(alias="default")
print(client)

<pymilvus.milvus_client.milvus_client.MilvusClient object at 0x1285ce9d0>


In [None]:
schema = client.create_schema(
    auto_id=True,
    enable_dynamic_field=True,  # allows us to later add additional fields (i.e. vector fields)
)
schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True)
schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=2048)
schema.add_field(field_name="pubdate", datatype=DataType.INT64)  # Milvus has no date type
schema.add_field(field_name="doi", datatype=DataType.VARCHAR, max_length=64)
schema.add_field(field_name="astrollama", datatype=DataType.FLOAT_VECTOR, dim=4096)

client.create_collection(collection_name="basic_collection", schema=schema)

In [2]:
client.list_collections()

['basic_collection']

In [3]:
collection = Collection("basic_collection")
fields_to_keep = [field.name for field in collection.schema.fields]
print(fields_to_keep)

['id', 'text', 'pubdate', 'doi', 'astrollama']


In [None]:
import pandas as pd

sample_data = pd.read_json("data/dataset/split/small_train.jsonl", lines=True)

In [4]:
from embedders import get_embedder

embedder = get_embedder("UniverseTBD/astrollama", device="mps", normalize=False)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
type(sample_data.sent_no_cit)

In [None]:
embeddings = embedder(sample_data["sent_no_cit"])
print(embeddings.shape)

In [None]:
sample_data.columns

In [None]:
texts = sample_data["sent_no_cit"]
embeddings = embedder(texts)

df = sample_data.rename(columns={"sent_no_cit": "text", "source_doi": "doi"})
df["pubdate"] = df["pubdate"].apply(lambda x: int(x.replace("-", ""))) # Convert date to int YYYYMMDD
df["astrollama"] = embeddings.tolist()  # Convert numpy array to list of lists

# Set up columns
columns_in_order = [col for col in fields_to_keep if col != 'id']  # Exclude auto-generated id field
df = df[columns_in_order]
df.head()

In [None]:
result = collection.insert(df)
print(result)

### Searching

In [None]:
# First you must have an index (even FLAT i.e. brute-force)
collection.create_index(
    field_name="astrollama",
    index_params={
        "index_type": "FLAT",
        "metric_type": "L2",
    }
)


In [5]:
# Put collection into memory
collection.load()

In [6]:
query_vector = embedder(["I want to know more about gravitational effects"])[0]
print(query_vector.shape)
print(type(query_vector))

(4096,)
<class 'numpy.ndarray'>


In [7]:
results = collection.search(
    data=[query_vector],
    anns_field="astrollama",
    param={
        "metric_type": "L2",
    },
    limit=10,
    output_fields=["text", "pubdate", "doi"],
)

In [8]:
for hit in results[0]:
    print(hit)

{'id': 460044358029246518, 'distance': 3008.390625, 'entity': {'text': '[Within purely stellar radiation or energetic X-ray photons, either the total number of ionizing photons produced or the total radiated energy, respectively, is what matters for reionization. This is because, in a largely neutral medium, each photoionization produces a host of secondary collisional ionizations, with approximately one hydrogen secondary ionization for every 37 eV of energy in the primary photoelectron ( [REF] ). As the medium becomes more ionized, however, an increasing fraction of this energy is deposited as heat.] Figure 16 depicts the quantity at z >6 according to our best-fit SFH for the range of stellar metallicities 0 Z ∗Z ⊙ .', 'pubdate': 20140801, 'doi': '10.1146/annurev-astro-081811-125615'}}
{'id': 460044358029246519, 'distance': 5980.67578125, 'entity': {'text': 'Using the same COS-Halos sample with new COS spectra covering the Lyman limit, and taking a nonparametric approach with a robus

In [None]:
results = collection.search(
    data=[query_vector.tolist()],
    anns_field="astrollama",
    param={
        "metric_type": "L2",
    },
    limit=10,
    output_fields=["text", "pubdate", "doi"],
)

for hit in results[0]:
    print(hit)