In [None]:
import pandas as pd
from pymilvus import connections, MilvusClient, DataType, Collection
from embedders import get_embedder
from database.database import Database

# connections.connect("default", host="localhost", port="19530")
client = MilvusClient("default")
print(f"Client: {client}")

db = Database()
db.test_connection()

embedder = get_embedder("UniverseTBD/astrollama", device="mps", normalize=False)
print(f"Embedder: {embedder}")


Client: <pymilvus.milvus_client.milvus_client.MilvusClient object at 0x34368f990>
Database         User             Host                             Port            
citelinedb       bbasseri         localhost                        5432            
Database version: ('PostgreSQL 17.5 (Homebrew) on aarch64-apple-darwin24.4.0, compiled by Apple clang version 17.0.0 (clang-1700.0.13.3), 64-bit',)


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Embedder: UniverseTBD/astrollama, device=mps, normalize=False


In [28]:
# Create the collection if it doesn't exist

COLLECTION_NAME = "contributions"
if not client.has_collection(collection_name=COLLECTION_NAME):
    print(f"Collection '{COLLECTION_NAME}' does not exist. Creating it.")
    schema = client.create_schema(
        auto_id=True,
        enable_dynamic_field=True
    )
    schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True)
    schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=2048)
    schema.add_field(field_name="doi", datatype=DataType.VARCHAR, max_length=64)
    schema.add_field(field_name="pubdate", datatype=DataType.INT64)  # Milvus has no date type
    schema.add_field(field_name="astrollama", datatype=DataType.FLOAT_VECTOR, dim=4096)

    client.create_collection(collection_name=COLLECTION_NAME, schema=schema)
print(f"Collections on client: {client.list_collections()}")
collection = Collection(COLLECTION_NAME)

Collections on client: ['contributions', 'chunks', 'basic_collection']


In [4]:
def rows_to_df(rows: list[tuple]) -> pd.DataFrame:
    """
    Expects a list of tuples (text, doi, pubdate)
    Embeds the text using the embedder to create a new column 'astrollama'
    Converts the pubdate in datetime.date format to an integer YYYYMMDD
    """
    df = pd.DataFrame(rows, columns=["text", "doi", "pubdate"])
    df["astrollama"] = embedder(df["text"]).tolist()  # Convert numpy array to list of lists
    df["pubdate"] = df["pubdate"].apply(lambda x: int(x.strftime("%Y%m%d")))  # Convert date to int YYYYMMDD
    return df
    

In [5]:
# Iterate over DB contributions table and fetch data
BATCH_SIZE = 16
OFFSET = collection.num_entities
with db.conn.cursor() as cursor:
    cursor.execute("SELECT text, doi, pubdate FROM contributions OFFSET %s", (OFFSET,))
    while True:
        rows = cursor.fetchmany(BATCH_SIZE)
        if not rows:
            break
        df = rows_to_df(rows)

        # Insert into collection
        result = collection.insert(df)
        print(result)



(insert count: 16, delete count: 0, upsert count: 0, timestamp: 460053158988087298, success count: 16, err count: 0
(insert count: 16, delete count: 0, upsert count: 0, timestamp: 460053159354564609, success count: 16, err count: 0
(insert count: 16, delete count: 0, upsert count: 0, timestamp: 460053159656030209, success count: 16, err count: 0
(insert count: 16, delete count: 0, upsert count: 0, timestamp: 460053159931281409, success count: 16, err count: 0
(insert count: 16, delete count: 0, upsert count: 0, timestamp: 460053160258961409, success count: 16, err count: 0
(insert count: 16, delete count: 0, upsert count: 0, timestamp: 460053160639070209, success count: 16, err count: 0
(insert count: 16, delete count: 0, upsert count: 0, timestamp: 460053161032286210, success count: 16, err count: 0
(insert count: 16, delete count: 0, upsert count: 0, timestamp: 460053161412395011, success count: 16, err count: 0
(insert count: 16, delete count: 0, upsert count: 0, timestamp: 46005316

In [31]:
collection.indexes

[]

In [33]:
collection.create_index(
    field_name="astrollama",
    index_params={
        "index_type": "FLAT",
        "metric_type": "L2",
    },
)

index_params = client.prepare_index_params()
index_params.add_index(
    field_name="pubdate",
    index_type="STL_SORT",
    index_name="contributions_pubdate_index",
)

client.create_index(
    collection_name=collection.name,
    index_params=index_params
)

In [25]:
# Create the collection if it doesn't exist

COLLECTION_NAME = "chunks"
if not client.has_collection(collection_name=COLLECTION_NAME):
    print(f"Collection '{COLLECTION_NAME}' does not exist. Creating it.")
    schema = client.create_schema(auto_id=True, enable_dynamic_field=True)
    schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True)
    schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=2048)
    schema.add_field(field_name="doi", datatype=DataType.VARCHAR, max_length=64)
    schema.add_field(field_name="pubdate", datatype=DataType.INT64)  # Milvus has no date type
    schema.add_field(field_name="astrollama", datatype=DataType.FLOAT_VECTOR, dim=4096)

    client.create_collection(collection_name=COLLECTION_NAME, schema=schema)
print(f"Collections on client: {client.list_collections()}")
collection = Collection(COLLECTION_NAME)
print(f"{collection.num_entities} entities in collection {COLLECTION_NAME}")

Collection 'chunks' does not exist. Creating it.
Collections on client: ['chunks', 'basic_collection', 'contributions']
0 entities in collection chunks


In [26]:
# Iterate over DB contributions table and fetch data
with db.conn.cursor() as cursor:
    cursor.execute("SELECT COUNT(*) FROM chunks")
    total_chunks = cursor.fetchone()[0]

processed = 0
print(f"Total chunks to process: {total_chunks}")


Total chunks to process: 463202


In [None]:
# db.conn.rollback()

In [27]:
BATCH_SIZE = 8
with db.conn.cursor() as cursor:
    cursor.execute("SELECT text, doi, pubdate FROM chunks")
    while True:
        rows = cursor.fetchmany(BATCH_SIZE)
        # print(f"Fetched {len(rows)} rows.")
        if not rows:
            break
        df = rows_to_df(rows)
        # print(f"Converted rows to DataFrame with {len(df)} entries.")

        # Insert into collection
        result = collection.insert(df)
        processed += len(rows)
        print(f"\rProcessed {processed}/{total_chunks} total chunks.", end="", flush=True)

Processed 920/463202 total chunks.

KeyboardInterrupt: 