# Setup Milvus collection and ingest embeddings + metadata
This notebook will walk through the steps necessary create a collection in Milvus for the one percent data sample, and ingest the metadata and reduced embeddings.

### Pre-requisites: 
You will need the reduced embeddings generated in the pervious notebook: `2_train_apply_pca.ipynb`

**Author:** Leo Thomas - leo@developmentseed.org\
**Last updated:** 2023/06/15

In [8]:
import json
import numpy as np
import os
import concurrent.futures
from pymilvus import (
    connections,
    FieldSchema, CollectionSchema, DataType,
    Collection,
    utility
)
from tqdm.notebook import tqdm

### 1.0. Instantiate a connection through kubectl to the Milvus instance

In [2]:
# Connecte to remote Milvus instance by using kubectl port forwarding: 
# gcloud auth login
# gcloud container clusters get-credentials bioacoustics-devseed-staging-cluster --region=us-central1-f  --project dulcet-clock-385511
# kubectl port-forward service/milvus 9091:9091 & kubectl port-forward service/milvus 19530:19530 &

HOST = "127.0.0.1"
PORT = 19530
connections.connect(host=HOST, port=PORT)
print("Connections: ", connections.list_connections())

Connections:  [('default', <pymilvus.client.grpc_handler.GrpcHandler object at 0x10cbf1c40>)]


In [3]:
COLLECTION_NAME = "a2o_bioacoustics"

### 2.0. Drop collection if exists. 

# WARNING: this will delete whichever collection is deployed proceed with CAUTION
The rest of the notebook will walk throught the steps to load data into the Milvus instance and create the necessary indexes, but the process may take quite a while, and requires uploading ~20Gb of reduced embeddings and metadata

In [None]:
# Drop collection if exists
if utility.has_collection(COLLECTION_NAME): 
    print("Collection exists, dropping...")
    collection = Collection(COLLECTION_NAME)
    collection.drop()

### 3.0. Define schema for the new collection and create

In [None]:
# define collection fields
id_field = FieldSchema(
    name="id", dtype=DataType.INT64, descrition="primary field", is_primary=True, auto_id=True
)

embedding_field = FieldSchema(
    name="embedding", 
    dtype=DataType.FLOAT_VECTOR, 
    description="Float32 vector with dim 256 (reduced using PCA from origina vector dim:1280)", 
    dim=256,
    is_primary=False
)

file_timestamp_field = FieldSchema(
    name="file_timestamp", 
    dtype=DataType.INT64, 
    description="UTC File timestamp (in seconds since 1970-01-01T00:00:00)"
)

file_seconds_since_midnight_field = FieldSchema(
    name="file_seconds_since_midnight", 
    dtype=DataType.INT64, 
    description="Number of seconds since 00h00 (within same timezone)"
)

clip_offset_in_file_field = FieldSchema(
    name="clip_offset_in_file", 
    dtype=DataType.INT64, 
    description="Offset (in seconds) from start of file where embedding window starts"
)

# max len found in 1% sample = 4
site_id_field = FieldSchema(
    name="site_id", dtype=DataType.VARCHAR, description="Site ID", max_length=8
)

# max len found in 1% sample = 50
site_name_field = FieldSchema(
    name="site_name", dtype=DataType.VARCHAR, description="Site name", max_length=100 
)

# (one of: Dry-A, Dry-B, Wet-A, Wet-B)
subsite_name_field = FieldSchema(
    name="subsite_name", dtype=DataType.VARCHAR, description="Subsite name", max_length=5 
)

# sequence id can be converted to int
file_seq_id_field = FieldSchema(
    name="file_seq_id", dtype=DataType.INT64, description="File sequence ID", 
)
filename_field = FieldSchema(
    name="filename", dtype=DataType.VARCHAR, max_length=500
)

schema = CollectionSchema(
    fields=[
        id_field,
        embedding_field, 
        file_timestamp_field,
        file_seconds_since_midnight_field,
        clip_offset_in_file_field, 
        site_id_field, 
        site_name_field, 
        subsite_name_field, 
        file_seq_id_field, 
        filename_field
    ], 
    description="Collection for A20 bird call embeddings"
)


collection = Collection(
    name=COLLECTION_NAME, 
    # instantiate colleciton with no data
    data=None,
    schema=schema, 
    # Set TTL to 0 to disable data expiration
    properties={"collection.ttl.seconds": 0}
)

print(f"Collections: {utility.list_collections()}")

### 4.0. Create index on the embedding field in the collection
The index will greatly reduce search time and further reduce memory footprint, at the cost of a slight decrease in accuracy. See the index evaluation notebook for a thorough comparison of the various indexing strategies

In [None]:
print(f"Creating new index: IVF_SQ8 with params nlist:4096")
# create new index: 
index_params = {
    "index_type": "IVF_SQ8",
    "params":{"nlist": 4096},
    "metric_type": "L2"
}

collection.create_index("embedding", index_params)

### 5.0. Define constants
The `BATCH_SIZE` parameter will defined how many data entities are loaded at once. Very large batches can overload the machine's memory and very small batches can lead to longer ingestion times

In [None]:
BATCH_SIZE=1000
FIELD_NAMES = (
    "embedding",
    "file_timestamp",
    "file_seconds_since_midnight",
    "recording_offset_in_file",
    "site_id",
    "site_name",
    "subsite_name", 
    "file_seq_id", 
    "filename"
)
DATA_DIR = os.path.abspath("./one_percent_data_sample")
METADATA_DIR = os.path.join(DATA_DIR, "metadata")
REDUCED_EMBEDDINGS_DIR = os.path.join(DATA_DIR, "reduced_embeddings")

In [5]:
# helper function to batch data: 
def split_into_batches(data, n=10_000): 
    for i in range(0, len(data), n):
        yield data[i:i + n]

### 6.0. Load the data
The insert operation expects a list of lists, where each sub-list contains the values corresponding to one field: 
eg: 
```python
[
    [ embedding_1, embedding_2, ..., embedding_n],
    [ file_timestamp_1, file_timestamp_2, ..., file_timestamp_n],
    ...
    [filename_1, filename_2, ..., filename_3]
]
```

In [10]:
%%time
def load_data(metadata_file): 
    
    embeddings_filename = metadata_file.split("/")[-1].replace(".json", ".npy")
    embeddings_file = os.path.join(REDUCED_EMBEDDINGS_DIR, embeddings_filename)
    _embeddings = np.load(embeddings_file)
    
    _metadata = None
    with open(metadata_file, "r") as f: 
        _metadata = json.loads(f.read())
        
    assert len(_embeddings) == len(_metadata)
    
    data = [
        {"embedding": _embeddings[i], **_metadata[i]} 
        for i in range(len(_metadata)) 
    ]
        
    collection = Collection(COLLECTION_NAME)
    
    for batch in split_into_batches(data, BATCH_SIZE): 
        
        collection.insert(
            [[_data[fieldname] for _data in batch] for fieldname in FIELD_NAMES]
        )
        collection.flush()

    return len(_embeddings)

metadata_files = [
    os.path.join(METADATA_DIR, file) 
    for file in os.listdir(METADATA_DIR) 
    if os.path.isfile(os.path.join(METADATA_DIR, file))
]

with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
    results = list(
        tqdm(
            executor.map(
                lambda x: load_data(x),
                metadata_files
            ), 
            total=len(metadata_files) 
        )
    ) 

print(f"Collection {COLLECTION_NAME} currently loaded with {collection.num_entities} entities")