In [None]:
import os, sys, time, math, json, random, threading, hashlib
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Optional, Any
import numpy as np
import subprocess
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "faiss-cpu", "xxhash"])
import faiss
import xxhash

def now_ms() -> int:
    return int(time.time() * 1000)

def stable_hash64(s: str) -> int:
    return xxhash.xxh3_64_intdigest(s)

def l2_normalize(x: np.ndarray, eps: float = 1e-12) -> np.ndarray:
    n = np.linalg.norm(x, axis=-1, keepdims=True)
    return x / np.clip(n, eps, None)

def merge_topk(results: List[Tuple[np.ndarray, np.ndarray]], k: int) -> Tuple[np.ndarray, np.ndarray]:
    scores = np.concatenate([r[0] for r in results], axis=0)
    ids = np.concatenate([r[1] for r in results], axis=0)
    m = ids >= 0
    scores, ids = scores[m], ids[m]
    if len(ids) == 0:
        return np.array([], dtype=np.float32), np.array([], dtype=np.int64)
    idx = np.argpartition(-scores, min(k, len(scores)-1))[:k]
    idx = idx[np.argsort(-scores[idx])]
    return scores[idx].astype(np.float32), ids[idx].astype(np.int64)

def sha1_bytes(b: bytes) -> str:
    return hashlib.sha1(b).hexdigest()

def deterministic_rng(seed: int) -> np.random.Generator:
    return np.random.default_rng(seed)

@dataclass
class VectorRecord:
    key: str
    vec: np.ndarray
    payload: Dict[str, Any]
    version: int
    ts_ms: int

@dataclass
class NodeStatus:
    node_id: str
    is_up: bool = True
    last_seen_ms: int = field(default_factory=now_ms)

In [None]:
class ShardNode:
    def __init__(self, node_id: str, dim: int, metric: str = "cosine"):
        self.node_id = node_id
        self.dim = dim
        self.metric = metric
        self.status = NodeStatus(node_id=node_id)
        self._store: Dict[str, VectorRecord] = {}
        self._next_int_id = 0
        self._key_to_int: Dict[str, int] = {}
        self._int_to_key: Dict[int, str] = {}
        if metric == "cosine":
            self.index = faiss.IndexFlatIP(dim)
        elif metric == "l2":
            self.index = faiss.IndexFlatL2(dim)
        else:
            raise ValueError("metric must be 'cosine' or 'l2'")
        self._vectors: List[np.ndarray] = []
        self._lock = threading.RLock()

    def set_up(self, up: bool):
        with self._lock:
            self.status.is_up = up
            self.status.last_seen_ms = now_ms()

    def is_up(self) -> bool:
        return self.status.is_up

    def _ensure_up(self):
        if not self.status.is_up:
            raise RuntimeError(f"Node {self.node_id} is DOWN")

    def upsert(self, rec: VectorRecord) -> bool:
        with self._lock:
            self._ensure_up()
            prev = self._store.get(rec.key)
            if prev is not None:
                if (rec.version, rec.ts_ms) <= (prev.version, prev.ts_ms):
                    return False
            v = rec.vec.astype(np.float32)
            if v.shape != (self.dim,):
                raise ValueError(f"vec must be shape ({self.dim},), got {v.shape}")
            if self.metric == "cosine":
                v = l2_normalize(v)
            self._store[rec.key] = VectorRecord(
                key=rec.key,
                vec=v,
                payload=dict(rec.payload),
                version=int(rec.version),
                ts_ms=int(rec.ts_ms),
            )
            if rec.key not in self._key_to_int:
                int_id = self._next_int_id
                self._next_int_id += 1
                self._key_to_int[rec.key] = int_id
                self._int_to_key[int_id] = rec.key
                self._vectors.append(v)
                self.index.add(v.reshape(1, -1))
            else:
                int_id = self._key_to_int[rec.key]
                self._vectors[int_id] = v
                self._rebuild_index()
            return True

    def _rebuild_index(self):
        self.index.reset()
        if len(self._vectors) == 0:
            return
        mat = np.stack(self._vectors, axis=0).astype(np.float32)
        self.index.add(mat)

    def get(self, key: str) -> Optional[VectorRecord]:
        with self._lock:
            self._ensure_up()
            rec = self._store.get(key)
            if rec is None:
                return None
            return VectorRecord(
                key=rec.key,
                vec=rec.vec.copy(),
                payload=dict(rec.payload),
                version=rec.version,
                ts_ms=rec.ts_ms,
            )

    def batch_get(self, keys: List[str]) -> Dict[str, VectorRecord]:
        out = {}
        with self._lock:
            self._ensure_up()
            for k in keys:
                r = self._store.get(k)
                if r is not None:
                    out[k] = VectorRecord(
                        key=r.key, vec=r.vec.copy(), payload=dict(r.payload),
                        version=r.version, ts_ms=r.ts_ms
                    )
        return out

    def search(self, q: np.ndarray, k: int) -> Tuple[np.ndarray, np.ndarray]:
        with self._lock:
            self._ensure_up()
            if len(self._vectors) == 0:
                return np.array([], dtype=np.float32), np.array([], dtype=np.int64)
            qq = q.astype(np.float32).reshape(1, -1)
            if qq.shape != (1, self.dim):
                raise ValueError(f"query must be shape ({self.dim},), got {q.shape}")
            if self.metric == "cosine":
                qq = l2_normalize(qq)
            D, I = self.index.search(qq, k)
            D = D.reshape(-1)
            I = I.reshape(-1)
            node_tag = stable_hash64(self.node_id) & 0xFFFFFFFF
            global_ids = np.array([(node_tag << 32) | (int(i) & 0xFFFFFFFF) for i in I], dtype=np.int64)
            return D.astype(np.float32), global_ids

    def decode_global_id(self, gid: int) -> Optional[str]:
        with self._lock:
            node_tag = stable_hash64(self.node_id) & 0xFFFFFFFF
            if ((gid >> 32) & 0xFFFFFFFF) != node_tag:
                return None
            int_id = int(gid & 0xFFFFFFFF)
            return self._int_to_key.get(int_id)

    def digest(self) -> str:
        with self._lock:
            self._ensure_up()
            items = sorted((k, r.version, r.ts_ms) for k, r in self._store.items())
            b = json.dumps(items, separators=(",", ":")).encode("utf-8")
            return sha1_bytes(b)

    def diff_keys(self, other_meta: Dict[str, Tuple[int, int]]) -> List[str]:
        with self._lock:
            self._ensure_up()
            need = []
            for k, (v, t) in other_meta.items():
                mine = self._store.get(k)
                if mine is None or (mine.version, mine.ts_ms) < (v, t):
                    need.append(k)
            return need

    def meta_snapshot(self) -> Dict[str, Tuple[int, int]]:
        with self._lock:
            self._ensure_up()
            return {k: (r.version, r.ts_ms) for k, r in self._store.items()}

In [None]:
class ConsistentHashRing:
    def __init__(self, node_ids: List[str], vnodes: int = 64):
        self.vnodes = int(vnodes)
        self._ring: List[Tuple[int, str]] = []
        for nid in node_ids:
            for v in range(self.vnodes):
                h = stable_hash64(f"{nid}::vn{v}")
                self._ring.append((h, nid))
        self._ring.sort(key=lambda x: x[0])
        self._hashes = [h for h, _ in self._ring]

    def owners(self, key: str, rf: int) -> List[str]:
        h = stable_hash64(key)
        import bisect
        i = bisect.bisect_left(self._hashes, h)
        if i == len(self._hashes):
            i = 0
        owners = []
        seen = set()
        j = i
        while len(owners) < rf and len(seen) < len(set(n for _, n in self._ring)):
            nid = self._ring[j][1]
            if nid not in seen:
                owners.append(nid)
                seen.add(nid)
            j = (j + 1) % len(self._ring)
        return owners

In [None]:
class DistributedVectorDB:
    def __init__(
        self,
        dim: int,
        node_ids: List[str],
        rf: int = 2,
        w: int = 2,
        r: int = 1,
        vnodes: int = 64,
        metric: str = "cosine",
        anti_entropy_interval_s: float = 2.0,
        anti_entropy_enabled: bool = True,
        seed: int = 7,
    ):
        if rf < 1:
            raise ValueError("rf must be >= 1")
        if w < 1 or r < 1:
            raise ValueError("w and r must be >= 1")
        if w > rf or r > rf:
            raise ValueError("w and r must be <= rf")

        self.dim = int(dim)
        self.node_ids = list(node_ids)
        self.rf = int(rf)
        self.w = int(w)
        self.r = int(r)
        self.metric = metric
        self.ring = ConsistentHashRing(self.node_ids, vnodes=vnodes)
        self.nodes: Dict[str, ShardNode] = {nid: ShardNode(nid, dim=self.dim, metric=metric) for nid in self.node_ids}

        self._lamport: Dict[str, int] = {}
        self._lamport_lock = threading.Lock()

        self._rng = deterministic_rng(seed)

        self._anti_entropy_interval_s = float(anti_entropy_interval_s)
        self._anti_entropy_enabled = bool(anti_entropy_enabled)
        self._stop = threading.Event()
        self._bg = None
        if self._anti_entropy_enabled:
            self._bg = threading.Thread(target=self._anti_entropy_loop, daemon=True)
            self._bg.start()

    def stop(self):
        self._stop.set()
        if self._bg is not None:
            self._bg.join(timeout=2)

    def set_node_up(self, node_id: str, up: bool):
        self.nodes[node_id].set_up(up)

    def _owners(self, key: str) -> List[ShardNode]:
        return [self.nodes[nid] for nid in self.ring.owners(key, self.rf)]

    def _next_version(self, key: str) -> int:
        with self._lamport_lock:
            v = self._lamport.get(key, 0) + 1
            self._lamport[key] = v
            return v

    def upsert(self, key: str, vec: np.ndarray, payload: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
        payload = payload or {}
        owners = self._owners(key)

        rec = VectorRecord(
            key=key,
            vec=np.asarray(vec, dtype=np.float32),
            payload=dict(payload),
            version=self._next_version(key),
            ts_ms=now_ms(),
        )

        acks = 0
        errors = []
        applied_nodes = []
        for n in owners:
            try:
                ok = n.upsert(rec)
                if ok:
                    acks += 1
                    applied_nodes.append(n.node_id)
            except Exception as e:
                errors.append(f"{n.node_id}: {e}")

        return {
            "key": key,
            "owners": [n.node_id for n in owners],
            "acks": acks,
            "required_w": self.w,
            "applied_nodes": applied_nodes,
            "ok": (acks >= self.w),
            "errors": errors,
            "version": rec.version,
            "ts_ms": rec.ts_ms,
        }

    def get(self, key: str) -> Dict[str, Any]:
        owners = self._owners(key)
        responses: List[Tuple[str, Optional[VectorRecord], Optional[str]]] = []
        good = 0

        for n in owners:
            try:
                rec = n.get(key)
                responses.append((n.node_id, rec, None))
                if rec is not None:
                    good += 1
            except Exception as e:
                responses.append((n.node_id, None, str(e)))

        up_contacts = sum(1 for _, _, err in responses if err is None)
        ok = up_contacts >= self.r

        latest = None
        for _, rec, err in responses:
            if err is not None or rec is None:
                continue
            if latest is None or (rec.version, rec.ts_ms) > (latest.version, latest.ts_ms):
                latest = rec

        repaired = []
        if latest is not None:
            for nid, rec, err in responses:
                if err is not None:
                    continue
                if rec is None or (rec.version, rec.ts_ms) < (latest.version, latest.ts_ms):
                    try:
                        self.nodes[nid].upsert(latest)
                        repaired.append(nid)
                    except Exception:
                        pass

        return {
            "key": key,
            "owners": [n.node_id for n in owners],
            "required_r": self.r,
            "up_contacts": up_contacts,
            "ok": ok,
            "latest": None if latest is None else {
                "version": latest.version,
                "ts_ms": latest.ts_ms,
                "payload": latest.payload,
                "vec_preview": latest.vec[:6].tolist(),
            },
            "read_repair_applied_to": repaired,
            "responses": [
                {
                    "node": nid,
                    "has": (rec is not None),
                    "version": None if rec is None else rec.version,
                    "ts_ms": None if rec is None else rec.ts_ms,
                    "err": err,
                }
                for nid, rec, err in responses
            ],
        }

In [None]:
    def knn_search(self, q: np.ndarray, k: int = 5, fanout: Optional[int] = None) -> Dict[str, Any]:
        q = np.asarray(q, dtype=np.float32).reshape(-1)
        if q.shape != (self.dim,):
            raise ValueError(f"q must be shape ({self.dim},), got {q.shape}")

        node_list = list(self.nodes.values())
        if fanout is not None:
            fanout = min(int(fanout), len(node_list))
            node_list = list(self._rng.choice(node_list, size=fanout, replace=False))

        partials = []
        per_node = []
        for n in node_list:
            try:
                D, I = n.search(q, k)
                partials.append((D, I))
                per_node.append({"node": n.node_id, "count": int(len(I)), "ok": True})
            except Exception as e:
                per_node.append({"node": n.node_id, "count": 0, "ok": False, "err": str(e)})

        scores, gids = merge_topk(partials, k=k)

        hits = []
        for score, gid in zip(scores.tolist(), gids.tolist()):
            key = None
            for n in self.nodes.values():
                key = n.decode_global_id(gid)
                if key is not None:
                    break
            if key is None:
                continue
            doc = self.get(key)
            payload = None if doc["latest"] is None else doc["latest"]["payload"]
            hits.append({"score": float(score), "key": key, "payload": payload})

        return {
            "k": int(k),
            "fanout": None if fanout is None else int(fanout),
            "metric": self.metric,
            "per_node": per_node,
            "hits": hits,
        }

    def _anti_entropy_loop(self):
        while not self._stop.is_set():
            try:
                self._anti_entropy_round()
            except Exception:
                pass
            self._stop.wait(self._anti_entropy_interval_s)

    def _anti_entropy_round(self):
        nids = self.node_ids[:]
        self._rng.shuffle(nids)
        pairs = list(zip(nids[::2], nids[1::2]))
        for a, b in pairs:
            na, nb = self.nodes[a], self.nodes[b]
            if not (na.is_up() and nb.is_up()):
                continue
            da, db = na.digest(), nb.digest()
            if da == db:
                continue

            meta_a = na.meta_snapshot()
            meta_b = nb.meta_snapshot()

            need_a = na.diff_keys(meta_b)
            need_b = nb.diff_keys(meta_a)

            if need_a:
                got = nb.batch_get(need_a)
                for rec in got.values():
                    na.upsert(rec)
            if need_b:
                got = na.batch_get(need_b)
                for rec in got.values():
                    nb.upsert(rec)

    def bulk_upsert(self, items: List[Tuple[str, np.ndarray, Dict[str, Any]]]) -> Dict[str, Any]:
        ok = 0
        bad = 0
        details = []
        for key, vec, payload in items:
            res = self.upsert(key, vec, payload)
            if res["ok"]:
                ok += 1
            else:
                bad += 1
            details.append({"key": key, "ok": res["ok"], "acks": res["acks"], "owners": res["owners"], "errors": res["errors"]})
        return {"ok": ok, "bad": bad, "total": ok + bad, "details_preview": details[:5]}

In [2]:
print("============================================================")
print("Distributed Vector DB demo: sharding + replication + quorums")
print("============================================================")

DIM = 64
NODES = ["node-a", "node-b", "node-c", "node-d"]
RF = 2
W = 2
R = 1

db = DistributedVectorDB(
    dim=DIM,
    node_ids=NODES,
    rf=RF,
    w=W,
    r=R,
    vnodes=128,
    metric="cosine",
    anti_entropy_interval_s=1.5,
    anti_entropy_enabled=True,
    seed=42,
)

rng = np.random.default_rng(123)
num_items = 500
base = rng.normal(size=(num_items, DIM)).astype(np.float32)

items = []
for i in range(num_items):
    key = f"doc:{i:04d}"
    vec = base[i]
    payload = {"title": f"Document {i}", "tag": f"topic-{i%10}", "i": i}
    items.append((key, vec, payload))

print("\nIngesting data (bulk upsert)...")
ing = db.bulk_upsert(items)
print(json.dumps({k: ing[k] for k in ["ok", "bad", "total"]}, indent=2))

target_idx = 133
q = base[target_idx] + 0.01 * rng.normal(size=(DIM,)).astype(np.float32)

print("\nKNN search (fanout all nodes)...")
res = db.knn_search(q, k=8, fanout=None)
print("Top hits:")
for h in res["hits"][:8]:
    print(f"  score={h['score']:.4f}  key={h['key']}  tag={h['payload']['tag'] if h['payload'] else None}")

print("\nShard ownership examples (consistent hashing + RF=2):")
for key in ["doc:0001", "doc:0133", "doc:0499"]:
    owners = [n.node_id for n in db._owners(key)]
    print(f"  {key} -> owners={owners}")

print("\n--- Simulating failure: bringing DOWN one owner and testing writes ---")
test_key = "doc:0001"
owners = [n.node_id for n in db._owners(test_key)]
down_node = owners[0]
print(f"Key {test_key} owners={owners} -> taking DOWN {down_node}")
db.set_node_up(down_node, False)

print("\nUpserting an update for that key (W=2, RF=2): expect failure if one replica down.")
update_vec = (base[1] + 0.2 * rng.normal(size=(DIM,)).astype(np.float32))
wr = db.upsert(test_key, update_vec, {"title": "Document 1 (UPDATED)", "tag": "topic-updated", "i": 1})
print(json.dumps(wr, indent=2))

print("\nReading that key (R=1): should still succeed if at least one replica up.")
gr = db.get(test_key)
print(json.dumps({k: gr[k] for k in ["key", "owners", "ok", "latest", "read_repair_applied_to"]}, indent=2))

print("\nBringing node back UP and letting anti-entropy repair state...")
db.set_node_up(down_node, True)
time.sleep(3.2)

print("\nReading again: should show consistent latest state across replicas (and possibly repaired).")
gr2 = db.get(test_key)
print(json.dumps({k: gr2[k] for k in ["key", "owners", "ok", "latest", "read_repair_applied_to"]}, indent=2))

print("\n--- Forcing staleness + read-repair demo ---")
key2 = "doc:0420"
owners2 = [n.node_id for n in db._owners(key2)]
print(f"{key2} owners={owners2}")

db.set_node_up(owners2[1], False)
wr2 = db.upsert(key2, base[420] + 0.3 * rng.normal(size=(DIM,)).astype(np.float32), {"title": "Doc 420 NEW", "tag": "topic-new", "i": 420})
db.set_node_up(owners2[1], True)

print("Write result:")
print(json.dumps(wr2, indent=2))

print("\nNow GET (R=1) should return latest from any replica and read-repair the stale one if needed:")
gr3 = db.get(key2)
print(json.dumps({k: gr3[k] for k in ["key", "owners", "ok", "latest", "read_repair_applied_to"]}, indent=2))

print("\nKNN search after repairs (fanout=2 random nodes, showing robustness):")
res2 = db.knn_search(q, k=6, fanout=2)
for h in res2["hits"]:
    print(f"  score={h['score']:.4f}  key={h['key']}  title={h['payload']['title'] if h['payload'] else None}")

db.stop()
print("\nWe now have a working simulated distributed vector DB with sharding + replication + quorum + repair.")

Distributed Vector DB demo: sharding + replication + quorums

Ingesting data (bulk upsert)...
{
  "ok": 500,
  "bad": 0,
  "total": 500
}

KNN search (fanout all nodes)...
Top hits:
  score=0.9999  key=doc:0133  tag=topic-3
  score=0.4128  key=doc:0033  tag=topic-3
  score=0.3697  key=doc:0137  tag=topic-7
  score=0.3572  key=doc:0214  tag=topic-4
  score=0.3321  key=doc:0267  tag=topic-7
  score=0.3168  key=doc:0155  tag=topic-5
  score=0.2767  key=doc:0462  tag=topic-2
  score=0.2767  key=doc:0462  tag=topic-2

Shard ownership examples (consistent hashing + RF=2):
  doc:0001 -> owners=['node-b', 'node-d']
  doc:0133 -> owners=['node-a', 'node-c']
  doc:0499 -> owners=['node-c', 'node-d']

--- Simulating failure: bringing DOWN one owner and testing writes ---
Key doc:0001 owners=['node-b', 'node-d'] -> taking DOWN node-b

Upserting an update for that key (W=2, RF=2): expect failure if one replica down.
{
  "key": "doc:0001",
  "owners": [
    "node-b",
    "node-d"
  ],
  "acks": 1,
 