-
Notifications
You must be signed in to change notification settings - Fork 948
/
test_python_feature_server.py
137 lines (123 loc) · 4.58 KB
/
test_python_feature_server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import json
from datetime import datetime
from typing import List
import pytest
from fastapi.testclient import TestClient
from feast.feast_object import FeastObject
from feast.feature_server import get_app
from tests.integration.feature_repos.repo_configuration import (
construct_universal_feature_views,
)
from tests.integration.feature_repos.universal.entities import (
customer,
driver,
location,
)
@pytest.mark.integration
@pytest.mark.universal_online_stores
def test_get_online_features(python_fs_client):
request_data_dict = {
"features": [
"driver_stats:conv_rate",
"driver_stats:acc_rate",
"driver_stats:avg_daily_trips",
],
"entities": {"driver_id": [5001, 5002]},
}
response = python_fs_client.post(
"/get-online-features", data=json.dumps(request_data_dict)
)
# Check entities and features are present
parsed_response = json.loads(response.text)
assert "metadata" in parsed_response
metadata = parsed_response["metadata"]
expected_features = ["driver_id", "conv_rate", "acc_rate", "avg_daily_trips"]
response_feature_names = metadata["feature_names"]
assert len(response_feature_names) == len(expected_features)
for expected_feature in expected_features:
assert expected_feature in response_feature_names
assert "results" in parsed_response
results = parsed_response["results"]
for result in results:
# Same order as in metadata
assert len(result["statuses"]) == 2 # Requested two entities
for status in result["statuses"]:
assert status == "PRESENT"
results_driver_id_index = response_feature_names.index("driver_id")
assert (
results[results_driver_id_index]["values"]
== request_data_dict["entities"]["driver_id"]
)
@pytest.mark.integration
@pytest.mark.universal_online_stores
def test_push(python_fs_client):
initial_temp = _get_temperatures_from_feature_server(
python_fs_client, location_ids=[1]
)[0]
json_data = json.dumps(
{
"push_source_name": "location_stats_push_source",
"df": {
"location_id": [1],
"temperature": [initial_temp * 100],
"event_timestamp": [str(datetime.utcnow())],
"created": [str(datetime.utcnow())],
},
}
)
response = python_fs_client.post(
"/push",
data=json_data,
)
# Check new pushed temperature is fetched
assert response.status_code == 200
assert _get_temperatures_from_feature_server(
python_fs_client, location_ids=[1]
) == [initial_temp * 100]
@pytest.mark.integration
@pytest.mark.universal_online_stores
def test_push_source_does_not_exist(python_fs_client):
initial_temp = _get_temperatures_from_feature_server(
python_fs_client, location_ids=[1]
)[0]
response = python_fs_client.post(
"/push",
data=json.dumps(
{
"push_source_name": "push_source_does_not_exist",
"df": {
"location_id": [1],
"temperature": [initial_temp * 100],
"event_timestamp": [str(datetime.utcnow())],
"created": [str(datetime.utcnow())],
},
}
),
)
assert response.status_code == 422
def _get_temperatures_from_feature_server(client, location_ids: List[int]):
get_request_data = {
"features": ["pushable_location_stats:temperature"],
"entities": {"location_id": location_ids},
}
response = client.post("/get-online-features", data=json.dumps(get_request_data))
parsed_response = json.loads(response.text)
assert "metadata" in parsed_response
metadata = parsed_response["metadata"]
response_feature_names = metadata["feature_names"]
assert "results" in parsed_response
results = parsed_response["results"]
results_temperature_index = response_feature_names.index("temperature")
return results[results_temperature_index]["values"]
@pytest.fixture
def python_fs_client(environment, universal_data_sources, request):
fs = environment.feature_store
entities, datasets, data_sources = universal_data_sources
feature_views = construct_universal_feature_views(data_sources)
feast_objects: List[FeastObject] = []
feast_objects.extend(feature_views.values())
feast_objects.extend([driver(), customer(), location()])
fs.apply(feast_objects)
fs.materialize(environment.start_date, environment.end_date)
client = TestClient(get_app(fs))
yield client