## Method 1: Every actor holds a full replica of the dataset
routing handled by actor pool

In [0]:
import ray
import numpy as np
import pyarrow as pa
import pyarrow.dataset as ds
import lancedb

num_test_vectors = 1_000_000

class LanceDBActor:
    def __init__(self, parquet_path, pyarrow_schema):
        
        # TODO: consider parameterizing
        num_partitions = 200
        num_sub_vectors = 5
        lance_db_uri = "/tmp/lancedb"
        # assuming 96 core CPU...
        os.environ["LANCE_CPU_THREADS"] = "96"
        os.environ["LANCE_IO_THREADS"] = "96"

        db = lancedb.connect(lance_db_uri)
        self.table_arrow = db.create_table(lance_table_name,
                                      data=_get_batches_from_parquet_with_progress(parquet_path,
                                                                                    pyarrow_schema,
                                                                                     batch_size=200_000),
                                      mode="overwrite"
                                      )
        
        self.table_arrow.create_index(
                metric="l2",
                vector_column_name="list_col",
                num_partitions=num_partitions,
                num_sub_vectors=num_sub_vectors
            )
        print("Index loaded.")

    @staticmethod
    def _get_batches_from_parquet_with_progress(parquet_path: str,
                                                schema: pa.Schema,
                                                 batch_size: int = 4096):
            """
            Reads a Parquet file in chunks and yields PyArrow RecordBatches,
            displaying progress using tqdm.
            """
        dataset = ds.dataset(parquet_path, format="parquet", schema=schema)

        total_rows = dataset.count_rows()
        total_batches = np.ceil(total_rows / batch_size)
        scanner = dataset.scanner(batch_size=batch_size)

        # Wrap the scanner.to_batches() with tqdm
        # We use `total_batches` for tqdm's 'total' argument.
        with tqdm(total=total_batches, unit="batch", desc="Ingesting Parquet Batches") as pbar:
            for batch in scanner.to_batches():
                yield batch
                pbar.update(1) # Manually update progress for each yielded batch
                pbar.set_postfix({"rows_in_batch": len(batch)})

    def __call__(self, batch: np.ndarray, limit: int):
        results = self.table_arrow.search(query_batch).limit(limit).to_pandas()

        return results


## TODO: Can actors read from Volumes?! Do we just need some S3 bucket instead?
audio_parquet_path = f'/Volumes/{catalog}/{schema}/{lance_table_name}'
pyarrow_schema = pa.schema(
    [
        pa.field("id", pa.int64()),
        pa.field("list_col", pa.list_(pa.float16(), 35)),   # Fixed size list
    ]
)

def create_arrays(n, dimensions):
    return [np.random.randint(0, 256, size=dimensions).astype(np.float16) for _ in range(n)]



# Make Ray Data dataset and inference with it
large_query_batch =  ray.data.from_items(create_arrays(num_test_vectors, dimensions=35))
large_query_batch.map_batches(LanceDBActor,
                              fn_constructor_args={'parquet_path': audio_parquet_path, 
                                                   'pyarrow_schema': pyarrow_schema},
                              fn_args={'limit': 1},
                              num_cpus=96,
                              memory=1.5e+12)






In [0]:
import ray
import numpy as np
import pyarrow as pa
import pyarrow.dataset as ds
import lancedb

@ray.core
class LanceDBActor:
    def __init__(self, parquet_path, pyarrow_schema):
        
        # TODO: consider parameterizing
        num_partitions = 200
        num_sub_vectors = 5
        lance_db_uri = "/tmp/lancedb"
        # assuming 96 core CPU...
        os.environ["LANCE_CPU_THREADS"] = "80"
        os.environ["LANCE_IO_THREADS"] = "32"

        db = lancedb.connect(lance_db_uri)
        self.table_arrow = db.create_table(lance_table_name,
                                      data=_get_batches_from_parquet_with_progress(parquet_path,
                                                                                    pyarrow_schema,
                                                                                     batch_size=200_000),
                                      mode="overwrite"
                                      )
        
        self.table_arrow.create_index(
                metric="l2",
                vector_column_name="list_col",
                num_partitions=num_partitions,
                num_sub_vectors=num_sub_vectors
            )
        print("Index loaded.")

    @staticmethod
    def _get_batches_from_parquet_with_progress(parquet_path: str,
                                                schema: pa.Schema,
                                                 batch_size: int = 4096):
            """
            Reads a Parquet file in chunks and yields PyArrow RecordBatches,
            displaying progress using tqdm.
            """
        dataset = ds.dataset(parquet_path, format="parquet", schema=schema)

        total_rows = dataset.count_rows()
        total_batches = np.ceil(total_rows / batch_size)
        scanner = dataset.scanner(batch_size=batch_size)

        # Wrap the scanner.to_batches() with tqdm
        # We use `total_batches` for tqdm's 'total' argument.
        with tqdm(total=total_batches, unit="batch", desc="Ingesting Parquet Batches") as pbar:
            for batch in scanner.to_batches():
                yield batch
                pbar.update(1) # Manually update progress for each yielded batch
                pbar.set_postfix({"rows_in_batch": len(batch)})

    def search(self, query_batch: np.ndarray, limit: int):
        results = self.table_arrow.search(query_batch).limit(limit).to_pandas()

        return results

# --- Main Application ---
# 1. Create Actor Pool// non-autoscalable
num_replicas = 64

## TODO: Can actors read from Volumes?! Do we just need some S3 bucket instead?
audio_parquet_path = f'/Volumes/{catalog}/{schema}/{lance_table_name}'
pyarrow_schema = pa.schema(
    [
        pa.field("id", pa.int64()),
        pa.field("list_col", pa.list_(pa.float16(), 35)),   # Fixed size list
    ]
)

actors = [LanceDBActor.remote(audio_parquet_path, pyarrow_schema) for _ in range(num_replicas)]
pool = ray.util.ActorPool(actors)

# 2. Divide a large batch of queries and submit
def create_arrays(n, dimensions):
    return [np.random.randint(0, 256, size=dimensions).astype(np.float16) for _ in range(n)]

num_test_vectors = 1_000_000

# map_unordered is efficient, yielding results as they complete
large_query_batch = create_arrays(num_test_vectors, 35)
results_generator = pool.map_unordered(lambda actor, batch: actor.search.remote(batch, limit=1),
                                       [large_query_batch])

for result in results_generator:
    
    # write results to spark
    pass
# Make Ray Data dataset.
# large_query_batch =  ray.data.from_items(create_arrays(num_test_vectors, dimensions=35))
# large_query_batch.map_batches(LanceDBActor, concurrency=10)






## Method 2: routing to shards

In [0]:
import ray
from qdrant_client import QdrantClient
# Assume 30 shard endpoints are configured
QDRANT_ENDPOINTS = ["host1:6333", "host2:6333", ..., "host30:6333"]

@ray.remote
class QueryRouter:
    def __init__(self):
        # This actor holds persistent connections to all DB shards
        self.clients = [QdrantClient(host=endpoint) for endpoint in QDRANT_ENDPOINTS]
        print("Initialized connections to all shards.")

    def _get_shard_index(self, query_metadata):
        # Simple example: determine shard from a 'user_id'
        # A more robust implementation would use consistent hashing
        user_id = query_metadata.get("user_id", 0)
        return user_id % len(self.clients)

    async def search(self, vector, query_metadata, top_k):
        shard_index = self._get_shard_index(query_metadata)
        client = self.clients[shard_index]

        # Asynchronously query the specific database shard
        # This is a non-blocking operation
        return await client.search(
            collection_name="my_collection",
            query_vector=vector,
            limit=top_k
        )

In [0]:
import ray
import numpy as np
# Assume Faiss is used for the index
# import faiss

@ray.remote
class ShardActor:
    def __init__(self, shard_id: int, total_shards: int):
        # Each actor loads only its assigned shard of the data
        index_path = f"s3://my-bucket/shard_{shard_id}_of_{total_shards}.faiss"
        print(f"Actor {shard_id} loading {index_path}...")
        # self.index = faiss.read_index(index_path)
        self.shard_id = shard_id
        print(f"Actor {shard_id} ready.")

    def search(self, query_vector: np.ndarray, k: int):
        # Search the local shard index
        # distances, local_indices = self.index.search(query_vector, k)
        # You must map local_indices back to global IDs
        # global_indices = self.map_local_to_global(local_indices)
        # return (distances, global_indices)
        # Placeholder for demonstration
        return (np.random.rand(query_vector.shape[0], k),
                np.random.randint(0, 3_000_000_000, size=(query_vector.shape[0], k)))


# --- Main Application ---
# 1. Create one actor for each shard
num_shards = 30
shard_actors = [ShardActor.remote(i, num_shards) for i in range(num_shards)]

# 2. Broadcast a single query to ALL shards
query_vector = np.random.rand(1, 768).astype('float32')
k = 10

# Fan-out the query to all actors
results_futures = [actor.search.remote(query_vector, k) for actor in shard_actors]
all_shard_results = ray.get(results_futures) # Returns a list of [(distances, indices), ...]

# 3. Merge results from all shards
# This is a critical step. You need to collect all candidates and re-rank them.
all_distances = np.concatenate([res[0][0] for res in all_shard_results])
all_indices = np.concatenate([res[1][0] for res in all_shard_results])

# Sort by distance and take the top k
top_k_indices = np.argsort(all_distances)[:k]
final_indices = all_indices[top_k_indices]
final_distances = all_distances[top_k_indices]

print("Final Top-K Indices:", final_indices)