Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Adding saved dataset capabilities for Postgres #3070

Merged
merged 6 commits into from Aug 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
47 changes: 35 additions & 12 deletions Makefile
Expand Up @@ -164,24 +164,47 @@ test-python-universal-athena:
not s3_registry" \
sdk/python/tests



test-python-universal-postgres:
test-python-universal-postgres-offline:
PYTHONPATH='.' \
FULL_REPO_CONFIGS_MODULE=sdk.python.feast.infra.offline_stores.contrib.postgres_repo_configuration \
PYTEST_PLUGINS=sdk.python.feast.infra.offline_stores.contrib.postgres_offline_store.tests \
FEAST_USAGE=False \
IS_TEST=True \
python -m pytest -x --integration \
-k "not test_historical_retrieval_fails_on_validation and \
not test_historical_retrieval_with_validation and \
python -m pytest -n 8 --integration \
-k "not test_historical_retrieval_with_validation and \
not test_historical_features_persisting and \
not test_historical_retrieval_fails_on_validation and \
not test_universal_cli and \
not test_go_feature_server and \
not test_feature_logging and \
not test_universal_types" \
sdk/python/tests
not test_universal_cli and \
not test_go_feature_server and \
not test_feature_logging and \
not test_reorder_columns and \
not test_logged_features_validation and \
not test_lambda_materialization_consistency and \
not test_offline_write and \
not test_push_features_to_offline_store and \
not gcs_registry and \
not s3_registry and \
not test_universal_types" \
sdk/python/tests

test-python-universal-postgres-online:
PYTHONPATH='.' \
FULL_REPO_CONFIGS_MODULE=sdk.python.feast.infra.online_stores.contrib.postgres_repo_configuration \
PYTEST_PLUGINS=sdk.python.feast.infra.offline_stores.contrib.postgres_offline_store.tests \
FEAST_USAGE=False \
IS_TEST=True \
python -m pytest -n 8 --integration \
-k "not test_universal_cli and \
not test_go_feature_server and \
not test_feature_logging and \
not test_reorder_columns and \
not test_logged_features_validation and \
not test_lambda_materialization_consistency and \
not test_offline_write and \
not test_push_features_to_offline_store and \
not gcs_registry and \
not s3_registry and \
not test_universal_types" \
sdk/python/tests

test-python-universal-cassandra:
PYTHONPATH='.' \
Expand Down
Expand Up @@ -5,6 +5,7 @@
Any,
Callable,
ContextManager,
Dict,
Iterator,
KeysView,
List,
Expand All @@ -13,6 +14,7 @@
Union,
)

import numpy as np
import pandas as pd
import pyarrow as pa
from jinja2 import BaseLoader, Environment
Expand All @@ -24,6 +26,9 @@
from feast.errors import InvalidEntityType
from feast.feature_view import DUMMY_ENTITY_ID, DUMMY_ENTITY_VAL, FeatureView
from feast.infra.offline_stores import offline_utils
from feast.infra.offline_stores.contrib.postgres_offline_store.postgres_source import (
SavedDatasetPostgreSQLStorage,
)
from feast.infra.offline_stores.offline_store import (
OfflineStore,
RetrievalJob,
Expand Down Expand Up @@ -112,24 +117,24 @@ def get_historical_features(
project: str,
full_feature_names: bool = False,
) -> RetrievalJob:

entity_schema = _get_entity_schema(entity_df, config)

entity_df_event_timestamp_col = (
offline_utils.infer_event_timestamp_from_entity_df(entity_schema)
)

entity_df_event_timestamp_range = _get_entity_df_event_timestamp_range(
entity_df,
entity_df_event_timestamp_col,
config,
)

@contextlib.contextmanager
def query_generator() -> Iterator[str]:
table_name = None
if isinstance(entity_df, pd.DataFrame):
table_name = offline_utils.get_temp_entity_table_name()
entity_schema = df_to_postgres_table(
config.offline_store, entity_df, table_name
)
df_query = table_name
elif isinstance(entity_df, str):
df_query = f"({entity_df}) AS sub"
entity_schema = get_query_schema(config.offline_store, df_query)
else:
raise TypeError(entity_df)

entity_df_event_timestamp_col = (
offline_utils.infer_event_timestamp_from_entity_df(entity_schema)
)
table_name = offline_utils.get_temp_entity_table_name()

_upload_entity_df(config, entity_df, table_name)

expected_join_keys = offline_utils.get_expected_join_keys(
project, feature_views, registry
Expand All @@ -139,13 +144,6 @@ def query_generator() -> Iterator[str]:
entity_schema, expected_join_keys, entity_df_event_timestamp_col
)

entity_df_event_timestamp_range = _get_entity_df_event_timestamp_range(
entity_df,
entity_df_event_timestamp_col,
config,
df_query,
)

query_context = offline_utils.get_feature_view_query_context(
feature_refs,
feature_views,
Expand All @@ -165,7 +163,7 @@ def query_generator() -> Iterator[str]:
try:
yield build_point_in_time_query(
query_context_dict,
left_table_query_string=df_query,
left_table_query_string=table_name,
entity_df_event_timestamp_col=entity_df_event_timestamp_col,
entity_df_columns=entity_schema.keys(),
query_template=MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN,
Expand All @@ -189,6 +187,12 @@ def query_generator() -> Iterator[str]:
on_demand_feature_views=OnDemandFeatureView.get_requested_odfvs(
feature_refs, project, registry
),
metadata=RetrievalMetadata(
features=feature_refs,
keys=list(entity_schema.keys() - {entity_df_event_timestamp_col}),
min_event_timestamp=entity_df_event_timestamp_range[0],
max_event_timestamp=entity_df_event_timestamp_range[1],
),
)

@staticmethod
Expand Down Expand Up @@ -294,14 +298,19 @@ def metadata(self) -> Optional[RetrievalMetadata]:
return self._metadata

def persist(self, storage: SavedDatasetStorage):
pass
assert isinstance(storage, SavedDatasetPostgreSQLStorage)

df_to_postgres_table(
config=self.config.offline_store,
df=self.to_df(),
table_name=storage.postgres_options._table,
)


def _get_entity_df_event_timestamp_range(
entity_df: Union[pd.DataFrame, str],
entity_df_event_timestamp_col: str,
config: RepoConfig,
table_name: str,
) -> Tuple[datetime, datetime]:
if isinstance(entity_df, pd.DataFrame):
entity_df_event_timestamp = entity_df.loc[
Expand All @@ -312,15 +321,15 @@ def _get_entity_df_event_timestamp_range(
entity_df_event_timestamp, utc=True
)
entity_df_event_timestamp_range = (
entity_df_event_timestamp.min(),
entity_df_event_timestamp.max(),
entity_df_event_timestamp.min().to_pydatetime(),
entity_df_event_timestamp.max().to_pydatetime(),
)
elif isinstance(entity_df, str):
# If the entity_df is a string (SQL query), determine range
# from table
with _get_conn(config.offline_store) as conn, conn.cursor() as cur:
cur.execute(
f"SELECT MIN({entity_df_event_timestamp_col}) AS min, MAX({entity_df_event_timestamp_col}) AS max FROM {table_name}"
f"SELECT MIN({entity_df_event_timestamp_col}) AS min, MAX({entity_df_event_timestamp_col}) AS max FROM ({entity_df}) as tmp_alias"
),
res = cur.fetchone()
entity_df_event_timestamp_range = (res[0], res[1])
Expand Down Expand Up @@ -374,6 +383,34 @@ def build_point_in_time_query(
return query


def _upload_entity_df(
config: RepoConfig, entity_df: Union[pd.DataFrame, str], table_name: str
):
if isinstance(entity_df, pd.DataFrame):
# If the entity_df is a pandas dataframe, upload it to Postgres
df_to_postgres_table(config.offline_store, entity_df, table_name)
elif isinstance(entity_df, str):
# If the entity_df is a string (SQL query), create a Postgres table out of it
with _get_conn(config.offline_store) as conn, conn.cursor() as cur:
cur.execute(f"CREATE TABLE {table_name} AS ({entity_df})")
else:
raise InvalidEntityType(type(entity_df))


def _get_entity_schema(
entity_df: Union[pd.DataFrame, str],
config: RepoConfig,
) -> Dict[str, np.dtype]:
if isinstance(entity_df, pd.DataFrame):
return dict(zip(entity_df.columns, entity_df.dtypes))

elif isinstance(entity_df, str):
df_query = f"({entity_df}) AS sub"
return get_query_schema(config.offline_store, df_query)
else:
raise InvalidEntityType(type(entity_df))


# Copied from the Feast Redshift offline store implementation
# Note: Keep this in sync with sdk/python/feast/infra/offline_stores/redshift.py:
# MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN
Expand Down
@@ -1,27 +1,42 @@
import json
from typing import Callable, Dict, Iterable, Optional, Tuple

from typeguard import typechecked

from feast.data_source import DataSource
from feast.errors import DataSourceNoNameException
from feast.infra.utils.postgres.connection_utils import _get_conn
from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto
from feast.protos.feast.core.SavedDataset_pb2 import (
SavedDatasetStorage as SavedDatasetStorageProto,
)
from feast.repo_config import RepoConfig
from feast.saved_dataset import SavedDatasetStorage
from feast.type_map import pg_type_code_to_pg_type, pg_type_to_feast_value_type
from feast.value_type import ValueType


@typechecked
class PostgreSQLSource(DataSource):
def __init__(
self,
name: str,
query: str,
name: Optional[str] = None,
query: Optional[str] = None,
table: Optional[str] = None,
timestamp_field: Optional[str] = "",
created_timestamp_column: Optional[str] = "",
field_mapping: Optional[Dict[str, str]] = None,
description: Optional[str] = "",
tags: Optional[Dict[str, str]] = None,
owner: Optional[str] = "",
):
self._postgres_options = PostgreSQLOptions(name=name, query=query)
self._postgres_options = PostgreSQLOptions(name=name, query=query, table=table)

# If no name, use the table as the default name.
if name is None and table is None:
raise DataSourceNoNameException()
name = name or table
assert name

super().__init__(
name=name,
Expand Down Expand Up @@ -55,9 +70,11 @@ def from_proto(data_source: DataSourceProto):
assert data_source.HasField("custom_options")

postgres_options = json.loads(data_source.custom_options.configuration)

return PostgreSQLSource(
name=postgres_options["name"],
query=postgres_options["query"],
table=postgres_options["table"],
field_mapping=dict(data_source.field_mapping),
timestamp_field=data_source.timestamp_field,
created_timestamp_column=data_source.created_timestamp_column,
Expand Down Expand Up @@ -102,26 +119,60 @@ def get_table_column_names_and_types(
)

def get_table_query_string(self) -> str:
return f"({self._postgres_options._query})"

if self._postgres_options._table:
return f"{self._postgres_options._table}"
else:
return f"({self._postgres_options._query})"


class PostgreSQLOptions:
def __init__(self, name: str, query: Optional[str]):
self._name = name
self._query = query
def __init__(
self,
name: Optional[str],
query: Optional[str],
table: Optional[str],
):
self._name = name or ""
self._query = query or ""
self._table = table or ""

@classmethod
def from_proto(cls, postgres_options_proto: DataSourceProto.CustomSourceOptions):
config = json.loads(postgres_options_proto.configuration.decode("utf8"))
postgres_options = cls(name=config["name"], query=config["query"])
postgres_options = cls(
name=config["name"], query=config["query"], table=config["table"]
)

return postgres_options

def to_proto(self) -> DataSourceProto.CustomSourceOptions:
postgres_options_proto = DataSourceProto.CustomSourceOptions(
configuration=json.dumps(
{"name": self._name, "query": self._query}
{"name": self._name, "query": self._query, "table": self._table}
).encode()
)

return postgres_options_proto


class SavedDatasetPostgreSQLStorage(SavedDatasetStorage):
_proto_attr_name = "custom_storage"

postgres_options: PostgreSQLOptions

def __init__(self, table_ref: str):
self.postgres_options = PostgreSQLOptions(
table=table_ref, name=None, query=None
)

@staticmethod
def from_proto(storage_proto: SavedDatasetStorageProto) -> SavedDatasetStorage:
return SavedDatasetPostgreSQLStorage(
table_ref=PostgreSQLOptions.from_proto(storage_proto.custom_storage)._table
)

def to_proto(self) -> SavedDatasetStorageProto:
return SavedDatasetStorageProto(custom_storage=self.postgres_options.to_proto())

def to_data_source(self) -> DataSource:
return PostgreSQLSource(table=self.postgres_options._table)
@@ -1,7 +1,11 @@
from feast.infra.offline_stores.contrib.postgres_offline_store.tests.data_source import (
PostgreSQLDataSourceCreator,
)
from tests.integration.feature_repos.repo_configuration import REDIS_CONFIG
from tests.integration.feature_repos.universal.online_store.redis import (
RedisOnlineStoreCreator,
)

AVAILABLE_OFFLINE_STORES = [("local", PostgreSQLDataSourceCreator)]

AVAILABLE_ONLINE_STORES = {"postgres": (None, PostgreSQLDataSourceCreator)}
AVAILABLE_ONLINE_STORES = {"redis": (REDIS_CONFIG, RedisOnlineStoreCreator)}
@@ -0,0 +1,10 @@
from feast.infra.offline_stores.contrib.postgres_offline_store.tests.data_source import (
PostgreSQLDataSourceCreator,
)
from tests.integration.feature_repos.integration_test_repo_config import (
IntegrationTestRepoConfig,
)

FULL_REPO_CONFIGS = [
IntegrationTestRepoConfig(online_store_creator=PostgreSQLDataSourceCreator),
]