Skip to content

Commit

Permalink
feat: Adding saved dataset capabilities for Postgres (#3070)
Browse files Browse the repository at this point in the history
* feat: Adding saved dataset capabilities for Postgres

Signed-off-by: Danny Chiao <danny@tecton.ai>
Signed-off-by: alex.eijssen <alex.eijssen@energyessentials.nl>
Co-authored-by: Danny Chiao <danny@tecton.ai>
Co-authored-by: alex.eijssen <alex.eijssen@energyessentials.nl>
  • Loading branch information
3 people committed Aug 11, 2022
1 parent 36747aa commit d3253c3
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 52 deletions.
47 changes: 35 additions & 12 deletions Makefile
Expand Up @@ -164,24 +164,47 @@ test-python-universal-athena:
not s3_registry" \
sdk/python/tests



test-python-universal-postgres:
test-python-universal-postgres-offline:
PYTHONPATH='.' \
FULL_REPO_CONFIGS_MODULE=sdk.python.feast.infra.offline_stores.contrib.postgres_repo_configuration \
PYTEST_PLUGINS=sdk.python.feast.infra.offline_stores.contrib.postgres_offline_store.tests \
FEAST_USAGE=False \
IS_TEST=True \
python -m pytest -x --integration \
-k "not test_historical_retrieval_fails_on_validation and \
not test_historical_retrieval_with_validation and \
python -m pytest -n 8 --integration \
-k "not test_historical_retrieval_with_validation and \
not test_historical_features_persisting and \
not test_historical_retrieval_fails_on_validation and \
not test_universal_cli and \
not test_go_feature_server and \
not test_feature_logging and \
not test_universal_types" \
sdk/python/tests
not test_universal_cli and \
not test_go_feature_server and \
not test_feature_logging and \
not test_reorder_columns and \
not test_logged_features_validation and \
not test_lambda_materialization_consistency and \
not test_offline_write and \
not test_push_features_to_offline_store and \
not gcs_registry and \
not s3_registry and \
not test_universal_types" \
sdk/python/tests

test-python-universal-postgres-online:
PYTHONPATH='.' \
FULL_REPO_CONFIGS_MODULE=sdk.python.feast.infra.online_stores.contrib.postgres_repo_configuration \
PYTEST_PLUGINS=sdk.python.feast.infra.offline_stores.contrib.postgres_offline_store.tests \
FEAST_USAGE=False \
IS_TEST=True \
python -m pytest -n 8 --integration \
-k "not test_universal_cli and \
not test_go_feature_server and \
not test_feature_logging and \
not test_reorder_columns and \
not test_logged_features_validation and \
not test_lambda_materialization_consistency and \
not test_offline_write and \
not test_push_features_to_offline_store and \
not gcs_registry and \
not s3_registry and \
not test_universal_types" \
sdk/python/tests

test-python-universal-cassandra:
PYTHONPATH='.' \
Expand Down
Expand Up @@ -5,6 +5,7 @@
Any,
Callable,
ContextManager,
Dict,
Iterator,
KeysView,
List,
Expand All @@ -13,6 +14,7 @@
Union,
)

import numpy as np
import pandas as pd
import pyarrow as pa
from jinja2 import BaseLoader, Environment
Expand All @@ -24,6 +26,9 @@
from feast.errors import InvalidEntityType
from feast.feature_view import DUMMY_ENTITY_ID, DUMMY_ENTITY_VAL, FeatureView
from feast.infra.offline_stores import offline_utils
from feast.infra.offline_stores.contrib.postgres_offline_store.postgres_source import (
SavedDatasetPostgreSQLStorage,
)
from feast.infra.offline_stores.offline_store import (
OfflineStore,
RetrievalJob,
Expand Down Expand Up @@ -112,24 +117,24 @@ def get_historical_features(
project: str,
full_feature_names: bool = False,
) -> RetrievalJob:

entity_schema = _get_entity_schema(entity_df, config)

entity_df_event_timestamp_col = (
offline_utils.infer_event_timestamp_from_entity_df(entity_schema)
)

entity_df_event_timestamp_range = _get_entity_df_event_timestamp_range(
entity_df,
entity_df_event_timestamp_col,
config,
)

@contextlib.contextmanager
def query_generator() -> Iterator[str]:
table_name = None
if isinstance(entity_df, pd.DataFrame):
table_name = offline_utils.get_temp_entity_table_name()
entity_schema = df_to_postgres_table(
config.offline_store, entity_df, table_name
)
df_query = table_name
elif isinstance(entity_df, str):
df_query = f"({entity_df}) AS sub"
entity_schema = get_query_schema(config.offline_store, df_query)
else:
raise TypeError(entity_df)

entity_df_event_timestamp_col = (
offline_utils.infer_event_timestamp_from_entity_df(entity_schema)
)
table_name = offline_utils.get_temp_entity_table_name()

_upload_entity_df(config, entity_df, table_name)

expected_join_keys = offline_utils.get_expected_join_keys(
project, feature_views, registry
Expand All @@ -139,13 +144,6 @@ def query_generator() -> Iterator[str]:
entity_schema, expected_join_keys, entity_df_event_timestamp_col
)

entity_df_event_timestamp_range = _get_entity_df_event_timestamp_range(
entity_df,
entity_df_event_timestamp_col,
config,
df_query,
)

query_context = offline_utils.get_feature_view_query_context(
feature_refs,
feature_views,
Expand All @@ -165,7 +163,7 @@ def query_generator() -> Iterator[str]:
try:
yield build_point_in_time_query(
query_context_dict,
left_table_query_string=df_query,
left_table_query_string=table_name,
entity_df_event_timestamp_col=entity_df_event_timestamp_col,
entity_df_columns=entity_schema.keys(),
query_template=MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN,
Expand All @@ -189,6 +187,12 @@ def query_generator() -> Iterator[str]:
on_demand_feature_views=OnDemandFeatureView.get_requested_odfvs(
feature_refs, project, registry
),
metadata=RetrievalMetadata(
features=feature_refs,
keys=list(entity_schema.keys() - {entity_df_event_timestamp_col}),
min_event_timestamp=entity_df_event_timestamp_range[0],
max_event_timestamp=entity_df_event_timestamp_range[1],
),
)

@staticmethod
Expand Down Expand Up @@ -294,14 +298,19 @@ def metadata(self) -> Optional[RetrievalMetadata]:
return self._metadata

def persist(self, storage: SavedDatasetStorage):
pass
assert isinstance(storage, SavedDatasetPostgreSQLStorage)

df_to_postgres_table(
config=self.config.offline_store,
df=self.to_df(),
table_name=storage.postgres_options._table,
)


def _get_entity_df_event_timestamp_range(
entity_df: Union[pd.DataFrame, str],
entity_df_event_timestamp_col: str,
config: RepoConfig,
table_name: str,
) -> Tuple[datetime, datetime]:
if isinstance(entity_df, pd.DataFrame):
entity_df_event_timestamp = entity_df.loc[
Expand All @@ -312,15 +321,15 @@ def _get_entity_df_event_timestamp_range(
entity_df_event_timestamp, utc=True
)
entity_df_event_timestamp_range = (
entity_df_event_timestamp.min(),
entity_df_event_timestamp.max(),
entity_df_event_timestamp.min().to_pydatetime(),
entity_df_event_timestamp.max().to_pydatetime(),
)
elif isinstance(entity_df, str):
# If the entity_df is a string (SQL query), determine range
# from table
with _get_conn(config.offline_store) as conn, conn.cursor() as cur:
cur.execute(
f"SELECT MIN({entity_df_event_timestamp_col}) AS min, MAX({entity_df_event_timestamp_col}) AS max FROM {table_name}"
f"SELECT MIN({entity_df_event_timestamp_col}) AS min, MAX({entity_df_event_timestamp_col}) AS max FROM ({entity_df}) as tmp_alias"
),
res = cur.fetchone()
entity_df_event_timestamp_range = (res[0], res[1])
Expand Down Expand Up @@ -374,6 +383,34 @@ def build_point_in_time_query(
return query


def _upload_entity_df(
config: RepoConfig, entity_df: Union[pd.DataFrame, str], table_name: str
):
if isinstance(entity_df, pd.DataFrame):
# If the entity_df is a pandas dataframe, upload it to Postgres
df_to_postgres_table(config.offline_store, entity_df, table_name)
elif isinstance(entity_df, str):
# If the entity_df is a string (SQL query), create a Postgres table out of it
with _get_conn(config.offline_store) as conn, conn.cursor() as cur:
cur.execute(f"CREATE TABLE {table_name} AS ({entity_df})")
else:
raise InvalidEntityType(type(entity_df))


def _get_entity_schema(
entity_df: Union[pd.DataFrame, str],
config: RepoConfig,
) -> Dict[str, np.dtype]:
if isinstance(entity_df, pd.DataFrame):
return dict(zip(entity_df.columns, entity_df.dtypes))

elif isinstance(entity_df, str):
df_query = f"({entity_df}) AS sub"
return get_query_schema(config.offline_store, df_query)
else:
raise InvalidEntityType(type(entity_df))


# Copied from the Feast Redshift offline store implementation
# Note: Keep this in sync with sdk/python/feast/infra/offline_stores/redshift.py:
# MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN
Expand Down
@@ -1,27 +1,42 @@
import json
from typing import Callable, Dict, Iterable, Optional, Tuple

from typeguard import typechecked

from feast.data_source import DataSource
from feast.errors import DataSourceNoNameException
from feast.infra.utils.postgres.connection_utils import _get_conn
from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto
from feast.protos.feast.core.SavedDataset_pb2 import (
SavedDatasetStorage as SavedDatasetStorageProto,
)
from feast.repo_config import RepoConfig
from feast.saved_dataset import SavedDatasetStorage
from feast.type_map import pg_type_code_to_pg_type, pg_type_to_feast_value_type
from feast.value_type import ValueType


@typechecked
class PostgreSQLSource(DataSource):
def __init__(
self,
name: str,
query: str,
name: Optional[str] = None,
query: Optional[str] = None,
table: Optional[str] = None,
timestamp_field: Optional[str] = "",
created_timestamp_column: Optional[str] = "",
field_mapping: Optional[Dict[str, str]] = None,
description: Optional[str] = "",
tags: Optional[Dict[str, str]] = None,
owner: Optional[str] = "",
):
self._postgres_options = PostgreSQLOptions(name=name, query=query)
self._postgres_options = PostgreSQLOptions(name=name, query=query, table=table)

# If no name, use the table as the default name.
if name is None and table is None:
raise DataSourceNoNameException()
name = name or table
assert name

super().__init__(
name=name,
Expand Down Expand Up @@ -55,9 +70,11 @@ def from_proto(data_source: DataSourceProto):
assert data_source.HasField("custom_options")

postgres_options = json.loads(data_source.custom_options.configuration)

return PostgreSQLSource(
name=postgres_options["name"],
query=postgres_options["query"],
table=postgres_options["table"],
field_mapping=dict(data_source.field_mapping),
timestamp_field=data_source.timestamp_field,
created_timestamp_column=data_source.created_timestamp_column,
Expand Down Expand Up @@ -102,26 +119,60 @@ def get_table_column_names_and_types(
)

def get_table_query_string(self) -> str:
return f"({self._postgres_options._query})"

if self._postgres_options._table:
return f"{self._postgres_options._table}"
else:
return f"({self._postgres_options._query})"


class PostgreSQLOptions:
def __init__(self, name: str, query: Optional[str]):
self._name = name
self._query = query
def __init__(
self,
name: Optional[str],
query: Optional[str],
table: Optional[str],
):
self._name = name or ""
self._query = query or ""
self._table = table or ""

@classmethod
def from_proto(cls, postgres_options_proto: DataSourceProto.CustomSourceOptions):
config = json.loads(postgres_options_proto.configuration.decode("utf8"))
postgres_options = cls(name=config["name"], query=config["query"])
postgres_options = cls(
name=config["name"], query=config["query"], table=config["table"]
)

return postgres_options

def to_proto(self) -> DataSourceProto.CustomSourceOptions:
postgres_options_proto = DataSourceProto.CustomSourceOptions(
configuration=json.dumps(
{"name": self._name, "query": self._query}
{"name": self._name, "query": self._query, "table": self._table}
).encode()
)

return postgres_options_proto


class SavedDatasetPostgreSQLStorage(SavedDatasetStorage):
_proto_attr_name = "custom_storage"

postgres_options: PostgreSQLOptions

def __init__(self, table_ref: str):
self.postgres_options = PostgreSQLOptions(
table=table_ref, name=None, query=None
)

@staticmethod
def from_proto(storage_proto: SavedDatasetStorageProto) -> SavedDatasetStorage:
return SavedDatasetPostgreSQLStorage(
table_ref=PostgreSQLOptions.from_proto(storage_proto.custom_storage)._table
)

def to_proto(self) -> SavedDatasetStorageProto:
return SavedDatasetStorageProto(custom_storage=self.postgres_options.to_proto())

def to_data_source(self) -> DataSource:
return PostgreSQLSource(table=self.postgres_options._table)
@@ -1,7 +1,11 @@
from feast.infra.offline_stores.contrib.postgres_offline_store.tests.data_source import (
PostgreSQLDataSourceCreator,
)
from tests.integration.feature_repos.repo_configuration import REDIS_CONFIG
from tests.integration.feature_repos.universal.online_store.redis import (
RedisOnlineStoreCreator,
)

AVAILABLE_OFFLINE_STORES = [("local", PostgreSQLDataSourceCreator)]

AVAILABLE_ONLINE_STORES = {"postgres": (None, PostgreSQLDataSourceCreator)}
AVAILABLE_ONLINE_STORES = {"redis": (REDIS_CONFIG, RedisOnlineStoreCreator)}
@@ -0,0 +1,10 @@
from feast.infra.offline_stores.contrib.postgres_offline_store.tests.data_source import (
PostgreSQLDataSourceCreator,
)
from tests.integration.feature_repos.integration_test_repo_config import (
IntegrationTestRepoConfig,
)

FULL_REPO_CONFIGS = [
IntegrationTestRepoConfig(online_store_creator=PostgreSQLDataSourceCreator),
]

0 comments on commit d3253c3

Please sign in to comment.