Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Refactor testing and sort out unit and integration tests #2975

Merged
merged 30 commits into from Jul 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
20 changes: 4 additions & 16 deletions sdk/python/tests/conftest.py
Expand Up @@ -14,8 +14,6 @@
import logging
import multiprocessing
import os
import socket
from contextlib import closing
from datetime import datetime, timedelta
from multiprocessing import Process
from sys import platform
Expand Down Expand Up @@ -45,6 +43,7 @@
from tests.integration.feature_repos.universal.data_sources.file import ( # noqa: E402
FileDataSourceCreator,
)
from tests.utils.http_server import check_port_open, free_port # noqa: E402

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -327,7 +326,7 @@ def feature_server_endpoint(environment):
yield environment.feature_store.get_feature_server_endpoint()
return

port = _free_port()
port = free_port()

proc = Process(
target=start_test_local_server,
Expand All @@ -340,7 +339,7 @@ def feature_server_endpoint(environment):
proc.start()
# Wait for server to start
wait_retry_backoff(
lambda: (None, _check_port_open("localhost", port)),
lambda: (None, check_port_open("localhost", port)),
timeout_secs=10,
)

Expand All @@ -353,23 +352,12 @@ def feature_server_endpoint(environment):
wait_retry_backoff(
lambda: (
None,
not _check_port_open("localhost", environment.get_local_server_port()),
not check_port_open("localhost", environment.get_local_server_port()),
),
timeout_secs=30,
)


def _check_port_open(host, port) -> bool:
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
return sock.connect_ex((host, port)) == 0


def _free_port():
sock = socket.socket()
sock.bind(("", 0))
return sock.getsockname()[1]


@pytest.fixture
def universal_data_sources(environment) -> TestData:
return construct_universal_test_data(environment)
Expand Down
215 changes: 91 additions & 124 deletions sdk/python/tests/integration/e2e/test_go_feature_server.py
@@ -1,7 +1,5 @@
import socket
import threading
import time
from contextlib import closing
from datetime import datetime
from typing import List

Expand All @@ -11,10 +9,10 @@
import pytz
import requests

from feast import FeatureService, FeatureView, ValueType
from feast.embedded_go.online_features_service import EmbeddedOnlineFeatureServer
from feast.feast_object import FeastObject
from feast.feature_logging import LoggingConfig
from feast.feature_service import FeatureService
from feast.infra.feature_servers.base_config import FeatureLoggingConfig
from feast.protos.feast.serving.ServingService_pb2 import (
FieldStatus,
Expand All @@ -24,6 +22,7 @@
from feast.protos.feast.serving.ServingService_pb2_grpc import ServingServiceStub
from feast.protos.feast.types.Value_pb2 import RepeatedValue
from feast.type_map import python_values_to_proto_values
from feast.value_type import ValueType
from feast.wait import wait_retry_backoff
from tests.integration.feature_repos.repo_configuration import (
construct_universal_feature_views,
Expand All @@ -33,94 +32,8 @@
driver,
location,
)


@pytest.fixture
def initialized_registry(environment, universal_data_sources):
fs = environment.feature_store

_, _, data_sources = universal_data_sources
feature_views = construct_universal_feature_views(data_sources)

feature_service = FeatureService(
name="driver_features",
features=[feature_views.driver],
logging_config=LoggingConfig(
destination=environment.data_source_creator.create_logged_features_destination(),
sample_rate=1.0,
),
)
feast_objects: List[FeastObject] = [feature_service]
feast_objects.extend(feature_views.values())
feast_objects.extend([driver(), customer(), location()])

fs.apply(feast_objects)
fs.materialize(environment.start_date, environment.end_date)


def server_port(environment, server_type: str):
if not environment.test_repo_config.go_feature_serving:
pytest.skip("Only for Go path")

fs = environment.feature_store

embedded = EmbeddedOnlineFeatureServer(
repo_path=str(fs.repo_path.absolute()),
repo_config=fs.config,
feature_store=fs,
)
port = free_port()
if server_type == "grpc":
target = embedded.start_grpc_server
elif server_type == "http":
target = embedded.start_http_server
else:
raise ValueError("Server Type must be either 'http' or 'grpc'")

t = threading.Thread(
target=target,
args=("127.0.0.1", port),
kwargs=dict(
enable_logging=True,
logging_options=FeatureLoggingConfig(
enabled=True,
queue_capacity=100,
write_to_disk_interval_secs=1,
flush_interval_secs=1,
emit_timeout_micro_secs=10000,
),
),
)
t.start()

wait_retry_backoff(
lambda: (None, check_port_open("127.0.0.1", port)), timeout_secs=15
)

yield port
if server_type == "grpc":
embedded.stop_grpc_server()
else:
embedded.stop_http_server()

# wait for graceful stop
time.sleep(5)


@pytest.fixture
def grpc_server_port(environment, initialized_registry):
yield from server_port(environment, "grpc")


@pytest.fixture
def http_server_port(environment, initialized_registry):
yield from server_port(environment, "http")


@pytest.fixture
def grpc_client(grpc_server_port):
ch = grpc.insecure_channel(f"localhost:{grpc_server_port}")
yield ServingServiceStub(ch)
from tests.utils.http_server import check_port_open, free_port
from tests.utils.test_log_creator import generate_expected_logs, get_latest_rows


@pytest.mark.integration
Expand Down Expand Up @@ -254,43 +167,97 @@ def retrieve():
pd.testing.assert_frame_equal(expected_logs, persisted_logs, check_dtype=False)


def free_port():
sock = socket.socket()
sock.bind(("", 0))
return sock.getsockname()[1]
"""
Start go feature server either on http or grpc based on the repo configuration for testing.
"""


def check_port_open(host, port) -> bool:
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
return sock.connect_ex((host, port)) == 0
def _server_port(environment, server_type: str):
if not environment.test_repo_config.go_feature_serving:
pytest.skip("Only for Go path")

fs = environment.feature_store

def get_latest_rows(df, join_key, entity_values):
rows = df[df[join_key].isin(entity_values)]
return rows.loc[rows.groupby(join_key)["event_timestamp"].idxmax()]
embedded = EmbeddedOnlineFeatureServer(
repo_path=str(fs.repo_path.absolute()),
repo_config=fs.config,
feature_store=fs,
)
port = free_port()
if server_type == "grpc":
target = embedded.start_grpc_server
elif server_type == "http":
target = embedded.start_http_server
else:
raise ValueError("Server Type must be either 'http' or 'grpc'")

t = threading.Thread(
target=target,
args=("127.0.0.1", port),
kwargs=dict(
enable_logging=True,
logging_options=FeatureLoggingConfig(
enabled=True,
queue_capacity=100,
write_to_disk_interval_secs=1,
flush_interval_secs=1,
emit_timeout_micro_secs=10000,
),
),
)
t.start()

wait_retry_backoff(
lambda: (None, check_port_open("127.0.0.1", port)), timeout_secs=15
)

def generate_expected_logs(
df: pd.DataFrame,
feature_view: FeatureView,
features: List[str],
join_keys: List[str],
timestamp_column: str,
):
logs = pd.DataFrame()
for join_key in join_keys:
logs[join_key] = df[join_key]

for feature in features:
col = f"{feature_view.name}__{feature}"
logs[col] = df[feature]
logs[f"{col}__timestamp"] = df[timestamp_column]
logs[f"{col}__status"] = FieldStatus.PRESENT
if feature_view.ttl:
logs[f"{col}__status"] = logs[f"{col}__status"].mask(
df[timestamp_column]
< datetime.utcnow().replace(tzinfo=pytz.UTC) - feature_view.ttl,
FieldStatus.OUTSIDE_MAX_AGE,
)
yield port
if server_type == "grpc":
embedded.stop_grpc_server()
else:
embedded.stop_http_server()

return logs.sort_values(by=join_keys).reset_index(drop=True)
# wait for graceful stop
time.sleep(5)


# Go test fixtures


@pytest.fixture
def initialized_registry(environment, universal_data_sources):
fs = environment.feature_store

_, _, data_sources = universal_data_sources
feature_views = construct_universal_feature_views(data_sources)

feature_service = FeatureService(
name="driver_features",
features=[feature_views.driver],
logging_config=LoggingConfig(
destination=environment.data_source_creator.create_logged_features_destination(),
sample_rate=1.0,
),
)
feast_objects: List[FeastObject] = [feature_service]
feast_objects.extend(feature_views.values())
feast_objects.extend([driver(), customer(), location()])

fs.apply(feast_objects)
fs.materialize(environment.start_date, environment.end_date)


@pytest.fixture
def grpc_server_port(environment, initialized_registry):
yield from _server_port(environment, "grpc")


@pytest.fixture
def http_server_port(environment, initialized_registry):
yield from _server_port(environment, "http")


@pytest.fixture
def grpc_client(grpc_server_port):
ch = grpc.insecure_channel(f"localhost:{grpc_server_port}")
yield ServingServiceStub(ch)
10 changes: 7 additions & 3 deletions sdk/python/tests/integration/e2e/test_python_feature_server.py
Expand Up @@ -58,7 +58,9 @@ def test_get_online_features(python_fs_client):
@pytest.mark.integration
@pytest.mark.universal_online_stores
def test_push(python_fs_client):
initial_temp = get_temperatures(python_fs_client, location_ids=[1])[0]
initial_temp = _get_temperatures_from_feature_server(
python_fs_client, location_ids=[1]
)[0]
json_data = json.dumps(
{
"push_source_name": "location_stats_push_source",
Expand All @@ -77,10 +79,12 @@ def test_push(python_fs_client):

# Check new pushed temperature is fetched
assert response.status_code == 200
assert get_temperatures(python_fs_client, location_ids=[1]) == [initial_temp * 100]
assert _get_temperatures_from_feature_server(
python_fs_client, location_ids=[1]
) == [initial_temp * 100]


def get_temperatures(client, location_ids: List[int]):
def _get_temperatures_from_feature_server(client, location_ids: List[int]):
get_request_data = {
"features": ["pushable_location_stats:temperature"],
"entities": {"location_id": location_ids},
Expand Down