In [2]:
import time
import json
import numpy as np
import concurrent.futures
from milvus import default_server
from pymilvus import (
    connections,
    FieldSchema, CollectionSchema, DataType,
    Collection,
    utility
)
from tqdm.notebook import tqdm
from utils import extract_sample_data, split_into_batches

In [3]:
# Connecte to remote Milvus instance by using kubectl port forwarding: 
# gcloud auth login
# gcloud container clusters get-credentials bioacoustics-devseed-staging-cluster --region=us-central1-f
# kubectl port-forward service/milvus 9091:9091 & \
# kubectl port-forward service/milvus 19530:19530 &

HOST = "127.0.0.1"
PORT = 19530
connections.connect(host=HOST, port=PORT)
print("Connections: ", connections.list_connections())

COLLECTION_NAME = "a2o_bioacoustics"


Connections:  [('default', <pymilvus.client.grpc_handler.GrpcHandler object at 0x110e24fa0>)]


In [4]:
# set up collection

# Drop collection if exists
if utility.has_collection(COLLECTION_NAME): 
    print("Collection exists, dropping...")
    collection = Collection(COLLECTION_NAME)
    collection.drop()

# define collection fields
id_field = FieldSchema(
    name="id", 
    dtype=DataType.INT64, 
    descrition="primary field", 
    is_primary=True, 
    auto_id=True
)

embedding_field = FieldSchema(
    name="embedding", 
    dtype=DataType.FLOAT_VECTOR, 
    description="Float32 vector with dim 256 (reduced using PCA from origina vector dim:1280)", 
    dim=256,
    is_primary=False
)
file_timestamp_field = FieldSchema(
    name="file_timestamp", 
    dtype=DataType.INT64, 
    description="UTC File timestamp (in seconds since 1970-01-01T00:00:00)"
)
file_seconds_since_midnight_field = FieldSchema(
    name="file_seconds_since_midnight", 
    dtype=DataType.INT64, 
    description="Number of seconds since 00h00 (within same timezone)"
)
clip_offset_in_file_field = FieldSchema(
    name="clip_offset_in_file", 
    dtype=DataType.INT64, 
    description="Offset (in seconds) from start of file where embedding window starts"
)
site_id_field = FieldSchema(
    name="site_id", 
    dtype=DataType.VARCHAR, 
    description="Site ID",
    max_length=8 # max len found in 1% sample = 4
)
site_name_field = FieldSchema(
    name="site_name", 
    dtype=DataType.VARCHAR, 
    description="Site name", 
    max_length=100 # max len found in 1% sample = 50
)
subsite_name_field = FieldSchema(
    name="subsite_name", 
    dtype=DataType.VARCHAR, 
    description="Subsite name", 
    max_length=5 # (one of: Dry-A, Dry-B, Wet-A, Wet-B)
)
file_seq_id_field = FieldSchema(
    name="file_seq_id", 
    dtype=DataType.INT64, 
    description="File sequence ID", # sequence id can be converted to int
)
filename_field = FieldSchema(
    name="filename", 
    dtype=DataType.VARCHAR, 
    max_length=500
)

schema = CollectionSchema(
    fields=[
        id_field,
        embedding_field, 
        file_timestamp_field,
        file_seconds_since_midnight_field,
        clip_offset_in_file_field, 
        site_id_field, 
        site_name_field, 
        subsite_name_field, 
        file_seq_id_field, 
        filename_field
    ], 
    description="Collection for searching A20 bird embeddings"
)
collection = Collection(
    name=COLLECTION_NAME, 
    data=None,
    schema=schema, 
    # Set TTL to 0 to disable
    properties={"collection.ttl.seconds": 0}
)
print(f"Collections: {utility.list_collections()}")
print(f"Collection {COLLECTION_NAME} instantiated with {collection.num_entities} entities")

if len(collection.indexes): 
    print(f"Dropping current index on field: {collection.index().field_name} -> {collection.index().params}")
    collection.drop_index()
    print(f"Indexes remaining after dropping: {collection.indexes}")
    
print(f"Creating new index: IVF_SQ8 with params nlist:4096")
# create new index: 
index_params = {
    "index_type": "IVF_SQ8",
    "params":{"nlist": 4096},
    "metric_type": "L2"
}

collection.create_index("embedding", index_params)

print(f"Collection {COLLECTION_NAME} index created...")
collection.flush()

Collection exists, dropping...
Collections: ['a2o_bioacoustics']
Collection a2o_bioacoustics instantiated with 0 entities
Creating new index: IVF_SQ8 with params nlist:4096
Collection a2o_bioacoustics index created...


In [5]:
BATCH_SIZE=1000
FIELD_NAMES = (
    "embedding",
    "file_timestamp",
    "file_seconds_since_midnight",
    "recording_offset_in_file",
    "site_id",
    "site_name",
    "subsite_name", 
    "file_seq_id", 
    "filename"
)

def load_data(metadata_file): 
    start = time.time()
    _embeddings = np.load(metadata_file.replace("metadata", "numpy_reduced").replace(".json", ".npy"))
    
    _metadata = None
    with open(metadata_file, "r") as f: 
        _metadata = json.loads(f.read())
        
    assert len(_embeddings) == len(_metadata)
    
    data = [
        {"embedding": _embeddings[i], **_metadata[i]} 
        for i in range(len(_metadata)) 
    ]
    
    start = time.time()
    
    collection = Collection(COLLECTION_NAME)
    
    for batch in split_into_batches(data, BATCH_SIZE):
        
        collection.insert(
            [[_data[fieldname] for _data in batch] for fieldname in FIELD_NAMES]
        )
        collection.flush()

    return len(_embeddings)

   

In [6]:
%%time
#data = extract_sample_data()
metadata_filenames = [f"./one_percent_embeddings_metadata/a2o_sample_embeddings-{i:05}-of-00374.json" for i in range(0, 374)]
#reduced_embedding_numpy_filenames = [f"./one_percent_embeddings_numpy_reduced/a2o_sample_embeddings-{i:05}-of-00374.npy" for i in range(0, 374)]

with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
    results = list(
        tqdm(
            executor.map(
                lambda x: load_data(x),
                metadata_filenames
            ), 
            total=len(metadata_filenames) 
        )
    ) 

print(f"Collection {COLLECTION_NAME} currently loaded with {collection.num_entities} entities")
  

  0%|          | 0/374 [00:00<?, ?it/s]

Collection a2o_bioacoustics currently loaded with 14412192 entities
CPU times: user 23min 32s, sys: 4min 42s, total: 28min 15s
Wall time: 5h 48min 34s


In [None]:
%%time
if len(collection.indexes): 
    print(f"Dropping current index on field: {collection.index().field_name} -> {collection.index().params}")
    collection.drop_index()
    print(f"Indexes remaining after dropping: {collection.indexes}")
    
print(f"Creating new index: IVF_SQ8 with params nlist:4096")
# create new index: 
index_params = {
    "index_type": "IVF_SQ8",
    "params":{"nlist": 4096},
    "metric_type": "L2"
}

collection.create_index("embedding", index_params)

In [64]:
_embeddings = np.load(metadata_filenames[1].replace("metadata", "numpy_reduced").replace(".json", ".npy"))

search_vectors = _embeddings[np.random.choice(range(len(_embeddings)), size=1)]

In [104]:
%%time
collection.load()

#search_vectors = [data[i]["embedding"] for i in np.random.choice(range(len(data)), size=5)]

search_param_1 = {
    "data": search_vectors,
    "anns_field": "embedding",
    "param": {"metric_type": "L2", "params": {"nprobe": 16}},
    "limit": 20,
}

search_results_1 = collection.search(**search_param_1)

search_param_2 = {
    "data": search_vectors,
    "anns_field": "embedding",
    "param": {"metric_type": "L2", "params": {"nprobe": 16},  "offset": 10},
    "limit": 10,
}

search_results_2 = collection.search(**search_param_2)

[r.id for r in search_results_1[0]][10:] ==  [r.id for r in search_results_2[0]]

CPU times: user 5.18 ms, sys: 2.6 ms, total: 7.78 ms
Wall time: 1.53 s


True

In [99]:
id1 = [r.id for r in search_results_1[0]]
id2 = [r.id for r in search_results_2[0]]
print(id1)
print(id2)

[441649888454387164, 441649888466767199, 441649888468615140, 441649888466356568, 441649888468615139, 441649888455936804, 441649888464992819, 441649888455252561, 441649888455287219, 441649888462164921, 441649888457944054, 441649888461832523, 441649888461609117, 441649888460221637, 441649888463097280, 441649888461609114, 441649888465594860, 441649888464556168, 441649888457944049, 441649888457561596]
[441649888457944054, 441649888461832523, 441649888461609117, 441649888460221637, 441649888463097280, 441649888461609114, 441649888465594860, 441649888464556168, 441649888457944049, 441649888457561596]


In [None]:
import json
with open("sample_search_vectors.json", "w") as f: 
    f.write(json.dumps([v.tolist() for v in search_vectors]))


In [None]:
list(search_results)