Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unit test for historical retrieval with panda dataframe #1073

Merged
merged 1 commit into from Oct 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 3 additions & 2 deletions sdk/python/feast/client.py
Expand Up @@ -827,8 +827,9 @@ def get_historical_features(
feature_tables = self._get_feature_tables_from_feature_refs(
feature_refs, project
)
output_location = self._config.get(
CONFIG_SPARK_HISTORICAL_FEATURE_OUTPUT_LOCATION
output_location = os.path.join(
self._config.get(CONFIG_SPARK_HISTORICAL_FEATURE_OUTPUT_LOCATION),
str(uuid.uuid4()),
)
output_format = self._config.get(CONFIG_SPARK_HISTORICAL_FEATURE_OUTPUT_FORMAT)

Expand Down
1 change: 1 addition & 0 deletions sdk/python/feast/pyspark/launchers/standalone/local.py
Expand Up @@ -124,6 +124,7 @@ def get_output_file_uri(self, timeout_sec: int = None):
with self._process as p:
try:
p.wait(timeout_sec)
return self._output_file_uri
except Exception:
p.kill()
raise SparkJobFailure("Timeout waiting for subprocess to return")
Expand Down
2 changes: 1 addition & 1 deletion sdk/python/feast/staging/storage_client.py
Expand Up @@ -270,7 +270,7 @@ def list_files(self, bucket: str, path: str) -> List[str]:
raise NotImplementedError("list files not implemented for Local file")

def upload_file(self, local_path: str, bucket: str, remote_path: str):
dest_fpath = "/" + remote_path
dest_fpath = remote_path if remote_path.startswith("/") else "/" + remote_path
os.makedirs(os.path.dirname(dest_fpath), exist_ok=True)
shutil.copy(local_path, dest_fpath)

Expand Down
178 changes: 138 additions & 40 deletions sdk/python/tests/test_historical_feature_retrieval.py
Expand Up @@ -6,10 +6,14 @@
from contextlib import closing
from datetime import datetime
from typing import List, Tuple
from urllib.parse import urlparse

import grpc
import numpy as np
import pandas as pd
import pytest
from google.protobuf.duration_pb2 import Duration
from pandas.util.testing import assert_frame_equal
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.types import (
BooleanType,
Expand All @@ -19,6 +23,7 @@
StructType,
TimestampType,
)
from pytz import utc

from feast import Client, Entity, Feature, FeatureTable, FileSource, ValueType
from feast.core import CoreService_pb2_grpc as Core
Expand Down Expand Up @@ -82,6 +87,26 @@ def client(server):
return Client(core_url=f"localhost:{free_port}")


@pytest.yield_fixture()
def client_with_local_spark(tmpdir):
import pyspark

spark_staging_location = f"file://{os.path.join(tmpdir, 'staging')}"
historical_feature_output_location = (
f"file://{os.path.join(tmpdir, 'historical_feature_retrieval_output')}"
)

return Client(
core_url=f"localhost:{free_port}",
spark_launcher="standalone",
spark_standalone_master="local",
spark_home=os.path.dirname(pyspark.__file__),
spark_staging_location=spark_staging_location,
historical_feature_output_location=historical_feature_output_location,
historical_feature_output_format="parquet",
)


@pytest.fixture()
def driver_entity(client):
return client.apply_entity(Entity("driver_id", "description", ValueType.INT32))
Expand Down Expand Up @@ -116,36 +141,36 @@ def transactions_feature_table(spark, client):
df_data = [
(
1001,
datetime(year=2020, month=9, day=1),
datetime(year=2020, month=9, day=1),
datetime(year=2020, month=9, day=1, tzinfo=utc),
datetime(year=2020, month=9, day=1, tzinfo=utc),
50.0,
True,
),
(
1001,
datetime(year=2020, month=9, day=1),
datetime(year=2020, month=9, day=2),
datetime(year=2020, month=9, day=1, tzinfo=utc),
datetime(year=2020, month=9, day=2, tzinfo=utc),
100.0,
True,
),
(
2001,
datetime(year=2020, month=9, day=1),
datetime(year=2020, month=9, day=1),
datetime(year=2020, month=9, day=1, tzinfo=utc),
datetime(year=2020, month=9, day=1, tzinfo=utc),
400.0,
False,
),
(
1001,
datetime(year=2020, month=9, day=2),
datetime(year=2020, month=9, day=1),
datetime(year=2020, month=9, day=2, tzinfo=utc),
datetime(year=2020, month=9, day=1, tzinfo=utc),
200.0,
False,
),
(
1001,
datetime(year=2020, month=9, day=4),
datetime(year=2020, month=9, day=1),
datetime(year=2020, month=9, day=4, tzinfo=utc),
datetime(year=2020, month=9, day=1, tzinfo=utc),
300.0,
False,
),
Expand Down Expand Up @@ -180,20 +205,20 @@ def bookings_feature_table(spark, client):
df_data = [
(
8001,
datetime(year=2020, month=9, day=1),
datetime(year=2020, month=9, day=1),
datetime(year=2020, month=9, day=1, tzinfo=utc),
datetime(year=2020, month=9, day=1, tzinfo=utc),
100,
),
(
8001,
datetime(year=2020, month=9, day=2),
datetime(year=2020, month=9, day=2),
datetime(year=2020, month=9, day=2, tzinfo=utc),
datetime(year=2020, month=9, day=2, tzinfo=utc),
150,
),
(
8002,
datetime(year=2020, month=9, day=2),
datetime(year=2020, month=9, day=2),
datetime(year=2020, month=9, day=2, tzinfo=utc),
datetime(year=2020, month=9, day=2, tzinfo=utc),
200,
),
]
Expand Down Expand Up @@ -225,20 +250,20 @@ def bookings_feature_table_with_mapping(spark, client):
df_data = [
(
8001,
datetime(year=2020, month=9, day=1),
datetime(year=2020, month=9, day=1),
datetime(year=2020, month=9, day=1, tzinfo=utc),
datetime(year=2020, month=9, day=1, tzinfo=utc),
100,
),
(
8001,
datetime(year=2020, month=9, day=2),
datetime(year=2020, month=9, day=2),
datetime(year=2020, month=9, day=2, tzinfo=utc),
datetime(year=2020, month=9, day=2, tzinfo=utc),
150,
),
(
8002,
datetime(year=2020, month=9, day=2),
datetime(year=2020, month=9, day=2),
datetime(year=2020, month=9, day=2, tzinfo=utc),
datetime(year=2020, month=9, day=2, tzinfo=utc),
200,
),
]
Expand Down Expand Up @@ -273,12 +298,12 @@ def test_historical_feature_retrieval_from_local_spark_session(
]
)
df_data = [
(1001, 8001, datetime(year=2020, month=9, day=1),),
(2001, 8001, datetime(year=2020, month=9, day=2),),
(2001, 8002, datetime(year=2020, month=9, day=1),),
(1001, 8001, datetime(year=2020, month=9, day=2),),
(1001, 8001, datetime(year=2020, month=9, day=3),),
(1001, 8001, datetime(year=2020, month=9, day=4),),
(1001, 8001, datetime(year=2020, month=9, day=1, tzinfo=utc)),
(2001, 8001, datetime(year=2020, month=9, day=2, tzinfo=utc)),
(2001, 8002, datetime(year=2020, month=9, day=1, tzinfo=utc)),
(1001, 8001, datetime(year=2020, month=9, day=2, tzinfo=utc)),
(1001, 8001, datetime(year=2020, month=9, day=3, tzinfo=utc)),
(1001, 8001, datetime(year=2020, month=9, day=4, tzinfo=utc)),
]
temp_dir, file_uri = create_temp_parquet_file(
spark, "customer_driver_pair", schema, df_data
Expand All @@ -300,12 +325,12 @@ def test_historical_feature_retrieval_from_local_spark_session(
]
)
expected_joined_df_data = [
(1001, 8001, datetime(year=2020, month=9, day=1), 100.0, 100),
(2001, 8001, datetime(year=2020, month=9, day=2), 400.0, 150),
(2001, 8002, datetime(year=2020, month=9, day=1), 400.0, None),
(1001, 8001, datetime(year=2020, month=9, day=2), 200.0, 150),
(1001, 8001, datetime(year=2020, month=9, day=3), 200.0, 150),
(1001, 8001, datetime(year=2020, month=9, day=4), 300.0, None),
(1001, 8001, datetime(year=2020, month=9, day=1, tzinfo=utc), 100.0, 100),
(2001, 8001, datetime(year=2020, month=9, day=2, tzinfo=utc), 400.0, 150),
(2001, 8002, datetime(year=2020, month=9, day=1, tzinfo=utc), 400.0, None),
(1001, 8001, datetime(year=2020, month=9, day=2, tzinfo=utc), 200.0, 150),
(1001, 8001, datetime(year=2020, month=9, day=3, tzinfo=utc), 200.0, 150),
(1001, 8001, datetime(year=2020, month=9, day=4, tzinfo=utc), 300.0, None),
]
expected_joined_df = spark.createDataFrame(
spark.sparkContext.parallelize(expected_joined_df_data),
Expand All @@ -325,9 +350,9 @@ def test_historical_feature_retrieval_with_field_mappings_from_local_spark_sessi
]
)
df_data = [
(8001, datetime(year=2020, month=9, day=1)),
(8001, datetime(year=2020, month=9, day=2)),
(8002, datetime(year=2020, month=9, day=1)),
(8001, datetime(year=2020, month=9, day=1, tzinfo=utc)),
(8001, datetime(year=2020, month=9, day=2, tzinfo=utc)),
(8002, datetime(year=2020, month=9, day=1, tzinfo=utc)),
]
temp_dir, file_uri = create_temp_parquet_file(spark, "drivers", schema, df_data)
entity_source = FileSource(
Expand All @@ -344,13 +369,86 @@ def test_historical_feature_retrieval_with_field_mappings_from_local_spark_sessi
]
)
expected_joined_df_data = [
(8001, datetime(year=2020, month=9, day=1), 100),
(8001, datetime(year=2020, month=9, day=2), 150),
(8002, datetime(year=2020, month=9, day=1), None),
(8001, datetime(year=2020, month=9, day=1, tzinfo=utc), 100),
(8001, datetime(year=2020, month=9, day=2, tzinfo=utc), 150),
(8002, datetime(year=2020, month=9, day=1, tzinfo=utc), None),
]
expected_joined_df = spark.createDataFrame(
spark.sparkContext.parallelize(expected_joined_df_data),
expected_joined_df_schema,
)
assert_dataframe_equal(joined_df, expected_joined_df)
shutil.rmtree(temp_dir)


@pytest.mark.usefixtures(
"driver_entity",
"customer_entity",
"bookings_feature_table",
"transactions_feature_table",
)
def test_historical_feature_retrieval_with_pandas_dataframe_input(
client_with_local_spark,
):

customer_driver_pairs_pandas_df = pd.DataFrame(
np.array(
[
[1001, 8001, datetime(year=2020, month=9, day=1, tzinfo=utc)],
[2001, 8001, datetime(year=2020, month=9, day=2, tzinfo=utc)],
[2001, 8002, datetime(year=2020, month=9, day=1, tzinfo=utc)],
[1001, 8001, datetime(year=2020, month=9, day=2, tzinfo=utc)],
[1001, 8001, datetime(year=2020, month=9, day=3, tzinfo=utc)],
[1001, 8001, datetime(year=2020, month=9, day=4, tzinfo=utc)],
]
),
columns=["customer_id", "driver_id", "event_timestamp"],
)
customer_driver_pairs_pandas_df = customer_driver_pairs_pandas_df.astype(
{"customer_id": "int32", "driver_id": "int32"}
)

job_output = client_with_local_spark.get_historical_features(
["transactions:total_transactions", "bookings:total_completed_bookings"],
customer_driver_pairs_pandas_df,
)

output_dir = job_output.get_output_file_uri()
joined_df = pd.read_parquet(urlparse(output_dir).path)

expected_joined_df = pd.DataFrame(
np.array(
[
[1001, 8001, datetime(year=2020, month=9, day=1), 100.0, 100],
[2001, 8001, datetime(year=2020, month=9, day=2), 400.0, 150],
[2001, 8002, datetime(year=2020, month=9, day=1), 400.0, None],
[1001, 8001, datetime(year=2020, month=9, day=2), 200.0, 150],
[1001, 8001, datetime(year=2020, month=9, day=3), 200.0, 150],
[1001, 8001, datetime(year=2020, month=9, day=4), 300.0, None],
]
),
columns=[
"customer_id",
"driver_id",
"event_timestamp",
"transactions__total_transactions",
"bookings__total_completed_bookings",
],
)
expected_joined_df = expected_joined_df.astype(
{
"customer_id": "int32",
"driver_id": "int32",
"transactions__total_transactions": "float64",
"bookings__total_completed_bookings": "float64",
}
)

assert_frame_equal(
joined_df.sort_values(
by=["customer_id", "driver_id", "event_timestamp"]
).reset_index(drop=True),
expected_joined_df.sort_values(
by=["customer_id", "driver_id", "event_timestamp"]
).reset_index(drop=True),
)