In [None]:
import csv
import re
from pymilvus import connections, db, Collection, utility, FieldSchema, DataType, CollectionSchema
import numpy as np

In [None]:
# 1. Create a variable to store the path to the file
PATH = "/home/marco/image-viz/backend/src/caption_model/captions/captions_wikiart.csv"
PATH_CLEANED = "/home/marco/image-viz/backend/src/caption_model/captions/captions_wikiart_cleaned.csv"

# 2. Open the file
with open(PATH, "r") as file, open(PATH_CLEANED, "w") as file_cleaned:
    # 3. Create a csv reader object and a csv writer object
    reader = csv.reader(file)
    writer = csv.writer(file_cleaned)
    # 4. Write the header
    writer.writerow(next(reader))
    # 5. Loop over the rows
    for row in reader:
        # Split the string at "ASSISTANT:"
        split_caption = row[1].split("ASSISTANT:", 1)
        if len(split_caption) > 1:
            # Select the part after "ASSISTANT:"
            assistant_caption = split_caption[1]
            # Replace newlines with spaces
            cleaned_caption = re.sub('\n+', ' ', assistant_caption)
            # Remove unnecessary spaces at the beginning and end
            cleaned_caption = cleaned_caption.strip()
            # Write the cleaned caption to the new file
            writer.writerow([row[0], cleaned_caption])

In [None]:
# 1. Create a variable to store the path to the file
DATASET = "wikiart"
PATH = f"/home/marco/image-viz/backend/src/caption_model/captions/captions_{DATASET}.csv"

captions = {}
# 2. Open the file
with open(PATH, "r") as file:
    # 3. Create a csv reader object
    reader = csv.reader(file)
    header = next(reader)
    # 5. Loop over the rows
    for row in reader:
        # Get captions
        captions[int(row[0])] = row[1]
        
captions[123]

In [None]:
# Create milvus connection
connections.connect(
        host="0.0.0.0",
        port=19530,
        user="root",
        password="Milvus"
    )
# Use database
db.using_database("aiplusart")
# Use collection
collection = Collection(DATASET)

In [None]:
 # Fetch vectors
entities = []
try:
    for i in range(0, collection.num_entities, 16384):
        if collection.num_entities > 0:
            # Get SEARCH_LIMIT entities
            query_result = collection.query(
                expr=f"index in {list(range(i, i + 16384))}",
                output_fields=["*"]
            )
            # Add entities to the list of entities
            entities += query_result
except Exception as e:
    print(e.__str__())
    print("Error in update_metadata. Update failed!")

entities[123]

In [None]:
len(entities)

In [None]:
# Order entities by index
entities = sorted(entities, key=lambda x: x["index"])
# Assert that the indices are correct, i.e., the range is the same as the number of captions
assert len(entities) == len(captions)

In [None]:
# Add captions to entities
# Update vectors
for i in range(len(entities)):
    assert (entities[i]["index"] == i)
    entities[i]["caption"] = captions[entities[i]["index"]]
    
entities[123]

In [None]:
EMBEDDING_VECTOR_FIELD_NAME = "embedding"
COSINE_METRIC = "COSINE"
INDEX_TYPE = "FLAT"

def embeddings_collection(collection_name: str):
    # Create fields for collection
    index = FieldSchema(
        name="index",
        dtype=DataType.INT64,
        is_primary=True
    )
    x = FieldSchema(
        name="x",
        dtype=DataType.FLOAT,
        default_value=np.nan
    )
    y = FieldSchema(
        name="y",
        dtype=DataType.FLOAT,
        default_value=np.nan
    )
    embedding = FieldSchema(
        name=EMBEDDING_VECTOR_FIELD_NAME,
        dtype=DataType.FLOAT_VECTOR,
        dim=512
    )

    # Create collection schema
    schema = CollectionSchema(
        fields=[embedding, x, y, index],
        description="embeddings",
        enable_dynamic_field=True
    )

    # Create collection
    collection = Collection(
        name=collection_name,
        schema=schema,
        shards_num=1  # type: ignore
    )

    # Create index for embedding field to make similarity search faster
    index_params = {
        "metric_type": COSINE_METRIC,
        "index_type": INDEX_TYPE,
        "params": {}
    }

    collection.create_index(
        field_name="embedding",
        index_params=index_params
    )

    return collection

In [None]:
try:
    # Create cluster collection
    new_collection = embeddings_collection("temp_" + DATASET)
    # Do for loop to avoid resource exhaustion
    for i in range(0, len(entities), 16384):
        new_collection.insert(data=[entities[j] for j in range(i, i + 16384) if j < len(entities)])
        new_collection.flush()
except Exception as e:
    print(e.__str__())
    print("Error in update_metadata. Update failed!")
    utility.drop_collection("temp_" + DATASET)

In [None]:
# Drop old collection and rename new collection
try:
    # Drop old collection
    utility.drop_collection(DATASET)
    # Rename new collection
    utility.rename_collection("temp_" + DATASET, DATASET, "aiplusart")
except Exception as e:
    print(e.__str__())
    print("Error in update_metadata. Update failed!")