diff --git a/photomap/backend/cluster_labels.py b/photomap/backend/cluster_labels.py index ed464099..8baacf63 100644 --- a/photomap/backend/cluster_labels.py +++ b/photomap/backend/cluster_labels.py @@ -18,6 +18,7 @@ import hashlib import importlib.resources import logging +import threading from pathlib import Path from typing import TYPE_CHECKING @@ -66,6 +67,25 @@ VOCAB_BATCH_PHRASES = 32 +# Single-flight guard for vocab cache builds. Both /cluster_labels and +# /image_label dispatch through `asyncio.to_thread`, so concurrent FastAPI +# requests run `get_or_build_vocab_embeddings` in independent worker threads. +# On a cold cache, every caller would otherwise re-load the encoder and +# re-encode the full vocabulary before the first writer's atomic rename +# lands. Per-encoder locks let unrelated specs build in parallel. +_VOCAB_BUILD_LOCKS_MUTEX = threading.Lock() +_VOCAB_BUILD_LOCKS: dict[str, threading.Lock] = {} + + +def _vocab_build_lock(encoder_spec: str) -> threading.Lock: + with _VOCAB_BUILD_LOCKS_MUTEX: + lock = _VOCAB_BUILD_LOCKS.get(encoder_spec) + if lock is None: + lock = threading.Lock() + _VOCAB_BUILD_LOCKS[encoder_spec] = lock + return lock + + def vocab_file_path() -> Path: """Filesystem path to the bundled `cluster_vocab.txt`.""" resource = importlib.resources.files(VOCAB_PACKAGE) / VOCAB_FILENAME @@ -315,27 +335,37 @@ def get_or_build_vocab_embeddings( if cached is not None: return cached - logger.info("Building vocab embeddings cache at %s", cache_path) - phrases = load_vocab_phrases(vocab_path) - encoder = get_cached_encoder(encoder_spec, cache_dir=cache_dir) - embeddings = _encode_phrases_ensembled(encoder, phrases) - - cache_path.parent.mkdir(parents=True, exist_ok=True) - tmp_path = cache_path.with_name(cache_path.name + ".tmp") - # Open via file handle so numpy doesn't second-guess the suffix. - with tmp_path.open("wb") as fh: - np.savez( - fh, - encoder_spec=np.array(encoder_spec), - n_templates=np.array(len(PROMPT_TEMPLATES)), - phrases=np.array(phrases), - embeddings=embeddings, + # Serialize concurrent builds for the same encoder. Re-check inside the + # lock so the second waiter picks up the first builder's atomic rename + # instead of redundantly re-encoding the full vocabulary. + with _vocab_build_lock(encoder_spec): + cached = _read_cached_vocab(cache_path, vocab_path, encoder_spec) + if cached is not None: + return cached + + logger.info("Building vocab embeddings cache at %s", cache_path) + phrases = load_vocab_phrases(vocab_path) + encoder = get_cached_encoder(encoder_spec, cache_dir=cache_dir) + embeddings = _encode_phrases_ensembled(encoder, phrases) + + cache_path.parent.mkdir(parents=True, exist_ok=True) + tmp_path = cache_path.with_name(cache_path.name + ".tmp") + # Open via file handle so numpy doesn't second-guess the suffix. + with tmp_path.open("wb") as fh: + np.savez( + fh, + encoder_spec=np.array(encoder_spec), + n_templates=np.array(len(PROMPT_TEMPLATES)), + phrases=np.array(phrases), + embeddings=embeddings, + ) + tmp_path.replace(cache_path) + logger.info( + "Vocab embeddings cached: %d phrases, dim=%d", + len(phrases), + embeddings.shape[1] if len(phrases) else 0, ) - tmp_path.replace(cache_path) - logger.info( - "Vocab embeddings cached: %d phrases, dim=%d", len(phrases), embeddings.shape[1] if len(phrases) else 0 - ) - return phrases, embeddings + return phrases, embeddings # --------------------------------------------------------------------------- diff --git a/tests/backend/test_cluster_labels.py b/tests/backend/test_cluster_labels.py index ba8cc828..ae1c24d1 100644 --- a/tests/backend/test_cluster_labels.py +++ b/tests/backend/test_cluster_labels.py @@ -362,6 +362,86 @@ def test_build_caches_and_reuses(tiny_vocab, isolated_cache, fake_encoder): assert encoder.encode_calls == 1 # unchanged — cache hit +def test_concurrent_builds_are_serialized(tiny_vocab, isolated_cache, monkeypatch): + """Concurrent first-time callers must not redundantly re-encode the vocab. + + Simulates the production race: /cluster_labels and /image_label both + dispatch through asyncio.to_thread, so a cold cache plus a few near- + simultaneous requests would otherwise trigger N parallel encoder builds. + The first thread blocks inside encode_text until the other threads have + arrived at the lock; then the gate is released and we assert that exactly + one build ran. + """ + import threading + + spec = "fake:concurrent" + entered = threading.Event() + release = threading.Event() + call_count = 0 + call_lock = threading.Lock() + + class GatedEncoder: + embedding_dim = 16 + + def encode_text(self, texts): + nonlocal call_count + with call_lock: + call_count += 1 + first = call_count == 1 + if first: + entered.set() + # Block until the test signals release, giving the other + # threads time to pile up on the build lock. + assert release.wait(timeout=5.0), "release event never fired" + rows = [] + for t in texts: + rng = np.random.default_rng(abs(hash(t)) % (2**32)) + v = rng.standard_normal(self.embedding_dim).astype(np.float32) + v /= np.linalg.norm(v) + rows.append(v) + return np.stack(rows) + + encoder = GatedEncoder() + monkeypatch.setattr( + cluster_labels, "get_cached_encoder", lambda spec, *, cache_dir=None, device=None: encoder + ) + # Ensure no stale lock from a previous test run. + monkeypatch.setattr(cluster_labels, "_VOCAB_BUILD_LOCKS", {}) + + results: list[tuple[list[str], np.ndarray]] = [] + errors: list[BaseException] = [] + + def worker(): + try: + results.append(cluster_labels.get_or_build_vocab_embeddings(spec)) + except BaseException as err: + errors.append(err) + + threads = [threading.Thread(target=worker) for _ in range(4)] + for t in threads: + t.start() + + assert entered.wait(timeout=5.0), "first thread never reached encode_text" + # Give the other threads a moment to block on the build lock. There's no + # public hook to observe that, so we sleep briefly; the assertion below + # (call_count == 1) catches the race regardless of timing. + import time + time.sleep(0.1) + release.set() + for t in threads: + t.join(timeout=5.0) + assert not t.is_alive(), "worker thread hung" + + assert not errors, f"worker raised: {errors!r}" + assert call_count == 1, f"encoder was called {call_count} times; guard failed" + assert len(results) == len(threads) + # All callers see the same phrases and identical embeddings. + phrases0, emb0 = results[0] + for phrases, emb in results[1:]: + assert phrases == phrases0 + np.testing.assert_array_equal(emb, emb0) + + def test_cache_invalidates_on_vocab_edit(tiny_vocab, isolated_cache, fake_encoder): spec = "fake:test-encoder" cluster_labels.get_or_build_vocab_embeddings(spec)