In [0]:
%pip install ray[all]

In [0]:
ray.init(num_cpus=4) # Assuming at least 4 CPUs for this example

In [0]:
import ray

@ray.remote
class ShardActor:
    def __init__(self, shard_id, data_path):
        self.shard_id = shard_id
        # Load shard data here
        print(f"ShardActor {shard_id} initialized on {ray.util.get_node_ip_address()}")
        self.data = f"Data for shard {shard_id}" # Placeholder
        # For vector search, load your HNSW index etc.

    def lookup(self, query):
        # Perform lookup on self.data
        # print(f"Lookup on shard {self.shard_id} for query {query}")
        return f"Result from shard {self.shard_id} for {query}"

# Instead of: shard_actors = {i: ShardActor.remote(i, f"path_to_shard_{i}") for i in range(num_shards)}

# Use multiple actors for each shard (especially for anticipated hot shards)
num_shards = 4
num_actors_per_shard = 2 # Or more for hot shards, less for cold

shard_actor_groups = {}
for i in range(num_shards):
    shard_actor_groups[i] = [ShardActor.remote(i, f"path_to_shard_{i}") for _ in range(num_actors_per_shard)]

@ray.remote
class Router:
    def __init__(self, shard_actor_groups):
        self.shard_actor_groups = shard_actor_groups
        self.counters = {shard_id: 0 for shard_id in shard_actor_groups.keys()}

    def route_request(self, query):
        shard_id = self._determine_shard(query) # Implement your sharding logic

        # Round-robin or intelligent load balancing among actors for the same shard
        actor_idx = self.counters[shard_id] % len(self.shard_actor_groups[shard_id])
        self.counters[shard_id] += 1

        selected_actor = self.shard_actor_groups[shard_id][actor_idx]
        return ray.get(selected_actor.lookup.remote(query))

    def _determine_shard(self, query):
        # Simple example: Assuming query contains a shard_id hint
        # In real scenario, this would be based on vector content or ID
        if "shard_1" in query:
            return 1
        elif "shard_0" in query:
            return 0
        elif "shard_2" in query:
            return 2
        else: # Default for demonstration
            return 3


# --- Demo ---
router = Router.remote(shard_actor_groups)

# Simulate skewed load towards shard 1
print("\nSimulating skewed load...")
results = []
for _ in range(10): # More requests for shard 1
    results.append(router.route_request.remote("query for shard_1"))
for _ in range(2): # Fewer requests for other shards
    results.append(router.route_request.remote("query for shard_0"))
    results.append(router.route_request.remote("query for shard_2"))
    results.append(router.route_request.remote("query for shard_3"))

ray.get(results)
print("All lookups complete.")

# You would observe that requests for shard 1 are distributed among the multiple actors dedicated to shard 1.

# ray.shutdown()

In [0]:
ray.get(results)