Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 50 additions & 20 deletions photomap/backend/cluster_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import hashlib
import importlib.resources
import logging
import threading
from pathlib import Path
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -66,6 +67,25 @@
VOCAB_BATCH_PHRASES = 32


# Single-flight guard for vocab cache builds. Both /cluster_labels and
# /image_label dispatch through `asyncio.to_thread`, so concurrent FastAPI
# requests run `get_or_build_vocab_embeddings` in independent worker threads.
# On a cold cache, every caller would otherwise re-load the encoder and
# re-encode the full vocabulary before the first writer's atomic rename
# lands. Per-encoder locks let unrelated specs build in parallel.
_VOCAB_BUILD_LOCKS_MUTEX = threading.Lock()
_VOCAB_BUILD_LOCKS: dict[str, threading.Lock] = {}


def _vocab_build_lock(encoder_spec: str) -> threading.Lock:
with _VOCAB_BUILD_LOCKS_MUTEX:
lock = _VOCAB_BUILD_LOCKS.get(encoder_spec)
if lock is None:
lock = threading.Lock()
_VOCAB_BUILD_LOCKS[encoder_spec] = lock
return lock


def vocab_file_path() -> Path:
"""Filesystem path to the bundled `cluster_vocab.txt`."""
resource = importlib.resources.files(VOCAB_PACKAGE) / VOCAB_FILENAME
Expand Down Expand Up @@ -315,27 +335,37 @@ def get_or_build_vocab_embeddings(
if cached is not None:
return cached

logger.info("Building vocab embeddings cache at %s", cache_path)
phrases = load_vocab_phrases(vocab_path)
encoder = get_cached_encoder(encoder_spec, cache_dir=cache_dir)
embeddings = _encode_phrases_ensembled(encoder, phrases)

cache_path.parent.mkdir(parents=True, exist_ok=True)
tmp_path = cache_path.with_name(cache_path.name + ".tmp")
# Open via file handle so numpy doesn't second-guess the suffix.
with tmp_path.open("wb") as fh:
np.savez(
fh,
encoder_spec=np.array(encoder_spec),
n_templates=np.array(len(PROMPT_TEMPLATES)),
phrases=np.array(phrases),
embeddings=embeddings,
# Serialize concurrent builds for the same encoder. Re-check inside the
# lock so the second waiter picks up the first builder's atomic rename
# instead of redundantly re-encoding the full vocabulary.
with _vocab_build_lock(encoder_spec):
cached = _read_cached_vocab(cache_path, vocab_path, encoder_spec)
if cached is not None:
return cached

logger.info("Building vocab embeddings cache at %s", cache_path)
phrases = load_vocab_phrases(vocab_path)
encoder = get_cached_encoder(encoder_spec, cache_dir=cache_dir)
embeddings = _encode_phrases_ensembled(encoder, phrases)

cache_path.parent.mkdir(parents=True, exist_ok=True)
tmp_path = cache_path.with_name(cache_path.name + ".tmp")
# Open via file handle so numpy doesn't second-guess the suffix.
with tmp_path.open("wb") as fh:
np.savez(
fh,
encoder_spec=np.array(encoder_spec),
n_templates=np.array(len(PROMPT_TEMPLATES)),
phrases=np.array(phrases),
embeddings=embeddings,
)
tmp_path.replace(cache_path)
logger.info(
"Vocab embeddings cached: %d phrases, dim=%d",
len(phrases),
embeddings.shape[1] if len(phrases) else 0,
)
tmp_path.replace(cache_path)
logger.info(
"Vocab embeddings cached: %d phrases, dim=%d", len(phrases), embeddings.shape[1] if len(phrases) else 0
)
return phrases, embeddings
return phrases, embeddings


# ---------------------------------------------------------------------------
Expand Down
80 changes: 80 additions & 0 deletions tests/backend/test_cluster_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,86 @@ def test_build_caches_and_reuses(tiny_vocab, isolated_cache, fake_encoder):
assert encoder.encode_calls == 1 # unchanged — cache hit


def test_concurrent_builds_are_serialized(tiny_vocab, isolated_cache, monkeypatch):
"""Concurrent first-time callers must not redundantly re-encode the vocab.

Simulates the production race: /cluster_labels and /image_label both
dispatch through asyncio.to_thread, so a cold cache plus a few near-
simultaneous requests would otherwise trigger N parallel encoder builds.
The first thread blocks inside encode_text until the other threads have
arrived at the lock; then the gate is released and we assert that exactly
one build ran.
"""
import threading

spec = "fake:concurrent"
entered = threading.Event()
release = threading.Event()
call_count = 0
call_lock = threading.Lock()

class GatedEncoder:
embedding_dim = 16

def encode_text(self, texts):
nonlocal call_count
with call_lock:
call_count += 1
first = call_count == 1
if first:
entered.set()
# Block until the test signals release, giving the other
# threads time to pile up on the build lock.
assert release.wait(timeout=5.0), "release event never fired"
rows = []
for t in texts:
rng = np.random.default_rng(abs(hash(t)) % (2**32))
v = rng.standard_normal(self.embedding_dim).astype(np.float32)
v /= np.linalg.norm(v)
rows.append(v)
return np.stack(rows)

encoder = GatedEncoder()
monkeypatch.setattr(
cluster_labels, "get_cached_encoder", lambda spec, *, cache_dir=None, device=None: encoder
)
# Ensure no stale lock from a previous test run.
monkeypatch.setattr(cluster_labels, "_VOCAB_BUILD_LOCKS", {})

results: list[tuple[list[str], np.ndarray]] = []
errors: list[BaseException] = []

def worker():
try:
results.append(cluster_labels.get_or_build_vocab_embeddings(spec))
except BaseException as err:
errors.append(err)

threads = [threading.Thread(target=worker) for _ in range(4)]
for t in threads:
t.start()

assert entered.wait(timeout=5.0), "first thread never reached encode_text"
# Give the other threads a moment to block on the build lock. There's no
# public hook to observe that, so we sleep briefly; the assertion below
# (call_count == 1) catches the race regardless of timing.
import time
time.sleep(0.1)
release.set()
for t in threads:
t.join(timeout=5.0)
assert not t.is_alive(), "worker thread hung"

assert not errors, f"worker raised: {errors!r}"
assert call_count == 1, f"encoder was called {call_count} times; guard failed"
assert len(results) == len(threads)
# All callers see the same phrases and identical embeddings.
phrases0, emb0 = results[0]
for phrases, emb in results[1:]:
assert phrases == phrases0
np.testing.assert_array_equal(emb, emb0)


def test_cache_invalidates_on_vocab_edit(tiny_vocab, isolated_cache, fake_encoder):
spec = "fake:test-encoder"
cluster_labels.get_or_build_vocab_embeddings(spec)
Expand Down
Loading