diff --git a/sdk/python/feast/job.py b/sdk/python/feast/job.py index 414139e0e5a..25396213e47 100644 --- a/sdk/python/feast/job.py +++ b/sdk/python/feast/job.py @@ -1,10 +1,8 @@ -import tempfile from typing import List from urllib.parse import urlparse import fastavro import pandas as pd -from google.cloud import storage from google.protobuf.json_format import MessageToJson from feast.constants import CONFIG_TIMEOUT_KEY @@ -23,9 +21,17 @@ from feast.serving.ServingService_pb2 import Job as JobProto from feast.serving.ServingService_pb2_grpc import ServingServiceStub from feast.source import Source +from feast.staging.storage_client import get_staging_client from feast.wait import wait_retry_backoff from tensorflow_metadata.proto.v0 import statistics_pb2 +# Maximum no of seconds to wait until the retrieval jobs status is DONE in Feast +# Currently set to the maximum query execution time limit in BigQuery +DEFAULT_TIMEOUT_SEC: int = 21600 + +# Maximum no of seconds to wait before reloading the job status in Feast +MAX_WAIT_INTERVAL_SEC: int = 60 + class RetrievalJob: """ @@ -42,8 +48,6 @@ def __init__( """ self.job_proto = job_proto self.serving_stub = serving_stub - # TODO: abstract away GCP depedency - self.gcs_client = storage.Client(project=None) @property def id(self): @@ -117,16 +121,7 @@ def result(self, timeout_sec: int = int(defaults[CONFIG_TIMEOUT_KEY])): """ uris = self.get_avro_files(timeout_sec) for file_uri in uris: - if file_uri.scheme == "gs": - file_obj = tempfile.TemporaryFile() - self.gcs_client.download_blob_to_file(file_uri.geturl(), file_obj) - elif file_uri.scheme == "file": - file_obj = open(file_uri.path, "rb") - else: - raise Exception( - f"Could not identify file URI {file_uri}. Only gs:// and file:// supported" - ) - + file_obj = get_staging_client(file_uri.scheme).download_file(file_uri) file_obj.seek(0) avro_reader = fastavro.reader(file_obj) diff --git a/sdk/python/feast/loaders/file.py b/sdk/python/feast/loaders/file.py index 5fffd62ea32..cad82cbced1 100644 --- a/sdk/python/feast/loaders/file.py +++ b/sdk/python/feast/loaders/file.py @@ -13,18 +13,18 @@ # limitations under the License. import os -import re import shutil import tempfile import uuid from datetime import datetime from typing import List, Optional, Tuple, Union -from urllib.parse import ParseResult, urlparse +from urllib.parse import urlparse import pandas as pd -from google.cloud import storage from pandavro import to_avro +from feast.staging.storage_client import get_staging_client + def export_source_to_staging_location( source: Union[pd.DataFrame, str], staging_location_uri: str @@ -44,12 +44,14 @@ def export_source_to_staging_location( * Pandas DataFrame * Local Avro file * GCS Avro file + * S3 Avro file staging_location_uri (str): Remote staging location where DataFrame should be written. Examples: * gs://bucket/path/ + * s3://bucket/path/ * file:///data/subfolder/ Returns: @@ -66,28 +68,24 @@ def export_source_to_staging_location( uri_path = None # type: Optional[str] if uri.scheme == "file": uri_path = uri.path - # Remote gs staging location provided by serving dir_path, file_name, source_path = export_dataframe_to_local( df=source, dir_path=uri_path ) - elif urlparse(source).scheme in ["", "file"]: - # Local file provided as a source - dir_path = "" - file_name = os.path.basename(source) - source_path = os.path.abspath( - os.path.join(urlparse(source).netloc, urlparse(source).path) - ) - elif urlparse(source).scheme == "gs": - # Google Cloud Storage path provided - input_source_uri = urlparse(source) - if "*" in source: - # Wildcard path - return _get_files( - bucket=str(input_source_uri.hostname), uri=input_source_uri + elif isinstance(source, str): + source_uri = urlparse(source) + if source_uri.scheme in ["", "file"]: + # Local file provided as a source + dir_path = "" + file_name = os.path.basename(source) + source_path = os.path.abspath( + os.path.join(source_uri.netloc, source_uri.path) ) else: - return [source] + # gs, s3 file provided as a source. + return get_staging_client(source_uri.scheme).list_files( + bucket=source_uri.hostname, path=source_uri.path + ) else: raise Exception( f"Only string and DataFrame types are allowed as a " @@ -95,23 +93,12 @@ def export_source_to_staging_location( ) # Push data to required staging location - if uri.scheme == "gs": - # Staging location is a Google Cloud Storage path - upload_file_to_gcs( - source_path, str(uri.hostname), str(uri.path).strip("/") + "/" + file_name - ) - elif uri.scheme == "file": - # Staging location is a file path - # Used for end-to-end test - pass - else: - raise Exception( - f"Staging location {staging_location_uri} does not have a " - f"valid URI. Only gs:// and file:// uri scheme are supported." - ) + get_staging_client(uri.scheme).upload_file( + source_path, uri.hostname, str(uri.path).strip("/") + "/" + file_name, + ) # Clean up, remove local staging file - if isinstance(source, pd.DataFrame) and len(str(dir_path)) > 4: + if dir_path and isinstance(source, pd.DataFrame) and len(dir_path) > 4: shutil.rmtree(dir_path) return [staging_location_uri.rstrip("/") + "/" + file_name] @@ -162,70 +149,6 @@ def export_dataframe_to_local( return dir_path, file_name, dest_path -def upload_file_to_gcs(local_path: str, bucket: str, remote_path: str) -> None: - """ - Upload a file from the local file system to Google Cloud Storage (GCS). - - Args: - local_path (str): - Local filesystem path of file to upload. - - bucket (str): - GCS bucket destination to upload to. - - remote_path (str): - Path within GCS bucket to upload file to, includes file name. - - Returns: - None: - None - """ - - storage_client = storage.Client(project=None) - bucket_storage = storage_client.get_bucket(bucket) - blob = bucket_storage.blob(remote_path) - blob.upload_from_filename(local_path) - - -def _get_files(bucket: str, uri: ParseResult) -> List[str]: - """ - List all available files within a Google storage bucket that matches a wild - card path. - - Args: - bucket (str): - Google Storage bucket to reference. - - uri (urllib.parse.ParseResult): - Wild card uri path containing the "*" character. - Example: - * gs://feast/staging_location/* - * gs://feast/staging_location/file_*.avro - - Returns: - List[str]: - List of all available files matching the wildcard path. - """ - - storage_client = storage.Client(project=None) - bucket_storage = storage_client.get_bucket(bucket) - path = uri.path - - if "*" in path: - regex = re.compile(path.replace("*", ".*?").strip("/")) - blob_list = bucket_storage.list_blobs( - prefix=path.strip("/").split("*")[0], delimiter="/" - ) - # File path should not be in path (file path must be longer than path) - return [ - f"{uri.scheme}://{uri.hostname}/{file}" - for file in [x.name for x in blob_list] - if re.match(regex, file) and file not in path - ] - else: - raise Exception(f"{path} is not a wildcard path") - - def _get_file_name() -> str: """ Create a random file name. diff --git a/sdk/python/feast/staging/__init__.py b/sdk/python/feast/staging/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/sdk/python/feast/staging/storage_client.py b/sdk/python/feast/staging/storage_client.py new file mode 100644 index 00000000000..a71a3b2b5d9 --- /dev/null +++ b/sdk/python/feast/staging/storage_client.py @@ -0,0 +1,252 @@ +# +# Copyright 2020 The Feast Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import re +from abc import ABC, ABCMeta, abstractmethod +from tempfile import TemporaryFile +from typing import List +from typing.io import IO +from urllib.parse import ParseResult + +GS = "gs" +S3 = "s3" +LOCAL_FILE = "file" + + +class AbstractStagingClient(ABC): + """ + Client used to stage files in order to upload or download datasets into a historical store. + """ + + __metaclass__ = ABCMeta + + @abstractmethod + def __init__(self): + pass + + @abstractmethod + def download_file(self, uri: ParseResult) -> IO[bytes]: + """ + Downloads a file from an object store and returns a TemporaryFile object + """ + pass + + @abstractmethod + def list_files(self, bucket: str, path: str) -> List[str]: + """ + Lists all the files under a directory in an object store. + """ + pass + + @abstractmethod + def upload_file(self, local_path: str, bucket: str, remote_path: str): + """ + Uploads a file to an object store. + """ + pass + + +class GCSClient(AbstractStagingClient): + """ + Implementation of AbstractStagingClient for google cloud storage + """ + + def __init__(self): + try: + from google.cloud import storage + except ImportError: + raise ImportError( + "Install package google-cloud-storage==1.20.* for gcs staging support" + "run ```pip install google-cloud-storage==1.20.*```" + ) + self.gcs_client = storage.Client(project=None) + + def download_file(self, uri: ParseResult) -> IO[bytes]: + """ + Downloads a file from google cloud storage and returns a TemporaryFile object + + Args: + uri (urllib.parse.ParseResult): Parsed uri of the file ex: urlparse("gs://bucket/file.avro") + + Returns: + TemporaryFile object + """ + url = uri.geturl() + file_obj = TemporaryFile() + self.gcs_client.download_blob_to_file(url, file_obj) + return file_obj + + def list_files(self, bucket: str, path: str) -> List[str]: + """ + Lists all the files under a directory in google cloud storage if path has wildcard(*) character. + + Args: + bucket (str): google cloud storage bucket name + path (str): object location in google cloud storage. + + Returns: + List[str]: A list containing the full path to the file(s) in the + remote staging location. + """ + + gs_bucket = self.gcs_client.get_bucket(bucket) + + if "*" in path: + regex = re.compile(path.replace("*", ".*?").strip("/")) + blob_list = gs_bucket.list_blobs( + prefix=path.strip("/").split("*")[0], delimiter="/" + ) + # File path should not be in path (file path must be longer than path) + return [ + f"{GS}://{bucket}/{file}" + for file in [x.name for x in blob_list] + if re.match(regex, file) and file not in path + ] + else: + return [f"{S3}://{bucket}/{path.lstrip('/')}"] + + def upload_file(self, local_path: str, bucket: str, remote_path: str): + """ + Uploads file to google cloud storage. + + Args: + local_path (str): Path to the local file that needs to be uploaded/staged + bucket (str): gs Bucket name + remote_path (str): relative path to the folder to which the files need to be uploaded + """ + gs_bucket = self.gcs_client.get_bucket(bucket) + blob = gs_bucket.blob(remote_path) + blob.upload_from_filename(local_path) + + +class S3Client(AbstractStagingClient): + """ + Implementation of AbstractStagingClient for Aws S3 storage + """ + + def __init__(self): + try: + import boto3 + except ImportError: + raise ImportError( + "Install package boto3 for s3 staging support" + "run ```pip install boto3```" + ) + self.s3_client = boto3.client("s3") + + def download_file(self, uri: ParseResult) -> IO[bytes]: + """ + Downloads a file from AWS s3 storage and returns a TemporaryFile object + + Args: + uri (urllib.parse.ParseResult): Parsed uri of the file ex: urlparse("s3://bucket/file.avro") + Returns: + TemporaryFile object + """ + url = uri.path.lstrip("/") + bucket = uri.hostname + file_obj = TemporaryFile() + self.s3_client.download_fileobj(bucket, url, file_obj) + return file_obj + + def list_files(self, bucket: str, path: str) -> List[str]: + """ + Lists all the files under a directory in s3 if path has wildcard(*) character. + + Args: + bucket (str): s3 bucket name. + path (str): Object location in s3. + + Returns: + List[str]: A list containing the full path to the file(s) in the + remote staging location. + """ + + if "*" in path: + regex = re.compile(path.replace("*", ".*?").strip("/")) + blob_list = self.s3_client.list_objects( + Bucket=bucket, Prefix=path.strip("/").split("*")[0], Delimiter="/" + ) + # File path should not be in path (file path must be longer than path) + return [ + f"{S3}://{bucket}/{file}" + for file in [x["Key"] for x in blob_list["Contents"]] + if re.match(regex, file) and file not in path + ] + else: + return [f"{S3}://{bucket}/{path.lstrip('/')}"] + + def upload_file(self, local_path: str, bucket: str, remote_path: str): + """ + Uploads file to s3. + + Args: + local_path (str): Path to the local file that needs to be uploaded/staged + bucket (str): s3 Bucket name + remote_path (str): relative path to the folder to which the files need to be uploaded + """ + with open(local_path, "rb") as file: + self.s3_client.upload_fileobj(file, bucket, remote_path) + + +class LocalFSClient(AbstractStagingClient): + """ + Implementation of AbstractStagingClient for local file + Note: The is used for E2E tests. + """ + + def __init__(self): + pass + + def download_file(self, uri: ParseResult) -> IO[bytes]: + """ + Reads a local file from the disk + + Args: + uri (urllib.parse.ParseResult): Parsed uri of the file ex: urlparse("file://folder/file.avro") + Returns: + TemporaryFile object + """ + url = uri.path + file_obj = open(url, "rb") + return file_obj + + def list_files(self, bucket: str, path: str) -> List[str]: + raise NotImplementedError("list files not implemented for Local file") + + def upload_file(self, local_path: str, bucket: str, remote_path: str): + pass # For test cases + + +storage_clients = {GS: GCSClient, S3: S3Client, LOCAL_FILE: LocalFSClient} + + +def get_staging_client(scheme): + """ + Initialization of a specific client object(GCSClient, S3Client etc.) + + Args: + scheme (str): uri scheme: s3, gs or file + + Returns: + An object of concrete implementation of AbstractStagingClient + """ + try: + return storage_clients[scheme]() + except ValueError: + raise Exception( + f"Could not identify file scheme {scheme}. Only gs://, file:// and s3:// are supported" + ) diff --git a/sdk/python/requirements-ci.txt b/sdk/python/requirements-ci.txt index 03abbb57c37..726abb794fe 100644 --- a/sdk/python/requirements-ci.txt +++ b/sdk/python/requirements-ci.txt @@ -11,5 +11,6 @@ pytest-ordering==0.6.* pandas==0.* mock==2.0.0 pandavro==1.5.* +moto mypy -mypy-protobuf +mypy-protobuf \ No newline at end of file diff --git a/sdk/python/requirements-dev.txt b/sdk/python/requirements-dev.txt index f24141fb491..ca341d001b1 100644 --- a/sdk/python/requirements-dev.txt +++ b/sdk/python/requirements-dev.txt @@ -35,4 +35,6 @@ mypy mypy-protobuf pre-commit flake8 -black \ No newline at end of file +black +boto3 +moto \ No newline at end of file diff --git a/sdk/python/tests/loaders/__init__.py b/sdk/python/tests/loaders/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/sdk/python/tests/loaders/test_file.py b/sdk/python/tests/loaders/test_file.py new file mode 100644 index 00000000000..9d02447ab35 --- /dev/null +++ b/sdk/python/tests/loaders/test_file.py @@ -0,0 +1,93 @@ +# +# Copyright 2020 The Feast Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import tempfile +from unittest.mock import patch +from urllib.parse import urlparse + +import boto3 +import fastavro +import pandas as pd +import pandavro +from moto import mock_s3 +from pandas.testing import assert_frame_equal +from pytest import fixture + +from feast.loaders.file import export_source_to_staging_location + +BUCKET = "test_bucket" +FOLDER_NAME = "test_folder" +FILE_NAME = "test.avro" + +LOCAL_FILE = "file://tmp/tmp" +S3_LOCATION = f"s3://{BUCKET}/{FOLDER_NAME}" + +TEST_DATA_FRAME = pd.DataFrame( + { + "driver": [1001, 1002, 1003], + "transaction": [1001, 1002, 1003], + "driver_id": [1001, 1002, 1003], + } +) + + +@fixture +def avro_data_path(): + final_results = tempfile.mktemp() + pandavro.to_avro(file_path_or_buffer=final_results, df=TEST_DATA_FRAME) + return final_results + + +@patch("feast.loaders.file._get_file_name", return_value=FILE_NAME) +def test_export_source_to_staging_location_local_file_should_pass(get_file_name): + source = export_source_to_staging_location(TEST_DATA_FRAME, LOCAL_FILE) + assert source == [f"{LOCAL_FILE}/{FILE_NAME}"] + assert get_file_name.call_count == 1 + + +@mock_s3 +@patch("feast.loaders.file._get_file_name", return_value=FILE_NAME) +def test_export_source_to_staging_location_dataframe_to_s3_should_pass(get_file_name): + s3_client = boto3.client("s3") + s3_client.create_bucket(Bucket=BUCKET) + source = export_source_to_staging_location(TEST_DATA_FRAME, S3_LOCATION) + file_obj = tempfile.TemporaryFile() + uri = urlparse(source[0]) + s3_client.download_fileobj(uri.hostname, uri.path[1:], file_obj) + file_obj.seek(0) + avro_reader = fastavro.reader(file_obj) + retrived_df = pd.DataFrame.from_records([r for r in avro_reader]) + assert_frame_equal(retrived_df, TEST_DATA_FRAME, check_like=True) + assert get_file_name.call_count == 1 + + +def test_export_source_to_staging_location_s3_file_as_source_should_pass(): + source = export_source_to_staging_location(S3_LOCATION, None) + assert source == [S3_LOCATION] + + +@mock_s3 +def test_export_source_to_staging_location_s3_wildcard_as_source_should_pass( + avro_data_path, +): + s3_client = boto3.client("s3") + s3_client.create_bucket(Bucket=BUCKET) + with open(avro_data_path, "rb") as data: + s3_client.upload_fileobj(data, BUCKET, f"{FOLDER_NAME}/file1.avro") + with open(avro_data_path, "rb") as data: + s3_client.upload_fileobj(data, BUCKET, f"{FOLDER_NAME}/file2.avro") + sources = export_source_to_staging_location(f"{S3_LOCATION}/*", None) + assert sources == [f"{S3_LOCATION}/file1.avro", f"{S3_LOCATION}/file2.avro"] diff --git a/sdk/python/tests/test_job.py b/sdk/python/tests/test_job.py new file mode 100644 index 00000000000..a9a25fcee3f --- /dev/null +++ b/sdk/python/tests/test_job.py @@ -0,0 +1,144 @@ +# +# Copyright 2020 The Feast Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import tempfile + +import boto3 +import grpc +import pandas as pd +import pandavro +import pytest +from moto import mock_s3 +from pandas.testing import assert_frame_equal +from pytest import fixture, raises + +import feast.serving.ServingService_pb2_grpc as Serving +from feast.job import JobProto, RetrievalJob +from feast.serving.ServingService_pb2 import DataFormat, GetJobResponse +from feast.serving.ServingService_pb2 import Job as BatchRetrievalJob +from feast.serving.ServingService_pb2 import JobStatus, JobType + +BUCKET = "test_bucket" + +TEST_DATA_FRAME = pd.DataFrame( + { + "driver": [1001, 1002, 1003], + "transaction": [1001, 1002, 1003], + "driver_id": [1001, 1002, 1003], + } +) + + +class TestRetrievalJob: + @fixture + def retrieve_job(self): + + serving_service_stub = Serving.ServingServiceStub(grpc.insecure_channel("")) + job_proto = JobProto( + id="123", + type=JobType.JOB_TYPE_DOWNLOAD, + status=JobStatus.JOB_STATUS_RUNNING, + ) + return RetrievalJob(job_proto, serving_service_stub) + + @fixture + def avro_data_path(self): + final_results = tempfile.mktemp() + pandavro.to_avro(file_path_or_buffer=final_results, df=TEST_DATA_FRAME) + return final_results + + def test_to_dataframe_local_file_staging_should_pass( + self, retrieve_job, avro_data_path, mocker + ): + mocker.patch.object( + retrieve_job.serving_stub, + "GetJob", + return_value=GetJobResponse( + job=BatchRetrievalJob( + id="123", + type=JobType.JOB_TYPE_DOWNLOAD, + status=JobStatus.JOB_STATUS_DONE, + file_uris=[f"file://{avro_data_path}"], + data_format=DataFormat.DATA_FORMAT_AVRO, + ) + ), + ) + retrived_df = retrieve_job.to_dataframe() + assert_frame_equal(TEST_DATA_FRAME, retrived_df, check_like=True) + + @mock_s3 + def test_to_dataframe_s3_file_staging_should_pass( + self, retrieve_job, avro_data_path, mocker + ): + s3_client = boto3.client("s3") + target = "test_proj/test_features.avro" + s3_client.create_bucket(Bucket=BUCKET) + with open(avro_data_path, "rb") as data: + s3_client.upload_fileobj(data, BUCKET, target) + + mocker.patch.object( + retrieve_job.serving_stub, + "GetJob", + return_value=GetJobResponse( + job=BatchRetrievalJob( + id="123", + type=JobType.JOB_TYPE_DOWNLOAD, + status=JobStatus.JOB_STATUS_DONE, + file_uris=[f"s3://{BUCKET}/{target}"], + data_format=DataFormat.DATA_FORMAT_AVRO, + ) + ), + ) + retrived_df = retrieve_job.to_dataframe() + assert_frame_equal(TEST_DATA_FRAME, retrived_df, check_like=True) + + @pytest.mark.parametrize( + "job_proto,exception", + [ + ( + GetJobResponse( + job=BatchRetrievalJob( + id="123", + type=JobType.JOB_TYPE_DOWNLOAD, + status=JobStatus.JOB_STATUS_DONE, + data_format=DataFormat.DATA_FORMAT_AVRO, + error="Testing job failure", + ) + ), + Exception, + ), + ( + GetJobResponse( + job=BatchRetrievalJob( + id="123", + type=JobType.JOB_TYPE_DOWNLOAD, + status=JobStatus.JOB_STATUS_DONE, + data_format=DataFormat.DATA_FORMAT_INVALID, + ) + ), + Exception, + ), + ], + ids=["when_retrieve_job_fails", "when_data_format_is_not_avro"], + ) + def test_to_dataframe_s3_file_staging_should_raise( + self, retrieve_job, mocker, job_proto, exception + ): + mocker.patch.object( + retrieve_job.serving_stub, "GetJob", return_value=job_proto, + ) + with raises(exception): + retrieve_job.to_dataframe()