Skip to content

Commit

Permalink
Add project whitelist (#57)
Browse files Browse the repository at this point in the history
* Add project whitelist

Signed-off-by: Terence Lim <terencelimxp@gmail.com>

* Shift project whitelist logic to jobservice

Signed-off-by: Terence Lim <terencelimxp@gmail.com>

* Fix tests

Signed-off-by: Terence Lim <terencelimxp@gmail.com>

* Add whitelist project tests

Signed-off-by: Terence Lim <terencelimxp@gmail.com>

* Fix flaky test

Signed-off-by: Terence Lim <terencelimxp@gmail.com>

* Use config instead of js client

Signed-off-by: Terence Lim <terencelimxp@gmail.com>

* Remove unnecessary code

Signed-off-by: Terence Lim <terencelimxp@gmail.com>
  • Loading branch information
terryyylim committed Apr 20, 2021
1 parent 5730cf4 commit fda56ad
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 26 deletions.
3 changes: 3 additions & 0 deletions python/feast_spark/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,9 @@ class ConfigOptions(metaclass=ConfigMeta):
#: Log path of EMR cluster
EMR_LOG_LOCATION: Optional[str] = None

#: Whitelisted Feast projects
WHITELISTED_PROJECTS: Optional[str] = None

def defaults(self):
return {
k: getattr(self, k)
Expand Down
45 changes: 44 additions & 1 deletion python/feast_spark/job_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import time
import traceback
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Tuple, cast
from typing import Dict, List, Optional, Tuple, cast

import grpc
from google.api_core.exceptions import FailedPrecondition
Expand Down Expand Up @@ -98,10 +98,29 @@ class JobServiceServicer(JobService_pb2_grpc.JobServiceServicer):
def __init__(self, client: Client):
self.client = client

@property
def _whitelisted_projects(self) -> Optional[List[str]]:
if self.client.config.exists(opt.WHITELISTED_PROJECTS):
whitelisted_projects = self.client.config.get(opt.WHITELISTED_PROJECTS)
return whitelisted_projects.split(",")
return None

def is_whitelisted(self, project: str):
# Whitelisted projects not specified, allow all projects
if not self._whitelisted_projects:
return True
return project in self._whitelisted_projects

def StartOfflineToOnlineIngestionJob(
self, request: StartOfflineToOnlineIngestionJobRequest, context
):
"""Start job to ingest data from offline store into online store"""

if not self.is_whitelisted(request.project):
raise ValueError(
f"Project {request.project} is not whitelisted. Please contact your Feast administrator to whitelist it."
)

feature_table = self.client.feature_store.get_feature_table(
request.table_name, request.project
)
Expand All @@ -125,6 +144,12 @@ def StartOfflineToOnlineIngestionJob(

def GetHistoricalFeatures(self, request: GetHistoricalFeaturesRequest, context):
"""Produce a training dataset, return a job id that will provide a file reference"""

if not self.is_whitelisted(request.project):
raise ValueError(
f"Project {request.project} is not whitelisted. Please contact your Feast administrator to whitelist it."
)

job = start_historical_feature_retrieval_job(
client=self.client,
project=request.project,
Expand Down Expand Up @@ -152,6 +177,11 @@ def StartStreamToOnlineIngestionJob(
):
"""Start job to ingest data from stream into online store"""

if not self.is_whitelisted(request.project):
raise ValueError(
f"Project {request.project} is not whitelisted. Please contact your Feast administrator to whitelist it."
)

feature_table = self.client.feature_store.get_feature_table(
request.table_name, request.project
)
Expand Down Expand Up @@ -196,6 +226,12 @@ def StartStreamToOnlineIngestionJob(

def ListJobs(self, request, context):
"""List all types of jobs"""

if not self.is_whitelisted(request.project):
raise ValueError(
f"Project {request.project} is not whitelisted. Please contact your Feast administrator to whitelist it."
)

jobs = list_jobs(
include_terminated=request.include_terminated,
project=request.project,
Expand Down Expand Up @@ -326,6 +362,13 @@ def ensure_stream_ingestion_jobs(client: Client, all_projects: bool):
if all_projects
else [client.feature_store.project]
)
if client.config.exists(opt.WHITELISTED_PROJECTS):
whitelisted_projects = client.config.get(opt.WHITELISTED_PROJECTS)
if whitelisted_projects:
whitelisted_projects = whitelisted_projects.split(",")
projects = [
project for project in projects if project in whitelisted_projects
]

expected_job_hash_to_tables = _get_expected_job_hash_to_tables(client, projects)

Expand Down
67 changes: 42 additions & 25 deletions python/tests/test_streaming_job_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@

@pytest.fixture
def feast_client():
c = FeastClient(job_service_pause_between_jobs=0)
c.list_projects = Mock(return_value=["default"])
c = FeastClient(
job_service_pause_between_jobs=0,
options={"whitelisted_projects": "default,ride"},
)
c.list_projects = Mock(return_value=["default", "ride", "invalid_project"])
c.list_feature_tables = Mock()

yield c
Expand Down Expand Up @@ -51,15 +54,18 @@ def feature_table():


class SimpleStreamingIngestionJob(StreamIngestionJob):
def __init__(self, id: str, feature_table: FeatureTable, status: SparkJobStatus):
def __init__(
self, id: str, project: str, feature_table: FeatureTable, status: SparkJobStatus
):
self._id = id
self._feature_table = feature_table
self._project = project
self._status = status
self._hash = hash

def get_hash(self) -> str:
source = _source_to_argument(self._feature_table.stream_source, Config())
feature_table = _feature_table_to_argument(None, "default", self._feature_table) # type: ignore
feature_table = _feature_table_to_argument(None, self._project, self._feature_table) # type: ignore

job_json = json.dumps(
{"source": source, "feature_table": feature_table}, sort_keys=True,
Expand Down Expand Up @@ -90,18 +96,21 @@ def test_new_job_creation(spark_client, feature_table):

ensure_stream_ingestion_jobs(spark_client, all_projects=True)

spark_client.start_stream_to_online_ingestion.assert_called_once_with(
feature_table, [], project="default"
)
assert spark_client.start_stream_to_online_ingestion.call_count == 2


def test_no_changes(spark_client, feature_table):
""" Feature Table spec is the same """

job = SimpleStreamingIngestionJob("", feature_table, SparkJobStatus.IN_PROGRESS)
job = SimpleStreamingIngestionJob(
"", "default", feature_table, SparkJobStatus.IN_PROGRESS
)
job2 = SimpleStreamingIngestionJob(
"", "ride", feature_table, SparkJobStatus.IN_PROGRESS
)

spark_client.feature_store.list_feature_tables.return_value = [feature_table]
spark_client.list_jobs.return_value = [job]
spark_client.list_jobs.return_value = [job, job2]

ensure_stream_ingestion_jobs(spark_client, all_projects=True)

Expand All @@ -114,41 +123,43 @@ def test_update_existing_job(spark_client, feature_table):

new_ft = copy.deepcopy(feature_table)
new_ft.stream_source._kafka_options.topic = "new_t"
job = SimpleStreamingIngestionJob("", feature_table, SparkJobStatus.IN_PROGRESS)
job = SimpleStreamingIngestionJob(
"", "default", feature_table, SparkJobStatus.IN_PROGRESS
)

spark_client.feature_store.list_feature_tables.return_value = [new_ft]
spark_client.list_jobs.return_value = [job]

ensure_stream_ingestion_jobs(spark_client, all_projects=True)

assert job.get_status() == SparkJobStatus.COMPLETED
spark_client.start_stream_to_online_ingestion.assert_called_once_with(
new_ft, [], project="default"
)
assert spark_client.start_stream_to_online_ingestion.call_count == 2


def test_not_cancelling_starting_job(spark_client, feature_table):
""" Feature Table spec was updated but previous version is still starting """

new_ft = copy.deepcopy(feature_table)
new_ft.stream_source._kafka_options.topic = "new_t"
job = SimpleStreamingIngestionJob("", feature_table, SparkJobStatus.STARTING)
job = SimpleStreamingIngestionJob(
"", "default", feature_table, SparkJobStatus.STARTING
)

spark_client.feature_store.list_feature_tables.return_value = [new_ft]
spark_client.list_jobs.return_value = [job]

ensure_stream_ingestion_jobs(spark_client, all_projects=True)

assert job.get_status() == SparkJobStatus.STARTING
spark_client.start_stream_to_online_ingestion.assert_called_once_with(
new_ft, [], project="default"
)
assert spark_client.start_stream_to_online_ingestion.call_count == 2


def test_not_retrying_failed_job(spark_client, feature_table):
""" Job has failed on previous try """

job = SimpleStreamingIngestionJob("", feature_table, SparkJobStatus.FAILED)
job = SimpleStreamingIngestionJob(
"", "default", feature_table, SparkJobStatus.FAILED
)

spark_client.feature_store.list_feature_tables.return_value = [feature_table]
spark_client.list_jobs.return_value = [job]
Expand All @@ -157,29 +168,33 @@ def test_not_retrying_failed_job(spark_client, feature_table):

spark_client.list_jobs.assert_called_once_with(include_terminated=True)
assert job.get_status() == SparkJobStatus.FAILED
spark_client.start_stream_to_online_ingestion.assert_not_called()
spark_client.start_stream_to_online_ingestion.assert_called_once_with(
feature_table, [], project="ride"
)


def test_restarting_completed_job(spark_client, feature_table):
""" Job has succesfully finished on previous try """
job = SimpleStreamingIngestionJob("", feature_table, SparkJobStatus.COMPLETED)
job = SimpleStreamingIngestionJob(
"", "default", feature_table, SparkJobStatus.COMPLETED
)

spark_client.feature_store.list_feature_tables.return_value = [feature_table]
spark_client.list_jobs.return_value = [job]

ensure_stream_ingestion_jobs(spark_client, all_projects=True)

spark_client.start_stream_to_online_ingestion.assert_called_once_with(
feature_table, [], project="default"
)
assert spark_client.start_stream_to_online_ingestion.call_count == 2


def test_stopping_running_job(spark_client, feature_table):
""" Streaming source was deleted """
new_ft = copy.deepcopy(feature_table)
new_ft.stream_source = None

job = SimpleStreamingIngestionJob("", feature_table, SparkJobStatus.IN_PROGRESS)
job = SimpleStreamingIngestionJob(
"", "default", feature_table, SparkJobStatus.IN_PROGRESS
)

spark_client.feature_store.list_feature_tables.return_value = [new_ft]
spark_client.list_jobs.return_value = [job]
Expand All @@ -194,7 +209,9 @@ def test_restarting_failed_jobs(feature_table):
""" If configured - restart failed jobs """

feast_client = FeastClient(
job_service_pause_between_jobs=0, job_service_retry_failed_jobs=True
job_service_pause_between_jobs=0,
job_service_retry_failed_jobs=True,
options={"whitelisted_projects": "default,ride"},
)
feast_client.list_projects = Mock(return_value=["default"])
feast_client.list_feature_tables = Mock()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class SparkSpec extends UnitSpec with BeforeAndAfter {
val sparkConf = new SparkConf()
.setMaster("local[4]")
.setAppName("Testing")
.set("spark.driver.bindAddress", "localhost")
.set("spark.default.parallelism", "8")
.set(
"spark.metrics.conf.*.sink.statsd.class",
Expand Down

0 comments on commit fda56ad

Please sign in to comment.