From aa7907085db05211b1cb506cd7741c60fc799269 Mon Sep 17 00:00:00 2001 From: RutujaPathade <73137503+RutujaPathade@users.noreply.github.com> Date: Sun, 29 Mar 2026 20:22:35 -0700 Subject: [PATCH] Removed unnecessary lazy imports Signed-off-by: RutujaPathade <73137503+RutujaPathade@users.noreply.github.com> --- .../infra/online_stores/faiss_online_store.py | 122 ++++++--- .../feast/infra/online_stores/online_store.py | 3 +- .../test_faiss_online_store_versioning.py | 243 ++++++++++++++++++ 3 files changed, 329 insertions(+), 39 deletions(-) create mode 100644 sdk/python/tests/unit/test_faiss_online_store_versioning.py diff --git a/sdk/python/feast/infra/online_stores/faiss_online_store.py b/sdk/python/feast/infra/online_stores/faiss_online_store.py index 3e3d92cde6d..31c0b499b7c 100644 --- a/sdk/python/feast/infra/online_stores/faiss_online_store.py +++ b/sdk/python/feast/infra/online_stores/faiss_online_store.py @@ -43,16 +43,30 @@ def teardown(self): self.entity_keys = {} +def _table_id(project: str, table: FeatureView, enable_versioning: bool = False) -> str: + """Compute the table key, including version suffix when versioning is enabled.""" + name = table.name + if enable_versioning: + # Prefer version_tag from the projection (set by version-qualified refs like @v2) + # over current_version_number (the FV's active version in metadata). + version = getattr(table.projection, "version_tag", None) + if version is None: + version = getattr(table, "current_version_number", None) + if version is not None and version > 0: + name = f"{table.name}_v{version}" + return f"{project}_{name}" + + class FaissOnlineStore(OnlineStore): - _index: Optional[faiss.IndexIVFFlat] = None - _in_memory_store: InMemoryStore = InMemoryStore() + _indices: Dict[str, faiss.IndexIVFFlat] = {} + _in_memory_stores: Dict[str, InMemoryStore] = {} _config: Optional[FaissOnlineStoreConfig] = None _logger: logging.Logger = logging.getLogger(__name__) - def _get_index(self, config: RepoConfig) -> faiss.IndexIVFFlat: - if self._index is None or self._config is None: - raise ValueError("Index is not initialized") - return self._index + def _get_index( + self, table_key: str + ) -> Optional[faiss.IndexIVFFlat]: + return self._indices.get(table_key) def update( self, @@ -63,23 +77,33 @@ def update( entities_to_keep: Sequence[Entity], partial: bool, ): - feature_views = tables_to_keep - if not feature_views: - return - - feature_names = [f.name for f in feature_views[0].features] - dimension = len(feature_names) - self._config = FaissOnlineStoreConfig(**config.online_store.dict()) - if self._index is None or not partial: - quantizer = faiss.IndexFlatL2(dimension) - self._index = faiss.IndexIVFFlat(quantizer, dimension, self._config.nlist) - self._index.train( - np.random.rand(self._config.nlist * 100, dimension).astype(np.float32) - ) - self._in_memory_store = InMemoryStore() + versioning = config.registry.enable_online_feature_view_versioning + + for table in tables_to_delete: + table_key = _table_id(config.project, table, versioning) + self._indices.pop(table_key, None) + self._in_memory_stores.pop(table_key, None) + + for table in tables_to_keep: + table_key = _table_id(config.project, table, versioning) + feature_names = [f.name for f in table.features] + dimension = len(feature_names) + + if table_key not in self._indices or not partial: + quantizer = faiss.IndexFlatL2(dimension) + index = faiss.IndexIVFFlat( + quantizer, dimension, self._config.nlist + ) + index.train( + np.random.rand(self._config.nlist * 100, dimension).astype( + np.float32 + ) + ) + self._indices[table_key] = index + self._in_memory_stores[table_key] = InMemoryStore() - self._in_memory_store.update(feature_names, {}) + self._in_memory_stores[table_key].update(feature_names, {}) def teardown( self, @@ -87,8 +111,13 @@ def teardown( tables: Sequence[FeatureView], entities: Sequence[Entity], ): - self._index = None - self._in_memory_store.teardown() + versioning = config.registry.enable_online_feature_view_versioning + for table in tables: + table_key = _table_id(config.project, table, versioning) + self._indices.pop(table_key, None) + store = self._in_memory_stores.pop(table_key, None) + if store is not None: + store.teardown() def online_read( self, @@ -97,7 +126,12 @@ def online_read( entity_keys: List[EntityKeyProto], requested_features: Optional[List[str]] = None, ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: - if self._index is None: + versioning = config.registry.enable_online_feature_view_versioning + table_key = _table_id(config.project, table, versioning) + index = self._get_index(table_key) + in_memory_store = self._in_memory_stores.get(table_key) + + if index is None or in_memory_store is None: return [(None, None)] * len(entity_keys) results: List[Tuple[Optional[datetime], Optional[Dict[str, Any]]]] = [] @@ -105,15 +139,15 @@ def online_read( serialized_key = serialize_entity_key( entity_key, config.entity_key_serialization_version ).hex() - idx = self._in_memory_store.entity_keys.get(serialized_key, -1) + idx = in_memory_store.entity_keys.get(serialized_key, -1) if idx == -1: results.append((None, None)) else: - feature_vector = self._index.reconstruct(int(idx)) + feature_vector = index.reconstruct(int(idx)) feature_dict = { name: ValueProto(double_val=value) for name, value in zip( - self._in_memory_store.feature_names, feature_vector + in_memory_store.feature_names, feature_vector ) } results.append((None, feature_dict)) @@ -128,8 +162,16 @@ def online_write_batch( ], progress: Optional[Callable[[int], Any]], ) -> None: - if self._index is None: - self._logger.warning("Index is not initialized. Skipping write operation.") + versioning = config.registry.enable_online_feature_view_versioning + table_key = _table_id(config.project, table, versioning) + index = self._get_index(table_key) + in_memory_store = self._in_memory_stores.get(table_key) + + if index is None or in_memory_store is None: + self._logger.warning( + "Index for table '%s' is not initialized. Skipping write operation.", + table_key, + ) return feature_vectors = [] @@ -142,7 +184,7 @@ def online_write_batch( feature_vector = np.array( [ feature_dict[name].double_val - for name in self._in_memory_store.feature_names + for name in in_memory_store.feature_names ], dtype=np.float32, ) @@ -153,21 +195,21 @@ def online_write_batch( feature_vectors_array = np.array(feature_vectors) existing_indices = [ - self._in_memory_store.entity_keys.get(sk, -1) for sk in serialized_keys + in_memory_store.entity_keys.get(sk, -1) for sk in serialized_keys ] mask = np.array(existing_indices) != -1 if np.any(mask): - self._index.remove_ids( + index.remove_ids( np.array([idx for idx in existing_indices if idx != -1]) ) new_indices = np.arange( - self._index.ntotal, self._index.ntotal + len(feature_vectors_array) + index.ntotal, index.ntotal + len(feature_vectors_array) ) - self._index.add(feature_vectors_array) + index.add(feature_vectors_array) for sk, idx in zip(serialized_keys, new_indices): - self._in_memory_store.entity_keys[sk] = idx + in_memory_store.entity_keys[sk] = idx if progress: progress(len(data)) @@ -189,12 +231,16 @@ def retrieve_online_documents( Optional[ValueProto], ] ]: - if self._index is None: + versioning = config.registry.enable_online_feature_view_versioning + table_key = _table_id(config.project, table, versioning) + index = self._get_index(table_key) + + if index is None: self._logger.warning("Index is not initialized. Returning empty result.") return [] query_vector = np.array(embedding, dtype=np.float32).reshape(1, -1) - distances, indices = self._index.search(query_vector, top_k) + distances, indices = index.search(query_vector, top_k) results: List[ Tuple[ @@ -209,7 +255,7 @@ def retrieve_online_documents( if idx == -1: continue - feature_vector = self._index.reconstruct(int(idx)) + feature_vector = index.reconstruct(int(idx)) timestamp = Timestamp() timestamp.GetCurrentTime() diff --git a/sdk/python/feast/infra/online_stores/online_store.py b/sdk/python/feast/infra/online_stores/online_store.py index 4913046470c..b1263080cd8 100644 --- a/sdk/python/feast/infra/online_stores/online_store.py +++ b/sdk/python/feast/infra/online_stores/online_store.py @@ -256,9 +256,10 @@ def get_online_features( def _check_versioned_read_support(self, grouped_refs): """Raise an error if versioned reads are attempted on unsupported stores.""" + from feast.infra.online_stores.faiss_online_store import FaissOnlineStore from feast.infra.online_stores.sqlite import SqliteOnlineStore - if isinstance(self, SqliteOnlineStore): + if isinstance(self, (SqliteOnlineStore, FaissOnlineStore)): return for table, _ in grouped_refs: version_tag = getattr(table.projection, "version_tag", None) diff --git a/sdk/python/tests/unit/test_faiss_online_store_versioning.py b/sdk/python/tests/unit/test_faiss_online_store_versioning.py new file mode 100644 index 00000000000..fcb41245407 --- /dev/null +++ b/sdk/python/tests/unit/test_faiss_online_store_versioning.py @@ -0,0 +1,243 @@ +"""Unit tests for versioned feature view support in FaissOnlineStore.""" +import sys +from datetime import timedelta +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest + +from feast.entity import Entity +from feast.feature_view import FeatureView +from feast.field import Field +from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto +from feast.protos.feast.types.Value_pb2 import Value as ValueProto +from feast.types import Float32 + +# --------------------------------------------------------------------------- +# Helpers to build lightweight FeatureView fixtures without faiss +# --------------------------------------------------------------------------- + + +def _make_feature_view(name: str, version_number=None, version_tag=None): + entity = Entity(name="driver_id", join_keys=["driver_id"]) + fv = FeatureView( + name=name, + entities=[entity], + ttl=timedelta(days=1), + schema=[Field(name="feature_a", dtype=Float32)], + ) + if version_number is not None: + fv.current_version_number = version_number + if version_tag is not None: + fv.projection.version_tag = version_tag + return fv + + +# --------------------------------------------------------------------------- +# _table_id tests (no faiss needed — we mock the import at module level) +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _mock_faiss(): + """Inject a minimal faiss mock so faiss_online_store can be imported.""" + faiss_mock = MagicMock() + with patch.dict(sys.modules, {"faiss": faiss_mock}): + # Remove cached module so the patched version is used + sys.modules.pop("feast.infra.online_stores.faiss_online_store", None) + yield faiss_mock + sys.modules.pop("feast.infra.online_stores.faiss_online_store", None) + + +class TestFaissTableId: + def test_no_versioning(self): + from feast.infra.online_stores.faiss_online_store import _table_id + + fv = _make_feature_view("driver_stats") + assert _table_id("my_project", fv) == "my_project_driver_stats" + + def test_versioning_disabled_ignores_version_number(self): + from feast.infra.online_stores.faiss_online_store import _table_id + + fv = _make_feature_view("driver_stats", version_number=3) + assert _table_id("my_project", fv, enable_versioning=False) == "my_project_driver_stats" + + def test_versioning_enabled_v0_no_suffix(self): + from feast.infra.online_stores.faiss_online_store import _table_id + + fv = _make_feature_view("driver_stats", version_number=0) + # version 0 should NOT produce a suffix + assert _table_id("my_project", fv, enable_versioning=True) == "my_project_driver_stats" + + def test_versioning_enabled_version_number(self): + from feast.infra.online_stores.faiss_online_store import _table_id + + fv = _make_feature_view("driver_stats", version_number=2) + assert ( + _table_id("my_project", fv, enable_versioning=True) + == "my_project_driver_stats_v2" + ) + + def test_versioning_enabled_version_tag_takes_precedence(self): + from feast.infra.online_stores.faiss_online_store import _table_id + + # version_tag on the projection takes precedence over current_version_number + fv = _make_feature_view("driver_stats", version_number=1, version_tag=3) + assert ( + _table_id("my_project", fv, enable_versioning=True) + == "my_project_driver_stats_v3" + ) + + def test_versioning_enabled_no_version_set(self): + from feast.infra.online_stores.faiss_online_store import _table_id + + fv = _make_feature_view("driver_stats") + # No version information — falls back to bare name + assert ( + _table_id("my_project", fv, enable_versioning=True) + == "my_project_driver_stats" + ) + + +# --------------------------------------------------------------------------- +# FaissOnlineStore versioned write / read tests (faiss is mocked) +# --------------------------------------------------------------------------- + + +def _make_config(project: str = "test_project", versioning: bool = False): + """Build a minimal RepoConfig-like mock.""" + config = MagicMock() + config.project = project + config.entity_key_serialization_version = 2 + config.online_store.dict.return_value = { + "dimension": 1, + "index_path": "/tmp/test.index", + "index_type": "IVFFlat", + "nlist": 10, + } + config.registry.enable_online_feature_view_versioning = versioning + return config + + +def _make_entity_key(driver_id: int = 1): + return EntityKeyProto( + join_keys=["driver_id"], + entity_values=[ValueProto(int64_val=driver_id)], + ) + + +class TestFaissOnlineStoreVersionedWrite: + def _make_store(self, faiss_mock, nlist: int = 10): + """Create a FaissOnlineStore with a real-enough faiss mock.""" + # Make the index mock respond correctly + index_mock = MagicMock() + index_mock.ntotal = 0 + + def add_side_effect(vectors): + index_mock.ntotal += len(vectors) + + index_mock.add.side_effect = add_side_effect + + def reconstruct_side_effect(idx): + return np.array([float(idx)], dtype=np.float32) + + index_mock.reconstruct.side_effect = reconstruct_side_effect + + faiss_mock.IndexFlatL2.return_value = MagicMock() + faiss_mock.IndexIVFFlat.return_value = index_mock + + from feast.infra.online_stores.faiss_online_store import FaissOnlineStore + + store = FaissOnlineStore() + # Reset class-level dicts to avoid test pollution + store._indices = {} + store._in_memory_stores = {} + return store, index_mock + + def test_write_and_read_without_versioning(self, _mock_faiss): + store, index_mock = self._make_store(_mock_faiss) + config = _make_config(versioning=False) + fv = _make_feature_view("driver_stats") + + store.update(config, [], [fv], [], [], partial=False) + + entity_key = _make_entity_key(driver_id=42) + data = [ + (entity_key, {"feature_a": ValueProto(double_val=1.5)}, None, None) + ] + store.online_write_batch(config, fv, data, None) + + results = store.online_read(config, fv, [entity_key]) + assert len(results) == 1 + ts, feature_dict = results[0] + assert feature_dict is not None + assert "feature_a" in feature_dict + + def test_write_and_read_with_versioning(self, _mock_faiss): + store, index_mock = self._make_store(_mock_faiss) + config = _make_config(versioning=True) + fv_v2 = _make_feature_view("driver_stats", version_number=2) + + store.update(config, [], [fv_v2], [], [], partial=False) + + entity_key = _make_entity_key(driver_id=7) + data = [ + (entity_key, {"feature_a": ValueProto(double_val=2.0)}, None, None) + ] + store.online_write_batch(config, fv_v2, data, None) + + results = store.online_read(config, fv_v2, [entity_key]) + assert len(results) == 1 + _, feature_dict = results[0] + assert feature_dict is not None + + def test_versioned_namespaces_are_isolated(self, _mock_faiss): + """Data written under v1 must not be visible when reading under v2.""" + store, _ = self._make_store(_mock_faiss) + config = _make_config(versioning=True) + + fv_v1 = _make_feature_view("driver_stats", version_number=1) + fv_v2 = _make_feature_view("driver_stats", version_number=2) + + store.update(config, [], [fv_v1, fv_v2], [], [], partial=False) + + entity_key = _make_entity_key(driver_id=99) + data = [ + (entity_key, {"feature_a": ValueProto(double_val=9.9)}, None, None) + ] + # Write only to v1 + store.online_write_batch(config, fv_v1, data, None) + + # Reading from v2 should return (None, None) + results_v2 = store.online_read(config, fv_v2, [entity_key]) + assert results_v2 == [(None, None)] + + # Reading from v1 should return data + results_v1 = store.online_read(config, fv_v1, [entity_key]) + assert results_v1[0][1] is not None + + def test_missing_index_returns_none(self, _mock_faiss): + store, _ = self._make_store(_mock_faiss) + config = _make_config(versioning=True) + fv = _make_feature_view("driver_stats", version_number=5) + # No update called — index not initialised + entity_key = _make_entity_key(driver_id=1) + results = store.online_read(config, fv, [entity_key]) + assert results == [(None, None)] + + def test_teardown_removes_versioned_index(self, _mock_faiss): + store, _ = self._make_store(_mock_faiss) + config = _make_config(versioning=True) + fv = _make_feature_view("driver_stats", version_number=3) + + store.update(config, [], [fv], [], [], partial=False) + + entity_key = _make_entity_key(driver_id=1) + data = [(entity_key, {"feature_a": ValueProto(double_val=3.0)}, None, None)] + store.online_write_batch(config, fv, data, None) + + store.teardown(config, [fv], []) + + # After teardown, reads return None + results = store.online_read(config, fv, [entity_key]) + assert results == [(None, None)]