In [1]:
import pandas as pd
from pymilvus import connections, MilvusClient, DataType, Collection
from embedders import get_embedder
from database.database import Database



In [3]:
# connections.connect("default", host="localhost", port="19530")
client = MilvusClient(alias="default")
print(f"Client: {client}")
collection = Collection(name="contributions")
print(collection)

Client: <pymilvus.milvus_client.milvus_client.MilvusClient object at 0x104ef54d0>
<Collection>:
-------------
<name>: contributions
<description>: 
<schema>: {'auto_id': True, 'description': '', 'fields': [{'name': 'id', 'description': '', 'type': <DataType.INT64: 5>, 'is_primary': True, 'auto_id': True}, {'name': 'text', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 2048}}, {'name': 'doi', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 64}}, {'name': 'pubdate', 'description': '', 'type': <DataType.INT64: 5>}, {'name': 'astrollama', 'description': '', 'type': <DataType.FLOAT_VECTOR: 101>, 'params': {'dim': 4096}}], 'enable_dynamic_field': True}



In [None]:
# Instantiate pg database client
db = Database()
db.test_connection()

In [None]:
# Create the collection if it doesn't exist

COLLECTION_NAME = "contributions"
if not client.has_collection(collection_name=COLLECTION_NAME):
    print(f"Collection '{COLLECTION_NAME}' does not exist. Creating it.")
    schema = client.create_schema(
        auto_id=True,
        enable_dynamic_field=True
    )
    schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True)
    schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=2048)
    schema.add_field(field_name="doi", datatype=DataType.VARCHAR, max_length=64)
    schema.add_field(field_name="pubdate", datatype=DataType.INT64)  # Milvus has no date type
    schema.add_field(field_name="astrollama", datatype=DataType.FLOAT_VECTOR, dim=4096)

    client.create_collection(collection_name=COLLECTION_NAME, schema=schema)
print(f"Collections on client: {client.list_collections()}")
collection = Collection(COLLECTION_NAME)

In [None]:
# Instantiate astrollama embedder
embedder = get_embedder("UniverseTBD/astrollama", device="mps", normalize=False)
print(f"Embedder: {embedder}")

In [None]:
collection.schema.description

In [None]:
from pymilvus import FieldSchema, CollectionSchema
old_fields = collection.schema.fields
new_fields = []

for field in old_fields:
    new_field = FieldSchema(
        name=field.name,
        dtype=field.dtype,
        is_primary=field.is_primary,
        auto_id=field.auto_id,
        max_length=field.max_length if field.dtype == DataType.VARCHAR else None,
        dim=field.dim if field.dtype in [DataType.FLOAT_VECTOR, DataType.FLOAT16_VECTOR, DataType.BFLOAT16_VECTOR] else None,
        description=field.description
    )
    new_fields.append(new_field)

new_schema = CollectionSchema(
    fields=new_fields,
    description=collection.schema.description,
)

In [None]:
def rows_to_df(rows: list[tuple]) -> pd.DataFrame:
    """
    Expects a list of tuples (text, doi, pubdate)
    Embeds the text using the embedder to create a new column 'astrollama'
    Converts the pubdate in datetime.date format to an integer YYYYMMDD
    """
    df = pd.DataFrame(rows, columns=["text", "doi", "pubdate"])
    df["astrollama"] = embedder(df["text"]).tolist()  # Convert numpy array to list of lists
    df["pubdate"] = df["pubdate"].apply(lambda x: int(x.strftime("%Y%m%d")))  # Convert date to int YYYYMMDD
    return df
    

In [None]:
# Iterate over DB contributions table and fetch data
BATCH_SIZE = 16
OFFSET = collection.num_entities
with db.conn.cursor() as cursor:
    cursor.execute("SELECT text, doi, pubdate FROM contributions OFFSET %s", (OFFSET,))
    while True:
        rows = cursor.fetchmany(BATCH_SIZE)
        if not rows:
            break
        df = rows_to_df(rows)

        # Insert into collection
        result = collection.insert(df)
        print(result)



In [None]:
prop = collection.schema.fields[0]
print(prop)
print(type(prop))

In [7]:
print(collection.schema.fields)
print("Get just the names:")
names = [field.name for field in collection.schema.fields]
names.remove("id")
print(names)

[{'name': 'id', 'description': '', 'type': <DataType.INT64: 5>, 'is_primary': True, 'auto_id': True}, {'name': 'text', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 2048}}, {'name': 'doi', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 64}}, {'name': 'pubdate', 'description': '', 'type': <DataType.INT64: 5>}, {'name': 'astrollama', 'description': '', 'type': <DataType.FLOAT_VECTOR: 101>, 'params': {'dim': 4096}}]
Get just the names:
['text', 'doi', 'pubdate', 'astrollama']


In [8]:
client.list_collections()

['contributions', 'chunks']

In [None]:
from tqdm import tqdm
# Create new collection with both vector fields
NEW_COLLECTION_NAME = "contributions_with_qwen"

if not client.has_collection(collection_name=NEW_COLLECTION_NAME):
    print(f"Creating new collection '{NEW_COLLECTION_NAME}' with both vector fields.")
    schema = client.create_schema(auto_id=True, enable_dynamic_field=True)
    schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True)
    schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=2048)
    schema.add_field(field_name="doi", datatype=DataType.VARCHAR, max_length=64)
    schema.add_field(field_name="pubdate", datatype=DataType.INT64)
    schema.add_field(field_name="astrollama", datatype=DataType.FLOAT_VECTOR, dim=4096)
    schema.add_field(field_name="qwen8b", datatype=DataType.FLOAT_VECTOR, dim=4096)  # New field

    client.create_collection(collection_name=NEW_COLLECTION_NAME, schema=schema)

new_collection = Collection(NEW_COLLECTION_NAME)
old_collection = Collection("contributions")

# Load embedder
qwen_embedder = get_embedder("Qwen/Qwen3-Embedding-8B", device="mps", normalize=True)
print(f"Qwen Embedder: {qwen_embedder}")

# Migrate data with new embeddings
BATCH_SIZE = 8
total_entities = old_collection.num_entities
processed = 0

print(f"Migrating {total_entities} entities from 'contributions' to '{NEW_COLLECTION_NAME}'")

with tqdm(total=total_entities, desc="Migrating with Qwen embeddings") as pbar:
    for offset in range(0, total_entities, BATCH_SIZE):
        # Query batch from old collection
        entities = old_collection.query(
            expr="", output_fields=["text", "doi", "pubdate", "astrollama"], limit=BATCH_SIZE, offset=offset
        )

        if not entities:
            break

        # Extract data for new collection
        texts = [entity["text"] for entity in entities]
        dois = [entity["doi"] for entity in entities]
        pubdates = [entity["pubdate"] for entity in entities]
        astrollama_embeddings = [entity["astrollama"] for entity in entities]

        # Generate new Qwen embeddings
        qwen_embeddings = qwen_embedder(texts)

        # Prepare data for insertion
        insert_data = {
            "text": texts,
            "doi": dois,
            "pubdate": pubdates,
            "astrollama": astrollama_embeddings,
            "qwen8b": qwen_embeddings.tolist(),
        }

        # Insert into new collection
        result = new_collection.insert(insert_data)
        processed += len(entities)
        pbar.update(len(entities))

print(f"Migration complete! {processed} entities migrated.")

# Create indexes on the new collection
new_collection.create_index(
    field_name="astrollama",
    index_params={
        "index_type": "FLAT",
        "metric_type": "IP",
    },
)

new_collection.create_index(
    field_name="qwen8b",
    index_params={
        "index_type": "FLAT",
        "metric_type": "COSINE",
    },
)

# Optional: Drop old collection and rename new one
client.drop_collection("contributions")
client.rename_collection(old_name=NEW_COLLECTION_NAME, new_name="contributions")

In [None]:
collection.create_index(
    field_name="astrollama",
    index_params={
        "index_type": "FLAT",
        "metric_type": "L2",
    },
)

index_params = client.prepare_index_params()
index_params.add_index(
    field_name="pubdate",
    index_type="STL_SORT",
    index_name="contributions_pubdate_index",
)

client.create_index(
    collection_name=collection.name,
    index_params=index_params
)

In [None]:
# Create the collection if it doesn't exist

COLLECTION_NAME = "chunks"
if not client.has_collection(collection_name=COLLECTION_NAME):
    print(f"Collection '{COLLECTION_NAME}' does not exist. Creating it.")
    schema = client.create_schema(auto_id=True, enable_dynamic_field=True)
    schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True)
    schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=2048)
    schema.add_field(field_name="doi", datatype=DataType.VARCHAR, max_length=64)
    schema.add_field(field_name="pubdate", datatype=DataType.INT64)  # Milvus has no date type
    schema.add_field(field_name="astrollama", datatype=DataType.FLOAT_VECTOR, dim=4096)

    client.create_collection(collection_name=COLLECTION_NAME, schema=schema)
print(f"Collections on client: {client.list_collections()}")
collection = Collection(COLLECTION_NAME)
print(f"{collection.num_entities} entities in collection {COLLECTION_NAME}")

In [None]:
# Iterate over DB contributions table and fetch data
with db.conn.cursor() as cursor:
    cursor.execute("SELECT COUNT(*) FROM chunks")
    total_chunks = cursor.fetchone()[0]

processed = 0
print(f"Total chunks to process: {total_chunks}")


In [None]:
# db.conn.rollback()

In [None]:
BATCH_SIZE = 8
with db.conn.cursor() as cursor:
    cursor.execute("SELECT text, doi, pubdate FROM chunks")
    while True:
        rows = cursor.fetchmany(BATCH_SIZE)
        # print(f"Fetched {len(rows)} rows.")
        if not rows:
            break
        df = rows_to_df(rows)
        # print(f"Converted rows to DataFrame with {len(df)} entries.")

        # Insert into collection
        result = collection.insert(df)
        processed += len(rows)
        print(f"\rProcessed {processed}/{total_chunks} total chunks.", end="", flush=True)