diff --git a/feathr_project/docs/conf.py b/feathr_project/docs/conf.py index 8ca60b7d6..eaa63e242 100644 --- a/feathr_project/docs/conf.py +++ b/feathr_project/docs/conf.py @@ -14,19 +14,20 @@ # import os import sys -sys.path.insert(0, os.path.abspath('..')) + +sys.path.insert(0, os.path.abspath("..")) # -- Project information ----------------------------------------------------- -project = 'Feathr Feature Store' -copyright = '2023, Feathr Community' -author = 'Feathr Community' +project = "Feathr Feature Store" +copyright = "2023, Feathr Community" +author = "Feathr Community" # The short X.Y version -version = '1.0' +version = "1.0" # The full version, including alpha/beta/rc tags -release = '1.0.0' +release = "1.0.0" # -- General configuration --------------------------------------------------- @@ -39,30 +40,30 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.doctest', - 'sphinx.ext.intersphinx', - 'sphinx.ext.todo', - 'sphinx.ext.coverage', - 'sphinx.ext.mathjax', - 'sphinx.ext.ifconfig', - 'sphinx.ext.viewcode', - 'sphinx.ext.githubpages', - 'sphinx_rtd_theme', - 'sphinx.ext.napoleon', + "sphinx.ext.autodoc", + "sphinx.ext.doctest", + "sphinx.ext.intersphinx", + "sphinx.ext.todo", + "sphinx.ext.coverage", + "sphinx.ext.mathjax", + "sphinx.ext.ifconfig", + "sphinx.ext.viewcode", + "sphinx.ext.githubpages", + "sphinx_rtd_theme", + "sphinx.ext.napoleon", ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # # source_suffix = ['.rst', '.md'] -source_suffix = '.rst' +source_suffix = ".rst" # The master toctree document. -master_doc = 'index' +master_doc = "index" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -74,7 +75,7 @@ # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store', 'setup.py'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "setup.py"] # The name of the Pygments (syntax highlighting) style to use. pygments_style = None @@ -85,7 +86,7 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'sphinx_rtd_theme' +html_theme = "sphinx_rtd_theme" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the @@ -96,7 +97,7 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # Custom sidebar templates, must be a dictionary that maps document names # to template names. @@ -112,7 +113,7 @@ # -- Options for HTMLHelp output --------------------------------------------- # Output file base name for HTML help builder. -htmlhelp_basename = 'feathrdoc' +htmlhelp_basename = "feathrdoc" # -- Options for LaTeX output ------------------------------------------------ @@ -121,15 +122,12 @@ # The paper size ('letterpaper' or 'a4paper'). # # 'papersize': 'letterpaper', - # The font size ('10pt', '11pt' or '12pt'). # # 'pointsize': '10pt', - # Additional stuff for the LaTeX preamble. # # 'preamble': '', - # Latex figure (float) alignment # # 'figure_align': 'htbp', @@ -139,8 +137,7 @@ # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'feathr.tex', 'feathr Documentation', - 'Feathr Community', 'manual'), + (master_doc, "feathr.tex", "feathr Documentation", "Feathr Community", "manual"), ] @@ -148,10 +145,7 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - (master_doc, 'feathr', 'feathr Documentation', - [author], 1) -] +man_pages = [(master_doc, "feathr", "feathr Documentation", [author], 1)] # -- Options for Texinfo output ---------------------------------------------- @@ -160,9 +154,15 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'feathr', 'feathr Documentation', - author, 'feathr', 'One line description of project.', - 'Miscellaneous'), + ( + master_doc, + "feathr", + "feathr Documentation", + author, + "feathr", + "One line description of project.", + "Miscellaneous", + ), ] @@ -181,7 +181,7 @@ # epub_uid = '' # A list of files that should not be packed into the epub file. -epub_exclude_files = ['search.html'] +epub_exclude_files = ["search.html"] # -- Extension configuration ------------------------------------------------- @@ -189,7 +189,7 @@ # -- Options for intersphinx extension --------------------------------------- # Example configuration for intersphinx: refer to the Python standard library. -intersphinx_mapping = {'https://docs.python.org/': None} +intersphinx_mapping = {"https://docs.python.org/": None} # -- Options for todo extension ---------------------------------------------- diff --git a/feathr_project/feathr/__init__.py b/feathr_project/feathr/__init__.py index 5c279b7d5..3301eef43 100644 --- a/feathr_project/feathr/__init__.py +++ b/feathr_project/feathr/__init__.py @@ -31,49 +31,49 @@ __all__ = [ - 'FeatureJoinJobParams', - 'FeatureGenerationJobParams', - 'FeathrClient', - 'DerivedFeature', - 'FeatureAnchor', - 'Feature', - 'ValueType', - 'WindowAggTransformation', - 'TypedKey', - 'DUMMYKEY', - 'BackfillTime', - 'MaterializationSettings', - 'MonitoringSettings', - 'RedisSink', - 'HdfsSink', - 'MonitoringSqlSink', - 'AerospikeSink', - 'FeatureQuery', - 'LookupFeature', - 'Aggregation', - 'get_result_df', - 'AvroJsonSchema', - 'Source', - 'InputContext', - 'HdfsSource', - 'SnowflakeSource', - 'KafkaConfig', - 'KafKaSource', - 'ValueType', - 'BooleanFeatureType', - 'Int32FeatureType', - 'Int64FeatureType', - 'FloatFeatureType', - 'DoubleFeatureType', - 'StringFeatureType', - 'BytesFeatureType', - 'FloatVectorFeatureType', - 'Int32VectorFeatureType', - 'Int64VectorFeatureType', - 'DoubleVectorFeatureType', - 'FeatureNameValidationError', - 'ObservationSettings', - 'FeaturePrinter', - 'SparkExecutionConfiguration', + "FeatureJoinJobParams", + "FeatureGenerationJobParams", + "FeathrClient", + "DerivedFeature", + "FeatureAnchor", + "Feature", + "ValueType", + "WindowAggTransformation", + "TypedKey", + "DUMMYKEY", + "BackfillTime", + "MaterializationSettings", + "MonitoringSettings", + "RedisSink", + "HdfsSink", + "MonitoringSqlSink", + "AerospikeSink", + "FeatureQuery", + "LookupFeature", + "Aggregation", + "get_result_df", + "AvroJsonSchema", + "Source", + "InputContext", + "HdfsSource", + "SnowflakeSource", + "KafkaConfig", + "KafKaSource", + "ValueType", + "BooleanFeatureType", + "Int32FeatureType", + "Int64FeatureType", + "FloatFeatureType", + "DoubleFeatureType", + "StringFeatureType", + "BytesFeatureType", + "FloatVectorFeatureType", + "Int32VectorFeatureType", + "Int64VectorFeatureType", + "DoubleVectorFeatureType", + "FeatureNameValidationError", + "ObservationSettings", + "FeaturePrinter", + "SparkExecutionConfiguration", __version__, - ] +] diff --git a/feathr_project/feathr/chat/__init__.py b/feathr_project/feathr/chat/__init__.py index 60bc05cfd..8d9b4623b 100644 --- a/feathr_project/feathr/chat/__init__.py +++ b/feathr_project/feathr/chat/__init__.py @@ -1,4 +1,3 @@ - from IPython.core.magic import Magics, line_magic, magics_class from feathr.chat.feathr_chat import FeathrChat @@ -7,32 +6,35 @@ chatBot = FeathrChat() dsl_learned = False + + @magics_class class FeathrMagic(Magics): - @line_magic def feathr(self, question): - scope = get_ipython().get_local_scope(stack_depth=2) - client = scope.get('client') + scope = get_ipython().get_local_scope(stack_depth=2) + client = scope.get("client") global dsl_learned if client: prompt_generator = PromptGenerator(client) if not dsl_learned: ask_to_teach = "I am going to teach you a DSL, could you learn it?" chatBot.ask_llm_in_notebook(ask_to_teach) - dsl = '\n Feathr DSL: \n' + prompt_generator.get_feathr_dsl_prompts() + dsl = "\n Feathr DSL: \n" + prompt_generator.get_feathr_dsl_prompts() chatBot.ask_llm_in_notebook(dsl) - + ask_to_teach = "Do you want to see the full source code for the APIs?" chatBot.ask_llm_in_notebook(ask_to_teach) - dsl = '\n Feathr DSL: \n' + prompt_generator.get_full_dsl_source_code() + dsl = "\n Feathr DSL: \n" + prompt_generator.get_full_dsl_source_code() chatBot.ask_llm_in_notebook(dsl) - + dsl_learned = True question_with_prompt = prompt_generator.process_question(question) chatBot.ask_llm_in_notebook(question_with_prompt) else: - print("'client' is not defined in the notebook. Please create a FeathrClient instance named as 'client' before using Feathr chat. e.g. client = FeathrClient('/path/to/your/workspace') ") + print( + "'client' is not defined in the notebook. Please create a FeathrClient instance named as 'client' before using Feathr chat. e.g. client = FeathrClient('/path/to/your/workspace') " + ) def load_ipython_extension(ipython): diff --git a/feathr_project/feathr/chat/feathr_chat.py b/feathr_project/feathr/chat/feathr_chat.py index 46a73b5ec..6dc91b190 100644 --- a/feathr_project/feathr/chat/feathr_chat.py +++ b/feathr_project/feathr/chat/feathr_chat.py @@ -3,6 +3,7 @@ from revChatGPT.V3 import Chatbot from feathr.chat.notebook_utils import * + class FeathrChat(object): def __init__(self): key = self.get_api_key() @@ -19,7 +20,9 @@ def ask_llm_in_notebook(self, question_with_prompt: str): if not self.chat_bot: key = self.get_api_key() if not key: - raise RuntimeError("Please set environment variable CHATGPT_API_KEY before using Feathr Chat. You can get your API key for ChatGPT at https://platform.openai.com/account/api-keys. For example, run: os.environ['CHATGPT_API_KEY'] = 'your api key' and retry.") + raise RuntimeError( + "Please set environment variable CHATGPT_API_KEY before using Feathr Chat. You can get your API key for ChatGPT at https://platform.openai.com/account/api-keys. For example, run: os.environ['CHATGPT_API_KEY'] = 'your api key' and retry." + ) self.chat_bot = Chatbot(key) content = self.chat_bot.ask(question_with_prompt) @@ -31,6 +34,6 @@ def ask_llm_in_notebook(self, question_with_prompt: str): create_new_cell(code) else: print(content) - + def is_a_code_gen_question(self, question_with_prompt): return "explain" not in question_with_prompt.lower() and "what" not in question_with_prompt.lower() diff --git a/feathr_project/feathr/chat/notebook_utils.py b/feathr_project/feathr/chat/notebook_utils.py index 20f0a745d..69642c7a7 100644 --- a/feathr_project/feathr/chat/notebook_utils.py +++ b/feathr_project/feathr/chat/notebook_utils.py @@ -1,19 +1,20 @@ from IPython.core.getipython import get_ipython import re + def create_new_cell(contents): shell = get_ipython() payload = dict( - source='set_next_input', + source="set_next_input", text=contents, replace=False, ) shell.payload_manager.write_payload(payload, single=False) - -def extract_code_from_string(input_str, lang='python'): + +def extract_code_from_string(input_str, lang="python"): """Extract the code block for a given language""" - pattern = fr'```({lang}|\s*)\n(.*?)\n```' + pattern = rf"```({lang}|\s*)\n(.*?)\n```" # Use the re.findall() function to extract the code block match = re.search(pattern, input_str, re.DOTALL) # Check if any code block was found diff --git a/feathr_project/feathr/chat/prompt_generator.py b/feathr_project/feathr/chat/prompt_generator.py index 95f842a29..2377bfd9a 100644 --- a/feathr_project/feathr/chat/prompt_generator.py +++ b/feathr_project/feathr/chat/prompt_generator.py @@ -3,6 +3,7 @@ from feathr.chat.source_code_utils import read_source_code_compact + class PromptGenerator(object): def __init__(self, client: FeathrClient) -> None: self.client = client @@ -67,16 +68,18 @@ def my_source_udf_name(df: DataFrame) -> DataFrame: """ prompts += self.get_materialization_prompts() prompts += self.get_test_prompts() - prompts += self.get_join_dsl_prompts() + prompts += self.get_join_dsl_prompts() prompts += self.get_features_prompts() return prompts def get_full_dsl_source_code(self): prompts = "" - client_source_code = '\n Feathr client API: \n' + read_source_code_compact(self.module_path, self.module_path + '/../client.py') + client_source_code = "\n Feathr client API: \n" + read_source_code_compact( + self.module_path, self.module_path + "/../client.py" + ) prompts += client_source_code - - definition_source_code = read_source_code_compact(self.module_path, self.module_path + '/../definition') + + definition_source_code = read_source_code_compact(self.module_path, self.module_path + "/../definition") prompts += f"""The full classes for the Feathr DSL is here: {definition_source_code}""" return prompts @@ -117,6 +120,7 @@ def get_join_dsl_prompts(self): """ + def get_test_prompts(self): return f""" The API to test a feature anchor is: @@ -161,19 +165,22 @@ def get_features_prompts(self): print (fv) ) """ - + def get_metadata_prompts(self): # Get from registry # TODO return metadata from registry - #features = self.client.registry.list_registered_features(self.client.project_name) - #return f""" registered features are {features} """ + # features = self.client.registry.list_registered_features(self.client.project_name) + # return f""" registered features are {features} """ return "" - + def process_question(self, question: str): - prompts = " My question is: " + question + ". \n Requirement: Please use the provided Feathr DSL if you're able to, do not use API or concept from other feature engineering related solutions. You can assume an instance of FeathrClient is already created and named as 'client'. If your anwser has code, combine all your code in a block." + prompts = ( + " My question is: " + + question + + ". \n Requirement: Please use the provided Feathr DSL if you're able to, do not use API or concept from other feature engineering related solutions. You can assume an instance of FeathrClient is already created and named as 'client'. If your anwser has code, combine all your code in a block." + ) prompts += "\n Context Information:\n" - prompts += self.get_metadata_prompts() + prompts += self.get_metadata_prompts() if "train" in question.lower(): prompts = prompts + ". Do not use event timestamp related columns in model training." return prompts - diff --git a/feathr_project/feathr/chat/source_code_utils.py b/feathr_project/feathr/chat/source_code_utils.py index bcbb2d463..35e0d6a23 100644 --- a/feathr_project/feathr/chat/source_code_utils.py +++ b/feathr_project/feathr/chat/source_code_utils.py @@ -2,21 +2,22 @@ import os import re + def read_source_code_compact(module_path, source_file_or_dir_path): # Get a list of all Python files in the specified directory and its subdirectories # If it is a file, directly read the file. python_files = [] - if source_file_or_dir_path.endswith('.py'): - python_files.append(source_file_or_dir_path) + if source_file_or_dir_path.endswith(".py"): + python_files.append(source_file_or_dir_path) for root, dirs, files in os.walk(source_file_or_dir_path): for file in files: - if file.endswith('.py'): + if file.endswith(".py"): python_files.append(os.path.join(root, file)) # Concatenate the contents of all Python files into a single string - concatenated_contents = '' + concatenated_contents = "" for file_path in python_files: - with open(file_path, 'r') as f: + with open(file_path, "r") as f: # Parse the file into an abstract syntax tree (AST) tree = ast.parse(f.read()) @@ -27,22 +28,24 @@ def read_source_code_compact(module_path, source_file_or_dir_path): # Convert the AST back to Python code code = ast.unparse(tree) relative_path = os.path.relpath(file_path, module_path) - concatenated_contents += '\n' + 'In file: ' + relative_path + '\n' + remove_comments(code) + concatenated_contents += "\n" + "In file: " + relative_path + "\n" + remove_comments(code) return concatenated_contents + def remove_comments(source): # Remove block comments - source = re.sub(r'""".*?"""', '', source, flags=re.DOTALL) + source = re.sub(r'""".*?"""', "", source, flags=re.DOTALL) # Remove line comments - source = re.sub(r'#.*?\n', '', source) + source = re.sub(r"#.*?\n", "", source) # Remove empty lines - source = re.sub(r'^\s*\n', '', source, flags=re.MULTILINE) + source = re.sub(r"^\s*\n", "", source, flags=re.MULTILINE) return source + # Define a function to remove the function implementations def remove_function_implementations(node): if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): # Replace the function body with a single 'pass' statement node.body = [ast.Pass()] - ast.fix_missing_locations(node) \ No newline at end of file + ast.fix_missing_locations(node) diff --git a/feathr_project/feathr/client.py b/feathr_project/feathr/client.py index 1be4238ea..5424bc54b 100644 --- a/feathr_project/feathr/client.py +++ b/feathr_project/feathr/client.py @@ -42,7 +42,6 @@ import importlib.util - class FeathrClient(object): """Feathr client. @@ -59,9 +58,10 @@ class FeathrClient(object): RuntimeError: Fail to create the client since necessary environment variables are not set for Redis client creation. """ + def __init__( self, - config_path:str = "./feathr_config.yaml", + config_path: str = "./feathr_config.yaml", local_workspace_dir: str = None, credential: Any = None, project_registry_tag: Dict[str, str] = None, @@ -80,8 +80,8 @@ def __init__( """ self.logger = logging.getLogger(__name__) # Redis key separator - self._KEY_SEPARATOR = ':' - self._COMPOSITE_KEY_SEPARATOR = '#' + self._KEY_SEPARATOR = ":" + self._COMPOSITE_KEY_SEPARATOR = "#" self.env_config = EnvConfigReader(config_path=config_path) if local_workspace_dir: self.local_workspace_dir = local_workspace_dir @@ -91,107 +91,86 @@ def __init__( self.local_workspace_dir = tem_dir_obj.name if not os.path.exists(config_path): - self.logger.warning('No Configuration file exist at the user provided config_path or the default config_path (./feathr_config.yaml), you need to set the environment variables explicitly. For all the environment variables that you need to set, please refer to https://github.com/feathr-ai/feathr/blob/main/feathr_project/feathrcli/data/feathr_user_workspace/feathr_config.yaml') + self.logger.warning( + "No Configuration file exist at the user provided config_path or the default config_path (./feathr_config.yaml), you need to set the environment variables explicitly. For all the environment variables that you need to set, please refer to https://github.com/feathr-ai/feathr/blob/main/feathr_project/feathrcli/data/feathr_user_workspace/feathr_config.yaml" + ) # Load all configs from yaml at initialization # DO NOT load any configs from yaml during runtime. - self.project_name = self.env_config.get( - 'project_config__project_name') + self.project_name = self.env_config.get("project_config__project_name") # Redis configs. This is optional unless users have configured Redis host. - if self.env_config.get('online_store__redis__host'): + if self.env_config.get("online_store__redis__host"): # For illustrative purposes. spec = importlib.util.find_spec("redis") if spec is None: - self.logger.warning('You have configured Redis host, but there is no local Redis client package. Install the package using "pip install redis". ') - self.redis_host = self.env_config.get( - 'online_store__redis__host') - self.redis_port = self.env_config.get( - 'online_store__redis__port') - self.redis_ssl_enabled = self.env_config.get( - 'online_store__redis__ssl_enabled') + self.logger.warning( + 'You have configured Redis host, but there is no local Redis client package. Install the package using "pip install redis". ' + ) + self.redis_host = self.env_config.get("online_store__redis__host") + self.redis_port = self.env_config.get("online_store__redis__port") + self.redis_ssl_enabled = self.env_config.get("online_store__redis__ssl_enabled") self._construct_redis_client() # Offline store enabled configs; false by default - self.s3_enabled = self.env_config.get( - 'offline_store__s3__s3_enabled') - self.adls_enabled = self.env_config.get( - 'offline_store__adls__adls_enabled') - self.wasb_enabled = self.env_config.get( - 'offline_store__wasb__wasb_enabled') - self.jdbc_enabled = self.env_config.get( - 'offline_store__jdbc__jdbc_enabled') - self.snowflake_enabled = self.env_config.get( - 'offline_store__snowflake__snowflake_enabled') - if not (self.s3_enabled or self.adls_enabled or self.wasb_enabled or self.jdbc_enabled or self.snowflake_enabled): + self.s3_enabled = self.env_config.get("offline_store__s3__s3_enabled") + self.adls_enabled = self.env_config.get("offline_store__adls__adls_enabled") + self.wasb_enabled = self.env_config.get("offline_store__wasb__wasb_enabled") + self.jdbc_enabled = self.env_config.get("offline_store__jdbc__jdbc_enabled") + self.snowflake_enabled = self.env_config.get("offline_store__snowflake__snowflake_enabled") + if not ( + self.s3_enabled or self.adls_enabled or self.wasb_enabled or self.jdbc_enabled or self.snowflake_enabled + ): self.logger.warning("No offline storage enabled.") # S3 configs if self.s3_enabled: - self.s3_endpoint = self.env_config.get( - 'offline_store__s3__s3_endpoint') + self.s3_endpoint = self.env_config.get("offline_store__s3__s3_endpoint") # spark configs - self.output_num_parts = self.env_config.get( - 'spark_config__spark_result_output_parts') - self.spark_runtime = self.env_config.get( - 'spark_config__spark_cluster') + self.output_num_parts = self.env_config.get("spark_config__spark_result_output_parts") + self.spark_runtime = self.env_config.get("spark_config__spark_cluster") self.credential = credential - if self.spark_runtime not in {'azure_synapse', 'databricks', 'local'}: + if self.spark_runtime not in {"azure_synapse", "databricks", "local"}: raise RuntimeError( - f'{self.spark_runtime} is not supported. Only \'azure_synapse\', \'databricks\' and \'local\' are currently supported.') - elif self.spark_runtime == 'azure_synapse': + f"{self.spark_runtime} is not supported. Only 'azure_synapse', 'databricks' and 'local' are currently supported." + ) + elif self.spark_runtime == "azure_synapse": # Feathr is a spark-based application so the feathr jar compiled from source code will be used in the # Spark job submission. The feathr jar hosted in cloud saves the time users needed to upload the jar from # their local env. - self._FEATHR_JOB_JAR_PATH = \ - self.env_config.get( - 'spark_config__azure_synapse__feathr_runtime_location') + self._FEATHR_JOB_JAR_PATH = self.env_config.get("spark_config__azure_synapse__feathr_runtime_location") if self.credential is None: self.credential = DefaultAzureCredential(exclude_interactive_browser_credential=False) self.feathr_spark_launcher = _FeathrSynapseJobLauncher( - synapse_dev_url=self.env_config.get( - 'spark_config__azure_synapse__dev_url'), - pool_name=self.env_config.get( - 'spark_config__azure_synapse__pool_name'), - datalake_dir=self.env_config.get( - 'spark_config__azure_synapse__workspace_dir'), - executor_size=self.env_config.get( - 'spark_config__azure_synapse__executor_size'), - executors=self.env_config.get( - 'spark_config__azure_synapse__executor_num'), - credential=self.credential + synapse_dev_url=self.env_config.get("spark_config__azure_synapse__dev_url"), + pool_name=self.env_config.get("spark_config__azure_synapse__pool_name"), + datalake_dir=self.env_config.get("spark_config__azure_synapse__workspace_dir"), + executor_size=self.env_config.get("spark_config__azure_synapse__executor_size"), + executors=self.env_config.get("spark_config__azure_synapse__executor_num"), + credential=self.credential, ) - elif self.spark_runtime == 'databricks': + elif self.spark_runtime == "databricks": # Feathr is a spark-based application so the feathr jar compiled from source code will be used in the # Spark job submission. The feathr jar hosted in cloud saves the time users needed to upload the jar from # their local env. - self._FEATHR_JOB_JAR_PATH = \ - self.env_config.get( - 'spark_config__databricks__feathr_runtime_location') + self._FEATHR_JOB_JAR_PATH = self.env_config.get("spark_config__databricks__feathr_runtime_location") self.feathr_spark_launcher = _FeathrDatabricksJobLauncher( - workspace_instance_url=self.env_config.get( - 'spark_config__databricks__workspace_instance_url'), - token_value=self.env_config.get_from_env_or_akv( - 'DATABRICKS_WORKSPACE_TOKEN_VALUE'), - config_template=self.env_config.get( - 'spark_config__databricks__config_template'), - databricks_work_dir=self.env_config.get( - 'spark_config__databricks__work_dir') + workspace_instance_url=self.env_config.get("spark_config__databricks__workspace_instance_url"), + token_value=self.env_config.get_from_env_or_akv("DATABRICKS_WORKSPACE_TOKEN_VALUE"), + config_template=self.env_config.get("spark_config__databricks__config_template"), + databricks_work_dir=self.env_config.get("spark_config__databricks__work_dir"), ) - elif self.spark_runtime == 'local': - self._FEATHR_JOB_JAR_PATH = \ - self.env_config.get( - 'spark_config__local__feathr_runtime_location') + elif self.spark_runtime == "local": + self._FEATHR_JOB_JAR_PATH = self.env_config.get("spark_config__local__feathr_runtime_location") self.feathr_spark_launcher = _FeathrLocalSparkJobLauncher( - workspace_path = self.env_config.get('spark_config__local__workspace'), - master = self.env_config.get('spark_config__local__master') - ) - + workspace_path=self.env_config.get("spark_config__local__workspace"), + master=self.env_config.get("spark_config__local__master"), + ) self.secret_names = [] @@ -200,18 +179,31 @@ def __init__( # initialize registry self.registry = None - registry_endpoint = self.env_config.get('feature_registry__api_endpoint') - azure_purview_name = self.env_config.get('feature_registry__purview__purview_name') + registry_endpoint = self.env_config.get("feature_registry__api_endpoint") + azure_purview_name = self.env_config.get("feature_registry__purview__purview_name") if registry_endpoint: - self.registry = _FeatureRegistry(self.project_name, endpoint=registry_endpoint, project_tags=project_registry_tag, credential=credential) + self.registry = _FeatureRegistry( + self.project_name, endpoint=registry_endpoint, project_tags=project_registry_tag, credential=credential + ) elif azure_purview_name: - registry_delimiter = self.env_config.get('feature_registry__purview__delimiter') + registry_delimiter = self.env_config.get("feature_registry__purview__delimiter") # initialize the registry no matter whether we set purview name or not, given some of the methods are used there. - self.registry = _PurviewRegistry(self.project_name, azure_purview_name, registry_delimiter, project_registry_tag, config_path = config_path, credential=credential) - logger.warning("FEATURE_REGISTRY__PURVIEW__PURVIEW_NAME will be deprecated soon. Please use FEATURE_REGISTRY__API_ENDPOINT instead.") + self.registry = _PurviewRegistry( + self.project_name, + azure_purview_name, + registry_delimiter, + project_registry_tag, + config_path=config_path, + credential=credential, + ) + logger.warning( + "FEATURE_REGISTRY__PURVIEW__PURVIEW_NAME will be deprecated soon. Please use FEATURE_REGISTRY__API_ENDPOINT instead." + ) else: # no registry configured - logger.info("Feathr registry is not configured. Consider setting the Feathr registry component for richer feature store experience.") + logger.info( + "Feathr registry is not configured. Consider setting the Feathr registry component for richer feature store experience." + ) logger.info(f"Feathr client {get_version()} initialized successfully.") @@ -221,10 +213,12 @@ def _check_required_environment_variables_exist(self): Some required information has to be set via environment variables so the client can work. """ props = self.secret_names - for required_field in (self.required_fields + props): + for required_field in self.required_fields + props: if required_field not in os.environ: - raise RuntimeError(f'{required_field} is not set in environment variable. All required environment ' - f'variables are: {self.required_fields}.') + raise RuntimeError( + f"{required_field} is not set in environment variable. All required environment " + f"variables are: {self.required_fields}." + ) def register_features(self, from_context: bool = True): """Registers features based on the current workspace @@ -236,15 +230,27 @@ def register_features(self, from_context: bool = True): if from_context: # make sure those items are in `self` - if 'anchor_list' in dir(self) and 'derived_feature_list' in dir(self): - self.config_helper.save_to_feature_config_from_context(self.anchor_list, self.derived_feature_list, self.local_workspace_dir) - self.registry.register_features(self.local_workspace_dir, from_context=from_context, anchor_list=self.anchor_list, derived_feature_list=self.derived_feature_list) + if "anchor_list" in dir(self) and "derived_feature_list" in dir(self): + self.config_helper.save_to_feature_config_from_context( + self.anchor_list, self.derived_feature_list, self.local_workspace_dir + ) + self.registry.register_features( + self.local_workspace_dir, + from_context=from_context, + anchor_list=self.anchor_list, + derived_feature_list=self.derived_feature_list, + ) else: raise RuntimeError("Please call FeathrClient.build_features() first in order to register features") else: self.registry.register_features(self.local_workspace_dir, from_context=from_context) - def build_features(self, anchor_list: List[FeatureAnchor] = [], derived_feature_list: List[DerivedFeature] = [], verbose: bool = False): + def build_features( + self, + anchor_list: List[FeatureAnchor] = [], + derived_feature_list: List[DerivedFeature] = [], + verbose: bool = False, + ): """Build features based on the current workspace. all actions that triggers a spark job will be based on the result of this action. """ @@ -254,18 +260,24 @@ def build_features(self, anchor_list: List[FeatureAnchor] = [], derived_feature_ source_names = {} for anchor in anchor_list: if anchor.name in anchor_names: - raise RuntimeError(f"Anchor name should be unique but there are duplicate anchor names in your anchor " - f"definitions. Anchor name of {anchor} is already defined in {anchor_names[anchor.name]}") + raise RuntimeError( + f"Anchor name should be unique but there are duplicate anchor names in your anchor " + f"definitions. Anchor name of {anchor} is already defined in {anchor_names[anchor.name]}" + ) else: anchor_names[anchor.name] = anchor if anchor.source.name in source_names and (anchor.source is not source_names[anchor.source.name]): - raise RuntimeError(f"Source name should be unique but there are duplicate source names in your source " - f"definitions. Source name of {anchor.source} is already defined in {source_names[anchor.source.name]}") + raise RuntimeError( + f"Source name should be unique but there are duplicate source names in your source " + f"definitions. Source name of {anchor.source} is already defined in {source_names[anchor.source.name]}" + ) else: source_names[anchor.source.name] = anchor.source _PreprocessingPyudfManager.build_anchor_preprocessing_metadata(anchor_list, self.local_workspace_dir) - self.config_helper.save_to_feature_config_from_context(anchor_list, derived_feature_list, self.local_workspace_dir) + self.config_helper.save_to_feature_config_from_context( + anchor_list, derived_feature_list, self.local_workspace_dir + ) self.anchor_list = anchor_list self.derived_feature_list = derived_feature_list @@ -337,7 +349,7 @@ def get_online_features(self, feature_table: str, key: Any, feature_names: List[ [None, None, None, None]. If a feature doesn't exist, then a None is returned for that feature. For example: [None, b'4.0', b'31.0', b'23.0']. - """ + """ redis_key = self._construct_redis_key(feature_table, key) res = self.redis_client.hmget(redis_key, *feature_names) return self._decode_proto(res) @@ -386,41 +398,59 @@ def _decode_proto(self, feature_list): feature_value = FeatureValue() decoded = base64.b64decode(raw_feature) feature_value.ParseFromString(decoded) - if feature_value.WhichOneof('FeatureValueOneOf') == 'boolean_value': + if feature_value.WhichOneof("FeatureValueOneOf") == "boolean_value": typed_result.append(feature_value.boolean_value) - elif feature_value.WhichOneof('FeatureValueOneOf') == 'string_value': + elif feature_value.WhichOneof("FeatureValueOneOf") == "string_value": typed_result.append(feature_value.string_value) - elif feature_value.WhichOneof('FeatureValueOneOf') == 'float_value': + elif feature_value.WhichOneof("FeatureValueOneOf") == "float_value": typed_result.append(feature_value.float_value) - elif feature_value.WhichOneof('FeatureValueOneOf') == 'double_value': + elif feature_value.WhichOneof("FeatureValueOneOf") == "double_value": typed_result.append(feature_value.double_value) - elif feature_value.WhichOneof('FeatureValueOneOf') == 'int_value': + elif feature_value.WhichOneof("FeatureValueOneOf") == "int_value": typed_result.append(feature_value.int_value) - elif feature_value.WhichOneof('FeatureValueOneOf') == 'long_value': + elif feature_value.WhichOneof("FeatureValueOneOf") == "long_value": typed_result.append(feature_value.long_value) - elif feature_value.WhichOneof('FeatureValueOneOf') == 'int_array': + elif feature_value.WhichOneof("FeatureValueOneOf") == "int_array": typed_result.append(feature_value.int_array.integers) - elif feature_value.WhichOneof('FeatureValueOneOf') == 'string_array': + elif feature_value.WhichOneof("FeatureValueOneOf") == "string_array": typed_result.append(feature_value.string_array.strings) - elif feature_value.WhichOneof('FeatureValueOneOf') == 'float_array': + elif feature_value.WhichOneof("FeatureValueOneOf") == "float_array": typed_result.append(feature_value.float_array.floats) - elif feature_value.WhichOneof('FeatureValueOneOf') == 'double_array': + elif feature_value.WhichOneof("FeatureValueOneOf") == "double_array": typed_result.append(feature_value.double_array.doubles) - elif feature_value.WhichOneof('FeatureValueOneOf') == 'boolean_array': + elif feature_value.WhichOneof("FeatureValueOneOf") == "boolean_array": typed_result.append(feature_value.boolean_array.booleans) - elif feature_value.WhichOneof('FeatureValueOneOf') == 'sparse_string_array': - typed_result.append((feature_value.sparse_string_array.index_integers, feature_value.sparse_string_array.value_strings)) - elif feature_value.WhichOneof('FeatureValueOneOf') == 'sparse_bool_array': - typed_result.append((feature_value.sparse_bool_array.index_integers, feature_value.sparse_bool_array.value_booleans)) - elif feature_value.WhichOneof('FeatureValueOneOf') == 'sparse_float_array': - typed_result.append((feature_value.sparse_float_array.index_integers, feature_value.sparse_float_array.value_floats)) - elif feature_value.WhichOneof('FeatureValueOneOf') == 'sparse_double_array': - typed_result.append((feature_value.sparse_double_array.index_integers, feature_value.sparse_double_array.value_doubles)) - elif feature_value.WhichOneof('FeatureValueOneOf') == 'sparse_long_array': - typed_result.append((feature_value.sparse_long_array.index_integers, feature_value.sparse_long_array.value_longs)) + elif feature_value.WhichOneof("FeatureValueOneOf") == "sparse_string_array": + typed_result.append( + ( + feature_value.sparse_string_array.index_integers, + feature_value.sparse_string_array.value_strings, + ) + ) + elif feature_value.WhichOneof("FeatureValueOneOf") == "sparse_bool_array": + typed_result.append( + (feature_value.sparse_bool_array.index_integers, feature_value.sparse_bool_array.value_booleans) + ) + elif feature_value.WhichOneof("FeatureValueOneOf") == "sparse_float_array": + typed_result.append( + (feature_value.sparse_float_array.index_integers, feature_value.sparse_float_array.value_floats) + ) + elif feature_value.WhichOneof("FeatureValueOneOf") == "sparse_double_array": + typed_result.append( + ( + feature_value.sparse_double_array.index_integers, + feature_value.sparse_double_array.value_doubles, + ) + ) + elif feature_value.WhichOneof("FeatureValueOneOf") == "sparse_long_array": + typed_result.append( + (feature_value.sparse_long_array.index_integers, feature_value.sparse_long_array.value_longs) + ) else: - self.logger.debug("Fail to load the feature type. Maybe a new type that is not supported by this " - "client version") + self.logger.debug( + "Fail to load the feature type. Maybe a new type that is not supported by this " + "client version" + ) self.logger.debug(f"The raw feature is {raw_feature}.") self.logger.debug(f"The loaded feature is {feature_value}") typed_result.append(None) @@ -441,9 +471,9 @@ def delete_feature_from_redis(self, feature_table, key, feature_name) -> None: redis_key = self._construct_redis_key(feature_table, key) if self.redis_client.hexists(redis_key, feature_name): self.redis_client.delete(redis_key, feature_name) - print(f'Deletion successful. {feature_name} is deleted from Redis.') + print(f"Deletion successful. {feature_name} is deleted from Redis.") else: - raise RuntimeError(f'Deletion failed. {feature_name} not found in Redis.') + raise RuntimeError(f"Deletion failed. {feature_name} not found in Redis.") def _clean_test_data(self, feature_table): """ @@ -454,12 +484,11 @@ def _clean_test_data(self, feature_table): Args: feature_table: str, feature_table i.e your prefix before the separator in the Redis database. """ - cursor = '0' - ns_keys = feature_table + '*' + cursor = "0" + ns_keys = feature_table + "*" while cursor != 0: # 5000 count at a scan seems reasonable faster for our testing data - cursor, keys = self.redis_client.scan( - cursor=cursor, match=ns_keys, count=5000) + cursor, keys = self.redis_client.scan(cursor=cursor, match=ns_keys, count=5000) if keys: self.redis_client.delete(*keys) @@ -468,15 +497,16 @@ def _construct_redis_key(self, feature_table, key): key = self._COMPOSITE_KEY_SEPARATOR.join(key) return feature_table + self._KEY_SEPARATOR + key - def _str_to_bool(self, s: str, variable_name = None): - """Define a function to detect convert string to bool, since Redis client sometimes require a bool and sometimes require a str - """ - if (isinstance(s, str) and s.casefold() == 'True'.casefold()) or s == True: + def _str_to_bool(self, s: str, variable_name=None): + """Define a function to detect convert string to bool, since Redis client sometimes require a bool and sometimes require a str""" + if (isinstance(s, str) and s.casefold() == "True".casefold()) or s == True: return True - elif (isinstance(s, str) and s.casefold() == 'False'.casefold()) or s == False: + elif (isinstance(s, str) and s.casefold() == "False".casefold()) or s == False: return False else: - self.logger.warning(f'{s} is not a valid Bool value. Maybe you want to double check if it is set correctly for {variable_name}.') + self.logger.warning( + f"{s} is not a valid Bool value. Maybe you want to double check if it is set correctly for {variable_name}." + ) return s def _construct_redis_client(self): @@ -488,21 +518,20 @@ def _construct_redis_client(self): port = self.redis_port ssl_enabled = self.redis_ssl_enabled self.redis_client = redis.Redis( - host=host, - port=port, - password=password, - ssl=self._str_to_bool(ssl_enabled, "ssl_enabled")) - self.logger.info('Redis connection is successful and completed.') - - def get_offline_features(self, - observation_settings: ObservationSettings, - feature_query: Union[FeatureQuery, List[FeatureQuery]], - output_path: Union[str, Sink], - execution_configurations: Union[SparkExecutionConfiguration ,Dict[str,str]] = {}, - config_file_name:str = "feature_join_conf/feature_join.conf", - dataset_column_names: Set[str] = None, - verbose: bool = False - ): + host=host, port=port, password=password, ssl=self._str_to_bool(ssl_enabled, "ssl_enabled") + ) + self.logger.info("Redis connection is successful and completed.") + + def get_offline_features( + self, + observation_settings: ObservationSettings, + feature_query: Union[FeatureQuery, List[FeatureQuery]], + output_path: Union[str, Sink], + execution_configurations: Union[SparkExecutionConfiguration, Dict[str, str]] = {}, + config_file_name: str = "feature_join_conf/feature_join.conf", + dataset_column_names: Set[str] = None, + verbose: bool = False, + ): """ Get offline features for the observation dataset Args: @@ -518,12 +547,22 @@ def get_offline_features(self, for feature_query in feature_queries: for feature_name in feature_query.feature_list: feature_names.append(feature_name) - + if len(feature_names) > 0 and observation_settings.conflicts_auto_correction is None: import feathr.utils.job_utils as job_utils - dataset_column_names_from_path = job_utils.get_cloud_file_column_names(self, observation_settings.observation_path, observation_settings.file_format,observation_settings.is_file_path) - if (dataset_column_names_from_path is None or len(dataset_column_names_from_path) == 0) and dataset_column_names is None: - self.logger.warning(f"Feathr is unable to read the Observation data from {observation_settings.observation_path} due to permission issue or invalid path. Please either grant the permission or supply the observation column names in the filed: observation_column_names.") + + dataset_column_names_from_path = job_utils.get_cloud_file_column_names( + self, + observation_settings.observation_path, + observation_settings.file_format, + observation_settings.is_file_path, + ) + if ( + dataset_column_names_from_path is None or len(dataset_column_names_from_path) == 0 + ) and dataset_column_names is None: + self.logger.warning( + f"Feathr is unable to read the Observation data from {observation_settings.observation_path} due to permission issue or invalid path. Please either grant the permission or supply the observation column names in the filed: observation_column_names." + ) else: if dataset_column_names_from_path is not None and len(dataset_column_names_from_path) > 0: dataset_column_names = dataset_column_names_from_path @@ -534,11 +573,12 @@ def get_offline_features(self, if len(conflict_names) != 0: conflict_names = ",".join(conflict_names) raise RuntimeError(f"Feature names exist conflicts with dataset column names: {conflict_names}") - + udf_files = _PreprocessingPyudfManager.prepare_pyspark_udf_files(feature_names, self.local_workspace_dir) # produce join config - tm = Template(""" + tm = Template( + """ {{observation_settings.to_feature_config()}} featureList: [ {% for list in feature_lists %} @@ -546,15 +586,20 @@ def get_offline_features(self, {% endfor %} ] outputPath: "{{output_path}}" - """) - config = tm.render(feature_lists=feature_queries, observation_settings=observation_settings, output_path=output_path) + """ + ) + config = tm.render( + feature_lists=feature_queries, observation_settings=observation_settings, output_path=output_path + ) config_file_path = os.path.join(self.local_workspace_dir, config_file_name) # make sure `FeathrClient.build_features()` is called before getting offline features/materialize features # otherwise users will be confused on what are the available features # in build_features it will assign anchor_list and derived_feature_list variable, hence we are checking if those two variables exist to make sure the above condition is met - if 'anchor_list' in dir(self) and 'derived_feature_list' in dir(self): - self.config_helper.save_to_feature_config_from_context(self.anchor_list, self.derived_feature_list, self.local_workspace_dir) + if "anchor_list" in dir(self) and "derived_feature_list" in dir(self): + self.config_helper.save_to_feature_config_from_context( + self.anchor_list, self.derived_feature_list, self.local_workspace_dir + ) else: raise RuntimeError("Please call FeathrClient.build_features() first in order to get offline features") @@ -563,107 +608,120 @@ def get_offline_features(self, FeaturePrinter.pretty_print_feature_query(feature_query) write_to_file(content=config, full_file_name=config_file_path) - return self._get_offline_features_with_config(config_file_path, - output_path=output_path, - execution_configurations=execution_configurations, - udf_files=udf_files) - - def _get_offline_features_with_config(self, - feature_join_conf_path='feature_join_conf/feature_join.conf', - output_path: Union[str, Sink] = "", - execution_configurations: Dict[str,str] = {}, - udf_files=[]): + return self._get_offline_features_with_config( + config_file_path, + output_path=output_path, + execution_configurations=execution_configurations, + udf_files=udf_files, + ) + + def _get_offline_features_with_config( + self, + feature_join_conf_path="feature_join_conf/feature_join.conf", + output_path: Union[str, Sink] = "", + execution_configurations: Dict[str, str] = {}, + udf_files=[], + ): """Joins the features to your offline observation dataset based on the join config. Args: feature_join_conf_path: Relative path to your feature join config file. """ - cloud_udf_paths = [self.feathr_spark_launcher.upload_or_get_cloud_path(udf_local_path) for udf_local_path in udf_files] + cloud_udf_paths = [ + self.feathr_spark_launcher.upload_or_get_cloud_path(udf_local_path) for udf_local_path in udf_files + ] feathr_feature = ConfigFactory.parse_file(feature_join_conf_path) - feature_join_job_params = FeatureJoinJobParams(join_config_path=os.path.abspath(feature_join_conf_path), - observation_path=feathr_feature['observationPath'], - feature_config=os.path.join(self.local_workspace_dir, 'feature_conf/'), - job_output_path=output_path) - job_tags = { OUTPUT_PATH_TAG: feature_join_job_params.job_output_path } + feature_join_job_params = FeatureJoinJobParams( + join_config_path=os.path.abspath(feature_join_conf_path), + observation_path=feathr_feature["observationPath"], + feature_config=os.path.join(self.local_workspace_dir, "feature_conf/"), + job_output_path=output_path, + ) + job_tags = {OUTPUT_PATH_TAG: feature_join_job_params.job_output_path} # set output format in job tags if it's set by user, so that it can be used to parse the job result in the helper function if execution_configurations is not None and OUTPUT_FORMAT in execution_configurations: job_tags[OUTPUT_FORMAT] = execution_configurations[OUTPUT_FORMAT] else: job_tags[OUTPUT_FORMAT] = "avro" - ''' + """ - Job tags are for job metadata and it's not passed to the actual spark job (i.e. not visible to spark job), more like a platform related thing that Feathr want to add (currently job tags only have job output URL and job output format, ). They are carried over with the job and is visible to every Feathr client. Think this more like some customized metadata for the job which would be weird to be put in the spark job itself. - Job arguments (or sometimes called job parameters)are the arguments which are command line arguments passed into the actual spark job. This is usually highly related with the spark job. In Feathr it's like the input to the scala spark CLI. They are usually not spark specific (for example if we want to specify the location of the feature files, or want to - Job configuration are like "configurations" for the spark job and are usually spark specific. For example, we want to control the no. of write parts for spark Job configurations and job arguments (or sometimes called job parameters) have quite some overlaps (i.e. you can achieve the same goal by either using the job arguments/parameters vs. job configurations). But the job tags should just be used for metadata purpose. - ''' + """ # submit the jars return self.feathr_spark_launcher.submit_feathr_job( - job_name=self.project_name + '_feathr_feature_join_job', + job_name=self.project_name + "_feathr_feature_join_job", main_jar_path=self._FEATHR_JOB_JAR_PATH, python_files=cloud_udf_paths, job_tags=job_tags, main_class_name=JOIN_CLASS_NAME, - arguments= [ - '--join-config', self.feathr_spark_launcher.upload_or_get_cloud_path( - feature_join_job_params.join_config_path), - '--input', feature_join_job_params.observation_path, - '--output', feature_join_job_params.job_output_path, - '--feature-config', self.feathr_spark_launcher.upload_or_get_cloud_path( - feature_join_job_params.feature_config), - '--num-parts', self.output_num_parts - ]+self._get_offline_storage_arguments(), + arguments=[ + "--join-config", + self.feathr_spark_launcher.upload_or_get_cloud_path(feature_join_job_params.join_config_path), + "--input", + feature_join_job_params.observation_path, + "--output", + feature_join_job_params.job_output_path, + "--feature-config", + self.feathr_spark_launcher.upload_or_get_cloud_path(feature_join_job_params.feature_config), + "--num-parts", + self.output_num_parts, + ] + + self._get_offline_storage_arguments(), reference_files_path=[], configuration=execution_configurations, - properties=self._collect_secrets(feature_join_job_params.secrets) + properties=self._collect_secrets(feature_join_job_params.secrets), ) def _get_offline_storage_arguments(self): arguments = [] if self.s3_enabled: - arguments.append('--s3-config') + arguments.append("--s3-config") arguments.append(self._get_s3_config_str()) if self.adls_enabled: - arguments.append('--adls-config') + arguments.append("--adls-config") arguments.append(self._get_adls_config_str()) if self.wasb_enabled: - arguments.append('--blob-config') + arguments.append("--blob-config") arguments.append(self._get_blob_config_str()) if self.jdbc_enabled: - arguments.append('--sql-config') + arguments.append("--sql-config") arguments.append(self._get_sql_config_str()) if self.snowflake_enabled: - arguments.append('--snowflake-config') + arguments.append("--snowflake-config") arguments.append(self._get_snowflake_config_str()) return arguments def get_job_result_uri(self, block=True, timeout_sec=300) -> str: - """Gets the job output URI - """ + """Gets the job output URI""" if not block: return self.feathr_spark_launcher.get_job_result_uri() # Block the API by pooling the job status and wait for complete if self.feathr_spark_launcher.wait_for_completion(timeout_sec): return self.feathr_spark_launcher.get_job_result_uri() else: - raise RuntimeError( - 'Spark job failed so output cannot be retrieved.') + raise RuntimeError("Spark job failed so output cannot be retrieved.") def get_job_tags(self) -> Dict[str, str]: - """Gets the job tags - """ + """Gets the job tags""" return self.feathr_spark_launcher.get_job_tags() def wait_job_to_finish(self, timeout_sec: int = 300): - """Waits for the job to finish in a blocking way unless it times out - """ + """Waits for the job to finish in a blocking way unless it times out""" if self.feathr_spark_launcher.wait_for_completion(timeout_sec): return else: - raise RuntimeError('Spark job failed.') + raise RuntimeError("Spark job failed.") - def monitor_features(self, settings: MonitoringSettings, execution_configurations: Union[SparkExecutionConfiguration ,Dict[str,str]] = {}, verbose: bool = False): + def monitor_features( + self, + settings: MonitoringSettings, + execution_configurations: Union[SparkExecutionConfiguration, Dict[str, str]] = {}, + verbose: bool = False, + ): """Create a offline job to generate statistics to monitor feature data Args: @@ -677,16 +735,18 @@ def monitor_features(self, settings: MonitoringSettings, execution_configuration # Return related keys(key_column list) or None if cannot find the feature def _get_feature_key(self, feature_name: str): features = [] - if 'derived_feature_list' in dir(self): + if "derived_feature_list" in dir(self): features += self.derived_feature_list - if 'anchor_list' in dir(self): + if "anchor_list" in dir(self): for anchor in self.anchor_list: features += anchor.features for feature in features: if feature.name == feature_name: keys = feature.key return set(key.key_column for key in keys) - self.logger.warning(f"Invalid feature name: {feature_name}. Please call FeathrClient.build_features() first in order to materialize the features.") + self.logger.warning( + f"Invalid feature name: {feature_name}. Please call FeathrClient.build_features() first in order to materialize the features." + ) return None # Validation on feature keys: @@ -697,10 +757,12 @@ def _valid_materialize_keys(self, features: List[str], allow_empty_key=False): for feature in features: new_keys = self._get_feature_key(feature) if new_keys is None: - self.logger.error(f"Key of feature: {feature} is empty. Please confirm the feature is defined. In addition, if this feature is not from INPUT_CONTEXT, you might want to double check on the feature definition to see whether the key is empty or not.") + self.logger.error( + f"Key of feature: {feature} is empty. Please confirm the feature is defined. In addition, if this feature is not from INPUT_CONTEXT, you might want to double check on the feature definition to see whether the key is empty or not." + ) return False # If only get one key and it's "NOT_NEEDED", it means the feature has an empty key. - if ','.join(new_keys) == "NOT_NEEDED" and not allow_empty_key: + if ",".join(new_keys) == "NOT_NEEDED" and not allow_empty_key: self.logger.error(f"Empty feature key is not allowed for features: {features}") return False if keys is None: @@ -715,7 +777,13 @@ def _valid_materialize_keys(self, features: List[str], allow_empty_key=False): return False return True - def materialize_features(self, settings: MaterializationSettings, execution_configurations: Union[SparkExecutionConfiguration ,Dict[str,str]] = {}, verbose: bool = False, allow_materialize_non_agg_feature: bool = False): + def materialize_features( + self, + settings: MaterializationSettings, + execution_configurations: Union[SparkExecutionConfiguration, Dict[str, str]] = {}, + verbose: bool = False, + allow_materialize_non_agg_feature: bool = False, + ): """Materialize feature data Args: @@ -725,14 +793,18 @@ def materialize_features(self, settings: MaterializationSettings, execution_conf """ feature_list = settings.feature_names if len(feature_list) > 0: - if 'anchor_list' in dir(self): + if "anchor_list" in dir(self): anchors = [anchor for anchor in self.anchor_list if isinstance(anchor.source, InputContext)] anchor_feature_names = set(feature.name for anchor in anchors for feature in anchor.features) for feature in feature_list: if feature in anchor_feature_names: - raise RuntimeError(f"Materializing features that are defined on INPUT_CONTEXT is not supported. {feature} is defined on INPUT_CONTEXT so you should remove it from the feature list in MaterializationSettings.") + raise RuntimeError( + f"Materializing features that are defined on INPUT_CONTEXT is not supported. {feature} is defined on INPUT_CONTEXT so you should remove it from the feature list in MaterializationSettings." + ) if not self._valid_materialize_keys(feature_list): - raise RuntimeError(f"Invalid materialization features: {feature_list}, since they have different keys or they are not defined. Currently Feathr only supports materializing features of the same keys.") + raise RuntimeError( + f"Invalid materialization features: {feature_list}, since they have different keys or they are not defined. Currently Feathr only supports materializing features of the same keys." + ) if not allow_materialize_non_agg_feature: # Check if there are non-aggregation features in the list @@ -741,11 +813,15 @@ def materialize_features(self, settings: MaterializationSettings, execution_conf for anchor in self.anchor_list: for feature in anchor.features: if feature.name == fn and not isinstance(feature.transform, WindowAggTransformation): - raise RuntimeError(f"Feature {fn} is not an aggregation feature. Currently Feathr only supports materializing aggregation features. If you want to materialize {fn}, please set allow_materialize_non_agg_feature to True.") + raise RuntimeError( + f"Feature {fn} is not an aggregation feature. Currently Feathr only supports materializing aggregation features. If you want to materialize {fn}, please set allow_materialize_non_agg_feature to True." + ) # Check over derived features for feature in self.derived_feature_list: if feature.name == fn and not isinstance(feature.transform, WindowAggTransformation): - raise RuntimeError(f"Feature {fn} is not an aggregation feature. Currently Feathr only supports materializing aggregation features. If you want to materialize {fn}, please set allow_materialize_non_agg_feature to True.") + raise RuntimeError( + f"Feature {fn} is not an aggregation feature. Currently Feathr only supports materializing aggregation features. If you want to materialize {fn}, please set allow_materialize_non_agg_feature to True." + ) # Collect secrets from sinks. Get output_path as well if the sink is offline sink (HdfsSink) for later use. secrets = [] @@ -769,12 +845,18 @@ def materialize_features(self, settings: MaterializationSettings, execution_conf # make sure `FeathrClient.build_features()` is called before getting offline features/materialize features in the python SDK # otherwise users will be confused on what are the available features # in build_features it will assign anchor_list and derived_feature_list variable, hence we are checking if those two variables exist to make sure the above condition is met - if 'anchor_list' in dir(self) and 'derived_feature_list' in dir(self): - self.config_helper.save_to_feature_config_from_context(self.anchor_list, self.derived_feature_list, self.local_workspace_dir) + if "anchor_list" in dir(self) and "derived_feature_list" in dir(self): + self.config_helper.save_to_feature_config_from_context( + self.anchor_list, self.derived_feature_list, self.local_workspace_dir + ) else: - raise RuntimeError("Please call FeathrClient.build_features() first in order to materialize the features") + raise RuntimeError( + "Please call FeathrClient.build_features() first in order to materialize the features" + ) - udf_files = _PreprocessingPyudfManager.prepare_pyspark_udf_files(settings.feature_names, self.local_workspace_dir) + udf_files = _PreprocessingPyudfManager.prepare_pyspark_udf_files( + settings.feature_names, self.local_workspace_dir + ) # CLI will directly call this so the experience won't be broken result = self._materialize_features_with_config( feature_gen_conf_path=config_file_path, @@ -783,7 +865,7 @@ def materialize_features(self, settings: MaterializationSettings, execution_conf secrets=secrets, output_path=output_path, ) - if os.path.exists(config_file_path) and self.spark_runtime != 'local': + if os.path.exists(config_file_path) and self.spark_runtime != "local": os.remove(config_file_path) results.append(result) @@ -795,8 +877,8 @@ def materialize_features(self, settings: MaterializationSettings, execution_conf def _materialize_features_with_config( self, - feature_gen_conf_path: str = 'feature_gen_conf/feature_gen.conf', - execution_configurations: Dict[str,str] = {}, + feature_gen_conf_path: str = "feature_gen_conf/feature_gen.conf", + execution_configurations: Dict[str, str] = {}, udf_files: List = [], secrets: List = [], output_path: str = None, @@ -811,12 +893,15 @@ def _materialize_features_with_config( secrets: Secrets to access sinks. output_path: The output path of the materialized features when using an offline sink. """ - cloud_udf_paths = [self.feathr_spark_launcher.upload_or_get_cloud_path(udf_local_path) for udf_local_path in udf_files] + cloud_udf_paths = [ + self.feathr_spark_launcher.upload_or_get_cloud_path(udf_local_path) for udf_local_path in udf_files + ] # Read all features conf generation_config = FeatureGenerationJobParams( generation_config_path=os.path.abspath(feature_gen_conf_path), - feature_config=os.path.join(self.local_workspace_dir, "feature_conf/")) + feature_config=os.path.join(self.local_workspace_dir, "feature_conf/"), + ) # When using offline sink (i.e. output_path is not None) job_tags = {} @@ -827,30 +912,35 @@ def _materialize_features_with_config( job_tags[OUTPUT_FORMAT] = execution_configurations[OUTPUT_FORMAT] else: job_tags[OUTPUT_FORMAT] = "avro" - ''' + """ - Job tags are for job metadata and it's not passed to the actual spark job (i.e. not visible to spark job), more like a platform related thing that Feathr want to add (currently job tags only have job output URL and job output format, ). They are carried over with the job and is visible to every Feathr client. Think this more like some customized metadata for the job which would be weird to be put in the spark job itself. - Job arguments (or sometimes called job parameters)are the arguments which are command line arguments passed into the actual spark job. This is usually highly related with the spark job. In Feathr it's like the input to the scala spark CLI. They are usually not spark specific (for example if we want to specify the location of the feature files, or want to - Job configuration are like "configurations" for the spark job and are usually spark specific. For example, we want to control the no. of write parts for spark Job configurations and job arguments (or sometimes called job parameters) have quite some overlaps (i.e. you can achieve the same goal by either using the job arguments/parameters vs. job configurations). But the job tags should just be used for metadata purpose. - ''' + """ optional_params = [] - if self.env_config.get_from_env_or_akv('KAFKA_SASL_JAAS_CONFIG'): - optional_params = optional_params + ['--kafka-config', self._get_kafka_config_str()] - arguments = [ - '--generation-config', self.feathr_spark_launcher.upload_or_get_cloud_path( - generation_config.generation_config_path), + if self.env_config.get_from_env_or_akv("KAFKA_SASL_JAAS_CONFIG"): + optional_params = optional_params + ["--kafka-config", self._get_kafka_config_str()] + arguments = ( + [ + "--generation-config", + self.feathr_spark_launcher.upload_or_get_cloud_path(generation_config.generation_config_path), # Local Config, comma seperated file names - '--feature-config', self.feathr_spark_launcher.upload_or_get_cloud_path( - generation_config.feature_config), - '--redis-config', self._getRedisConfigStr(), - ] + self._get_offline_storage_arguments()+optional_params + "--feature-config", + self.feathr_spark_launcher.upload_or_get_cloud_path(generation_config.feature_config), + "--redis-config", + self._getRedisConfigStr(), + ] + + self._get_offline_storage_arguments() + + optional_params + ) monitoring_config_str = self._get_monitoring_config_str() if monitoring_config_str: - arguments.append('--monitoring-config') + arguments.append("--monitoring-config") arguments.append(monitoring_config_str) return self.feathr_spark_launcher.submit_feathr_job( - job_name=self.project_name + '_feathr_feature_materialization_job', + job_name=self.project_name + "_feathr_feature_materialization_job", main_jar_path=self._FEATHR_JOB_JAR_PATH, python_files=cloud_udf_paths, job_tags=job_tags, @@ -858,16 +948,15 @@ def _materialize_features_with_config( arguments=arguments, reference_files_path=[], configuration=execution_configurations, - properties=self._collect_secrets(secrets) + properties=self._collect_secrets(secrets), ) def wait_job_to_finish(self, timeout_sec: int = 300): - """Waits for the job to finish in a blocking way unless it times out - """ + """Waits for the job to finish in a blocking way unless it times out""" if self.feathr_spark_launcher.wait_for_completion(timeout_sec): return else: - raise RuntimeError('Spark job failed.') + raise RuntimeError("Spark job failed.") def _getRedisConfigStr(self): """Construct the Redis config string. The host, port, credential and other parameters can be set via environment @@ -881,7 +970,9 @@ def _getRedisConfigStr(self): REDIS_HOST: "{REDIS_HOST}" REDIS_PORT: {REDIS_PORT} REDIS_SSL_ENABLED: {REDIS_SSL_ENABLED} - """.format(REDIS_PASSWORD=password, REDIS_HOST=host, REDIS_PORT=port, REDIS_SSL_ENABLED=str(ssl_enabled)) + """.format( + REDIS_PASSWORD=password, REDIS_HOST=host, REDIS_PORT=port, REDIS_SSL_ENABLED=str(ssl_enabled) + ) return self._reshape_config_str(config_str) def _get_s3_config_str(self): @@ -890,53 +981,59 @@ def _get_s3_config_str(self): endpoint = self.s3_endpoint # if s3 endpoint is set in the feathr_config, then we need other environment variables # keys can't be only accessed through environment - access_key = self.env_config.get_from_env_or_akv('S3_ACCESS_KEY') - secret_key = self.env_config.get_from_env_or_akv('S3_SECRET_KEY') + access_key = self.env_config.get_from_env_or_akv("S3_ACCESS_KEY") + secret_key = self.env_config.get_from_env_or_akv("S3_SECRET_KEY") # HOCON format will be parsed by the Feathr job config_str = """ S3_ENDPOINT: {S3_ENDPOINT} S3_ACCESS_KEY: "{S3_ACCESS_KEY}" S3_SECRET_KEY: "{S3_SECRET_KEY}" - """.format(S3_ENDPOINT=endpoint, S3_ACCESS_KEY=access_key, S3_SECRET_KEY=secret_key) + """.format( + S3_ENDPOINT=endpoint, S3_ACCESS_KEY=access_key, S3_SECRET_KEY=secret_key + ) return self._reshape_config_str(config_str) def _get_adls_config_str(self): """Construct the ADLS config string for abfs(s). The Account, access key and other parameters can be set via environment variables.""" - account = self.env_config.get_from_env_or_akv('ADLS_ACCOUNT') + account = self.env_config.get_from_env_or_akv("ADLS_ACCOUNT") # if ADLS Account is set in the feathr_config, then we need other environment variables # keys can't be only accessed through environment - key = self.env_config.get_from_env_or_akv('ADLS_KEY') + key = self.env_config.get_from_env_or_akv("ADLS_KEY") # HOCON format will be parsed by the Feathr job config_str = """ ADLS_ACCOUNT: {ADLS_ACCOUNT} ADLS_KEY: "{ADLS_KEY}" - """.format(ADLS_ACCOUNT=account, ADLS_KEY=key) + """.format( + ADLS_ACCOUNT=account, ADLS_KEY=key + ) return self._reshape_config_str(config_str) def _get_blob_config_str(self): """Construct the Blob config string for wasb(s). The Account, access key and other parameters can be set via environment variables.""" - account = self.env_config.get_from_env_or_akv('BLOB_ACCOUNT') + account = self.env_config.get_from_env_or_akv("BLOB_ACCOUNT") # if BLOB Account is set in the feathr_config, then we need other environment variables # keys can't be only accessed through environment - key = self.env_config.get_from_env_or_akv('BLOB_KEY') + key = self.env_config.get_from_env_or_akv("BLOB_KEY") # HOCON format will be parsed by the Feathr job config_str = """ BLOB_ACCOUNT: {BLOB_ACCOUNT} BLOB_KEY: "{BLOB_KEY}" - """.format(BLOB_ACCOUNT=account, BLOB_KEY=key) + """.format( + BLOB_ACCOUNT=account, BLOB_KEY=key + ) return self._reshape_config_str(config_str) def _get_sql_config_str(self): """Construct the SQL config string for jdbc. The dbtable (query), user, password and other parameters can be set via environment variables.""" - table = self.env_config.get_from_env_or_akv('JDBC_TABLE') - user = self.env_config.get_from_env_or_akv('JDBC_USER') - password = self.env_config.get_from_env_or_akv('JDBC_PASSWORD') - driver = self.env_config.get_from_env_or_akv('JDBC_DRIVER') - auth_flag = self.env_config.get_from_env_or_akv('JDBC_AUTH_FLAG') - token = self.env_config.get_from_env_or_akv('JDBC_TOKEN') + table = self.env_config.get_from_env_or_akv("JDBC_TABLE") + user = self.env_config.get_from_env_or_akv("JDBC_USER") + password = self.env_config.get_from_env_or_akv("JDBC_PASSWORD") + driver = self.env_config.get_from_env_or_akv("JDBC_DRIVER") + auth_flag = self.env_config.get_from_env_or_akv("JDBC_AUTH_FLAG") + token = self.env_config.get_from_env_or_akv("JDBC_TOKEN") # HOCON format will be parsed by the Feathr job config_str = """ JDBC_TABLE: {JDBC_TABLE} @@ -945,33 +1042,42 @@ def _get_sql_config_str(self): JDBC_DRIVER: {JDBC_DRIVER} JDBC_AUTH_FLAG: {JDBC_AUTH_FLAG} JDBC_TOKEN: {JDBC_TOKEN} - """.format(JDBC_TABLE=table, JDBC_USER=user, JDBC_PASSWORD=password, JDBC_DRIVER = driver, JDBC_AUTH_FLAG = auth_flag, JDBC_TOKEN = token) + """.format( + JDBC_TABLE=table, + JDBC_USER=user, + JDBC_PASSWORD=password, + JDBC_DRIVER=driver, + JDBC_AUTH_FLAG=auth_flag, + JDBC_TOKEN=token, + ) return self._reshape_config_str(config_str) def _get_monitoring_config_str(self): """Construct monitoring-related config string.""" - url = self.env_config.get('monitoring__database__sql__url') - user = self.env_config.get('monitoring__database__sql__user') - password = self.env_config.get_from_env_or_akv('MONITORING_DATABASE_SQL_PASSWORD') + url = self.env_config.get("monitoring__database__sql__url") + user = self.env_config.get("monitoring__database__sql__user") + password = self.env_config.get_from_env_or_akv("MONITORING_DATABASE_SQL_PASSWORD") if url: # HOCON format will be parsed by the Feathr job config_str = """ MONITORING_DATABASE_SQL_URL: "{url}" MONITORING_DATABASE_SQL_USER: {user} MONITORING_DATABASE_SQL_PASSWORD: {password} - """.format(url=url, user=user, password=password) + """.format( + url=url, user=user, password=password + ) return self._reshape_config_str(config_str) else: - "" + """""" def _get_snowflake_config_str(self): """Construct the Snowflake config string for jdbc. The url, user, role and other parameters can be set via yaml config. Password can be set via environment variables.""" - sf_url = self.env_config.get('offline_store__snowflake__url') - sf_user = self.env_config.get('offline_store__snowflake__user') - sf_role = self.env_config.get('offline_store__snowflake__role') - sf_warehouse = self.env_config.get('offline_store__snowflake__warehouse') - sf_password = self.env_config.get_from_env_or_akv('JDBC_SF_PASSWORD') + sf_url = self.env_config.get("offline_store__snowflake__url") + sf_user = self.env_config.get("offline_store__snowflake__user") + sf_role = self.env_config.get("offline_store__snowflake__role") + sf_warehouse = self.env_config.get("offline_store__snowflake__warehouse") + sf_password = self.env_config.get_from_env_or_akv("JDBC_SF_PASSWORD") # HOCON format will be parsed by the Feathr job config_str = """ JDBC_SF_URL: {JDBC_SF_URL} @@ -979,17 +1085,25 @@ def _get_snowflake_config_str(self): JDBC_SF_ROLE: {JDBC_SF_ROLE} JDBC_SF_WAREHOUSE: {JDBC_SF_WAREHOUSE} JDBC_SF_PASSWORD: {JDBC_SF_PASSWORD} - """.format(JDBC_SF_URL=sf_url, JDBC_SF_USER=sf_user, JDBC_SF_PASSWORD=sf_password, JDBC_SF_ROLE=sf_role, JDBC_SF_WAREHOUSE=sf_warehouse) + """.format( + JDBC_SF_URL=sf_url, + JDBC_SF_USER=sf_user, + JDBC_SF_PASSWORD=sf_password, + JDBC_SF_ROLE=sf_role, + JDBC_SF_WAREHOUSE=sf_warehouse, + ) return self._reshape_config_str(config_str) def _get_kafka_config_str(self): """Construct the Kafka config string. The endpoint, access key, secret key, and other parameters can be set via environment variables.""" - sasl = self.env_config.get_from_env_or_akv('KAFKA_SASL_JAAS_CONFIG') + sasl = self.env_config.get_from_env_or_akv("KAFKA_SASL_JAAS_CONFIG") # HOCON format will be parsed by the Feathr job config_str = """ KAFKA_SASL_JAAS_CONFIG: "{sasl}" - """.format(sasl=sasl) + """.format( + sasl=sasl + ) return self._reshape_config_str(config_str) def _collect_secrets(self, additional_secrets=[]): @@ -1000,7 +1114,9 @@ def _collect_secrets(self, additional_secrets=[]): prop_and_value[prop] = self.env_config.get(prop) return prop_and_value - def get_features_from_registry(self, project_name: str, return_keys: bool = False, verbose: bool = False) -> Union[Dict[str, FeatureBase], Tuple[Dict[str, FeatureBase], Dict[str, Union[TypedKey, List[TypedKey]]]]]: + def get_features_from_registry( + self, project_name: str, return_keys: bool = False, verbose: bool = False + ) -> Union[Dict[str, FeatureBase], Tuple[Dict[str, FeatureBase], Dict[str, Union[TypedKey, List[TypedKey]]]]]: """ Get feature from registry by project name. The features got from registry are automatically built. """ @@ -1020,17 +1136,16 @@ def get_features_from_registry(self, project_name: str, return_keys: bool = Fals if verbose and registry_derived_feature_list: logger.info("Get derived features from registry: ") for feature in registry_derived_feature_list: - feature_dict[feature.name] = feature - key_dict[feature.name] = feature.key - if verbose: - logger.info(json.dumps(derived_feature_to_def(feature), indent=2)) + feature_dict[feature.name] = feature + key_dict[feature.name] = feature.key + if verbose: + logger.info(json.dumps(derived_feature_to_def(feature), indent=2)) if return_keys: return feature_dict, key_dict return feature_dict - def _reshape_config_str(self, config_str:str): - if self.spark_runtime == 'local': + def _reshape_config_str(self, config_str: str): + if self.spark_runtime == "local": return "'{" + config_str + "}'" else: return config_str - \ No newline at end of file diff --git a/feathr_project/feathr/constants.py b/feathr_project/feathr/constants.py index 31e64ad25..3b51c2071 100644 --- a/feathr_project/feathr/constants.py +++ b/feathr_project/feathr/constants.py @@ -1,33 +1,33 @@ OUTPUT_PATH_TAG = "output_path" # spark config for output format setting OUTPUT_FORMAT = "spark.feathr.outputFormat" -REDIS_PASSWORD = 'REDIS_PASSWORD' +REDIS_PASSWORD = "REDIS_PASSWORD" # 1MB = 1024*1024 MB_BYTES = 1048576 -INPUT_CONTEXT="PASSTHROUGH" +INPUT_CONTEXT = "PASSTHROUGH" RELATION_CONTAINS = "CONTAINS" RELATION_BELONGSTO = "BELONGSTO" RELATION_CONSUMES = "CONSUMES" RELATION_PRODUCES = "PRODUCES" -# For use in registry. -# For type definition, think it's like a schema of a table. -# This version field is mainly to smooth possible future upgrades, +# For use in registry. +# For type definition, think it's like a schema of a table. +# This version field is mainly to smooth possible future upgrades, # for example, backward incompatible changes should be introduced in v2, to make sure that features registered with v1 schema can still be used -REGISTRY_TYPEDEF_VERSION="v1" +REGISTRY_TYPEDEF_VERSION = "v1" -TYPEDEF_SOURCE=f'feathr_source_{REGISTRY_TYPEDEF_VERSION}' +TYPEDEF_SOURCE = f"feathr_source_{REGISTRY_TYPEDEF_VERSION}" # TODO: change the name from feathr_workspace_ to feathr_project_ -TYPEDEF_FEATHR_PROJECT=f'feathr_workspace_{REGISTRY_TYPEDEF_VERSION}' -TYPEDEF_DERIVED_FEATURE=f'feathr_derived_feature_{REGISTRY_TYPEDEF_VERSION}' -TYPEDEF_ANCHOR=f'feathr_anchor_{REGISTRY_TYPEDEF_VERSION}' -TYPEDEF_ANCHOR_FEATURE=f'feathr_anchor_feature_{REGISTRY_TYPEDEF_VERSION}' +TYPEDEF_FEATHR_PROJECT = f"feathr_workspace_{REGISTRY_TYPEDEF_VERSION}" +TYPEDEF_DERIVED_FEATURE = f"feathr_derived_feature_{REGISTRY_TYPEDEF_VERSION}" +TYPEDEF_ANCHOR = f"feathr_anchor_{REGISTRY_TYPEDEF_VERSION}" +TYPEDEF_ANCHOR_FEATURE = f"feathr_anchor_feature_{REGISTRY_TYPEDEF_VERSION}" -TYPEDEF_ARRAY_ANCHOR=f"array" -TYPEDEF_ARRAY_DERIVED_FEATURE=f"array" -TYPEDEF_ARRAY_ANCHOR_FEATURE=f"array" +TYPEDEF_ARRAY_ANCHOR = f"array" +TYPEDEF_ARRAY_DERIVED_FEATURE = f"array" +TYPEDEF_ARRAY_ANCHOR_FEATURE = f"array" -JOIN_CLASS_NAME="com.linkedin.feathr.offline.job.FeatureJoinJob" -GEN_CLASS_NAME="com.linkedin.feathr.offline.job.FeatureGenJob" \ No newline at end of file +JOIN_CLASS_NAME = "com.linkedin.feathr.offline.job.FeatureJoinJob" +GEN_CLASS_NAME = "com.linkedin.feathr.offline.job.FeatureGenJob" diff --git a/feathr_project/feathr/datasets/constants.py b/feathr_project/feathr/datasets/constants.py index 5c138e0ec..e8372ffc6 100644 --- a/feathr_project/feathr/datasets/constants.py +++ b/feathr_project/feathr/datasets/constants.py @@ -20,21 +20,13 @@ # Product recommendation sample datasets. # Ref: -PRODUCT_RECOMMENDATION_USER_OBSERVATION_URL = ( - "https://azurefeathrstorage.blob.core.windows.net/public/sample_data/product_recommendation_sample/user_observation_mock_data.csv" -) +PRODUCT_RECOMMENDATION_USER_OBSERVATION_URL = "https://azurefeathrstorage.blob.core.windows.net/public/sample_data/product_recommendation_sample/user_observation_mock_data.csv" -PRODUCT_RECOMMENDATION_USER_PROFILE_URL = ( - "https://azurefeathrstorage.blob.core.windows.net/public/sample_data/product_recommendation_sample/user_profile_mock_data.csv" -) +PRODUCT_RECOMMENDATION_USER_PROFILE_URL = "https://azurefeathrstorage.blob.core.windows.net/public/sample_data/product_recommendation_sample/user_profile_mock_data.csv" -PRODUCT_RECOMMENDATION_USER_PURCHASE_HISTORY_URL = ( - "https://azurefeathrstorage.blob.core.windows.net/public/sample_data/product_recommendation_sample/user_purchase_history_mock_data.csv" -) +PRODUCT_RECOMMENDATION_USER_PURCHASE_HISTORY_URL = "https://azurefeathrstorage.blob.core.windows.net/public/sample_data/product_recommendation_sample/user_purchase_history_mock_data.csv" -PRODUCT_RECOMMENDATION_PRODUCT_DETAIL_URL = ( - "https://azurefeathrstorage.blob.core.windows.net/public/sample_data/product_recommendation_sample/product_detail_mock_data.csv" -) +PRODUCT_RECOMMENDATION_PRODUCT_DETAIL_URL = "https://azurefeathrstorage.blob.core.windows.net/public/sample_data/product_recommendation_sample/product_detail_mock_data.csv" # Hotel review sample datasets. # Ref: https://www.kaggle.com/datasets/datafiniti/hotel-reviews diff --git a/feathr_project/feathr/definition/_materialization_utils.py b/feathr_project/feathr/definition/_materialization_utils.py index f4e862a8b..6f9821518 100644 --- a/feathr_project/feathr/definition/_materialization_utils.py +++ b/feathr_project/feathr/definition/_materialization_utils.py @@ -1,10 +1,11 @@ -from jinja2 import Template +from jinja2 import Template from feathr.definition.materialization_settings import MaterializationSettings def _to_materialization_config(settings: MaterializationSettings): # produce materialization config - tm = Template(""" + tm = Template( + """ operational: { name: {{ settings.name }} endTime: "{{ settings.backfill_time.end.strftime('%Y-%m-%d %H:%M:%S') }}" @@ -20,6 +21,7 @@ def _to_materialization_config(settings: MaterializationSettings): ] } features: [{{','.join(settings.feature_names)}}] - """) + """ + ) msg = tm.render(settings=settings) - return msg \ No newline at end of file + return msg diff --git a/feathr_project/feathr/definition/aggregation.py b/feathr_project/feathr/definition/aggregation.py index 4521a90d8..c014308c1 100644 --- a/feathr_project/feathr/definition/aggregation.py +++ b/feathr_project/feathr/definition/aggregation.py @@ -5,6 +5,7 @@ class Aggregation(Enum): """ The built-in aggregation functions for LookupFeature """ + # No operation NOP = 0 # Average @@ -21,4 +22,4 @@ class Aggregation(Enum): # Pick the latest value according to its timestamp LATEST = 10 # Pick the first value from the looked up values (non-deterministic) - FIRST = 11 \ No newline at end of file + FIRST = 11 diff --git a/feathr_project/feathr/definition/anchor.py b/feathr_project/feathr/definition/anchor.py index 17c8728cf..9ebdba29e 100644 --- a/feathr_project/feathr/definition/anchor.py +++ b/feathr_project/feathr/definition/anchor.py @@ -9,30 +9,33 @@ class FeatureAnchor(HoconConvertible): """ - A feature anchor defines a set of features on top of a data source, a.k.a. a set of features anchored to a source. + A feature anchor defines a set of features on top of a data source, a.k.a. a set of features anchored to a source. - The feature producer writes multiple anchors for a feature, exposing the same feature name for the feature - consumer to reference it. + The feature producer writes multiple anchors for a feature, exposing the same feature name for the feature + consumer to reference it. - Attributes: - name: Unique name of the anchor. - source: data source that the features are anchored to. Should be either of `INPUT_CONTEXT` or `feathr.source.Source` - features: list of features defined within this anchor. - registry_tags: A dict of (str, str) that you can pass to feature registry for better organization. - For example, you can use {"deprecated": "true"} to indicate this anchor is deprecated, etc. + Attributes: + name: Unique name of the anchor. + source: data source that the features are anchored to. Should be either of `INPUT_CONTEXT` or `feathr.source.Source` + features: list of features defined within this anchor. + registry_tags: A dict of (str, str) that you can pass to feature registry for better organization. + For example, you can use {"deprecated": "true"} to indicate this anchor is deprecated, etc. """ - def __init__(self, - name: str, - source: Source, - features: List[Feature], - registry_tags: Optional[Dict[str, str]] = None, - **kwargs): + + def __init__( + self, + name: str, + source: Source, + features: List[Feature], + registry_tags: Optional[Dict[str, str]] = None, + **kwargs, + ): self.name = name self.features = features self.source = source - self.registry_tags=registry_tags + self.registry_tags = registry_tags # Add a hidden option to skip validation, Anchor could be half-constructed during the loading from registry - if not kwargs.get("__no_validate", False) : + if not kwargs.get("__no_validate", False): self.validate_features() def validate_features(self): @@ -41,11 +44,14 @@ def validate_features(self): if self.source != INPUT_CONTEXT: for feature in self.features: if feature.key == [DUMMY_KEY]: - raise RuntimeError(f"For anchors of non-INPUT_CONTEXT source, key of feature {feature.name} " - f"should be explicitly specified and not left blank.") - + raise RuntimeError( + f"For anchors of non-INPUT_CONTEXT source, key of feature {feature.name} " + f"should be explicitly specified and not left blank." + ) + def to_feature_config(self) -> str: - tm = Template(""" + tm = Template( + """ {{anchor_name}}: { source: {{source.name}} key.sqlExpr: [{{key_list}}] @@ -55,12 +61,10 @@ def to_feature_config(self) -> str: {% endfor %} } } - """) - key_list = ','.join((key for key in self.features[0].key_alias) if len(self.features)!=0 else []) - return tm.render(anchor_name = self.name, - key_list = key_list, - features = self.features, - source = self.source) + """ + ) + key_list = ",".join((key for key in self.features[0].key_alias) if len(self.features) != 0 else []) + return tm.render(anchor_name=self.name, key_list=key_list, features=self.features, source=self.source) def __str__(self): return self.to_feature_config() diff --git a/feathr_project/feathr/definition/config_helper.py b/feathr_project/feathr/definition/config_helper.py index a2e63e977..5ac84222c 100644 --- a/feathr_project/feathr/definition/config_helper.py +++ b/feathr_project/feathr/definition/config_helper.py @@ -3,12 +3,11 @@ from feathr.utils._file_utils import write_to_file from feathr.definition.anchor import FeatureAnchor from feathr.constants import * -from feathr.definition.feature import Feature, FeatureType,FeatureBase +from feathr.definition.feature import Feature, FeatureType, FeatureBase from feathr.definition.feature_derivations import DerivedFeature from feathr.definition.repo_definitions import RepoDefinitions from feathr.definition.source import HdfsSource, InputContext, JdbcSource, Source -from feathr.definition.transformation import (ExpressionTransformation, Transformation, - WindowAggTransformation) +from feathr.definition.transformation import ExpressionTransformation, Transformation, WindowAggTransformation from feathr.definition.typed_key import TypedKey from feathr.registry.feature_registry import FeathrRegistry from feathr.definition.repo_definitions import RepoDefinitions @@ -19,62 +18,57 @@ import importlib import os + class FeathrConfigHelper(object): def __init__(self) -> None: pass + def _get_py_files(self, path: Path) -> List[Path]: """Get all Python files under path recursively, excluding __init__.py""" py_files = [] - for item in path.glob('**/*.py'): + for item in path.glob("**/*.py"): if "__init__.py" != item.name: py_files.append(item) return py_files def _convert_to_module_path(self, path: Path, workspace_path: Path) -> str: """Convert a Python file path to its module path so that we can import it later""" - prefix = os.path.commonprefix( - [path.resolve(), workspace_path.resolve()]) + prefix = os.path.commonprefix([path.resolve(), workspace_path.resolve()]) resolved_path = str(path.resolve()) - module_path = resolved_path[len(prefix): -len(".py")] + module_path = resolved_path[len(prefix) : -len(".py")] # Convert features under nested folder to module name # e.g. /path/to/pyfile will become path.to.pyfile - return ( - module_path - .lstrip('/') - .replace("/", ".") - ) + return module_path.lstrip("/").replace("/", ".") def _extract_features_from_context(self, anchor_list, derived_feature_list, result_path: Path) -> RepoDefinitions: """Collect feature definitions from the context instead of python files""" definitions = RepoDefinitions( - sources=set(), - features=set(), - transformations=set(), - feature_anchors=set(), - derived_features=set() + sources=set(), features=set(), transformations=set(), feature_anchors=set(), derived_features=set() ) for derived_feature in derived_feature_list: if isinstance(derived_feature, DerivedFeature): definitions.derived_features.add(derived_feature) - definitions.transformations.add( - vars(derived_feature)["transform"]) + definitions.transformations.add(vars(derived_feature)["transform"]) else: - raise RuntimeError(f"Please make sure you pass a list of `DerivedFeature` objects to the `derived_feature_list` argument. {str(type(derived_feature))} is detected.") + raise RuntimeError( + f"Please make sure you pass a list of `DerivedFeature` objects to the `derived_feature_list` argument. {str(type(derived_feature))} is detected." + ) for anchor in anchor_list: # obj is `FeatureAnchor` definitions.feature_anchors.add(anchor) # add the source section of this `FeatureAnchor` object - definitions.sources.add(vars(anchor)['source']) - for feature in vars(anchor)['features']: + definitions.sources.add(vars(anchor)["source"]) + for feature in vars(anchor)["features"]: # get the transformation object from `Feature` or `DerivedFeature` if isinstance(feature, Feature): # feature is of type `Feature` definitions.features.add(feature) definitions.transformations.add(vars(feature)["transform"]) else: - - raise RuntimeError(f"Please make sure you pass a list of `Feature` objects. {str(type(feature))} is detected.") + raise RuntimeError( + f"Please make sure you pass a list of `Feature` objects. {str(type(feature))} is detected." + ) return definitions @@ -84,11 +78,7 @@ def _extract_features(self, workspace_path: Path) -> RepoDefinitions: # Add workspace path to system path so that we can load features defined in Python via import_module sys.path.append(str(workspace_path)) definitions = RepoDefinitions( - sources=set(), - features=set(), - transformations=set(), - feature_anchors=set(), - derived_features=set() + sources=set(), features=set(), transformations=set(), feature_anchors=set(), derived_features=set() ) for py_file in self._get_py_files(workspace_path): module_path = self._convert_to_module_path(py_file, workspace_path) @@ -116,8 +106,7 @@ def save_to_feature_config(self, workspace_path: Path, config_save_dir: Path): def save_to_feature_config_from_context(self, anchor_list, derived_feature_list, local_workspace_dir: Path): """Save feature definition within the workspace into HOCON feature config files from current context, rather than reading from python files""" - repo_definitions = self._extract_features_from_context( - anchor_list, derived_feature_list, local_workspace_dir) + repo_definitions = self._extract_features_from_context(anchor_list, derived_feature_list, local_workspace_dir) self._save_request_feature_config(repo_definitions, local_workspace_dir) self._save_anchored_feature_config(repo_definitions, local_workspace_dir) self._save_derived_feature_config(repo_definitions, local_workspace_dir) @@ -137,11 +126,9 @@ def _save_request_feature_config(self, repo_definitions: RepoDefinitions, local_ """ ) - request_feature_configs = tm.render( - feature_anchors=repo_definitions.feature_anchors) + request_feature_configs = tm.render(feature_anchors=repo_definitions.feature_anchors) config_file_path = os.path.join(local_workspace_dir, config_file_name) - write_to_file(content=request_feature_configs, - full_file_name=config_file_path) + write_to_file(content=request_feature_configs, full_file_name=config_file_path) @classmethod def _save_anchored_feature_config(self, repo_definitions: RepoDefinitions, local_workspace_dir="./"): @@ -166,11 +153,11 @@ def _save_anchored_feature_config(self, repo_definitions: RepoDefinitions, local } """ ) - anchored_feature_configs = tm.render(feature_anchors=repo_definitions.feature_anchors, - sources=repo_definitions.sources) + anchored_feature_configs = tm.render( + feature_anchors=repo_definitions.feature_anchors, sources=repo_definitions.sources + ) config_file_path = os.path.join(local_workspace_dir, config_file_name) - write_to_file(content=anchored_feature_configs, - full_file_name=config_file_path) + write_to_file(content=anchored_feature_configs, full_file_name=config_file_path) @classmethod def _save_derived_feature_config(self, repo_definitions: RepoDefinitions, local_workspace_dir="./"): @@ -185,9 +172,6 @@ def _save_derived_feature_config(self, repo_definitions: RepoDefinitions, local_ } """ ) - derived_feature_configs = tm.render( - derived_features=repo_definitions.derived_features) + derived_feature_configs = tm.render(derived_features=repo_definitions.derived_features) config_file_path = os.path.join(local_workspace_dir, config_file_name) - write_to_file(content=derived_feature_configs, - full_file_name=config_file_path) - + write_to_file(content=derived_feature_configs, full_file_name=config_file_path) diff --git a/feathr_project/feathr/definition/dtype.py b/feathr_project/feathr/definition/dtype.py index 7211c455a..472bf5e8e 100644 --- a/feathr_project/feathr/definition/dtype.py +++ b/feathr_project/feathr/definition/dtype.py @@ -2,6 +2,7 @@ from typing import List from feathr.definition.feathrconfig import HoconConvertible + class ValueType(enum.Enum): """Data type to describe feature keys or observation keys. @@ -15,6 +16,7 @@ class ValueType(enum.Enum): STRING: key data type is string, for example, a user name, 'user_joe' BYTES: key data type is bytes. """ + UNSPECIFIED = 0 BOOL = 1 INT32 = 2 @@ -24,6 +26,7 @@ class ValueType(enum.Enum): STRING = 6 BYTES = 7 + def value_type_to_str(v: ValueType) -> str: return { ValueType.UNSPECIFIED: "UNSPECIFIED", @@ -63,22 +66,32 @@ def str_to_value_type(v: str) -> ValueType: "7": ValueType.BYTES, }[v.upper()] + class FeatureType(HoconConvertible): """Base class for all feature types""" - def __init__(self, val_type: ValueType, dimension_type: List[ValueType] = [], tensor_category: str = "DENSE", type: str = "TENSOR"): + + def __init__( + self, + val_type: ValueType, + dimension_type: List[ValueType] = [], + tensor_category: str = "DENSE", + type: str = "TENSOR", + ): self.val_type = val_type self.dimension_type = dimension_type self.tensor_category = tensor_category self.type = type def __eq__(self, o) -> bool: - return self.val_type == o.val_type \ - and self.dimension_type == o.dimension_type \ - and self.tensor_category == o.tensor_category \ + return ( + self.val_type == o.val_type + and self.dimension_type == o.dimension_type + and self.tensor_category == o.tensor_category and self.type == o.type + ) def to_feature_config(self) -> str: - return fr""" + return rf""" type: {{ type: {self.type} tensorCategory: {self.tensor_category} @@ -87,74 +100,82 @@ def to_feature_config(self) -> str: }} """ + class BooleanFeatureType(FeatureType): - """Boolean feature value, either true or false. - """ + """Boolean feature value, either true or false.""" + def __init__(self): - FeatureType.__init__(self, val_type = ValueType.BOOL) + FeatureType.__init__(self, val_type=ValueType.BOOL) + class Int32FeatureType(FeatureType): - """32-bit integer feature value, for example, 123, 98765. - """ + """32-bit integer feature value, for example, 123, 98765.""" + def __init__(self): - FeatureType.__init__(self, val_type = ValueType.INT32) + FeatureType.__init__(self, val_type=ValueType.INT32) + class Int64FeatureType(FeatureType): - """64-bit integer(a.k.a. Long in some system) feature value, for example, 123, 98765 but stored in 64-bit integer. - """ + """64-bit integer(a.k.a. Long in some system) feature value, for example, 123, 98765 but stored in 64-bit integer.""" + def __init__(self): - FeatureType.__init__(self, val_type = ValueType.INT64) + FeatureType.__init__(self, val_type=ValueType.INT64) + class FloatFeatureType(FeatureType): - """Float feature value, for example, 1.3f, 2.4f. - """ + """Float feature value, for example, 1.3f, 2.4f.""" + def __init__(self): - FeatureType.__init__(self, val_type = ValueType.FLOAT) + FeatureType.__init__(self, val_type=ValueType.FLOAT) + class DoubleFeatureType(FeatureType): - """Double feature value, for example, 1.3d, 2.4d. Double has better precision than float. - """ + """Double feature value, for example, 1.3d, 2.4d. Double has better precision than float.""" + def __init__(self): - FeatureType.__init__(self, val_type = ValueType.DOUBLE) + FeatureType.__init__(self, val_type=ValueType.DOUBLE) + class StringFeatureType(FeatureType): - """String feature value, for example, 'apple', 'orange'. - """ + """String feature value, for example, 'apple', 'orange'.""" + def __init__(self): - FeatureType.__init__(self, val_type = ValueType.STRING) + FeatureType.__init__(self, val_type=ValueType.STRING) + class BytesFeatureType(FeatureType): - """Bytes feature value. - """ + """Bytes feature value.""" + def __init__(self): - FeatureType.__init__(self, val_type = ValueType.BYTES) + FeatureType.__init__(self, val_type=ValueType.BYTES) + class FloatVectorFeatureType(FeatureType): - """Float vector feature value, for example, [1,3f, 2.4f, 3.9f] - """ + """Float vector feature value, for example, [1,3f, 2.4f, 3.9f]""" + def __init__(self): - FeatureType.__init__(self, val_type = ValueType.FLOAT, dimension_type = [ValueType.INT32]) + FeatureType.__init__(self, val_type=ValueType.FLOAT, dimension_type=[ValueType.INT32]) class Int32VectorFeatureType(FeatureType): - """32-bit integer vector feature value, for example, [1, 3, 9] - """ + """32-bit integer vector feature value, for example, [1, 3, 9]""" + def __init__(self): - FeatureType.__init__(self, val_type = ValueType.INT32, dimension_type = [ValueType.INT32]) + FeatureType.__init__(self, val_type=ValueType.INT32, dimension_type=[ValueType.INT32]) class Int64VectorFeatureType(FeatureType): - """64-bit integer vector feature value, for example, [1, 3, 9] - """ + """64-bit integer vector feature value, for example, [1, 3, 9]""" + def __init__(self): - FeatureType.__init__(self, val_type = ValueType.INT64, dimension_type = [ValueType.INT32]) + FeatureType.__init__(self, val_type=ValueType.INT64, dimension_type=[ValueType.INT32]) class DoubleVectorFeatureType(FeatureType): - """Double vector feature value, for example, [1.3d, 3.3d, 9.3d] - """ + """Double vector feature value, for example, [1.3d, 3.3d, 9.3d]""" + def __init__(self): - FeatureType.__init__(self, val_type = ValueType.DOUBLE, dimension_type = [ValueType.INT32]) + FeatureType.__init__(self, val_type=ValueType.DOUBLE, dimension_type=[ValueType.INT32]) # tensor dimension/axis @@ -174,4 +195,4 @@ def __init__(self, shape: int, dType: ValueType = ValueType.INT32): FLOAT_VECTOR = FloatVectorFeatureType() INT32_VECTOR = Int32VectorFeatureType() INT64_VECTOR = Int64VectorFeatureType() -DOUBLE_VECTOR = DoubleVectorFeatureType() \ No newline at end of file +DOUBLE_VECTOR = DoubleVectorFeatureType() diff --git a/feathr_project/feathr/definition/feathrconfig.py b/feathr_project/feathr/definition/feathrconfig.py index 8c1e90636..52490e72f 100644 --- a/feathr_project/feathr/definition/feathrconfig.py +++ b/feathr_project/feathr/definition/feathrconfig.py @@ -2,9 +2,9 @@ class HoconConvertible(ABC): - """Represent classes that can convert into Feathr HOCON config. - """ + """Represent classes that can convert into Feathr HOCON config.""" + @abstractmethod def to_feature_config(self) -> str: """Convert the feature anchor definition into internal HOCON format. (For internal use ony)""" - pass \ No newline at end of file + pass diff --git a/feathr_project/feathr/definition/feature.py b/feathr_project/feathr/definition/feature.py index 0720aced7..587b0dc22 100644 --- a/feathr_project/feathr/definition/feature.py +++ b/feathr_project/feathr/definition/feature.py @@ -22,22 +22,24 @@ class FeatureBase(HoconConvertible): transform: A transformation used to produce its feature value. e.g. amount * 10 registry_tags: A dict of (str, str) that you can pass to feature registry for better organization. For example, you can use {"deprecated": "true"} to indicate this feature is deprecated, etc. """ - def __init__(self, - name: str, - feature_type: FeatureType, - transform: Optional[Union[str, Transformation]] = None, - key: Optional[Union[TypedKey, List[TypedKey]]] = [DUMMY_KEY], - registry_tags: Optional[Dict[str, str]] = None, - ): + + def __init__( + self, + name: str, + feature_type: FeatureType, + transform: Optional[Union[str, Transformation]] = None, + key: Optional[Union[TypedKey, List[TypedKey]]] = [DUMMY_KEY], + registry_tags: Optional[Dict[str, str]] = None, + ): FeatureBase.validate_feature_name(name) # Validate the feature type if not isinstance(feature_type, FeatureType): - raise KeyError(f'Feature type must be a FeatureType class, like INT32, but got {feature_type}') + raise KeyError(f"Feature type must be a FeatureType class, like INT32, but got {feature_type}") self.name = name self.feature_type = feature_type - self.registry_tags=registry_tags + self.registry_tags = registry_tags self.key = key if isinstance(key, List) else [key] # feature_alias: Rename the derived feature to `feature_alias`. Default to feature name. self.feature_alias = name @@ -60,18 +62,21 @@ def validate_feature_name(cls, feature_name: str) -> bool: This is because some compute engines, such as Spark, will consider them as operators in feature name. """ if not feature_name: - raise Exception('Feature name rule violation: empty feature name detected') + raise Exception("Feature name rule violation: empty feature name detected") - feature_validator = re.compile(r"""^ # from the start of the string + feature_validator = re.compile( + r"""^ # from the start of the string [a-zA-Z_]{1} # first character can only be a letter or underscore [a-zA-Z0-9_]+ # as many letters, numbers, or underscores as you like - $""", # to the end of the string - re.X) + $""", # to the end of the string + re.X, + ) if not feature_validator.match(feature_name): raise Exception( - 'Feature name rule violation: only letters, numbers, and underscores are allowed in the name, ' + - f'and the name cannot start with a number. name={feature_name}') + "Feature name rule violation: only letters, numbers, and underscores are allowed in the name, " + + f"and the name cannot start with a number. name={feature_name}" + ) return True @@ -79,7 +84,7 @@ def with_key(self, key_alias: Union[str, List[str]]): """Rename the feature key with the alias. This is useful in derived features that depends on the same feature with different keys.""" cleaned_key_alias = [key_alias] if isinstance(key_alias, str) else key_alias - assert(len(cleaned_key_alias) == len(self.key)) + assert len(cleaned_key_alias) == len(self.key) new_key = [] for i in range(0, len(cleaned_key_alias)): typed_key = deepcopy(self.key[i]) @@ -111,22 +116,24 @@ class Feature(FeatureBase): transform: A row transformation used to produce its feature value. e.g. amount * 10 registry_tags: A dict of (str, str) that you can pass to feature registry for better organization. For example, you can use {"deprecated": "true"} to indicate this feature is deprecated, etc. """ - def __init__(self, - name: str, - feature_type: FeatureType, - key: Optional[Union[TypedKey, List[TypedKey]]] = [DUMMY_KEY], - transform: Optional[Union[str, Transformation]] = None, - registry_tags: Optional[Dict[str, str]] = None, - ): - super(Feature, self).__init__(name, feature_type, transform, key, registry_tags) + def __init__( + self, + name: str, + feature_type: FeatureType, + key: Optional[Union[TypedKey, List[TypedKey]]] = [DUMMY_KEY], + transform: Optional[Union[str, Transformation]] = None, + registry_tags: Optional[Dict[str, str]] = None, + ): + super(Feature, self).__init__(name, feature_type, transform, key, registry_tags) def to_feature_config(self) -> str: - tm = Template(""" + tm = Template( + """ {{feature.name}}: { {{feature.transform.to_feature_config()}} {{feature.feature_type.to_feature_config()}} } - """) + """ + ) return tm.render(feature=self) - diff --git a/feathr_project/feathr/definition/feature_derivations.py b/feathr_project/feathr/definition/feature_derivations.py index 9205685ce..39adcb357 100644 --- a/feathr_project/feathr/definition/feature_derivations.py +++ b/feathr_project/feathr/definition/feature_derivations.py @@ -20,34 +20,46 @@ class DerivedFeature(FeatureBase): registry_tags: A dict of (str, str) that you can pass to feature registry for better organization. For example, you can For example, you can use {"deprecated": "true"} to indicate this feature is deprecated, etc. """ - def __init__(self, - name: str, - feature_type: FeatureType, - input_features: Union[FeatureBase, List[FeatureBase]], - transform: Union[str, RowTransformation], - key: Optional[Union[TypedKey, List[TypedKey]]] = [DUMMY_KEY], - registry_tags: Optional[Dict[str, str]] = None, - **kwargs): - super(DerivedFeature, self).__init__(name, feature_type, key=key, transform=transform, registry_tags=registry_tags) + def __init__( + self, + name: str, + feature_type: FeatureType, + input_features: Union[FeatureBase, List[FeatureBase]], + transform: Union[str, RowTransformation], + key: Optional[Union[TypedKey, List[TypedKey]]] = [DUMMY_KEY], + registry_tags: Optional[Dict[str, str]] = None, + **kwargs, + ): + super(DerivedFeature, self).__init__( + name, feature_type, key=key, transform=transform, registry_tags=registry_tags + ) self.input_features = input_features if isinstance(input_features, List) else [input_features] # Add a hidden option to skip validation, Anchor could be half-constructed during the loading from registry - if not kwargs.get("__no_validate", False) : + if not kwargs.get("__no_validate", False): self.validate_feature() def validate_feature(self): """Validate the derived feature is valid""" - + input_feature_key_alias = [] - # for new entity in Purview, the attributes are Camel cases, while the old logic works as snake cases. + # for new entity in Purview, the attributes are Camel cases, while the old logic works as snake cases. # Modify the conversion to work with both schema. for feature in self.input_features: - input_feature_key_alias.extend([x['keyColumnAlias'] for x in feature['attributes']['key']] if isinstance(feature,dict) else feature.key_alias) + input_feature_key_alias.extend( + [x["keyColumnAlias"] for x in feature["attributes"]["key"]] + if isinstance(feature, dict) + else feature.key_alias + ) for key_alias in self.key_alias: - assert key_alias in input_feature_key_alias, "key alias {} in derived feature {} must come from " \ - "its input features key alias list {}".format(key_alias, self.name, input_feature_key_alias) + assert ( + key_alias in input_feature_key_alias + ), "key alias {} in derived feature {} must come from " "its input features key alias list {}".format( + key_alias, self.name, input_feature_key_alias + ) def to_feature_config(self) -> str: - tm = Template(""" + tm = Template( + """ {{derived_feature.name}}: { key: [{{','.join(derived_feature.key_alias)}}] inputs: { @@ -61,5 +73,6 @@ def to_feature_config(self) -> str: definition.sqlExpr: {{derived_feature.transform.to_feature_config(False)}} {{derived_feature.feature_type.to_feature_config()}} } - """) + """ + ) return tm.render(derived_feature=self) diff --git a/feathr_project/feathr/definition/lookup_feature.py b/feathr_project/feathr/definition/lookup_feature.py index 2f1b80ccd..d0c71732c 100644 --- a/feathr_project/feathr/definition/lookup_feature.py +++ b/feathr_project/feathr/definition/lookup_feature.py @@ -10,6 +10,7 @@ from feathr.definition.typed_key import DUMMY_KEY, TypedKey from feathr.definition.aggregation import Aggregation + class LookupFeature(DerivedFeature): """A lookup feature is a feature defined on top of two other features, i.e. using the feature value of the base feature as key, to lookup the feature value from the expansion feature. e.g. a lookup feature user_purchased_item_avg_price could be key-ed by user_id, and computed by: @@ -28,23 +29,31 @@ class LookupFeature(DerivedFeature): e.g. feature value is an array and each value in the array is used once as a lookup key. """ - def __init__(self, - name: str, - feature_type: FeatureType, - base_feature: FeatureBase, - expansion_feature: FeatureBase, - aggregation: Aggregation, - key: Optional[Union[TypedKey, List[TypedKey]]] = [DUMMY_KEY], - registry_tags: Optional[Dict[str, str]] = None, - ): - super(LookupFeature, self).__init__(name, feature_type, input_features=[base_feature, expansion_feature], - transform="", key=key, registry_tags=registry_tags) + def __init__( + self, + name: str, + feature_type: FeatureType, + base_feature: FeatureBase, + expansion_feature: FeatureBase, + aggregation: Aggregation, + key: Optional[Union[TypedKey, List[TypedKey]]] = [DUMMY_KEY], + registry_tags: Optional[Dict[str, str]] = None, + ): + super(LookupFeature, self).__init__( + name, + feature_type, + input_features=[base_feature, expansion_feature], + transform="", + key=key, + registry_tags=registry_tags, + ) self.base_feature = base_feature self.expansion_feature = expansion_feature self.aggregation = aggregation def to_feature_config(self) -> str: - tm = Template(""" + tm = Template( + """ {{lookup_feature.name}}: { key: [{{','.join(lookup_feature.key_alias)}}] join: { @@ -60,5 +69,6 @@ def to_feature_config(self) -> str: aggregation: {{lookup_feature.aggregation.name}} {{lookup_feature.feature_type.to_feature_config()}} } - """) + """ + ) return tm.render(lookup_feature=self) diff --git a/feathr_project/feathr/definition/materialization_settings.py b/feathr_project/feathr/definition/materialization_settings.py index d275b7eb3..90ddb2f1c 100644 --- a/feathr_project/feathr/definition/materialization_settings.py +++ b/feathr_project/feathr/definition/materialization_settings.py @@ -12,6 +12,7 @@ class BackfillTime: end: end time of the backfill, inclusive. step: duration of each backfill step. e.g. if you want to materialize features on daily basis, use timedelta(days=1) """ + def __init__(self, start: datetime, end: datetime, step: timedelta): self.start = start self.end = end @@ -30,14 +31,23 @@ class MaterializationSettings: If 'DAILY', output paths should be: yyyy/MM/dd; Otherwise would be: yyyy/MM/dd/HH """ - def __init__(self, name: str, sinks: List[Sink], feature_names: List[str], backfill_time: Optional[BackfillTime] = None, resolution: str = "DAILY"): + + def __init__( + self, + name: str, + sinks: List[Sink], + feature_names: List[str], + backfill_time: Optional[BackfillTime] = None, + resolution: str = "DAILY", + ): if resolution not in ["DAILY", "HOURLY"]: - raise RuntimeError( - f'{resolution} is not supported. Only \'DAILY\' and \'HOURLY\' are currently supported.') + raise RuntimeError(f"{resolution} is not supported. Only 'DAILY' and 'HOURLY' are currently supported.") self.resolution = resolution self.name = name now = datetime.now() - self.backfill_time = backfill_time if backfill_time else BackfillTime(start=now, end=now, step=timedelta(days=1)) + self.backfill_time = ( + backfill_time if backfill_time else BackfillTime(start=now, end=now, step=timedelta(days=1)) + ) for sink in sinks: if isinstance(sink, HdfsSink): self.has_hdfs_sink = True @@ -48,18 +58,20 @@ def __init__(self, name: str, sinks: List[Sink], feature_names: List[str], backf self.feature_names = feature_names def get_backfill_cutoff_time(self) -> List[datetime]: - """Get the backfill cutoff time points for materialization. - E.g. for `BackfillTime(start=datetime(2022, 3, 1), end=datetime(2022, 3, 5), step=timedelta(days=1))`, - it returns cutoff time list as `[2022-3-1, 2022-3-2, 2022-3-3, 2022-3-4, 2022-3-5]`, - for `BackfillTime(start=datetime(2022, 3, 1, 1), end=datetime(2022, 3, 1, 5), step=timedelta(hours=1))`, + """Get the backfill cutoff time points for materialization. + E.g. for `BackfillTime(start=datetime(2022, 3, 1), end=datetime(2022, 3, 5), step=timedelta(days=1))`, + it returns cutoff time list as `[2022-3-1, 2022-3-2, 2022-3-3, 2022-3-4, 2022-3-5]`, + for `BackfillTime(start=datetime(2022, 3, 1, 1), end=datetime(2022, 3, 1, 5), step=timedelta(hours=1))`, it returns cutoff time list as `[2022-3-1 01:00:00, 2022-3-1 02:00:00, 2022-3-1 03:00:00, 2022-3-1 04:00:00, 2022-3-1 05:00:00]` - """ + """ start_time = self.backfill_time.start end_time = self.backfill_time.end step_in_seconds = self.backfill_time.step.total_seconds() - assert start_time <= end_time, "Start time {} must be earlier or equal to end time {}".format(start_time, end_time) + assert start_time <= end_time, "Start time {} must be earlier or equal to end time {}".format( + start_time, end_time + ) assert step_in_seconds > 0, "Step in time range should be greater than 0, but got {}".format(step_in_seconds) num_delta = (self.backfill_time.end - self.backfill_time.start).total_seconds() / step_in_seconds num_delta = math.floor(num_delta) + 1 - return [end_time - timedelta(seconds=n*step_in_seconds) for n in reversed(range(num_delta))] \ No newline at end of file + return [end_time - timedelta(seconds=n * step_in_seconds) for n in reversed(range(num_delta))] diff --git a/feathr_project/feathr/definition/monitoring_settings.py b/feathr_project/feathr/definition/monitoring_settings.py index ee39f84d5..e7fa9e4e7 100644 --- a/feathr_project/feathr/definition/monitoring_settings.py +++ b/feathr_project/feathr/definition/monitoring_settings.py @@ -4,5 +4,4 @@ # it's completely the same as MaterializationSettings. But we renamed it to improve usability. # In the future, we may want to rely a separate system other than MaterializationSettings to generate stats. class MonitoringSettings(MaterializationSettings): - """Settings about monitoring features. - """ + """Settings about monitoring features.""" diff --git a/feathr_project/feathr/definition/query_feature_list.py b/feathr_project/feathr/definition/query_feature_list.py index a667e77d0..10015e797 100644 --- a/feathr_project/feathr/definition/query_feature_list.py +++ b/feathr_project/feathr/definition/query_feature_list.py @@ -5,6 +5,7 @@ from feathr.definition.typed_key import TypedKey from feathr.definition.feathrconfig import HoconConvertible + class FeatureQuery(HoconConvertible): """A FeatureQuery contains a list of features @@ -12,8 +13,14 @@ class FeatureQuery(HoconConvertible): feature_list: a list of feature names key: key of `feature_list`, all features must share the same key override_time_delay [Optional]: to simulate time delay of features - """ - def __init__(self, feature_list: List[str], key: Optional[Union[TypedKey, List[TypedKey]]] = None, override_time_delay: Optional[str] = None) -> None: + """ + + def __init__( + self, + feature_list: List[str], + key: Optional[Union[TypedKey, List[TypedKey]]] = None, + override_time_delay: Optional[str] = None, + ) -> None: self.key = key if isinstance(key, TypedKey): self.key = [key] @@ -22,7 +29,8 @@ def __init__(self, feature_list: List[str], key: Optional[Union[TypedKey, List[T self.overrideTimeDelay = override_time_delay def to_feature_config(self) -> str: - tm = Template(""" + tm = Template( + """ { key: [{{key_columns}}] featureList: [{{feature_names}}] @@ -30,7 +38,8 @@ def to_feature_config(self) -> str: overrideTimeDelay: "{{query.overrideTimeDelay}}" {% endif %} } - """) + """ + ) key_columns = ", ".join(k.key_column for k in self.key) if self.key else "NOT_NEEDED" feature_list = ", ".join(f for f in self.feature_list) - return tm.render(key_columns = key_columns, feature_names = feature_list, query=self) + return tm.render(key_columns=key_columns, feature_names=feature_list, query=self) diff --git a/feathr_project/feathr/definition/repo_definitions.py b/feathr_project/feathr/definition/repo_definitions.py index 51e780f60..5db4a73a6 100644 --- a/feathr_project/feathr/definition/repo_definitions.py +++ b/feathr_project/feathr/definition/repo_definitions.py @@ -8,14 +8,17 @@ class RepoDefinitions: """A list of shareable Feathr objects defined in the project.""" - def __init__(self, - sources: Set[Source], - features: Set[Feature], - transformations: Set[Transformation], - feature_anchors: Set[FeatureAnchor], - derived_features: Set[DerivedFeature]) -> None: + + def __init__( + self, + sources: Set[Source], + features: Set[Feature], + transformations: Set[Transformation], + feature_anchors: Set[FeatureAnchor], + derived_features: Set[DerivedFeature], + ) -> None: self.sources = sources self.features = features self.transformations = transformations self.feature_anchors = feature_anchors - self.derived_features = derived_features \ No newline at end of file + self.derived_features = derived_features diff --git a/feathr_project/feathr/definition/settings.py b/feathr_project/feathr/definition/settings.py index 6711c592c..ca158d7a1 100644 --- a/feathr_project/feathr/definition/settings.py +++ b/feathr_project/feathr/definition/settings.py @@ -3,32 +3,37 @@ from loguru import logger from feathr.definition.feathrconfig import HoconConvertible -class ConflictsAutoCorrection(): - """Conflicts auto-correction handler settings. - Used in feature join when some conflicts exist + +class ConflictsAutoCorrection: + """Conflicts auto-correction handler settings. + Used in feature join when some conflicts exist between feature names and observation dataset columns. - + Attributes: rename_features: rename feature names when solving conflicts. Default by 'False' which means to rename observation dataset columns. suffix: customized suffix to be applied to conflicts names. Default by "1" eg. The conflicted column name 'field' will become 'field_1' if suffix is "1" """ + def __init__(self, rename_features: bool = False, suffix: str = "1") -> None: self.rename_features = rename_features self.suffix = suffix - + def to_feature_config(self) -> str: - tm = Template(""" + tm = Template( + """ {% if auto_correction.rename_features %} renameFeatures: True {% else %} renameFeatures: False {% endif %} suffix: {{auto_correction.suffix}} - """) + """ + ) return tm.render(auto_correction=self) + class ObservationSettings(HoconConvertible): """Time settings of the observation data. Used in feature join. @@ -42,29 +47,36 @@ class ObservationSettings(HoconConvertible): - Any date formats supported by [SimpleDateFormat](https://docs.oracle.com/javase/8/docs/api/java/text/SimpleDateFormat.html). file_format: format of the dataset file. Default as "csv" is_file_path: if the 'observation_path' is a path of file (instead of a directory). Default as 'True' - conflicts_auto_correction: settings about auto-correct feature names conflicts. - Default as None which means do not enable it. + conflicts_auto_correction: settings about auto-correct feature names conflicts. + Default as None which means do not enable it. """ - def __init__(self, - observation_path: str, - event_timestamp_column: Optional[str] = None, - simulate_time_delay: Optional[str] = None, - timestamp_format: str = "epoch", - conflicts_auto_correction: ConflictsAutoCorrection = None, - file_format: str = "csv", - is_file_path: bool = True) -> None: + + def __init__( + self, + observation_path: str, + event_timestamp_column: Optional[str] = None, + simulate_time_delay: Optional[str] = None, + timestamp_format: str = "epoch", + conflicts_auto_correction: ConflictsAutoCorrection = None, + file_format: str = "csv", + is_file_path: bool = True, + ) -> None: self.event_timestamp_column = event_timestamp_column self.simulate_time_delay = simulate_time_delay self.timestamp_format = timestamp_format self.observation_path = observation_path if observation_path.startswith("http"): - logger.warning("Your observation_path {} starts with http, which is not supported. Consider using paths starting with wasb[s]/abfs[s]/s3.", observation_path) + logger.warning( + "Your observation_path {} starts with http, which is not supported. Consider using paths starting with wasb[s]/abfs[s]/s3.", + observation_path, + ) self.file_format = file_format self.is_file_path = is_file_path self.conflicts_auto_correction = conflicts_auto_correction - + def to_feature_config(self) -> str: - tm = Template(""" + tm = Template( + """ {% if setting.event_timestamp_column is not none %} settings: { joinTimeSettings: { @@ -84,6 +96,6 @@ def to_feature_config(self) -> str: } {% endif %} observationPath: "{{setting.observation_path}}" - """) + """ + ) return tm.render(setting=self) - diff --git a/feathr_project/feathr/definition/sink.py b/feathr_project/feathr/definition/sink.py index 71c406561..5c145a729 100644 --- a/feathr_project/feathr/definition/sink.py +++ b/feathr_project/feathr/definition/sink.py @@ -7,55 +7,59 @@ class Sink(HoconConvertible): - """A data sink. - """ - + """A data sink.""" + @abstractmethod def support_offline(self) -> bool: pass - + @abstractmethod def support_online(self) -> bool: pass - + @abstractmethod def to_argument(self): pass - + def __str__(self) -> str: return "DUMMY" + class MonitoringSqlSink(Sink): """SQL-based sink that stores feature monitoring results. Attributes: table_name: output table name """ + def __init__(self, table_name: str) -> None: self.table_name = table_name def to_feature_config(self) -> str: """Produce the config used in feature monitoring""" - tm = Template(""" + tm = Template( + """ { name: MONITORING params: { table_name: "{{source.table_name}}" } } - """) + """ + ) msg = tm.render(source=self) return msg def support_offline(self) -> bool: return False - + def support_online(self) -> bool: return True - + def to_argument(self): raise TypeError("MonitoringSqlSink cannot be used as output argument") + class RedisSink(Sink): """Redis-based sink use to store online feature data, can be used in batch job or streaming job. @@ -64,14 +68,16 @@ class RedisSink(Sink): streaming: whether it is used in streaming mode streamingTimeoutMs: maximum running time for streaming mode. It is not used in batch mode. """ - def __init__(self, table_name: str, streaming: bool=False, streamingTimeoutMs: Optional[int]=None) -> None: + + def __init__(self, table_name: str, streaming: bool = False, streamingTimeoutMs: Optional[int] = None) -> None: self.table_name = table_name self.streaming = streaming self.streamingTimeoutMs = streamingTimeoutMs def to_feature_config(self) -> str: """Produce the config used in feature materialization""" - tm = Template(""" + tm = Template( + """ { name: REDIS params: { @@ -87,36 +93,39 @@ def to_feature_config(self) -> str: {% endif %} } } - """) + """ + ) msg = tm.render(source=self) return msg def support_offline(self) -> bool: return False - + def support_online(self) -> bool: return True - + def to_argument(self): raise TypeError("RedisSink cannot be used as output argument") class HdfsSink(Sink): """Offline Hadoop HDFS-compatible(HDFS, delta lake, Azure blog storage etc) sink that is used to store feature data. - The result is in AVRO format. + The result is in AVRO format. - Incremental aggregation is enabled by default when using HdfsSink. Use incremental aggregation will significantly expedite the WindowAggTransformation feature calculation. - For example, aggregation sum of a feature F within a 180-day window at day T can be expressed as: F(T) = F(T - 1)+DirectAgg(T-1)-DirectAgg(T - 181). - Once a SNAPSHOT of the first day is generated, the calculation for the following days can leverage it. + Incremental aggregation is enabled by default when using HdfsSink. Use incremental aggregation will significantly expedite the WindowAggTransformation feature calculation. + For example, aggregation sum of a feature F within a 180-day window at day T can be expressed as: F(T) = F(T - 1)+DirectAgg(T-1)-DirectAgg(T - 181). + Once a SNAPSHOT of the first day is generated, the calculation for the following days can leverage it. Attributes: output_path: output path - store_name: the folder name under the base "path". Used especially for the current dataset to support 'Incremental' aggregation. - + store_name: the folder name under the base "path". Used especially for the current dataset to support 'Incremental' aggregation. + """ - def __init__(self, output_path: str, store_name: Optional[str]="df0") -> None: + + def __init__(self, output_path: str, store_name: Optional[str] = "df0") -> None: self.output_path = output_path self.store_name = store_name + # Sample generated HOCON config: # operational: { # name: testFeatureGen @@ -139,7 +148,8 @@ def __init__(self, output_path: str, store_name: Optional[str]="df0") -> None: # features: [mockdata_a_ct_gen, mockdata_a_sample_gen] def to_feature_config(self) -> str: """Produce the config used in feature materialization""" - tm = Template(""" + tm = Template( + """ { name: HDFS outputFormat: RAW_DATA @@ -153,19 +163,21 @@ def to_feature_config(self) -> str: {% endif %} } } - """) + """ + ) hocon_config = tm.render(sink=self) return hocon_config def support_offline(self) -> bool: return True - + def support_online(self) -> bool: return True - + def to_argument(self): return self.output_path + class JdbcSink(Sink): def __init__(self, name: str, url: str, dbtable: str, auth: Optional[str] = None) -> None: self.name = name @@ -174,8 +186,7 @@ def __init__(self, name: str, url: str, dbtable: str, auth: Optional[str] = None if auth is not None: self.auth = auth.upper() if self.auth not in ["USERPASS", "TOKEN"]: - raise ValueError( - "auth must be None or one of following values: ['userpass', 'token']") + raise ValueError("auth must be None or one of following values: ['userpass', 'token']") def get_required_properties(self): if not hasattr(self, "auth"): @@ -187,13 +198,14 @@ def get_required_properties(self): def support_offline(self) -> bool: return True - + def support_online(self) -> bool: return True - + def to_feature_config(self) -> str: """Produce the config used in feature materialization""" - tm = Template(""" + tm = Template( + """ { name: HDFS params: { @@ -210,7 +222,8 @@ def to_feature_config(self) -> str: {% endif %} } } - """) + """ + ) sink = copy.copy(self) sink.name = self.name.upper() hocon_config = tm.render(sink=sink) @@ -233,32 +246,31 @@ def to_argument(self): else: d["anonymous"] = True return json.dumps(d) - + + class GenericSink(Sink): """ This class is corresponding to 'GenericLocation' in Feathr core, but only be used as Sink. The class is not meant to be used by user directly, user should use its subclasses like `CosmosDbSink` """ + def __init__(self, format: str, mode: Optional[str] = None, options: Dict[str, str] = {}) -> None: self.format = format self.mode = mode self.options = dict([(o.replace(".", "__"), options[o]) for o in options]) - + def to_feature_config(self) -> str: - ret = { - "name": "HDFS", - "params": self._to_dict() - } + ret = {"name": "HDFS", "params": self._to_dict()} return json.dumps(ret, indent=4) - + def _to_dict(self) -> Dict[str, str]: ret = self.options.copy() ret["type"] = "generic" ret["format"] = self.format if self.mode: ret["mode"] = self.mode - return ret - + return ret + def get_required_properties(self): ret = [] for option in self.options: @@ -266,7 +278,7 @@ def get_required_properties(self): if start >= 0: end = option[start:].find("}") if end >= 0: - ret.append(option[start+2:start+end]) + ret.append(option[start + 2 : start + end]) return ret def to_argument(self): @@ -274,44 +286,46 @@ def to_argument(self): One-line JSON string, used by job submitter """ return json.dumps(self._to_dict()) - + + class CosmosDbSink(GenericSink): """ CosmosDbSink is a sink that is used to store online feature data in CosmosDB. Even it's possible, but we shouldn't use it as offline store as CosmosDb requires records to have unique keys, why offline feature job cannot generate unique keys. """ - def __init__(self, name: str, endpoint: str, database: str, container: str): - super().__init__(format = "cosmos.oltp", mode="APPEND", options={ - "spark.cosmos.accountEndpoint": endpoint, - 'spark.cosmos.accountKey': "${%s_KEY}" % name.upper(), - "spark.cosmos.database": database, - "spark.cosmos.container": container - }) + + def __init__(self, name: str, endpoint: str, database: str, container: str): + super().__init__( + format="cosmos.oltp", + mode="APPEND", + options={ + "spark.cosmos.accountEndpoint": endpoint, + "spark.cosmos.accountKey": "${%s_KEY}" % name.upper(), + "spark.cosmos.database": database, + "spark.cosmos.container": container, + }, + ) self.name = name self.endpoint = endpoint self.database = database self.container = container - + def support_offline(self) -> bool: return False - + def support_online(self) -> bool: return True - + def get_required_properties(self) -> List[str]: return [self.name.upper() + "_KEY"] + class ElasticSearchSink(GenericSink): """ Use ElasticSearch as the data sink. """ - def __init__(self, - name: str, - host: str, - index: str, - ssl: bool = True, - auth: bool = True, - mode = 'OVERWRITE'): + + def __init__(self, name: str, host: str, index: str, ssl: bool = True, auth: bool = True, mode="OVERWRITE"): """ name: The name of the sink. host: ElasticSearch node, can be `hostname` or `hostname:port`, default port is 9200. @@ -322,9 +336,9 @@ def __init__(self, """ self.auth = auth options = { - 'es.nodes': host, - 'es.ssl': str(ssl).lower(), - 'es.resource': index, + "es.nodes": host, + "es.ssl": str(ssl).lower(), + "es.resource": index, } if auth: """ @@ -332,12 +346,9 @@ def __init__(self, ElasticSearch Spark connector also supports PKI auth but that needs to setup keystore on each driver node, which seems to be too complicated for managed Spark cluster. """ - options["es.net.http.auth.user"] = "${%s_USER}" % name.upper(), - options["es.net.http.auth.pass"] = "${%s_PASSWORD}" % name.upper(), - super().__init__(name, - format='org.elasticsearch.spark.sql', - mode=mode, - options=options) + options["es.net.http.auth.user"] = ("${%s_USER}" % name.upper(),) + options["es.net.http.auth.pass"] = ("${%s_PASSWORD}" % name.upper(),) + super().__init__(name, format="org.elasticsearch.spark.sql", mode=mode, options=options) def support_offline(self) -> bool: """ @@ -345,32 +356,37 @@ def support_offline(self) -> bool: the output dataset is accessible in other ways, like full-text search or time-series with a timestamp field. """ return True - + def support_online(self) -> bool: return True - + def get_required_properties(self) -> List[str]: if self.auth: return [self.name.upper() + "_USER", self.name.upper() + "_PASSWORD"] return [] + class AerospikeSink(GenericSink): - def __init__(self,name:str,seedhost:str,port:int,namespace:str,setname:str): - super().__init__(format="aerospike",mode="APPEND",options = { - "aerospike.seedhost":seedhost, - "aerospike.port":str(port), - "aerospike.namespace":namespace, - "aerospike.user":"${%s_USER}" % name.upper(), - "aerospike.password":"${%s_PASSWORD}" % name.upper(), - "aerospike.set":setname - }) + def __init__(self, name: str, seedhost: str, port: int, namespace: str, setname: str): + super().__init__( + format="aerospike", + mode="APPEND", + options={ + "aerospike.seedhost": seedhost, + "aerospike.port": str(port), + "aerospike.namespace": namespace, + "aerospike.user": "${%s_USER}" % name.upper(), + "aerospike.password": "${%s_PASSWORD}" % name.upper(), + "aerospike.set": setname, + }, + ) self.name = name def support_offline(self) -> bool: return False - + def support_online(self) -> bool: return True - + def get_required_properties(self) -> List[str]: return [self.name.upper() + "_USER", self.name.upper() + "_PASSWORD"] diff --git a/feathr_project/feathr/definition/source.py b/feathr_project/feathr/definition/source.py index 47a633932..d0b831369 100644 --- a/feathr_project/feathr/definition/source.py +++ b/feathr_project/feathr/definition/source.py @@ -1,4 +1,3 @@ - from abc import abstractmethod import copy from typing import Callable, Dict, List, Optional @@ -22,12 +21,14 @@ def __init__(self, schemaStr: str): def to_feature_config(self): """Convert the feature anchor definition into internal HOCON format.""" - tm = Template(""" + tm = Template( + """ schema: { type = "avro" avroJson:{{avroJson}} } - """) + """ + ) avroJson = json.dumps(self.schemaStr) msg = tm.render(schema=self, avroJson=avroJson) return msg @@ -44,12 +45,13 @@ class Source(HoconConvertible): For example, you can use {"deprecated": "true"} to indicate this source is deprecated, etc. """ - def __init__(self, - name: str, - event_timestamp_column: Optional[str] = "0", - timestamp_format: Optional[str] = "epoch", - registry_tags: Optional[Dict[str, str]] = None, - ) -> None: + def __init__( + self, + name: str, + event_timestamp_column: Optional[str] = "0", + timestamp_format: Optional[str] = "epoch", + registry_tags: Optional[Dict[str, str]] = None, + ) -> None: self.name = name self.event_timestamp_column = event_timestamp_column self.timestamp_format = timestamp_format @@ -65,7 +67,7 @@ def __hash__(self): def __str__(self): return self.to_feature_config() - + @abstractmethod def to_argument(self): pass @@ -77,6 +79,7 @@ class InputContext(Source): can be transformed from the observation data table t1 itself, like geo location, then you can define that feature on top of the InputContext. """ + __SOURCE_NAME = "PASSTHROUGH" def __init__(self) -> None: @@ -104,33 +107,45 @@ class HdfsSource(Source): registry_tags: A dict of (str, str) that you can pass to feature registry for better organization. For example, you can use {"deprecated": "true"} to indicate this source is deprecated, etc. time_partition_pattern(Optional[str]): Format of the time partitioned feature data. e.g. yyyy/MM/DD. All formats defined in dateTimeFormatter are supported. config: - timeSnapshotHdfsSource: - { - location: - { - path: "/data/somePath/daily/" - } - timePartitionPattern: "yyyy/MM/dd" + timeSnapshotHdfsSource: + { + location: + { + path: "/data/somePath/daily/" + } + timePartitionPattern: "yyyy/MM/dd" } - Given the above HDFS path: /data/somePath/daily, + Given the above HDFS path: /data/somePath/daily, then the expectation is that the following sub directorie(s) should exist: /data/somePath/daily/{yyyy}/{MM}/{dd} postfix_path(Optional[str]): postfix path followed by the 'time_partition_pattern'. Given above config, if we have 'postfix_path' defined all contents under paths of the pattern '{path}/{yyyy}/{MM}/{dd}/{postfix_path}' will be visited. """ - def __init__(self, name: str, path: str, preprocessing: Optional[Callable] = None, event_timestamp_column: Optional[str] = None, timestamp_format: Optional[str] = "epoch", registry_tags: Optional[Dict[str, str]] = None, time_partition_pattern: Optional[str] = None, postfix_path: Optional[str] = None) -> None: - super().__init__(name, event_timestamp_column, - timestamp_format, registry_tags=registry_tags) + def __init__( + self, + name: str, + path: str, + preprocessing: Optional[Callable] = None, + event_timestamp_column: Optional[str] = None, + timestamp_format: Optional[str] = "epoch", + registry_tags: Optional[Dict[str, str]] = None, + time_partition_pattern: Optional[str] = None, + postfix_path: Optional[str] = None, + ) -> None: + super().__init__(name, event_timestamp_column, timestamp_format, registry_tags=registry_tags) self.path = path self.preprocessing = preprocessing self.time_partition_pattern = time_partition_pattern self.postfix_path = postfix_path if path.startswith("http"): logger.warning( - "Your input path {} starts with http, which is not supported. Consider using paths starting with wasb[s]/abfs[s]/s3.", path) + "Your input path {} starts with http, which is not supported. Consider using paths starting with wasb[s]/abfs[s]/s3.", + path, + ) def to_feature_config(self) -> str: - tm = Template(""" + tm = Template( + """ {{source.name}}: { location: {path: "{{source.path}}"} {% if source.time_partition_pattern %} @@ -146,16 +161,18 @@ def to_feature_config(self) -> str: } {% endif %} } - """) + """ + ) msg = tm.render(source=self) return msg def __str__(self): - return str(self.preprocessing) + '\n' + self.to_feature_config() + return str(self.preprocessing) + "\n" + self.to_feature_config() def to_argument(self): return self.path + class SnowflakeSource(Source): """ A data source for Snowflake @@ -175,10 +192,21 @@ class SnowflakeSource(Source): - Any date formats supported by [SimpleDateFormat](https://docs.oracle.com/javase/8/docs/api/java/text/SimpleDateFormat.html). registry_tags: A dict of (str, str) that you can pass to feature registry for better organization. For example, you can use {"deprecated": "true"} to indicate this source is deprecated, etc. """ - def __init__(self, name: str, database: str, schema: str, dbtable: Optional[str] = None, query: Optional[str] = None, preprocessing: Optional[Callable] = None, event_timestamp_column: Optional[str] = None, timestamp_format: Optional[str] = "epoch", registry_tags: Optional[Dict[str, str]] = None) -> None: - super().__init__(name, event_timestamp_column, - timestamp_format, registry_tags=registry_tags) - self.preprocessing=preprocessing + + def __init__( + self, + name: str, + database: str, + schema: str, + dbtable: Optional[str] = None, + query: Optional[str] = None, + preprocessing: Optional[Callable] = None, + event_timestamp_column: Optional[str] = None, + timestamp_format: Optional[str] = "epoch", + registry_tags: Optional[Dict[str, str]] = None, + ) -> None: + super().__init__(name, event_timestamp_column, timestamp_format, registry_tags=registry_tags) + self.preprocessing = preprocessing if dbtable is not None and query is not None: raise RuntimeError("Both dbtable and query are specified. Can only specify one..") if dbtable is None and query is None: @@ -199,7 +227,7 @@ def _get_snowflake_path(self, dbtable: Optional[str] = None, query: Optional[str return f"snowflake://snowflake_account/?sfDatabase={self.database}&sfSchema={self.schema}&dbtable={dbtable}" else: return f"snowflake://snowflake_account/?sfDatabase={self.database}&sfSchema={self.schema}&query={query}" - + def parse_snowflake_path(url: str) -> Dict[str, str]: """ Parses snowflake path into dictionary of components for registry. @@ -208,9 +236,10 @@ def parse_snowflake_path(url: str) -> Dict[str, str]: parsed_queries = parse_qs(parse_result.query) updated_dict = {key: parsed_queries[key][0] for key in parsed_queries} return updated_dict - + def to_feature_config(self) -> str: - tm = Template(""" + tm = Template( + """ {{source.name}}: { type: SNOWFLAKE location: { @@ -231,18 +260,31 @@ def to_feature_config(self) -> str: } {% endif %} } - """) + """ + ) msg = tm.render(source=self) return msg def __str__(self): - return str(self.preprocessing) + '\n' + self.to_feature_config() + return str(self.preprocessing) + "\n" + self.to_feature_config() def to_argument(self): return self.path + class JdbcSource(Source): - def __init__(self, name: str, url: str = "", dbtable: Optional[str] = None, query: Optional[str] = None, auth: Optional[str] = None, preprocessing: Optional[Callable] = None, event_timestamp_column: Optional[str] = None, timestamp_format: Optional[str] = "epoch", registry_tags: Optional[Dict[str, str]] = None) -> None: + def __init__( + self, + name: str, + url: str = "", + dbtable: Optional[str] = None, + query: Optional[str] = None, + auth: Optional[str] = None, + preprocessing: Optional[Callable] = None, + event_timestamp_column: Optional[str] = None, + timestamp_format: Optional[str] = "epoch", + registry_tags: Optional[Dict[str, str]] = None, + ) -> None: super().__init__(name, event_timestamp_column, timestamp_format, registry_tags) self.preprocessing = preprocessing self.url = url @@ -253,8 +295,7 @@ def __init__(self, name: str, url: str = "", dbtable: Optional[str] = None, quer if auth is not None: self.auth = auth.upper() if self.auth not in ["USERPASS", "TOKEN"]: - raise ValueError( - "auth must be None or one of following values: ['userpass', 'token']") + raise ValueError("auth must be None or one of following values: ['userpass', 'token']") def get_required_properties(self): if not hasattr(self, "auth"): @@ -265,7 +306,8 @@ def get_required_properties(self): return ["%s_TOKEN" % self.name.upper()] def to_feature_config(self) -> str: - tm = Template(""" + tm = Template( + """ {{source.name}}: { location: { type: "jdbc" @@ -295,14 +337,15 @@ def to_feature_config(self) -> str: } {% endif %} } - """) + """ + ) source = copy.copy(self) source.name = self.name.upper() msg = tm.render(source=source) return msg def __str__(self): - return str(self.preprocessing) + '\n' + self.to_feature_config() + return str(self.preprocessing) + "\n" + self.to_feature_config() def to_argument(self): d = { @@ -324,6 +367,7 @@ def to_argument(self): d["anonymous"] = True return json.dumps(d) + class KafkaConfig: """Kafka config for a streaming source @@ -331,7 +375,7 @@ class KafkaConfig: brokers: broker/server address topics: Kafka topics schema: Kafka message schema - """ + """ def __init__(self, brokers: List[str], topics: List[str], schema: SourceSchema): self.brokers = brokers @@ -342,12 +386,13 @@ def __init__(self, brokers: List[str], topics: List[str], schema: SourceSchema): class KafKaSource(Source): """A kafka source object. Used in streaming feature ingestion.""" - def __init__(self, name: str, kafkaConfig: KafkaConfig, registry_tags: Optional[Dict[str, str]] = None): + def __init__(self, name: str, kafkaConfig: KafkaConfig, registry_tags: Optional[Dict[str, str]] = None): super().__init__(name, registry_tags=registry_tags) self.config = kafkaConfig def to_feature_config(self) -> str: - tm = Template(""" + tm = Template( + """ {{source.name}}: { type: KAFKA config: { @@ -356,18 +401,29 @@ def to_feature_config(self) -> str: {{source.config.schema.to_feature_config()}} } } - """) - brokers = '"'+'","'.join(self.config.brokers)+'"' - topics = ','.join(self.config.topics) + """ + ) + brokers = '"' + '","'.join(self.config.brokers) + '"' + topics = ",".join(self.config.topics) msg = tm.render(source=self, brokers=brokers, topics=topics) return msg def to_argument(self): raise TypeError("KafKaSource cannot be used as observation source") + class SparkSqlSource(Source): - def __init__(self, name: str, sql: Optional[str] = None, table: Optional[str] = None, preprocessing: Optional[Callable] = None, event_timestamp_column: Optional[str] = None, timestamp_format: Optional[str] = "epoch", registry_tags: Optional[Dict[str, str]] = None) -> None: - """ SparkSqlSource can use either a sql query or a table name as the source for Feathr job. + def __init__( + self, + name: str, + sql: Optional[str] = None, + table: Optional[str] = None, + preprocessing: Optional[Callable] = None, + event_timestamp_column: Optional[str] = None, + timestamp_format: Optional[str] = "epoch", + registry_tags: Optional[Dict[str, str]] = None, + ) -> None: + """SparkSqlSource can use either a sql query or a table name as the source for Feathr job. name: name of the source sql: sql query to use as the source, either sql or table must be specified table: table name to use as the source, either sql or table must be specified @@ -379,9 +435,8 @@ def __init__(self, name: str, sql: Optional[str] = None, table: Optional[str] = - Any date formats supported by [SimpleDateFormat](https://docs.oracle.com/javase/8/docs/api/java/text/SimpleDateFormat.html). registry_tags: A dict of (str, str) that you can pass to feature registry for better organization. For example, you can use {"deprecated": "true"} to indicate this source is deprecated, etc. """ - super().__init__(name, event_timestamp_column, - timestamp_format, registry_tags=registry_tags) - self.source_type = 'sparksql' + super().__init__(name, event_timestamp_column, timestamp_format, registry_tags=registry_tags) + self.source_type = "sparksql" if sql is None and table is None: raise ValueError("Either `sql` or `table` must be specified") if sql is not None and table is not None: @@ -393,7 +448,8 @@ def __init__(self, name: str, sql: Optional[str] = None, table: Optional[str] = self.preprocessing = preprocessing def to_feature_config(self) -> str: - tm = Template(""" + tm = Template( + """ {{source.name}}: { location: { type: "sparksql" @@ -410,7 +466,8 @@ def to_feature_config(self) -> str: } {% endif %} } - """) + """ + ) msg = tm.render(source=self) return msg @@ -424,24 +481,34 @@ def to_dict(self) -> Dict[str, str]: ret["sql"] = self.sql elif hasattr(self, "table"): ret["table"] = self.table - return ret - + return ret + def to_argument(self): """ One-line JSON string, used by job submitter """ return json.dumps(self.to_dict()) - + class GenericSource(Source): """ This class is corresponding to 'GenericLocation' in Feathr core, but only be used as Source. The class is not meant to be used by user directly, user should use its subclasses like `CosmosDbSource` """ - def __init__(self, name: str, format: str, mode: Optional[str] = None, options: Dict[str, str] = {}, preprocessing: Optional[Callable] = None, event_timestamp_column: Optional[str] = None, timestamp_format: Optional[str] = "epoch", registry_tags: Optional[Dict[str, str]] = None) -> None: - super().__init__(name, event_timestamp_column, - timestamp_format, registry_tags=registry_tags) - self.source_type = 'generic' + + def __init__( + self, + name: str, + format: str, + mode: Optional[str] = None, + options: Dict[str, str] = {}, + preprocessing: Optional[Callable] = None, + event_timestamp_column: Optional[str] = None, + timestamp_format: Optional[str] = "epoch", + registry_tags: Optional[Dict[str, str]] = None, + ) -> None: + super().__init__(name, event_timestamp_column, timestamp_format, registry_tags=registry_tags) + self.source_type = "generic" self.format = format self.mode = mode self.preprocessing = preprocessing @@ -450,7 +517,8 @@ def __init__(self, name: str, format: str, mode: Optional[str] = None, options: self.options = dict([(key.replace(".", "__"), options[key]) for key in options]) def to_feature_config(self) -> str: - tm = Template(""" + tm = Template( + """ {{source.name}}: { location: { type: "generic" @@ -469,7 +537,8 @@ def to_feature_config(self) -> str: } {% endif %} } - """) + """ + ) msg = tm.render(source=self) return msg @@ -480,7 +549,7 @@ def get_required_properties(self): if start >= 0: end = option[start:].find("}") if end >= 0: - ret.append(option[start+2:start+end]) + ret.append(option[start + 2 : start + end]) return ret def to_dict(self) -> Dict[str, str]: @@ -489,8 +558,8 @@ def to_dict(self) -> Dict[str, str]: ret["format"] = self.format if self.mode: ret["mode"] = self.mode - return ret - + return ret + def to_argument(self): """ One-line JSON string, used by job submitter @@ -502,43 +571,53 @@ class CosmosDbSource(GenericSource): """ Use CosmosDb as the data source """ - def __init__(self, - name: str, - endpoint: str, - database: str, - container: str, - preprocessing: Optional[Callable] = None, - event_timestamp_column: Optional[str] = None, - timestamp_format: Optional[str] = "epoch", - registry_tags: Optional[Dict[str, str]] = None): + + def __init__( + self, + name: str, + endpoint: str, + database: str, + container: str, + preprocessing: Optional[Callable] = None, + event_timestamp_column: Optional[str] = None, + timestamp_format: Optional[str] = "epoch", + registry_tags: Optional[Dict[str, str]] = None, + ): options = { - 'spark.cosmos.accountEndpoint': endpoint, - 'spark.cosmos.accountKey': "${%s_KEY}" % name.upper(), - 'spark.cosmos.database': database, - 'spark.cosmos.container': container + "spark.cosmos.accountEndpoint": endpoint, + "spark.cosmos.accountKey": "${%s_KEY}" % name.upper(), + "spark.cosmos.database": database, + "spark.cosmos.container": container, } - super().__init__(name, - format='cosmos.oltp', - mode="APPEND", - options=options, - preprocessing=preprocessing, - event_timestamp_column=event_timestamp_column, timestamp_format=timestamp_format, - registry_tags=registry_tags) + super().__init__( + name, + format="cosmos.oltp", + mode="APPEND", + options=options, + preprocessing=preprocessing, + event_timestamp_column=event_timestamp_column, + timestamp_format=timestamp_format, + registry_tags=registry_tags, + ) + class ElasticSearchSource(GenericSource): """ Use ElasticSearch as the data source """ - def __init__(self, - name: str, - host: str, - index: str, - ssl: bool = True, - auth: bool = True, - preprocessing: Optional[Callable] = None, - event_timestamp_column: Optional[str] = None, - timestamp_format: Optional[str] = "epoch", - registry_tags: Optional[Dict[str, str]] = None): + + def __init__( + self, + name: str, + host: str, + index: str, + ssl: bool = True, + auth: bool = True, + preprocessing: Optional[Callable] = None, + event_timestamp_column: Optional[str] = None, + timestamp_format: Optional[str] = "epoch", + registry_tags: Optional[Dict[str, str]] = None, + ): """ name: The name of the sink. host: ElasticSearch node, can be `hostname` or `hostname:port`, default port is 9200. @@ -548,19 +627,23 @@ def __init__(self, preprocessing/event_timestamp_column/timestamp_format/registry_tags: See `HdfsSource` """ options = { - 'es.nodes': host, - 'es.ssl': str(ssl).lower(), - 'es.resource': index, + "es.nodes": host, + "es.ssl": str(ssl).lower(), + "es.resource": index, } if auth: - options["es.net.http.auth.user"] = "${%s_USER}" % name.upper(), - options["es.net.http.auth.pass"] = "${%s_PASSWORD}" % name.upper(), - super().__init__(name, - format='org.elasticsearch.spark.sql', - mode="APPEND", - options=options, - preprocessing=preprocessing, - event_timestamp_column=event_timestamp_column, timestamp_format=timestamp_format, - registry_tags=registry_tags) + options["es.net.http.auth.user"] = ("${%s_USER}" % name.upper(),) + options["es.net.http.auth.pass"] = ("${%s_PASSWORD}" % name.upper(),) + super().__init__( + name, + format="org.elasticsearch.spark.sql", + mode="APPEND", + options=options, + preprocessing=preprocessing, + event_timestamp_column=event_timestamp_column, + timestamp_format=timestamp_format, + registry_tags=registry_tags, + ) + INPUT_CONTEXT = InputContext() diff --git a/feathr_project/feathr/definition/transformation.py b/feathr_project/feathr/definition/transformation.py index 1aa6864be..f02cc67d4 100644 --- a/feathr_project/feathr/definition/transformation.py +++ b/feathr_project/feathr/definition/transformation.py @@ -4,13 +4,16 @@ from jinja2 import Template from feathr.definition.feathrconfig import HoconConvertible + class Transformation(HoconConvertible): """Base class for all transformations that produce feature values.""" + pass class RowTransformation(Transformation): """Base class for all row-level transformations.""" + pass @@ -20,18 +23,21 @@ class ExpressionTransformation(RowTransformation): Attributes: expr: expression that transforms the raw value into a new value, e.g. amount * 10. """ + def __init__(self, expr: str) -> None: super().__init__() self.expr = expr def to_feature_config(self, with_def_field_name: Optional[bool] = True) -> str: - tm = Template(""" + tm = Template( + """ {% if with_def_field_name %} def.sqlExpr: "{{expr}}" {% else %} "{{expr}}" {% endif %} - """) + """ + ) return tm.render(expr=self.expr, with_def_field_name=with_def_field_name) @@ -45,7 +51,16 @@ class WindowAggTransformation(Transformation): group_by: Feathr expressions applied after the `agg_expr` transformation as groupby field, before aggregation, same as 'group by' in SQL filter: Feathr expression applied to each row as a filter before aggregation. This should be a string and a valid Spark SQL Expression. For example: filter = 'age > 3'. This is similar to PySpark filter operation and more details can be learned here: https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.filter.html """ - def __init__(self, agg_expr: str, agg_func: str, window: str, group_by: Optional[str] = None, filter: Optional[str] = None, limit: Optional[int] = None) -> None: + + def __init__( + self, + agg_expr: str, + agg_func: str, + window: str, + group_by: Optional[str] = None, + filter: Optional[str] = None, + limit: Optional[int] = None, + ) -> None: super().__init__() self.def_expr = agg_expr self.agg_func = agg_func @@ -55,7 +70,8 @@ def __init__(self, agg_expr: str, agg_func: str, window: str, group_by: Optional self.limit = limit def to_feature_config(self, with_def_field_name: Optional[bool] = True) -> str: - tm = Template(""" + tm = Template( + """ def:"{{windowAgg.def_expr}}" window: {{windowAgg.window}} aggregation: {{windowAgg.agg_func}} @@ -68,8 +84,9 @@ def to_feature_config(self, with_def_field_name: Optional[bool] = True) -> str: {% if windowAgg.limit is not none %} limit: {{windowAgg.limit}} {% endif %} - """) - return tm.render(windowAgg = self) + """ + ) + return tm.render(windowAgg=self) class UdfTransform(Transformation): @@ -78,6 +95,7 @@ class UdfTransform(Transformation): Attributes: name: name of the user defined function """ + def __init__(self, name: str) -> None: """ @@ -85,5 +103,6 @@ def __init__(self, name: str) -> None: """ super().__init__() self.name = name + def to_feature_config(self) -> str: - pass \ No newline at end of file + pass diff --git a/feathr_project/feathr/definition/typed_key.py b/feathr_project/feathr/definition/typed_key.py index c2732a476..c647501e5 100644 --- a/feathr_project/feathr/definition/typed_key.py +++ b/feathr_project/feathr/definition/typed_key.py @@ -7,23 +7,26 @@ class TypedKey: """The key of a feature. A feature is typically keyed by some id(s). e.g. product id, user id - Attributes: - key_column: The id column name of this key. e.g. 'product_id'. - key_column_type: Types of the key_column - full_name: Unique name of the key. Recommend using [project_name].[key_name], e.g. ads.user_id - description: Documentation for the key. - key_column_alias: Used in some advanced derived features. Default to the key_column. + Attributes: + key_column: The id column name of this key. e.g. 'product_id'. + key_column_type: Types of the key_column + full_name: Unique name of the key. Recommend using [project_name].[key_name], e.g. ads.user_id + description: Documentation for the key. + key_column_alias: Used in some advanced derived features. Default to the key_column. """ - def __init__(self, - key_column: str, - key_column_type: ValueType, - full_name: Optional[str] = None, - description: Optional[str] = None, - key_column_alias: Optional[str] = None) -> None: + + def __init__( + self, + key_column: str, + key_column_type: ValueType, + full_name: Optional[str] = None, + description: Optional[str] = None, + key_column_alias: Optional[str] = None, + ) -> None: # Validate the key_column type if not isinstance(key_column_type, ValueType): - raise KeyError(f'key_column_type must be a ValueType, like Value.INT32, but got {key_column_type}') - + raise KeyError(f"key_column_type must be a ValueType, like Value.INT32, but got {key_column_type}") + self.key_column = key_column self.key_column_type = key_column_type self.full_name = full_name @@ -32,7 +35,7 @@ def __init__(self, def as_key(self, key_column_alias: str) -> TypedKey: """Rename the key alias. This is useful in derived features that depends on the same feature - with different keys. + with different keys. """ new_key = deepcopy(self) new_key.key_column_alias = key_column_alias @@ -42,7 +45,9 @@ def as_key(self, key_column_alias: str) -> TypedKey: # passthrough/request feature do not need keys, as they are just a transformation defined on top of the request data # They do not necessarily describe the value of keyed entity, e.g. dayofweek(timestamp) is transform on a request # field without key -DUMMY_KEY = TypedKey(key_column="NOT_NEEDED", - key_column_type=ValueType.UNSPECIFIED, - full_name="feathr.dummy_typedkey", - description="A dummy typed key for passthrough/request feature.") \ No newline at end of file +DUMMY_KEY = TypedKey( + key_column="NOT_NEEDED", + key_column_type=ValueType.UNSPECIFIED, + full_name="feathr.dummy_typedkey", + description="A dummy typed key for passthrough/request feature.", +) diff --git a/feathr_project/feathr/protobuf/featureValue_pb2.py b/feathr_project/feathr/protobuf/featureValue_pb2.py index b5aa643b0..7b3f24a29 100644 --- a/feathr_project/feathr/protobuf/featureValue_pb2.py +++ b/feathr_project/feathr/protobuf/featureValue_pb2.py @@ -7,159 +7,214 @@ from google.protobuf import message as _message from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database + # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() - - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12\x66\x65\x61tureValue.proto\x12\x08protobuf\"\xce\x06\n\x0c\x46\x65\x61tureValue\x12\x17\n\rboolean_value\x18\x01 \x01(\x08H\x00\x12\x16\n\x0cstring_value\x18\x02 \x01(\tH\x00\x12\x15\n\x0b\x66loat_value\x18\x03 \x01(\x02H\x00\x12\x16\n\x0c\x64ouble_value\x18\x04 \x01(\x01H\x00\x12\x13\n\tint_value\x18\x05 \x01(\x05H\x00\x12\x14\n\nlong_value\x18\x06 \x01(\x03H\x00\x12/\n\rboolean_array\x18\n \x01(\x0b\x32\x16.protobuf.BooleanArrayH\x00\x12-\n\x0cstring_array\x18\x0b \x01(\x0b\x32\x15.protobuf.StringArrayH\x00\x12+\n\x0b\x66loat_array\x18\x0c \x01(\x0b\x32\x14.protobuf.FloatArrayH\x00\x12-\n\x0c\x64ouble_array\x18\r \x01(\x0b\x32\x15.protobuf.DoubleArrayH\x00\x12+\n\tint_array\x18\x0e \x01(\x0b\x32\x16.protobuf.IntegerArrayH\x00\x12)\n\nlong_array\x18\x0f \x01(\x0b\x32\x13.protobuf.LongArrayH\x00\x12*\n\nbyte_array\x18\x10 \x01(\x0b\x32\x14.protobuf.BytesArrayH\x00\x12:\n\x13sparse_string_array\x18\x14 \x01(\x0b\x32\x1b.protobuf.SparseStringArrayH\x00\x12\x36\n\x11sparse_bool_array\x18\x15 \x01(\x0b\x32\x19.protobuf.SparseBoolArrayH\x00\x12<\n\x14sparse_integer_array\x18\x16 \x01(\x0b\x32\x1c.protobuf.SparseIntegerArrayH\x00\x12\x36\n\x11sparse_long_array\x18\x17 \x01(\x0b\x32\x19.protobuf.SparseLongArrayH\x00\x12:\n\x13sparse_double_array\x18\x18 \x01(\x0b\x32\x1b.protobuf.SparseDoubleArrayH\x00\x12\x38\n\x12sparse_float_array\x18\x19 \x01(\x0b\x32\x1a.protobuf.SparseFloatArrayH\x00\x42\x13\n\x11\x46\x65\x61tureValueOneOf\" \n\x0c\x42ooleanArray\x12\x10\n\x08\x62ooleans\x18\x01 \x03(\x08\"\x1e\n\x0bStringArray\x12\x0f\n\x07strings\x18\x01 \x03(\t\"\x1e\n\x0b\x44oubleArray\x12\x0f\n\x07\x64oubles\x18\x01 \x03(\x01\"\x1c\n\nFloatArray\x12\x0e\n\x06\x66loats\x18\x01 \x03(\x02\" \n\x0cIntegerArray\x12\x10\n\x08integers\x18\x01 \x03(\x05\"\x1a\n\tLongArray\x12\r\n\x05longs\x18\x01 \x03(\x03\"\x1b\n\nBytesArray\x12\r\n\x05\x62ytes\x18\x01 \x03(\x0c\"B\n\x11SparseStringArray\x12\x16\n\x0eindex_integers\x18\x01 \x03(\x05\x12\x15\n\rvalue_strings\x18\x02 \x03(\t\"A\n\x0fSparseBoolArray\x12\x16\n\x0eindex_integers\x18\x01 \x03(\x05\x12\x16\n\x0evalue_booleans\x18\x02 \x03(\x08\"D\n\x12SparseIntegerArray\x12\x16\n\x0eindex_integers\x18\x01 \x03(\x05\x12\x16\n\x0evalue_integers\x18\x02 \x03(\x05\">\n\x0fSparseLongArray\x12\x16\n\x0eindex_integers\x18\x01 \x03(\x05\x12\x13\n\x0bvalue_longs\x18\x02 \x03(\x03\"B\n\x11SparseDoubleArray\x12\x16\n\x0eindex_integers\x18\x01 \x03(\x05\x12\x15\n\rvalue_doubles\x18\x02 \x03(\x01\"@\n\x10SparseFloatArray\x12\x16\n\x0eindex_integers\x18\x01 \x03(\x05\x12\x14\n\x0cvalue_floats\x18\x02 \x03(\x02\x42+\n)com.linkedin.feathr.common.types.protobufb\x06proto3') - - - -_FEATUREVALUE = DESCRIPTOR.message_types_by_name['FeatureValue'] -_BOOLEANARRAY = DESCRIPTOR.message_types_by_name['BooleanArray'] -_STRINGARRAY = DESCRIPTOR.message_types_by_name['StringArray'] -_DOUBLEARRAY = DESCRIPTOR.message_types_by_name['DoubleArray'] -_FLOATARRAY = DESCRIPTOR.message_types_by_name['FloatArray'] -_INTEGERARRAY = DESCRIPTOR.message_types_by_name['IntegerArray'] -_LONGARRAY = DESCRIPTOR.message_types_by_name['LongArray'] -_BYTESARRAY = DESCRIPTOR.message_types_by_name['BytesArray'] -_SPARSESTRINGARRAY = DESCRIPTOR.message_types_by_name['SparseStringArray'] -_SPARSEBOOLARRAY = DESCRIPTOR.message_types_by_name['SparseBoolArray'] -_SPARSEINTEGERARRAY = DESCRIPTOR.message_types_by_name['SparseIntegerArray'] -_SPARSELONGARRAY = DESCRIPTOR.message_types_by_name['SparseLongArray'] -_SPARSEDOUBLEARRAY = DESCRIPTOR.message_types_by_name['SparseDoubleArray'] -_SPARSEFLOATARRAY = DESCRIPTOR.message_types_by_name['SparseFloatArray'] -FeatureValue = _reflection.GeneratedProtocolMessageType('FeatureValue', (_message.Message,), { - 'DESCRIPTOR' : _FEATUREVALUE, - '__module__' : 'featureValue_pb2' - # @@protoc_insertion_point(class_scope:protobuf.FeatureValue) -}) +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x12\x66\x65\x61tureValue.proto\x12\x08protobuf"\xce\x06\n\x0c\x46\x65\x61tureValue\x12\x17\n\rboolean_value\x18\x01 \x01(\x08H\x00\x12\x16\n\x0cstring_value\x18\x02 \x01(\tH\x00\x12\x15\n\x0b\x66loat_value\x18\x03 \x01(\x02H\x00\x12\x16\n\x0c\x64ouble_value\x18\x04 \x01(\x01H\x00\x12\x13\n\tint_value\x18\x05 \x01(\x05H\x00\x12\x14\n\nlong_value\x18\x06 \x01(\x03H\x00\x12/\n\rboolean_array\x18\n \x01(\x0b\x32\x16.protobuf.BooleanArrayH\x00\x12-\n\x0cstring_array\x18\x0b \x01(\x0b\x32\x15.protobuf.StringArrayH\x00\x12+\n\x0b\x66loat_array\x18\x0c \x01(\x0b\x32\x14.protobuf.FloatArrayH\x00\x12-\n\x0c\x64ouble_array\x18\r \x01(\x0b\x32\x15.protobuf.DoubleArrayH\x00\x12+\n\tint_array\x18\x0e \x01(\x0b\x32\x16.protobuf.IntegerArrayH\x00\x12)\n\nlong_array\x18\x0f \x01(\x0b\x32\x13.protobuf.LongArrayH\x00\x12*\n\nbyte_array\x18\x10 \x01(\x0b\x32\x14.protobuf.BytesArrayH\x00\x12:\n\x13sparse_string_array\x18\x14 \x01(\x0b\x32\x1b.protobuf.SparseStringArrayH\x00\x12\x36\n\x11sparse_bool_array\x18\x15 \x01(\x0b\x32\x19.protobuf.SparseBoolArrayH\x00\x12<\n\x14sparse_integer_array\x18\x16 \x01(\x0b\x32\x1c.protobuf.SparseIntegerArrayH\x00\x12\x36\n\x11sparse_long_array\x18\x17 \x01(\x0b\x32\x19.protobuf.SparseLongArrayH\x00\x12:\n\x13sparse_double_array\x18\x18 \x01(\x0b\x32\x1b.protobuf.SparseDoubleArrayH\x00\x12\x38\n\x12sparse_float_array\x18\x19 \x01(\x0b\x32\x1a.protobuf.SparseFloatArrayH\x00\x42\x13\n\x11\x46\x65\x61tureValueOneOf" \n\x0c\x42ooleanArray\x12\x10\n\x08\x62ooleans\x18\x01 \x03(\x08"\x1e\n\x0bStringArray\x12\x0f\n\x07strings\x18\x01 \x03(\t"\x1e\n\x0b\x44oubleArray\x12\x0f\n\x07\x64oubles\x18\x01 \x03(\x01"\x1c\n\nFloatArray\x12\x0e\n\x06\x66loats\x18\x01 \x03(\x02" \n\x0cIntegerArray\x12\x10\n\x08integers\x18\x01 \x03(\x05"\x1a\n\tLongArray\x12\r\n\x05longs\x18\x01 \x03(\x03"\x1b\n\nBytesArray\x12\r\n\x05\x62ytes\x18\x01 \x03(\x0c"B\n\x11SparseStringArray\x12\x16\n\x0eindex_integers\x18\x01 \x03(\x05\x12\x15\n\rvalue_strings\x18\x02 \x03(\t"A\n\x0fSparseBoolArray\x12\x16\n\x0eindex_integers\x18\x01 \x03(\x05\x12\x16\n\x0evalue_booleans\x18\x02 \x03(\x08"D\n\x12SparseIntegerArray\x12\x16\n\x0eindex_integers\x18\x01 \x03(\x05\x12\x16\n\x0evalue_integers\x18\x02 \x03(\x05">\n\x0fSparseLongArray\x12\x16\n\x0eindex_integers\x18\x01 \x03(\x05\x12\x13\n\x0bvalue_longs\x18\x02 \x03(\x03"B\n\x11SparseDoubleArray\x12\x16\n\x0eindex_integers\x18\x01 \x03(\x05\x12\x15\n\rvalue_doubles\x18\x02 \x03(\x01"@\n\x10SparseFloatArray\x12\x16\n\x0eindex_integers\x18\x01 \x03(\x05\x12\x14\n\x0cvalue_floats\x18\x02 \x03(\x02\x42+\n)com.linkedin.feathr.common.types.protobufb\x06proto3' +) + + +_FEATUREVALUE = DESCRIPTOR.message_types_by_name["FeatureValue"] +_BOOLEANARRAY = DESCRIPTOR.message_types_by_name["BooleanArray"] +_STRINGARRAY = DESCRIPTOR.message_types_by_name["StringArray"] +_DOUBLEARRAY = DESCRIPTOR.message_types_by_name["DoubleArray"] +_FLOATARRAY = DESCRIPTOR.message_types_by_name["FloatArray"] +_INTEGERARRAY = DESCRIPTOR.message_types_by_name["IntegerArray"] +_LONGARRAY = DESCRIPTOR.message_types_by_name["LongArray"] +_BYTESARRAY = DESCRIPTOR.message_types_by_name["BytesArray"] +_SPARSESTRINGARRAY = DESCRIPTOR.message_types_by_name["SparseStringArray"] +_SPARSEBOOLARRAY = DESCRIPTOR.message_types_by_name["SparseBoolArray"] +_SPARSEINTEGERARRAY = DESCRIPTOR.message_types_by_name["SparseIntegerArray"] +_SPARSELONGARRAY = DESCRIPTOR.message_types_by_name["SparseLongArray"] +_SPARSEDOUBLEARRAY = DESCRIPTOR.message_types_by_name["SparseDoubleArray"] +_SPARSEFLOATARRAY = DESCRIPTOR.message_types_by_name["SparseFloatArray"] +FeatureValue = _reflection.GeneratedProtocolMessageType( + "FeatureValue", + (_message.Message,), + { + "DESCRIPTOR": _FEATUREVALUE, + "__module__": "featureValue_pb2" + # @@protoc_insertion_point(class_scope:protobuf.FeatureValue) + }, +) _sym_db.RegisterMessage(FeatureValue) -BooleanArray = _reflection.GeneratedProtocolMessageType('BooleanArray', (_message.Message,), { - 'DESCRIPTOR' : _BOOLEANARRAY, - '__module__' : 'featureValue_pb2' - # @@protoc_insertion_point(class_scope:protobuf.BooleanArray) -}) +BooleanArray = _reflection.GeneratedProtocolMessageType( + "BooleanArray", + (_message.Message,), + { + "DESCRIPTOR": _BOOLEANARRAY, + "__module__": "featureValue_pb2" + # @@protoc_insertion_point(class_scope:protobuf.BooleanArray) + }, +) _sym_db.RegisterMessage(BooleanArray) -StringArray = _reflection.GeneratedProtocolMessageType('StringArray', (_message.Message,), { - 'DESCRIPTOR' : _STRINGARRAY, - '__module__' : 'featureValue_pb2' - # @@protoc_insertion_point(class_scope:protobuf.StringArray) -}) +StringArray = _reflection.GeneratedProtocolMessageType( + "StringArray", + (_message.Message,), + { + "DESCRIPTOR": _STRINGARRAY, + "__module__": "featureValue_pb2" + # @@protoc_insertion_point(class_scope:protobuf.StringArray) + }, +) _sym_db.RegisterMessage(StringArray) -DoubleArray = _reflection.GeneratedProtocolMessageType('DoubleArray', (_message.Message,), { - 'DESCRIPTOR' : _DOUBLEARRAY, - '__module__' : 'featureValue_pb2' - # @@protoc_insertion_point(class_scope:protobuf.DoubleArray) -}) +DoubleArray = _reflection.GeneratedProtocolMessageType( + "DoubleArray", + (_message.Message,), + { + "DESCRIPTOR": _DOUBLEARRAY, + "__module__": "featureValue_pb2" + # @@protoc_insertion_point(class_scope:protobuf.DoubleArray) + }, +) _sym_db.RegisterMessage(DoubleArray) -FloatArray = _reflection.GeneratedProtocolMessageType('FloatArray', (_message.Message,), { - 'DESCRIPTOR' : _FLOATARRAY, - '__module__' : 'featureValue_pb2' - # @@protoc_insertion_point(class_scope:protobuf.FloatArray) -}) +FloatArray = _reflection.GeneratedProtocolMessageType( + "FloatArray", + (_message.Message,), + { + "DESCRIPTOR": _FLOATARRAY, + "__module__": "featureValue_pb2" + # @@protoc_insertion_point(class_scope:protobuf.FloatArray) + }, +) _sym_db.RegisterMessage(FloatArray) -IntegerArray = _reflection.GeneratedProtocolMessageType('IntegerArray', (_message.Message,), { - 'DESCRIPTOR' : _INTEGERARRAY, - '__module__' : 'featureValue_pb2' - # @@protoc_insertion_point(class_scope:protobuf.IntegerArray) -}) +IntegerArray = _reflection.GeneratedProtocolMessageType( + "IntegerArray", + (_message.Message,), + { + "DESCRIPTOR": _INTEGERARRAY, + "__module__": "featureValue_pb2" + # @@protoc_insertion_point(class_scope:protobuf.IntegerArray) + }, +) _sym_db.RegisterMessage(IntegerArray) -LongArray = _reflection.GeneratedProtocolMessageType('LongArray', (_message.Message,), { - 'DESCRIPTOR' : _LONGARRAY, - '__module__' : 'featureValue_pb2' - # @@protoc_insertion_point(class_scope:protobuf.LongArray) -}) +LongArray = _reflection.GeneratedProtocolMessageType( + "LongArray", + (_message.Message,), + { + "DESCRIPTOR": _LONGARRAY, + "__module__": "featureValue_pb2" + # @@protoc_insertion_point(class_scope:protobuf.LongArray) + }, +) _sym_db.RegisterMessage(LongArray) -BytesArray = _reflection.GeneratedProtocolMessageType('BytesArray', (_message.Message,), { - 'DESCRIPTOR' : _BYTESARRAY, - '__module__' : 'featureValue_pb2' - # @@protoc_insertion_point(class_scope:protobuf.BytesArray) -}) +BytesArray = _reflection.GeneratedProtocolMessageType( + "BytesArray", + (_message.Message,), + { + "DESCRIPTOR": _BYTESARRAY, + "__module__": "featureValue_pb2" + # @@protoc_insertion_point(class_scope:protobuf.BytesArray) + }, +) _sym_db.RegisterMessage(BytesArray) -SparseStringArray = _reflection.GeneratedProtocolMessageType('SparseStringArray', (_message.Message,), { - 'DESCRIPTOR' : _SPARSESTRINGARRAY, - '__module__' : 'featureValue_pb2' - # @@protoc_insertion_point(class_scope:protobuf.SparseStringArray) -}) +SparseStringArray = _reflection.GeneratedProtocolMessageType( + "SparseStringArray", + (_message.Message,), + { + "DESCRIPTOR": _SPARSESTRINGARRAY, + "__module__": "featureValue_pb2" + # @@protoc_insertion_point(class_scope:protobuf.SparseStringArray) + }, +) _sym_db.RegisterMessage(SparseStringArray) -SparseBoolArray = _reflection.GeneratedProtocolMessageType('SparseBoolArray', (_message.Message,), { - 'DESCRIPTOR' : _SPARSEBOOLARRAY, - '__module__' : 'featureValue_pb2' - # @@protoc_insertion_point(class_scope:protobuf.SparseBoolArray) -}) +SparseBoolArray = _reflection.GeneratedProtocolMessageType( + "SparseBoolArray", + (_message.Message,), + { + "DESCRIPTOR": _SPARSEBOOLARRAY, + "__module__": "featureValue_pb2" + # @@protoc_insertion_point(class_scope:protobuf.SparseBoolArray) + }, +) _sym_db.RegisterMessage(SparseBoolArray) -SparseIntegerArray = _reflection.GeneratedProtocolMessageType('SparseIntegerArray', (_message.Message,), { - 'DESCRIPTOR' : _SPARSEINTEGERARRAY, - '__module__' : 'featureValue_pb2' - # @@protoc_insertion_point(class_scope:protobuf.SparseIntegerArray) -}) +SparseIntegerArray = _reflection.GeneratedProtocolMessageType( + "SparseIntegerArray", + (_message.Message,), + { + "DESCRIPTOR": _SPARSEINTEGERARRAY, + "__module__": "featureValue_pb2" + # @@protoc_insertion_point(class_scope:protobuf.SparseIntegerArray) + }, +) _sym_db.RegisterMessage(SparseIntegerArray) -SparseLongArray = _reflection.GeneratedProtocolMessageType('SparseLongArray', (_message.Message,), { - 'DESCRIPTOR' : _SPARSELONGARRAY, - '__module__' : 'featureValue_pb2' - # @@protoc_insertion_point(class_scope:protobuf.SparseLongArray) -}) +SparseLongArray = _reflection.GeneratedProtocolMessageType( + "SparseLongArray", + (_message.Message,), + { + "DESCRIPTOR": _SPARSELONGARRAY, + "__module__": "featureValue_pb2" + # @@protoc_insertion_point(class_scope:protobuf.SparseLongArray) + }, +) _sym_db.RegisterMessage(SparseLongArray) -SparseDoubleArray = _reflection.GeneratedProtocolMessageType('SparseDoubleArray', (_message.Message,), { - 'DESCRIPTOR' : _SPARSEDOUBLEARRAY, - '__module__' : 'featureValue_pb2' - # @@protoc_insertion_point(class_scope:protobuf.SparseDoubleArray) -}) +SparseDoubleArray = _reflection.GeneratedProtocolMessageType( + "SparseDoubleArray", + (_message.Message,), + { + "DESCRIPTOR": _SPARSEDOUBLEARRAY, + "__module__": "featureValue_pb2" + # @@protoc_insertion_point(class_scope:protobuf.SparseDoubleArray) + }, +) _sym_db.RegisterMessage(SparseDoubleArray) -SparseFloatArray = _reflection.GeneratedProtocolMessageType('SparseFloatArray', (_message.Message,), { - 'DESCRIPTOR' : _SPARSEFLOATARRAY, - '__module__' : 'featureValue_pb2' - # @@protoc_insertion_point(class_scope:protobuf.SparseFloatArray) -}) +SparseFloatArray = _reflection.GeneratedProtocolMessageType( + "SparseFloatArray", + (_message.Message,), + { + "DESCRIPTOR": _SPARSEFLOATARRAY, + "__module__": "featureValue_pb2" + # @@protoc_insertion_point(class_scope:protobuf.SparseFloatArray) + }, +) _sym_db.RegisterMessage(SparseFloatArray) if _descriptor._USE_C_DESCRIPTORS == False: - - DESCRIPTOR._options = None - DESCRIPTOR._serialized_options = b'\n)com.linkedin.feathr.common.types.protobuf' - _FEATUREVALUE._serialized_start=33 - _FEATUREVALUE._serialized_end=879 - _BOOLEANARRAY._serialized_start=881 - _BOOLEANARRAY._serialized_end=913 - _STRINGARRAY._serialized_start=915 - _STRINGARRAY._serialized_end=945 - _DOUBLEARRAY._serialized_start=947 - _DOUBLEARRAY._serialized_end=977 - _FLOATARRAY._serialized_start=979 - _FLOATARRAY._serialized_end=1007 - _INTEGERARRAY._serialized_start=1009 - _INTEGERARRAY._serialized_end=1041 - _LONGARRAY._serialized_start=1043 - _LONGARRAY._serialized_end=1069 - _BYTESARRAY._serialized_start=1071 - _BYTESARRAY._serialized_end=1098 - _SPARSESTRINGARRAY._serialized_start=1100 - _SPARSESTRINGARRAY._serialized_end=1166 - _SPARSEBOOLARRAY._serialized_start=1168 - _SPARSEBOOLARRAY._serialized_end=1233 - _SPARSEINTEGERARRAY._serialized_start=1235 - _SPARSEINTEGERARRAY._serialized_end=1303 - _SPARSELONGARRAY._serialized_start=1305 - _SPARSELONGARRAY._serialized_end=1367 - _SPARSEDOUBLEARRAY._serialized_start=1369 - _SPARSEDOUBLEARRAY._serialized_end=1435 - _SPARSEFLOATARRAY._serialized_start=1437 - _SPARSEFLOATARRAY._serialized_end=1501 + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b"\n)com.linkedin.feathr.common.types.protobuf" + _FEATUREVALUE._serialized_start = 33 + _FEATUREVALUE._serialized_end = 879 + _BOOLEANARRAY._serialized_start = 881 + _BOOLEANARRAY._serialized_end = 913 + _STRINGARRAY._serialized_start = 915 + _STRINGARRAY._serialized_end = 945 + _DOUBLEARRAY._serialized_start = 947 + _DOUBLEARRAY._serialized_end = 977 + _FLOATARRAY._serialized_start = 979 + _FLOATARRAY._serialized_end = 1007 + _INTEGERARRAY._serialized_start = 1009 + _INTEGERARRAY._serialized_end = 1041 + _LONGARRAY._serialized_start = 1043 + _LONGARRAY._serialized_end = 1069 + _BYTESARRAY._serialized_start = 1071 + _BYTESARRAY._serialized_end = 1098 + _SPARSESTRINGARRAY._serialized_start = 1100 + _SPARSESTRINGARRAY._serialized_end = 1166 + _SPARSEBOOLARRAY._serialized_start = 1168 + _SPARSEBOOLARRAY._serialized_end = 1233 + _SPARSEINTEGERARRAY._serialized_start = 1235 + _SPARSEINTEGERARRAY._serialized_end = 1303 + _SPARSELONGARRAY._serialized_start = 1305 + _SPARSELONGARRAY._serialized_end = 1367 + _SPARSEDOUBLEARRAY._serialized_start = 1369 + _SPARSEDOUBLEARRAY._serialized_end = 1435 + _SPARSEFLOATARRAY._serialized_start = 1437 + _SPARSEFLOATARRAY._serialized_end = 1501 # @@protoc_insertion_point(module_scope) diff --git a/feathr_project/feathr/registry/_feathr_registry_client.py b/feathr_project/feathr/registry/_feathr_registry_client.py index a70cd5e20..378a81b8a 100644 --- a/feathr_project/feathr/registry/_feathr_registry_client.py +++ b/feathr_project/feathr/registry/_feathr_registry_client.py @@ -14,32 +14,72 @@ from jinja2 import Template import requests -from feathr.constants import INPUT_CONTEXT, TYPEDEF_ANCHOR, TYPEDEF_ANCHOR_FEATURE, TYPEDEF_DERIVED_FEATURE, TYPEDEF_SOURCE +from feathr.constants import ( + INPUT_CONTEXT, + TYPEDEF_ANCHOR, + TYPEDEF_ANCHOR_FEATURE, + TYPEDEF_DERIVED_FEATURE, + TYPEDEF_SOURCE, +) from feathr.definition.anchor import FeatureAnchor from feathr.definition.dtype import FeatureType, str_to_value_type, value_type_to_str from feathr.definition.feature import Feature, FeatureBase from feathr.definition.feature_derivations import DerivedFeature from feathr.definition.repo_definitions import RepoDefinitions -from feathr.definition.source import GenericSource, HdfsSource, InputContext, JdbcSource, SnowflakeSource, Source, SparkSqlSource, KafKaSource, KafkaConfig, AvroJsonSchema +from feathr.definition.source import ( + GenericSource, + HdfsSource, + InputContext, + JdbcSource, + SnowflakeSource, + Source, + SparkSqlSource, + KafKaSource, + KafkaConfig, + AvroJsonSchema, +) from feathr.definition.transformation import ExpressionTransformation, Transformation, WindowAggTransformation from feathr.definition.typed_key import TypedKey from feathr.registry.feature_registry import FeathrRegistry from feathr.utils._file_utils import write_to_file -from feathr.registry.registry_utils import topological_sort, to_camel,source_to_def, anchor_to_def, transformation_to_def, feature_type_to_def, typed_key_to_def, feature_to_def, derived_feature_to_def, _correct_function_indentation +from feathr.registry.registry_utils import ( + topological_sort, + to_camel, + source_to_def, + anchor_to_def, + transformation_to_def, + feature_type_to_def, + typed_key_to_def, + feature_to_def, + derived_feature_to_def, + _correct_function_indentation, +) + class _FeatureRegistry(FeathrRegistry): - def __init__(self, project_name: str, endpoint: str, project_tags: Dict[str, str] = None, credential=None, config_path=None): + def __init__( + self, project_name: str, endpoint: str, project_tags: Dict[str, str] = None, credential=None, config_path=None + ): self.project_name = project_name self.project_tags = project_tags self.endpoint = endpoint # TODO: expand to more credential provider # If FEATHR_SANDBOX is set in the environment variable, don't do auth - self.credential = DefaultAzureCredential( - exclude_interactive_browser_credential=False) if credential is None and not os.environ.get("FEATHR_SANDBOX") else credential + self.credential = ( + DefaultAzureCredential(exclude_interactive_browser_credential=False) + if credential is None and not os.environ.get("FEATHR_SANDBOX") + else credential + ) self.project_id = None - def register_features(self, workspace_path: Optional[Path] = None, from_context: bool = True, anchor_list: List[FeatureAnchor]=[], derived_feature_list=[]): - """Register Features for the specified workspace. + def register_features( + self, + workspace_path: Optional[Path] = None, + from_context: bool = True, + anchor_list: List[FeatureAnchor] = [], + derived_feature_list=[], + ): + """Register Features for the specified workspace. Args: workspace_path (str, optional): path to a workspace. Defaults to None, not used in this implementation. from_context: whether the feature is from context (i.e. end users has to callFeathrClient.build_features()) or the feature is from a pre-built config file. Currently Feathr only supports register features from context. @@ -48,7 +88,8 @@ def register_features(self, workspace_path: Optional[Path] = None, from_context: """ if not from_context: raise RuntimeError( - "Currently Feathr only supports registering features from context (i.e. you must call FeathrClient.build_features() before calling this function).") + "Currently Feathr only supports registering features from context (i.e. you must call FeathrClient.build_features() before calling this function)." + ) # Before starting, create the project self.project_id = self._create_project() @@ -65,13 +106,12 @@ def register_features(self, workspace_path: Optional[Path] = None, from_context: # 3. Create all features on the registry for feature in anchor.features: if not hasattr(feature, "_registry_id"): - feature._registry_id = self._create_anchor_feature( - anchor._registry_id, feature) + feature._registry_id = self._create_anchor_feature(anchor._registry_id, feature) # 4. Create all derived features on the registry for df in topological_sort(derived_feature_list): if not hasattr(df, "_registry_id"): df._registry_id = self._create_derived_feature(df) - url = '/'.join(self.endpoint.split('/')[:3]) + url = "/".join(self.endpoint.split("/")[:3]) logging.info(f"Check project lineage by this link: {url}/projects/{self.project_name}/lineage") def list_registered_features(self, project_name: str) -> List[str]: @@ -80,23 +120,29 @@ def list_registered_features(self, project_name: str) -> List[str]: """ resp = self._get(f"/projects/{project_name}/features") # In V1 API resp should be an array, will be changed in V2 API - return [{ - "name": r["attributes"]["name"], - "id": r["guid"], - "qualifiedName": r["attributes"]["qualifiedName"], - } for r in resp] - + return [ + { + "name": r["attributes"]["name"], + "id": r["guid"], + "qualifiedName": r["attributes"]["qualifiedName"], + } + for r in resp + ] + def list_dependent_entities(self, qualified_name: str): """ Returns list of dependent entities for provided entity """ resp = self._get(f"/dependent/{qualified_name}") - return [{ - "name": r["attributes"]["name"], - "id": r["guid"], - "qualifiedName": r["attributes"]["qualifiedName"], - } for r in resp] - + return [ + { + "name": r["attributes"]["name"], + "id": r["guid"], + "qualifiedName": r["attributes"]["qualifiedName"], + } + for r in resp + ] + def delete_entity(self, qualified_name: str): """ Deletes entity if it has no dependent entities @@ -119,32 +165,28 @@ def _create_project(self) -> UUID: return self.project_id def _create_source(self, s: Source) -> UUID: - r = self._post( - f"/projects/{self.project_id}/datasources", source_to_def(s)) + r = self._post(f"/projects/{self.project_id}/datasources", source_to_def(s)) id = UUID(r["guid"]) s._registry_id = id s._qualified_name = f"{self.project_name}__{s.name}" return id def _create_anchor(self, s: FeatureAnchor) -> UUID: - r = self._post( - f"/projects/{self.project_id}/anchors", anchor_to_def(s)) + r = self._post(f"/projects/{self.project_id}/anchors", anchor_to_def(s)) id = UUID(r["guid"]) s._registry_id = id s._qualified_name = f"{self.project_name}__{s.name}" return id def _create_anchor_feature(self, anchor_name: str, s: Feature) -> UUID: - r = self._post( - f"/projects/{self.project_id}/anchors/{anchor_name}/features", feature_to_def(s)) + r = self._post(f"/projects/{self.project_id}/anchors/{anchor_name}/features", feature_to_def(s)) id = UUID(r["guid"]) s._registry_id = id s._qualified_name = f"{self.project_name}__{anchor_name}__{s.name}" return id def _create_derived_feature(self, s: DerivedFeature) -> UUID: - r = self._post( - f"/projects/{self.project_id}/derivedfeatures", derived_feature_to_def(s)) + r = self._post(f"/projects/{self.project_id}/derivedfeatures", derived_feature_to_def(s)) id = UUID(r["guid"]) s._registry_id = id s._qualified_name = f"{self.project_name}__{s.name}" @@ -153,7 +195,7 @@ def _create_derived_feature(self, s: DerivedFeature) -> UUID: def _get(self, path: str) -> dict: logging.debug("PATH: ", path) return check(requests.get(f"{self.endpoint}{path}", headers=self._get_auth_header())).json() - + def _delete(self, path: str) -> dict: logging.debug("PATH: ", path) return check(requests.delete(f"{self.endpoint}{path}", headers=self._get_auth_header())).json() @@ -166,17 +208,19 @@ def _post(self, path: str, body: dict) -> dict: def _get_auth_header(self) -> dict: # if the environment is sandbox, don't do auth # TODO: expand to more credential providers - return {"Authorization": f'Bearer {self.credential.get_token("https://management.azure.com/.default").token}'} if not os.environ.get("FEATHR_SANDBOX") else None - + return ( + {"Authorization": f'Bearer {self.credential.get_token("https://management.azure.com/.default").token}'} + if not os.environ.get("FEATHR_SANDBOX") + else None + ) + def check(r): if not r.ok: - raise RuntimeError( - f"Failed to call registry API, status is {r.status_code}, error is {r.text}") + raise RuntimeError(f"Failed to call registry API, status is {r.status_code}, error is {r.text}") return r - def dict_to_source(v: dict) -> Source: id = UUID(v["guid"]) type = v["attributes"]["type"] @@ -184,48 +228,51 @@ def dict_to_source(v: dict) -> Source: if type == INPUT_CONTEXT: source = InputContext() elif type == "sparksql": - source = SparkSqlSource(name=v["attributes"]["name"], - sql=v["attributes"].get("sql"), - table=v["attributes"].get("table"), - preprocessing=_correct_function_indentation( - v["attributes"].get("preprocessing")), - event_timestamp_column=v["attributes"].get( - "eventTimestampColumn"), - timestamp_format=v["attributes"].get( - "timestampFormat"), - registry_tags=v["attributes"].get("tags", {})) + source = SparkSqlSource( + name=v["attributes"]["name"], + sql=v["attributes"].get("sql"), + table=v["attributes"].get("table"), + preprocessing=_correct_function_indentation(v["attributes"].get("preprocessing")), + event_timestamp_column=v["attributes"].get("eventTimestampColumn"), + timestamp_format=v["attributes"].get("timestampFormat"), + registry_tags=v["attributes"].get("tags", {}), + ) elif type == "jdbc": - source = JdbcSource(name=v["attributes"]["name"], - url=v["attributes"].get("url"), - dbtable=v["attributes"].get("dbtable"), - query=v["attributes"].get("query"), - auth=v["attributes"].get("auth"), - preprocessing=_correct_function_indentation( - v["attributes"].get("preprocessing")), - event_timestamp_column=v["attributes"].get( - "eventTimestampColumn"), - timestamp_format=v["attributes"].get( - "timestampFormat"), - registry_tags=v["attributes"].get("tags", {})) + source = JdbcSource( + name=v["attributes"]["name"], + url=v["attributes"].get("url"), + dbtable=v["attributes"].get("dbtable"), + query=v["attributes"].get("query"), + auth=v["attributes"].get("auth"), + preprocessing=_correct_function_indentation(v["attributes"].get("preprocessing")), + event_timestamp_column=v["attributes"].get("eventTimestampColumn"), + timestamp_format=v["attributes"].get("timestampFormat"), + registry_tags=v["attributes"].get("tags", {}), + ) elif type == "SNOWFLAKE": snowflake_path = v["attributes"]["path"] snowflake_parameters = SnowflakeSource.parse_snowflake_path(snowflake_path) - source = SnowflakeSource(name=v["attributes"]["name"], - dbtable=snowflake_parameters.get("dbtable", None), - query=snowflake_parameters.get("query", None), - database=snowflake_parameters["sfDatabase"], - schema=snowflake_parameters["sfSchema"], - preprocessing=_correct_function_indentation( - v["attributes"].get("preprocessing")), - event_timestamp_column=v["attributes"].get( - "eventTimestampColumn"), - timestamp_format=v["attributes"].get( - "timestampFormat"), - registry_tags=v["attributes"].get("tags", {})) + source = SnowflakeSource( + name=v["attributes"]["name"], + dbtable=snowflake_parameters.get("dbtable", None), + query=snowflake_parameters.get("query", None), + database=snowflake_parameters["sfDatabase"], + schema=snowflake_parameters["sfSchema"], + preprocessing=_correct_function_indentation(v["attributes"].get("preprocessing")), + event_timestamp_column=v["attributes"].get("eventTimestampColumn"), + timestamp_format=v["attributes"].get("timestampFormat"), + registry_tags=v["attributes"].get("tags", {}), + ) elif type == "kafka": # print('v["attributes"]', v["attributes"]) - kafka_config = KafkaConfig(brokers=v["attributes"].get("brokers", []), topics=v["attributes"].get ("topics", []), schema=AvroJsonSchema(schemaStr= v["attributes"].get("schemaStr", ""))) - source = KafKaSource(name=v["attributes"]["name"], kafkaConfig=kafka_config, registry_tags=v["attributes"].get("tags", {})) + kafka_config = KafkaConfig( + brokers=v["attributes"].get("brokers", []), + topics=v["attributes"].get("topics", []), + schema=AvroJsonSchema(schemaStr=v["attributes"].get("schemaStr", "")), + ) + source = KafKaSource( + name=v["attributes"]["name"], kafkaConfig=kafka_config, registry_tags=v["attributes"].get("tags", {}) + ) elif type == "generic": options = v["attributes"].copy() # These are not options @@ -249,23 +296,20 @@ def dict_to_source(v: dict) -> Source: format=v["attributes"]["format"], mode=v["attributes"].get("mode", ""), options=options, - preprocessing=_correct_function_indentation( - v["attributes"].get("preprocessing")), - event_timestamp_column=v["attributes"].get( - "eventTimestampColumn"), - timestamp_format=v["attributes"].get( - "timestampFormat"), - registry_tags=v["attributes"].get("tags", {})) + preprocessing=_correct_function_indentation(v["attributes"].get("preprocessing")), + event_timestamp_column=v["attributes"].get("eventTimestampColumn"), + timestamp_format=v["attributes"].get("timestampFormat"), + registry_tags=v["attributes"].get("tags", {}), + ) elif v["attributes"].get("path"): - source = HdfsSource(name=v["attributes"]["name"], - path=v["attributes"]["path"], - preprocessing=_correct_function_indentation( - v["attributes"].get("preprocessing")), - event_timestamp_column=v["attributes"].get( - "eventTimestampColumn"), - timestamp_format=v["attributes"].get( - "timestampFormat"), - registry_tags=v["attributes"].get("tags", {})) + source = HdfsSource( + name=v["attributes"]["name"], + path=v["attributes"]["path"], + preprocessing=_correct_function_indentation(v["attributes"].get("preprocessing")), + event_timestamp_column=v["attributes"].get("eventTimestampColumn"), + timestamp_format=v["attributes"].get("timestampFormat"), + registry_tags=v["attributes"].get("tags", {}), + ) else: raise ValueError(f"Invalid source format {type}") source._registry_id = id @@ -274,11 +318,13 @@ def dict_to_source(v: dict) -> Source: def dict_to_anchor(v: dict) -> FeatureAnchor: - ret = FeatureAnchor(name=v["attributes"]["name"], - source=None, - features=[], - registry_tags=v["attributes"].get("tags", {}), - __no_validate=True) + ret = FeatureAnchor( + name=v["attributes"]["name"], + source=None, + features=[], + registry_tags=v["attributes"].get("tags", {}), + __no_validate=True, + ) ret._source_id = UUID(v["attributes"]["source"]["guid"]) ret._features = [UUID(f["guid"]) for f in v["attributes"]["features"]] ret._qualified_name = v["attributes"]["qualifiedName"] @@ -286,23 +332,22 @@ def dict_to_anchor(v: dict) -> FeatureAnchor: return ret - - - def dict_to_transformation(v: dict) -> Transformation: if v is None: return None v = to_camel(v) - if 'transformExpr' in v: + if "transformExpr" in v: # it's ExpressionTransformation - return ExpressionTransformation(v['transformExpr']) - elif 'defExpr' in v: - return WindowAggTransformation(agg_expr=v['defExpr'], - agg_func=v.get('aggFunc'), - window=v.get('window'), - group_by=v.get('groupBy'), - filter=v.get('filter'), - limit=v.get('limit')) + return ExpressionTransformation(v["transformExpr"]) + elif "defExpr" in v: + return WindowAggTransformation( + agg_expr=v["defExpr"], + agg_func=v.get("aggFunc"), + window=v.get("window"), + group_by=v.get("groupBy"), + filter=v.get("filter"), + limit=v.get("limit"), + ) raise ValueError(f"Invalid transformation format {v}") @@ -316,18 +361,16 @@ def feature_type_to_def(v: FeatureType) -> dict: def dict_to_feature_type(v: dict) -> FeatureType: - return FeatureType(val_type=str_to_value_type(v["valType"]), - dimension_type=[str_to_value_type( - s) for s in v["dimensionType"]], - tensor_category=v["tensorCategory"], - type=v["type"]) + return FeatureType( + val_type=str_to_value_type(v["valType"]), + dimension_type=[str_to_value_type(s) for s in v["dimensionType"]], + tensor_category=v["tensorCategory"], + type=v["type"], + ) def typed_key_to_def(v: TypedKey) -> dict: - ret = { - "keyColumn": v.key_column, - "keyColumnType": value_type_to_str(v.key_column_type) - } + ret = {"keyColumn": v.key_column, "keyColumnType": value_type_to_str(v.key_column_type)} if v.full_name: ret["fullName"] = v.full_name if v.description: @@ -339,22 +382,23 @@ def typed_key_to_def(v: TypedKey) -> dict: def dict_to_typed_key(v: dict) -> TypedKey: v = to_camel(v) - return TypedKey(key_column=v["keyColumn"], - key_column_type=str_to_value_type(v["keyColumnType"]), - full_name=v.get("fullName"), - description=v.get("description"), - key_column_alias=v.get("keyColumnAlias")) - - + return TypedKey( + key_column=v["keyColumn"], + key_column_type=str_to_value_type(v["keyColumnType"]), + full_name=v.get("fullName"), + description=v.get("description"), + key_column_alias=v.get("keyColumnAlias"), + ) def dict_to_feature(v: dict) -> Feature: - ret = Feature(name=v["attributes"]["name"], - feature_type=dict_to_feature_type(v["attributes"]["type"]), - key=[dict_to_typed_key(k) for k in v["attributes"]["key"]], - transform=dict_to_transformation( - v["attributes"].get("transformation")), - registry_tags=v["attributes"].get("tags", {})) + ret = Feature( + name=v["attributes"]["name"], + feature_type=dict_to_feature_type(v["attributes"]["type"]), + key=[dict_to_typed_key(k) for k in v["attributes"]["key"]], + transform=dict_to_transformation(v["attributes"].get("transformation")), + registry_tags=v["attributes"].get("tags", {}), + ) ret._qualified_name = v["attributes"]["qualifiedName"] ret._registry_id = UUID(v["guid"]) return ret @@ -387,20 +431,17 @@ def derived_feature_to_def(v: DerivedFeature) -> dict: def dict_to_derived_feature(v: dict) -> DerivedFeature: v["attributes"] = to_camel(v["attributes"]) - ret = DerivedFeature(name=v["attributes"]["name"], - feature_type=dict_to_feature_type( - v["attributes"]["type"]), - input_features=[], - key=[dict_to_typed_key(k) - for k in v["attributes"]["key"]], - transform=dict_to_transformation( - v["attributes"]["transformation"]), - registry_tags=v["attributes"].get("tags", {}), - __no_validate=True) - ret._input_anchor_features = [ - UUID(f["guid"]) for f in v["attributes"]["inputAnchorFeatures"]] - ret._input_derived_features = [ - UUID(f["guid"]) for f in v["attributes"]["inputDerivedFeatures"]] + ret = DerivedFeature( + name=v["attributes"]["name"], + feature_type=dict_to_feature_type(v["attributes"]["type"]), + input_features=[], + key=[dict_to_typed_key(k) for k in v["attributes"]["key"]], + transform=dict_to_transformation(v["attributes"]["transformation"]), + registry_tags=v["attributes"].get("tags", {}), + __no_validate=True, + ) + ret._input_anchor_features = [UUID(f["guid"]) for f in v["attributes"]["inputAnchorFeatures"]] + ret._input_derived_features = [UUID(f["guid"]) for f in v["attributes"]["inputDerivedFeatures"]] ret._qualified_name = v["attributes"]["qualifiedName"] ret._registry_id = UUID(v["guid"]) return ret @@ -416,14 +457,26 @@ def __init__(self, v: dict): def dict_to_project(v: dict) -> Tuple[List[FeatureAnchor], List[DerivedFeature]]: entities = v["guidEntityMap"] # Step 1, Extract each entity - sources = dict([(UUID(k), dict_to_source(entities[k])) - for k in entities if entities[k]["typeName"] == "feathr_source_v1"]) - anchors = dict([(UUID(k), dict_to_anchor(entities[k])) - for k in entities if entities[k]["typeName"] == "feathr_anchor_v1"]) - features = dict([(UUID(k), dict_to_feature(entities[k])) - for k in entities if entities[k]["typeName"] == "feathr_anchor_feature_v1"]) - derived_features = dict([(UUID(k), dict_to_derived_feature(entities[k])) - for k in entities if entities[k]["typeName"] == "feathr_derived_feature_v1"]) + sources = dict( + [(UUID(k), dict_to_source(entities[k])) for k in entities if entities[k]["typeName"] == "feathr_source_v1"] + ) + anchors = dict( + [(UUID(k), dict_to_anchor(entities[k])) for k in entities if entities[k]["typeName"] == "feathr_anchor_v1"] + ) + features = dict( + [ + (UUID(k), dict_to_feature(entities[k])) + for k in entities + if entities[k]["typeName"] == "feathr_anchor_feature_v1" + ] + ) + derived_features = dict( + [ + (UUID(k), dict_to_derived_feature(entities[k])) + for k in entities + if entities[k]["typeName"] == "feathr_derived_feature_v1" + ] + ) # Step 2, Setup connections between extracted entities # Step 2-1, Set up anchors for k in anchors: @@ -433,11 +486,7 @@ def dict_to_project(v: dict) -> Tuple[List[FeatureAnchor], List[DerivedFeature]] # Step 2-1, Set up derived features for k in derived_features: df = derived_features[k] - input_anchor_features = [features[id] - for id in df._input_anchor_features] - input_derived_features = [derived_features[id] - for id in df._input_derived_features] + input_anchor_features = [features[id] for id in df._input_anchor_features] + input_derived_features = [derived_features[id] for id in df._input_derived_features] df.input_features = input_anchor_features + input_derived_features return (list(anchors.values()), list(derived_features.values())) - - diff --git a/feathr_project/feathr/registry/_feature_registry_purview.py b/feathr_project/feathr/registry/_feature_registry_purview.py index d47105a37..a8778f87e 100644 --- a/feathr_project/feathr/registry/_feature_registry_purview.py +++ b/feathr_project/feathr/registry/_feature_registry_purview.py @@ -14,13 +14,16 @@ from azure.identity import DefaultAzureCredential from loguru import logger from pyapacheatlas.auth.azcredential import AzCredentialWrapper -from pyapacheatlas.core import (AtlasClassification, AtlasEntity, AtlasProcess, - PurviewClient, TypeCategory) -from pyapacheatlas.core.typedef import (AtlasAttributeDef, - AtlasRelationshipEndDef, Cardinality, - EntityTypeDef, RelationshipTypeDef) - -from pyapacheatlas.core.util import GuidTracker,AtlasException +from pyapacheatlas.core import AtlasClassification, AtlasEntity, AtlasProcess, PurviewClient, TypeCategory +from pyapacheatlas.core.typedef import ( + AtlasAttributeDef, + AtlasRelationshipEndDef, + Cardinality, + EntityTypeDef, + RelationshipTypeDef, +) + +from pyapacheatlas.core.util import GuidTracker, AtlasException from pyhocon import ConfigFactory from feathr.definition.dtype import * @@ -31,13 +34,13 @@ from feathr.definition.feature_derivations import DerivedFeature from feathr.definition.repo_definitions import RepoDefinitions from feathr.definition.source import HdfsSource, InputContext, JdbcSource, SnowflakeSource, Source -from feathr.definition.transformation import (ExpressionTransformation, Transformation, - WindowAggTransformation) +from feathr.definition.transformation import ExpressionTransformation, Transformation, WindowAggTransformation from feathr.definition.typed_key import TypedKey from feathr.registry.feature_registry import FeathrRegistry from feathr.constants import * + def _to_snake(d, level: int = 0): """ Convert `string`, `list[string]`, or all keys in a `dict` into snake case @@ -48,13 +51,15 @@ def _to_snake(d, level: int = 0): raise ValueError("Too many nested levels") if isinstance(d, str): d = d[:100] - return re.sub(r'(? 100: raise ValueError("Dict has too many keys") - return {_to_snake(a, level + 1): _to_snake(b, level + 1) if isinstance(b, (dict, list)) else b for a, b in d.items()} + return { + _to_snake(a, level + 1): _to_snake(b, level + 1) if isinstance(b, (dict, list)) else b for a, b in d.items() + } class _PurviewRegistry(FeathrRegistry): @@ -64,18 +69,26 @@ class _PurviewRegistry(FeathrRegistry): - Initialize an Azure Purview Client - Initialize the GUID tracker, project name, etc. """ - def __init__(self, project_name: str, azure_purview_name: str, registry_delimiter: str, project_tags: Dict[str, str] = None, credential=None, config_path=None,): + + def __init__( + self, + project_name: str, + azure_purview_name: str, + registry_delimiter: str, + project_tags: Dict[str, str] = None, + credential=None, + config_path=None, + ): self.project_name = project_name self.registry_delimiter = registry_delimiter self.azure_purview_name = azure_purview_name self.project_tags = project_tags - self.credential = DefaultAzureCredential(exclude_interactive_browser_credential=False) if credential is None else credential + self.credential = ( + DefaultAzureCredential(exclude_interactive_browser_credential=False) if credential is None else credential + ) self.oauth = AzCredentialWrapper(credential=self.credential) - self.purview_client = PurviewClient( - account_name=self.azure_purview_name, - authentication=self.oauth - ) + self.purview_client = PurviewClient(account_name=self.azure_purview_name, authentication=self.oauth) self.guid = GuidTracker(starting=-1000) self.entity_batch_queue = [] @@ -93,40 +106,27 @@ def _register_feathr_feature_types(self): name=TYPEDEF_FEATHR_PROJECT, attributeDefs=[ # TODO: this should be called "anchors" rather than "anchor_features" to make it less confusing. + AtlasAttributeDef(name="anchor_features", typeName=TYPEDEF_ARRAY_ANCHOR, cardinality=Cardinality.SET), AtlasAttributeDef( - name="anchor_features", typeName=TYPEDEF_ARRAY_ANCHOR, cardinality=Cardinality.SET), - AtlasAttributeDef( - name="derived_features", typeName=TYPEDEF_ARRAY_DERIVED_FEATURE, cardinality=Cardinality.SET), - AtlasAttributeDef(name="tags", typeName="map", - cardinality=Cardinality.SINGLE), + name="derived_features", typeName=TYPEDEF_ARRAY_DERIVED_FEATURE, cardinality=Cardinality.SET + ), + AtlasAttributeDef(name="tags", typeName="map", cardinality=Cardinality.SINGLE), ], superTypes=["DataSet"], - ) type_feathr_sources = EntityTypeDef( name=TYPEDEF_SOURCE, attributeDefs=[ - - AtlasAttributeDef( - name="path", typeName="string", cardinality=Cardinality.SINGLE), - AtlasAttributeDef( - name="url", typeName="string", cardinality=Cardinality.SINGLE), - AtlasAttributeDef( - name="dbtable", typeName="string", cardinality=Cardinality.SINGLE), - AtlasAttributeDef( - name="query", typeName="string", cardinality=Cardinality.SINGLE), - AtlasAttributeDef( - name="auth", typeName="string", cardinality=Cardinality.SINGLE), - AtlasAttributeDef(name="event_timestamp_column", - typeName="string", cardinality=Cardinality.SINGLE), - AtlasAttributeDef(name="timestamp_format", - typeName="string", cardinality=Cardinality.SINGLE), - AtlasAttributeDef(name="type", typeName="string", - cardinality=Cardinality.SINGLE), - AtlasAttributeDef(name="preprocessing", typeName="string", - cardinality=Cardinality.SINGLE), - AtlasAttributeDef(name="tags", typeName="map", - cardinality=Cardinality.SINGLE), + AtlasAttributeDef(name="path", typeName="string", cardinality=Cardinality.SINGLE), + AtlasAttributeDef(name="url", typeName="string", cardinality=Cardinality.SINGLE), + AtlasAttributeDef(name="dbtable", typeName="string", cardinality=Cardinality.SINGLE), + AtlasAttributeDef(name="query", typeName="string", cardinality=Cardinality.SINGLE), + AtlasAttributeDef(name="auth", typeName="string", cardinality=Cardinality.SINGLE), + AtlasAttributeDef(name="event_timestamp_column", typeName="string", cardinality=Cardinality.SINGLE), + AtlasAttributeDef(name="timestamp_format", typeName="string", cardinality=Cardinality.SINGLE), + AtlasAttributeDef(name="type", typeName="string", cardinality=Cardinality.SINGLE), + AtlasAttributeDef(name="preprocessing", typeName="string", cardinality=Cardinality.SINGLE), + AtlasAttributeDef(name="tags", typeName="map", cardinality=Cardinality.SINGLE), ], superTypes=["DataSet"], ) @@ -134,14 +134,10 @@ def _register_feathr_feature_types(self): type_feathr_anchor_features = EntityTypeDef( name=TYPEDEF_ANCHOR_FEATURE, attributeDefs=[ - AtlasAttributeDef(name="type", typeName="string", - cardinality=Cardinality.SINGLE), - AtlasAttributeDef(name="key", typeName="array>", - cardinality=Cardinality.SET), - AtlasAttributeDef(name="transformation", typeName="map", - cardinality=Cardinality.SINGLE), - AtlasAttributeDef(name="tags", typeName="map", - cardinality=Cardinality.SINGLE), + AtlasAttributeDef(name="type", typeName="string", cardinality=Cardinality.SINGLE), + AtlasAttributeDef(name="key", typeName="array>", cardinality=Cardinality.SET), + AtlasAttributeDef(name="transformation", typeName="map", cardinality=Cardinality.SINGLE), + AtlasAttributeDef(name="tags", typeName="map", cardinality=Cardinality.SINGLE), ], superTypes=["DataSet"], ) @@ -149,19 +145,16 @@ def _register_feathr_feature_types(self): type_feathr_derived_features = EntityTypeDef( name=TYPEDEF_DERIVED_FEATURE, attributeDefs=[ - AtlasAttributeDef(name="type", typeName="string", - cardinality=Cardinality.SINGLE), - - AtlasAttributeDef(name="input_anchor_features", typeName=TYPEDEF_ARRAY_ANCHOR_FEATURE, - cardinality=Cardinality.SET), - AtlasAttributeDef(name="input_derived_features", typeName=TYPEDEF_ARRAY_DERIVED_FEATURE, - cardinality=Cardinality.SET), - AtlasAttributeDef(name="key", typeName="array>", - cardinality=Cardinality.SET), - AtlasAttributeDef(name="transformation", typeName="map", - cardinality=Cardinality.SINGLE), - AtlasAttributeDef(name="tags", typeName="map", - cardinality=Cardinality.SINGLE), + AtlasAttributeDef(name="type", typeName="string", cardinality=Cardinality.SINGLE), + AtlasAttributeDef( + name="input_anchor_features", typeName=TYPEDEF_ARRAY_ANCHOR_FEATURE, cardinality=Cardinality.SET + ), + AtlasAttributeDef( + name="input_derived_features", typeName=TYPEDEF_ARRAY_DERIVED_FEATURE, cardinality=Cardinality.SET + ), + AtlasAttributeDef(name="key", typeName="array>", cardinality=Cardinality.SET), + AtlasAttributeDef(name="transformation", typeName="map", cardinality=Cardinality.SINGLE), + AtlasAttributeDef(name="tags", typeName="map", cardinality=Cardinality.SINGLE), ], superTypes=["DataSet"], ) @@ -169,20 +162,23 @@ def _register_feathr_feature_types(self): type_feathr_anchors = EntityTypeDef( name=TYPEDEF_ANCHOR, attributeDefs=[ - AtlasAttributeDef( - name="source", typeName=TYPEDEF_SOURCE, cardinality=Cardinality.SINGLE), - AtlasAttributeDef( - name="features", typeName=TYPEDEF_ARRAY_ANCHOR_FEATURE, cardinality=Cardinality.SET), - AtlasAttributeDef(name="tags", typeName="map", - cardinality=Cardinality.SINGLE), + AtlasAttributeDef(name="source", typeName=TYPEDEF_SOURCE, cardinality=Cardinality.SINGLE), + AtlasAttributeDef(name="features", typeName=TYPEDEF_ARRAY_ANCHOR_FEATURE, cardinality=Cardinality.SET), + AtlasAttributeDef(name="tags", typeName="map", cardinality=Cardinality.SINGLE), ], superTypes=["DataSet"], ) def_result = self.purview_client.upload_typedefs( - entityDefs=[type_feathr_anchor_features, type_feathr_anchors, - type_feathr_derived_features, type_feathr_sources, type_feathr_project], - force_update=True) + entityDefs=[ + type_feathr_anchor_features, + type_feathr_anchors, + type_feathr_derived_features, + type_feathr_sources, + type_feathr_project, + ], + force_update=True, + ) logger.info("Feathr Feature Type System Initialized.") def _parse_anchor_features(self, anchor: FeatureAnchor) -> List[AtlasEntity]: @@ -197,8 +193,13 @@ def _parse_anchor_features(self, anchor: FeatureAnchor) -> List[AtlasEntity]: for anchor_feature in anchor.features: key_list = [] for individual_key in anchor_feature.key: - key_dict = {"key_column": individual_key.key_column, "key_column_type": individual_key.key_column_type.value, - "full_name": individual_key.full_name, "description": individual_key.description, "key_column_alias": individual_key.key_column_alias} + key_dict = { + "key_column": individual_key.key_column, + "key_column_type": individual_key.key_column_type.value, + "full_name": individual_key.full_name, + "description": individual_key.description, + "key_column_alias": individual_key.key_column_alias, + } key_list.append(key_dict) # define a dict to save all the transformation schema @@ -217,8 +218,11 @@ def _parse_anchor_features(self, anchor: FeatureAnchor) -> List[AtlasEntity]: anchor_feature_entity = AtlasEntity( name=anchor_feature.name, - qualified_name=self.project_name + self.registry_delimiter + - anchor.name + self.registry_delimiter + anchor_feature.name, + qualified_name=self.project_name + + self.registry_delimiter + + anchor.name + + self.registry_delimiter + + anchor_feature.name, attributes={ "type": anchor_feature.feature_type.to_feature_config(), "key": key_list, @@ -244,17 +248,17 @@ def _parse_anchors(self, anchor_list: List[FeatureAnchor]) -> List[AtlasEntity]: anchor_feature_entities = self._parse_anchor_features(anchor) # then parse the source of that anchor source_entity = self._parse_source(anchor.source) - anchor_fully_qualified_name = self.project_name+self.registry_delimiter+anchor.name - original_id = self.get_feature_id(anchor_fully_qualified_name, type=TYPEDEF_ANCHOR ) + anchor_fully_qualified_name = self.project_name + self.registry_delimiter + anchor.name + original_id = self.get_feature_id(anchor_fully_qualified_name, type=TYPEDEF_ANCHOR) original_anchor = self.get_feature_by_guid(original_id) if original_id else None - merged_elements = self._merge_anchor(original_anchor,anchor_feature_entities) + merged_elements = self._merge_anchor(original_anchor, anchor_feature_entities) anchor_entity = AtlasEntity( name=anchor.name, - qualified_name=anchor_fully_qualified_name , + qualified_name=anchor_fully_qualified_name, attributes={ "source": source_entity.to_json(minimum=True), "features": merged_elements, - "tags": anchor.registry_tags + "tags": anchor.registry_tags, }, typeName=TYPEDEF_ANCHOR, guid=self.guid.get_guid(), @@ -264,9 +268,14 @@ def _parse_anchors(self, anchor_list: List[FeatureAnchor]) -> List[AtlasEntity]: lineage = AtlasProcess( name=anchor_feature_entity.name + " to " + anchor.name, typeName="Process", - qualified_name=self.registry_delimiter + "PROCESS" + self.registry_delimiter + self.project_name + - self.registry_delimiter + anchor.name + self.registry_delimiter + - anchor_feature_entity.name, + qualified_name=self.registry_delimiter + + "PROCESS" + + self.registry_delimiter + + self.project_name + + self.registry_delimiter + + anchor.name + + self.registry_delimiter + + anchor_feature_entity.name, inputs=[anchor_feature_entity], outputs=[anchor_entity], guid=self.guid.get_guid(), @@ -277,9 +286,14 @@ def _parse_anchors(self, anchor_list: List[FeatureAnchor]) -> List[AtlasEntity]: anchor_source_lineage = AtlasProcess( name=source_entity.name + " to " + anchor.name, typeName="Process", - qualified_name=self.registry_delimiter + "PROCESS" + self.registry_delimiter + self.project_name + - self.registry_delimiter + anchor.name + self.registry_delimiter + - source_entity.name, + qualified_name=self.registry_delimiter + + "PROCESS" + + self.registry_delimiter + + self.project_name + + self.registry_delimiter + + anchor.name + + self.registry_delimiter + + source_entity.name, inputs=[source_entity], outputs=[anchor_entity], guid=self.guid.get_guid(), @@ -289,10 +303,10 @@ def _parse_anchors(self, anchor_list: List[FeatureAnchor]) -> List[AtlasEntity]: anchors_batch.append(anchor_entity) return anchors_batch - def _merge_anchor(self,original_anchor:Dict, new_anchor:Dict)->List[Dict[str,any]]: - ''' + def _merge_anchor(self, original_anchor: Dict, new_anchor: Dict) -> List[Dict[str, any]]: + """ Merge the new anchors defined locally with the anchors that is defined in the centralized registry. - ''' + """ # TODO: This will serve as a quick fix, full fix will work with MVCC, and is in progress. new_anchor_json_repr = [s.to_json(minimum=True) for s in new_anchor] if not original_anchor: @@ -300,19 +314,19 @@ def _merge_anchor(self,original_anchor:Dict, new_anchor:Dict)->List[Dict[str,any # sample : [{'guid':'GUID_OF_ANCHOR','typeName':'','qualifiedName':'QUALIFIED_NAME'} return new_anchor_json_repr else: - original_anchor_elements = [x for x in original_anchor['entity']['attributes']['features']] + original_anchor_elements = [x for x in original_anchor["entity"]["attributes"]["features"]] transformed_original_elements = { - x['uniqueAttributes']['qualifiedName']: - { - 'guid':x['guid'], - 'typeName':x['typeName'], - 'qualifiedName':x['uniqueAttributes']['qualifiedName'] + x["uniqueAttributes"]["qualifiedName"]: { + "guid": x["guid"], + "typeName": x["typeName"], + "qualifiedName": x["uniqueAttributes"]["qualifiedName"], } - for x in original_anchor_elements} + for x in original_anchor_elements + } for elem in new_anchor_json_repr: - transformed_original_elements.setdefault(elem['qualifiedName'],elem) + transformed_original_elements.setdefault(elem["qualifiedName"], elem) return list(transformed_original_elements.values()) - + def _parse_source(self, source: Union[Source, HdfsSource, JdbcSource, SnowflakeSource]) -> AtlasEntity: """ parse the input sources @@ -320,9 +334,9 @@ def _parse_source(self, source: Union[Source, HdfsSource, JdbcSource, SnowflakeS input_context = False if isinstance(source, InputContext): input_context = True - + # only set preprocessing if it's available in the object and is not None - if 'preprocessing' in dir(source) and source.preprocessing is not None: + if "preprocessing" in dir(source) and source.preprocessing is not None: preprocessing_func = inspect.getsource(source.preprocessing) else: preprocessing_func = None @@ -335,7 +349,7 @@ def _parse_source(self, source: Union[Source, HdfsSource, JdbcSource, SnowflakeS "timestamp_format": source.timestamp_format, "event_timestamp_column": source.event_timestamp_column, "tags": source.registry_tags, - "preprocessing": preprocessing_func # store the UDF as a string + "preprocessing": preprocessing_func, # store the UDF as a string } if source.auth is not None: attrs["auth"] = source.auth @@ -351,7 +365,7 @@ def _parse_source(self, source: Union[Source, HdfsSource, JdbcSource, SnowflakeS "timestamp_format": source.timestamp_format, "event_timestamp_column": source.event_timestamp_column, "tags": source.registry_tags, - "preprocessing": preprocessing_func # store the UDF as a string + "preprocessing": preprocessing_func, # store the UDF as a string } if source.dbtable is not None: attrs["dbtable"] = source.dbtable @@ -364,7 +378,7 @@ def _parse_source(self, source: Union[Source, HdfsSource, JdbcSource, SnowflakeS "timestamp_format": source.timestamp_format, "event_timestamp_column": source.event_timestamp_column, "tags": source.registry_tags, - "preprocessing": preprocessing_func # store the UDF as a string + "preprocessing": preprocessing_func, # store the UDF as a string } source_entity = AtlasEntity( name=source.name, @@ -376,7 +390,7 @@ def _parse_source(self, source: Union[Source, HdfsSource, JdbcSource, SnowflakeS self.entity_batch_queue.append(source_entity) return source_entity - def _add_all_derived_features(self, derived_features: List[DerivedFeature], ts:TopologicalSorter ) -> None: + def _add_all_derived_features(self, derived_features: List[DerivedFeature], ts: TopologicalSorter) -> None: """iterate thru all the dependencies of the derived feature and return a derived feature list in a topological sorted way (the result list only has derived features, without their anchor features) Args: @@ -401,11 +415,10 @@ def _add_all_derived_features(self, derived_features: List[DerivedFeature], ts:T # `input_feature` is predecessor of `derived_feature` ts.add(derived_feature, input_feature) # if any of the input feature is a derived feature, have this recursive call - # use this for code simplicity. + # use this for code simplicity. # if the amount of features is huge, consider only add the derived features into the function call self._add_all_derived_features(input_feature.input_features, ts) - def _parse_derived_features(self, derived_features: List[DerivedFeature]) -> List[AtlasEntity]: """parse derived feature @@ -421,16 +434,22 @@ def _parse_derived_features(self, derived_features: List[DerivedFeature]) -> Lis self._add_all_derived_features(derived_features, ts) # topological sort the derived features to make sure that we can correctly refer to them later in the registry toposorted_derived_feature_list: List[DerivedFeature] = list(ts.static_order()) - + for derived_feature in toposorted_derived_feature_list: # get the corresponding Atlas entity by searching feature name # Since this list is topological sorted, so you can always find the corresponding name input_feature_entity_list: List[AtlasEntity] = [ - self.global_feature_entity_dict[f.name] for f in derived_feature.input_features] + self.global_feature_entity_dict[f.name] for f in derived_feature.input_features + ] key_list = [] for individual_key in derived_feature.key: - key_dict = {"key_column": individual_key.key_column, "key_column_type": individual_key.key_column_type.value, - "full_name": individual_key.full_name, "description": individual_key.description, "key_column_alias": individual_key.key_column_alias} + key_dict = { + "key_column": individual_key.key_column, + "key_column_type": individual_key.key_column_type.value, + "full_name": individual_key.full_name, + "description": individual_key.description, + "key_column_alias": individual_key.key_column_alias, + } key_list.append(key_dict) # define a dict to save all the transformation schema @@ -446,25 +465,31 @@ def _parse_derived_features(self, derived_features: List[DerivedFeature]) -> Lis "filter": derived_feature.transform.filter, "limit": derived_feature.transform.limit, } - + derived_feature_entity = AtlasEntity( name=derived_feature.name, - qualified_name=self.project_name + - self.registry_delimiter + derived_feature.name, + qualified_name=self.project_name + self.registry_delimiter + derived_feature.name, attributes={ "type": derived_feature.feature_type.to_feature_config(), "key": key_list, - "input_anchor_features": [f.to_json(minimum=True) for f in input_feature_entity_list if f.typeName==TYPEDEF_ANCHOR_FEATURE], - "input_derived_features": [f.to_json(minimum=True) for f in input_feature_entity_list if f.typeName==TYPEDEF_DERIVED_FEATURE], + "input_anchor_features": [ + f.to_json(minimum=True) + for f in input_feature_entity_list + if f.typeName == TYPEDEF_ANCHOR_FEATURE + ], + "input_derived_features": [ + f.to_json(minimum=True) + for f in input_feature_entity_list + if f.typeName == TYPEDEF_DERIVED_FEATURE + ], "transformation": transform_dict, "tags": derived_feature.registry_tags, - }, typeName=TYPEDEF_DERIVED_FEATURE, guid=self.guid.get_guid(), ) - # Add the feature entity in the global dict so that it can be referenced further. + # Add the feature entity in the global dict so that it can be referenced further. self.global_feature_entity_dict[derived_feature.name] = derived_feature_entity for input_feature_entity in input_feature_entity_list: @@ -472,9 +497,14 @@ def _parse_derived_features(self, derived_features: List[DerivedFeature]) -> Lis derived_feature_feature_lineage = AtlasProcess( name=input_feature_entity.name + " to " + derived_feature.name, typeName="Process", - qualified_name=self.registry_delimiter + "PROCESS" + self.registry_delimiter + self.project_name + - self.registry_delimiter + derived_feature.name + self.registry_delimiter + - input_feature_entity.name, + qualified_name=self.registry_delimiter + + "PROCESS" + + self.registry_delimiter + + self.project_name + + self.registry_delimiter + + derived_feature.name + + self.registry_delimiter + + input_feature_entity.name, inputs=[input_feature_entity], outputs=[derived_feature_entity], guid=self.guid.get_guid(), @@ -496,14 +526,14 @@ def _parse_features_from_context(self, workspace_path: str, anchor_list, derived if anchor_list: anchor_entities = self._parse_anchors(anchor_list) - project_attributes = {"anchor_features": [ - s.to_json(minimum=True) for s in anchor_entities], "tags": self.project_tags} + project_attributes = { + "anchor_features": [s.to_json(minimum=True) for s in anchor_entities], + "tags": self.project_tags, + } # add derived feature if it's there if derived_feature_list: - derived_feature_entities = self._parse_derived_features( - derived_feature_list) - project_attributes["derived_features"] = [ - s.to_json(minimum=True) for s in derived_feature_entities] + derived_feature_entities = self._parse_derived_features(derived_feature_list) + project_attributes["derived_features"] = [s.to_json(minimum=True) for s in derived_feature_entities] # define project in Atlas entity feathr_project_entity = AtlasEntity( @@ -516,13 +546,16 @@ def _parse_features_from_context(self, workspace_path: str, anchor_list, derived # add lineage from anchor to project for individual_anchor_entity in anchor_entities: - lineage_process = AtlasProcess( name=individual_anchor_entity.name + " to " + self.project_name, typeName="Process", # fqdn: PROCESS+PROJECT_NAME+ANCHOR_NAME - qualified_name=self.registry_delimiter + "PROCESS" + self.registry_delimiter + self.project_name + - self.registry_delimiter + individual_anchor_entity.name, + qualified_name=self.registry_delimiter + + "PROCESS" + + self.registry_delimiter + + self.project_name + + self.registry_delimiter + + individual_anchor_entity.name, inputs=[individual_anchor_entity], outputs=[feathr_project_entity], guid=self.guid.get_guid(), @@ -535,144 +568,174 @@ def _parse_features_from_context(self, workspace_path: str, anchor_list, derived name=derived_feature_entity.name + " to " + self.project_name, typeName="Process", # fqdn: PROCESS+PROJECT_NAME+DERIVATION_NAME - qualified_name=self.registry_delimiter + "PROCESS" + self.registry_delimiter + self.project_name + - self.registry_delimiter + derived_feature_entity.name, + qualified_name=self.registry_delimiter + + "PROCESS" + + self.registry_delimiter + + self.project_name + + self.registry_delimiter + + derived_feature_entity.name, inputs=[derived_feature_entity], outputs=[feathr_project_entity], guid=self.guid.get_guid(), ) self.entity_batch_queue.append(lineage_process) - + self.entity_batch_queue.append(feathr_project_entity) self.entity_batch_queue.extend(anchor_entities) self.entity_batch_queue.extend(derived_feature_entities) def _create_project(self) -> UUID: - ''' + """ create a project entity - ''' + """ predefined_guid = self.guid.get_guid() feathr_project_entity = AtlasEntity( name=self.project_name, qualified_name=self.project_name, typeName=TYPEDEF_FEATHR_PROJECT, - guid=predefined_guid) + guid=predefined_guid, + ) guid = self.upload_single_entity_to_purview(feathr_project_entity) return guid - def upload_single_entity_to_purview(self,entity:Union[AtlasEntity,AtlasProcess]): - ''' - Upload a single entity to purview, could be a process entity or AtlasEntity. + def upload_single_entity_to_purview(self, entity: Union[AtlasEntity, AtlasProcess]): + """ + Upload a single entity to purview, could be a process entity or AtlasEntity. Since this is used for migration existing project, ignore Atlas PreconditionFail (412) If the entity already exists, return the existing entity's GUID. Otherwise, return the new entity GUID. The entity itself will also be modified, fill the GUID with real GUID in Purview. In order to avoid having concurrency issue, and provide clear guidance, this method only allows entity uploading once at a time. - ''' + """ try: """ Try to find existing entity/process first, if found, return the existing entity's GUID """ - response = self.purview_client.get_entity(qualifiedName=entity.qualifiedName)['entities'][0] + response = self.purview_client.get_entity(qualifiedName=entity.qualifiedName)["entities"][0] j = entity.to_json() if j["typeName"] == response["typeName"]: if j["typeName"] == "Process": if response["attributes"]["qualifiedName"] != j["attributes"]["qualifiedName"]: - raise RuntimeError("The requested entity %s conflicts with the existing entity in PurView" % j["attributes"]["qualifiedName"]) + raise RuntimeError( + "The requested entity %s conflicts with the existing entity in PurView" + % j["attributes"]["qualifiedName"] + ) else: - if "type" in response['attributes'] and response["typeName"] in (TYPEDEF_ANCHOR_FEATURE, TYPEDEF_DERIVED_FEATURE): - conf = ConfigFactory.parse_string(response['attributes']['type']) - response['attributes']['type'] = dict(conf) + if "type" in response["attributes"] and response["typeName"] in ( + TYPEDEF_ANCHOR_FEATURE, + TYPEDEF_DERIVED_FEATURE, + ): + conf = ConfigFactory.parse_string(response["attributes"]["type"]) + response["attributes"]["type"] = dict(conf) keys = set([_to_snake(key) for key in j["attributes"].keys()]) - set(["qualified_name"]) keys.add("qualifiedName") for k in keys: if response["attributes"][k] != j["attributes"][k]: - raise RuntimeError("The requested entity %s conflicts with the existing entity in PurView" % j["attributes"]["qualifiedName"]) + raise RuntimeError( + "The requested entity %s conflicts with the existing entity in PurView" + % j["attributes"]["qualifiedName"] + ) return response["guid"] else: - raise RuntimeError("The requested entity %s conflicts with the existing entity in PurView" % j["attributes"]["qualifiedName"]) + raise RuntimeError( + "The requested entity %s conflicts with the existing entity in PurView" + % j["attributes"]["qualifiedName"] + ) except AtlasException as e: pass - + try: - entity.lastModifiedTS="0" + entity.lastModifiedTS = "0" result = self.purview_client.upload_entities([entity]) - entity.guid = result['guidAssignments'][entity.guid] + entity.guid = result["guidAssignments"][entity.guid] print(f"Successfully created {entity.typeName} -- {entity.qualifiedName}") except AtlasException as e: if "PreConditionCheckFailed" in e.args[0]: - entity.guid = self.purview_client.get_entity(qualifiedName=entity.qualifiedName,typeName = entity.typeName)['entities'][0]['guid'] + entity.guid = self.purview_client.get_entity( + qualifiedName=entity.qualifiedName, typeName=entity.typeName + )["entities"][0]["guid"] print(f"Found existing entity {entity.guid}, {entity.typeName} -- {entity.qualifiedName}") return UUID(entity.guid) - - def _generate_relation_pairs(self, from_entity:dict, to_entity:dict, relation_type): + + def _generate_relation_pairs(self, from_entity: dict, to_entity: dict, relation_type): type_lookup = {RELATION_CONTAINS: RELATION_BELONGSTO, RELATION_CONSUMES: RELATION_PRODUCES} - forward_relation = AtlasProcess( + forward_relation = AtlasProcess( name=str(from_entity["guid"]) + " to " + str(to_entity["guid"]), typeName="Process", qualified_name=self.registry_delimiter.join( - [relation_type,str(from_entity["guid"]), str(to_entity["guid"])]), + [relation_type, str(from_entity["guid"]), str(to_entity["guid"])] + ), inputs=[self.to_min_repr(from_entity)], outputs=[self.to_min_repr(to_entity)], - guid=self.guid.get_guid()) - + guid=self.guid.get_guid(), + ) + backward_relation = AtlasProcess( name=str(to_entity["guid"]) + " to " + str(from_entity["guid"]), typeName="Process", qualified_name=self.registry_delimiter.join( - [type_lookup[relation_type], str(to_entity["guid"]), str(from_entity["guid"])]), + [type_lookup[relation_type], str(to_entity["guid"]), str(from_entity["guid"])] + ), inputs=[self.to_min_repr(to_entity)], outputs=[self.to_min_repr(from_entity)], - guid=self.guid.get_guid()) - return [forward_relation,backward_relation] - - def to_min_repr(self,entity:dict) -> dict: + guid=self.guid.get_guid(), + ) + return [forward_relation, backward_relation] + + def to_min_repr(self, entity: dict) -> dict: return { - 'qualifiedName':entity['attributes']["qualifiedName"], - 'guid':str(entity["guid"]), - 'typeName':str(entity['typeName']), + "qualifiedName": entity["attributes"]["qualifiedName"], + "guid": str(entity["guid"]), + "typeName": str(entity["typeName"]), } def _create_source(self, s: Source) -> UUID: - ''' + """ create a data source under a project. this will create the data source entity, together with the relation entity - ''' - project_entity = self.purview_client.get_entity(qualifiedName=self.project_name,typeName=TYPEDEF_FEATHR_PROJECT)['entities'][0] + """ + project_entity = self.purview_client.get_entity( + qualifiedName=self.project_name, typeName=TYPEDEF_FEATHR_PROJECT + )["entities"][0] attrs = source_to_def(s) - qualified_name = self.registry_delimiter.join([project_entity['attributes']['qualifiedName'],attrs['name']]) + qualified_name = self.registry_delimiter.join([project_entity["attributes"]["qualifiedName"], attrs["name"]]) source_entity = AtlasEntity( - name=attrs['name'], + name=attrs["name"], qualified_name=qualified_name, - attributes= {k:v for k,v in attrs.items() if k !="name"}, + attributes={k: v for k, v in attrs.items() if k != "name"}, typeName=TYPEDEF_SOURCE, guid=self.guid.get_guid(), ) source_id = self.upload_single_entity_to_purview(source_entity) - + # change from AtlasEntity to Entity - source_entity = self.purview_client.get_entity(source_id)['entities'][0] + source_entity = self.purview_client.get_entity(source_id)["entities"][0] # Project contains source, source belongs to project project_contains_source_relation = self._generate_relation_pairs( - project_entity, source_entity, RELATION_CONTAINS) + project_entity, source_entity, RELATION_CONTAINS + ) [self.upload_single_entity_to_purview(x) for x in project_contains_source_relation] return source_id def _create_anchor(self, s: FeatureAnchor) -> UUID: - ''' + """ Create anchor under project ,and based on the data source - This will also create two relation pairs - ''' - project_entity = self.purview_client.get_entity(qualifiedName=self.project_name,typeName=TYPEDEF_FEATHR_PROJECT)['entities'][0] - source_entity = self.purview_client.get_entity(qualifiedName=self.registry_delimiter.join([self.project_name,s.source.name]),typeName=TYPEDEF_SOURCE)['entities'][0] + This will also create two relation pairs + """ + project_entity = self.purview_client.get_entity( + qualifiedName=self.project_name, typeName=TYPEDEF_FEATHR_PROJECT + )["entities"][0] + source_entity = self.purview_client.get_entity( + qualifiedName=self.registry_delimiter.join([self.project_name, s.source.name]), typeName=TYPEDEF_SOURCE + )["entities"][0] attrs = anchor_to_def(s) - qualified_name = self.registry_delimiter.join([self.project_name,attrs['name']]) + qualified_name = self.registry_delimiter.join([self.project_name, attrs["name"]]) anchor_entity = AtlasEntity( name=s.name, qualified_name=qualified_name, - attributes= {k:v for k,v in attrs.items() if k not in ['name','qualifiedName']}, + attributes={k: v for k, v in attrs.items() if k not in ["name", "qualifiedName"]}, typeName=TYPEDEF_ANCHOR, guid=self.guid.get_guid(), ) @@ -680,99 +743,120 @@ def _create_anchor(self, s: FeatureAnchor) -> UUID: anchor_id = self.upload_single_entity_to_purview(anchor_entity) # change from AtlasEntity to Entity - anchor_entity = self.purview_client.get_entity(anchor_id)['entities'][0] + anchor_entity = self.purview_client.get_entity(anchor_id)["entities"][0] - # project contains anchor, anchor belongs to project. project_contains_anchor_relation = self._generate_relation_pairs( - project_entity, anchor_entity, RELATION_CONTAINS) - anchor_consumes_source_relation = self._generate_relation_pairs( - anchor_entity,source_entity, RELATION_CONSUMES) - [self.upload_single_entity_to_purview(x) for x in project_contains_anchor_relation + anchor_consumes_source_relation] + project_entity, anchor_entity, RELATION_CONTAINS + ) + anchor_consumes_source_relation = self._generate_relation_pairs(anchor_entity, source_entity, RELATION_CONSUMES) + [ + self.upload_single_entity_to_purview(x) + for x in project_contains_anchor_relation + anchor_consumes_source_relation + ] return anchor_id - def _create_anchor_feature(self, anchor_id: str, source:Source,s: Feature) -> UUID: - ''' - Create anchor feature under anchor. + def _create_anchor_feature(self, anchor_id: str, source: Source, s: Feature) -> UUID: + """ + Create anchor feature under anchor. This will also create three relation pairs - ''' - project_entity = self.purview_client.get_entity(qualifiedName=self.project_name,typeName=TYPEDEF_FEATHR_PROJECT)['entities'][0] - anchor_entity = self.purview_client.get_entity(anchor_id)['entities'][0] - source_entity = self.purview_client.get_entity(qualifiedName=self.registry_delimiter.join([self.project_name,source.name]),typeName=TYPEDEF_SOURCE)['entities'][0] + """ + project_entity = self.purview_client.get_entity( + qualifiedName=self.project_name, typeName=TYPEDEF_FEATHR_PROJECT + )["entities"][0] + anchor_entity = self.purview_client.get_entity(anchor_id)["entities"][0] + source_entity = self.purview_client.get_entity( + qualifiedName=self.registry_delimiter.join([self.project_name, source.name]), typeName=TYPEDEF_SOURCE + )["entities"][0] attrs = feature_to_def(s) - attrs['type'] = attrs['featureType'] - qualified_name = self.registry_delimiter.join([self.project_name, - anchor_entity["attributes"]["name"], - attrs['name']]) + attrs["type"] = attrs["featureType"] + qualified_name = self.registry_delimiter.join( + [self.project_name, anchor_entity["attributes"]["name"], attrs["name"]] + ) anchor_feature_entity = AtlasEntity( name=s.name, qualified_name=qualified_name, - attributes= {k:v for k,v in attrs.items() if k not in ['name','qualifiedName']}, + attributes={k: v for k, v in attrs.items() if k not in ["name", "qualifiedName"]}, typeName=TYPEDEF_ANCHOR_FEATURE, - guid=self.guid.get_guid()) + guid=self.guid.get_guid(), + ) anchor_feature_id = self.upload_single_entity_to_purview(anchor_feature_entity) # change from AtlasEntity to Entity - anchor_feature_entity = self.purview_client.get_entity(anchor_feature_id)['entities'][0] + anchor_feature_entity = self.purview_client.get_entity(anchor_feature_id)["entities"][0] # Project contains AnchorFeature, AnchorFeature belongs to Project project_contains_feature_relation = self._generate_relation_pairs( - project_entity, anchor_feature_entity, RELATION_CONTAINS) + project_entity, anchor_feature_entity, RELATION_CONTAINS + ) anchor_contains_feature_relation = self._generate_relation_pairs( - anchor_entity, anchor_feature_entity, RELATION_CONTAINS) + anchor_entity, anchor_feature_entity, RELATION_CONTAINS + ) feature_consumes_source_relation = self._generate_relation_pairs( - anchor_feature_entity, source_entity, RELATION_CONSUMES) + anchor_feature_entity, source_entity, RELATION_CONSUMES + ) + + [ + self.upload_single_entity_to_purview(x) + for x in project_contains_feature_relation + + anchor_contains_feature_relation + + feature_consumes_source_relation + ] - [self.upload_single_entity_to_purview(x) for x in - project_contains_feature_relation - + anchor_contains_feature_relation - + feature_consumes_source_relation] - return anchor_feature_id def _create_derived_feature(self, s: DerivedFeature) -> UUID: - ''' + """ Create DerivedFeature. This will also create multiple relations. - ''' - input_features = [self.purview_client.get_entity(x._registry_id)['entities'][0] for x in s.input_features] + """ + input_features = [self.purview_client.get_entity(x._registry_id)["entities"][0] for x in s.input_features] attrs = derived_feature_to_def(s) - attrs['type'] = attrs['featureType'] + attrs["type"] = attrs["featureType"] - project_entity = self.purview_client.get_entity(qualifiedName=self.project_name,typeName=TYPEDEF_FEATHR_PROJECT)['entities'][0] - qualified_name = self.registry_delimiter.join([self.project_name,attrs['name']]) + project_entity = self.purview_client.get_entity( + qualifiedName=self.project_name, typeName=TYPEDEF_FEATHR_PROJECT + )["entities"][0] + qualified_name = self.registry_delimiter.join([self.project_name, attrs["name"]]) derived_feature_entity = AtlasEntity( name=s.name, qualified_name=qualified_name, - attributes={k:v for k,v in attrs.items() if k not in ['name','qualifiedName']}, + attributes={k: v for k, v in attrs.items() if k not in ["name", "qualifiedName"]}, typeName=TYPEDEF_DERIVED_FEATURE, - guid=self.guid.get_guid()) + guid=self.guid.get_guid(), + ) derived_feature_id = self.upload_single_entity_to_purview(derived_feature_entity) - + # change from AtlasEntity to Entity - derived_feature_entity = self.purview_client.get_entity(derived_feature_id)['entities'][0] + derived_feature_entity = self.purview_client.get_entity(derived_feature_id)["entities"][0] # Project contains DerivedFeature, DerivedFeature belongs to Project. feature_project_contain_belong_pairs = self._generate_relation_pairs( - project_entity, derived_feature_entity, RELATION_CONTAINS) + project_entity, derived_feature_entity, RELATION_CONTAINS + ) consume_produce_pairs = [] # Each input feature produces DerivedFeature, DerivedFeatures consumes from each input feature. for input_feature in input_features: consume_produce_pairs += self._generate_relation_pairs( - derived_feature_entity, input_feature,RELATION_CONSUMES) - - [self.upload_single_entity_to_purview(x) for x in - feature_project_contain_belong_pairs - + consume_produce_pairs] - + derived_feature_entity, input_feature, RELATION_CONSUMES + ) + + [self.upload_single_entity_to_purview(x) for x in feature_project_contain_belong_pairs + consume_produce_pairs] + return derived_feature_id - def register_features(self, workspace_path: Optional[Path] = None, from_context: bool = True, anchor_list:List[FeatureAnchor]=[], derived_feature_list:List[DerivedFeature]=[]): - """Register Features for the specified workspace. + def register_features( + self, + workspace_path: Optional[Path] = None, + from_context: bool = True, + anchor_list: List[FeatureAnchor] = [], + derived_feature_list: List[DerivedFeature] = [], + ): + """Register Features for the specified workspace. Args: workspace_path (str, optional): path to a workspace. Defaults to None. @@ -782,8 +866,10 @@ def register_features(self, workspace_path: Optional[Path] = None, from_context: """ if not from_context: - raise RuntimeError("Currently Feathr only supports registering features from context (i.e. you must call FeathrClient.build_features() before calling this function).") - # Before starting, create the project + raise RuntimeError( + "Currently Feathr only supports registering features from context (i.e. you must call FeathrClient.build_features() before calling this function)." + ) + # Before starting, create the project self._register_feathr_feature_types() self.project_id = self._create_project() @@ -800,15 +886,12 @@ def register_features(self, workspace_path: Optional[Path] = None, from_context: # 3. Create all features on the registry for feature in anchor.features: if not hasattr(feature, "_registry_id"): - feature._registry_id = self._create_anchor_feature( - anchor._registry_id, anchor.source,feature) + feature._registry_id = self._create_anchor_feature(anchor._registry_id, anchor.source, feature) # 4. Create all derived features on the registry for df in topological_sort(derived_feature_list): if not hasattr(df, "_registry_id"): df._registry_id = self._create_derived_feature(df) - logger.info( - "Finished registering features.") - + logger.info("Finished registering features.") def _purge_feathr_registry(self): """ @@ -817,24 +900,22 @@ def _purge_feathr_registry(self): self._delete_all_feathr_entities() self._delete_all_feathr_types() - def _delete_all_feathr_types(self): """ Delete all the corresponding type definitions for feathr registry. For internal use only """ typedefs = self.purview_client.get_all_typedefs() - relationshipdef_list=[] - for relationshipdef in typedefs['relationshipDefs']: - if "feathr" in relationshipdef['name']: + relationshipdef_list = [] + for relationshipdef in typedefs["relationshipDefs"]: + if "feathr" in relationshipdef["name"]: relationshipdef_list.append(relationshipdef) self.purview_client.delete_typedefs(relationshipDefs=relationshipdef_list) - - entitydef_list=[] - for typedef in typedefs['entityDefs']: - if "feathr" in typedef['name']: - entitydef_list.append(typedef ) + entitydef_list = [] + for typedef in typedefs["entityDefs"]: + if "feathr" in typedef["name"]: + entitydef_list.append(typedef) self.purview_client.delete_typedefs(entityDefs=entitydef_list) logger.info("Deleted all the Feathr related definitions.") @@ -851,21 +932,20 @@ def _delete_all_feathr_entities(self): # use the `query` API so that it can return immediately (don't use the search_entity API as it will try to return all the results in a single request) while True: - result = self.purview_client.discovery.query( - "feathr", limit=batch_delete_size) - logger.info("Total number of entities:",result['@search.count'] ) + result = self.purview_client.discovery.query("feathr", limit=batch_delete_size) + logger.info("Total number of entities:", result["@search.count"]) # if no results, break: - if result['@search.count'] == 0: + if result["@search.count"] == 0: break - entities = result['value'] + entities = result["value"] guid_list = [entity["id"] for entity in entities] self.purview_client.delete_entity(guid=guid_list) logger.info("{} feathr entities deleted", batch_delete_size) # sleep here, otherwise backend might throttle # process the next batch after sleep sleep(1) - + @classmethod def _get_registry_client(self): """ @@ -873,7 +953,7 @@ def _get_registry_client(self): """ return self.purview_client - def list_registered_features(self, project_name: str, limit=1000, starting_offset=0) -> List[Dict[str,str]]: + def list_registered_features(self, project_name: str, limit=1000, starting_offset=0) -> List[Dict[str, str]]: """ List all the already registered features. If project_name is not provided or is None, it will return all the registered features; otherwise it will only return only features under this project @@ -889,30 +969,24 @@ def list_registered_features(self, project_name: str, limit=1000, starting_offse # see syntax here: https://docs.microsoft.com/en-us/rest/api/purview/catalogdataplane/discovery/query#discovery_query_andornested query_filter = { "and": [ - { - "or": - [ - {"entityType": TYPEDEF_DERIVED_FEATURE}, - {"entityType": TYPEDEF_ANCHOR_FEATURE} - ] - }, + {"or": [{"entityType": TYPEDEF_DERIVED_FEATURE}, {"entityType": TYPEDEF_ANCHOR_FEATURE}]}, { "attributeName": "qualifiedName", "operator": "startswith", - "attributeValue": project_name + self.registry_delimiter - } + "attributeValue": project_name + self.registry_delimiter, + }, ] } result = self.purview_client.discovery.query(filter=query_filter) - entities = result['value'] + entities = result["value"] # entities = self.purview_client.discovery.search_entities(query = None, search_filter=query_filter, limit=limit) for entity in entities: - feature_list.append({"name":entity["name"],'id':entity['id'],"qualifiedName":entity['qualifiedName']}) + feature_list.append({"name": entity["name"], "id": entity["id"], "qualifiedName": entity["qualifiedName"]}) return feature_list - + def list_dependent_entities(self, qualified_name: str): """ Returns list of dependent entities for provided entity @@ -924,34 +998,34 @@ def delete_entity(self, qualified_name: str): Deletes entity if it has no dependent entities """ raise NotImplementedError("Delete functionality supported through API") - + def get_feature_by_fqdn_type(self, qualifiedName, typeName): """ Get a single feature by it's QualifiedName and Type Returns the feature else throws an AtlasException with 400 error code """ response = self.purview_client.get_entity(qualifiedName=qualifiedName, typeName=typeName) - entities = response.get('entities') + entities = response.get("entities") for entity in entities: - if entity.get('typeName') == typeName and entity.get('attributes').get('qualifiedName') == qualifiedName: + if entity.get("typeName") == typeName and entity.get("attributes").get("qualifiedName") == qualifiedName: return entity - + def get_feature_by_fqdn(self, qualifiedName): """ Get feature by qualifiedName Returns the feature else throws an AtlasException with 400 error code - """ + """ id = self.get_feature_id(qualifiedName) return self.get_feature_by_guid(id) - + def get_feature_by_guid(self, guid): """ Get a single feature by it's GUID Returns the feature else throws an AtlasException with 400 error code - """ + """ response = self.purview_client.get_single_entity(guid=guid) return response - + def get_feature_lineage(self, guid): """ Get feature's lineage by it's GUID @@ -962,7 +1036,7 @@ def get_feature_lineage(self, guid): def get_feature_id(self, qualifiedName, type: str): """ Get guid of a feature given its qualifiedName - """ + """ # the search term should be full qualified name # TODO: need to update the calling functions to add `type` field to make it more performant # purview_client.get_entity(qualifiedName=qualifiedName) might not work here since it requires an additonal typeName parameter @@ -972,45 +1046,53 @@ def get_feature_id(self, qualifiedName, type: str): # get the corresponding features belongs to a certain project. # note that we need to use "eq" to filter exactly this qualified name # see syntax here: https://docs.microsoft.com/en-us/rest/api/purview/catalogdataplane/discovery/query#discovery_query_andornested - query_filter = { - "attributeName": "qualifiedName", - "operator": "eq", - "attributeValue": qualifiedName - } - result = self.purview_client.discovery.query(keywords = None, filter=query_filter) - entities = result['value'] + query_filter = {"attributeName": "qualifiedName", "operator": "eq", "attributeValue": qualifiedName} + result = self.purview_client.discovery.query(keywords=None, filter=query_filter) + entities = result["value"] # There should be exactly one result, but we don't enforce the check here for entity in entities: - if entity.get('qualifiedName') == qualifiedName: - return entity.get('id') + if entity.get("qualifiedName") == qualifiedName: + return entity.get("id") def search_features(self, searchTerm): """ Search the registry for the given query term For a ride hailing company few examples could be - "taxi", "passenger", "fare" etc. It's a keyword search on the registry metadata - """ + """ entities = self.purview_client.discovery.search_entities(searchTerm) return entities - - def _list_registered_entities_with_details(self, project_name: str, entity_type: Union[str, List[str]] = None, limit=1000, starting_offset=0,) -> List[Dict]: + + def _list_registered_entities_with_details( + self, + project_name: str, + entity_type: Union[str, List[str]] = None, + limit=1000, + starting_offset=0, + ) -> List[Dict]: """ List all the already registered entities. entity_type should be one of: SOURCE, DERIVED_FEATURE, ANCHOR, ANCHOR_FEATURE, FEATHR_PROJECT, or a list of those values limit: a maximum 1000 will be enforced at the underlying API - + returns a list of the result entities. """ - entity_type_list = [entity_type] if isinstance( - entity_type, str) else entity_type + entity_type_list = [entity_type] if isinstance(entity_type, str) else entity_type for i in entity_type_list: - if i not in {TYPEDEF_SOURCE, TYPEDEF_DERIVED_FEATURE, TYPEDEF_ANCHOR, TYPEDEF_ANCHOR_FEATURE, TYPEDEF_FEATHR_PROJECT}: + if i not in { + TYPEDEF_SOURCE, + TYPEDEF_DERIVED_FEATURE, + TYPEDEF_ANCHOR, + TYPEDEF_ANCHOR_FEATURE, + TYPEDEF_FEATHR_PROJECT, + }: raise RuntimeError( - f'only SOURCE, DERIVED_FEATURE, ANCHOR, ANCHOR_FEATURE, FEATHR_PROJECT are supported when listing the registered entities, {entity_type} is not one of them.') + f"only SOURCE, DERIVED_FEATURE, ANCHOR, ANCHOR_FEATURE, FEATHR_PROJECT are supported when listing the registered entities, {entity_type} is not one of them." + ) if project_name is None: raise RuntimeError("You need to specify a project_name") - # the search grammar: + # the search grammar: # https://docs.microsoft.com/en-us/azure/purview/how-to-search-catalog#search-query-syntax # https://docs.microsoft.com/en-us/rest/api/datacatalog/data-catalog-search-syntax-reference @@ -1023,28 +1105,33 @@ def _list_registered_entities_with_details(self, project_name: str, entity_type: # Hence if TYPEDEF_FEATHR_PROJECT is in the `entity_type` input, we need to search for that specifically # and finally "OR" the result to union them query_filter = { - "or": - [{ - "and": [{ - # this is a list of the entity types that you want to query - "or": [{"entityType": e} for e in entity_type_list] + "or": [ + { + "and": [ + { + # this is a list of the entity types that you want to query + "or": [{"entityType": e} for e in entity_type_list] + }, + { + "attributeName": "qualifiedName", + "operator": "startswith", + # use `project_name + self.registry_delimiter` to limit the search results + "attributeValue": project_name + self.registry_delimiter, + }, + ] }, - { - "attributeName": "qualifiedName", - "operator": "startswith", - # use `project_name + self.registry_delimiter` to limit the search results - "attributeValue": project_name + self.registry_delimiter - }]}, # if we are querying TYPEDEF_FEATHR_PROJECT, then "union" the result by using this query { - "and": [{ - "or": [{"entityType": TYPEDEF_FEATHR_PROJECT}] if TYPEDEF_FEATHR_PROJECT in entity_type_list else None + "and": [ + { + "or": [{"entityType": TYPEDEF_FEATHR_PROJECT}] + if TYPEDEF_FEATHR_PROJECT in entity_type_list + else None + }, + {"attributeName": "qualifiedName", "operator": "startswith", "attributeValue": project_name}, + ] }, - { - "attributeName": "qualifiedName", - "operator": "startswith", - "attributeValue": project_name - }]}] + ] } # Important properties returned includes: @@ -1053,13 +1140,18 @@ def _list_registered_entities_with_details(self, project_name: str, entity_type: # TODO: it might be throttled in the backend and wait for the `pyapacheatlas` to fix this # https://github.com/wjohnson/pyapacheatlas/issues/206 # `pyapacheatlas` needs a bit optimization to avoid additional calls. - result_entities = self.purview_client.discovery.search_entities(query=None, search_filter=query_filter, limit = limit) - + result_entities = self.purview_client.discovery.search_entities( + query=None, search_filter=query_filter, limit=limit + ) + # append the guid list. Since we are using project_name + delimiter to search, all the result will be valid. guid_list = [entity["id"] for entity in result_entities] - entity_res = [] if guid_list is None or len(guid_list)==0 else self.purview_client.get_entity( - guid=guid_list)["entities"] + entity_res = ( + [] + if guid_list is None or len(guid_list) == 0 + else self.purview_client.get_entity(guid=guid_list)["entities"] + ) return entity_res def get_features_from_registry(self, project_name: str) -> Tuple[List[FeatureAnchor], List[DerivedFeature]]: @@ -1068,88 +1160,151 @@ def get_features_from_registry(self, project_name: str) -> Tuple[List[FeatureAnc Args: project_name (str): project name. """ - project_entity = self.purview_client.get_entity(qualifiedName=project_name,typeName=TYPEDEF_FEATHR_PROJECT)['entities'][0] - - single_direction_process = [entity for _,entity in self.purview_client.get_entity_lineage(project_entity['guid'])['guidEntityMap'].items() \ - if entity['typeName']=='Process' and \ - (entity['attributes']['qualifiedName'].startswith(RELATION_CONTAINS) \ - or entity['attributes']['qualifiedName'].startswith(RELATION_CONSUMES))] - contain_relations = [x['displayText'].split(' to ') for x in single_direction_process if x['attributes']['qualifiedName'].startswith(RELATION_CONTAINS)] - - consume_relations = [x['displayText'].split(' to ') for x in single_direction_process if x['attributes']['qualifiedName'].startswith(RELATION_CONSUMES)] - - entities_under_project = [self.purview_client.get_entity(x[1])['entities'][0] for x in contain_relations if x[0]== project_entity['guid']] + project_entity = self.purview_client.get_entity(qualifiedName=project_name, typeName=TYPEDEF_FEATHR_PROJECT)[ + "entities" + ][0] + + single_direction_process = [ + entity + for _, entity in self.purview_client.get_entity_lineage(project_entity["guid"])["guidEntityMap"].items() + if entity["typeName"] == "Process" + and ( + entity["attributes"]["qualifiedName"].startswith(RELATION_CONTAINS) + or entity["attributes"]["qualifiedName"].startswith(RELATION_CONSUMES) + ) + ] + contain_relations = [ + x["displayText"].split(" to ") + for x in single_direction_process + if x["attributes"]["qualifiedName"].startswith(RELATION_CONTAINS) + ] + + consume_relations = [ + x["displayText"].split(" to ") + for x in single_direction_process + if x["attributes"]["qualifiedName"].startswith(RELATION_CONSUMES) + ] + + entities_under_project = [ + self.purview_client.get_entity(x[1])["entities"][0] + for x in contain_relations + if x[0] == project_entity["guid"] + ] if not entities_under_project: # if the result is empty return (None, None) - entities_dict = {x['guid']:x for x in entities_under_project} + entities_dict = {x["guid"]: x for x in entities_under_project} derived_feature_list = [] - for derived_feature_entity in [x for x in entities_under_project if x['typeName']==TYPEDEF_DERIVED_FEATURE]: + for derived_feature_entity in [x for x in entities_under_project if x["typeName"] == TYPEDEF_DERIVED_FEATURE]: # this will be used to generate DerivedFeature instance derived_feature_key_list = [] - + for key in derived_feature_entity["attributes"]["key"]: - derived_feature_key_list.append(TypedKey(key_column=key["keyColumn"], key_column_type=key["keyColumnType"], full_name=key["fullName"], description=key["description"], key_column_alias=key["keyColumnAlias"])) - - def search_for_input_feature(elem, full_relations,full_entities): - matching_relations = [x for x in full_relations if x[0]==elem['guid']] + derived_feature_key_list.append( + TypedKey( + key_column=key["keyColumn"], + key_column_type=key["keyColumnType"], + full_name=key["fullName"], + description=key["description"], + key_column_alias=key["keyColumnAlias"], + ) + ) + + def search_for_input_feature(elem, full_relations, full_entities): + matching_relations = [x for x in full_relations if x[0] == elem["guid"]] target_entities = [full_entities[x[1]] for x in matching_relations] - input_features = [x for x in target_entities if x['typeName']==TYPEDEF_ANCHOR_FEATURE \ - or x['typeName']==TYPEDEF_DERIVED_FEATURE] - result_features=[] + input_features = [ + x + for x in target_entities + if x["typeName"] == TYPEDEF_ANCHOR_FEATURE or x["typeName"] == TYPEDEF_DERIVED_FEATURE + ] + result_features = [] for feature_entity in input_features: - key_list=[] + key_list = [] for key in feature_entity["attributes"]["key"]: - key_list.append(TypedKey(key_column=key["keyColumn"], key_column_type=key["keyColumnType"], full_name=key["fullName"], description=key["description"], key_column_alias=key["keyColumnAlias"])) - result_features.append(Feature(name=feature_entity["attributes"]["name"], - feature_type=self._get_feature_type_from_hocon(feature_entity["attributes"]["type"]), # stored as a hocon string, can be parsed using pyhocon - transform=self._get_transformation_from_dict(feature_entity["attributes"]['transformation']), #transform attributes are stored in a dict fashion , can be put in a WindowAggTransformation - key=key_list, # since all features inside an anchor should share the same key, pick the first one. - registry_tags=feature_entity["attributes"]["tags"])) + key_list.append( + TypedKey( + key_column=key["keyColumn"], + key_column_type=key["keyColumnType"], + full_name=key["fullName"], + description=key["description"], + key_column_alias=key["keyColumnAlias"], + ) + ) + result_features.append( + Feature( + name=feature_entity["attributes"]["name"], + feature_type=self._get_feature_type_from_hocon( + feature_entity["attributes"]["type"] + ), # stored as a hocon string, can be parsed using pyhocon + transform=self._get_transformation_from_dict( + feature_entity["attributes"]["transformation"] + ), # transform attributes are stored in a dict fashion , can be put in a WindowAggTransformation + key=key_list, # since all features inside an anchor should share the same key, pick the first one. + registry_tags=feature_entity["attributes"]["tags"], + ) + ) return result_features + all_input_features = search_for_input_feature(derived_feature_entity, consume_relations, entities_dict) + derived_feature_list.append( + DerivedFeature( + name=derived_feature_entity["attributes"]["name"], + feature_type=self._get_feature_type_from_hocon(derived_feature_entity["attributes"]["type"]), + transform=self._get_transformation_from_dict( + derived_feature_entity["attributes"]["transformation"] + ), + key=derived_feature_key_list, + input_features=all_input_features, + registry_tags=derived_feature_entity["attributes"]["tags"], + ) + ) - all_input_features = search_for_input_feature(derived_feature_entity,consume_relations,entities_dict) - derived_feature_list.append(DerivedFeature(name=derived_feature_entity["attributes"]["name"], - feature_type=self._get_feature_type_from_hocon(derived_feature_entity["attributes"]["type"]), - transform=self._get_transformation_from_dict(derived_feature_entity["attributes"]['transformation']), - key=derived_feature_key_list, - input_features= all_input_features, - registry_tags=derived_feature_entity["attributes"]["tags"])) - # anchor_result = self.purview_client.get_entity(guid=anchor_guid)["entities"] anchor_list = [] - for anchor_entity in [x for x in entities_under_project if x['typeName']==TYPEDEF_ANCHOR]: - consume_items_under_anchor = [entities_dict[x[1]] for x in consume_relations if x[0]==anchor_entity['guid']] - source_entity = [x for x in consume_items_under_anchor if x['typeName']==TYPEDEF_SOURCE][0] - - contain_items_under_anchor = [entities_dict[x[1]] for x in contain_relations if x[0]==anchor_entity['guid']] - anchor_features_guid = [x['guid'] for x in contain_items_under_anchor if x['typeName']==TYPEDEF_ANCHOR_FEATURE] - - anchor_list.append(FeatureAnchor(name=anchor_entity["attributes"]["name"], - source=HdfsSource( - name=source_entity["attributes"]["name"], - event_timestamp_column=source_entity["attributes"]["event_timestamp_column"], - timestamp_format=source_entity["attributes"]["timestamp_format"], - preprocessing=self._correct_function_indentation(source_entity["attributes"]["preprocessing"]), - path=source_entity["attributes"]["path"], - registry_tags=source_entity["attributes"]["tags"] - ), - features=self._get_features_by_guid_or_entities(guid_list = anchor_features_guid, entity_list=entities_under_project), - registry_tags=anchor_entity["attributes"]["tags"])) + for anchor_entity in [x for x in entities_under_project if x["typeName"] == TYPEDEF_ANCHOR]: + consume_items_under_anchor = [ + entities_dict[x[1]] for x in consume_relations if x[0] == anchor_entity["guid"] + ] + source_entity = [x for x in consume_items_under_anchor if x["typeName"] == TYPEDEF_SOURCE][0] + + contain_items_under_anchor = [ + entities_dict[x[1]] for x in contain_relations if x[0] == anchor_entity["guid"] + ] + anchor_features_guid = [ + x["guid"] for x in contain_items_under_anchor if x["typeName"] == TYPEDEF_ANCHOR_FEATURE + ] + + anchor_list.append( + FeatureAnchor( + name=anchor_entity["attributes"]["name"], + source=HdfsSource( + name=source_entity["attributes"]["name"], + event_timestamp_column=source_entity["attributes"]["event_timestamp_column"], + timestamp_format=source_entity["attributes"]["timestamp_format"], + preprocessing=self._correct_function_indentation(source_entity["attributes"]["preprocessing"]), + path=source_entity["attributes"]["path"], + registry_tags=source_entity["attributes"]["tags"], + ), + features=self._get_features_by_guid_or_entities( + guid_list=anchor_features_guid, entity_list=entities_under_project + ), + registry_tags=anchor_entity["attributes"]["tags"], + ) + ) return (anchor_list, derived_feature_list) - def search_input_anchor_features(self,derived_guids,feature_entity_guid_mapping) ->List[str]: - ''' + def search_input_anchor_features(self, derived_guids, feature_entity_guid_mapping) -> List[str]: + """ Iterate all derived features and its parent links, extract and aggregate all inputs - ''' + """ stack = [x for x in derived_guids] result = [] - while len(stack)>0: + while len(stack) > 0: current_derived_guid = stack.pop() current_input = feature_entity_guid_mapping[current_derived_guid] new_derived_features = [x["guid"] for x in current_input["attributes"]["input_derived_features"]] @@ -1160,7 +1315,6 @@ def search_input_anchor_features(self,derived_guids,feature_entity_guid_mapping) result = list(set(result)) return result - def _correct_function_indentation(self, user_func: str) -> str: """ The function read from registry might have the wrong indentation. We need to correct those indentation. @@ -1178,33 +1332,31 @@ def feathr_udf2(df) if user_func is None: return None # if user_func is a string, turn it into a list of strings so that it can be used below - temp_udf_source_code = user_func.split('\n') + temp_udf_source_code = user_func.split("\n") # assuming the first line is the function name leading_space_num = len(temp_udf_source_code[0]) - len(temp_udf_source_code[0].lstrip()) # strip the lines to make sure the function has the correct indentation udf_source_code_striped = [line[leading_space_num:] for line in temp_udf_source_code] # append '\n' back since it was deleted due to the previous split - udf_source_code = [line+'\n' for line in udf_source_code_striped] + udf_source_code = [line + "\n" for line in udf_source_code_striped] return " ".join(udf_source_code) def _get_source_by_guid(self, guid, entity_list) -> Source: - """give a entity list and the target GUID for the source entity, return a python `Source` object. - """ + """give a entity list and the target GUID for the source entity, return a python `Source` object.""" # TODO: currently return HDFS source by default. For JDBC source, it's currently implemented using HDFS Source so we should split in the future # there should be only one entity available - source_entity = [x for x in entity_list if x['guid'] == guid][0] + source_entity = [x for x in entity_list if x["guid"] == guid][0] # if source_entity["attributes"]["path"] is INPUT_CONTEXT, it will also be assigned to this returned object - return HdfsSource(name=source_entity["attributes"]["name"], - event_timestamp_column=source_entity["attributes"]["event_timestamp_column"], - timestamp_format=source_entity["attributes"]["timestamp_format"], - preprocessing=self._correct_function_indentation(source_entity["attributes"]["preprocessing"]), - path=source_entity["attributes"]["path"], - registry_tags=source_entity["attributes"]["tags"] - ) - - + return HdfsSource( + name=source_entity["attributes"]["name"], + event_timestamp_column=source_entity["attributes"]["event_timestamp_column"], + timestamp_format=source_entity["attributes"]["timestamp_format"], + preprocessing=self._correct_function_indentation(source_entity["attributes"]["preprocessing"]), + path=source_entity["attributes"]["path"], + registry_tags=source_entity["attributes"]["tags"], + ) def _get_feature_type_from_hocon(self, input_str: str) -> FeatureType: """Get Feature types from a HOCON config, given that we stored the feature type in a plain string. @@ -1218,56 +1370,58 @@ def _get_feature_type_from_hocon(self, input_str: str) -> FeatureType: if not input_str: return None conf = ConfigFactory.parse_string(input_str) - valType = conf.get_string('valType') if 'valType' in conf else conf.get_string('type.valType') - dimensionType = conf.get_string('dimensionType') if 'dimensionType' in conf else conf.get_string('type.dimensionType') - if dimensionType == '[INT]': + valType = conf.get_string("valType") if "valType" in conf else conf.get_string("type.valType") + dimensionType = ( + conf.get_string("dimensionType") if "dimensionType" in conf else conf.get_string("type.dimensionType") + ) + if dimensionType == "[INT]": # if it's not empty, i.e. [INT], indicating it's vectors - if valType == 'DOUBLE': + if valType == "DOUBLE": return DoubleVectorFeatureType() - elif valType == 'LONG': + elif valType == "LONG": return Int64VectorFeatureType() - elif valType == 'INT': + elif valType == "INT": return Int32VectorFeatureType() - elif valType == 'FLOAT': + elif valType == "FLOAT": return FloatVectorFeatureType() else: logger.error("{} cannot be parsed.", valType) else: - if valType == 'STRING': + if valType == "STRING": return StringFeatureType() - elif valType == 'BYTES': + elif valType == "BYTES": return BytesFeatureType() - elif valType == 'DOUBLE': + elif valType == "DOUBLE": return DoubleFeatureType() - elif valType == 'FLOAT': + elif valType == "FLOAT": return FloatFeatureType() - elif valType == 'LONG': + elif valType == "LONG": return Int64FeatureType() - elif valType == 'INT': + elif valType == "INT": return Int32FeatureType() - elif valType == 'BOOLEAN': + elif valType == "BOOLEAN": return BooleanFeatureType() else: logger.error("{} cannot be parsed.", valType) def _get_transformation_from_dict(self, input: Dict) -> FeatureType: - if 'transformExpr' in input: + if "transformExpr" in input: # it's ExpressionTransformation - return ExpressionTransformation(input['transformExpr']) - elif 'def_expr' in input or 'defExpr' in input: - agg_expr=input['def_expr'] if 'def_expr' in input else (input['defExpr'] if 'defExpr' in input else None) - agg_func=input['agg_func']if 'agg_func' in input else (input['aggFunc'] if 'aggFunc' in input else None) - window=input['window']if 'window' in input else None - group_by=input['group_by']if 'group_by' in input else (input['groupBy'] if 'groupBy' in input else None) - filter=input['filter']if 'filter' in input else None - limit=input['limit']if 'limit' in input else None + return ExpressionTransformation(input["transformExpr"]) + elif "def_expr" in input or "defExpr" in input: + agg_expr = input["def_expr"] if "def_expr" in input else (input["defExpr"] if "defExpr" in input else None) + agg_func = input["agg_func"] if "agg_func" in input else (input["aggFunc"] if "aggFunc" in input else None) + window = input["window"] if "window" in input else None + group_by = input["group_by"] if "group_by" in input else (input["groupBy"] if "groupBy" in input else None) + filter = input["filter"] if "filter" in input else None + limit = input["limit"] if "limit" in input else None return WindowAggTransformation(agg_expr, agg_func, window, group_by, filter, limit) else: # no transformation function observed return None def _get_features_by_guid_or_entities(self, guid_list, entity_list) -> List[FeatureAnchor]: - """return a python list of the features that are referenced by a list of guids. + """return a python list of the features that are referenced by a list of guids. If entity_list is provided, use entity_list to reconstruct those features This is for "anchor feature" only. """ @@ -1275,34 +1429,48 @@ def _get_features_by_guid_or_entities(self, guid_list, entity_list) -> List[Feat feature_entities = self.purview_client.get_entity(guid=guid_list)["entities"] else: guid_set = set(guid_list) - feature_entities = [x for x in entity_list if x['guid'] in guid_set] + feature_entities = [x for x in entity_list if x["guid"] in guid_set] # raise error if we cannot find all the guid if len(feature_entities) != len(guid_list): - raise RuntimeError("Number of `feature_entities` is less than provided GUID list for search. The project might be broken.") + raise RuntimeError( + "Number of `feature_entities` is less than provided GUID list for search. The project might be broken." + ) - feature_list=[] - - ''' + feature_list = [] + + """ The assumption here is , a feture could have multiple keys, and features inside an anchor should share the same set of keys. So we will take any one of the feature, extract its keys , dedup them by full name, and use them to generate the key list. - ''' + """ first_feature_keys = feature_entities[0]["attributes"]["key"] deduped_keys = dict() for key in first_feature_keys: - if key['fullName'] not in deduped_keys: - deduped_keys.setdefault(key['fullName'],key) + if key["fullName"] not in deduped_keys: + deduped_keys.setdefault(key["fullName"], key) key_list = [ - TypedKey(key_column=key["keyColumn"], key_column_type=key["keyColumnType"], full_name=key["fullName"], description=key["description"], key_column_alias=key["keyColumnAlias"])\ - for key in list(deduped_keys.values()) - ] - for feature_entity in feature_entities: + TypedKey( + key_column=key["keyColumn"], + key_column_type=key["keyColumnType"], + full_name=key["fullName"], + description=key["description"], + key_column_alias=key["keyColumnAlias"], + ) + for key in list(deduped_keys.values()) + ] + for feature_entity in feature_entities: # after get keys, put them in features - feature_list.append(Feature(name=feature_entity["attributes"]["name"], - feature_type=self._get_feature_type_from_hocon(feature_entity["attributes"]["type"]), # stored as a hocon string, can be parsed using pyhocon - transform=self._get_transformation_from_dict(feature_entity["attributes"]['transformation']), #transform attributes are stored in a dict fashion , can be put in a WindowAggTransformation - key=key_list, + feature_list.append( + Feature( + name=feature_entity["attributes"]["name"], + feature_type=self._get_feature_type_from_hocon( + feature_entity["attributes"]["type"] + ), # stored as a hocon string, can be parsed using pyhocon + transform=self._get_transformation_from_dict( + feature_entity["attributes"]["transformation"] + ), # transform attributes are stored in a dict fashion , can be put in a WindowAggTransformation + key=key_list, registry_tags=feature_entity["attributes"]["tags"], - - )) - return feature_list + ) + ) + return feature_list diff --git a/feathr_project/feathr/registry/feature_registry.py b/feathr_project/feathr/registry/feature_registry.py index 2bea40653..8af000b54 100644 --- a/feathr_project/feathr/registry/feature_registry.py +++ b/feathr_project/feathr/registry/feature_registry.py @@ -4,21 +4,20 @@ from feathr.definition.feature_derivations import DerivedFeature from feathr.definition.anchor import FeatureAnchor + class FeathrRegistry(ABC): - """This is the abstract class for all the feature registries. All the feature registries should implement those interfaces. - """ + """This is the abstract class for all the feature registries. All the feature registries should implement those interfaces.""" @abstractmethod - def register_features(self, anchor_list: List[FeatureAnchor] =[], derived_feature_list: List[DerivedFeature]=[]): + def register_features(self, anchor_list: List[FeatureAnchor] = [], derived_feature_list: List[DerivedFeature] = []): """Registers features based on the current workspace - Args: - anchor_list: List of FeatureAnchors - derived_feature_list: List of DerivedFeatures + Args: + anchor_list: List of FeatureAnchors + derived_feature_list: List of DerivedFeatures """ pass - @abstractmethod def list_registered_features(self, project_name: str) -> List[str]: """List all the already registered features under the given project. @@ -51,5 +50,3 @@ def get_features_from_registry(self, project_name: str) -> Tuple[List[FeatureAnc bool: Returns true if the job completed successfully, otherwise False """ pass - - diff --git a/feathr_project/feathr/registry/registry_utils.py b/feathr_project/feathr/registry/registry_utils.py index 7f347d278..6a85863ff 100644 --- a/feathr_project/feathr/registry/registry_utils.py +++ b/feathr_project/feathr/registry/registry_utils.py @@ -8,17 +8,29 @@ from feathr.definition.feature import Feature from feathr.definition.feature_derivations import DerivedFeature from feathr.definition.source import HdfsSource, JdbcSource, Source, SnowflakeSource -from pyapacheatlas.core import AtlasProcess,AtlasEntity -from feathr.definition.source import GenericSource, HdfsSource, JdbcSource, SnowflakeSource, Source, SparkSqlSource, KafKaSource, CosmosDbSource, ElasticSearchSource +from pyapacheatlas.core import AtlasProcess, AtlasEntity +from feathr.definition.source import ( + GenericSource, + HdfsSource, + JdbcSource, + SnowflakeSource, + Source, + SparkSqlSource, + KafKaSource, + CosmosDbSource, + ElasticSearchSource, +) from feathr.definition.transformation import ExpressionTransformation, Transformation, WindowAggTransformation from feathr.definition.typed_key import TypedKey + + def to_camel(s): if not s: return s if isinstance(s, str): if "_" in s: s = sub(r"(_)+", " ", s).title().replace(" ", "") - return ''.join([s[0].lower(), s[1:]]) + return "".join([s[0].lower(), s[1:]]) return s elif isinstance(s, list): return [to_camel(i) for i in s] @@ -26,7 +38,7 @@ def to_camel(s): return dict([(to_camel(k), s[k]) for k in s]) -# TODO: need to update the other sources to make the code cleaner +# TODO: need to update the other sources to make the code cleaner def source_to_def(source: Source) -> dict: ret = {} if source.name == INPUT_CONTEXT: @@ -45,9 +57,9 @@ def source_to_def(source: Source) -> dict: ret = { "name": source.name, "type": "kafka", - "brokers":source.config.brokers, - "topics":source.config.topics, - "schemaStr":source.config.schema.schemaStr + "brokers": source.config.brokers, + "topics": source.config.topics, + "schemaStr": source.config.schema.schemaStr, } print("ret is", ret) elif isinstance(source, SnowflakeSource): @@ -76,8 +88,7 @@ def source_to_def(source: Source) -> dict: ret["name"] = source.name else: raise ValueError(f"Unsupported source type {source.__class__}") - - + if hasattr(source, "preprocessing") and source.preprocessing: ret["preprocessing"] = inspect.getsource(source.preprocessing) if source.event_timestamp_column: @@ -90,10 +101,10 @@ def source_to_def(source: Source) -> dict: ret["tags"] = source.registry_tags return ret - + def anchor_to_def(v: FeatureAnchor) -> dict: - # Note that after this method, attributes are Camel cased (eventTimestampColumn). - # If the old logic works with snake case (event_timestamp_column), make sure you handle them manually. + # Note that after this method, attributes are Camel cased (eventTimestampColumn). + # If the old logic works with snake case (event_timestamp_column), make sure you handle them manually. source_id = v.source._registry_id ret = { "name": v.name, @@ -121,22 +132,19 @@ def feathr_udf2(df) if user_func is None: return None # if user_func is a string, turn it into a list of strings so that it can be used below - temp_udf_source_code = user_func.split('\n') + temp_udf_source_code = user_func.split("\n") # assuming the first line is the function name - leading_space_num = len( - temp_udf_source_code[0]) - len(temp_udf_source_code[0].lstrip()) + leading_space_num = len(temp_udf_source_code[0]) - len(temp_udf_source_code[0].lstrip()) # strip the lines to make sure the function has the correct indentation - udf_source_code_striped = [line[leading_space_num:] - for line in temp_udf_source_code] + udf_source_code_striped = [line[leading_space_num:] for line in temp_udf_source_code] # append '\n' back since it was deleted due to the previous split - udf_source_code = [line+'\n' for line in udf_source_code_striped] + udf_source_code = [line + "\n" for line in udf_source_code_striped] return " ".join(udf_source_code) + def transformation_to_def(v: Transformation) -> dict: if isinstance(v, ExpressionTransformation): - return { - "transformExpr": v.expr - } + return {"transformExpr": v.expr} elif isinstance(v, WindowAggTransformation): ret = { "defExpr": v.def_expr, @@ -154,9 +162,10 @@ def transformation_to_def(v: Transformation) -> dict: return ret raise ValueError("Unsupported Transformation type") + def feature_type_to_def(v: FeatureType) -> dict: - # Note that after this method, attributes are Camel cased (eventTimestampColumn). - # If the old logic works with snake case (event_timestamp_column), make sure you handle them manually. + # Note that after this method, attributes are Camel cased (eventTimestampColumn). + # If the old logic works with snake case (event_timestamp_column), make sure you handle them manually. return { "type": v.type, "tensorCategory": v.tensor_category, @@ -164,11 +173,9 @@ def feature_type_to_def(v: FeatureType) -> dict: "valType": value_type_to_str(v.val_type), } + def typed_key_to_def(v: TypedKey) -> dict: - ret = { - "keyColumn": v.key_column, - "keyColumnType": value_type_to_str(v.key_column_type) - } + ret = {"keyColumn": v.key_column, "keyColumnType": value_type_to_str(v.key_column_type)} if v.full_name: ret["fullName"] = v.full_name if v.description: @@ -177,6 +184,7 @@ def typed_key_to_def(v: TypedKey) -> dict: ret["keyColumnAlias"] = v.key_column_alias return ret + def feature_to_def(v: Feature) -> dict: ret = { "name": v.name, @@ -184,15 +192,15 @@ def feature_to_def(v: Feature) -> dict: "key": [typed_key_to_def(k) for k in v.key], } if v.transform: - ret["transformation"] = transformation_to_def( - v.transform) + ret["transformation"] = transformation_to_def(v.transform) if v.registry_tags: ret["tags"] = v.registry_tags return ret + def derived_feature_to_def(v: DerivedFeature) -> dict: - # Note that after this method, attributes are Camel cased (eventTimestampColumn). - # If the old logic works with snake case (event_timestamp_column), make sure you handle them manually. + # Note that after this method, attributes are Camel cased (eventTimestampColumn). + # If the old logic works with snake case (event_timestamp_column), make sure you handle them manually. ret = { "name": v.name, "featureType": feature_type_to_def(v.feature_type), @@ -204,6 +212,7 @@ def derived_feature_to_def(v: DerivedFeature) -> dict: ret["transformation"] = transformation_to_def(v.transform) return ret + def topological_sort(derived_feature_list: List[DerivedFeature]) -> List[DerivedFeature]: """ In the current registry implementation, we need to make sure all upstream are registered before registering one derived feature @@ -212,32 +221,32 @@ def topological_sort(derived_feature_list: List[DerivedFeature]) -> List[Derived ret = [] # We don't want to destroy the input list input = derived_feature_list.copy() - + # Each round add the most downstream features into `ret`, so `ret` is in reversed order while input: # Process all remaining features current = input.copy() - + # In Python you should not alter content while iterating current_copy = current.copy() - + # Go over all remaining features to see if some feature depends on others for f in current_copy: for i in f.input_features: if i in current: # Someone depends on feature `i`, so `i` is **not** the most downstream current.remove(i) - + # Now `current` contains only the most downstream features in this round ret.extend(current) - + # Remove one level of dependency from input for f in current: input.remove(f) - + # The ret was in a reversed order when it's generated ret.reverse() - - if len(set(ret)) != len (set(derived_feature_list)): + + if len(set(ret)) != len(set(derived_feature_list)): raise ValueError("Cyclic dependency detected") - return ret \ No newline at end of file + return ret diff --git a/feathr_project/feathr/secrets/akv_client.py b/feathr_project/feathr/secrets/akv_client.py index cdec01e12..064a3993a 100644 --- a/feathr_project/feathr/secrets/akv_client.py +++ b/feathr_project/feathr/secrets/akv_client.py @@ -3,6 +3,7 @@ from loguru import logger from azure.core.exceptions import ResourceNotFoundError + class AzureKeyVaultClient: def __init__(self, akv_name: str): self.akv_name = akv_name @@ -16,16 +17,15 @@ def get_feathr_akv_secret(self, secret_name: str): """ if self.secret_client is None: self.secret_client = SecretClient( - vault_url = f"https://{self.akv_name}.vault.azure.net", - credential=DefaultAzureCredential() + vault_url=f"https://{self.akv_name}.vault.azure.net", credential=DefaultAzureCredential() ) try: # replace '_' with '-' since Azure Key Vault doesn't support it - variable_replaced = secret_name.replace('_','-') #.upper() - logger.info('Fetching the secret {} from Key Vault {}.', variable_replaced, self.akv_name) + variable_replaced = secret_name.replace("_", "-") # .upper() + logger.info("Fetching the secret {} from Key Vault {}.", variable_replaced, self.akv_name) secret = self.secret_client.get_secret(variable_replaced) - logger.info('Secret {} fetched from Key Vault {}.', variable_replaced, self.akv_name) + logger.info("Secret {} fetched from Key Vault {}.", variable_replaced, self.akv_name) return secret.value except ResourceNotFoundError as e: logger.error(f"Secret {secret_name} cannot be found in Key Vault {self.akv_name}.") - raise \ No newline at end of file + raise diff --git a/feathr_project/feathr/spark_provider/_abc.py b/feathr_project/feathr/spark_provider/_abc.py index c91fdf5c1..afbe3952a 100644 --- a/feathr_project/feathr/spark_provider/_abc.py +++ b/feathr_project/feathr/spark_provider/_abc.py @@ -3,8 +3,7 @@ class SparkJobLauncher(ABC): - """This is the abstract class for all the spark launchers. All the Spark launcher should implement those interfaces - """ + """This is the abstract class for all the spark launchers. All the Spark launcher should implement those interfaces""" @abstractmethod def upload_or_get_cloud_path(self, local_path_or_http_path: str): @@ -16,9 +15,17 @@ def upload_or_get_cloud_path(self, local_path_or_http_path: str): pass @abstractmethod - def submit_feathr_job(self, job_name: str, main_jar_path: str, main_class_name: str, arguments: List[str], - reference_files_path: List[str], job_tags: Dict[str, str] = None, - configuration: Dict[str, str] = {}, properties: Dict[str, str] = None): + def submit_feathr_job( + self, + job_name: str, + main_jar_path: str, + main_class_name: str, + arguments: List[str], + reference_files_path: List[str], + job_tags: Dict[str, str] = None, + configuration: Dict[str, str] = {}, + properties: Dict[str, str] = None, + ): """ Submits the feathr job diff --git a/feathr_project/feathr/spark_provider/_databricks_submission.py b/feathr_project/feathr/spark_provider/_databricks_submission.py index bfa76e3e7..3cd92d748 100644 --- a/feathr_project/feathr/spark_provider/_databricks_submission.py +++ b/feathr_project/feathr/spark_provider/_databricks_submission.py @@ -61,30 +61,29 @@ def __init__( self.auth_headers["Accept"] = "application/json" self.auth_headers["Authorization"] = f"Bearer {token_value}" self.databricks_work_dir = databricks_work_dir - self.api_client = ApiClient( - host=self.workspace_instance_url, token=token_value) + self.api_client = ApiClient(host=self.workspace_instance_url, token=token_value) def upload_or_get_cloud_path(self, local_path_or_cloud_src_path: str, tar_dir_path: Optional[str] = None): """ Supports transferring file from an http path to cloud working storage, or upload directly from a local storage. or copying files from a source dbfs directory to a target dbfs directory """ - if local_path_or_cloud_src_path.startswith('dbfs') and tar_dir_path is not None: - if not tar_dir_path.startswith('dbfs'): + if local_path_or_cloud_src_path.startswith("dbfs") and tar_dir_path is not None: + if not tar_dir_path.startswith("dbfs"): raise RuntimeError( f"Failed to copy files from dbfs directory: {local_path_or_cloud_src_path}. {tar_dir_path} is not a valid target directory path" ) if not self.cloud_dir_exists(local_path_or_cloud_src_path): raise RuntimeError( - f"Source folder:{local_path_or_cloud_src_path} doesn't exist. Please make sure it's a valid path") + f"Source folder:{local_path_or_cloud_src_path} doesn't exist. Please make sure it's a valid path" + ) if self.cloud_dir_exists(tar_dir_path): - logger.warning( - 'Target cloud directory {} already exists. Please use another one.', tar_dir_path) + logger.warning("Target cloud directory {} already exists. Please use another one.", tar_dir_path) return tar_dir_path - DbfsApi(self.api_client).cp(recursive=True, overwrite=False, - src=local_path_or_cloud_src_path, dst=tar_dir_path) - logger.info('{} is copied to location: {}', - local_path_or_cloud_src_path, tar_dir_path) + DbfsApi(self.api_client).cp( + recursive=True, overwrite=False, src=local_path_or_cloud_src_path, dst=tar_dir_path + ) + logger.info("{} is copied to location: {}", local_path_or_cloud_src_path, tar_dir_path) return tar_dir_path src_parse_result = urlparse(local_path_or_cloud_src_path) @@ -93,22 +92,26 @@ def upload_or_get_cloud_path(self, local_path_or_cloud_src_path: str, tar_dir_pa # dbfs:/feathrazure_cijob_snowflake_9_30_157692\auto_generated_derived_features.conf, where the path sep is mixed, and won't be able to be parsed by databricks. # so we force the path to be Linux style here. cloud_dest_path = self.databricks_work_dir + "/" + file_name - if src_parse_result.scheme.startswith('http'): + if src_parse_result.scheme.startswith("http"): with urlopen(local_path_or_cloud_src_path) as f: # use REST API to avoid local temp file data = f.read() files = {"file": data} # for DBFS APIs, see: https://docs.microsoft.com/en-us/azure/databricks/dev-tools/api/latest/dbfs - r = requests.post(url=self.workspace_instance_url+'/api/2.0/dbfs/put', - headers=self.auth_headers, files=files, data={'overwrite': 'true', 'path': cloud_dest_path}) - logger.info('{} is downloaded and then uploaded to location: {}', - local_path_or_cloud_src_path, cloud_dest_path) - elif src_parse_result.scheme.startswith('dbfs'): + r = requests.post( + url=self.workspace_instance_url + "/api/2.0/dbfs/put", + headers=self.auth_headers, + files=files, + data={"overwrite": "true", "path": cloud_dest_path}, + ) + logger.info( + "{} is downloaded and then uploaded to location: {}", local_path_or_cloud_src_path, cloud_dest_path + ) + elif src_parse_result.scheme.startswith("dbfs"): # passed a cloud path - logger.info( - 'Skip uploading file {} as the file starts with dbfs:/', local_path_or_cloud_src_path) + logger.info("Skip uploading file {} as the file starts with dbfs:/", local_path_or_cloud_src_path) cloud_dest_path = local_path_or_cloud_src_path - elif src_parse_result.scheme.startswith(('wasb', 's3', 'gs')): + elif src_parse_result.scheme.startswith(("wasb", "s3", "gs")): # if the path starts with a location that's not a local path logger.error( "File {} cannot be downloaded. Please upload the file to dbfs manually.", local_path_or_cloud_src_path @@ -119,17 +122,14 @@ def upload_or_get_cloud_path(self, local_path_or_cloud_src_path: str, tar_dir_pa else: # else it should be a local file path or dir if os.path.isdir(local_path_or_cloud_src_path): - logger.info("Uploading folder {}", - local_path_or_cloud_src_path) + logger.info("Uploading folder {}", local_path_or_cloud_src_path) dest_paths = [] - for item in Path(local_path_or_cloud_src_path).glob('**/*.conf'): - cloud_dest_path = self._upload_local_file_to_workspace( - item.resolve()) + for item in Path(local_path_or_cloud_src_path).glob("**/*.conf"): + cloud_dest_path = self._upload_local_file_to_workspace(item.resolve()) dest_paths.extend([cloud_dest_path]) - cloud_dest_path = ','.join(dest_paths) + cloud_dest_path = ",".join(dest_paths) else: - cloud_dest_path = self._upload_local_file_to_workspace( - local_path_or_cloud_src_path) + cloud_dest_path = self._upload_local_file_to_workspace(local_path_or_cloud_src_path) return cloud_dest_path def _upload_local_file_to_workspace(self, local_path: str) -> str: @@ -144,11 +144,11 @@ def _upload_local_file_to_workspace(self, local_path: str) -> str: # `local_path_or_http_path` will be either string or PathLib object, so normalize it to string local_path = str(local_path) try: - DbfsApi(self.api_client).cp(recursive=True, overwrite=True, - src=local_path, dst=cloud_dest_path) + DbfsApi(self.api_client).cp(recursive=True, overwrite=True, src=local_path, dst=cloud_dest_path) except RuntimeError as e: raise RuntimeError( - f"The source path: {local_path}, or the destination path: {cloud_dest_path}, is/are not valid.") from e + f"The source path: {local_path}, or the destination path: {cloud_dest_path}, is/are not valid." + ) from e return cloud_dest_path def submit_feathr_job( @@ -191,8 +191,7 @@ def submit_feathr_job( submission_params["run_name"] = job_name cfg = configuration.copy() if "existing_cluster_id" in submission_params: - logger.info( - "Using an existing general purpose cluster to run the feathr job...") + logger.info("Using an existing general purpose cluster to run the feathr job...") if cfg: logger.warning( "Spark execution configuration will be ignored. To use job-specific spark configs, please use a new job cluster or set the configs via Databricks UI." @@ -211,8 +210,7 @@ def submit_feathr_job( submission_params["new_cluster"]["spark_conf"] = cfg if job_tags: - custom_tags = submission_params["new_cluster"].get( - "custom_tags", {}) + custom_tags = submission_params["new_cluster"].get("custom_tags", {}) for tag, value in job_tags.items(): custom_tags[tag] = value @@ -225,10 +223,8 @@ def submit_feathr_job( # the feathr main jar file is anyway needed regardless it's pyspark or scala spark if not main_jar_path: - logger.info( - f"Main JAR file is not set, using default package '{get_maven_artifact_fullname()}' from Maven") - submission_params['libraries'][0]['maven'] = { - "coordinates": get_maven_artifact_fullname()} + logger.info(f"Main JAR file is not set, using default package '{get_maven_artifact_fullname()}' from Maven") + submission_params["libraries"][0]["maven"] = {"coordinates": get_maven_artifact_fullname()} # Add json-schema dependency # TODO: find a proper way deal with unresolved dependencies # Since we are adding another entry to the config, make sure that the spark config passed as part of execution also contains a libraries array of atleast size 2 @@ -236,8 +232,7 @@ def submit_feathr_job( # Example from feathr_config.yaml - # config_template: {"run_name":"FEATHR_FILL_IN",.....,"libraries":[{}, {}],".......} else: - submission_params["libraries"][0]["jar"] = self.upload_or_get_cloud_path( - main_jar_path) + submission_params["libraries"][0]["jar"] = self.upload_or_get_cloud_path(main_jar_path) # see here for the submission parameter definition https://docs.microsoft.com/en-us/azure/databricks/dev-tools/api/2.0/jobs#--request-structure-6 if python_files: # this is a pyspark job. definition here: https://docs.microsoft.com/en-us/azure/databricks/dev-tools/api/2.0/jobs#--sparkpythontask @@ -248,8 +243,7 @@ def submit_feathr_job( } # indicates this is a pyspark job # `setdefault` method will get the value of the "spark_python_task" item, if the "spark_python_task" item does not exist, insert "spark_python_task" with the value "param_and_file_dict": - submission_params.setdefault( - "spark_python_task", param_and_file_dict) + submission_params.setdefault("spark_python_task", param_and_file_dict) else: # this is a scala spark job submission_params["spark_jar_task"]["parameters"] = arguments @@ -268,8 +262,7 @@ def submit_feathr_job( result = RunsApi(self.api_client).get_run(self.res_job_id) self.job_url = result["run_page_url"] - logger.info( - "Feathr job Submitted Successfully. View more details here: {}", self.job_url) + logger.info("Feathr job Submitted Successfully. View more details here: {}", self.job_url) # return ID as the submission result return self.res_job_id @@ -286,12 +279,10 @@ def wait_for_completion(self, timeout_seconds: Optional[int] = 600) -> bool: if status in {"SUCCESS"}: return True elif status in {"INTERNAL_ERROR", "FAILED", "TIMEDOUT", "CANCELED"}: - result = RunsApi(self.api_client).get_run_output( - self.res_job_id) + result = RunsApi(self.api_client).get_run_output(self.res_job_id) # See here for the returned fields: https://docs.microsoft.com/en-us/azure/databricks/dev-tools/api/2.0/jobs#--response-structure-8 # print out logs and stack trace if the job has failed - logger.error( - "Feathr job has failed. Please visit this page to view error message: {}", self.job_url) + logger.error("Feathr job has failed. Please visit this page to view error message: {}", self.job_url) if "error" in result: logger.error("Error Code: {}", result["error"]) if "error_trace" in result: @@ -307,8 +298,7 @@ def get_status(self) -> str: result = RunsApi(self.api_client).get_run(self.res_job_id) # first try to get result state. it might not be available, and if that's the case, try to get life_cycle_state # see result structure: https://docs.microsoft.com/en-us/azure/databricks/dev-tools/api/2.0/jobs#--response-structure-6 - res_state = result["state"].get( - "result_state") or result["state"]["life_cycle_state"] + res_state = result["state"].get("result_state") or result["state"]["life_cycle_state"] assert res_state is not None return res_state @@ -333,8 +323,7 @@ def get_job_tags(self) -> Dict[str, str]: result = RunsApi(self.api_client).get_run(self.res_job_id) if "new_cluster" in result["cluster_spec"]: - custom_tags = result["cluster_spec"]["new_cluster"].get( - "custom_tags") + custom_tags = result["cluster_spec"]["new_cluster"].get("custom_tags") return custom_tags else: # this is not a new cluster; it's an existing cluster. @@ -354,14 +343,15 @@ def download_result(self, result_path: str, local_folder: str, is_file_path: boo recursive = True if not is_file_path else False DbfsApi(self.api_client).cp(recursive=recursive, overwrite=True, src=result_path, dst=local_folder) - + def cloud_dir_exists(self, dir_path: str): """ Check if a directory of hdfs already exists """ - if not dir_path.startswith('dbfs'): + if not dir_path.startswith("dbfs"): raise RuntimeError( - 'Currently only paths starting with dbfs is supported. The paths should start with \"dbfs:\" .') + 'Currently only paths starting with dbfs is supported. The paths should start with "dbfs:" .' + ) try: DbfsApi(self.api_client).list_files(DbfsPath(dir_path)) diff --git a/feathr_project/feathr/spark_provider/_localspark_submission.py b/feathr_project/feathr/spark_provider/_localspark_submission.py index 4daefa493..926a6b6ef 100644 --- a/feathr_project/feathr/spark_provider/_localspark_submission.py +++ b/feathr_project/feathr/spark_provider/_localspark_submission.py @@ -86,7 +86,12 @@ def submit_feathr_job( maven_dependency = f"{cfg.pop('spark.jars.packages', self.packages)},{get_maven_artifact_fullname()}" spark_args = self._init_args(job_name=job_name, confs=cfg) # Add additional repositories - spark_args.extend(["--repositories", "https://repository.mulesoft.org/nexus/content/repositories/public/,https://linkedin.jfrog.io/artifactory/open-source/"]) + spark_args.extend( + [ + "--repositories", + "https://repository.mulesoft.org/nexus/content/repositories/public/,https://linkedin.jfrog.io/artifactory/open-source/", + ] + ) if not main_jar_path: # We don't have the main jar, use Maven @@ -94,7 +99,9 @@ def submit_feathr_job( # This is a JAR job # Azure Synapse/Livy doesn't allow JAR job starts from Maven directly, we must have a jar file uploaded. # so we have to use a dummy jar as the main file. - logger.info(f"Main JAR file is not set, using default package '{get_maven_artifact_fullname()}' from Maven") + logger.info( + f"Main JAR file is not set, using default package '{get_maven_artifact_fullname()}' from Maven" + ) # Use the no-op jar as the main file # This is a dummy jar which contains only one `org.example.Noop` class with one empty `main` function # which does nothing @@ -120,7 +127,6 @@ def submit_feathr_job( spark_args.extend(["--py-files", ",".join(python_files[1:])]) spark_args.append(python_files[0]) - if arguments: spark_args.extend(arguments) @@ -216,9 +222,7 @@ def wait_for_completion(self, timeout_seconds: Optional[float] = 500) -> bool: # If the logs gives out hint that the job is finished, even the poll result is None (indicating the process is still running) we will still terminate it. # by calling `proc.terminate()` # if the process is terminated with this way, the return code will be 143. We assume this will still be a successful run. - logger.info( - f"Spark job with pid {self.latest_spark_proc.pid} finished in: {int(job_duration)} seconds." - ) + logger.info(f"Spark job with pid {self.latest_spark_proc.pid} finished in: {int(job_duration)} seconds.") return True else: logger.info( diff --git a/feathr_project/feathr/spark_provider/_synapse_submission.py b/feathr_project/feathr/spark_provider/_synapse_submission.py index 02d7d4c99..8804ff47c 100644 --- a/feathr_project/feathr/spark_provider/_synapse_submission.py +++ b/feathr_project/feathr/spark_provider/_synapse_submission.py @@ -11,9 +11,13 @@ from os.path import basename from enum import Enum import tempfile -from azure.identity import (ChainedTokenCredential, DefaultAzureCredential, - DeviceCodeCredential, EnvironmentCredential, - ManagedIdentityCredential) +from azure.identity import ( + ChainedTokenCredential, + DefaultAzureCredential, + DeviceCodeCredential, + EnvironmentCredential, + ManagedIdentityCredential, +) from azure.storage.filedatalake import DataLakeServiceClient, DataLakeDirectoryClient from azure.synapse.spark import SparkClient from azure.synapse.spark.models import SparkBatchJobOptions @@ -25,8 +29,9 @@ from feathr.constants import * from feathr.version import get_maven_artifact_fullname + class LivyStates(Enum): - """ Adapt LivyStates over to relax the dependency for azure-synapse-spark pacakge. + """Adapt LivyStates over to relax the dependency for azure-synapse-spark pacakge. Definition is here: https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/synapse/azure-synapse-spark/azure/synapse/spark/models/_spark_client_enums.py#L38 """ @@ -49,14 +54,22 @@ class _FeathrSynapseJobLauncher(SparkJobLauncher): Submits spark jobs to a Synapse spark cluster. """ - def __init__(self, synapse_dev_url: str, pool_name: str, datalake_dir: str, executor_size: str, executors: int, credential=None): + def __init__( + self, + synapse_dev_url: str, + pool_name: str, + datalake_dir: str, + executor_size: str, + executors: int, + credential=None, + ): # use DeviceCodeCredential if EnvironmentCredential is not available self.credential = credential # use the same credential for authentication to avoid further login. self._api = _SynapseJobRunner( - synapse_dev_url, pool_name, executor_size=executor_size, executors=executors, credential=self.credential) - self._datalake = _DataLakeFiler( - datalake_dir, credential=self.credential) + synapse_dev_url, pool_name, executor_size=executor_size, executors=executors, credential=self.credential + ) + self._datalake = _DataLakeFiler(datalake_dir, credential=self.credential) # Save Synapse parameters to retrieve driver log self._synapse_dev_url = synapse_dev_url self._pool_name = pool_name @@ -66,32 +79,31 @@ def upload_or_get_cloud_path(self, local_path_or_cloud_src_path: str, tar_dir_pa Supports transferring file from an http path to cloud working storage, or upload directly from a local storage, or copying files from a source datalake directory to a target datalake directory """ - if local_path_or_cloud_src_path.startswith('abfs') or local_path_or_cloud_src_path.startswith('wasb'): - if tar_dir_path is None or not (tar_dir_path.startswith('abfs') or tar_dir_path.startswith('wasb')): + if local_path_or_cloud_src_path.startswith("abfs") or local_path_or_cloud_src_path.startswith("wasb"): + if tar_dir_path is None or not (tar_dir_path.startswith("abfs") or tar_dir_path.startswith("wasb")): raise RuntimeError( - f"Failed to copy files from dbfs directory: {local_path_or_cloud_src_path}. {tar_dir_path} is not a valid target directory path" - ) + f"Failed to copy files from dbfs directory: {local_path_or_cloud_src_path}. {tar_dir_path} is not a valid target directory path" + ) [_, source_exist] = self._datalake._dir_exists(local_path_or_cloud_src_path) if not source_exist: - raise RuntimeError(f"Source folder:{local_path_or_cloud_src_path} doesn't exist. Please make sure it's a valid path") + raise RuntimeError( + f"Source folder:{local_path_or_cloud_src_path} doesn't exist. Please make sure it's a valid path" + ) [dir_client, target_exist] = self._datalake._dir_exists(tar_dir_path) if target_exist: - logger.warning('Target cloud directory {} already exists. Please use another one.', tar_dir_path) + logger.warning("Target cloud directory {} already exists. Please use another one.", tar_dir_path) return tar_dir_path dir_client.create_directory() tem_dir_obj = tempfile.TemporaryDirectory() self._datalake.download_file(local_path_or_cloud_src_path, tem_dir_obj.name) self._datalake.upload_file_to_workdir(tem_dir_obj.name, tar_dir_path, dir_client) - logger.info('{} is uploaded to location: {}', - local_path_or_cloud_src_path, tar_dir_path) + logger.info("{} is uploaded to location: {}", local_path_or_cloud_src_path, tar_dir_path) return tar_dir_path - - logger.info('Uploading {} to cloud..', local_path_or_cloud_src_path) - res_path = self._datalake.upload_file_to_workdir( - local_path_or_cloud_src_path) - logger.info('{} is uploaded to location: {}', - local_path_or_cloud_src_path, res_path) + logger.info("Uploading {} to cloud..", local_path_or_cloud_src_path) + res_path = self._datalake.upload_file_to_workdir(local_path_or_cloud_src_path) + + logger.info("{} is uploaded to location: {}", local_path_or_cloud_src_path, res_path) return res_path def download_result(self, result_path: str, local_folder: str, is_file_path: bool = False): @@ -99,24 +111,32 @@ def download_result(self, result_path: str, local_folder: str, is_file_path: boo Supports downloading files from the result folder """ if is_file_path: - paths = result_path.rsplit('/',1) + paths = result_path.rsplit("/", 1) if len(paths) != 2: raise RuntimeError(f"Invalid single file path: {result_path}") - return self._datalake.download_file(paths[0]+'/', local_folder, paths[1]) + return self._datalake.download_file(paths[0] + "/", local_folder, paths[1]) return self._datalake.download_file(result_path, local_folder, None) - - + def cloud_dir_exists(self, dir_path: str) -> bool: """ Checks if a directory already exists in the datalake """ - + [_, exists] = self._datalake._dir_exists(dir_path) return exists - def submit_feathr_job(self, job_name: str, main_jar_path: str = None, main_class_name: str = None, arguments: List[str] = None, - python_files: List[str]= None, reference_files_path: List[str] = None, job_tags: Dict[str, str] = None, - configuration: Dict[str, str] = {}, properties: Dict[str, str] = {}): + def submit_feathr_job( + self, + job_name: str, + main_jar_path: str = None, + main_class_name: str = None, + arguments: List[str] = None, + python_files: List[str] = None, + reference_files_path: List[str] = None, + job_tags: Dict[str, str] = None, + configuration: Dict[str, str] = {}, + properties: Dict[str, str] = {}, + ): """ Submits the feathr job Refer to the Apache Livy doc for more details on the meaning of the parameters: @@ -152,8 +172,7 @@ def submit_feathr_job(self, job_name: str, main_jar_path: str = None, main_clas # Add Maven dependency to the job configuration logger.info(f"Main JAR file is not set, using default package '{get_maven_artifact_fullname()}' from Maven") if "spark.jars.packages" in cfg: - cfg["spark.jars.packages"] = ",".join( - [cfg["spark.jars.packages"], get_maven_artifact_fullname()]) + cfg["spark.jars.packages"] = ",".join([cfg["spark.jars.packages"], get_maven_artifact_fullname()]) else: cfg["spark.jars.packages"] = get_maven_artifact_fullname() @@ -171,16 +190,13 @@ def submit_feathr_job(self, job_name: str, main_jar_path: str = None, main_clas main_jar_cloud_path = None if main_jar_path: # Now we have a main jar, either feathr or noop - if main_jar_path.startswith('abfs'): + if main_jar_path.startswith("abfs"): main_jar_cloud_path = main_jar_path - logger.info( - 'Cloud path {} is used for running the job: {}', main_jar_path, job_name) + logger.info("Cloud path {} is used for running the job: {}", main_jar_path, job_name) else: - logger.info('Uploading jar from {} to cloud for running job: {}', - main_jar_path, job_name) + logger.info("Uploading jar from {} to cloud for running job: {}", main_jar_path, job_name) main_jar_cloud_path = self._datalake.upload_file_to_workdir(main_jar_path) - logger.info('{} is uploaded to {} for running job: {}', - main_jar_path, main_jar_cloud_path, job_name) + logger.info("{} is uploaded to {} for running job: {}", main_jar_path, main_jar_cloud_path, job_name) else: # We don't have the main Jar, and this is a PySpark job so we don't use `noop.jar` either # Keep `main_jar_cloud_path` as `None` as we already added maven package into cfg @@ -188,40 +204,44 @@ def submit_feathr_job(self, job_name: str, main_jar_path: str = None, main_clas reference_file_paths = [] for file_path in reference_files_path: - reference_file_paths.append( - self._datalake.upload_file_to_workdir(file_path)) - - self.current_job_info = self._api.create_spark_batch_job(job_name=job_name, - main_file=main_jar_cloud_path, - class_name=main_class_name, - python_files=python_files, - arguments=arguments, - reference_files=reference_files_path, - tags=job_tags, - configuration=cfg) - logger.info('See submitted job here: https://web.azuresynapse.net/en-us/monitoring/sparkapplication') + reference_file_paths.append(self._datalake.upload_file_to_workdir(file_path)) + + self.current_job_info = self._api.create_spark_batch_job( + job_name=job_name, + main_file=main_jar_cloud_path, + class_name=main_class_name, + python_files=python_files, + arguments=arguments, + reference_files=reference_files_path, + tags=job_tags, + configuration=cfg, + ) + logger.info("See submitted job here: https://web.azuresynapse.net/en-us/monitoring/sparkapplication") return self.current_job_info def wait_for_completion(self, timeout_seconds: Optional[float]) -> bool: """ Returns true if the job completed successfully - """ + """ start_time = time.time() while (timeout_seconds is None) or (time.time() - start_time < timeout_seconds): status = self.get_status() - logger.info('Current Spark job status: {}', status) + logger.info("Current Spark job status: {}", status) if status in {LivyStates.SUCCESS.value}: return True elif status in {LivyStates.ERROR.value, LivyStates.DEAD.value, LivyStates.KILLED.value}: logger.error("Feathr job has failed.") - error_msg = self._api.get_driver_log(self.current_job_info.id).decode('utf-8') + error_msg = self._api.get_driver_log(self.current_job_info.id).decode("utf-8") logger.error(error_msg) - logger.error("The size of the whole error log is: {}. The logs might be truncated in some cases (such as in Visual Studio Code) so only the top a few lines of the error message is displayed. If you cannot see the whole log, you may want to extend the setting for output size limit.", len(error_msg)) + logger.error( + "The size of the whole error log is: {}. The logs might be truncated in some cases (such as in Visual Studio Code) so only the top a few lines of the error message is displayed. If you cannot see the whole log, you may want to extend the setting for output size limit.", + len(error_msg), + ) return False else: time.sleep(30) else: - raise TimeoutError('Timeout waiting for job to complete') + raise TimeoutError("Timeout waiting for job to complete") def get_status(self) -> str: """Get current job status @@ -251,28 +271,29 @@ def get_job_tags(self) -> Dict[str, str]: """ return self._api.get_spark_batch_job(self.current_job_info.id).tags + class _SynapseJobRunner(object): """ Class to interact with Synapse Spark cluster """ - def __init__(self, synapse_dev_url, spark_pool_name, credential=None, executor_size='Small', executors=2): + + def __init__(self, synapse_dev_url, spark_pool_name, credential=None, executor_size="Small", executors=2): self._synapse_dev_url = synapse_dev_url self._spark_pool_name = spark_pool_name if credential is None: - logger.warning('No valid Azure credential detected. Using DefaultAzureCredential') + logger.warning("No valid Azure credential detected. Using DefaultAzureCredential") credential = DefaultAzureCredential() self._credential = credential - self.client = SparkClient( - credential=credential, - endpoint=synapse_dev_url, - spark_pool_name=spark_pool_name - ) + self.client = SparkClient(credential=credential, endpoint=synapse_dev_url, spark_pool_name=spark_pool_name) self._executor_size = executor_size self._executors = executors - self.EXECUTOR_SIZE = {'Small': {'Cores': 4, 'Memory': '28g'}, 'Medium': {'Cores': 8, 'Memory': '56g'}, - 'Large': {'Cores': 16, 'Memory': '112g'}} + self.EXECUTOR_SIZE = { + "Small": {"Cores": 4, "Memory": "28g"}, + "Medium": {"Cores": 8, "Memory": "56g"}, + "Large": {"Cores": 16, "Memory": "112g"}, + } def _categorized_files(self, reference_files: List[str]): """categorize files to make sure they are in the ready to submissio format @@ -290,13 +311,13 @@ def _categorized_files(self, reference_files: List[str]): jars = [] for file in reference_files: file = file.strip() - if file.endswith('.jar'): + if file.endswith(".jar"): jars.append(file) else: files.append(file) return files, jars - def get_spark_batch_job(self, job_id:int): + def get_spark_batch_job(self, job_id: int): """ Get the job object by searching a certain ID """ @@ -310,24 +331,34 @@ def get_spark_batch_jobs(self): return self.client.spark_batch.get_spark_batch_jobs(detailed=True) - def cancel_spark_batch_job(self, job_id:int): + def cancel_spark_batch_job(self, job_id: int): """ Cancel a job by searching a certain ID """ return self.client.spark_batch.cancel_spark_batch_job(job_id) - def create_spark_batch_job(self, job_name, main_file, class_name=None, - arguments=None, python_files=None, reference_files=None, archives=None, configuration=None, tags=None): + def create_spark_batch_job( + self, + job_name, + main_file, + class_name=None, + arguments=None, + python_files=None, + reference_files=None, + archives=None, + configuration=None, + tags=None, + ): """ Submit a spark job to a certain cluster """ files, jars = self._categorized_files(reference_files) - driver_cores = self.EXECUTOR_SIZE[self._executor_size]['Cores'] - driver_memory = self.EXECUTOR_SIZE[self._executor_size]['Memory'] - executor_cores = self.EXECUTOR_SIZE[self._executor_size]['Cores'] - executor_memory = self.EXECUTOR_SIZE[self._executor_size]['Memory'] + driver_cores = self.EXECUTOR_SIZE[self._executor_size]["Cores"] + driver_memory = self.EXECUTOR_SIZE[self._executor_size]["Memory"] + executor_cores = self.EXECUTOR_SIZE[self._executor_size]["Cores"] + executor_memory = self.EXECUTOR_SIZE[self._executor_size]["Memory"] # If we have a main jar, it needs to be added as dependencies for pyspark job # Otherwise it's a PySpark job with Feathr JAR from Maven @@ -356,14 +387,20 @@ def create_spark_batch_job(self, job_name, main_file, class_name=None, driver_cores=driver_cores, executor_memory=executor_memory, executor_cores=executor_cores, - executor_count=self._executors) + executor_count=self._executors, + ) return self.client.spark_batch.create_spark_batch_job(spark_batch_job_options, detailed=True) def get_driver_log(self, job_id) -> str: # @see: https://docs.microsoft.com/en-us/azure/synapse-analytics/spark/connect-monitor-azure-synapse-spark-application-level-metrics app_id = self.get_spark_batch_job(job_id).app_id - url = "%s/sparkhistory/api/v1/sparkpools/%s/livyid/%s/applications/%s/driverlog/stderr/?isDownload=true" % (self._synapse_dev_url, self._spark_pool_name, job_id, app_id) + url = "%s/sparkhistory/api/v1/sparkpools/%s/livyid/%s/applications/%s/driverlog/stderr/?isDownload=true" % ( + self._synapse_dev_url, + self._spark_pool_name, + job_id, + app_id, + ) token = self._credential.get_token("https://dev.azuresynapse.net/.default").token req = urllib.request.Request(url=url, headers={"authorization": "Bearer %s" % token}) resp = urllib.request.urlopen(req) @@ -374,6 +411,7 @@ class _DataLakeFiler(object): """ Class to interact with Azure Data Lake Storage. """ + def __init__(self, datalake_dir, credential=None): # A datalake path would be something like this: # "abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/frame_getting_started" after this @@ -383,7 +421,7 @@ def __init__(self, datalake_dir, credential=None): # container name), datalake_path_split[3] should be the full name of this target path, # and datalake_path_split[3:] would be all the directory in this particular container split datalake names by # "/" or "@" - datalake_path_split = list(filter(None, re.split('/|@', datalake_dir))) + datalake_path_split = list(filter(None, re.split("/|@", datalake_dir))) assert len(datalake_path_split) >= 3 if credential is None: @@ -392,23 +430,25 @@ def __init__(self, datalake_dir, credential=None): account_url = "https://" + datalake_path_split[2] self.file_system_client = DataLakeServiceClient( - credential=credential, - account_url=account_url + credential=credential, account_url=account_url ).get_file_system_client(datalake_path_split[1]) if len(datalake_path_split) > 3: # directory exists in datalake path - self.dir_client = self.file_system_client.get_directory_client( - '/'.join(datalake_path_split[3:])) + self.dir_client = self.file_system_client.get_directory_client("/".join(datalake_path_split[3:])) self.dir_client.create_directory() else: # otherwise use root folder instead - self.dir_client = self.file_system_client.get_directory_client('/') + self.dir_client = self.file_system_client.get_directory_client("/") - self.datalake_dir = datalake_dir + \ - '/' if datalake_dir[-1] != '/' else datalake_dir + self.datalake_dir = datalake_dir + "/" if datalake_dir[-1] != "/" else datalake_dir - def upload_file_to_workdir(self, src_file_path: str, tar_dir_path: Optional[str] = "", tar_dir_client: Optional[DataLakeDirectoryClient] = None) -> str: + def upload_file_to_workdir( + self, + src_file_path: str, + tar_dir_path: Optional[str] = "", + tar_dir_client: Optional[DataLakeDirectoryClient] = None, + ) -> str: """ Handles file upload to the corresponding datalake storage. If a path starts with "wasb" or "abfs", it will skip uploading and return the original path; otherwise it will upload the source file to the working @@ -416,7 +456,7 @@ def upload_file_to_workdir(self, src_file_path: str, tar_dir_path: Optional[str] """ src_parse_result = urlparse(src_file_path) - if src_parse_result.scheme.startswith('http'): + if src_parse_result.scheme.startswith("http"): file_name = basename(src_file_path) file_client = self.dir_client.create_file(file_name) # returned paths for the uploaded file @@ -425,7 +465,7 @@ def upload_file_to_workdir(self, src_file_path: str, tar_dir_path: Optional[str] data = f.read() file_client.upload_data(data, overwrite=True) logger.info("{} is downloaded and then uploaded to location: {}", src_file_path, returned_path) - elif src_parse_result.scheme.startswith('abfs') or src_parse_result.scheme.startswith('wasb'): + elif src_parse_result.scheme.startswith("abfs") or src_parse_result.scheme.startswith("wasb"): # passed a cloud path logger.info("Skip uploading file {} as it's already in the cloud", src_file_path) returned_path = src_file_path @@ -436,30 +476,33 @@ def upload_file_to_workdir(self, src_file_path: str, tar_dir_path: Optional[str] dest_paths = [] if tar_dir_client is not None: # Only supports uploading local files/dir to datalake dir for now - for item in Path(src_file_path).iterdir(): + for item in Path(src_file_path).iterdir(): returned_path = self.upload_file(item.resolve(), tar_dir_path, tar_dir_client) - dest_paths.extend([returned_path]) + dest_paths.extend([returned_path]) else: - for item in Path(src_file_path).glob('**/*.conf'): + for item in Path(src_file_path).glob("**/*.conf"): returned_path = self.upload_file(item.resolve()) dest_paths.extend([returned_path]) - returned_path = ','.join(dest_paths) + returned_path = ",".join(dest_paths) else: returned_path = self.upload_file(src_file_path) return returned_path - def upload_file(self, src_file_path, tar_dir_path: Optional[str]="", tar_dir_client: Optional[DataLakeDirectoryClient] = None)-> str: + def upload_file( + self, src_file_path, tar_dir_path: Optional[str] = "", tar_dir_client: Optional[DataLakeDirectoryClient] = None + ) -> str: file_name = basename(src_file_path) logger.info("Uploading file {}", file_name) # TODO: add handling for only tar_dir_client or tar_dir_path is provided - file_client = self.dir_client.create_file(file_name) if tar_dir_client is None else tar_dir_client.create_file(file_name) + file_client = ( + self.dir_client.create_file(file_name) if tar_dir_client is None else tar_dir_client.create_file(file_name) + ) returned_path = self.datalake_dir + file_name if tar_dir_path == "" else tar_dir_path + file_name - with open(src_file_path, 'rb') as f: + with open(src_file_path, "rb") as f: data = f.read() file_client.upload_data(data, overwrite=True) logger.info("{} is uploaded to location: {}", src_file_path, returned_path) return returned_path - def download_file(self, target_adls_directory: str, local_dir_cache: str, file_name: str = None): """ @@ -471,71 +514,72 @@ def download_file(self, target_adls_directory: str, local_dir_cache: str, file_n local_dir_cache (str): local cache to store local results file_name (str): only download the file with name 'file_name' under the target directory if it's provided (default as None) """ - logger.info('Beginning reading of results from {}', - target_adls_directory) + logger.info("Beginning reading of results from {}", target_adls_directory) parse_result = urlparse(target_adls_directory) - if parse_result.path == '': - parse_result.path = '/' - directory_client = self.file_system_client.get_directory_client( - parse_result.path) - + if parse_result.path == "": + parse_result.path = "/" + directory_client = self.file_system_client.get_directory_client(parse_result.path) + if file_name is not None: local_paths = [os.path.join(local_dir_cache, file_name)] self._download_file_list(local_paths, [file_name], directory_client) - logger.info('Finish downloading file {} from {} to {}.', - file_name, target_adls_directory, local_dir_cache) + logger.info("Finish downloading file {} from {} to {}.", file_name, target_adls_directory, local_dir_cache) return - + # returns the paths to all the files in the target director in ADLS # get all the paths that are not under a directory - result_paths = [basename(file_path.name) for file_path in self.file_system_client.get_paths( - path=parse_result.path, recursive=False) if not file_path.is_directory] + result_paths = [ + basename(file_path.name) + for file_path in self.file_system_client.get_paths(path=parse_result.path, recursive=False) + if not file_path.is_directory + ] # get all the paths that are directories and download them - result_folders = [file_path.name for file_path in self.file_system_client.get_paths( - path=parse_result.path) if file_path.is_directory] + result_folders = [ + file_path.name + for file_path in self.file_system_client.get_paths(path=parse_result.path) + if file_path.is_directory + ] # list all the files under the certain folder, and download them preserving the hierarchy for folder in result_folders: folder_name = basename(folder) - file_in_folder = [os.path.join(folder_name, basename(file_path.name)) for file_path in self.file_system_client.get_paths( - path=folder, recursive=False) if not file_path.is_directory] - local_paths = [os.path.join(local_dir_cache, file_name) - for file_name in file_in_folder] + file_in_folder = [ + os.path.join(folder_name, basename(file_path.name)) + for file_path in self.file_system_client.get_paths(path=folder, recursive=False) + if not file_path.is_directory + ] + local_paths = [os.path.join(local_dir_cache, file_name) for file_name in file_in_folder] self._download_file_list(local_paths, file_in_folder, directory_client) # download files that are in the result folder - local_paths = [os.path.join(local_dir_cache, file_name) - for file_name in result_paths] + local_paths = [os.path.join(local_dir_cache, file_name) for file_name in result_paths] self._download_file_list(local_paths, result_paths, directory_client) - logger.info('Finish downloading files from {} to {}.', - target_adls_directory, local_dir_cache) - + logger.info("Finish downloading files from {} to {}.", target_adls_directory, local_dir_cache) + def _download_file_list(self, local_paths: List[str], result_paths, directory_client): - ''' + """ Download filelist to local - ''' - for idx, file_to_write in enumerate(tqdm(result_paths,desc="Downloading result files: ")): + """ + for idx, file_to_write in enumerate(tqdm(result_paths, desc="Downloading result files: ")): try: os.makedirs(os.path.dirname(local_paths[idx]), exist_ok=True) - local_file = open(local_paths[idx], 'wb') + local_file = open(local_paths[idx], "wb") file_client = directory_client.get_file_client(file_to_write) download = file_client.download_file() downloaded_bytes = download.readall() local_file.write(downloaded_bytes) - local_file.close() + local_file.close() except Exception as e: - logger.error(e) - - def _dir_exists(self, dir_path:str): - ''' + logger.error(e) + + def _dir_exists(self, dir_path: str): + """ Check if a directory in datalake already exists. Will also return the directory client - ''' - datalake_path_split = list(filter(None, re.split('/|@', dir_path))) + """ + datalake_path_split = list(filter(None, re.split("/|@", dir_path))) if len(datalake_path_split) <= 3: raise RuntimeError("Invalid directory path for datalake: {dir_path}") - dir_client = self.file_system_client.get_directory_client( - '/'.join(datalake_path_split[3:])) + dir_client = self.file_system_client.get_directory_client("/".join(datalake_path_split[3:])) return [dir_client, dir_client.exists()] - \ No newline at end of file diff --git a/feathr_project/feathr/spark_provider/feathr_configurations.py b/feathr_project/feathr/spark_provider/feathr_configurations.py index cc8c60824..56735c607 100644 --- a/feathr_project/feathr/spark_provider/feathr_configurations.py +++ b/feathr_project/feathr/spark_provider/feathr_configurations.py @@ -8,5 +8,6 @@ class SparkExecutionConfiguration: Returns: dict[str, str] """ - def __new__(cls, spark_execution_configuration = Dict[str, str]) -> Dict[str, str]: + + def __new__(cls, spark_execution_configuration=Dict[str, str]) -> Dict[str, str]: return spark_execution_configuration diff --git a/feathr_project/feathr/udf/_preprocessing_pyudf_manager.py b/feathr_project/feathr/udf/_preprocessing_pyudf_manager.py index c4f102566..e400995ff 100644 --- a/feathr_project/feathr/udf/_preprocessing_pyudf_manager.py +++ b/feathr_project/feathr/udf/_preprocessing_pyudf_manager.py @@ -12,21 +12,20 @@ # Some metadata that are only needed by Feathr -FEATHR_PYSPARK_METADATA = 'generated_feathr_pyspark_metadata' +FEATHR_PYSPARK_METADATA = "generated_feathr_pyspark_metadata" # UDFs that are provided by users and persisted by Feathr into this file. # It will be uploaded into the Pyspark cluster as a dependency for execution -FEATHR_CLIENT_UDF_FILE_NAME = 'client_udf_repo.py' +FEATHR_CLIENT_UDF_FILE_NAME = "client_udf_repo.py" # Pyspark driver code that is executed by the Pyspark driver -FEATHR_PYSPARK_DRIVER_FILE_NAME = 'feathr_pyspark_driver.py' -FEATHR_PYSPARK_DRIVER_TEMPLATE_FILE_NAME = 'feathr_pyspark_driver_template.py' +FEATHR_PYSPARK_DRIVER_FILE_NAME = "feathr_pyspark_driver.py" +FEATHR_PYSPARK_DRIVER_TEMPLATE_FILE_NAME = "feathr_pyspark_driver_template.py" # Feathr provided imports for pyspark UDFs all go here -PROVIDED_IMPORTS = ['\nfrom pyspark.sql import SparkSession, DataFrame\n'] + \ - ['from pyspark.sql.functions import *\n'] +PROVIDED_IMPORTS = ["\nfrom pyspark.sql import SparkSession, DataFrame\n"] + ["from pyspark.sql.functions import *\n"] class _PreprocessingPyudfManager(object): - """This class manages Pyspark UDF preprocessing related artifacts, like UDFs from users, the pyspark_client etc. - """ + """This class manages Pyspark UDF preprocessing related artifacts, like UDFs from users, the pyspark_client etc.""" + @staticmethod def build_anchor_preprocessing_metadata(anchor_list: List[FeatureAnchor], local_workspace_dir): """When the client build features, UDFs and features that need preprocessing will be stored as metadata. Those @@ -43,7 +42,7 @@ def build_anchor_preprocessing_metadata(anchor_list: List[FeatureAnchor], local_ pyspark_driver_path = os.path.join(local_workspace_dir, FEATHR_PYSPARK_DRIVER_FILE_NAME) # delete the file if it already exists to avoid caching previous results - for f in [client_udf_repo_path, metadata_path, pyspark_driver_path]: + for f in [client_udf_repo_path, metadata_path, pyspark_driver_path]: if os.path.exists(f): os.remove(f) @@ -57,9 +56,11 @@ def build_anchor_preprocessing_metadata(anchor_list: List[FeatureAnchor], local_ feature_names = [feature.name for feature in anchor.features] features_with_preprocessing = features_with_preprocessing + feature_names feature_names.sort() - string_feature_list = ','.join(feature_names) + string_feature_list = ",".join(feature_names) if isinstance(anchor.source.preprocessing, str): - feature_names_to_func_mapping[string_feature_list] = _PreprocessingPyudfManager._parse_function_str_for_name(anchor.source.preprocessing) + feature_names_to_func_mapping[ + string_feature_list + ] = _PreprocessingPyudfManager._parse_function_str_for_name(anchor.source.preprocessing) else: # it's a callable function feature_names_to_func_mapping[string_feature_list] = anchor.source.preprocessing.__name__ @@ -67,16 +68,17 @@ def build_anchor_preprocessing_metadata(anchor_list: List[FeatureAnchor], local_ if not features_with_preprocessing: return - _PreprocessingPyudfManager.write_feature_names_to_udf_name_file(feature_names_to_func_mapping, local_workspace_dir) + _PreprocessingPyudfManager.write_feature_names_to_udf_name_file( + feature_names_to_func_mapping, local_workspace_dir + ) # Save necessary preprocessing-related metadata locally in your workspace # Typically it's used as a metadata for join/gen job to figure out if there is preprocessing UDF # exist for the features requested feathr_pyspark_metadata_abs_path = os.path.join(local_workspace_dir, FEATHR_PYSPARK_METADATA) - with open(feathr_pyspark_metadata_abs_path, 'wb') as file: + with open(feathr_pyspark_metadata_abs_path, "wb") as file: pickle.dump(features_with_preprocessing, file) - @staticmethod def _parse_function_str_for_name(fn_str: str) -> str: """Use AST to parse the function string and get the name out. @@ -100,7 +102,6 @@ def _parse_function_str_for_name(fn_str: str) -> str: # Get the function name from the function definition. return tree.body[0].name - @staticmethod def persist_pyspark_udf_to_file(user_func, local_workspace_dir): """persist the pyspark UDF to a file in `local_workspace_dir` for later usage. @@ -116,7 +117,7 @@ def persist_pyspark_udf_to_file(user_func, local_workspace_dir): # Some basic imports will be provided lines = lines + PROVIDED_IMPORTS lines = lines + udf_source_code - lines.append('\n') + lines.append("\n") client_udf_repo_path = os.path.join(local_workspace_dir, FEATHR_CLIENT_UDF_FILE_NAME) @@ -134,13 +135,15 @@ def write_feature_names_to_udf_name_file(feature_names_to_func_mapping, local_wo """ # indent in since python needs correct indentation # Don't change the indentation - tm = Template(""" + tm = Template( + """ feature_names_funcs = { {% for key, value in func_maps.items() %} "{{key}}" : {{value}}, {% endfor %} } - """) + """ + ) new_file = tm.render(func_maps=feature_names_to_func_mapping) full_file_name = os.path.join(local_workspace_dir, FEATHR_CLIENT_UDF_FILE_NAME) @@ -163,7 +166,7 @@ def prepare_pyspark_udf_files(feature_names: List[str], local_workspace_dir): # if the preprocessing metadata file doesn't exist or is empty, then we just skip if not Path(feathr_pyspark_metadata_abs_path).is_file(): return py_udf_files - with open(feathr_pyspark_metadata_abs_path, 'rb') as pyspark_metadata_file: + with open(feathr_pyspark_metadata_abs_path, "rb") as pyspark_metadata_file: features_with_preprocessing = pickle.load(pyspark_metadata_file) # if there is not features that needs preprocessing, just return. if not features_with_preprocessing: @@ -180,22 +183,24 @@ def prepare_pyspark_udf_files(feature_names: List[str], local_workspace_dir): if has_py_udf_preprocessing: pyspark_driver_path = os.path.join(local_workspace_dir, FEATHR_PYSPARK_DRIVER_FILE_NAME) - pyspark_driver_template_abs_path = str(Path(Path(__file__).parent / FEATHR_PYSPARK_DRIVER_TEMPLATE_FILE_NAME).absolute()) + pyspark_driver_template_abs_path = str( + Path(Path(__file__).parent / FEATHR_PYSPARK_DRIVER_TEMPLATE_FILE_NAME).absolute() + ) client_udf_repo_path = os.path.join(local_workspace_dir, FEATHR_CLIENT_UDF_FILE_NAME) # write pyspark_driver_template_abs_path and then client_udf_repo_path filenames = [pyspark_driver_template_abs_path, client_udf_repo_path] - with open(pyspark_driver_path, 'w') as outfile: + with open(pyspark_driver_path, "w") as outfile: for fname in filenames: with open(fname) as infile: for line in infile: outfile.write(line) lines = [ - '\n', + "\n", 'print("pyspark_client.py: Preprocessing via UDFs and submit Spark job.")\n', - 'submit_spark_job(feature_names_funcs)\n', + "submit_spark_job(feature_names_funcs)\n", 'print("pyspark_client.py: Feathr Pyspark job completed.")\n', - '\n', + "\n", ] with open(pyspark_driver_path, "a") as handle: print("".join(lines), file=handle) diff --git a/feathr_project/feathr/udf/feathr_pyspark_driver_template.py b/feathr_project/feathr/udf/feathr_pyspark_driver_template.py index e964cd12d..f5e0362d8 100644 --- a/feathr_project/feathr/udf/feathr_pyspark_driver_template.py +++ b/feathr_project/feathr/udf/feathr_pyspark_driver_template.py @@ -1,4 +1,3 @@ - from pyspark.sql import SparkSession, DataFrame, SQLContext import sys from pyspark.sql.functions import * @@ -6,12 +5,11 @@ # This is executed in Spark driver # The logger doesn't work in Pyspark so we just use print print("Feathr Pyspark job started.") -spark = SparkSession.builder.appName('FeathrPyspark').getOrCreate() +spark = SparkSession.builder.appName("FeathrPyspark").getOrCreate() def to_java_string_array(arr): - """Convert a Python string list to a Java String array. - """ + """Convert a Python string list to a Java String array.""" jarr = spark._sc._gateway.new_array(spark._sc._jvm.java.lang.String, len(arr)) for i in range(len(arr)): jarr[i] = arr[i] @@ -34,15 +32,16 @@ def submit_spark_job(feature_names_funcs): # For example: ['pyspark_client.py', '--join-config', 'abfss://...', ...] has_gen_config = False has_join_config = False - if '--generation-config' in sys.argv: + if "--generation-config" in sys.argv: has_gen_config = True - if '--join-config' in sys.argv: + if "--join-config" in sys.argv: has_join_config = True py4j_feature_job = None if has_gen_config and has_join_config: - raise RuntimeError("Both FeatureGenConfig and FeatureJoinConfig are provided. " - "Only one of them should be provided.") + raise RuntimeError( + "Both FeatureGenConfig and FeatureJoinConfig are provided. " "Only one of them should be provided." + ) elif has_gen_config: py4j_feature_job = spark._jvm.com.linkedin.feathr.offline.job.FeatureGenJob print("FeatureGenConfig is provided. Executing FeatureGenJob.") @@ -50,8 +49,9 @@ def submit_spark_job(feature_names_funcs): py4j_feature_job = spark._jvm.com.linkedin.feathr.offline.job.FeatureJoinJob print("FeatureJoinConfig is provided. Executing FeatureJoinJob.") else: - raise RuntimeError("None of FeatureGenConfig and FeatureJoinConfig are provided. " - "One of them should be provided.") + raise RuntimeError( + "None of FeatureGenConfig and FeatureJoinConfig are provided. " "One of them should be provided." + ) job_param_java_array = to_java_string_array(sys.argv) print("submit_spark_job: feature_names_funcs: ") @@ -61,7 +61,9 @@ def submit_spark_job(feature_names_funcs): print("submit_spark_job: Load DataFrame from Scala engine.") - dataframeFromSpark = py4j_feature_job.loadSourceDataframe(job_param_java_array, set(feature_names_funcs.keys())) # TODO: Add data handler support here + dataframeFromSpark = py4j_feature_job.loadSourceDataframe( + job_param_java_array, set(feature_names_funcs.keys()) + ) # TODO: Add data handler support here print("Submit_spark_job: dataframeFromSpark: ") print(dataframeFromSpark) @@ -84,4 +86,3 @@ def submit_spark_job(feature_names_funcs): py4j_feature_job.mainWithPreprocessedDataFrame(job_param_java_array, new_preprocessed_df_map) return None - diff --git a/feathr_project/feathr/utils/_env_config_reader.py b/feathr_project/feathr/utils/_env_config_reader.py index 334f2bc93..060b92abb 100644 --- a/feathr_project/feathr/utils/_env_config_reader.py +++ b/feathr_project/feathr/utils/_env_config_reader.py @@ -7,12 +7,14 @@ from azure.core.exceptions import ResourceNotFoundError from feathr.secrets.akv_client import AzureKeyVaultClient + class EnvConfigReader(object): """A utility class to read Feathr environment variables either from os environment variables, the config yaml file or Azure Key Vault. If a key is set in the environment variable, ConfigReader will return the value of that environment variable. """ - akv_name: str = None # Azure Key Vault name to use for retrieving config values. + + akv_name: str = None # Azure Key Vault name to use for retrieving config values. yaml_config: dict = None # YAML config file content. def __init__(self, config_path: str): @@ -57,7 +59,9 @@ def get(self, key: str, default: str = None) -> str: if val is not None: return val else: - logger.info(f"Config {key} is not found in the environment variable, configuration file, or the remote key value store. Returning the default value: {default}.") + logger.info( + f"Config {key} is not found in the environment variable, configuration file, or the remote key value store. Returning the default value: {default}." + ) return default def get_from_env_or_akv(self, key: str) -> str: diff --git a/feathr_project/feathr/utils/_file_utils.py b/feathr_project/feathr/utils/_file_utils.py index c81208e79..bbbf9d33f 100644 --- a/feathr_project/feathr/utils/_file_utils.py +++ b/feathr_project/feathr/utils/_file_utils.py @@ -11,4 +11,4 @@ def write_to_file(content: str, full_file_name: str): dir_name = os.path.dirname(full_file_name) Path(dir_name).mkdir(parents=True, exist_ok=True) with open(full_file_name, "w") as handle: - print(content, file=handle) \ No newline at end of file + print(content, file=handle) diff --git a/feathr_project/feathr/utils/config.py b/feathr_project/feathr/utils/config.py index 5f92ccb59..63cb1322c 100644 --- a/feathr_project/feathr/utils/config.py +++ b/feathr_project/feathr/utils/config.py @@ -27,7 +27,7 @@ "port": "6380", "ssl_enabled": "true", } - } + }, } # New databricks job cluster config @@ -162,8 +162,9 @@ def _set_azure_synapse_config( config["spark_config"]["azure_synapse"]["dev_url"] = f"https://{resource_prefix}syws.dev.azuresynapse.net" if not config["spark_config"]["azure_synapse"].get("workspace_dir"): - config["spark_config"]["azure_synapse"]["workspace_dir"] =\ - f"abfss://{resource_prefix}fs@{resource_prefix}dls.dfs.core.windows.net/{project_name}" + config["spark_config"]["azure_synapse"][ + "workspace_dir" + ] = f"abfss://{resource_prefix}fs@{resource_prefix}dls.dfs.core.windows.net/{project_name}" for k, v in DEFAULT_AZURE_SYNAPSE_SPARK_POOL_CONFIG.items(): if not config["spark_config"]["azure_synapse"].get(k): @@ -235,13 +236,13 @@ def _verify_config(config: Dict): if not os.environ.get("ADLS_KEY"): raise ValueError("ADLS_KEY must be set in environment variables") elif ( - not os.environ.get("SPARK_CONFIG__AZURE_SYNAPSE__DEV_URL") and - config["spark_config"]["azure_synapse"].get("dev_url") is None + not os.environ.get("SPARK_CONFIG__AZURE_SYNAPSE__DEV_URL") + and config["spark_config"]["azure_synapse"].get("dev_url") is None ): raise ValueError("Azure Synapse dev endpoint is not provided.") elif ( - not os.environ.get("SPARK_CONFIG__AZURE_SYNAPSE__POOL_NAME") and - config["spark_config"]["azure_synapse"].get("pool_name") is None + not os.environ.get("SPARK_CONFIG__AZURE_SYNAPSE__POOL_NAME") + and config["spark_config"]["azure_synapse"].get("pool_name") is None ): raise ValueError("Azure Synapse pool name is not provided.") @@ -249,8 +250,8 @@ def _verify_config(config: Dict): if not os.environ.get("DATABRICKS_WORKSPACE_TOKEN_VALUE"): raise ValueError("Databricks workspace token is not provided.") elif ( - not os.environ.get("SPARK_CONFIG__DATABRICKS__WORKSPACE_INSTANCE_URL") and - config["spark_config"]["databricks"].get("workspace_instance_url") is None + not os.environ.get("SPARK_CONFIG__DATABRICKS__WORKSPACE_INSTANCE_URL") + and config["spark_config"]["databricks"].get("workspace_instance_url") is None ): raise ValueError("Databricks workspace url is not provided.") diff --git a/feathr_project/feathr/utils/dsl/dsl_generator.py b/feathr_project/feathr/utils/dsl/dsl_generator.py index 7e8effaa9..ed224c08d 100644 --- a/feathr_project/feathr/utils/dsl/dsl_generator.py +++ b/feathr_project/feathr/utils/dsl/dsl_generator.py @@ -1,4 +1,3 @@ - import re from enum import Enum from typing import List @@ -8,145 +7,175 @@ from feathr.definition.transformation import ExpressionTransformation, WindowAggTransformation from functions import SUPPORTED_FUNCTIONS + class Token: def __init__(self, name, value): self.name = name self.value = value + def __str__(self): - res = 'Token({}, {})'.format( - self.name, repr(self.value)) + res = "Token({}, {})".format(self.name, repr(self.value)) return res + def __repr__(self): return self.__str__() + def is_identifier(self): return self.name == Tokenizer.identifier.name + def is_number(self): return self.name == Tokenizer.number.name + def is_operator(self): - return hasattr(Operator, self.name) + return hasattr(Operator, self.name) + def is_new_line(self): return self.name == Tokenizer.new_line.name + def is_eof(self): return self.name == Tokenizer.eof.name + class Operator(Enum): # 2 characters - power = '**' - less_equal = '<=' - greater_equal = '>=' - not_equal = '!=' + power = "**" + less_equal = "<=" + greater_equal = ">=" + not_equal = "!=" # 1 character - plus = '+' - minus = '-' - multiply = '*' - divide = '/' - mod = '%' - equal = '=' - less_than = '<' - greater_than = '>' - left_curly = '{' - left_paren = '(' - left_square = '[' - right_curly = '}' - right_paren = ')' - right_square = ']' - comma = ',' + plus = "+" + minus = "-" + multiply = "*" + divide = "/" + mod = "%" + equal = "=" + less_than = "<" + greater_than = ">" + left_curly = "{" + left_paren = "(" + left_square = "[" + right_curly = "}" + right_paren = ")" + right_square = "]" + comma = "," + class Tokenizer(Enum): - comment = r'#[^\r\n]*' - space = r'[ \t]+' - identifier = r'[a-zA-Z_][a-zA-Z_0-9]*' - number = r'[0-9]+(?:\.[0-9]*)?' - operator = r'\*\*|[<>!]=|[-+*/%=<>()[\]{},]' - new_line = r'[\r\n]' - eof = r'$' - error = r'(.+?)' + comment = r"#[^\r\n]*" + space = r"[ \t]+" + identifier = r"[a-zA-Z_][a-zA-Z_0-9]*" + number = r"[0-9]+(?:\.[0-9]*)?" + operator = r"\*\*|[<>!]=|[-+*/%=<>()[\]{},]" + new_line = r"[\r\n]" + eof = r"$" + error = r"(.+?)" + @classmethod def _build_pattern(cls): cls.names = [x.name for x in cls] - cls.regex = '|'.join('({})'.format(x.value) for x in cls) + cls.regex = "|".join("({})".format(x.value) for x in cls) cls.pattern = re.compile(cls.regex) + @classmethod def token_iter(cls, text): - ''' text to token iter. + """text to token iter. Args: text: string for tokenization. Returns: Iteration object of generated tokens. - ''' + """ for match in cls.pattern.finditer(text): - name = cls.names[match.lastindex-1] + name = cls.names[match.lastindex - 1] # skip space and comment - if (name == cls.space.name - or name == cls.comment.name): + if name == cls.space.name or name == cls.comment.name: continue # raise error elif name == cls.error.name: - print(text[match.start():]) - raise Exception('Invalid Syntax.') + print(text[match.start() :]) + raise Exception("Invalid Syntax.") value = match.group() # operator name if name == cls.operator.name: name = Operator(value).name token = Token(name, value) yield token + + Tokenizer._build_pattern() + class AST: def __init__(self, token): self.token = token + def __repr__(self): return self.__str__() + def __str__(self): return str(self.token) + class AtomOp(AST): def __init__(self, token, is_func=False): self.token = token self.value = token.value self.is_func = is_func + def __str__(self): if self.token.is_operator(): return self.token.name return self.value + + class FuncOp(AST): def __init__(self, func=None, ops=None): self.func = func self.ops = ops + def __str__(self): - pstr = ', '.join([str(x) for x in self.ops]) - return '{}({})'.format(self.func, pstr) + pstr = ", ".join([str(x) for x in self.ops]) + return "{}({})".format(self.func, pstr) + + class VectorOp(AST): def __init__(self, ops=None): self.ops = ops + def __str__(self): - pstr = ', '.join([str(x) for x in self.ops]) - return '[{}]'.format(pstr) + pstr = ", ".join([str(x) for x in self.ops]) + return "[{}]".format(pstr) + + class SetOp(AST): def __init__(self, ops=None): self.ops = ops + def __str__(self): - pstr = ', '.join([str(x) for x in self.ops]) - return '{'+pstr+'}' + pstr = ", ".join([str(x) for x in self.ops]) + return "{" + pstr + "}" + class Parser: def __init__(self, token_iter): - '''Parser. + """Parser. Args: token_iter: token iterator returned by Tokenizer. - ''' + """ self.token_iter = token_iter self.forward() + def forward(self): - '''Set current token by next(token_iter). ''' + """Set current token by next(token_iter).""" self.current_token = next(self.token_iter) + def error(self, error_code, token): - '''Trigger by unexpected input. ''' + """Trigger by unexpected input.""" raise ValueError( - f'{error_code} -> {token}', + f"{error_code} -> {token}", ) + def parse(self): - """ generate ast. """ + """generate ast.""" node = self.expr() if not self.current_token.is_eof(): @@ -157,7 +186,7 @@ def parse(self): return node def expr(self): - """ expr: set_expr|vec_expr|add_expr """ + """expr: set_expr|vec_expr|add_expr""" if self.current_token.name == Operator.left_curly.name: node = self.set_expr() elif self.current_token.name == Operator.left_square.name: @@ -165,8 +194,9 @@ def expr(self): else: node = self.add_expr() return node + def set_expr(self): - """ set_expr: { arglist } """ + """set_expr: { arglist }""" assert self.current_token.name == Operator.left_curly.name self.forward() ops = self.arglist() @@ -174,38 +204,39 @@ def set_expr(self): self.forward() node = SetOp(ops=ops) return node + def vec_expr(self): - """ vec_expr: [ arglist ] """ + """vec_expr: [ arglist ]""" assert self.current_token.name == Operator.left_square.name - self.forward() + self.forward() ops = self.arglist() assert self.current_token.name == Operator.right_square.name self.forward() node = VectorOp(ops=ops) return node + def add_expr(self): - """ add_expr: mul_expr ([+-] mul_expr)* """ + """add_expr: mul_expr ([+-] mul_expr)*""" node = self.mul_expr() - while self.current_token.name in (Operator.plus.name, - Operator.minus.name): + while self.current_token.name in (Operator.plus.name, Operator.minus.name): op = AtomOp(self.current_token, is_func=True) self.forward() node = FuncOp(func=op, ops=[node, self.mul_expr()]) return node + def mul_expr(self): - """ mul_expr: factor ([*/%] factor)* """ + """mul_expr: factor ([*/%] factor)*""" node = self.factor() - while self.current_token.name in (Operator.multiply.name, - Operator.divide.name, Operator.mod.name): + while self.current_token.name in (Operator.multiply.name, Operator.divide.name, Operator.mod.name): op = AtomOp(self.current_token, is_func=True) self.forward() node = FuncOp(func=op, ops=[node, self.factor()]) return node + def factor(self): - """ factor: [+-] factor | power """ + """factor: [+-] factor | power""" token = self.current_token - if (token.name == Operator.minus.name - or token.name == Operator.plus.name): + if token.name == Operator.minus.name or token.name == Operator.plus.name: op = AtomOp(token, is_func=True) self.forward() node = FuncOp(func=op, ops=[self.factor()]) @@ -213,8 +244,9 @@ def factor(self): else: node = self.power() return node + def power(self): - """ power: term [** factor] """ + """power: term [** factor]""" node = self.term() if self.current_token.name == Operator.power.name: op = AtomOp(self.current_token, is_func=True) @@ -222,23 +254,24 @@ def power(self): node = FuncOp(func=op, ops=[node, self.factor()]) return node return node + def term(self): - """ term: function | ( add_expr ) """ + """term: function | ( add_expr )""" if self.current_token.name == Operator.left_paren.name: self.forward() node = self.add_expr() assert self.current_token.name == Operator.right_paren.name self.forward() else: - node = self.function() + node = self.function() return node + def function(self): - """ function: atom ( arglist? ) """ - assert (self.current_token.is_identifier() - or self.current_token.is_number()) + """function: atom ( arglist? )""" + assert self.current_token.is_identifier() or self.current_token.is_number() token = self.current_token self.forward() - is_func = (self.current_token.name == Operator.left_paren.name) + is_func = self.current_token.name == Operator.left_paren.name node = AtomOp(token, is_func=is_func) if self.current_token.name == Operator.left_paren.name: self.forward() @@ -249,23 +282,27 @@ def function(self): self.forward() node = FuncOp(func=node, ops=arglist) return node + def arglist(self): - """ arglist: expr (, expr)* """ + """arglist: expr (, expr)*""" res = [self.expr()] while self.current_token.name == Operator.comma.name: self.forward() res.append(self.expr()) return res + def parse(txt) -> AST: - """ parse txt to AST. """ + """parse txt to AST.""" return Parser(Tokenizer.token_iter(txt)).parse() + def get_identifiers(txt) -> list: ast = parse(txt) s = set() return collect_id(ast, s) + def collect_id(ast, s): if isinstance(ast, AtomOp): # if ast.is_func: @@ -286,11 +323,12 @@ def collect_id(ast, s): raise ValueError(f"unknown ast type: {ast}") return s + def gen_dsl(name: str, features: List[Feature]): """Generate a dsl file for the given features""" - + layers = [] - + # Add all upstreams to the current_features current_features = features.copy() while True: @@ -303,9 +341,9 @@ def gen_dsl(name: str, features: List[Feature]): if len(cf) == len(current_features): break current_features = cf - + feature_names = set([f.name for f in current_features]) - + # Topological sort the features while current_features: current_layer = set() @@ -325,7 +363,7 @@ def gen_dsl(name: str, features: List[Feature]): current_features.remove(f) current_layer.add(f) layers.append(current_layer) - + identifiers = set() stages = [] for l in layers: @@ -350,6 +388,8 @@ def gen_dsl(name: str, features: List[Feature]): if unsupported_func: raise NotImplementedError(f"Feature {f.name} uses unsupported function {unsupported_func}") stages.append(f'| project {", ".join(t)}') - stages.append(f'| project-keep {", ".join([f.name for f in features])}', ) + stages.append( + f'| project-keep {", ".join([f.name for f in features])}', + ) schema = f'({", ".join(identifiers)})' - return "\n".join([f'{name}{schema}', "\n".join(stages), ";"]) + return "\n".join([f"{name}{schema}", "\n".join(stages), ";"]) diff --git a/feathr_project/feathr/utils/dsl/functions.py b/feathr_project/feathr/utils/dsl/functions.py index c7cef0787..55d327552 100644 --- a/feathr_project/feathr/utils/dsl/functions.py +++ b/feathr_project/feathr/utils/dsl/functions.py @@ -3,185 +3,185 @@ # Do not edit this file directly. ###################################################################### SUPPORTED_FUNCTIONS = [ -"abs", -"acos", -"acosh", -"add_months", -"array", -"array_contains", -"array_distinct", -"array_except", -"array_intersect", -"array_join", -"array_max", -"array_min", -"array_position", -"array_remove", -"array_repeat", -"array_size", -"array_union", -"arrays_overlap", -"arrays_zip", -"ascii", -"asin", -"asinh", -"atan", -"atan2", -"atanh", -"bigint", -"bit_and", -"bit_count", -"bit_get", -"bit_length", -"bit_not", -"bit_or", -"bit_xor", -"bool_and", -"bool_or", -"boolean", -"btrim", -"bucket", -"case", -"cbrt", -"ceil", -"ceiling", -"char", -"char_length", -"character_length", -"chr", -"coalesce", -"concat", -"concat_ws", -"contains", -"conv", -"cos", -"cosh", -"cot", -"csc", -"current_date", -"current_timestamp", -"current_timezone", -"date", -"date_add", -"date_diff", -"date_from_unix_date", -"date_sub", -"day", -"dayofmonth", -"dayofweek", -"dayofyear", -"degrees", -"distance", -"double", -"e", -"element_at", -"elt", -"endswith", -"every", -"exp", -"expm1", -"factorial", -"flatten", -"float", -"floor", -"from_utc_timestamp", -"get_json_array", -"get_json_object", -"getbit", -"hour", -"hypot", -"if", -"ifnull", -"instr", -"int", -"isnan", -"isnotnull", -"isnull", -"json_array_length", -"json_object_keys", -"last_day", -"lcase", -"len", -"length", -"levenshtein", -"ln", -"log", -"log10", -"log1p", -"log2", -"lower", -"ltrim", -"make_date", -"make_timestamp", -"map_contains_key", -"map_from_arrays", -"map_keys", -"map_values", -"minute", -"mod", -"month", -"nanvl", -"next_day", -"now", -"nullif", -"nvl", -"nvl2", -"pi", -"positive", -"pow", -"power", -"quarter", -"radians", -"rand", -"random", -"regexp", -"regexp_extract", -"regexp_extract_all", -"regexp_like", -"regexp_replace", -"repeat", -"round", -"rtrim", -"sec", -"second", -"shiftleft", -"shiftright", -"shiftrightunsigned", -"shuffle", -"sign", -"signum", -"sin", -"sinh", -"size", -"slice", -"space", -"split", -"split_part", -"sqrt", -"startswith", -"string", -"substring", -"substring_index", -"tan", -"tanh", -"timestamp", -"timestamp_micros", -"timestamp_millis", -"timestamp_seconds", -"to_json", -"to_unix_timestamp", -"to_utc_timestamp", -"translate", -"trim", -"ucase", -"unix_date", -"unix_micros", -"unix_millis", -"unix_seconds", -"unix_timestamp", -"upper", -"uuid", -"weekday", -"weekofyear", -"year", + "abs", + "acos", + "acosh", + "add_months", + "array", + "array_contains", + "array_distinct", + "array_except", + "array_intersect", + "array_join", + "array_max", + "array_min", + "array_position", + "array_remove", + "array_repeat", + "array_size", + "array_union", + "arrays_overlap", + "arrays_zip", + "ascii", + "asin", + "asinh", + "atan", + "atan2", + "atanh", + "bigint", + "bit_and", + "bit_count", + "bit_get", + "bit_length", + "bit_not", + "bit_or", + "bit_xor", + "bool_and", + "bool_or", + "boolean", + "btrim", + "bucket", + "case", + "cbrt", + "ceil", + "ceiling", + "char", + "char_length", + "character_length", + "chr", + "coalesce", + "concat", + "concat_ws", + "contains", + "conv", + "cos", + "cosh", + "cot", + "csc", + "current_date", + "current_timestamp", + "current_timezone", + "date", + "date_add", + "date_diff", + "date_from_unix_date", + "date_sub", + "day", + "dayofmonth", + "dayofweek", + "dayofyear", + "degrees", + "distance", + "double", + "e", + "element_at", + "elt", + "endswith", + "every", + "exp", + "expm1", + "factorial", + "flatten", + "float", + "floor", + "from_utc_timestamp", + "get_json_array", + "get_json_object", + "getbit", + "hour", + "hypot", + "if", + "ifnull", + "instr", + "int", + "isnan", + "isnotnull", + "isnull", + "json_array_length", + "json_object_keys", + "last_day", + "lcase", + "len", + "length", + "levenshtein", + "ln", + "log", + "log10", + "log1p", + "log2", + "lower", + "ltrim", + "make_date", + "make_timestamp", + "map_contains_key", + "map_from_arrays", + "map_keys", + "map_values", + "minute", + "mod", + "month", + "nanvl", + "next_day", + "now", + "nullif", + "nvl", + "nvl2", + "pi", + "positive", + "pow", + "power", + "quarter", + "radians", + "rand", + "random", + "regexp", + "regexp_extract", + "regexp_extract_all", + "regexp_like", + "regexp_replace", + "repeat", + "round", + "rtrim", + "sec", + "second", + "shiftleft", + "shiftright", + "shiftrightunsigned", + "shuffle", + "sign", + "signum", + "sin", + "sinh", + "size", + "slice", + "space", + "split", + "split_part", + "sqrt", + "startswith", + "string", + "substring", + "substring_index", + "tan", + "tanh", + "timestamp", + "timestamp_micros", + "timestamp_millis", + "timestamp_seconds", + "to_json", + "to_unix_timestamp", + "to_utc_timestamp", + "translate", + "trim", + "ucase", + "unix_date", + "unix_micros", + "unix_millis", + "unix_seconds", + "unix_timestamp", + "upper", + "uuid", + "weekday", + "weekofyear", + "year", ] diff --git a/feathr_project/feathr/utils/dsl/test_dsl_generator.py b/feathr_project/feathr/utils/dsl/test_dsl_generator.py index eee08e5cd..381571fb5 100644 --- a/feathr_project/feathr/utils/dsl/test_dsl_generator.py +++ b/feathr_project/feathr/utils/dsl/test_dsl_generator.py @@ -6,30 +6,42 @@ import random from datetime import datetime, timedelta -from feathr import (BOOLEAN, FLOAT, INPUT_CONTEXT, INT32, STRING, - DerivedFeature, Feature, FeatureAnchor, HdfsSource, - TypedKey, ValueType, WindowAggTransformation) +from feathr import ( + BOOLEAN, + FLOAT, + INPUT_CONTEXT, + INT32, + STRING, + DerivedFeature, + Feature, + FeatureAnchor, + HdfsSource, + TypedKey, + ValueType, + WindowAggTransformation, +) from feathr import FeathrClient from feathr.definition.transformation import ExpressionTransformation import dsl_generator -batch_source = HdfsSource(name="nycTaxiBatchSource", - path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04.csv", - event_timestamp_column="lpep_dropoff_datetime", - timestamp_format="yyyy-MM-dd HH:mm:ss") +batch_source = HdfsSource( + name="nycTaxiBatchSource", + path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04.csv", + event_timestamp_column="lpep_dropoff_datetime", + timestamp_format="yyyy-MM-dd HH:mm:ss", +) -f_trip_distance = Feature(name="f_trip_distance", - feature_type=FLOAT, transform="trip_distance") -f_trip_time_duration = Feature(name="f_trip_time_duration", - feature_type=INT32, - transform="(to_unix_timestamp(lpep_dropoff_datetime) - to_unix_timestamp(lpep_pickup_datetime))/60") -f_is_long_trip_distance = Feature(name="f_is_long_trip_distance", - feature_type=BOOLEAN, - transform="cast_float(trip_distance)>30"), -f_day_of_week = Feature(name="f_day_of_week", - feature_type=INT32, - transform="some_fancy_func(lpep_dropoff_datetime)") +f_trip_distance = Feature(name="f_trip_distance", feature_type=FLOAT, transform="trip_distance") +f_trip_time_duration = Feature( + name="f_trip_time_duration", + feature_type=INT32, + transform="(to_unix_timestamp(lpep_dropoff_datetime) - to_unix_timestamp(lpep_pickup_datetime))/60", +) +f_is_long_trip_distance = ( + Feature(name="f_is_long_trip_distance", feature_type=BOOLEAN, transform="cast_float(trip_distance)>30"), +) +f_day_of_week = Feature(name="f_day_of_week", feature_type=INT32, transform="some_fancy_func(lpep_dropoff_datetime)") features = [ f_trip_distance, f_trip_time_duration, @@ -38,45 +50,47 @@ ] -request_anchor = FeatureAnchor(name="request_features", - source=INPUT_CONTEXT, - features=features) +request_anchor = FeatureAnchor(name="request_features", source=INPUT_CONTEXT, features=features) -f_trip_time_distance = DerivedFeature(name="f_trip_time_distance", - feature_type=FLOAT, - input_features=[ - f_trip_distance, f_trip_time_duration], - transform="f_trip_distance * f_trip_time_duration") +f_trip_time_distance = DerivedFeature( + name="f_trip_time_distance", + feature_type=FLOAT, + input_features=[f_trip_distance, f_trip_time_duration], + transform="f_trip_distance * f_trip_time_duration", +) -f_trip_time_rounded = DerivedFeature(name="f_trip_time_rounded", - feature_type=INT32, - input_features=[f_trip_time_duration], - transform="f_trip_time_duration % 10") +f_trip_time_rounded = DerivedFeature( + name="f_trip_time_rounded", + feature_type=INT32, + input_features=[f_trip_time_duration], + transform="f_trip_time_duration % 10", +) -location_id = TypedKey(key_column="DOLocationID", - key_column_type=ValueType.INT32, - description="location id in NYC", - full_name="nyc_taxi.location_id") -f_location_avg_fare = Feature(name="f_location_avg_fare", - key=location_id, - feature_type=FLOAT, - transform=WindowAggTransformation(agg_expr="cast_float(fare_amount)", - agg_func="AVG", - window="90d", - filter="fare_amount > 0" - )) -agg_features = [f_location_avg_fare, - Feature(name="f_location_max_fare", - key=location_id, - feature_type=FLOAT, - transform=WindowAggTransformation(agg_expr="cast_float(fare_amount)", - agg_func="MAX", - window="90d")) - ] +location_id = TypedKey( + key_column="DOLocationID", + key_column_type=ValueType.INT32, + description="location id in NYC", + full_name="nyc_taxi.location_id", +) +f_location_avg_fare = Feature( + name="f_location_avg_fare", + key=location_id, + feature_type=FLOAT, + transform=WindowAggTransformation( + agg_expr="cast_float(fare_amount)", agg_func="AVG", window="90d", filter="fare_amount > 0" + ), +) +agg_features = [ + f_location_avg_fare, + Feature( + name="f_location_max_fare", + key=location_id, + feature_type=FLOAT, + transform=WindowAggTransformation(agg_expr="cast_float(fare_amount)", agg_func="MAX", window="90d"), + ), +] -agg_anchor = FeatureAnchor(name="aggregationFeatures", - source=batch_source, - features=agg_features) +agg_anchor = FeatureAnchor(name="aggregationFeatures", source=batch_source, features=agg_features) # This should work @@ -89,11 +103,11 @@ raise Exception("Should have failed") except NotImplementedError as e: pass - + # This will also fail because we don't support WindowAggTransformation try: dsl_generator.gen_dsl("test_pipeline", [f_location_avg_fare]) raise Exception("Should have failed") except: - pass \ No newline at end of file + pass diff --git a/feathr_project/feathr/utils/feature_printer.py b/feathr_project/feathr/utils/feature_printer.py index fcd9d6c14..869c59eef 100644 --- a/feathr_project/feathr/utils/feature_printer.py +++ b/feathr_project/feathr/utils/feature_printer.py @@ -5,9 +5,10 @@ from feathr.definition.query_feature_list import FeatureQuery from feathr.definition.materialization_settings import MaterializationSettings + class FeaturePrinter: """The class for pretty-printing features""" - + @staticmethod def pretty_print_anchors(anchor_list: List[FeatureAnchor]) -> None: """Pretty print features @@ -18,8 +19,7 @@ def pretty_print_anchors(anchor_list: List[FeatureAnchor]) -> None: if all(isinstance(anchor, FeatureAnchor) for anchor in anchor_list): for anchor in anchor_list: - pprint("%s is the achor of %s" % \ - (anchor.name, [feature.name for feature in anchor.features])) + pprint("%s is the achor of %s" % (anchor.name, [feature.name for feature in anchor.features])) else: raise TypeError("anchor_list must be FeatureAnchor or List[FeatureAnchor]") @@ -46,4 +46,4 @@ def pretty_print_materialize_features(settings: MaterializationSettings) -> None if isinstance(settings, MaterializationSettings): print("Materialization features in settings: %s" % settings.feature_names) else: - raise TypeError("settings must be MaterializationSettings") \ No newline at end of file + raise TypeError("settings must be MaterializationSettings") diff --git a/feathr_project/feathr/utils/job_utils.py b/feathr_project/feathr/utils/job_utils.py index 7633870b4..128c6e9fd 100644 --- a/feathr_project/feathr/utils/job_utils.py +++ b/feathr_project/feathr/utils/job_utils.py @@ -12,6 +12,7 @@ from feathr.utils.platform import is_databricks from feathr.spark_provider._synapse_submission import _DataLakeFiler + def get_result_pandas_df( client: FeathrClient, data_format: str = None, @@ -73,7 +74,7 @@ def get_result_df( local_cache_path: str = None, spark: SparkSession = None, format: str = None, - is_file_path: bool = False + is_file_path: bool = False, ) -> Union[DataFrame, pd.DataFrame]: """Download the job result dataset from cloud as a Spark DataFrame or pandas DataFrame. @@ -106,11 +107,15 @@ def get_result_df( data_format = data_format.lower() if is_databricks() and client.spark_runtime != "databricks": - raise RuntimeError(f"The function is called from Databricks but the client.spark_runtime is {client.spark_runtime}.") + raise RuntimeError( + f"The function is called from Databricks but the client.spark_runtime is {client.spark_runtime}." + ) # TODO Loading Synapse Delta table result into pandas has a bug: https://github.com/delta-io/delta-rs/issues/582 if not spark and client.spark_runtime == "azure_synapse" and data_format == "delta": - raise RuntimeError(f"Loading Delta table result from Azure Synapse into pandas DataFrame is not supported. You maybe able to use spark DataFrame to load the result instead.") + raise RuntimeError( + f"Loading Delta table result from Azure Synapse into pandas DataFrame is not supported. You maybe able to use spark DataFrame to load the result instead." + ) # use a result url if it's provided by the user, otherwise use the one provided by the job res_url: str = res_url or client.get_job_result_uri(block=True, timeout_sec=1200) @@ -135,9 +140,7 @@ def get_result_df( if is_databricks(): # Check if the function is being called from Databricks if local_cache_path is not None: - logger.warning( - "Result files are already in DBFS and thus `local_cache_path` will be ignored." - ) + logger.warning("Result files are already in DBFS and thus `local_cache_path` will be ignored.") local_cache_path = res_url if local_cache_path is None: @@ -145,7 +148,9 @@ def get_result_df( if local_cache_path != res_url: logger.info(f"{res_url} files will be downloaded into {local_cache_path}") - client.feathr_spark_launcher.download_result(result_path=res_url, local_folder=local_cache_path, is_file_path = is_file_path) + client.feathr_spark_launcher.download_result( + result_path=res_url, local_folder=local_cache_path, is_file_path=is_file_path + ) result_df = None try: @@ -156,39 +161,47 @@ def get_result_df( result_df = spark.read.format(data_format).load(local_cache_path) else: result_df = _load_files_to_pandas_df( - dir_path=local_cache_path.replace("dbfs:", "/dbfs"), # replace to python path if spark path is provided. + dir_path=local_cache_path.replace( + "dbfs:", "/dbfs" + ), # replace to python path if spark path is provided. data_format=data_format, ) except Exception as e: logger.error(f"Failed to load result files from {local_cache_path} with format {data_format}.") raise e - + return result_df + def copy_cloud_dir(client: FeathrClient, source_url: str, target_url: str = None): source_url: str = source_url or client.get_job_result_uri(block=True, timeout_sec=1200) if source_url is None: - raise RuntimeError("source_url None. Please make sure either you provide a source_url or make sure the job finished in FeathrClient has a valid result URI.") + raise RuntimeError( + "source_url None. Please make sure either you provide a source_url or make sure the job finished in FeathrClient has a valid result URI." + ) if target_url is None: raise RuntimeError("target_url None. Please make sure you provide a target_url.") client.feathr_spark_launcher.upload_or_get_cloud_path(source_url, target_url) - + + def cloud_dir_exists(client: FeathrClient, dir_path: str) -> bool: return client.feathr_spark_launcher.cloud_dir_exists(dir_path) -def _load_files_to_pandas_df(dir_path: str, data_format: str = "avro") -> pd.DataFrame: +def _load_files_to_pandas_df(dir_path: str, data_format: str = "avro") -> pd.DataFrame: if data_format == "parquet": return pd.read_parquet(dir_path) elif data_format == "delta": from deltalake import DeltaTable + delta = DeltaTable(dir_path) return delta.to_pyarrow_table().to_pandas() elif data_format == "avro": import pandavro as pdx + if Path(dir_path).is_file(): return pdx.read_avro(dir_path) else: @@ -210,16 +223,17 @@ def _load_files_to_pandas_df(dir_path: str, data_format: str = "avro") -> pd.Dat raise ValueError( f"{data_format} is currently not supported in get_result_df. Currently only parquet, delta, avro, and csv are supported, please consider writing a customized function to read the result." ) - -def get_cloud_file_column_names(client: FeathrClient, path: str, format: str = "csv", is_file_path = True)->Set[str]: + + +def get_cloud_file_column_names(client: FeathrClient, path: str, format: str = "csv", is_file_path=True) -> Set[str]: # Try to load publid cloud files without credential - if path.startswith(("abfss:","wasbs:")): - paths = re.split('/|@', path) + if path.startswith(("abfss:", "wasbs:")): + paths = re.split("/|@", path) if len(paths) < 4: raise RuntimeError(f"invalid cloud path: ", path) - new_path = 'https://'+paths[3]+'/'+paths[2] + '/' + new_path = "https://" + paths[3] + "/" + paths[2] + "/" if len(paths) > 4: - new_path = new_path + '/'.join(paths[4:]) + new_path = new_path + "/".join(paths[4:]) if format == "csv" and is_file_path: try: df = pd.read_csv(new_path) @@ -227,13 +241,17 @@ def get_cloud_file_column_names(client: FeathrClient, path: str, format: str = " except: df = None # TODO: support loading other formats files - + try: - df = get_result_df(client=client, data_format=format, res_url=path, is_file_path = is_file_path) + df = get_result_df(client=client, data_format=format, res_url=path, is_file_path=is_file_path) except: - logger.warning(f"failed to load cloud files from the path: {path} because of lack of permission or invalid path.") + logger.warning( + f"failed to load cloud files from the path: {path} because of lack of permission or invalid path." + ) return None if df is None: - logger.warning(f"failed to load cloud files from the path: {path} because of lack of permission or invalid path.") + logger.warning( + f"failed to load cloud files from the path: {path} because of lack of permission or invalid path." + ) return None return df.columns diff --git a/feathr_project/feathr/utils/spark_job_params.py b/feathr_project/feathr/utils/spark_job_params.py index 80dfd350c..a41dcf19f 100644 --- a/feathr_project/feathr/utils/spark_job_params.py +++ b/feathr_project/feathr/utils/spark_job_params.py @@ -3,6 +3,7 @@ from feathr.definition.sink import Sink from feathr.definition.source import Source + class FeatureJoinJobParams: """Parameters related to feature join job. @@ -13,7 +14,7 @@ class FeatureJoinJobParams: job_output_path: Absolute path in Cloud that you want your output data to be in. """ - def __init__(self, join_config_path, observation_path, feature_config, job_output_path, secrets:List[str]=[]): + def __init__(self, join_config_path, observation_path, feature_config, job_output_path, secrets: List[str] = []): self.secrets = secrets self.join_config_path = join_config_path if isinstance(observation_path, str): @@ -34,6 +35,7 @@ def __init__(self, join_config_path, observation_path, feature_config, job_outpu else: raise TypeError("job_output_path must be a string or a Sink") + class FeatureGenerationJobParams: """Parameters related to feature generation job. diff --git a/feathr_project/feathr/version.py b/feathr_project/feathr/version.py index 2c1637058..6cbb139ff 100644 --- a/feathr_project/feathr/version.py +++ b/feathr_project/feathr/version.py @@ -1,10 +1,14 @@ __version__ = "1.0.0" + def get_version(): return __version__ + # Decouple Feathr MAVEN Version from Feathr Python SDK Version import os + + def get_maven_artifact_fullname(): maven_artifact_version = os.environ.get("MAVEN_ARTIFACT_VERSION", __version__) - return f"com.linkedin.feathr:feathr_2.12:{maven_artifact_version}" \ No newline at end of file + return f"com.linkedin.feathr:feathr_2.12:{maven_artifact_version}" diff --git a/feathr_project/feathrcli/cli.py b/feathr_project/feathrcli/cli.py index a29a08a76..3c3ec2514 100644 --- a/feathr_project/feathrcli/cli.py +++ b/feathr_project/feathrcli/cli.py @@ -10,6 +10,7 @@ from feathr.definition.config_helper import FeathrConfigHelper from feathr.registry._feathr_registry_client import _FeatureRegistry + @click.group() @click.pass_context def cli(ctx: click.Context): @@ -27,17 +28,19 @@ def check_user_at_root(): can work correctly. """ # use this file as a anchor point to identify the root of the repo - anchor_file = 'feathr_config.yaml' + anchor_file = "feathr_config.yaml" user_workspace_dir = Path(".") anchor_file_path = user_workspace_dir / anchor_file if not anchor_file_path.exists(): - raise click.UsageError('You are NOT at the root of your user workspace("/feathr_user_workspace"). Please ' - 'execute the command under your user workspace root.') + raise click.UsageError( + 'You are NOT at the root of your user workspace("/feathr_user_workspace"). Please ' + "execute the command under your user workspace root." + ) @cli.command() -@click.option('--name', default="feathr_user_workspace", help='Specify the workspace name.') -@click.option('--git/--no-git', default=False, help='When enabled, a git-based workspace will be created.') +@click.option("--name", default="feathr_user_workspace", help="Specify the workspace name.") +@click.option("--git/--no-git", default=False, help="When enabled, a git-based workspace will be created.") def init(name, git): """ Initializes a Feathr project to create and manage features. A team should share a same Feathr project usually via @@ -48,11 +51,11 @@ def init(name, git): workspace_exist = os.path.isdir(workspace_dir) if workspace_exist: # workspace already exist. Just exit. - raise click.UsageError(f'Feathr workspace ({name}) already exist. Please use a new folder name.') + raise click.UsageError(f"Feathr workspace ({name}) already exist. Please use a new folder name.") - output_str = f'Creating workspace {name} with sample config files and mock data ...' + output_str = f"Creating workspace {name} with sample config files and mock data ..." click.echo(output_str) - default_workspace = str(Path(Path(__file__).parent / 'data' / 'feathr_user_workspace').absolute()) + default_workspace = str(Path(Path(__file__).parent / "data" / "feathr_user_workspace").absolute()) # current feathr_user_workspace directory w.r.t. where the init command is executed pathlib.Path(name).mkdir(parents=True, exist_ok=True) @@ -63,17 +66,24 @@ def init(name, git): # Create a git repo for the workspace if git: os.chdir(workspace_dir) - process = subprocess.Popen(['git', 'init'], stdout=subprocess.PIPE) + process = subprocess.Popen(["git", "init"], stdout=subprocess.PIPE) output = process.communicate()[0] click.echo(output) - click.echo(click.style('Git init completed for your workspace. Please read the ' - 'wiki to learn how to manage ' - 'your workspace with git.', fg='green')) - click.echo(click.style('Feathr initialization completed.', fg='green')) + click.echo( + click.style( + "Git init completed for your workspace. Please read the " + "wiki to learn how to manage " + "your workspace with git.", + fg="green", + ) + ) + click.echo(click.style("Feathr initialization completed.", fg="green")) @cli.command() -@click.option('--save_to', default="./", help='Specify the path to save the output HOCON config(relative to current path).') +@click.option( + "--save_to", default="./", help="Specify the path to save the output HOCON config(relative to current path)." +) def hocon(save_to): """ Scan all Python-based feature definitions recursively under current directory, @@ -85,7 +95,7 @@ def hocon(save_to): @cli.command() -@click.argument('filepath', default='feature_join_conf/feature_join.conf', type=click.Path(exists=True)) +@click.argument("filepath", default="feature_join_conf/feature_join.conf", type=click.Path(exists=True)) def join(filepath): """ Creates the offline training dataset with the requested features. @@ -93,7 +103,7 @@ def join(filepath): check_user_at_root() - click.echo(click.style('Batch joining features with config: ' + filepath, fg='green')) + click.echo(click.style("Batch joining features with config: " + filepath, fg="green")) with open(filepath) as f: lines = [] for line in f: @@ -104,19 +114,25 @@ def join(filepath): client = FeathrClient() client._get_offline_features_with_config(filepath) - click.echo(click.style('Feathr feature join job submitted. Visit ' - 'https://ms.web.azuresynapse.net/en-us/monitoring/sparkapplication for detailed job ' - 'result.', fg='green')) + click.echo( + click.style( + "Feathr feature join job submitted. Visit " + "https://ms.web.azuresynapse.net/en-us/monitoring/sparkapplication for detailed job " + "result.", + fg="green", + ) + ) + @cli.command() -@click.argument('filepath', default='feature_gen_conf/feature_gen.conf', type=click.Path(exists=True)) +@click.argument("filepath", default="feature_gen_conf/feature_gen.conf", type=click.Path(exists=True)) def deploy(filepath): """ Deploys the features to online store based on the feature generation config. """ check_user_at_root() - click.echo(click.style('Deploying feature generation config: ' + filepath, fg='green')) + click.echo(click.style("Deploying feature generation config: " + filepath, fg="green")) with open(filepath) as f: lines = [] for line in f: @@ -128,14 +144,19 @@ def deploy(filepath): client = FeathrClient() client._materialize_features_with_config(filepath) click.echo() - click.echo(click.style('Feathr feature deployment submitted. Visit ' - 'https://ms.web.azuresynapse.net/en-us/monitoring/sparkapplication for detailed job ' - 'result.', fg='green')) + click.echo( + click.style( + "Feathr feature deployment submitted. Visit " + "https://ms.web.azuresynapse.net/en-us/monitoring/sparkapplication for detailed job " + "result.", + fg="green", + ) + ) @cli.command() -@click.option('--git/--no-git', default=False, help='If git-enabled, the new changes will be added and committed.') -@click.option('--msg', help='The feature name.') +@click.option("--git/--no-git", default=False, help="If git-enabled, the new changes will be added and committed.") +@click.option("--msg", help="The feature name.") def register(git, msg): """ Register your feature metadata to your metadata registry. @@ -144,22 +165,22 @@ def register(git, msg): check_user_at_root() # The register command is not integrated with Azure Atlas yet. - click.echo(click.style('Registering your metadata to metadata service...', fg='green')) + click.echo(click.style("Registering your metadata to metadata service...", fg="green")) if git: - click.echo(click.style('Git: adding all files.', fg='green')) + click.echo(click.style("Git: adding all files.", fg="green")) click.echo(msg) - process = subprocess.Popen(['git', 'add', '-A'], stdout=subprocess.PIPE) + process = subprocess.Popen(["git", "add", "-A"], stdout=subprocess.PIPE) output = process.communicate()[0] click.echo(output) - click.echo(click.style('Git: committing.', fg='green')) - process2 = subprocess.Popen(['git', 'commit', '-m', msg], stdout=subprocess.PIPE) + click.echo(click.style("Git: committing.", fg="green")) + process2 = subprocess.Popen(["git", "commit", "-m", msg], stdout=subprocess.PIPE) output2 = process2.communicate()[0] click.echo(output2) client = FeathrClient() client.register_features() - click.echo(click.style('Feathr registration completed successfully!', fg='green')) + click.echo(click.style("Feathr registration completed successfully!", fg="green")) @cli.command() @@ -172,34 +193,36 @@ def start(): feathr_user_workspace. After the jar is downloaded, the command will run this jar. The jar needs to be running( don't close the terminal) while you want to use 'feathr test'. """ + def run_jar(): - cmd = ['java', '-jar', jar_name] + cmd = ["java", "-jar", jar_name] with subprocess.Popen(cmd, stdout=subprocess.PIPE, bufsize=1, universal_newlines=True) as p: # Need to continuously pump the results from jar to terminal for line in p.stdout: - print(line, end='') - + print(line, end="") check_user_at_root() # The jar should be placed under the root of the user workspace - jar_name = 'feathr_local_engine.jar' + jar_name = "feathr_local_engine.jar" # Download the jar if it doesn't exist if not os.path.isfile(jar_name): - url = 'https://azurefeathrstorage.blob.core.windows.net/public/' + jar_name - file_name = url.split('/')[-1] + url = "https://azurefeathrstorage.blob.core.windows.net/public/" + jar_name + file_name = url.split("/")[-1] u = urllib.request.urlopen(url) - f = open(file_name, 'wb') + f = open(file_name, "wb") meta = u.info() - file_size = int(meta.get('Content-Length')) - click.echo(click.style('There is no local feathr engine(jar) in the workspace. Will download the feathr jar.', - fg='green')) - click.echo('Downloading feathr jar for local testing: %s Bytes: %s from %s' % (file_name, file_size, url)) + file_size = int(meta.get("Content-Length")) + click.echo( + click.style( + "There is no local feathr engine(jar) in the workspace. Will download the feathr jar.", fg="green" + ) + ) + click.echo("Downloading feathr jar for local testing: %s Bytes: %s from %s" % (file_name, file_size, url)) file_size_dl = 0 block_sz = 8192 - with click.progressbar(length=file_size, - label='Download feathr local engine jar') as bar: + with click.progressbar(length=file_size, label="Download feathr local engine jar") as bar: while True: buffer = u.read(block_sz) if not buffer: @@ -211,22 +234,27 @@ def run_jar(): f.close() - click.echo(click.style(f'Starting the local feathr engine: {jar_name}.')) - click.echo(click.style(f'Please keep this open and start another terminal to run feathr test. This terminal shows ' - f'the debug message.', fg='green')) + click.echo(click.style(f"Starting the local feathr engine: {jar_name}.")) + click.echo( + click.style( + f"Please keep this open and start another terminal to run feathr test. This terminal shows " + f"the debug message.", + fg="green", + ) + ) run_jar() @cli.command() -@click.option('--features', prompt='Your feature names, separated by comma', help='The feature name.') +@click.option("--features", prompt="Your feature names, separated by comma", help="The feature name.") def test(features): """ Tests a single feature definition locally via local spark mode with mock data. Mock data has to be provided by the users. Please execute "feathr start" before "feathr test" to setup the local engine. """ check_user_at_root() - click.echo('\nProducing feature values for requested features ... ') + click.echo("\nProducing feature values for requested features ... ") gateway = JavaGateway() # User should run this command at user workspace dir root @@ -235,5 +263,5 @@ def test(features): # for py4j, it's always named as entry_point stack_entry_point_result = gateway.entry_point.getResult(user_workspace_dir, features) - click.echo('\nFeature computation completed.') + click.echo("\nFeature computation completed.") click.echo(stack_entry_point_result) diff --git a/feathr_project/feathrcli/data/feathr_user_workspace/features/agg_features.py b/feathr_project/feathrcli/data/feathr_user_workspace/features/agg_features.py index aa166a221..c294f83b1 100644 --- a/feathr_project/feathrcli/data/feathr_user_workspace/features/agg_features.py +++ b/feathr_project/feathrcli/data/feathr_user_workspace/features/agg_features.py @@ -5,29 +5,32 @@ from feathr.transformation import WindowAggTransformation from feathr.typed_key import TypedKey -batch_source = HdfsSource(name="nycTaxiBatchSource", - path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04.csv", - event_timestamp_column="lpep_dropoff_datetime", - timestamp_format="yyyy-MM-dd HH:mm:ss") +batch_source = HdfsSource( + name="nycTaxiBatchSource", + path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04.csv", + event_timestamp_column="lpep_dropoff_datetime", + timestamp_format="yyyy-MM-dd HH:mm:ss", +) -location_id = TypedKey(key_column="DOLocationID", - key_column_type=ValueType.INT32, - description="location id in NYC", - full_name="nyc_taxi.location_id") -agg_features = [Feature(name="f_location_avg_fare", - key=location_id, - feature_type=FLOAT, - transform=WindowAggTransformation(agg_expr="cast_float(fare_amount)", - agg_func="AVG", - window="90d")), - Feature(name="f_location_max_fare", - key=location_id, - feature_type=FLOAT, - transform=WindowAggTransformation(agg_expr="cast_float(fare_amount)", - agg_func="MAX", - window="90d")) - ] +location_id = TypedKey( + key_column="DOLocationID", + key_column_type=ValueType.INT32, + description="location id in NYC", + full_name="nyc_taxi.location_id", +) +agg_features = [ + Feature( + name="f_location_avg_fare", + key=location_id, + feature_type=FLOAT, + transform=WindowAggTransformation(agg_expr="cast_float(fare_amount)", agg_func="AVG", window="90d"), + ), + Feature( + name="f_location_max_fare", + key=location_id, + feature_type=FLOAT, + transform=WindowAggTransformation(agg_expr="cast_float(fare_amount)", agg_func="MAX", window="90d"), + ), +] -agg_anchor = FeatureAnchor(name="aggregationFeatures", - source=batch_source, - features=agg_features) +agg_anchor = FeatureAnchor(name="aggregationFeatures", source=batch_source, features=agg_features) diff --git a/feathr_project/feathrcli/data/feathr_user_workspace/features/non_agg_features.py b/feathr_project/feathrcli/data/feathr_user_workspace/features/non_agg_features.py index 8d7d7c93b..5d7c5a397 100644 --- a/feathr_project/feathrcli/data/feathr_user_workspace/features/non_agg_features.py +++ b/feathr_project/feathrcli/data/feathr_user_workspace/features/non_agg_features.py @@ -4,24 +4,29 @@ from feathr.typed_key import TypedKey from feathr.source import HdfsSource -batch_source = HdfsSource(name="nycTaxiBatchSource", - path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04.csv", - event_timestamp_column="lpep_dropoff_datetime", - timestamp_format="yyyy-MM-dd HH:mm:ss") +batch_source = HdfsSource( + name="nycTaxiBatchSource", + path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04.csv", + event_timestamp_column="lpep_dropoff_datetime", + timestamp_format="yyyy-MM-dd HH:mm:ss", +) -location_id = TypedKey(key_column="DOLocationID", - key_column_type=ValueType.INT32, - description="location id in NYC", - full_name="nyc_taxi.location_id") +location_id = TypedKey( + key_column="DOLocationID", + key_column_type=ValueType.INT32, + description="location id in NYC", + full_name="nyc_taxi.location_id", +) features = [ - Feature(name="f_loc_is_long_trip_distance", - feature_type=BOOLEAN, - transform="cast_float(trip_distance)>30", key=location_id), - Feature(name="f_loc_day_of_week", - feature_type=INT32, - transform="dayofweek(lpep_dropoff_datetime)", key=location_id) + Feature( + name="f_loc_is_long_trip_distance", + feature_type=BOOLEAN, + transform="cast_float(trip_distance)>30", + key=location_id, + ), + Feature( + name="f_loc_day_of_week", feature_type=INT32, transform="dayofweek(lpep_dropoff_datetime)", key=location_id + ), ] -anchor = FeatureAnchor(name="nonAggFeatures", - source=batch_source, - features=features) \ No newline at end of file +anchor = FeatureAnchor(name="nonAggFeatures", source=batch_source, features=features) diff --git a/feathr_project/feathrcli/data/feathr_user_workspace/features/request_features.py b/feathr_project/feathrcli/data/feathr_user_workspace/features/request_features.py index 90b1c7395..5106762aa 100644 --- a/feathr_project/feathrcli/data/feathr_user_workspace/features/request_features.py +++ b/feathr_project/feathrcli/data/feathr_user_workspace/features/request_features.py @@ -5,32 +5,32 @@ from feathr.source import INPUT_CONTEXT f_trip_distance = Feature(name="f_trip_distance", feature_type=FLOAT, transform="trip_distance") -f_trip_time_duration = Feature(name="f_trip_time_duration", - feature_type=INT32, - transform="(to_unix_timestamp(lpep_dropoff_datetime) - to_unix_timestamp(lpep_pickup_datetime))/60") +f_trip_time_duration = Feature( + name="f_trip_time_duration", + feature_type=INT32, + transform="(to_unix_timestamp(lpep_dropoff_datetime) - to_unix_timestamp(lpep_pickup_datetime))/60", +) features = [ f_trip_distance, f_trip_time_duration, - Feature(name="f_is_long_trip_distance", - feature_type=BOOLEAN, - transform="cast_float(trip_distance)>30"), - Feature(name="f_day_of_week", - feature_type=INT32, - transform="dayofweek(lpep_dropoff_datetime)"), - ] + Feature(name="f_is_long_trip_distance", feature_type=BOOLEAN, transform="cast_float(trip_distance)>30"), + Feature(name="f_day_of_week", feature_type=INT32, transform="dayofweek(lpep_dropoff_datetime)"), +] -request_anchor = FeatureAnchor(name="request_features", - source=INPUT_CONTEXT, - features=features) +request_anchor = FeatureAnchor(name="request_features", source=INPUT_CONTEXT, features=features) -f_trip_time_distance = DerivedFeature(name="f_trip_time_distance", - feature_type=FLOAT, - input_features=[f_trip_distance, f_trip_time_duration], - transform="f_trip_distance * f_trip_time_duration") +f_trip_time_distance = DerivedFeature( + name="f_trip_time_distance", + feature_type=FLOAT, + input_features=[f_trip_distance, f_trip_time_duration], + transform="f_trip_distance * f_trip_time_duration", +) -f_trip_time_rounded = DerivedFeature(name="f_trip_time_rounded", - feature_type=INT32, - input_features=[f_trip_time_duration], - transform="f_trip_time_duration % 10") +f_trip_time_rounded = DerivedFeature( + name="f_trip_time_rounded", + feature_type=INT32, + input_features=[f_trip_time_duration], + transform="f_trip_time_duration % 10", +) diff --git a/feathr_project/setup.py b/feathr_project/setup.py index 89a0e805f..a5211d70c 100644 --- a/feathr_project/setup.py +++ b/feathr_project/setup.py @@ -17,18 +17,17 @@ try: exec(open("feathr/version.py").read()) except IOError: - print("Failed to load Feathr version file for packaging.", - file=sys.stderr) + print("Failed to load Feathr version file for packaging.", file=sys.stderr) # Temp workaround for conda build. For long term fix, Jay will need to update manifest.in file. VERSION = "1.0.0" VERSION = __version__ # noqa os.environ["FEATHR_VERSION"] = VERSION -extras_require=dict( +extras_require = dict( dev=[ - "black>=22.1.0", # formatter - "isort", # sort import statements + "black>=22.1.0", # formatter + "isort", # sort import statements "pytest>=7", "pytest-cov", "pytest-xdist", @@ -38,16 +37,16 @@ "azure-cli==2.37.0", "jupyter>=1.0.0", "matplotlib>=3.6.1", - "papermill>=2.1.2,<3", # to test run notebooks + "papermill>=2.1.2,<3", # to test run notebooks "scrapbook>=0.5.0,<1.0.0", # to scrap notebook outputs - "scikit-learn", # for notebook examples - "plotly" # for plotting + "scikit-learn", # for notebook examples + "plotly", # for plotting ], ) extras_require["all"] = list(set(sum([*extras_require.values()], []))) setup( - name='feathr', + name="feathr", version=VERSION, long_description=long_description, long_description_content_type="text/markdown", @@ -99,20 +98,18 @@ # See this for more details: https://github.com/Azure/azure-sdk-for-python/issues/24765 "msrest<=0.6.21", "typing_extensions>=4.2.0", - "ipython", # for chat in notebook - "revChatGPT" + "ipython", # for chat in notebook + "revChatGPT", ], tests_require=[ # TODO: This has been depricated "pytest", ], extras_require=extras_require, - entry_points={ - 'console_scripts': ['feathr=feathrcli.cli:cli'] - }, + entry_points={"console_scripts": ["feathr=feathrcli.cli:cli"]}, classifiers=[ "Programming Language :: Python :: 3", "License :: OSI Approved :: Apache Software License", "Operating System :: OS Independent", ], - python_requires=">=3.7" -) \ No newline at end of file + python_requires=">=3.7", +) diff --git a/feathr_project/test/clean_azure_test_data.py b/feathr_project/test/clean_azure_test_data.py index b0bf413cd..9002a588c 100644 --- a/feathr_project/test/clean_azure_test_data.py +++ b/feathr_project/test/clean_azure_test_data.py @@ -1,6 +1,6 @@ - import sys import os + # We have to append user's current path to sys path so the modules can be resolved # Otherwise we will got "no module named feathr" error sys.path.append(os.path.abspath(os.getcwd())) @@ -15,14 +15,14 @@ def clean_data(): Remove the test data(feature table: nycTaxiDemoFeature) in Azure. """ client = FeathrClient() - table_name = 'nycTaxiDemoFeature' + table_name = "nycTaxiDemoFeature" client._clean_test_data(table_name) - print('Redis table cleaned: ' + table_name) + print("Redis table cleaned: " + table_name) runner = CliRunner() with runner.isolated_filesystem(): runner.invoke(init, []) # Need to be in the workspace so it won't complain - os.chdir('feathr_user_workspace') + os.chdir("feathr_user_workspace") clean_data() diff --git a/feathr_project/test/conftest.py b/feathr_project/test/conftest.py index c2699e871..10bdaee19 100644 --- a/feathr_project/test/conftest.py +++ b/feathr_project/test/conftest.py @@ -42,12 +42,16 @@ def spark() -> SparkSession: """Generate a spark session for tests.""" # Set ui port other than the default one (4040) so that feathr spark job may not fail. spark_session = ( - SparkSession.builder - .appName("tests") - .config("spark.jars.packages", ",".join([ - "org.apache.spark:spark-avro_2.12:3.3.0", - "io.delta:delta-core_2.12:2.1.1", - ])) + SparkSession.builder.appName("tests") + .config( + "spark.jars.packages", + ",".join( + [ + "org.apache.spark:spark-avro_2.12:3.3.0", + "io.delta:delta-core_2.12:2.1.1", + ] + ), + ) .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog") .config("spark.ui.port", "8080") diff --git a/feathr_project/test/prep_azure_kafka_test_data.py b/feathr_project/test/prep_azure_kafka_test_data.py index cfe83b20d..8d9ddaf38 100644 --- a/feathr_project/test/prep_azure_kafka_test_data.py +++ b/feathr_project/test/prep_azure_kafka_test_data.py @@ -8,6 +8,7 @@ from avro.io import BinaryEncoder, DatumWriter from confluent_kafka import Producer from feathr.utils._env_config_reader import EnvConfigReader + """ Produce some sample data for streaming feature using Kafka""" KAFKA_BROKER = "feathrazureci.servicebus.windows.net:9093" @@ -15,22 +16,20 @@ GENERATION_SIZE = 10 + def generate_entities(): return range(GENERATION_SIZE) def generate_trips(entities): df = pd.DataFrame(columns=["driver_id", "trips_today", "datetime", "created"]) - df['driver_id'] = entities - df['trips_today'] = range(GENERATION_SIZE) - df['datetime'] = pd.to_datetime( - np.random.randint( - datetime(2021, 10, 10).timestamp(), - datetime(2022, 10, 30).timestamp(), - size=GENERATION_SIZE), - unit="s" + df["driver_id"] = entities + df["trips_today"] = range(GENERATION_SIZE) + df["datetime"] = pd.to_datetime( + np.random.randint(datetime(2021, 10, 10).timestamp(), datetime(2022, 10, 30).timestamp(), size=GENERATION_SIZE), + unit="s", ) - df['created'] = pd.to_datetime(datetime.now()) + df["created"] = pd.to_datetime(datetime.now()) return df @@ -41,44 +40,40 @@ def send_avro_record_to_kafka(topic, record): encoder = BinaryEncoder(bytes_writer) writer.write(record, encoder) env_config = EnvConfigReader(config_path=None) - sasl = env_config.get_from_env_or_akv('KAFKA_SASL_JAAS_CONFIG') + sasl = env_config.get_from_env_or_akv("KAFKA_SASL_JAAS_CONFIG") conf = { - 'bootstrap.servers': KAFKA_BROKER, - 'security.protocol': 'SASL_SSL', - 'ssl.ca.location': '/usr/local/etc/openssl@1.1/cert.pem', - 'sasl.mechanism': 'PLAIN', - 'sasl.username': '$ConnectionString', - 'sasl.password': '{};EntityPath={}'.format(sasl, topic), - 'client.id': 'python-example-producer' + "bootstrap.servers": KAFKA_BROKER, + "security.protocol": "SASL_SSL", + "ssl.ca.location": "/usr/local/etc/openssl@1.1/cert.pem", + "sasl.mechanism": "PLAIN", + "sasl.username": "$ConnectionString", + "sasl.password": "{};EntityPath={}".format(sasl, topic), + "client.id": "python-example-producer", } - producer = Producer({ - **conf - }) + producer = Producer({**conf}) producer.produce(topic=topic, value=bytes_writer.getvalue()) producer.flush() + entities = generate_entities() trips_df = generate_trips(entities) -avro_schema_json = json.dumps({ - "type": "record", - "name": "DriverTrips", - "fields": [ - {"name": "driver_id", "type": "long"}, - {"name": "trips_today", "type": "int"}, - { - "name": "datetime", - "type": {"type": "long", "logicalType": "timestamp-micros"} - } - ] -}) +avro_schema_json = json.dumps( + { + "type": "record", + "name": "DriverTrips", + "fields": [ + {"name": "driver_id", "type": "long"}, + {"name": "trips_today", "type": "int"}, + {"name": "datetime", "type": {"type": "long", "logicalType": "timestamp-micros"}}, + ], + } +) while True: -# This while loop is used to keep the process runinng and producing data stream; -# If no need please remove it - for record in trips_df.drop(columns=['created']).to_dict('record'): - record["datetime"] = ( - record["datetime"].to_pydatetime().replace(tzinfo=pytz.utc) - ) + # This while loop is used to keep the process runinng and producing data stream; + # If no need please remove it + for record in trips_df.drop(columns=["created"]).to_dict("record"): + record["datetime"] = record["datetime"].to_pydatetime().replace(tzinfo=pytz.utc) send_avro_record_to_kafka(topic=KAFKA_TOPIC, record=record) diff --git a/feathr_project/test/prep_azure_test_data.py b/feathr_project/test/prep_azure_test_data.py index 248e6bf29..1a248a4e0 100644 --- a/feathr_project/test/prep_azure_test_data.py +++ b/feathr_project/test/prep_azure_test_data.py @@ -1,6 +1,6 @@ - import sys import os + # We have to append user's current path to sys path so the modules can be resolved # Otherwise we will got "no module named feathr" error sys.path.append(os.path.abspath(os.getcwd())) @@ -9,23 +9,24 @@ from click.testing import CliRunner from feathr import FeathrClient + def initialize_data(): """ Initialize the test data to Azure for testing. WARNING: It will override the existing test data. """ - print('Creating test data. This might override existing test data.') + print("Creating test data. This might override existing test data.") client = FeathrClient() # materialize feature to online store - client._materialize_features_with_config('feature_gen_conf/test_feature_gen_1.conf') - client._materialize_features_with_config('feature_gen_conf/test_feature_gen_2.conf') - client._materialize_features_with_config('feature_gen_conf/test_feature_gen_snowflake.conf') - print('Test data push job has started. It will take some time to complete.') + client._materialize_features_with_config("feature_gen_conf/test_feature_gen_1.conf") + client._materialize_features_with_config("feature_gen_conf/test_feature_gen_2.conf") + client._materialize_features_with_config("feature_gen_conf/test_feature_gen_snowflake.conf") + print("Test data push job has started. It will take some time to complete.") runner = CliRunner() with runner.isolated_filesystem(): runner.invoke(init, []) # Need to be in the workspace so it won't complain - os.chdir('feathr_user_workspace') + os.chdir("feathr_user_workspace") initialize_data() diff --git a/feathr_project/test/samples/test_notebooks.py b/feathr_project/test/samples/test_notebooks.py index 4f76dbcfa..e8caa7001 100644 --- a/feathr_project/test/samples/test_notebooks.py +++ b/feathr_project/test/samples/test_notebooks.py @@ -2,6 +2,7 @@ import yaml import pytest + try: import papermill as pm import scrapbook as sb @@ -9,13 +10,10 @@ pass # disable error while collecting tests for non-notebook environments -SAMPLES_DIR = ( - Path(__file__) - .parent # .../samples - .parent # .../test - .parent # .../feathr_project - .parent # .../feathr (root of the repo) - .joinpath("docs", "samples") +SAMPLES_DIR = Path( + __file__ +).parent.parent.parent.parent.joinpath( # .../samples # .../test # .../feathr_project # .../feathr (root of the repo) + "docs", "samples" ) NOTEBOOK_PATHS = { "nyc_taxi_demo": str(SAMPLES_DIR.joinpath("nyc_taxi_demo.ipynb")), @@ -49,10 +47,10 @@ def test__nyc_taxi_demo(config_path, tmp_path): nb = sb.read_notebook(output_notebook_path) outputs = nb.scraps - assert outputs["materialized_feature_values"].data["239"] == pytest.approx([1480., 5707.], abs=1.) - assert outputs["materialized_feature_values"].data["265"] == pytest.approx([4160., 10000.], abs=1.) - assert outputs["rmse"].data == pytest.approx(5., abs=2.) - assert outputs["mae"].data == pytest.approx(2., abs=1.) + assert outputs["materialized_feature_values"].data["239"] == pytest.approx([1480.0, 5707.0], abs=1.0) + assert outputs["materialized_feature_values"].data["265"] == pytest.approx([4160.0, 10000.0], abs=1.0) + assert outputs["rmse"].data == pytest.approx(5.0, abs=2.0) + assert outputs["mae"].data == pytest.approx(2.0, abs=1.0) @pytest.mark.databricks @@ -98,7 +96,7 @@ def test__fraud_detection_demo(config_path, tmp_path): nb = sb.read_notebook(output_notebook_path) outputs = nb.scraps - assert outputs["materialized_feature_values"].data == pytest.approx([False, 0, 9, 239.0, 1, 1, 239.0, 0.0], abs=1.) + assert outputs["materialized_feature_values"].data == pytest.approx([False, 0, 9, 239.0, 1, 1, 239.0, 0.0], abs=1.0) assert outputs["precision"].data > 0.5 assert outputs["recall"].data > 0.5 assert outputs["f1"].data > 0.5 diff --git a/feathr_project/test/test_azure_feature_monitoring_e2e.py b/feathr_project/test/test_azure_feature_monitoring_e2e.py index 08a75ae32..600de26c8 100644 --- a/feathr_project/test/test_azure_feature_monitoring_e2e.py +++ b/feathr_project/test/test_azure_feature_monitoring_e2e.py @@ -3,22 +3,20 @@ from feathr import MonitoringSettings from feathr import MonitoringSqlSink -from test_fixture import (basic_test_setup, get_online_test_table_name) +from test_fixture import basic_test_setup, get_online_test_table_name from test_utils.constants import Constants def test_feature_monitoring(): monitor_sink_table = get_online_test_table_name("nycTaxiCITableMonitoring") - test_workspace_dir = Path( - __file__).parent.resolve() / "test_user_workspace" + test_workspace_dir = Path(__file__).parent.resolve() / "test_user_workspace" client = basic_test_setup(os.path.join(test_workspace_dir, "feathr_config.yaml")) monitor_sink = MonitoringSqlSink(table_name=monitor_sink_table) - settings = MonitoringSettings("monitoringSetting", - sinks=[monitor_sink], - feature_names=[ - "f_location_avg_fare", "f_location_max_fare"]) + settings = MonitoringSettings( + "monitoringSetting", sinks=[monitor_sink], feature_names=["f_location_avg_fare", "f_location_max_fare"] + ) client.monitor_features(settings) # just assume the job is successful without validating the actual result in Redis. Might need to consolidate # this part with the test_feathr_online_store test case diff --git a/feathr_project/test/test_azure_kafka_e2e.py b/feathr_project/test/test_azure_kafka_e2e.py index 1793c2b1c..407f11894 100644 --- a/feathr_project/test/test_azure_kafka_e2e.py +++ b/feathr_project/test/test_azure_kafka_e2e.py @@ -5,8 +5,10 @@ from test_fixture import kafka_test_setup from test_utils.constants import Constants + def test_feathr_kafka_configs(): - schema = AvroJsonSchema(schemaStr=""" + schema = AvroJsonSchema( + schemaStr=""" { "type": "record", "schema_name": "DriverTrips", @@ -14,13 +16,16 @@ def test_feathr_kafka_configs(): {"name": "driver_id", "type": "long"}, ] } - """) - stream_source = KafKaSource(name="kafkaStreamingSource", - kafkaConfig=KafkaConfig(brokers=["feathrazureci.servicebus.windows.net:9093"], - topics=["feathrcieventhub"], - schema=schema)) + """ + ) + stream_source = KafKaSource( + name="kafkaStreamingSource", + kafkaConfig=KafkaConfig( + brokers=["feathrazureci.servicebus.windows.net:9093"], topics=["feathrcieventhub"], schema=schema + ), + ) config_list = stream_source.to_feature_config().split() - config = ''.join([conf.replace('\\"', '').strip('\\n') for conf in config_list]) + config = "".join([conf.replace('\\"', "").strip("\\n") for conf in config_list]) expected_config = """ kafkaStreamingSource: { type: KAFKA @@ -43,9 +48,12 @@ def test_feathr_kafka_configs(): } """ assert config == "".join(expected_config.split()) - -@pytest.mark.skipif(os.environ.get('SPARK_CONFIG__SPARK_CLUSTER') != "azure_synapse", - reason="skip for databricks, as it cannot stop streaming job automatically for now.") + + +@pytest.mark.skipif( + os.environ.get("SPARK_CONFIG__SPARK_CLUSTER") != "azure_synapse", + reason="skip for databricks, as it cannot stop streaming job automatically for now.", +) def test_feathr_kafa_streaming_features(): """ Test FeathrClient() materialize_features can ingest streaming feature correctly @@ -54,9 +62,8 @@ def test_feathr_kafa_streaming_features(): client = kafka_test_setup(os.path.join(test_workspace_dir, "feathr_config.yaml")) redisSink = RedisSink(table_name="kafkaSampleDemoFeature", streaming=True, streamingTimeoutMs=10000) - settings = MaterializationSettings(name="kafkaSampleDemo", - sinks=[redisSink], - feature_names=['f_modified_streaming_count'] - ) + settings = MaterializationSettings( + name="kafkaSampleDemo", sinks=[redisSink], feature_names=["f_modified_streaming_count"] + ) client.materialize_features(settings, allow_materialize_non_agg_feature=True) client.wait_job_to_finish(timeout_sec=Constants.SPARK_JOB_TIMEOUT_SECONDS) diff --git a/feathr_project/test/test_azure_snowflake_e2e.py b/feathr_project/test/test_azure_snowflake_e2e.py index 6811c5d88..43f6a18fb 100644 --- a/feathr_project/test/test_azure_snowflake_e2e.py +++ b/feathr_project/test/test_azure_snowflake_e2e.py @@ -3,16 +3,17 @@ from datetime import datetime, timedelta from pathlib import Path -from feathr import (BackfillTime, MaterializationSettings) +from feathr import BackfillTime, MaterializationSettings from feathr import FeatureQuery from feathr import ObservationSettings from feathr import RedisSink from feathr import TypedKey from feathr import ValueType from feathr.utils.job_utils import get_result_df -from test_fixture import (snowflake_test_setup, get_online_test_table_name) +from test_fixture import snowflake_test_setup, get_online_test_table_name from test_utils.constants import Constants + @pytest.mark.skip(reason="All snowflake tests are skipped for now due to budget restriction.") def test_feathr_online_store_agg_features(): """ @@ -25,33 +26,36 @@ def test_feathr_online_store_agg_features(): online_test_table = get_online_test_table_name("snowflakeSampleDemoFeature") backfill_time = BackfillTime(start=datetime(2020, 5, 20), end=datetime(2020, 5, 20), step=timedelta(days=1)) - redisSink = RedisSink(table_name= online_test_table) - settings = MaterializationSettings(name="snowflakeSampleDemoFeature", - sinks=[redisSink], - feature_names=['f_snowflake_call_center_division_name', - 'f_snowflake_call_center_zipcode'], - backfill_time=backfill_time) + redisSink = RedisSink(table_name=online_test_table) + settings = MaterializationSettings( + name="snowflakeSampleDemoFeature", + sinks=[redisSink], + feature_names=["f_snowflake_call_center_division_name", "f_snowflake_call_center_zipcode"], + backfill_time=backfill_time, + ) client.materialize_features(settings, allow_materialize_non_agg_feature=True) # just assume the job is successful without validating the actual result in Redis. Might need to consolidate # this part with the test_feathr_online_store test case client.wait_job_to_finish(timeout_sec=Constants.SPARK_JOB_TIMEOUT_SECONDS) - res = client.get_online_features(online_test_table, '1', - ['f_snowflake_call_center_division_name', 'f_snowflake_call_center_zipcode']) + res = client.get_online_features( + online_test_table, "1", ["f_snowflake_call_center_division_name", "f_snowflake_call_center_zipcode"] + ) assert len(res) == 2 assert res[0] != None assert res[1] != None - res = client.multi_get_online_features(online_test_table, - ['1', '2'], - ['f_snowflake_call_center_division_name', 'f_snowflake_call_center_zipcode']) - assert res['1'][0] != None - assert res['1'][1] != None - assert res['2'][0] != None - assert res['2'][1] != None + res = client.multi_get_online_features( + online_test_table, ["1", "2"], ["f_snowflake_call_center_division_name", "f_snowflake_call_center_zipcode"] + ) + assert res["1"][0] != None + assert res["1"][1] != None + assert res["2"][0] != None + assert res["2"][1] != None client._clean_test_data(online_test_table) + @pytest.mark.skip(reason="All snowflake tests are skipped for now due to budget restriction.") def test_feathr_get_offline_features(): """ @@ -59,31 +63,42 @@ def test_feathr_get_offline_features(): """ test_workspace_dir = Path(__file__).parent.resolve() / "test_user_workspace" - client = snowflake_test_setup(os.path.join(test_workspace_dir, "feathr_config.yaml")) - call_sk_id = TypedKey(key_column="CC_CALL_CENTER_SK", - key_column_type=ValueType.INT32, - description="call center sk", - full_name="snowflake.CC_CALL_CENTER_SK") + call_sk_id = TypedKey( + key_column="CC_CALL_CENTER_SK", + key_column_type=ValueType.INT32, + description="call center sk", + full_name="snowflake.CC_CALL_CENTER_SK", + ) feature_query = FeatureQuery( - feature_list=['f_snowflake_call_center_division_name', 'f_snowflake_call_center_zipcode'], - key=call_sk_id) + feature_list=["f_snowflake_call_center_division_name", "f_snowflake_call_center_zipcode"], key=call_sk_id + ) - observation_path = client.get_snowflake_path(database="SNOWFLAKE_SAMPLE_DATA",schema="TPCDS_SF10TCL",dbtable="CALL_CENTER") - settings = ObservationSettings( - observation_path=observation_path) + observation_path = client.get_snowflake_path( + database="SNOWFLAKE_SAMPLE_DATA", schema="TPCDS_SF10TCL", dbtable="CALL_CENTER" + ) + settings = ObservationSettings(observation_path=observation_path) now = datetime.now() - # set output folder based on different runtime - if client.spark_runtime == 'databricks': - output_path = ''.join(['dbfs:/feathrazure_cijob_snowflake','_', str(now.minute), '_', str(now.second), ".avro"]) + # set output folder based on different runtime + if client.spark_runtime == "databricks": + output_path = "".join( + ["dbfs:/feathrazure_cijob_snowflake", "_", str(now.minute), "_", str(now.second), ".avro"] + ) else: - output_path = ''.join(['abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/demo_data/snowflake_output','_', str(now.minute), '_', str(now.second), ".avro"]) - - client.get_offline_features(observation_settings=settings, - feature_query=feature_query, - output_path=output_path) + output_path = "".join( + [ + "abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/demo_data/snowflake_output", + "_", + str(now.minute), + "_", + str(now.second), + ".avro", + ] + ) + + client.get_offline_features(observation_settings=settings, feature_query=feature_query, output_path=output_path) # assuming the job can successfully run; otherwise it will throw exception client.wait_job_to_finish(timeout_sec=Constants.SPARK_JOB_TIMEOUT_SECONDS) @@ -92,6 +107,7 @@ def test_feathr_get_offline_features(): # just assume there are results. assert res.shape[0] > 1 + @pytest.mark.skip(reason="All snowflake tests are skipped for now due to budget restriction.") def test_client_get_snowflake_observation_path(): """ @@ -99,7 +115,6 @@ def test_client_get_snowflake_observation_path(): """ test_workspace_dir = Path(__file__).parent.resolve() / "test_user_workspace" - client = snowflake_test_setup(os.path.join(test_workspace_dir, "feathr_config.yaml")) snowflake_path_actual = client.get_snowflake_path(database="DATABASE", schema="SCHEMA", dbtable="TABLE") snowflake_path_expected = "snowflake://snowflake_account/?sfDatabase=DATABASE&sfSchema=SCHEMA&dbtable=TABLE" diff --git a/feathr_project/test/test_azure_spark_e2e.py b/feathr_project/test/test_azure_spark_e2e.py index 322b186f6..16e73af16 100644 --- a/feathr_project/test/test_azure_spark_e2e.py +++ b/feathr_project/test/test_azure_spark_e2e.py @@ -1,9 +1,20 @@ import os from datetime import datetime, timedelta from pathlib import Path -from feathr import (BOOLEAN, FLOAT, INPUT_CONTEXT, INT32, STRING, - DerivedFeature, Feature, FeatureAnchor, HdfsSource, - TypedKey, ValueType, WindowAggTransformation) +from feathr import ( + BOOLEAN, + FLOAT, + INPUT_CONTEXT, + INT32, + STRING, + DerivedFeature, + Feature, + FeatureAnchor, + HdfsSource, + TypedKey, + ValueType, + WindowAggTransformation, +) from feathr import FeathrClient from feathr.definition.sink import CosmosDbSink, ElasticSearchSink from feathr.definition.source import HdfsSource @@ -11,44 +22,55 @@ import pytest from click.testing import CliRunner -from feathr import (BackfillTime, MaterializationSettings) +from feathr import BackfillTime, MaterializationSettings from feathr import FeathrClient from feathr import FeatureQuery from feathr import ObservationSettings -from feathr import RedisSink, HdfsSink, JdbcSink,AerospikeSink +from feathr import RedisSink, HdfsSink, JdbcSink, AerospikeSink from feathr import TypedKey from feathr import ValueType from feathr.utils.job_utils import get_result_df from feathrcli.cli import init -from test_fixture import (basic_test_setup, get_online_test_table_name, composite_keys_test_setup) +from test_fixture import basic_test_setup, get_online_test_table_name, composite_keys_test_setup from test_utils.constants import Constants + # make sure you have run the upload feature script before running these tests # the feature configs are from feathr_project/data/feathr_user_workspace def test_feathr_materialize_to_offline(): """ Test FeathrClient() HdfsSink. """ - test_workspace_dir = Path( - __file__).parent.resolve() / "test_user_workspace" + test_workspace_dir = Path(__file__).parent.resolve() / "test_user_workspace" # os.chdir(test_workspace_dir) client: FeathrClient = basic_test_setup(os.path.join(test_workspace_dir, "feathr_config.yaml")) - backfill_time = BackfillTime(start=datetime( - 2020, 5, 20), end=datetime(2020, 5, 20), step=timedelta(days=1)) + backfill_time = BackfillTime(start=datetime(2020, 5, 20), end=datetime(2020, 5, 20), step=timedelta(days=1)) now = datetime.now() - if client.spark_runtime == 'databricks': - output_path = ''.join(['dbfs:/feathrazure_cijob_materialize_offline_','_', str(now.minute), '_', str(now.second), ""]) + if client.spark_runtime == "databricks": + output_path = "".join( + ["dbfs:/feathrazure_cijob_materialize_offline_", "_", str(now.minute), "_", str(now.second), ""] + ) else: - output_path = ''.join(['abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/demo_data/feathrazure_cijob_materialize_offline_','_', str(now.minute), '_', str(now.second), ""]) + output_path = "".join( + [ + "abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/demo_data/feathrazure_cijob_materialize_offline_", + "_", + str(now.minute), + "_", + str(now.second), + "", + ] + ) offline_sink = HdfsSink(output_path=output_path) - settings = MaterializationSettings("nycTaxiTable", - sinks=[offline_sink], - feature_names=[ - "f_location_avg_fare", "f_location_max_fare"], - backfill_time=backfill_time) + settings = MaterializationSettings( + "nycTaxiTable", + sinks=[offline_sink], + feature_names=["f_location_avg_fare", "f_location_max_fare"], + backfill_time=backfill_time, + ) client.materialize_features(settings) # assuming the job can successfully run; otherwise it will throw exception client.wait_job_to_finish(timeout_sec=Constants.SPARK_JOB_TIMEOUT_SECONDS) @@ -58,72 +80,75 @@ def test_feathr_materialize_to_offline(): res_df = get_result_df(client, data_format="avro", res_url=output_path + "/df0/daily/2020/05/20") assert res_df.shape[0] > 0 + def test_feathr_online_store_agg_features(): """ Test FeathrClient() get_online_features and batch_get can get data correctly. """ online_test_table = get_online_test_table_name("nycTaxiCITableSparkE2E") - test_workspace_dir = Path( - __file__).parent.resolve() / "test_user_workspace" + test_workspace_dir = Path(__file__).parent.resolve() / "test_user_workspace" # os.chdir(test_workspace_dir) client: FeathrClient = composite_keys_test_setup(os.path.join(test_workspace_dir, "feathr_config.yaml")) - backfill_time = BackfillTime(start=datetime( - 2020, 5, 20), end=datetime(2020, 5, 20), step=timedelta(days=1)) + backfill_time = BackfillTime(start=datetime(2020, 5, 20), end=datetime(2020, 5, 20), step=timedelta(days=1)) redisSink = RedisSink(table_name=online_test_table) - settings = MaterializationSettings("nycTaxiTable", - sinks=[redisSink], - feature_names=[ - "f_location_avg_fare", "f_location_max_fare"], - backfill_time=backfill_time) + settings = MaterializationSettings( + "nycTaxiTable", + sinks=[redisSink], + feature_names=["f_location_avg_fare", "f_location_max_fare"], + backfill_time=backfill_time, + ) client.materialize_features(settings) # just assume the job is successful without validating the actual result in Redis. Might need to consolidate # this part with the test_feathr_online_store test case client.wait_job_to_finish(timeout_sec=Constants.SPARK_JOB_TIMEOUT_SECONDS) - res = client.get_online_features(online_test_table, ["81", "254"], [ - 'f_location_avg_fare', 'f_location_max_fare']) + res = client.get_online_features(online_test_table, ["81", "254"], ["f_location_avg_fare", "f_location_max_fare"]) # just assume there are values. We don't hard code the values for now for testing # the correctness of the feature generation should be guaranteed by feathr runtime. # ID 239 and 265 are available in the `DOLocationID` column in this file: # https://s3.amazonaws.com/nyc-tlc/trip+data/green_tripdata_2020-04.csv # View more details on this dataset: https://www1.nyc.gov/site/tlc/about/tlc-trip-record-data.page assert res != None - res = client.multi_get_online_features(online_test_table, - [["81","254"], ["25","42"]], - ['f_location_avg_fare', 'f_location_max_fare']) - assert res['81#254'] != None - assert res['25#42'] != None - + res = client.multi_get_online_features( + online_test_table, [["81", "254"], ["25", "42"]], ["f_location_avg_fare", "f_location_max_fare"] + ) + assert res["81#254"] != None + assert res["25#42"] != None + client._clean_test_data(online_test_table) + @pytest.mark.skip(reason="Add back when complex types are supported in python API") def test_feathr_online_store_non_agg_features(): """ Test FeathrClient() online_get_features and batch_get can get data correctly. """ - test_workspace_dir = Path( - __file__).parent.resolve() / "test_user_workspace" + test_workspace_dir = Path(__file__).parent.resolve() / "test_user_workspace" client = basic_test_setup(os.path.join(test_workspace_dir, "feathr_config.yaml")) - online_test_table = get_online_test_table_name('nycTaxiCITableNonAggFeature') - backfill_time = BackfillTime(start=datetime( - 2020, 5, 20), end=datetime(2020, 5, 20), step=timedelta(days=1)) + online_test_table = get_online_test_table_name("nycTaxiCITableNonAggFeature") + backfill_time = BackfillTime(start=datetime(2020, 5, 20), end=datetime(2020, 5, 20), step=timedelta(days=1)) redisSink = RedisSink(table_name=online_test_table) - settings = MaterializationSettings("nycTaxiTable", - sinks=[redisSink], - feature_names=["f_gen_trip_distance", "f_gen_is_long_trip_distance", "f1", "f2", "f3", "f4", "f5", "f6"], - backfill_time=backfill_time) + settings = MaterializationSettings( + "nycTaxiTable", + sinks=[redisSink], + feature_names=["f_gen_trip_distance", "f_gen_is_long_trip_distance", "f1", "f2", "f3", "f4", "f5", "f6"], + backfill_time=backfill_time, + ) client.materialize_features(settings, allow_materialize_non_agg_feature=True) # just assume the job is successful without validating the actual result in Redis. Might need to consolidate # this part with the test_feathr_online_store test case client.wait_job_to_finish(timeout_sec=Constants.SPARK_JOB_TIMEOUT_SECONDS) - res = client.get_online_features(online_test_table, '111', ['f_gen_trip_distance', 'f_gen_is_long_trip_distance', - 'f1', 'f2', 'f3', 'f4', 'f5', 'f6']) + res = client.get_online_features( + online_test_table, + "111", + ["f_gen_trip_distance", "f_gen_is_long_trip_distance", "f1", "f2", "f3", "f4", "f5", "f6"], + ) # just assume there are values. We don't hard code the values for now for testing # the correctness of the feature generation should be guaranteed by feathr runtime. # ID 239 and 265 are available in the `DOLocationID` column in this file: @@ -135,31 +160,32 @@ def test_feathr_online_store_non_agg_features(): assert res[1] != None # assert constant features _validate_constant_feature(res) - res = client.multi_get_online_features(online_test_table, - ['239', '265'], - ['f_gen_trip_distance', 'f_gen_is_long_trip_distance', 'f1', 'f2', 'f3', 'f4', 'f5', 'f6']) - _validate_constant_feature(res['239']) - assert res['239'][0] != None - assert res['239'][1] != None - _validate_constant_feature(res['265']) - assert res['265'][0] != None - assert res['265'][1] != None - + res = client.multi_get_online_features( + online_test_table, + ["239", "265"], + ["f_gen_trip_distance", "f_gen_is_long_trip_distance", "f1", "f2", "f3", "f4", "f5", "f6"], + ) + _validate_constant_feature(res["239"]) + assert res["239"][0] != None + assert res["239"][1] != None + _validate_constant_feature(res["265"]) + assert res["265"][0] != None + assert res["265"][1] != None + client._clean_test_data(online_test_table) def _validate_constant_feature(feature): assert feature[2] == [10.0, 20.0, 30.0] - assert feature[3] == ['a', 'b', 'c'] - assert feature[4] == ([1, 2, 3], ['10', '20', '30']) + assert feature[3] == ["a", "b", "c"] + assert feature[4] == ([1, 2, 3], ["10", "20", "30"]) assert feature[5] == ([1, 2, 3], [True, False, True]) assert feature[6] == ([1, 2, 3], [1.0, 2.0, 3.0]) assert feature[7] == ([1, 2, 3], [1, 2, 3]) def test_dbfs_path(): - test_workspace_dir = Path( - __file__).parent.resolve() / "test_user_workspace" + test_workspace_dir = Path(__file__).parent.resolve() / "test_user_workspace" client = basic_test_setup(os.path.join(test_workspace_dir, "feathr_config.yaml")) if client.spark_runtime.casefold() == "databricks": # expect this raise an error since the result path is not in dbfs: format @@ -174,32 +200,39 @@ def test_feathr_get_offline_features(): runner = CliRunner() with runner.isolated_filesystem(): runner.invoke(init, []) - client = basic_test_setup( - "./feathr_user_workspace/feathr_config.yaml") + client = basic_test_setup("./feathr_user_workspace/feathr_config.yaml") - location_id = TypedKey(key_column="DOLocationID", - key_column_type=ValueType.INT32, - description="location id in NYC", - full_name="nyc_taxi.location_id") + location_id = TypedKey( + key_column="DOLocationID", + key_column_type=ValueType.INT32, + description="location id in NYC", + full_name="nyc_taxi.location_id", + ) - feature_query = FeatureQuery( - feature_list=["f_location_avg_fare", "f_trip_time_rounded"], key=location_id) + feature_query = FeatureQuery(feature_list=["f_location_avg_fare", "f_trip_time_rounded"], key=location_id) settings = ObservationSettings( observation_path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04.csv", event_timestamp_column="lpep_dropoff_datetime", - timestamp_format="yyyy-MM-dd HH:mm:ss") + timestamp_format="yyyy-MM-dd HH:mm:ss", + ) now = datetime.now() # set output folder based on different runtime - if client.spark_runtime == 'databricks': - output_path = ''.join(['dbfs:/feathrazure_cijob','_', str(now.minute), '_', str(now.second), ".avro"]) + if client.spark_runtime == "databricks": + output_path = "".join(["dbfs:/feathrazure_cijob", "_", str(now.minute), "_", str(now.second), ".avro"]) else: - output_path = ''.join(['abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/demo_data/output','_', str(now.minute), '_', str(now.second), ".avro"]) - - - client.get_offline_features(observation_settings=settings, - feature_query=feature_query, - output_path=output_path) + output_path = "".join( + [ + "abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/demo_data/output", + "_", + str(now.minute), + "_", + str(now.second), + ".avro", + ] + ) + + client.get_offline_features(observation_settings=settings, feature_query=feature_query, output_path=output_path) # assuming the job can successfully run; otherwise it will throw exception client.wait_job_to_finish(timeout_sec=Constants.SPARK_JOB_TIMEOUT_SECONDS) @@ -208,131 +241,142 @@ def test_feathr_get_offline_features(): res_df = get_result_df(client) assert res_df.shape[0] > 0 + def test_feathr_get_offline_features_to_sql(): """ Test get_offline_features() can save data to SQL. """ # runner.invoke(init, []) - test_workspace_dir = Path( - __file__).parent.resolve() / "test_user_workspace" + test_workspace_dir = Path(__file__).parent.resolve() / "test_user_workspace" client: FeathrClient = basic_test_setup(os.path.join(test_workspace_dir, "feathr_config.yaml")) - location_id = TypedKey(key_column="DOLocationID", - key_column_type=ValueType.INT32, - description="location id in NYC", - full_name="nyc_taxi.location_id") + location_id = TypedKey( + key_column="DOLocationID", + key_column_type=ValueType.INT32, + description="location id in NYC", + full_name="nyc_taxi.location_id", + ) - feature_query = FeatureQuery( - feature_list=["f_location_avg_fare"], key=location_id) + feature_query = FeatureQuery(feature_list=["f_location_avg_fare"], key=location_id) settings = ObservationSettings( observation_path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04.csv", event_timestamp_column="lpep_dropoff_datetime", - timestamp_format="yyyy-MM-dd HH:mm:ss") + timestamp_format="yyyy-MM-dd HH:mm:ss", + ) now = datetime.now() # Set DB user and password before submitting job # os.environ[f"sql1_USER"] = "some_user@feathrtestsql4" # os.environ[f"sql1_PASSWORD"] = "some_password" - output_path = JdbcSink(name="sql1", - url="jdbc:sqlserver://feathrazureci.database.windows.net:1433;database=feathrci;encrypt=true;", - dbtable=f'feathr_ci_materialization_{str(now)[:19].replace(" ", "_").replace(":", "_").replace("-", "_")}', - auth="USERPASS") + output_path = JdbcSink( + name="sql1", + url="jdbc:sqlserver://feathrazureci.database.windows.net:1433;database=feathrci;encrypt=true;", + dbtable=f'feathr_ci_materialization_{str(now)[:19].replace(" ", "_").replace(":", "_").replace("-", "_")}', + auth="USERPASS", + ) - client.get_offline_features(observation_settings=settings, - feature_query=feature_query, - output_path=output_path) + client.get_offline_features(observation_settings=settings, feature_query=feature_query, output_path=output_path) # assuming the job can successfully run; otherwise it will throw exception client.wait_job_to_finish(timeout_sec=Constants.SPARK_JOB_TIMEOUT_SECONDS) + @pytest.mark.skip(reason="Marked as skipped as we need to setup token and enable SQL AAD login for this test") def test_feathr_get_offline_features_to_sql_with_token(): """ Test get_offline_features() can save data to SQL. """ # runner.invoke(init, []) - test_workspace_dir = Path( - __file__).parent.resolve() / "test_user_workspace" + test_workspace_dir = Path(__file__).parent.resolve() / "test_user_workspace" client: FeathrClient = basic_test_setup(os.path.join(test_workspace_dir, "feathr_config.yaml")) - location_id = TypedKey(key_column="DOLocationID", - key_column_type=ValueType.INT32, - description="location id in NYC", - full_name="nyc_taxi.location_id") + location_id = TypedKey( + key_column="DOLocationID", + key_column_type=ValueType.INT32, + description="location id in NYC", + full_name="nyc_taxi.location_id", + ) - feature_query = FeatureQuery( - feature_list=["f_location_avg_fare"], key=location_id) + feature_query = FeatureQuery(feature_list=["f_location_avg_fare"], key=location_id) settings = ObservationSettings( observation_path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04.csv", event_timestamp_column="lpep_dropoff_datetime", - timestamp_format="yyyy-MM-dd HH:mm:ss") + timestamp_format="yyyy-MM-dd HH:mm:ss", + ) now = datetime.now() # Set DB token before submitting job # os.environ[f"SQL1_TOKEN"] = "some_token" os.environ["SQL1_TOKEN"] = client.credential.get_token("https://management.azure.com/.default").token - output_path = JdbcSink(name="sql1", - url="jdbc:sqlserver://feathrazureci.database.windows.net:1433;database=feathrci;encrypt=true;", - dbtable=f'feathr_ci_sql_token_{str(now)[:19].replace(" ", "_").replace(":", "_").replace("-", "_")}', - auth="TOKEN") + output_path = JdbcSink( + name="sql1", + url="jdbc:sqlserver://feathrazureci.database.windows.net:1433;database=feathrci;encrypt=true;", + dbtable=f'feathr_ci_sql_token_{str(now)[:19].replace(" ", "_").replace(":", "_").replace("-", "_")}', + auth="TOKEN", + ) - client.get_offline_features(observation_settings=settings, - feature_query=feature_query, - output_path=output_path) + client.get_offline_features(observation_settings=settings, feature_query=feature_query, output_path=output_path) # assuming the job can successfully run; otherwise it will throw exception client.wait_job_to_finish(timeout_sec=Constants.SPARK_JOB_TIMEOUT_SECONDS) -@pytest.mark.skipif(os.environ.get('SPARK_CONFIG__SPARK_CLUSTER') == "databricks", - reason="Due to package conflicts, the CosmosDB test doesn't work on databricks clusters, refer to https://github.com/feathr-ai/feathr/blob/main/docs/how-to-guides/jdbc-cosmos-notes.md#using-cosmosdb-as-the-online-store for more details") + +@pytest.mark.skipif( + os.environ.get("SPARK_CONFIG__SPARK_CLUSTER") == "databricks", + reason="Due to package conflicts, the CosmosDB test doesn't work on databricks clusters, refer to https://github.com/feathr-ai/feathr/blob/main/docs/how-to-guides/jdbc-cosmos-notes.md#using-cosmosdb-as-the-online-store for more details", +) def test_feathr_materialize_to_cosmosdb(): """ Test FeathrClient() CosmosDbSink. """ - test_workspace_dir = Path( - __file__).parent.resolve() / "test_user_workspace" + test_workspace_dir = Path(__file__).parent.resolve() / "test_user_workspace" # os.chdir(test_workspace_dir) client: FeathrClient = basic_test_setup(os.path.join(test_workspace_dir, "feathr_config.yaml")) - backfill_time = BackfillTime(start=datetime( - 2020, 5, 20), end=datetime(2020, 5, 20), step=timedelta(days=1)) + backfill_time = BackfillTime(start=datetime(2020, 5, 20), end=datetime(2020, 5, 20), step=timedelta(days=1)) now = datetime.now() - container = ''.join(['feathrazure_cijob_materialize_','_', str(now.minute), '_', str(now.second), ""]) - sink = CosmosDbSink(name='cosmos1', endpoint='https://feathrazuretest3-cosmosdb.documents.azure.com:443/', database='feathr', container=container) - settings = MaterializationSettings("nycTaxiTable", - sinks=[sink], - feature_names=[ - "f_location_avg_fare", "f_location_max_fare"], - backfill_time=backfill_time) + container = "".join(["feathrazure_cijob_materialize_", "_", str(now.minute), "_", str(now.second), ""]) + sink = CosmosDbSink( + name="cosmos1", + endpoint="https://feathrazuretest3-cosmosdb.documents.azure.com:443/", + database="feathr", + container=container, + ) + settings = MaterializationSettings( + "nycTaxiTable", + sinks=[sink], + feature_names=["f_location_avg_fare", "f_location_max_fare"], + backfill_time=backfill_time, + ) client.materialize_features(settings) # assuming the job can successfully run; otherwise it will throw exception client.wait_job_to_finish(timeout_sec=Constants.SPARK_JOB_TIMEOUT_SECONDS) + @pytest.mark.skip(reason="Marked as skipped as we need to setup resources for this test") def test_feathr_materialize_to_es(): """ Test FeathrClient() CosmosDbSink. """ - test_workspace_dir = Path( - __file__).parent.resolve() / "test_user_workspace" + test_workspace_dir = Path(__file__).parent.resolve() / "test_user_workspace" # os.chdir(test_workspace_dir) client: FeathrClient = basic_test_setup(os.path.join(test_workspace_dir, "feathr_config.yaml")) - backfill_time = BackfillTime(start=datetime( - 2020, 5, 20), end=datetime(2020, 5, 20), step=timedelta(days=1)) + backfill_time = BackfillTime(start=datetime(2020, 5, 20), end=datetime(2020, 5, 20), step=timedelta(days=1)) now = datetime.now() - index = ''.join(['feathrazure_cijob_materialize_','_', str(now.minute), '_', str(now.second), ""]) - sink = ElasticSearchSink(name='es1', host='somenode:9200', index=index, ssl=True, auth=True) - settings = MaterializationSettings("nycTaxiTable", - sinks=[sink], - feature_names=[ - "f_location_avg_fare", "f_location_max_fare"], - backfill_time=backfill_time) + index = "".join(["feathrazure_cijob_materialize_", "_", str(now.minute), "_", str(now.second), ""]) + sink = ElasticSearchSink(name="es1", host="somenode:9200", index=index, ssl=True, auth=True) + settings = MaterializationSettings( + "nycTaxiTable", + sinks=[sink], + feature_names=["f_location_avg_fare", "f_location_max_fare"], + backfill_time=backfill_time, + ) client.materialize_features(settings) # Set user and password before submitting job # os.environ[f"es1_USER"] = "some_user" @@ -340,103 +384,115 @@ def test_feathr_materialize_to_es(): # assuming the job can successfully run; otherwise it will throw exception client.wait_job_to_finish(timeout_sec=Constants.SPARK_JOB_TIMEOUT_SECONDS) + @pytest.mark.skip(reason="Marked as skipped as we need to setup resources for this test") def test_feathr_materialize_to_aerospike(): """ Test FeathrClient() CosmosDbSink. """ - test_workspace_dir = Path( - __file__).parent.resolve() / "test_user_workspace" + test_workspace_dir = Path(__file__).parent.resolve() / "test_user_workspace" # os.chdir(test_workspace_dir) now = datetime.now() # set workspace folder by time; make sure we don't have write conflict if there are many CI tests running - os.environ['SPARK_CONFIG__DATABRICKS__WORK_DIR'] = ''.join(['dbfs:/feathrazure_cijob','_', str(now.minute), '_', str(now.second), '_', str(now.microsecond)]) - os.environ['SPARK_CONFIG__AZURE_SYNAPSE__WORKSPACE_DIR'] = ''.join(['abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/feathr_github_ci','_', str(now.minute), '_', str(now.second) ,'_', str(now.microsecond)]) + os.environ["SPARK_CONFIG__DATABRICKS__WORK_DIR"] = "".join( + ["dbfs:/feathrazure_cijob", "_", str(now.minute), "_", str(now.second), "_", str(now.microsecond)] + ) + os.environ["SPARK_CONFIG__AZURE_SYNAPSE__WORKSPACE_DIR"] = "".join( + [ + "abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/feathr_github_ci", + "_", + str(now.minute), + "_", + str(now.second), + "_", + str(now.microsecond), + ] + ) client = FeathrClient(config_path="feathr_config.yaml") - batch_source = HdfsSource(name="nycTaxiBatchSource", - path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04.csv", - event_timestamp_column="lpep_dropoff_datetime", - timestamp_format="yyyy-MM-dd HH:mm:ss") + batch_source = HdfsSource( + name="nycTaxiBatchSource", + path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04.csv", + event_timestamp_column="lpep_dropoff_datetime", + timestamp_format="yyyy-MM-dd HH:mm:ss", + ) - f_trip_distance = Feature(name="f_trip_distance", - feature_type=FLOAT, transform="trip_distance") - f_trip_time_duration = Feature(name="f_trip_time_duration", - feature_type=INT32, - transform="(to_unix_timestamp(lpep_dropoff_datetime) - to_unix_timestamp(lpep_pickup_datetime))/60") + f_trip_distance = Feature(name="f_trip_distance", feature_type=FLOAT, transform="trip_distance") + f_trip_time_duration = Feature( + name="f_trip_time_duration", + feature_type=INT32, + transform="(to_unix_timestamp(lpep_dropoff_datetime) - to_unix_timestamp(lpep_pickup_datetime))/60", + ) features = [ f_trip_distance, f_trip_time_duration, - Feature(name="f_is_long_trip_distance", - feature_type=BOOLEAN, - transform="cast_float(trip_distance)>30"), - Feature(name="f_day_of_week", - feature_type=INT32, - transform="dayofweek(lpep_dropoff_datetime)"), + Feature(name="f_is_long_trip_distance", feature_type=BOOLEAN, transform="cast_float(trip_distance)>30"), + Feature(name="f_day_of_week", feature_type=INT32, transform="dayofweek(lpep_dropoff_datetime)"), ] + request_anchor = FeatureAnchor(name="request_features", source=INPUT_CONTEXT, features=features) + + f_trip_time_distance = DerivedFeature( + name="f_trip_time_distance", + feature_type=FLOAT, + input_features=[f_trip_distance, f_trip_time_duration], + transform="f_trip_distance * f_trip_time_duration", + ) + + f_trip_time_rounded = DerivedFeature( + name="f_trip_time_rounded", + feature_type=INT32, + input_features=[f_trip_time_duration], + transform="f_trip_time_duration % 10", + ) + + location_id = TypedKey( + key_column="DOLocationID", + key_column_type=ValueType.INT32, + description="location id in NYC", + full_name="nyc_taxi.location_id", + ) + agg_features = [ + Feature( + name="avgfare", + key=location_id, + feature_type=FLOAT, + transform=WindowAggTransformation( + agg_expr="cast_float(fare_amount)", + agg_func="AVG", + window="90d", + ), + ), + Feature( + name="maxfare", + key=location_id, + feature_type=FLOAT, + transform=WindowAggTransformation(agg_expr="cast_float(fare_amount)", agg_func="MAX", window="90d"), + ), + ] - request_anchor = FeatureAnchor(name="request_features", - source=INPUT_CONTEXT, - features=features) - - f_trip_time_distance = DerivedFeature(name="f_trip_time_distance", - feature_type=FLOAT, - input_features=[ - f_trip_distance, f_trip_time_duration], - transform="f_trip_distance * f_trip_time_duration") - - f_trip_time_rounded = DerivedFeature(name="f_trip_time_rounded", - feature_type=INT32, - input_features=[f_trip_time_duration], - transform="f_trip_time_duration % 10") - - location_id = TypedKey(key_column="DOLocationID", - key_column_type=ValueType.INT32, - description="location id in NYC", - full_name="nyc_taxi.location_id") - agg_features = [Feature(name="avgfare", - key=location_id, - feature_type=FLOAT, - transform=WindowAggTransformation(agg_expr="cast_float(fare_amount)", - agg_func="AVG", - window="90d", - )), - Feature(name="maxfare", - key=location_id, - feature_type=FLOAT, - transform=WindowAggTransformation(agg_expr="cast_float(fare_amount)", - agg_func="MAX", - window="90d")) - ] - - agg_anchor = FeatureAnchor(name="aggregationFeatures", - source=batch_source, - features=agg_features) - - client.build_features(anchor_list=[agg_anchor, request_anchor], derived_feature_list=[ - f_trip_time_distance, f_trip_time_rounded]) - - - backfill_time = BackfillTime(start=datetime( - 2020, 5, 20), end=datetime(2020, 5, 20), step=timedelta(days=1)) + agg_anchor = FeatureAnchor(name="aggregationFeatures", source=batch_source, features=agg_features) + + client.build_features( + anchor_list=[agg_anchor, request_anchor], derived_feature_list=[f_trip_time_distance, f_trip_time_rounded] + ) + + backfill_time = BackfillTime(start=datetime(2020, 5, 20), end=datetime(2020, 5, 20), step=timedelta(days=1)) now = datetime.now() os.environ[f"aerospike_USER"] = "feathruser" os.environ[f"aerospike_PASSWORD"] = "feathr" - as_sink = AerospikeSink(name="aerospike",seedhost="20.57.186.153", port=3000, namespace="test", setname="test") - settings = MaterializationSettings("nycTaxiTable", - sinks=[as_sink], - feature_names=[ - "avgfare", "maxfare"], - backfill_time=backfill_time) + as_sink = AerospikeSink(name="aerospike", seedhost="20.57.186.153", port=3000, namespace="test", setname="test") + settings = MaterializationSettings( + "nycTaxiTable", sinks=[as_sink], feature_names=["avgfare", "maxfare"], backfill_time=backfill_time + ) client.materialize_features(settings) # assuming the job can successfully run; otherwise it will throw exception client.wait_job_to_finish(timeout_sec=Constants.SPARK_JOB_TIMEOUT_SECONDS) + if __name__ == "__main__": test_feathr_materialize_to_aerospike() test_feathr_get_offline_features_to_sql() test_feathr_materialize_to_cosmosdb() - diff --git a/feathr_project/test/test_azure_spark_maven_e2e.py b/feathr_project/test/test_azure_spark_maven_e2e.py index a2f214020..a8564141e 100644 --- a/feathr_project/test/test_azure_spark_maven_e2e.py +++ b/feathr_project/test/test_azure_spark_maven_e2e.py @@ -2,74 +2,80 @@ from datetime import datetime, timedelta from pathlib import Path -from feathr import (BackfillTime, MaterializationSettings) +from feathr import BackfillTime, MaterializationSettings + # from feathr import * from feathr.client import FeathrClient from feathr.definition.dtype import ValueType from feathr.definition.query_feature_list import FeatureQuery from feathr.definition.settings import ObservationSettings from feathr.definition.typed_key import TypedKey -from test_fixture import (basic_test_setup, get_online_test_table_name) +from test_fixture import basic_test_setup, get_online_test_table_name from test_utils.constants import Constants + def test_feathr_online_store_agg_features(): """ Test FeathrClient() get_online_features and batch_get can get data correctly. """ online_test_table = get_online_test_table_name("nycTaxiCITableMaven") - test_workspace_dir = Path( - __file__).parent.resolve() / "test_user_workspace" + test_workspace_dir = Path(__file__).parent.resolve() / "test_user_workspace" # os.chdir(test_workspace_dir) # The `feathr_runtime_location` was commented out in this config file, so feathr should use # Maven package as the dependency and `noop.jar` as the main file client: FeathrClient = basic_test_setup(os.path.join(test_workspace_dir, "feathr_config_maven.yaml")) - - - location_id = TypedKey(key_column="DOLocationID", - key_column_type=ValueType.INT32, - description="location id in NYC", - full_name="nyc_taxi.location_id") + location_id = TypedKey( + key_column="DOLocationID", + key_column_type=ValueType.INT32, + description="location id in NYC", + full_name="nyc_taxi.location_id", + ) - feature_query = FeatureQuery( - feature_list=["f_location_avg_fare"], key=location_id) + feature_query = FeatureQuery(feature_list=["f_location_avg_fare"], key=location_id) settings = ObservationSettings( observation_path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04.csv", event_timestamp_column="lpep_dropoff_datetime", - timestamp_format="yyyy-MM-dd HH:mm:ss") + timestamp_format="yyyy-MM-dd HH:mm:ss", + ) now = datetime.now() # set output folder based on different runtime - if client.spark_runtime == 'databricks': - output_path = ''.join(['dbfs:/feathrazure_cijob','_', str(now.minute), '_', str(now.second), ".avro"]) + if client.spark_runtime == "databricks": + output_path = "".join(["dbfs:/feathrazure_cijob", "_", str(now.minute), "_", str(now.second), ".avro"]) else: - output_path = ''.join(['abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/demo_data/output','_', str(now.minute), '_', str(now.second), ".avro"]) - + output_path = "".join( + [ + "abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/demo_data/output", + "_", + str(now.minute), + "_", + str(now.second), + ".avro", + ] + ) - client.get_offline_features(observation_settings=settings, - feature_query=feature_query, - output_path=output_path) + client.get_offline_features(observation_settings=settings, feature_query=feature_query, output_path=output_path) # assuming the job can successfully run; otherwise it will throw exception client.wait_job_to_finish(timeout_sec=Constants.SPARK_JOB_TIMEOUT_SECONDS) return - backfill_time = BackfillTime(start=datetime( - 2020, 5, 20), end=datetime(2020, 5, 20), step=timedelta(days=1)) + backfill_time = BackfillTime(start=datetime(2020, 5, 20), end=datetime(2020, 5, 20), step=timedelta(days=1)) redisSink = RedisSink(table_name=online_test_table) - settings = MaterializationSettings("TestJobName", - sinks=[redisSink], - feature_names=[ - "f_location_avg_fare", "f_location_max_fare"], - backfill_time=backfill_time) + settings = MaterializationSettings( + "TestJobName", + sinks=[redisSink], + feature_names=["f_location_avg_fare", "f_location_max_fare"], + backfill_time=backfill_time, + ) client.materialize_features(settings) # just assume the job is successful without validating the actual result in Redis. Might need to consolidate # this part with the test_feathr_online_store test case client.wait_job_to_finish(timeout_sec=Constants.SPARK_JOB_TIMEOUT_SECONDS) - res = client.get_online_features(online_test_table, '265', [ - 'f_location_avg_fare', 'f_location_max_fare']) + res = client.get_online_features(online_test_table, "265", ["f_location_avg_fare", "f_location_max_fare"]) # just assume there are values. We don't hard code the values for now for testing # the correctness of the feature generation should be guaranteed by feathr runtime. # ID 239 and 265 are available in the `DOLocationID` column in this file: @@ -78,10 +84,10 @@ def test_feathr_online_store_agg_features(): assert len(res) == 2 assert res[0] != None assert res[1] != None - res = client.multi_get_online_features(online_test_table, - ['239', '265'], - ['f_location_avg_fare', 'f_location_max_fare']) - assert res['239'][0] != None - assert res['239'][1] != None - assert res['265'][0] != None - assert res['265'][1] != None + res = client.multi_get_online_features( + online_test_table, ["239", "265"], ["f_location_avg_fare", "f_location_max_fare"] + ) + assert res["239"][0] != None + assert res["239"][1] != None + assert res["265"][0] != None + assert res["265"][1] != None diff --git a/feathr_project/test/test_cli.py b/feathr_project/test/test_cli.py index 6fc3cf986..a27eb666e 100644 --- a/feathr_project/test/test_cli.py +++ b/feathr_project/test/test_cli.py @@ -10,23 +10,22 @@ def test_workspace_creation(): """ runner = CliRunner() with runner.isolated_filesystem(): - result = runner.invoke(init, []) assert result.exit_code == 0 assert os.path.isdir("./feathr_user_workspace") - total_yaml_files = glob.glob('./feathr_user_workspace/*.yaml', recursive=True) + total_yaml_files = glob.glob("./feathr_user_workspace/*.yaml", recursive=True) # we should have exact 1 yaml file assert len(total_yaml_files) == 1 # result = runner.invoke(init, []) - test_folder_name = 'test_folder' - result = runner.invoke(init, ['--name', test_folder_name]) + test_folder_name = "test_folder" + result = runner.invoke(init, ["--name", test_folder_name]) assert result.exit_code == 0 - total_yaml_files = glob.glob(os.path.join(test_folder_name, '*.yaml'), recursive=True) + total_yaml_files = glob.glob(os.path.join(test_folder_name, "*.yaml"), recursive=True) # we should have exact 1 yaml file assert len(total_yaml_files) == 1 @@ -34,5 +33,5 @@ def test_workspace_creation(): assert result.exit_code == 2 # use output for test for now - expected_out = f'Feathr workspace ({test_folder_name}) already exist. Please use a new folder name.\n' + expected_out = f"Feathr workspace ({test_folder_name}) already exist. Please use a new folder name.\n" assert expected_out in result.output diff --git a/feathr_project/test/test_config_loading.py b/feathr_project/test/test_config_loading.py index c062af1c2..9346b6f1f 100644 --- a/feathr_project/test/test_config_loading.py +++ b/feathr_project/test/test_config_loading.py @@ -10,22 +10,21 @@ def test_configuration_loading(): """ runner = CliRunner() with runner.isolated_filesystem(): - result = runner.invoke(init, []) assert result.exit_code == 0 - assert os.path.isdir('./feathr_user_workspace') + assert os.path.isdir("./feathr_user_workspace") - client = FeathrClient(config_path='./feathr_user_workspace/feathr_config.yaml') + client = FeathrClient(config_path="./feathr_user_workspace/feathr_config.yaml") # test the loading is correct even if we are not in that folder assert client._FEATHR_JOB_JAR_PATH is not None - SPARK_RESULT_OUTPUT_PARTS = '4' + SPARK_RESULT_OUTPUT_PARTS = "4" # Use a less impactful config to test, as this config might be impactful for all the tests (since it's setting the envs) - os.environ['SPARK_CONFIG__SPARK_RESULT_OUTPUT_PARTS'] = SPARK_RESULT_OUTPUT_PARTS + os.environ["SPARK_CONFIG__SPARK_RESULT_OUTPUT_PARTS"] = SPARK_RESULT_OUTPUT_PARTS # this should not be error out as we will just give users prompt, though the config is not really here - client = FeathrClient(config_path='./feathr_user_workspace/feathr_config.yaml') + client = FeathrClient(config_path="./feathr_user_workspace/feathr_config.yaml") assert client.output_num_parts == SPARK_RESULT_OUTPUT_PARTS diff --git a/feathr_project/test/test_derived_features.py b/feathr_project/test/test_derived_features.py index ee10cd285..5708dfd0b 100644 --- a/feathr_project/test/test_derived_features.py +++ b/feathr_project/test/test_derived_features.py @@ -5,20 +5,29 @@ from feathr import TypedKey import pytest + def assert_config_equals(one, another): - assert one.translate(str.maketrans('', '', ' \n\t\r')) == another.translate(str.maketrans('', '', ' \n\t\r')) + assert one.translate(str.maketrans("", "", " \n\t\r")) == another.translate(str.maketrans("", "", " \n\t\r")) + def test_single_key_derived_feature_to_config(): """Single key derived feature config generation should work""" - user_key = TypedKey(full_name="mockdata.user", key_column="user_id", key_column_type=ValueType.INT32, description="An user identifier") + user_key = TypedKey( + full_name="mockdata.user", + key_column="user_id", + key_column_type=ValueType.INT32, + description="An user identifier", + ) user_embedding = Feature(name="user_embedding", feature_type=FLOAT_VECTOR, key=user_key) # A derived feature - derived_feature = DerivedFeature(name="user_embemdding_derived", - feature_type=FLOAT, - key=user_key, - input_features=user_embedding, - transform="if_else(user_embedding, user_embedding, [])") + derived_feature = DerivedFeature( + name="user_embemdding_derived", + feature_type=FLOAT, + key=user_key, + input_features=user_embedding, + transform="if_else(user_embedding, user_embedding, [])", + ) derived_feature_config = """ user_embemdding_derived: { @@ -36,20 +45,33 @@ def test_single_key_derived_feature_to_config(): }""" assert_config_equals(derived_feature.to_feature_config(), derived_feature_config) + def test_multikey_derived_feature_to_config(): """Multikey derived feature config generation should work""" - user_key = TypedKey(full_name="mockdata.user", key_column="user_id", key_column_type=ValueType.INT32, description="An user identifier") - item_key = TypedKey(full_name="mockdata.item", key_column="item_id", key_column_type=ValueType.INT32, description="An item identifier") + user_key = TypedKey( + full_name="mockdata.user", + key_column="user_id", + key_column_type=ValueType.INT32, + description="An user identifier", + ) + item_key = TypedKey( + full_name="mockdata.item", + key_column="item_id", + key_column_type=ValueType.INT32, + description="An item identifier", + ) user_embedding = Feature(name="user_embedding", feature_type=FLOAT_VECTOR, key=user_key) item_embedding = Feature(name="item_embedding", feature_type=FLOAT_VECTOR, key=item_key) # A derived feature - user_item_similarity = DerivedFeature(name="user_item_similarity", - feature_type=FLOAT, - key=[user_key, item_key], - input_features=[user_embedding, item_embedding], - transform="similarity(user_embedding, item_embedding)") + user_item_similarity = DerivedFeature( + name="user_item_similarity", + feature_type=FLOAT, + key=[user_key, item_key], + input_features=[user_embedding, item_embedding], + transform="similarity(user_embedding, item_embedding)", + ) derived_feature_config = """ user_item_similarity: { @@ -72,14 +94,23 @@ def test_multikey_derived_feature_to_config(): def test_derived_feature_to_config_with_alias(): # More complicated use case, viewer viewee aliasged user key # References the same key feature with different alias - user_key = TypedKey(full_name="mockdata.user", key_column="user_id", key_column_type=ValueType.INT32, description="An user identifier") + user_key = TypedKey( + full_name="mockdata.user", + key_column="user_id", + key_column_type=ValueType.INT32, + description="An user identifier", + ) user_embedding = Feature(name="user_embedding", key=user_key, feature_type=FLOAT_VECTOR) - viewer_viewee_distance = DerivedFeature(name="viewer_viewee_distance", - key=[user_key.as_key("viewer"), user_key.as_key("viewee")], - feature_type=FLOAT, - input_features=[user_embedding.with_key("viewer").as_feature("viewer_embedding"), - user_embedding.with_key("viewee").as_feature("viewee_embedding")], - transform="distance(viewer_embedding, viewee_embedding)") + viewer_viewee_distance = DerivedFeature( + name="viewer_viewee_distance", + key=[user_key.as_key("viewer"), user_key.as_key("viewee")], + feature_type=FLOAT, + input_features=[ + user_embedding.with_key("viewer").as_feature("viewer_embedding"), + user_embedding.with_key("viewee").as_feature("viewee_embedding"), + ], + transform="distance(viewer_embedding, viewee_embedding)", + ) expected_feature_config = """ viewer_viewee_distance: { @@ -102,24 +133,35 @@ def test_derived_feature_to_config_with_alias(): def test_multi_key_derived_feature_to_config_with_alias(): # References the same relation feature key alias with different alias # Note that in this case, it is possible that distance(a, b) != distance(b,a) - user_key = TypedKey(full_name="mockdata.user", key_column="user_id", key_column_type=ValueType.INT32, description="An user identifier") + user_key = TypedKey( + full_name="mockdata.user", + key_column="user_id", + key_column_type=ValueType.INT32, + description="An user identifier", + ) user_embedding = Feature(name="user_embedding", key=user_key, feature_type=FLOAT_VECTOR) - viewer_viewee_distance = DerivedFeature(name="viewer_viewee_distance", - key=[user_key.as_key("viewer"), user_key.as_key("viewee")], - feature_type=FLOAT, - input_features=[user_embedding.with_key("viewer").as_feature("viewer_embedding"), - user_embedding.with_key("viewee").as_feature("viewee_embedding")], - transform="distance(viewer_embedding, viewee_embedding)") - - viewee_viewer_combined = DerivedFeature(name = "viewee_viewer_combined_distance", - key=[user_key.as_key("viewer"), user_key.as_key("viewee")], - feature_type=FLOAT, - input_features=[viewer_viewee_distance.with_key(["viewer", "viewee"]) - .as_feature("viewer_viewee_distance"), - viewer_viewee_distance.with_key(["viewee", "viewer"]) - .as_feature("viewee_viewer_distance"),], - transform=ExpressionTransformation("viewer_viewee_distance + viewee_viewer_distance")) + viewer_viewee_distance = DerivedFeature( + name="viewer_viewee_distance", + key=[user_key.as_key("viewer"), user_key.as_key("viewee")], + feature_type=FLOAT, + input_features=[ + user_embedding.with_key("viewer").as_feature("viewer_embedding"), + user_embedding.with_key("viewee").as_feature("viewee_embedding"), + ], + transform="distance(viewer_embedding, viewee_embedding)", + ) + + viewee_viewer_combined = DerivedFeature( + name="viewee_viewer_combined_distance", + key=[user_key.as_key("viewer"), user_key.as_key("viewee")], + feature_type=FLOAT, + input_features=[ + viewer_viewee_distance.with_key(["viewer", "viewee"]).as_feature("viewer_viewee_distance"), + viewer_viewee_distance.with_key(["viewee", "viewer"]).as_feature("viewee_viewer_distance"), + ], + transform=ExpressionTransformation("viewer_viewee_distance + viewee_viewer_distance"), + ) # Note that unlike key features, a relation feature does not need a feature anchor. expected_feature_config = """ @@ -139,19 +181,32 @@ def test_multi_key_derived_feature_to_config_with_alias(): }""" assert_config_equals(viewee_viewer_combined.to_feature_config(), expected_feature_config) + def test_derived_feature_on_multikey_anchored_feature_to_config(): """Multikey derived feature config generation should work""" - user_key = TypedKey(full_name="mockdata.user", key_column="user_id", key_column_type=ValueType.INT32, description="First part of an user identifier") - user_key2 = TypedKey(full_name="mockdata.user2", key_column="user_id2", key_column_type=ValueType.INT32, description="Second part of an user identifier") + user_key = TypedKey( + full_name="mockdata.user", + key_column="user_id", + key_column_type=ValueType.INT32, + description="First part of an user identifier", + ) + user_key2 = TypedKey( + full_name="mockdata.user2", + key_column="user_id2", + key_column_type=ValueType.INT32, + description="Second part of an user identifier", + ) user_embedding = Feature(name="user_embedding", feature_type=FLOAT_VECTOR, key=[user_key, user_key2]) # A derived feature - user_item_derived = DerivedFeature(name="user_item_similarity", - feature_type=FLOAT, - key=[user_key.as_key("viewer"), user_key2.as_key("viewee")], - input_features=user_embedding.with_key(["viewer", "viewee"]), - transform="if_else(user_embedding, user_embedding, [])") + user_item_derived = DerivedFeature( + name="user_item_similarity", + feature_type=FLOAT, + key=[user_key.as_key("viewer"), user_key2.as_key("viewee")], + input_features=user_embedding.with_key(["viewer", "viewee"]), + transform="if_else(user_embedding, user_embedding, [])", + ) derived_feature_config = """ user_item_similarity: { @@ -170,18 +225,25 @@ def test_derived_feature_on_multikey_anchored_feature_to_config(): assert_config_equals(user_item_derived.to_feature_config(), derived_feature_config) - def test_multi_key_derived_feature_to_config_with_wrong_alias(): # References the same relation feature key alias with wrong alias # Should throw exception - user_key = TypedKey(full_name="mockdata.user", key_column="user_id", key_column_type=ValueType.INT32, description="An user identifier") + user_key = TypedKey( + full_name="mockdata.user", + key_column="user_id", + key_column_type=ValueType.INT32, + description="An user identifier", + ) user_embedding = Feature(name="user_embedding", key=user_key, feature_type=FLOAT_VECTOR) with pytest.raises(AssertionError): - viewer_viewee_distance = DerivedFeature(name="viewer_viewee_distance", - key=[user_key.as_key("non_exist_alias"), user_key.as_key("viewee")], - feature_type=FLOAT, - input_features=[user_embedding.with_key("viewer").as_feature("viewer_embedding"), - user_embedding.with_key("viewee").as_feature("viewee_embedding")], - transform="distance(viewer_embedding, viewee_embedding)") - + viewer_viewee_distance = DerivedFeature( + name="viewer_viewee_distance", + key=[user_key.as_key("non_exist_alias"), user_key.as_key("viewee")], + feature_type=FLOAT, + input_features=[ + user_embedding.with_key("viewer").as_feature("viewer_embedding"), + user_embedding.with_key("viewee").as_feature("viewee_embedding"), + ], + transform="distance(viewer_embedding, viewee_embedding)", + ) diff --git a/feathr_project/test/test_feature_anchor.py b/feathr_project/test/test_feature_anchor.py index d5e6701b8..6520a5fc1 100644 --- a/feathr_project/test/test_feature_anchor.py +++ b/feathr_project/test/test_feature_anchor.py @@ -6,20 +6,15 @@ from feathr import WindowAggTransformation from feathr import TypedKey + def test_request_feature_anchor_to_config(): features = [ Feature(name="trip_distance", feature_type=FLOAT), - Feature(name="f_is_long_trip_distance", - feature_type=BOOLEAN, - transform="cast_float(trip_distance)>30"), - Feature(name="f_day_of_week", - feature_type=INT32, - transform="dayofweek(lpep_dropoff_datetime)") + Feature(name="f_is_long_trip_distance", feature_type=BOOLEAN, transform="cast_float(trip_distance)>30"), + Feature(name="f_day_of_week", feature_type=INT32, transform="dayofweek(lpep_dropoff_datetime)"), ] - anchor = FeatureAnchor(name="request_features", - source=INPUT_CONTEXT, - features=features) + anchor = FeatureAnchor(name="request_features", source=INPUT_CONTEXT, features=features) expected_non_agg_feature_config = """ request_features: { source: PASSTHROUGH @@ -55,31 +50,36 @@ def test_request_feature_anchor_to_config(): } } """ - assert ''.join(anchor.to_feature_config().split()) == ''.join(expected_non_agg_feature_config.split()) + assert "".join(anchor.to_feature_config().split()) == "".join(expected_non_agg_feature_config.split()) def test_non_agg_feature_anchor_to_config(): - batch_source = HdfsSource(name="nycTaxiBatchSource", - path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04.csv", - event_timestamp_column="lpep_dropoff_datetime", - timestamp_format="yyyy-MM-dd HH:mm:ss") + batch_source = HdfsSource( + name="nycTaxiBatchSource", + path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04.csv", + event_timestamp_column="lpep_dropoff_datetime", + timestamp_format="yyyy-MM-dd HH:mm:ss", + ) - location_id = TypedKey(key_column="DOLocationID", - key_column_type=ValueType.INT32, - description="location id in NYC", - full_name="nyc_taxi.location_id") + location_id = TypedKey( + key_column="DOLocationID", + key_column_type=ValueType.INT32, + description="location id in NYC", + full_name="nyc_taxi.location_id", + ) features = [ - Feature(name="f_loc_is_long_trip_distance", - feature_type=BOOLEAN, - transform="cast_float(trip_distance)>30", key=location_id), - Feature(name="f_loc_day_of_week", - feature_type=INT32, - transform="dayofweek(lpep_dropoff_datetime)", key=location_id) + Feature( + name="f_loc_is_long_trip_distance", + feature_type=BOOLEAN, + transform="cast_float(trip_distance)>30", + key=location_id, + ), + Feature( + name="f_loc_day_of_week", feature_type=INT32, transform="dayofweek(lpep_dropoff_datetime)", key=location_id + ), ] - anchor = FeatureAnchor(name="nonAggFeatures", - source=batch_source, - features=features) + anchor = FeatureAnchor(name="nonAggFeatures", source=batch_source, features=features) expected_non_agg_feature_config = """ nonAggFeatures: { source: nycTaxiBatchSource @@ -106,36 +106,39 @@ def test_non_agg_feature_anchor_to_config(): } } """ - assert ''.join(anchor.to_feature_config().split()) == ''.join(expected_non_agg_feature_config.split()) + assert "".join(anchor.to_feature_config().split()) == "".join(expected_non_agg_feature_config.split()) def test_agg_anchor_to_config(): - batch_source = HdfsSource(name="nycTaxiBatchSource", - path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04.csv", - event_timestamp_column="lpep_dropoff_datetime", - timestamp_format="yyyy-MM-dd HH:mm:ss") + batch_source = HdfsSource( + name="nycTaxiBatchSource", + path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04.csv", + event_timestamp_column="lpep_dropoff_datetime", + timestamp_format="yyyy-MM-dd HH:mm:ss", + ) - location_id = TypedKey(key_column="DOLocationID", - key_column_type=ValueType.INT32, - description="location id in NYC", - full_name="nyc_taxi.location_id") - agg_features = [Feature(name="f_location_avg_fare", - key=location_id, - feature_type=FLOAT, - transform=WindowAggTransformation(agg_expr="cast_float(fare_amount)", - agg_func="AVG", - window="90d")), - Feature(name="f_location_max_fare", - key=location_id, - feature_type=FLOAT, - transform=WindowAggTransformation(agg_expr="cast_float(fare_amount)", - agg_func="MAX", - window="90d")) - ] + location_id = TypedKey( + key_column="DOLocationID", + key_column_type=ValueType.INT32, + description="location id in NYC", + full_name="nyc_taxi.location_id", + ) + agg_features = [ + Feature( + name="f_location_avg_fare", + key=location_id, + feature_type=FLOAT, + transform=WindowAggTransformation(agg_expr="cast_float(fare_amount)", agg_func="AVG", window="90d"), + ), + Feature( + name="f_location_max_fare", + key=location_id, + feature_type=FLOAT, + transform=WindowAggTransformation(agg_expr="cast_float(fare_amount)", agg_func="MAX", window="90d"), + ), + ] - agg_anchor = FeatureAnchor(name="aggregationFeatures", - source=batch_source, - features=agg_features) + agg_anchor = FeatureAnchor(name="aggregationFeatures", source=batch_source, features=agg_features) expected_agg_feature_config = """ aggregationFeatures: { @@ -167,4 +170,4 @@ def test_agg_anchor_to_config(): } } """ - assert ''.join(agg_anchor.to_feature_config().split()) == ''.join(expected_agg_feature_config.split()) \ No newline at end of file + assert "".join(agg_anchor.to_feature_config().split()) == "".join(expected_agg_feature_config.split()) diff --git a/feathr_project/test/test_feature_materialization.py b/feathr_project/test/test_feature_materialization.py index ad5a9a02f..d489ce464 100644 --- a/feathr_project/test/test_feature_materialization.py +++ b/feathr_project/test/test_feature_materialization.py @@ -8,8 +8,14 @@ from pyspark.sql.functions import col from feathr import BOOLEAN, FLOAT, INT32, ValueType -from feathr import (BackfillTime, MaterializationSettings, FeatureQuery, - ObservationSettings, SparkExecutionConfiguration, ConflictsAutoCorrection) +from feathr import ( + BackfillTime, + MaterializationSettings, + FeatureQuery, + ObservationSettings, + SparkExecutionConfiguration, + ConflictsAutoCorrection, +) from feathr import Feature from feathr import FeatureAnchor from feathr import INPUT_CONTEXT, HdfsSource @@ -22,13 +28,16 @@ from logging import raiseExceptions import pytest + def test_feature_materialization_config(): - backfill_time = BackfillTime(start=datetime(2020, 5, 20), end=datetime(2020, 5,20), step=timedelta(days=1)) + backfill_time = BackfillTime(start=datetime(2020, 5, 20), end=datetime(2020, 5, 20), step=timedelta(days=1)) redisSink = RedisSink(table_name="nycTaxiDemoFeature") - settings = MaterializationSettings("nycTaxiTable", - sinks=[redisSink], - feature_names=["f_location_avg_fare", "f_location_max_fare"], - backfill_time=backfill_time) + settings = MaterializationSettings( + "nycTaxiTable", + sinks=[redisSink], + feature_names=["f_location_avg_fare", "f_location_max_fare"], + backfill_time=backfill_time, + ) config = _to_materialization_config(settings) expected_config = """ operational: { @@ -48,15 +57,20 @@ def test_feature_materialization_config(): } features: [f_location_avg_fare, f_location_max_fare] """ - assert ''.join(config.split()) == ''.join(expected_config.split()) + assert "".join(config.split()) == "".join(expected_config.split()) + def test_feature_materialization_offline_config(): - backfill_time = BackfillTime(start=datetime(2020, 5, 20), end=datetime(2020, 5,20), step=timedelta(days=1)) - offlineSink = HdfsSink(output_path="abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/demo_data/output/hdfs_test.avro") - settings = MaterializationSettings("nycTaxiTable", - sinks=[offlineSink], - feature_names=["f_location_avg_fare", "f_location_max_fare"], - backfill_time=backfill_time) + backfill_time = BackfillTime(start=datetime(2020, 5, 20), end=datetime(2020, 5, 20), step=timedelta(days=1)) + offlineSink = HdfsSink( + output_path="abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/demo_data/output/hdfs_test.avro" + ) + settings = MaterializationSettings( + "nycTaxiTable", + sinks=[offlineSink], + feature_names=["f_location_avg_fare", "f_location_max_fare"], + backfill_time=backfill_time, + ) config = _to_materialization_config(settings) expected_config = """ operational: { @@ -80,17 +94,15 @@ def test_feature_materialization_offline_config(): } features: [f_location_avg_fare, f_location_max_fare] """ - assert ''.join(config.split()) == ''.join(expected_config.split()) + assert "".join(config.split()) == "".join(expected_config.split()) + def test_feature_materialization_aerospike_sink_config(): - as_sink = AerospikeSink(name="aerospike",seedhost="20.57.186.153", port=3000, namespace="test", setname="test") - backfill_time = BackfillTime(start=datetime( - 2020, 5, 20), end=datetime(2020, 5, 20), step=timedelta(days=1)) - settings = MaterializationSettings("nycTaxiTable", - sinks=[as_sink], - feature_names=[ - "avgfare", "maxfare"], - backfill_time=backfill_time) + as_sink = AerospikeSink(name="aerospike", seedhost="20.57.186.153", port=3000, namespace="test", setname="test") + backfill_time = BackfillTime(start=datetime(2020, 5, 20), end=datetime(2020, 5, 20), step=timedelta(days=1)) + settings = MaterializationSettings( + "nycTaxiTable", sinks=[as_sink], feature_names=["avgfare", "maxfare"], backfill_time=backfill_time + ) os.environ[f"aerospike_USER"] = "feathruser" os.environ[f"aerospike_PASSWORD"] = "feathrpwd" expected_config = """ @@ -120,22 +132,22 @@ def test_feature_materialization_aerospike_sink_config(): features: [avgfare, maxfare] """ config = _to_materialization_config(settings) - assert ''.join(config.split()) == ''.join(expected_config.split()) - + assert "".join(config.split()) == "".join(expected_config.split()) + + def test_feature_materialization_daily_schedule(): """Test back fill cutoff time for a daily range""" backfill_time = BackfillTime(start=datetime(2022, 3, 1), end=datetime(2022, 3, 5), step=timedelta(days=1)) settings = MaterializationSettings("", [], [], backfill_time) expected = [datetime(2022, 3, day) for day in range(1, 6)] assert settings.get_backfill_cutoff_time() == expected - - + def test_feature_materialization_hourly_schedule(): """Test back fill cutoff time for a hourly range""" backfill_time = BackfillTime(start=datetime(2022, 3, 1, 1), end=datetime(2022, 3, 1, 5), step=timedelta(hours=1)) settings = MaterializationSettings("", [], [], backfill_time) - expected = [datetime(2022,3, 1, hour) for hour in range(1, 6)] + expected = [datetime(2022, 3, 1, hour) for hour in range(1, 6)] assert settings.get_backfill_cutoff_time() == expected @@ -148,6 +160,7 @@ def test_feature_materialization_now_schedule(): assert expected.month == date.month assert expected.day == date.day + def test_build_feature_verbose(): """ Test verbose for pretty printing features @@ -159,21 +172,16 @@ def test_build_feature_verbose(): # An anchor feature features = [ Feature(name="trip_distance", feature_type=FLOAT), - Feature(name="f_is_long_trip_distance", - feature_type=BOOLEAN, - transform="cast_float(trip_distance)>30"), - Feature(name="f_day_of_week", - feature_type=INT32, - transform="dayofweek(lpep_dropoff_datetime)") + Feature(name="f_is_long_trip_distance", feature_type=BOOLEAN, transform="cast_float(trip_distance)>30"), + Feature(name="f_day_of_week", feature_type=INT32, transform="dayofweek(lpep_dropoff_datetime)"), ] - anchor = FeatureAnchor(name="request_features", - source=INPUT_CONTEXT, - features=features) + anchor = FeatureAnchor(name="request_features", source=INPUT_CONTEXT, features=features) # Check pretty print client.build_features(anchor_list=[anchor], verbose=True) + def test_get_offline_features_verbose(): """ Test verbose for pretty printing feature query @@ -183,97 +191,126 @@ def test_get_offline_features_verbose(): client = basic_test_setup(os.path.join(test_workspace_dir, "feathr_config.yaml")) - location_id = TypedKey(key_column="DOLocationID", - key_column_type=ValueType.INT32) + location_id = TypedKey(key_column="DOLocationID", key_column_type=ValueType.INT32) feature_query = FeatureQuery(feature_list=["f_location_avg_fare"], key=location_id) settings = ObservationSettings( observation_path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04", event_timestamp_column="lpep_dropoff_datetime", - timestamp_format="yyyy-MM-dd HH:mm:ss" + timestamp_format="yyyy-MM-dd HH:mm:ss", ) now = datetime.now() # set output folder based on different runtime - if client.spark_runtime == 'databricks': - output_path = ''.join(['dbfs:/feathrazure_cijob','_', str(now.minute), '_', str(now.second), ".parquet"]) + if client.spark_runtime == "databricks": + output_path = "".join(["dbfs:/feathrazure_cijob", "_", str(now.minute), "_", str(now.second), ".parquet"]) else: - output_path = ''.join(['abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/demo_data/output','_', str(now.minute), '_', str(now.second), ".parquet"]) + output_path = "".join( + [ + "abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/demo_data/output", + "_", + str(now.minute), + "_", + str(now.second), + ".parquet", + ] + ) # Check pretty print client.get_offline_features( - observation_settings=settings, - feature_query=feature_query, - output_path=output_path, - execution_configurations=SparkExecutionConfiguration({"spark.feathr.inputFormat": "parquet", "spark.feathr.outputFormat": "parquet"}), - verbose=True - ) - + observation_settings=settings, + feature_query=feature_query, + output_path=output_path, + execution_configurations=SparkExecutionConfiguration( + {"spark.feathr.inputFormat": "parquet", "spark.feathr.outputFormat": "parquet"} + ), + verbose=True, + ) + + def test_get_offline_features_auto_correct_dataset(): test_workspace_dir = Path(__file__).parent.resolve() / "test_user_workspace" client = conflicts_auto_correction_setup(os.path.join(test_workspace_dir, "feathr_config.yaml")) - + now = datetime.now() - location_id = TypedKey(key_column="DOLocationID", - key_column_type=ValueType.INT32, - description="location id in NYC", - full_name="nyc_taxi.location_id") - feature_query = FeatureQuery( - feature_list=["tip_amount", "total_amount"], key=location_id) + location_id = TypedKey( + key_column="DOLocationID", + key_column_type=ValueType.INT32, + description="location id in NYC", + full_name="nyc_taxi.location_id", + ) + feature_query = FeatureQuery(feature_list=["tip_amount", "total_amount"], key=location_id) settings = ObservationSettings( observation_path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04_with_index.csv", event_timestamp_column="lpep_dropoff_datetime", timestamp_format="yyyy-MM-dd HH:mm:ss", - conflicts_auto_correction=ConflictsAutoCorrection(rename_features=False, suffix="test")) - # set output folder based on different runtime - if client.spark_runtime == 'databricks': - output_path = ''.join(['dbfs:/feathrazure_cijob','_', str(now.minute), '_', str(now.second), ".avro"]) + conflicts_auto_correction=ConflictsAutoCorrection(rename_features=False, suffix="test"), + ) + # set output folder based on different runtime + if client.spark_runtime == "databricks": + output_path = "".join(["dbfs:/feathrazure_cijob", "_", str(now.minute), "_", str(now.second), ".avro"]) else: - output_path = ''.join(['abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/demo_data/output','_', str(now.minute), '_', str(now.second), ".avro"]) - client.get_offline_features(observation_settings=settings, - feature_query=feature_query, - output_path=output_path - ) + output_path = "".join( + [ + "abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/demo_data/output", + "_", + str(now.minute), + "_", + str(now.second), + ".avro", + ] + ) + client.get_offline_features(observation_settings=settings, feature_query=feature_query, output_path=output_path) client.wait_job_to_finish(timeout_sec=500) - res_df = get_result_df(client, data_format="avro", res_url = output_path) + res_df = get_result_df(client, data_format="avro", res_url=output_path) assert res_df.shape[0] > 0 - + + def test_get_offline_features_auto_correct_features(): test_workspace_dir = Path(__file__).parent.resolve() / "test_user_workspace" client = conflicts_auto_correction_setup(os.path.join(test_workspace_dir, "feathr_config.yaml")) - + now = datetime.now() - location_id = TypedKey(key_column="DOLocationID", - key_column_type=ValueType.INT32, - description="location id in NYC", - full_name="nyc_taxi.location_id") - feature_query = FeatureQuery( - feature_list=["tip_amount", "total_amount"], key=location_id) + location_id = TypedKey( + key_column="DOLocationID", + key_column_type=ValueType.INT32, + description="location id in NYC", + full_name="nyc_taxi.location_id", + ) + feature_query = FeatureQuery(feature_list=["tip_amount", "total_amount"], key=location_id) settings = ObservationSettings( observation_path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04_with_index.csv", event_timestamp_column="lpep_dropoff_datetime", timestamp_format="yyyy-MM-dd HH:mm:ss", - conflicts_auto_correction=ConflictsAutoCorrection(rename_features=True, suffix="test")) + conflicts_auto_correction=ConflictsAutoCorrection(rename_features=True, suffix="test"), + ) # set output folder based on different runtime - if client.spark_runtime == 'databricks': - output_path = ''.join(['dbfs:/feathrazure_cijob','_', str(now.minute), '_', str(now.second), ".avro"]) + if client.spark_runtime == "databricks": + output_path = "".join(["dbfs:/feathrazure_cijob", "_", str(now.minute), "_", str(now.second), ".avro"]) else: - output_path = ''.join(['abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/demo_data/output','_', str(now.minute), '_', str(now.second), ".avro"]) - - client.get_offline_features(observation_settings=settings, - feature_query=feature_query, - output_path=output_path - ) + output_path = "".join( + [ + "abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/demo_data/output", + "_", + str(now.minute), + "_", + str(now.second), + ".avro", + ] + ) + + client.get_offline_features(observation_settings=settings, feature_query=feature_query, output_path=output_path) client.wait_job_to_finish(timeout_sec=500) - res_df = get_result_df(client, data_format="avro", res_url = output_path) + res_df = get_result_df(client, data_format="avro", res_url=output_path) assert res_df.shape[0] > 0 + def test_materialize_features_verbose(): online_test_table = get_online_test_table_name("nycTaxiCITableMaterializeVerbose") test_workspace_dir = Path(__file__).parent.resolve() / "test_user_workspace" @@ -281,19 +318,22 @@ def test_materialize_features_verbose(): client: FeathrClient = basic_test_setup(os.path.join(test_workspace_dir, "feathr_config.yaml")) backfill_time = BackfillTime(start=datetime(2020, 5, 20), end=datetime(2020, 5, 20), step=timedelta(days=1)) redisSink = RedisSink(table_name=online_test_table) - settings = MaterializationSettings("nycTaxiTable", - sinks=[redisSink], - feature_names=[ - "f_location_avg_fare", "f_location_max_fare"], - backfill_time=backfill_time) + settings = MaterializationSettings( + "nycTaxiTable", + sinks=[redisSink], + feature_names=["f_location_avg_fare", "f_location_max_fare"], + backfill_time=backfill_time, + ) client.materialize_features(settings, verbose=True) client._clean_test_data(online_test_table) + def add_new_fare_amount(df: DataFrame) -> DataFrame: df = df.withColumn("fare_amount_new", col("fare_amount") + 8000000) return df + def test_delete_feature_from_redis(): """ Test FeathrClient() delete_feature_from_redis to remove feature from Redis. @@ -303,53 +343,56 @@ def test_delete_feature_from_redis(): client: FeathrClient = basic_test_setup(os.path.join(test_workspace_dir, "feathr_config.yaml")) - batch_source = HdfsSource(name="nycTaxiBatchSource_add_new_fare_amount", - path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04.csv", - preprocessing=add_new_fare_amount, - event_timestamp_column="lpep_dropoff_datetime", - timestamp_format="yyyy-MM-dd HH:mm:ss") + batch_source = HdfsSource( + name="nycTaxiBatchSource_add_new_fare_amount", + path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04.csv", + preprocessing=add_new_fare_amount, + event_timestamp_column="lpep_dropoff_datetime", + timestamp_format="yyyy-MM-dd HH:mm:ss", + ) - pickup_time_as_id = TypedKey(key_column="lpep_pickup_datetime", - key_column_type=ValueType.INT32, - description="location id in NYC", - full_name="nyc_taxi.location_id") + pickup_time_as_id = TypedKey( + key_column="lpep_pickup_datetime", + key_column_type=ValueType.INT32, + description="location id in NYC", + full_name="nyc_taxi.location_id", + ) features = [ - Feature(name="f_is_long_trip_distance", - key=pickup_time_as_id, - feature_type=FLOAT, - transform="fare_amount_new"), - Feature(name="f_day_of_week", - key=pickup_time_as_id, - feature_type=INT32, - transform="dayofweek(lpep_dropoff_datetime)"), + Feature(name="f_is_long_trip_distance", key=pickup_time_as_id, feature_type=FLOAT, transform="fare_amount_new"), + Feature( + name="f_day_of_week", + key=pickup_time_as_id, + feature_type=INT32, + transform="dayofweek(lpep_dropoff_datetime)", + ), ] - regular_anchor = FeatureAnchor(name="request_features_add_new_fare_amount", - source=batch_source, - features=features, - ) + regular_anchor = FeatureAnchor( + name="request_features_add_new_fare_amount", + source=batch_source, + features=features, + ) client.build_features(anchor_list=[regular_anchor]) - online_test_table = get_online_test_table_name('nycTaxiCITableDeletion') + online_test_table = get_online_test_table_name("nycTaxiCITableDeletion") - backfill_time = BackfillTime(start=datetime( - 2020, 5, 20), end=datetime(2020, 5, 20), step=timedelta(days=1)) + backfill_time = BackfillTime(start=datetime(2020, 5, 20), end=datetime(2020, 5, 20), step=timedelta(days=1)) redisSink = RedisSink(table_name=online_test_table) - settings = MaterializationSettings(name="py_udf", - sinks=[redisSink], - feature_names=[ - "f_is_long_trip_distance", - "f_day_of_week" - ], - backfill_time=backfill_time) + settings = MaterializationSettings( + name="py_udf", + sinks=[redisSink], + feature_names=["f_is_long_trip_distance", "f_day_of_week"], + backfill_time=backfill_time, + ) client.materialize_features(settings, allow_materialize_non_agg_feature=True) - + client.wait_job_to_finish(timeout_sec=Constants.SPARK_JOB_TIMEOUT_SECONDS) - - res = client.get_online_features(online_test_table, '2020-04-01 07:21:51', [ - 'f_is_long_trip_distance', 'f_day_of_week']) + + res = client.get_online_features( + online_test_table, "2020-04-01 07:21:51", ["f_is_long_trip_distance", "f_day_of_week"] + ) assert len(res) == 2 @@ -357,29 +400,30 @@ def test_delete_feature_from_redis(): assert res[1] != None # Delete online feature stored in Redis - client.delete_feature_from_redis(online_test_table, '2020-04-01 07:21:51', 'f_is_long_trip_distance') - + client.delete_feature_from_redis(online_test_table, "2020-04-01 07:21:51", "f_is_long_trip_distance") + # Check if the online feature is deleted successfully - res = client.get_online_features(online_test_table, '265', ['f_location_avg_fare']) + res = client.get_online_features(online_test_table, "265", ["f_location_avg_fare"]) assert len(res) == 1 assert res[0] == None - + client._clean_test_data(online_test_table) + def test_feature_list_on_input_context(): with pytest.raises(RuntimeError) as e_info: test_workspace_dir = Path(__file__).parent.resolve() / "test_user_workspace" client: FeathrClient = basic_test_setup(os.path.join(test_workspace_dir, "feathr_config.yaml")) - online_test_table = get_online_test_table_name('nycTaxiCITableDeletion') + online_test_table = get_online_test_table_name("nycTaxiCITableDeletion") redisSink = RedisSink(table_name=online_test_table) - settings = MaterializationSettings(name="py_udf", - sinks=[redisSink], - feature_names=[ - "f_location_avg_fare", - "f_day_of_week" - ]) + settings = MaterializationSettings( + name="py_udf", sinks=[redisSink], feature_names=["f_location_avg_fare", "f_day_of_week"] + ) client.materialize_features(settings, allow_materialize_non_agg_feature=True) assert e_info is not None - assert e_info.value.args[0] == "Materializing features that are defined on INPUT_CONTEXT is not supported. f_day_of_week is defined on INPUT_CONTEXT so you should remove it from the feature list in MaterializationSettings." \ No newline at end of file + assert ( + e_info.value.args[0] + == "Materializing features that are defined on INPUT_CONTEXT is not supported. f_day_of_week is defined on INPUT_CONTEXT so you should remove it from the feature list in MaterializationSettings." + ) diff --git a/feathr_project/test/test_feature_name_validation.py b/feathr_project/test/test_feature_name_validation.py index 0fd02f9c8..1f22820ab 100644 --- a/feathr_project/test/test_feature_name_validation.py +++ b/feathr_project/test/test_feature_name_validation.py @@ -4,12 +4,11 @@ from pathlib import Path from feathr import FeatureBase -from feathr import (TypedKey, ValueType, FeatureQuery, ObservationSettings) +from feathr import TypedKey, ValueType, FeatureQuery, ObservationSettings from feathr import FeathrClient -@pytest.mark.parametrize('bad_feature_name', - [None, - '']) + +@pytest.mark.parametrize("bad_feature_name", [None, ""]) def test_feature_name_fails_on_empty_name(bad_feature_name: str): with pytest.raises(Exception, match="empty feature name"): FeatureBase.validate_feature_name(bad_feature_name) @@ -21,89 +20,88 @@ def test_feature_name_fails_on_leading_number(): def test_feature_name_fails_on_punctuation_chars(): - for char in set(string.punctuation) - set('_'): - with pytest.raises(Exception, match="only letters, numbers, and underscores are allowed"): + for char in set(string.punctuation) - set("_"): + with pytest.raises(Exception, match="only letters, numbers, and underscores are allowed"): FeatureBase.validate_feature_name(f"feature_{char}_name") -@pytest.mark.parametrize('feature_name', - ["feature_name", - "features4lyfe", - "f_4_feature_", - "_leading_underscores_are_ok", - "CapitalizedFeature" - '']) +@pytest.mark.parametrize( + "feature_name", + ["feature_name", "features4lyfe", "f_4_feature_", "_leading_underscores_are_ok", "CapitalizedFeature" ""], +) def test_feature_name_validates_ok(feature_name: str): assert FeatureBase.validate_feature_name(feature_name) - + + def test_feature_name_conflicts_with_public_dataset_columns(): - test_workspace_dir = Path( - __file__).parent.resolve() / "test_user_workspace" - + test_workspace_dir = Path(__file__).parent.resolve() / "test_user_workspace" + client = client = FeathrClient(os.path.join(test_workspace_dir, "feathr_config.yaml")) - - location_id = TypedKey(key_column="DOLocationID", - key_column_type=ValueType.INT32, - description="location id in NYC", - full_name="nyc_taxi.location_id") - - feature_query = FeatureQuery( - feature_list=["trip_distance","fare_amount"], key=location_id) + + location_id = TypedKey( + key_column="DOLocationID", + key_column_type=ValueType.INT32, + description="location id in NYC", + full_name="nyc_taxi.location_id", + ) + + feature_query = FeatureQuery(feature_list=["trip_distance", "fare_amount"], key=location_id) settings = ObservationSettings( observation_path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04_with_index.csv", event_timestamp_column="lpep_dropoff_datetime", - timestamp_format="yyyy-MM-dd HH:mm:ss") + timestamp_format="yyyy-MM-dd HH:mm:ss", + ) output_path = "wasbs://fake_path" with pytest.raises(RuntimeError) as e: - client.get_offline_features(observation_settings=settings, - feature_query=feature_query, - output_path=output_path - ) + client.get_offline_features(observation_settings=settings, feature_query=feature_query, output_path=output_path) assert str(e.value) == "Feature names exist conflicts with dataset column names: trip_distance,fare_amount" - + settings = ObservationSettings( observation_path="wasbs://public@fake_file", event_timestamp_column="lpep_dropoff_datetime", - timestamp_format="yyyy-MM-dd HH:mm:ss") + timestamp_format="yyyy-MM-dd HH:mm:ss", + ) output_path = "wasbs://fakepath" with pytest.raises(RuntimeError) as e: - client.get_offline_features(observation_settings=settings, - feature_query=feature_query, - output_path=output_path, - dataset_column_names=set(('trip_distance','fare_amount')) + client.get_offline_features( + observation_settings=settings, + feature_query=feature_query, + output_path=output_path, + dataset_column_names=set(("trip_distance", "fare_amount")), ) assert str(e.value) == "Feature names exist conflicts with dataset column names: trip_distance,fare_amount" - + + def test_feature_name_conflicts_with_private_dataset_columns(): - test_workspace_dir = Path( - __file__).parent.resolve() / "test_user_workspace" - + test_workspace_dir = Path(__file__).parent.resolve() / "test_user_workspace" + client = client = FeathrClient(os.path.join(test_workspace_dir, "feathr_config.yaml")) - - if client.spark_runtime == 'databricks': - source_path = 'dbfs:/timePartitionPattern_test/df0/daily/2020/05/01/' + + if client.spark_runtime == "databricks": + source_path = "dbfs:/timePartitionPattern_test/df0/daily/2020/05/01/" else: - source_path = 'abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/timePartitionPattern_test/df0/daily/2020/05/01/' - - location_id = TypedKey(key_column="DOLocationID", - key_column_type=ValueType.INT32, - description="location id in NYC", - full_name="nyc_taxi.location_id") - - feature_query = FeatureQuery( - feature_list=["f_location_avg_fare","f_location_max_fare"], key=location_id) - + source_path = "abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/timePartitionPattern_test/df0/daily/2020/05/01/" + + location_id = TypedKey( + key_column="DOLocationID", + key_column_type=ValueType.INT32, + description="location id in NYC", + full_name="nyc_taxi.location_id", + ) + + feature_query = FeatureQuery(feature_list=["f_location_avg_fare", "f_location_max_fare"], key=location_id) + settings = ObservationSettings( observation_path=source_path, event_timestamp_column="lpep_dropoff_datetime", timestamp_format="yyyy-MM-dd HH:mm:ss", file_format="avro", - is_file_path=False) + is_file_path=False, + ) output_path = "wasbs://fake_path" with pytest.raises(RuntimeError) as e: - client.get_offline_features(observation_settings=settings, - feature_query=feature_query, - output_path=output_path - ) - assert str(e.value) == "Feature names exist conflicts with dataset column names: f_location_avg_fare,f_location_max_fare" - + client.get_offline_features(observation_settings=settings, feature_query=feature_query, output_path=output_path) + assert ( + str(e.value) + == "Feature names exist conflicts with dataset column names: f_location_avg_fare,f_location_max_fare" + ) diff --git a/feathr_project/test/test_feature_registry.py b/feathr_project/test/test_feature_registry.py index 681b443bf..7b1ac408b 100644 --- a/feathr_project/test/test_feature_registry.py +++ b/feathr_project/test/test_feature_registry.py @@ -8,8 +8,7 @@ from click.testing import CliRunner -from feathr import (FeatureQuery, ObservationSettings, TypedKey, - ValueType) +from feathr import FeatureQuery, ObservationSettings, TypedKey, ValueType from feathr.client import FeathrClient from feathr.registry._feathr_registry_client import _FeatureRegistry from feathrcli.cli import init @@ -17,6 +16,7 @@ from test_fixture import registry_test_setup_append, registry_test_setup_partially, registry_test_setup_for_409 from test_utils.constants import Constants + class FeatureRegistryTests(unittest.TestCase): def test_feathr_register_features_e2e(self): """ @@ -37,34 +37,46 @@ def test_feathr_register_features_e2e(self): # set output folder based on different runtime now = datetime.now() - if client.spark_runtime == 'databricks': - output_path = ''.join(['dbfs:/feathrazure_cijob','_', str(now.minute), '_', str(now.second), ".parquet"]) + if client.spark_runtime == "databricks": + output_path = "".join( + ["dbfs:/feathrazure_cijob", "_", str(now.minute), "_", str(now.second), ".parquet"] + ) else: - output_path = ''.join(['abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/demo_data/output','_', str(now.minute), '_', str(now.second), ".parquet"]) - + output_path = "".join( + [ + "abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/demo_data/output", + "_", + str(now.minute), + "_", + str(now.second), + ".parquet", + ] + ) client.register_features() # Allow purview to process a bit time.sleep(5) # in CI test, the project name is set by the CI pipeline so we read it here all_features = client.list_registered_features(project_name=client.project_name) - all_feature_names = [x['name'] for x in all_features] + all_feature_names = [x["name"] for x in all_features] - assert 'f_is_long_trip_distance' in all_feature_names # test regular ones - assert 'f_trip_time_rounded' in all_feature_names # make sure derived features are there - assert 'f_location_avg_fare' in all_feature_names # make sure aggregated features are there - assert 'f_trip_time_rounded_plus' in all_feature_names # make sure derived features are there - assert 'f_trip_time_distance' in all_feature_names # make sure derived features are there + assert "f_is_long_trip_distance" in all_feature_names # test regular ones + assert "f_trip_time_rounded" in all_feature_names # make sure derived features are there + assert "f_location_avg_fare" in all_feature_names # make sure aggregated features are there + assert "f_trip_time_rounded_plus" in all_feature_names # make sure derived features are there + assert "f_trip_time_distance" in all_feature_names # make sure derived features are there # Sync workspace from registry, will get all conf files back client.get_features_from_registry(client.project_name) - + # Register the same feature with different definition and expect an error. - client: FeathrClient = registry_test_setup_for_409(os.path.join(test_workspace_dir, config_path), client.project_name) + client: FeathrClient = registry_test_setup_for_409( + os.path.join(test_workspace_dir, config_path), client.project_name + ) with pytest.raises(RuntimeError) as exc_info: client.register_features() - + # 30 # update this to trigger 409 conflict with the existing one features = [ - Feature(name="f_is_long_trip_distance", - feature_type=BOOLEAN, - transform="cast_float(trip_distance)>10"), + Feature(name="f_is_long_trip_distance", feature_type=BOOLEAN, transform="cast_float(trip_distance)>10"), ] - request_anchor = FeatureAnchor(name="request_features", - source=INPUT_CONTEXT, - features=features, - registry_tags={"for_test_purpose":"true"} - ) + request_anchor = FeatureAnchor( + name="request_features", source=INPUT_CONTEXT, features=features, registry_tags={"for_test_purpose": "true"} + ) client.build_features(anchor_list=[request_anchor]) return client + def get_online_test_table_name(table_name: str): # use different time for testing to avoid write conflicts now = datetime.now() - res_table = '_'.join([table_name, str(now.minute), str(now.second)]) + res_table = "_".join([table_name, str(now.minute), str(now.second)]) print("The online Redis table is", res_table) return res_table -def time_partition_pattern_feature_gen_test_setup(config_path: str, data_source_path: str, local_workspace_dir: str = None, resolution: str = 'DAILY', postfix_path: str = ""): + +def time_partition_pattern_feature_gen_test_setup( + config_path: str, + data_source_path: str, + local_workspace_dir: str = None, + resolution: str = "DAILY", + postfix_path: str = "", +): now = datetime.now() # set workspace folder by time; make sure we don't have write conflict if there are many CI tests running - os.environ['SPARK_CONFIG__DATABRICKS__WORK_DIR'] = ''.join(['dbfs:/feathrazure_cijob','_', str(now.minute), '_', str(now.second), '_', str(now.microsecond)]) - os.environ['SPARK_CONFIG__AZURE_SYNAPSE__WORKSPACE_DIR'] = ''.join(['abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/feathr_github_ci','_', str(now.minute), '_', str(now.second) ,'_', str(now.microsecond)]) + os.environ["SPARK_CONFIG__DATABRICKS__WORK_DIR"] = "".join( + ["dbfs:/feathrazure_cijob", "_", str(now.minute), "_", str(now.second), "_", str(now.microsecond)] + ) + os.environ["SPARK_CONFIG__AZURE_SYNAPSE__WORKSPACE_DIR"] = "".join( + [ + "abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/feathr_github_ci", + "_", + str(now.minute), + "_", + str(now.second), + "_", + str(now.microsecond), + ] + ) client = FeathrClient(config_path=config_path, local_workspace_dir=local_workspace_dir) - if resolution == 'DAILY': + if resolution == "DAILY": if postfix_path != "": - batch_source = HdfsSource(name="testTimePartitionSource", - path=data_source_path, - time_partition_pattern="yyyy/MM/dd", - postfix_path=postfix_path - ) + batch_source = HdfsSource( + name="testTimePartitionSource", + path=data_source_path, + time_partition_pattern="yyyy/MM/dd", + postfix_path=postfix_path, + ) else: - batch_source = HdfsSource(name="testTimePartitionSource", - path=data_source_path, - time_partition_pattern="yyyy/MM/dd" - ) + batch_source = HdfsSource( + name="testTimePartitionSource", path=data_source_path, time_partition_pattern="yyyy/MM/dd" + ) else: - batch_source = HdfsSource(name="testTimePartitionSource", - path=data_source_path, - time_partition_pattern="yyyy/MM/dd/HH" - ) - key = TypedKey(key_column="key0", - key_column_type=ValueType.INT32) + batch_source = HdfsSource( + name="testTimePartitionSource", path=data_source_path, time_partition_pattern="yyyy/MM/dd/HH" + ) + key = TypedKey(key_column="key0", key_column_type=ValueType.INT32) agg_features = [ - Feature(name="f_loc_avg_output", + Feature( + name="f_loc_avg_output", key=[key], feature_type=FLOAT, - transform=WindowAggTransformation(agg_expr="f_location_avg_fare", - agg_func="AVG", - window="3d")), - Feature(name="f_loc_max_output", + transform=WindowAggTransformation(agg_expr="f_location_avg_fare", agg_func="AVG", window="3d"), + ), + Feature( + name="f_loc_max_output", feature_type=FLOAT, key=[key], - transform=WindowAggTransformation(agg_expr="f_location_max_fare", - agg_func="MAX", - window="3d")), + transform=WindowAggTransformation(agg_expr="f_location_max_fare", agg_func="MAX", window="3d"), + ), ] - agg_anchor = FeatureAnchor(name="testTimePartitionFeatures", - source=batch_source, - features=agg_features) + agg_anchor = FeatureAnchor(name="testTimePartitionFeatures", source=batch_source, features=agg_features) client.build_features(anchor_list=[agg_anchor]) return client -def time_partition_pattern_feature_join_test_setup(config_path: str, data_source_path: str, local_workspace_dir: str = None, resolution: str = 'DAILY', postfix_path: str = ""): + +def time_partition_pattern_feature_join_test_setup( + config_path: str, + data_source_path: str, + local_workspace_dir: str = None, + resolution: str = "DAILY", + postfix_path: str = "", +): now = datetime.now() # set workspace folder by time; make sure we don't have write conflict if there are many CI tests running - os.environ['SPARK_CONFIG__DATABRICKS__WORK_DIR'] = ''.join(['dbfs:/feathrazure_cijob','_', str(now.minute), '_', str(now.second), '_', str(now.microsecond)]) - os.environ['SPARK_CONFIG__AZURE_SYNAPSE__WORKSPACE_DIR'] = ''.join(['abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/feathr_github_ci','_', str(now.minute), '_', str(now.second) ,'_', str(now.microsecond)]) + os.environ["SPARK_CONFIG__DATABRICKS__WORK_DIR"] = "".join( + ["dbfs:/feathrazure_cijob", "_", str(now.minute), "_", str(now.second), "_", str(now.microsecond)] + ) + os.environ["SPARK_CONFIG__AZURE_SYNAPSE__WORKSPACE_DIR"] = "".join( + [ + "abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/feathr_github_ci", + "_", + str(now.minute), + "_", + str(now.second), + "_", + str(now.microsecond), + ] + ) client = FeathrClient(config_path=config_path, local_workspace_dir=local_workspace_dir) - + if postfix_path == "": - if resolution == 'DAILY': - batch_source_tpp = HdfsSource(name="nycTaxiBatchSource", - path=data_source_path, - time_partition_pattern="yyyy/MM/dd" - ) + if resolution == "DAILY": + batch_source_tpp = HdfsSource( + name="nycTaxiBatchSource", path=data_source_path, time_partition_pattern="yyyy/MM/dd" + ) else: - batch_source_tpp = HdfsSource(name="nycTaxiBatchSource", - path=data_source_path, - time_partition_pattern="yyyy/MM/dd/HH" - ) + batch_source_tpp = HdfsSource( + name="nycTaxiBatchSource", path=data_source_path, time_partition_pattern="yyyy/MM/dd/HH" + ) else: - batch_source_tpp = HdfsSource(name="nycTaxiBatchSource", - path=data_source_path, - time_partition_pattern="yyyy/MM/dd", - postfix_path=postfix_path - ) - tpp_key = TypedKey(key_column="f_location_max_fare", - key_column_type=ValueType.FLOAT) + batch_source_tpp = HdfsSource( + name="nycTaxiBatchSource", + path=data_source_path, + time_partition_pattern="yyyy/MM/dd", + postfix_path=postfix_path, + ) + tpp_key = TypedKey(key_column="f_location_max_fare", key_column_type=ValueType.FLOAT) tpp_features = [ - Feature(name="key0", + Feature( + name="key0", key=tpp_key, feature_type=FLOAT, - transform=WindowAggTransformation(agg_expr="key0", - agg_func="LATEST", - window="3d" - )) + transform=WindowAggTransformation(agg_expr="key0", agg_func="LATEST", window="3d"), + ) ] - tpp_anchor = FeatureAnchor(name="tppFeatures", - source=batch_source_tpp, - features=tpp_features) + tpp_anchor = FeatureAnchor(name="tppFeatures", source=batch_source_tpp, features=tpp_features) client.build_features(anchor_list=[tpp_anchor]) - + feature_query = FeatureQuery(feature_list=["key0"], key=tpp_key) settings = ObservationSettings( - observation_path='wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/tpp_source.csv', + observation_path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/tpp_source.csv", event_timestamp_column="lpep_dropoff_datetime", - timestamp_format="yyyy-MM-dd HH:mm:ss") - return [client, feature_query, settings] \ No newline at end of file + timestamp_format="yyyy-MM-dd HH:mm:ss", + ) + return [client, feature_query, settings] diff --git a/feathr_project/test/test_input_output_sources.py b/feathr_project/test/test_input_output_sources.py index ba4b3921a..1a0d9c5a3 100644 --- a/feathr_project/test/test_input_output_sources.py +++ b/feathr_project/test/test_input_output_sources.py @@ -3,7 +3,7 @@ from datetime import datetime from pathlib import Path -from feathr import (FeatureQuery, ObservationSettings, SparkExecutionConfiguration, TypedKey, ValueType) +from feathr import FeatureQuery, ObservationSettings, SparkExecutionConfiguration, TypedKey, ValueType from feathr.client import FeathrClient from feathr.constants import OUTPUT_FORMAT from feathr.utils.job_utils import get_result_df @@ -17,34 +17,54 @@ def test_feathr_get_offline_features_with_parquet(): Test if the program can read and write parquet files """ - test_workspace_dir = Path( - __file__).parent.resolve() / "test_user_workspace" + test_workspace_dir = Path(__file__).parent.resolve() / "test_user_workspace" client: FeathrClient = basic_test_setup(os.path.join(test_workspace_dir, "feathr_config.yaml")) - location_id = TypedKey(key_column="DOLocationID", - key_column_type=ValueType.INT32) + location_id = TypedKey(key_column="DOLocationID", key_column_type=ValueType.INT32) - feature_query = FeatureQuery( - feature_list=["f_location_avg_fare"], key=location_id) + feature_query = FeatureQuery(feature_list=["f_location_avg_fare"], key=location_id) settings = ObservationSettings( observation_path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04", event_timestamp_column="lpep_dropoff_datetime", - timestamp_format="yyyy-MM-dd HH:mm:ss") + timestamp_format="yyyy-MM-dd HH:mm:ss", + ) now = datetime.now() # set output folder based on different runtime - if client.spark_runtime == 'databricks': - output_path = ''.join(['dbfs:/feathrazure_cijob','_', str(now.minute), '_', str(now.second),'_', str(now.microsecond), ".parquet"]) + if client.spark_runtime == "databricks": + output_path = "".join( + [ + "dbfs:/feathrazure_cijob", + "_", + str(now.minute), + "_", + str(now.second), + "_", + str(now.microsecond), + ".parquet", + ] + ) else: - output_path = ''.join(['abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/demo_data/output','_', str(now.minute), '_', str(now.second), ".parquet"]) - - - client.get_offline_features(observation_settings=settings, - feature_query=feature_query, - output_path=output_path, - execution_configurations=SparkExecutionConfiguration({"spark.feathr.inputFormat": "parquet", "spark.feathr.outputFormat": "parquet"}) - ) + output_path = "".join( + [ + "abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/demo_data/output", + "_", + str(now.minute), + "_", + str(now.second), + ".parquet", + ] + ) + + client.get_offline_features( + observation_settings=settings, + feature_query=feature_query, + output_path=output_path, + execution_configurations=SparkExecutionConfiguration( + {"spark.feathr.inputFormat": "parquet", "spark.feathr.outputFormat": "parquet"} + ), + ) # assuming the job can successfully run; otherwise it will throw exception client.wait_job_to_finish(timeout_sec=Constants.SPARK_JOB_TIMEOUT_SECONDS) @@ -60,34 +80,43 @@ def test_feathr_get_offline_features_with_delta_lake(): Test if the program can read and write delta lake """ - test_workspace_dir = Path( - __file__).parent.resolve() / "test_user_workspace" + test_workspace_dir = Path(__file__).parent.resolve() / "test_user_workspace" client = basic_test_setup(os.path.join(test_workspace_dir, "feathr_config.yaml")) - location_id = TypedKey(key_column="DOLocationID", - key_column_type=ValueType.INT32) + location_id = TypedKey(key_column="DOLocationID", key_column_type=ValueType.INT32) - feature_query = FeatureQuery( - feature_list=["f_location_avg_fare"], key=location_id) + feature_query = FeatureQuery(feature_list=["f_location_avg_fare"], key=location_id) settings = ObservationSettings( observation_path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/feathr_delta_table", event_timestamp_column="lpep_dropoff_datetime", - timestamp_format="yyyy-MM-dd HH:mm:ss") + timestamp_format="yyyy-MM-dd HH:mm:ss", + ) now = datetime.now() # set output folder based on different runtime - if client.spark_runtime == 'databricks': - output_path = ''.join(['dbfs:/feathrazure_cijob','_', str(now.minute), '_', str(now.second), "_deltalake"]) + if client.spark_runtime == "databricks": + output_path = "".join(["dbfs:/feathrazure_cijob", "_", str(now.minute), "_", str(now.second), "_deltalake"]) else: - output_path = ''.join(['abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/demo_data/output','_', str(now.minute), '_', str(now.second), "_deltalake"]) - - - client.get_offline_features(observation_settings=settings, - feature_query=feature_query, - output_path=output_path, - execution_configurations=SparkExecutionConfiguration({"spark.feathr.inputFormat": "delta", "spark.feathr.outputFormat": "delta"}) - ) + output_path = "".join( + [ + "abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/demo_data/output", + "_", + str(now.minute), + "_", + str(now.second), + "_deltalake", + ] + ) + + client.get_offline_features( + observation_settings=settings, + feature_query=feature_query, + output_path=output_path, + execution_configurations=SparkExecutionConfiguration( + {"spark.feathr.inputFormat": "delta", "spark.feathr.outputFormat": "delta"} + ), + ) # assuming the job can successfully run; otherwise it will throw exception client.wait_job_to_finish(timeout_sec=Constants.SPARK_JOB_TIMEOUT_SECONDS) @@ -98,6 +127,6 @@ def test_feathr_get_offline_features_with_delta_lake(): # download result and just assert the returned result is not empty # if users are using delta format in synapse, skip this check, due to issue https://github.com/delta-io/delta-rs/issues/582 result_format: str = client.get_job_tags().get(OUTPUT_FORMAT, "") - if not (client.spark_runtime == 'azure_synapse' and result_format == 'delta'): + if not (client.spark_runtime == "azure_synapse" and result_format == "delta"): res_df = get_result_df(client) assert res_df.shape[0] > 0 diff --git a/feathr_project/test/test_local_spark_e2e.py b/feathr_project/test/test_local_spark_e2e.py index fe7d9b5df..9feee73e8 100644 --- a/feathr_project/test/test_local_spark_e2e.py +++ b/feathr_project/test/test_local_spark_e2e.py @@ -8,22 +8,35 @@ from pathlib import Path import os from datetime import datetime, timedelta -from feathr import FeathrClient, ObservationSettings, FeatureQuery, TypedKey, HdfsSource, Feature, FeatureAnchor, INPUT_CONTEXT, FLOAT, INT32, BOOLEAN, DerivedFeature, WindowAggTransformation, ValueType +from feathr import ( + FeathrClient, + ObservationSettings, + FeatureQuery, + TypedKey, + HdfsSource, + Feature, + FeatureAnchor, + INPUT_CONTEXT, + FLOAT, + INT32, + BOOLEAN, + DerivedFeature, + WindowAggTransformation, + ValueType, +) from feathr import BOOLEAN, FLOAT, INT32, ValueType -from test_utils.udfs import (add_new_dropoff_and_fare_amount_column, add_new_fare_amount) +from test_utils.udfs import add_new_dropoff_and_fare_amount_column, add_new_fare_amount def test_local_spark_get_offline_features(): - #This Test is for Local Spark only + # This Test is for Local Spark only if not _is_local(): return - - test_workspace_dir = Path( - __file__).parent.resolve() / "test_user_workspace" - client = _local_client_setup(test_workspace_dir) + test_workspace_dir = Path(__file__).parent.resolve() / "test_user_workspace" + client = _local_client_setup(test_workspace_dir) output_path, proc1 = _non_udf_features(client) client.wait_job_to_finish() @@ -31,316 +44,344 @@ def test_local_spark_get_offline_features(): df = parse_avro_result(output_path) assert df.shape[0] > 0 - shutil.rmtree('debug') - return + shutil.rmtree("debug") + return + def test_local_spark_pyudf_get_offline_features(): - #This Test is for Local Spark only + # This Test is for Local Spark only if not _is_local(): return client = _local_client_setup() - + output_path, proc = _udf_features(client) client.wait_job_to_finish() df = parse_avro_result(output_path) assert df.shape[0] > 0 - shutil.rmtree('debug') - return + shutil.rmtree("debug") + return + def test_local_spark_materialization(): - #This Test is for Local Spark only + # This Test is for Local Spark only if not _is_local(): return client: FeathrClient = _local_client_setup() results = _feature_gen_test(client) # just assume the job is successful without validating the actual result in Redis. - shutil.rmtree('debug') + shutil.rmtree("debug") return + def _is_local() -> bool: """ to verify if test is running on local spark mode """ - if os.environ.get('SPARK_CONFIG__SPARK_CLUSTER') == 'local': + if os.environ.get("SPARK_CONFIG__SPARK_CLUSTER") == "local": return True else: return False -def _local_client_setup(local_workspace:str = None): + +def _local_client_setup(local_workspace: str = None): if not local_workspace: - local_workspace = Path( - __file__).parent.resolve() / "test_user_workspace" + local_workspace = Path(__file__).parent.resolve() / "test_user_workspace" os.chdir(local_workspace) - client = FeathrClient(os.path.join(local_workspace, "feathr_config_local.yaml"), local_workspace_dir=local_workspace) + client = FeathrClient( + os.path.join(local_workspace, "feathr_config_local.yaml"), local_workspace_dir=local_workspace + ) return client -def _non_udf_features(client:FeathrClient = None): + +def _non_udf_features(client: FeathrClient = None): if not client: client = _local_client_setup() - batch_source = HdfsSource(name="nycTaxiBatchSource", - path="./green_tripdata_2020-04_with_index.csv", - event_timestamp_column="lpep_dropoff_datetime", - timestamp_format="yyyy-MM-dd HH:mm:ss") + batch_source = HdfsSource( + name="nycTaxiBatchSource", + path="./green_tripdata_2020-04_with_index.csv", + event_timestamp_column="lpep_dropoff_datetime", + timestamp_format="yyyy-MM-dd HH:mm:ss", + ) - f_trip_distance = Feature(name="f_trip_distance", - feature_type=FLOAT, transform="trip_distance") - f_trip_time_duration = Feature(name="f_trip_time_duration", - feature_type=INT32, - transform="(to_unix_timestamp(lpep_dropoff_datetime) - to_unix_timestamp(lpep_pickup_datetime))/60") + f_trip_distance = Feature(name="f_trip_distance", feature_type=FLOAT, transform="trip_distance") + f_trip_time_duration = Feature( + name="f_trip_time_duration", + feature_type=INT32, + transform="(to_unix_timestamp(lpep_dropoff_datetime) - to_unix_timestamp(lpep_pickup_datetime))/60", + ) features = [ f_trip_distance, f_trip_time_duration, - Feature(name="f_is_long_trip_distance", - feature_type=BOOLEAN, - transform="cast_float(trip_distance)>30"), - Feature(name="f_day_of_week", - feature_type=INT32, - transform="dayofweek(lpep_dropoff_datetime)"), + Feature(name="f_is_long_trip_distance", feature_type=BOOLEAN, transform="cast_float(trip_distance)>30"), + Feature(name="f_day_of_week", feature_type=INT32, transform="dayofweek(lpep_dropoff_datetime)"), ] + request_anchor = FeatureAnchor(name="request_features", source=INPUT_CONTEXT, features=features) + + f_trip_time_distance = DerivedFeature( + name="f_trip_time_distance", + feature_type=FLOAT, + input_features=[f_trip_distance, f_trip_time_duration], + transform="f_trip_distance * f_trip_time_duration", + ) + + f_trip_time_rounded = DerivedFeature( + name="f_trip_time_rounded", + feature_type=INT32, + input_features=[f_trip_time_duration], + transform="f_trip_time_duration % 10", + ) + + location_id = TypedKey( + key_column="DOLocationID", + key_column_type=ValueType.INT32, + description="location id in NYC", + full_name="nyc_taxi.location_id", + ) + agg_features = [ + Feature( + name="f_location_avg_fare", + key=location_id, + feature_type=FLOAT, + transform=WindowAggTransformation( + agg_expr="cast_float(fare_amount)", + agg_func="AVG", + window="90d", + ), + ), + Feature( + name="f_location_max_fare", + key=location_id, + feature_type=FLOAT, + transform=WindowAggTransformation(agg_expr="cast_float(fare_amount)", agg_func="MAX", window="90d"), + ), + ] + + agg_anchor = FeatureAnchor(name="aggregationFeatures", source=batch_source, features=agg_features) + + client.build_features( + anchor_list=[agg_anchor, request_anchor], derived_feature_list=[f_trip_time_distance, f_trip_time_rounded] + ) + + location_id = TypedKey( + key_column="DOLocationID", + key_column_type=ValueType.INT32, + description="location id in NYC", + full_name="nyc_taxi.location_id", + ) + + feature_query = FeatureQuery(feature_list=["f_location_avg_fare"], key=location_id) - request_anchor = FeatureAnchor(name="request_features", - source=INPUT_CONTEXT, - features=features) - - f_trip_time_distance = DerivedFeature(name="f_trip_time_distance", - feature_type=FLOAT, - input_features=[ - f_trip_distance, f_trip_time_duration], - transform="f_trip_distance * f_trip_time_duration") - - f_trip_time_rounded = DerivedFeature(name="f_trip_time_rounded", - feature_type=INT32, - input_features=[f_trip_time_duration], - transform="f_trip_time_duration % 10") - - location_id = TypedKey(key_column="DOLocationID", - key_column_type=ValueType.INT32, - description="location id in NYC", - full_name="nyc_taxi.location_id") - agg_features = [Feature(name="f_location_avg_fare", - key=location_id, - feature_type=FLOAT, - transform=WindowAggTransformation(agg_expr="cast_float(fare_amount)", - agg_func="AVG", - window="90d", - )), - Feature(name="f_location_max_fare", - key=location_id, - feature_type=FLOAT, - transform=WindowAggTransformation(agg_expr="cast_float(fare_amount)", - agg_func="MAX", - window="90d")) - ] - - agg_anchor = FeatureAnchor(name="aggregationFeatures", - source=batch_source, - features=agg_features) - - client.build_features(anchor_list=[agg_anchor, request_anchor], derived_feature_list=[ - f_trip_time_distance, f_trip_time_rounded]) - - location_id = TypedKey(key_column="DOLocationID", - key_column_type=ValueType.INT32, - description="location id in NYC", - full_name="nyc_taxi.location_id") - - feature_query = FeatureQuery( - feature_list=["f_location_avg_fare"], key=location_id) - settings = ObservationSettings( observation_path="./green_tripdata_2020-04_with_index.csv", - #observation_path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04.csv", + # observation_path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04.csv", event_timestamp_column="lpep_dropoff_datetime", - timestamp_format="yyyy-MM-dd HH:mm:ss") + timestamp_format="yyyy-MM-dd HH:mm:ss", + ) now = datetime.now().strftime("%Y%m%d%H%M%S") output_path = os.path.join("debug", f"test_output_{now}") - proc = client.get_offline_features(observation_settings=settings, - feature_query=feature_query, - output_path=output_path, - config_file_name = "feature_join_conf/feature_join_local.conf", - verbose=False) + proc = client.get_offline_features( + observation_settings=settings, + feature_query=feature_query, + output_path=output_path, + config_file_name="feature_join_conf/feature_join_local.conf", + verbose=False, + ) return output_path, proc -def _udf_features(client:FeathrClient = None): + +def _udf_features(client: FeathrClient = None): if not client: client = _local_client_setup() - - batch_source1 = HdfsSource(name="nycTaxiBatchSource_add_new_dropoff_and_fare_amount_column", - path="./green_tripdata_2020-04_with_index.csv", - preprocessing=add_new_dropoff_and_fare_amount_column, - event_timestamp_column="new_lpep_dropoff_datetime", - # event_timestamp_column="lpep_dropoff_datetime", - timestamp_format="yyyy-MM-dd HH:mm:ss") - - batch_source2 = HdfsSource(name="nycTaxiBatchSource_add_new_fare_amount", - path="./green_tripdata_2020-04_with_index.csv", - preprocessing=add_new_fare_amount, - event_timestamp_column="lpep_dropoff_datetime", - timestamp_format="yyyy-MM-dd HH:mm:ss") - - - location_id = TypedKey(key_column="DOLocationID", - key_column_type=ValueType.INT32, - description="location id in NYC", - full_name="nyc_taxi.location_id") - - - agg_features = [Feature(name="f_location_avg_fare", - key=location_id, - feature_type=FLOAT, - transform=WindowAggTransformation(agg_expr="new_fare_amount", - agg_func="SUM", - window="90d")), - Feature(name="f_location_max_fare", - key=location_id, - feature_type=FLOAT, - transform=WindowAggTransformation(agg_expr="new_fare_amount", - agg_func="MAX", - window="90d")) - ] - - agg_anchor = FeatureAnchor(name="aggregationFeatures", - source=batch_source1, - features=agg_features, - ) - - pickup_time_as_id = TypedKey(key_column="lpep_pickup_datetime", - key_column_type=ValueType.INT32, - description="location id in NYC", - full_name="nyc_taxi.location_id") + + batch_source1 = HdfsSource( + name="nycTaxiBatchSource_add_new_dropoff_and_fare_amount_column", + path="./green_tripdata_2020-04_with_index.csv", + preprocessing=add_new_dropoff_and_fare_amount_column, + event_timestamp_column="new_lpep_dropoff_datetime", + # event_timestamp_column="lpep_dropoff_datetime", + timestamp_format="yyyy-MM-dd HH:mm:ss", + ) + + batch_source2 = HdfsSource( + name="nycTaxiBatchSource_add_new_fare_amount", + path="./green_tripdata_2020-04_with_index.csv", + preprocessing=add_new_fare_amount, + event_timestamp_column="lpep_dropoff_datetime", + timestamp_format="yyyy-MM-dd HH:mm:ss", + ) + + location_id = TypedKey( + key_column="DOLocationID", + key_column_type=ValueType.INT32, + description="location id in NYC", + full_name="nyc_taxi.location_id", + ) + + agg_features = [ + Feature( + name="f_location_avg_fare", + key=location_id, + feature_type=FLOAT, + transform=WindowAggTransformation(agg_expr="new_fare_amount", agg_func="SUM", window="90d"), + ), + Feature( + name="f_location_max_fare", + key=location_id, + feature_type=FLOAT, + transform=WindowAggTransformation(agg_expr="new_fare_amount", agg_func="MAX", window="90d"), + ), + ] + + agg_anchor = FeatureAnchor( + name="aggregationFeatures", + source=batch_source1, + features=agg_features, + ) + + pickup_time_as_id = TypedKey( + key_column="lpep_pickup_datetime", + key_column_type=ValueType.INT32, + description="location id in NYC", + full_name="nyc_taxi.location_id", + ) features = [ - Feature(name="f_is_long_trip_distance", - key=pickup_time_as_id, - feature_type=FLOAT, - transform="fare_amount_new"), - Feature(name="f_day_of_week", - key=pickup_time_as_id, - feature_type=INT32, - transform="dayofweek(lpep_dropoff_datetime)"), + Feature(name="f_is_long_trip_distance", key=pickup_time_as_id, feature_type=FLOAT, transform="fare_amount_new"), + Feature( + name="f_day_of_week", + key=pickup_time_as_id, + feature_type=INT32, + transform="dayofweek(lpep_dropoff_datetime)", + ), ] - regular_anchor = FeatureAnchor(name="regular_anchor", - source=batch_source2, - features=features, - ) + regular_anchor = FeatureAnchor( + name="regular_anchor", + source=batch_source2, + features=features, + ) client.build_features(anchor_list=[agg_anchor, regular_anchor]) - feature_query = [FeatureQuery( - feature_list=["f_is_long_trip_distance", "f_day_of_week"], key=pickup_time_as_id), - FeatureQuery( - feature_list=["f_location_avg_fare", "f_location_max_fare"], key=location_id) + feature_query = [ + FeatureQuery(feature_list=["f_is_long_trip_distance", "f_day_of_week"], key=pickup_time_as_id), + FeatureQuery(feature_list=["f_location_avg_fare", "f_location_max_fare"], key=location_id), ] settings = ObservationSettings( observation_path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04.csv", event_timestamp_column="lpep_dropoff_datetime", - timestamp_format="yyyy-MM-dd HH:mm:ss") + timestamp_format="yyyy-MM-dd HH:mm:ss", + ) now = datetime.now().strftime("%Y%m%d%H%M%S") output_path = os.path.join("debug", f"test_output_{now}") - proc = client.get_offline_features(observation_settings=settings, - feature_query=feature_query, - output_path=output_path) + proc = client.get_offline_features( + observation_settings=settings, feature_query=feature_query, output_path=output_path + ) return output_path, proc -def _feature_gen_test(client:FeathrClient = None): + +def _feature_gen_test(client: FeathrClient = None): if not client: client = _local_client_setup() - batch_source = HdfsSource(name="nycTaxiBatchSource", - path="./green_tripdata_2020-04_with_index.csv", - event_timestamp_column="lpep_dropoff_datetime", - timestamp_format="yyyy-MM-dd HH:mm:ss") + batch_source = HdfsSource( + name="nycTaxiBatchSource", + path="./green_tripdata_2020-04_with_index.csv", + event_timestamp_column="lpep_dropoff_datetime", + timestamp_format="yyyy-MM-dd HH:mm:ss", + ) - f_trip_distance = Feature(name="f_trip_distance", - feature_type=FLOAT, transform="trip_distance") - f_trip_time_duration = Feature(name="f_trip_time_duration", - feature_type=INT32, - transform="(to_unix_timestamp(lpep_dropoff_datetime) - to_unix_timestamp(lpep_pickup_datetime))/60") + f_trip_distance = Feature(name="f_trip_distance", feature_type=FLOAT, transform="trip_distance") + f_trip_time_duration = Feature( + name="f_trip_time_duration", + feature_type=INT32, + transform="(to_unix_timestamp(lpep_dropoff_datetime) - to_unix_timestamp(lpep_pickup_datetime))/60", + ) features = [ f_trip_distance, f_trip_time_duration, - Feature(name="f_is_long_trip_distance", - feature_type=BOOLEAN, - transform="cast_float(trip_distance)>30"), - Feature(name="f_day_of_week", - feature_type=INT32, - transform="dayofweek(lpep_dropoff_datetime)"), + Feature(name="f_is_long_trip_distance", feature_type=BOOLEAN, transform="cast_float(trip_distance)>30"), + Feature(name="f_day_of_week", feature_type=INT32, transform="dayofweek(lpep_dropoff_datetime)"), + ] + + request_anchor = FeatureAnchor(name="request_features", source=INPUT_CONTEXT, features=features) + + f_trip_time_distance = DerivedFeature( + name="f_trip_time_distance", + feature_type=FLOAT, + input_features=[f_trip_distance, f_trip_time_duration], + transform="f_trip_distance * f_trip_time_duration", + ) + + f_trip_time_rounded = DerivedFeature( + name="f_trip_time_rounded", + feature_type=INT32, + input_features=[f_trip_time_duration], + transform="f_trip_time_duration % 10", + ) + + location_id = TypedKey( + key_column="DOLocationID", + key_column_type=ValueType.INT32, + description="location id in NYC", + full_name="nyc_taxi.location_id", + ) + agg_features = [ + Feature( + name="f_location_avg_fare", + key=location_id, + feature_type=FLOAT, + transform=WindowAggTransformation( + agg_expr="cast_float(fare_amount)", + agg_func="AVG", + window="90d", + ), + ), + Feature( + name="f_location_max_fare", + key=location_id, + feature_type=FLOAT, + transform=WindowAggTransformation(agg_expr="cast_float(fare_amount)", agg_func="MAX", window="90d"), + ), ] + agg_anchor = FeatureAnchor(name="aggregationFeatures", source=batch_source, features=agg_features) + + client.build_features( + anchor_list=[agg_anchor, request_anchor], derived_feature_list=[f_trip_time_distance, f_trip_time_rounded] + ) - request_anchor = FeatureAnchor(name="request_features", - source=INPUT_CONTEXT, - features=features) - - f_trip_time_distance = DerivedFeature(name="f_trip_time_distance", - feature_type=FLOAT, - input_features=[ - f_trip_distance, f_trip_time_duration], - transform="f_trip_distance * f_trip_time_duration") - - f_trip_time_rounded = DerivedFeature(name="f_trip_time_rounded", - feature_type=INT32, - input_features=[f_trip_time_duration], - transform="f_trip_time_duration % 10") - - location_id = TypedKey(key_column="DOLocationID", - key_column_type=ValueType.INT32, - description="location id in NYC", - full_name="nyc_taxi.location_id") - agg_features = [Feature(name="f_location_avg_fare", - key=location_id, - feature_type=FLOAT, - transform=WindowAggTransformation(agg_expr="cast_float(fare_amount)", - agg_func="AVG", - window="90d", - )), - Feature(name="f_location_max_fare", - key=location_id, - feature_type=FLOAT, - transform=WindowAggTransformation(agg_expr="cast_float(fare_amount)", - agg_func="MAX", - window="90d")) - ] - - agg_anchor = FeatureAnchor(name="aggregationFeatures", - source=batch_source, - features=agg_features) - - client.build_features(anchor_list=[agg_anchor, request_anchor], derived_feature_list=[ - f_trip_time_distance, f_trip_time_rounded]) - online_test_table = "localSparkTest" - backfill_time = BackfillTime(start=datetime( - 2020, 4, 1), end=datetime(2020, 4, 2), step=timedelta(days=1)) + backfill_time = BackfillTime(start=datetime(2020, 4, 1), end=datetime(2020, 4, 2), step=timedelta(days=1)) redisSink = RedisSink(table_name=online_test_table) - settings = MaterializationSettings("LocalSparkTest", - sinks=[redisSink], - feature_names=[ - "f_location_avg_fare", "f_location_max_fare"], - backfill_time=backfill_time) + settings = MaterializationSettings( + "LocalSparkTest", + sinks=[redisSink], + feature_names=["f_location_avg_fare", "f_location_max_fare"], + backfill_time=backfill_time, + ) results = client.materialize_features(settings) client.wait_job_to_finish() - res = client.get_online_features(online_test_table, '243', [ - 'f_location_avg_fare', 'f_location_max_fare']) + res = client.get_online_features(online_test_table, "243", ["f_location_avg_fare", "f_location_max_fare"]) return res -def parse_avro_result(output_path:str): + +def parse_avro_result(output_path: str): dataframe_list = [] # assuming the result are in avro format - for file in glob.glob(os.path.join(output_path, '*.avro')): + for file in glob.glob(os.path.join(output_path, "*.avro")): dataframe_list.append(pdx.read_avro(file)) - + vertical_concat_df = pd.concat(dataframe_list, axis=0) - return vertical_concat_df \ No newline at end of file + return vertical_concat_df diff --git a/feathr_project/test/test_lookup_feature.py b/feathr_project/test/test_lookup_feature.py index 82fe385a7..2b2ee5388 100644 --- a/feathr_project/test/test_lookup_feature.py +++ b/feathr_project/test/test_lookup_feature.py @@ -5,24 +5,38 @@ from feathr import FLOAT, FLOAT_VECTOR, ValueType, INT32_VECTOR from feathr import TypedKey + def assert_config_equals(one, another): - assert one.translate(str.maketrans('', '', ' \n\t\r')) == another.translate(str.maketrans('', '', ' \n\t\r')) - + assert one.translate(str.maketrans("", "", " \n\t\r")) == another.translate(str.maketrans("", "", " \n\t\r")) + + def test_single_key_lookup_feature_to_config(): """Single key lookup feature config generation should work""" - user_key = TypedKey(full_name="mockdata.user", key_column="user_id", key_column_type=ValueType.INT32, description="An user identifier") - item_key = TypedKey(full_name="mockdata.item", key_column="item_id", key_column_type=ValueType.INT32, description="An item identifier") - + user_key = TypedKey( + full_name="mockdata.user", + key_column="user_id", + key_column_type=ValueType.INT32, + description="An user identifier", + ) + item_key = TypedKey( + full_name="mockdata.item", + key_column="item_id", + key_column_type=ValueType.INT32, + description="An item identifier", + ) + user_item = Feature(name="user_items", feature_type=INT32_VECTOR, key=user_key) item_price = Feature(name="item_price", feature_type=FLOAT_VECTOR, key=item_key) # A lookup feature - lookup_feature = LookupFeature(name="user_avg_item_price", - feature_type=FLOAT, - key=user_key, - base_feature=user_item, - expansion_feature=item_price, - aggregation=Aggregation.AVG) + lookup_feature = LookupFeature( + name="user_avg_item_price", + feature_type=FLOAT, + key=user_key, + base_feature=user_item, + expansion_feature=item_price, + aggregation=Aggregation.AVG, + ) lookup_feature_config = """ user_avg_item_price: { @@ -40,5 +54,4 @@ def test_single_key_lookup_feature_to_config(): } }""" assert_config_equals(lookup_feature.to_feature_config(), lookup_feature_config) - assert(isinstance(lookup_feature, DerivedFeature)) - \ No newline at end of file + assert isinstance(lookup_feature, DerivedFeature) diff --git a/feathr_project/test/test_observation_setting.py b/feathr_project/test/test_observation_setting.py index aa9cd6f72..9197f36ea 100644 --- a/feathr_project/test/test_observation_setting.py +++ b/feathr_project/test/test_observation_setting.py @@ -5,7 +5,8 @@ def test_observation_setting_with_timestamp(): observation_settings = ObservationSettings( observation_path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04.csv", event_timestamp_column="lpep_dropoff_datetime", - timestamp_format="yyyy-MM-dd HH:mm:ss") + timestamp_format="yyyy-MM-dd HH:mm:ss", + ) config = observation_settings.to_feature_config() expected_config = """ settings: { @@ -19,15 +20,15 @@ def test_observation_setting_with_timestamp(): observationPath: "wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04.csv" """ - assert ''.join(config.split()) == ''.join(expected_config.split()) + assert "".join(config.split()) == "".join(expected_config.split()) def test_observation_setting_without_timestamp(): - observation_settings = ObservationSettings( - observation_path='snowflake://snowflake_account/?dbtable=CALL_CENTER&sfDatabase=SNOWFLAKE_SAMPLE_DATA&sfSchema=TPCDS_SF10TCL') + observation_path="snowflake://snowflake_account/?dbtable=CALL_CENTER&sfDatabase=SNOWFLAKE_SAMPLE_DATA&sfSchema=TPCDS_SF10TCL" + ) config = observation_settings.to_feature_config() expected_config = """ observationPath:"snowflake://snowflake_account/?dbtable=CALL_CENTER&sfDatabase=SNOWFLAKE_SAMPLE_DATA&sfSchema=TPCDS_SF10TCL" """ - assert ''.join(config.split()) == ''.join(expected_config.split()) \ No newline at end of file + assert "".join(config.split()) == "".join(expected_config.split()) diff --git a/feathr_project/test/test_pyduf_preprocessing_e2e.py b/feathr_project/test/test_pyduf_preprocessing_e2e.py index 9d3fae395..c3c08a522 100644 --- a/feathr_project/test/test_pyduf_preprocessing_e2e.py +++ b/feathr_project/test/test_pyduf_preprocessing_e2e.py @@ -7,7 +7,7 @@ from pyspark.sql import DataFrame from pyspark.sql.functions import col -from feathr import (BackfillTime, MaterializationSettings) +from feathr import BackfillTime, MaterializationSettings from feathr import Feature from feathr import FeatureAnchor from feathr import FeatureQuery @@ -18,26 +18,29 @@ from feathr import TypedKey from feathr import WindowAggTransformation from feathr.utils.job_utils import get_result_df -from test_fixture import (snowflake_test_setup, get_online_test_table_name, basic_test_setup) +from test_fixture import snowflake_test_setup, get_online_test_table_name, basic_test_setup from test_utils.constants import Constants def trip_distance_preprocessing(df: DataFrame): - df = df.withColumn("trip_distance", df.trip_distance.cast('double') - 90000) - df = df.withColumn("fare_amount", df.fare_amount.cast('double') - 90000) + df = df.withColumn("trip_distance", df.trip_distance.cast("double") - 90000) + df = df.withColumn("fare_amount", df.fare_amount.cast("double") - 90000) return df + def add_new_dropoff_and_fare_amount_column(df: DataFrame): df = df.withColumn("new_lpep_dropoff_datetime", col("lpep_dropoff_datetime")) df = df.withColumn("new_fare_amount", col("fare_amount") + 1000000) return df + def add_new_fare_amount(df: DataFrame) -> DataFrame: df = df.withColumn("fare_amount_new", col("fare_amount") + 8000000) return df + def add_new_surcharge_amount_and_pickup_column(df: DataFrame) -> DataFrame: df = df.withColumn("new_improvement_surcharge", col("improvement_surcharge") + 1000000) df = df.withColumn("new_tip_amount", col("tip_amount") + 1000000) @@ -45,72 +48,78 @@ def add_new_surcharge_amount_and_pickup_column(df: DataFrame) -> DataFrame: return df + def add_old_lpep_dropoff_datetime(df: DataFrame) -> DataFrame: df = df.withColumn("old_lpep_dropoff_datetime", col("lpep_dropoff_datetime")) return df + def feathr_udf_day_calc(df: DataFrame) -> DataFrame: df = df.withColumn("f_day_of_week", dayofweek("lpep_dropoff_datetime")) df = df.withColumn("f_day_of_year", dayofyear("lpep_dropoff_datetime")) return df + def test_non_swa_feature_gen_with_offline_preprocessing(): """ Test non-SWA feature gen with preprocessing """ test_workspace_dir = Path(__file__).parent.resolve() / "test_user_workspace" - client:FeathrClient = basic_test_setup(os.path.join(test_workspace_dir, "feathr_config.yaml")) + client: FeathrClient = basic_test_setup(os.path.join(test_workspace_dir, "feathr_config.yaml")) - batch_source = HdfsSource(name="nycTaxiBatchSource_add_new_fare_amount", - path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04.csv", - preprocessing=add_new_fare_amount, - event_timestamp_column="lpep_dropoff_datetime", - timestamp_format="yyyy-MM-dd HH:mm:ss") + batch_source = HdfsSource( + name="nycTaxiBatchSource_add_new_fare_amount", + path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04.csv", + preprocessing=add_new_fare_amount, + event_timestamp_column="lpep_dropoff_datetime", + timestamp_format="yyyy-MM-dd HH:mm:ss", + ) - pickup_time_as_id = TypedKey(key_column="lpep_pickup_datetime", - key_column_type=ValueType.INT32, - description="location id in NYC", - full_name="nyc_taxi.location_id") + pickup_time_as_id = TypedKey( + key_column="lpep_pickup_datetime", + key_column_type=ValueType.INT32, + description="location id in NYC", + full_name="nyc_taxi.location_id", + ) features = [ - Feature(name="f_is_long_trip_distance", - key=pickup_time_as_id, - feature_type=FLOAT, - transform="fare_amount_new"), - Feature(name="f_day_of_week", - key=pickup_time_as_id, - feature_type=INT32, - transform="dayofweek(lpep_dropoff_datetime)"), + Feature(name="f_is_long_trip_distance", key=pickup_time_as_id, feature_type=FLOAT, transform="fare_amount_new"), + Feature( + name="f_day_of_week", + key=pickup_time_as_id, + feature_type=INT32, + transform="dayofweek(lpep_dropoff_datetime)", + ), ] - regular_anchor = FeatureAnchor(name="request_features_add_new_fare_amount", - source=batch_source, - features=features, - ) + regular_anchor = FeatureAnchor( + name="request_features_add_new_fare_amount", + source=batch_source, + features=features, + ) client.build_features(anchor_list=[regular_anchor]) - online_test_table = get_online_test_table_name('nycTaxiCITableOfflineProcessing') + online_test_table = get_online_test_table_name("nycTaxiCITableOfflineProcessing") - backfill_time = BackfillTime(start=datetime( - 2020, 5, 20), end=datetime(2020, 5, 20), step=timedelta(days=1)) + backfill_time = BackfillTime(start=datetime(2020, 5, 20), end=datetime(2020, 5, 20), step=timedelta(days=1)) redisSink = RedisSink(table_name=online_test_table) - settings = MaterializationSettings(name="py_udf", - sinks=[redisSink], - feature_names=[ - "f_is_long_trip_distance", - "f_day_of_week" - ], - backfill_time=backfill_time) + settings = MaterializationSettings( + name="py_udf", + sinks=[redisSink], + feature_names=["f_is_long_trip_distance", "f_day_of_week"], + backfill_time=backfill_time, + ) client.materialize_features(settings, allow_materialize_non_agg_feature=True) # just assume the job is successful without validating the actual result in Redis. Might need to consolidate # this part with the test_feathr_online_store test case client.wait_job_to_finish(timeout_sec=Constants.SPARK_JOB_TIMEOUT_SECONDS) - res = client.get_online_features(online_test_table, '2020-04-01 07:21:51', [ - 'f_is_long_trip_distance', 'f_day_of_week']) + res = client.get_online_features( + online_test_table, "2020-04-01 07:21:51", ["f_is_long_trip_distance", "f_day_of_week"] + ) assert res == [8000006.0, 4] client._clean_test_data(online_test_table) @@ -123,56 +132,59 @@ def test_feature_swa_feature_gen_with_preprocessing(): test_workspace_dir = Path(__file__).parent.resolve() / "test_user_workspace" client = basic_test_setup(os.path.join(test_workspace_dir, "feathr_config.yaml")) - batch_source = HdfsSource(name="nycTaxiBatchSource", - path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04.csv", - preprocessing=add_new_dropoff_and_fare_amount_column, - event_timestamp_column="new_lpep_dropoff_datetime", - timestamp_format="yyyy-MM-dd HH:mm:ss") - - location_id = TypedKey(key_column="DOLocationID", - key_column_type=ValueType.INT32, - description="location id in NYC", - full_name="nyc_taxi.location_id") - - - agg_features = [Feature(name="f_location_avg_fare", - key=location_id, - feature_type=FLOAT, - transform=WindowAggTransformation(agg_expr="new_fare_amount", - agg_func="AVG", - window="90d")), - Feature(name="f_location_max_fare", - key=location_id, - feature_type=FLOAT, - transform=WindowAggTransformation(agg_expr="new_fare_amount", - agg_func="MAX", - window="90d")) - ] - - agg_anchor = FeatureAnchor(name="aggregationFeatures", - source=batch_source, - features=agg_features) + batch_source = HdfsSource( + name="nycTaxiBatchSource", + path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04.csv", + preprocessing=add_new_dropoff_and_fare_amount_column, + event_timestamp_column="new_lpep_dropoff_datetime", + timestamp_format="yyyy-MM-dd HH:mm:ss", + ) + + location_id = TypedKey( + key_column="DOLocationID", + key_column_type=ValueType.INT32, + description="location id in NYC", + full_name="nyc_taxi.location_id", + ) + + agg_features = [ + Feature( + name="f_location_avg_fare", + key=location_id, + feature_type=FLOAT, + transform=WindowAggTransformation(agg_expr="new_fare_amount", agg_func="AVG", window="90d"), + ), + Feature( + name="f_location_max_fare", + key=location_id, + feature_type=FLOAT, + transform=WindowAggTransformation(agg_expr="new_fare_amount", agg_func="MAX", window="90d"), + ), + ] + + agg_anchor = FeatureAnchor(name="aggregationFeatures", source=batch_source, features=agg_features) client.build_features(anchor_list=[agg_anchor]) - online_test_table = get_online_test_table_name('nycTaxiCITableSWAFeatureMaterialization') + online_test_table = get_online_test_table_name("nycTaxiCITableSWAFeatureMaterialization") - backfill_time = BackfillTime(start=datetime( - 2020, 5, 20), end=datetime(2020, 5, 20), step=timedelta(days=1)) + backfill_time = BackfillTime(start=datetime(2020, 5, 20), end=datetime(2020, 5, 20), step=timedelta(days=1)) redisSink = RedisSink(table_name=online_test_table) - settings = MaterializationSettings(name="py_udf", - sinks=[redisSink], - feature_names=[ - "f_location_avg_fare", - "f_location_max_fare", - ], - backfill_time=backfill_time) + settings = MaterializationSettings( + name="py_udf", + sinks=[redisSink], + feature_names=[ + "f_location_avg_fare", + "f_location_max_fare", + ], + backfill_time=backfill_time, + ) client.materialize_features(settings) # just assume the job is successful without validating the actual result in Redis. Might need to consolidate # this part with the test_feathr_online_store test case client.wait_job_to_finish(timeout_sec=Constants.SPARK_JOB_TIMEOUT_SECONDS) - res = client.get_online_features(online_test_table, '265', ['f_location_avg_fare', 'f_location_max_fare']) + res = client.get_online_features(online_test_table, "265", ["f_location_avg_fare", "f_location_max_fare"]) assert res == [1000041.625, 1000100.0] client._clean_test_data(online_test_table) @@ -186,90 +198,104 @@ def test_feathr_get_offline_features_hdfs_source(): client = basic_test_setup(os.path.join(test_workspace_dir, "feathr_config.yaml")) - batch_source1 = HdfsSource(name="nycTaxiBatchSource_add_new_dropoff_and_fare_amount_column", - path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04.csv", - preprocessing=add_new_dropoff_and_fare_amount_column, - event_timestamp_column="new_lpep_dropoff_datetime", - # event_timestamp_column="lpep_dropoff_datetime", - timestamp_format="yyyy-MM-dd HH:mm:ss") - - batch_source2 = HdfsSource(name="nycTaxiBatchSource_add_new_fare_amount", - path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04.csv", - preprocessing=add_new_fare_amount, - event_timestamp_column="lpep_dropoff_datetime", - timestamp_format="yyyy-MM-dd HH:mm:ss") - - - location_id = TypedKey(key_column="DOLocationID", - key_column_type=ValueType.INT32, - description="location id in NYC", - full_name="nyc_taxi.location_id") - - - agg_features = [Feature(name="f_location_avg_fare", - key=location_id, - feature_type=FLOAT, - transform=WindowAggTransformation(agg_expr="new_fare_amount", - agg_func="SUM", - window="90d")), - Feature(name="f_location_max_fare", - key=location_id, - feature_type=FLOAT, - transform=WindowAggTransformation(agg_expr="new_fare_amount", - agg_func="MAX", - window="90d")) - ] - - agg_anchor = FeatureAnchor(name="aggregationFeatures", - source=batch_source1, - features=agg_features, - ) - - pickup_time_as_id = TypedKey(key_column="lpep_pickup_datetime", - key_column_type=ValueType.INT32, - description="location id in NYC", - full_name="nyc_taxi.location_id") + batch_source1 = HdfsSource( + name="nycTaxiBatchSource_add_new_dropoff_and_fare_amount_column", + path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04.csv", + preprocessing=add_new_dropoff_and_fare_amount_column, + event_timestamp_column="new_lpep_dropoff_datetime", + # event_timestamp_column="lpep_dropoff_datetime", + timestamp_format="yyyy-MM-dd HH:mm:ss", + ) + + batch_source2 = HdfsSource( + name="nycTaxiBatchSource_add_new_fare_amount", + path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04.csv", + preprocessing=add_new_fare_amount, + event_timestamp_column="lpep_dropoff_datetime", + timestamp_format="yyyy-MM-dd HH:mm:ss", + ) + + location_id = TypedKey( + key_column="DOLocationID", + key_column_type=ValueType.INT32, + description="location id in NYC", + full_name="nyc_taxi.location_id", + ) + + agg_features = [ + Feature( + name="f_location_avg_fare", + key=location_id, + feature_type=FLOAT, + transform=WindowAggTransformation(agg_expr="new_fare_amount", agg_func="SUM", window="90d"), + ), + Feature( + name="f_location_max_fare", + key=location_id, + feature_type=FLOAT, + transform=WindowAggTransformation(agg_expr="new_fare_amount", agg_func="MAX", window="90d"), + ), + ] + + agg_anchor = FeatureAnchor( + name="aggregationFeatures", + source=batch_source1, + features=agg_features, + ) + + pickup_time_as_id = TypedKey( + key_column="lpep_pickup_datetime", + key_column_type=ValueType.INT32, + description="location id in NYC", + full_name="nyc_taxi.location_id", + ) features = [ - Feature(name="f_is_long_trip_distance", - key=pickup_time_as_id, - feature_type=FLOAT, - transform="fare_amount_new"), - Feature(name="f_day_of_week", - key=pickup_time_as_id, - feature_type=INT32, - transform="dayofweek(lpep_dropoff_datetime)"), + Feature(name="f_is_long_trip_distance", key=pickup_time_as_id, feature_type=FLOAT, transform="fare_amount_new"), + Feature( + name="f_day_of_week", + key=pickup_time_as_id, + feature_type=INT32, + transform="dayofweek(lpep_dropoff_datetime)", + ), ] - regular_anchor = FeatureAnchor(name="regular_anchor", - source=batch_source2, - features=features, - ) + regular_anchor = FeatureAnchor( + name="regular_anchor", + source=batch_source2, + features=features, + ) client.build_features(anchor_list=[agg_anchor, regular_anchor]) - feature_query = [FeatureQuery( - feature_list=["f_is_long_trip_distance", "f_day_of_week"], key=pickup_time_as_id), - FeatureQuery( - feature_list=["f_location_avg_fare", "f_location_max_fare"], key=location_id) + feature_query = [ + FeatureQuery(feature_list=["f_is_long_trip_distance", "f_day_of_week"], key=pickup_time_as_id), + FeatureQuery(feature_list=["f_location_avg_fare", "f_location_max_fare"], key=location_id), ] settings = ObservationSettings( observation_path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04.csv", event_timestamp_column="lpep_dropoff_datetime", - timestamp_format="yyyy-MM-dd HH:mm:ss") + timestamp_format="yyyy-MM-dd HH:mm:ss", + ) now = datetime.now() # set output folder based on different runtime - if client.spark_runtime == 'databricks': - output_path = ''.join(['dbfs:/feathrazure_cijob','_', str(now.minute), '_', str(now.second), ".avro"]) + if client.spark_runtime == "databricks": + output_path = "".join(["dbfs:/feathrazure_cijob", "_", str(now.minute), "_", str(now.second), ".avro"]) else: - output_path = ''.join(['abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/demo_data/output','_', str(now.minute), '_', str(now.second), ".avro"]) - - - client.get_offline_features(observation_settings=settings, - feature_query=feature_query, - output_path=output_path) + output_path = "".join( + [ + "abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/demo_data/output", + "_", + str(now.minute), + "_", + str(now.second), + ".avro", + ] + ) + + client.get_offline_features(observation_settings=settings, feature_query=feature_query, output_path=output_path) # assuming the job can successfully run; otherwise it will throw exception client.wait_job_to_finish(timeout_sec=Constants.SPARK_JOB_TIMEOUT_SECONDS) @@ -287,105 +313,125 @@ def test_get_offline_feature_two_swa_with_diff_preprocessing(): client = basic_test_setup(os.path.join(test_workspace_dir, "feathr_config.yaml")) - swa_source_1 = HdfsSource(name="nycTaxiBatchSource1", - path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04.csv", - preprocessing=add_new_dropoff_and_fare_amount_column, - event_timestamp_column="new_lpep_dropoff_datetime", - timestamp_format="yyyy-MM-dd HH:mm:ss") - - location_id = TypedKey(key_column="DOLocationID", - key_column_type=ValueType.INT32, - description="location id in NYC", - full_name="nyc_taxi.location_id") - - - agg_features1 = [Feature(name="f_location_avg_fare", - key=location_id, - feature_type=FLOAT, - transform=WindowAggTransformation(agg_expr="new_fare_amount", - agg_func="SUM", - window="90d")), - Feature(name="f_location_max_fare", - key=location_id, - feature_type=FLOAT, - transform=WindowAggTransformation(agg_expr="new_fare_amount", - agg_func="MAX", - window="90d")) - ] - - agg_anchor1 = FeatureAnchor(name="aggregationFeatures1", - source=swa_source_1, - features=agg_features1, - ) - - - swa_source_2 = HdfsSource(name="nycTaxiBatchSource2", - path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04.csv", - preprocessing=add_new_surcharge_amount_and_pickup_column, - event_timestamp_column="new_lpep_pickup_datetime", - timestamp_format="yyyy-MM-dd HH:mm:ss") - - agg_features2 = [Feature(name="f_location_new_tip_amount", - key=location_id, - feature_type=FLOAT, - transform=WindowAggTransformation(agg_expr="new_tip_amount", - agg_func="SUM", - window="90d")), - Feature(name="f_location_max_improvement_surcharge", - key=location_id, - feature_type=FLOAT, - transform=WindowAggTransformation(agg_expr="new_improvement_surcharge", - agg_func="SUM", - window="90d")) - ] - agg_anchor2 = FeatureAnchor(name="aggregationFeatures2", - source=swa_source_2, - features=agg_features2, - ) - - swa_source_3 = HdfsSource(name="nycTaxiBatchSource3", - path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04_old.csv", - preprocessing=add_old_lpep_dropoff_datetime, - event_timestamp_column="old_lpep_dropoff_datetime", - timestamp_format="yyyy-MM-dd HH:mm:ss") - - agg_features3 = [Feature(name="f_location_old_tip_amount", - key=location_id, - feature_type=FLOAT, - transform=WindowAggTransformation(agg_expr="cast_double(old_tip_amount)", - agg_func="SUM", - window="90d")) - ] - agg_anchor3 = FeatureAnchor(name="aggregationFeatures3", - source=swa_source_3, - features=agg_features3, - ) + swa_source_1 = HdfsSource( + name="nycTaxiBatchSource1", + path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04.csv", + preprocessing=add_new_dropoff_and_fare_amount_column, + event_timestamp_column="new_lpep_dropoff_datetime", + timestamp_format="yyyy-MM-dd HH:mm:ss", + ) + + location_id = TypedKey( + key_column="DOLocationID", + key_column_type=ValueType.INT32, + description="location id in NYC", + full_name="nyc_taxi.location_id", + ) + + agg_features1 = [ + Feature( + name="f_location_avg_fare", + key=location_id, + feature_type=FLOAT, + transform=WindowAggTransformation(agg_expr="new_fare_amount", agg_func="SUM", window="90d"), + ), + Feature( + name="f_location_max_fare", + key=location_id, + feature_type=FLOAT, + transform=WindowAggTransformation(agg_expr="new_fare_amount", agg_func="MAX", window="90d"), + ), + ] + + agg_anchor1 = FeatureAnchor( + name="aggregationFeatures1", + source=swa_source_1, + features=agg_features1, + ) + + swa_source_2 = HdfsSource( + name="nycTaxiBatchSource2", + path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04.csv", + preprocessing=add_new_surcharge_amount_and_pickup_column, + event_timestamp_column="new_lpep_pickup_datetime", + timestamp_format="yyyy-MM-dd HH:mm:ss", + ) + + agg_features2 = [ + Feature( + name="f_location_new_tip_amount", + key=location_id, + feature_type=FLOAT, + transform=WindowAggTransformation(agg_expr="new_tip_amount", agg_func="SUM", window="90d"), + ), + Feature( + name="f_location_max_improvement_surcharge", + key=location_id, + feature_type=FLOAT, + transform=WindowAggTransformation(agg_expr="new_improvement_surcharge", agg_func="SUM", window="90d"), + ), + ] + agg_anchor2 = FeatureAnchor( + name="aggregationFeatures2", + source=swa_source_2, + features=agg_features2, + ) + + swa_source_3 = HdfsSource( + name="nycTaxiBatchSource3", + path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04_old.csv", + preprocessing=add_old_lpep_dropoff_datetime, + event_timestamp_column="old_lpep_dropoff_datetime", + timestamp_format="yyyy-MM-dd HH:mm:ss", + ) + + agg_features3 = [ + Feature( + name="f_location_old_tip_amount", + key=location_id, + feature_type=FLOAT, + transform=WindowAggTransformation(agg_expr="cast_double(old_tip_amount)", agg_func="SUM", window="90d"), + ) + ] + agg_anchor3 = FeatureAnchor( + name="aggregationFeatures3", + source=swa_source_3, + features=agg_features3, + ) client.build_features(anchor_list=[agg_anchor1, agg_anchor2, agg_anchor3]) feature_query = [ - FeatureQuery(feature_list=["f_location_new_tip_amount", "f_location_max_improvement_surcharge"], key=location_id), FeatureQuery( - feature_list=["f_location_avg_fare", "f_location_max_fare"], key=location_id), - FeatureQuery( - feature_list=["f_location_old_tip_amount"], key=location_id) + feature_list=["f_location_new_tip_amount", "f_location_max_improvement_surcharge"], key=location_id + ), + FeatureQuery(feature_list=["f_location_avg_fare", "f_location_max_fare"], key=location_id), + FeatureQuery(feature_list=["f_location_old_tip_amount"], key=location_id), ] settings = ObservationSettings( observation_path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04.csv", event_timestamp_column="lpep_dropoff_datetime", - timestamp_format="yyyy-MM-dd HH:mm:ss") + timestamp_format="yyyy-MM-dd HH:mm:ss", + ) now = datetime.now() # set output folder based on different runtime - if client.spark_runtime == 'databricks': - output_path = ''.join(['dbfs:/feathrazure_cijob','_', str(now.minute), '_', str(now.second), ".avro"]) + if client.spark_runtime == "databricks": + output_path = "".join(["dbfs:/feathrazure_cijob", "_", str(now.minute), "_", str(now.second), ".avro"]) else: - output_path = ''.join(['abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/demo_data/output','_', str(now.minute), '_', str(now.second), ".avro"]) - - client.get_offline_features(observation_settings=settings, - feature_query=feature_query, - output_path=output_path) + output_path = "".join( + [ + "abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/demo_data/output", + "_", + str(now.minute), + "_", + str(now.second), + ".avro", + ] + ) + + client.get_offline_features(observation_settings=settings, feature_query=feature_query, output_path=output_path) # assuming the job can successfully run; otherwise it will throw exception client.wait_job_to_finish(timeout_sec=Constants.SPARK_JOB_TIMEOUT_SECONDS) @@ -400,6 +446,7 @@ def snowflake_preprocessing(df: DataFrame) -> DataFrame: df = df.withColumn("NEW_CC_ZIP", concat(col("CC_ZIP"), lit("____"), col("CC_ZIP"))) return df + @pytest.mark.skip(reason="All snowflake tests are skipped for now due to budget restriction.") def test_feathr_get_offline_features_from_snowflake(): """ @@ -407,53 +454,76 @@ def test_feathr_get_offline_features_from_snowflake(): """ test_workspace_dir = Path(__file__).parent.resolve() / "test_user_workspace" client = snowflake_test_setup(os.path.join(test_workspace_dir, "feathr_config.yaml")) - batch_source = SnowflakeSource(name="nycTaxiBatchSource", - database="SNOWFLAKE_SAMPLE_DATA", - schema="TPCDS_SF10TCL", - dbtable="CALL_CENTER", - preprocessing=snowflake_preprocessing, - event_timestamp_column="lpep_dropoff_datetime", - timestamp_format="yyyy-MM-dd HH:mm:ss") - call_sk_id = TypedKey(key_column="CC_CALL_CENTER_SK", - key_column_type=ValueType.STRING, - description="call center sk", - full_name="snowflake.CC_CALL_CENTER_SK") + batch_source = SnowflakeSource( + name="nycTaxiBatchSource", + database="SNOWFLAKE_SAMPLE_DATA", + schema="TPCDS_SF10TCL", + dbtable="CALL_CENTER", + preprocessing=snowflake_preprocessing, + event_timestamp_column="lpep_dropoff_datetime", + timestamp_format="yyyy-MM-dd HH:mm:ss", + ) + call_sk_id = TypedKey( + key_column="CC_CALL_CENTER_SK", + key_column_type=ValueType.STRING, + description="call center sk", + full_name="snowflake.CC_CALL_CENTER_SK", + ) features = [ - Feature(name="f_snowflake_call_center_division_name_with_preprocessing", - key=call_sk_id, - feature_type=STRING, - transform="NEW_CC_DIVISION_NAME"), - Feature(name="f_snowflake_call_center_zipcode_with_preprocessing", - key=call_sk_id, - feature_type=STRING, - transform="NEW_CC_ZIP"), + Feature( + name="f_snowflake_call_center_division_name_with_preprocessing", + key=call_sk_id, + feature_type=STRING, + transform="NEW_CC_DIVISION_NAME", + ), + Feature( + name="f_snowflake_call_center_zipcode_with_preprocessing", + key=call_sk_id, + feature_type=STRING, + transform="NEW_CC_ZIP", + ), ] - feature_anchor = FeatureAnchor(name="snowflake_features", - source=batch_source, - features=features, - ) + feature_anchor = FeatureAnchor( + name="snowflake_features", + source=batch_source, + features=features, + ) client.build_features(anchor_list=[feature_anchor]) feature_query = FeatureQuery( - feature_list=['f_snowflake_call_center_division_name_with_preprocessing', 'f_snowflake_call_center_zipcode_with_preprocessing'], - key=call_sk_id) - - observation_path = client.get_snowflake_path(database="SNOWFLAKE_SAMPLE_DATA", schema="TPCDS_SF10TCL", dbtable="CALL_CENTER") - settings = ObservationSettings( - observation_path=observation_path) + feature_list=[ + "f_snowflake_call_center_division_name_with_preprocessing", + "f_snowflake_call_center_zipcode_with_preprocessing", + ], + key=call_sk_id, + ) + + observation_path = client.get_snowflake_path( + database="SNOWFLAKE_SAMPLE_DATA", schema="TPCDS_SF10TCL", dbtable="CALL_CENTER" + ) + settings = ObservationSettings(observation_path=observation_path) now = datetime.now() # set output folder based on different runtime - if client.spark_runtime == 'databricks': - output_path = ''.join(['dbfs:/feathrazure_cijob_snowflake', '_', str(now.minute), '_', str(now.second), ".avro"]) + if client.spark_runtime == "databricks": + output_path = "".join( + ["dbfs:/feathrazure_cijob_snowflake", "_", str(now.minute), "_", str(now.second), ".avro"] + ) else: - output_path = ''.join(['abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/demo_data/snowflake_output','_', str(now.minute), '_', str(now.second), ".avro"]) - - client.get_offline_features(observation_settings=settings, - feature_query=feature_query, - output_path=output_path) + output_path = "".join( + [ + "abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/demo_data/snowflake_output", + "_", + str(now.minute), + "_", + str(now.second), + ".avro", + ] + ) + + client.get_offline_features(observation_settings=settings, feature_query=feature_query, output_path=output_path) # assuming the job can successfully run; otherwise it will throw exception client.wait_job_to_finish(timeout_sec=Constants.SPARK_JOB_TIMEOUT_SECONDS) diff --git a/feathr_project/test/test_registry_client.py b/feathr_project/test/test_registry_client.py index 784fbbc29..87cca9657 100644 --- a/feathr_project/test/test_registry_client.py +++ b/feathr_project/test/test_registry_client.py @@ -11,11 +11,18 @@ from feathr.definition.transformation import ExpressionTransformation, WindowAggTransformation from feathr.definition.typed_key import TypedKey -from feathr.registry._feathr_registry_client import _FeatureRegistry, dict_to_source, dict_to_anchor, dict_to_feature, dict_to_derived_feature, dict_to_project +from feathr.registry._feathr_registry_client import ( + _FeatureRegistry, + dict_to_source, + dict_to_anchor, + dict_to_feature, + dict_to_derived_feature, + dict_to_project, +) def test_parse_source(): - s = r'''{ + s = r"""{ "attributes": { "name": "nycTaxiBatchSource", "path": "wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04.csv", @@ -34,17 +41,19 @@ def test_parse_source(): "qualifiedName": "feathr_getting_started__nycTaxiBatchSource", "status": "Active", "typeName": "feathr_source_v1" - }''' + }""" source = dict_to_source(json.loads(s)) assert isinstance(source, HdfsSource) assert source.name == "nycTaxiBatchSource" - assert source.path == "wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04.csv" + assert ( + source.path == "wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04.csv" + ) assert source._registry_id == UUID("c4a0ae0f-09cc-43bf-94e9-21ff178fbda6") assert source._qualified_name == "feathr_getting_started__nycTaxiBatchSource" def test_parse_anchor(): - s = r'''{ + s = r"""{ "attributes": { "features": [ { @@ -83,7 +92,7 @@ def test_parse_anchor(): "qualifiedName": "feathr_getting_started__request_features", "status": "Active", "typeName": "feathr_anchor_v1" - }''' + }""" anchor = dict_to_anchor(json.loads(s)) # Parsed anchor is empty, features and source are added later assert anchor.name == "request_features" @@ -92,7 +101,7 @@ def test_parse_anchor(): def test_parse_feature(): - s = r'''{ + s = r"""{ "attributes": { "key": [ { @@ -124,7 +133,7 @@ def test_parse_feature(): "qualifiedName": "feathr_getting_started__request_features__f_is_long_trip_distance", "status": "Active", "typeName": "feathr_anchor_feature_v1" - }''' + }""" f: Feature = dict_to_feature(json.loads(s)) assert f.feature_type == BOOLEAN assert f.name == "f_is_long_trip_distance" @@ -139,7 +148,7 @@ def test_parse_feature(): def test_parse_derived_feature(): - s = r'''{ + s = r"""{ "attributes": { "inputAnchorFeatures": [], "inputDerivedFeatures": [ @@ -181,7 +190,7 @@ def test_parse_derived_feature(): "qualifiedName": "feathr_getting_started__f_trip_time_rounded_plus", "status": "Active", "typeName": "feathr_derived_feature_v1" - }''' + }""" df = dict_to_derived_feature(json.loads(s)) assert df.name == "f_trip_time_rounded_plus" assert df.feature_type == INT32 @@ -191,110 +200,121 @@ def test_parse_derived_feature(): assert df._qualified_name == "feathr_getting_started__f_trip_time_rounded_plus" assert df._registry_id == UUID("479c6306-5fdb-4e06-9008-c18f68db52a4") + def test_parse_project(): filename = os.path.join(os.path.dirname(__file__), "test_registry_lineage.json") - f=open(filename, "r") + f = open(filename, "r") (anchors, derived_features) = dict_to_project(json.load(f)) - assert len(anchors)==2 - request_features = [a for a in anchors if a.name=='request_features'][0] + assert len(anchors) == 2 + request_features = [a for a in anchors if a.name == "request_features"][0] assert isinstance(request_features.source, InputContext) - assert len(request_features.features)==4 - aggregationFeatures = [a for a in anchors if a.name=='aggregationFeatures'][0] - assert len(aggregationFeatures.features)==2 + assert len(request_features.features) == 4 + aggregationFeatures = [a for a in anchors if a.name == "aggregationFeatures"][0] + assert len(aggregationFeatures.features) == 2 assert isinstance(aggregationFeatures.source, HdfsSource) - assert aggregationFeatures.source.name=="nycTaxiBatchSource" - assert aggregationFeatures.source.path=="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04.csv" - assert len(derived_features)==3 + assert aggregationFeatures.source.name == "nycTaxiBatchSource" + assert ( + aggregationFeatures.source.path + == "wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04.csv" + ) + assert len(derived_features) == 3 + def test_registry_client_list_features(): c = _FeatureRegistry(project_name="p", endpoint="https://feathr-sql-registry.azurewebsites.net/api/v1") f = [e["qualifiedName"] for e in c.list_registered_features("feathr_ci_registry_getting_started")] print(f) - assert len(f)==9 + assert len(f) == 9 for i in f: assert i.startswith("feathr_ci_registry_getting_started__") - + + def test_registry_client_load(): c = _FeatureRegistry(project_name="p", endpoint="https://feathr-sql-registry.azurewebsites.net/api/v1") (anchors, derived_features) = c.get_features_from_registry("feathr_ci_registry_getting_started") - assert len(anchors)==2 - request_features = [a for a in anchors if a.name=='request_features'][0] + assert len(anchors) == 2 + request_features = [a for a in anchors if a.name == "request_features"][0] assert isinstance(request_features.source, InputContext) - assert len(request_features.features)==4 - aggregationFeatures = [a for a in anchors if a.name=='aggregationFeatures'][0] - assert len(aggregationFeatures.features)==3 + assert len(request_features.features) == 4 + aggregationFeatures = [a for a in anchors if a.name == "aggregationFeatures"][0] + assert len(aggregationFeatures.features) == 3 assert isinstance(aggregationFeatures.source, HdfsSource) - assert aggregationFeatures.source.name=="nycTaxiBatchSource" - assert aggregationFeatures.source.path=="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04_with_index.csv" - assert len(derived_features)==2 + assert aggregationFeatures.source.name == "nycTaxiBatchSource" + assert ( + aggregationFeatures.source.path + == "wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04_with_index.csv" + ) + assert len(derived_features) == 2 + def test_create(): project_name = f"feathr_registry_client_test_{int(time.time())}" c = _FeatureRegistry(project_name="p", endpoint="https://feathr-sql-registry.azurewebsites.net/api/v1") - - batch_source = HdfsSource(name="nycTaxiBatchSource", - path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04.csv", - event_timestamp_column="lpep_dropoff_datetime", - timestamp_format="yyyy-MM-dd HH:mm:ss") - f_trip_distance = Feature(name="f_trip_distance", - feature_type=FLOAT, transform="trip_distance") - f_trip_time_duration = Feature(name="f_trip_time_duration", - feature_type=INT32, - transform="(to_unix_timestamp(lpep_dropoff_datetime) - to_unix_timestamp(lpep_pickup_datetime))/60") + batch_source = HdfsSource( + name="nycTaxiBatchSource", + path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04.csv", + event_timestamp_column="lpep_dropoff_datetime", + timestamp_format="yyyy-MM-dd HH:mm:ss", + ) + + f_trip_distance = Feature(name="f_trip_distance", feature_type=FLOAT, transform="trip_distance") + f_trip_time_duration = Feature( + name="f_trip_time_duration", + feature_type=INT32, + transform="(to_unix_timestamp(lpep_dropoff_datetime) - to_unix_timestamp(lpep_pickup_datetime))/60", + ) features = [ f_trip_distance, f_trip_time_duration, - Feature(name="f_is_long_trip_distance", - feature_type=BOOLEAN, - transform="cast_float(trip_distance)>30"), - Feature(name="f_day_of_week", - feature_type=INT32, - transform="dayofweek(lpep_dropoff_datetime)"), + Feature(name="f_is_long_trip_distance", feature_type=BOOLEAN, transform="cast_float(trip_distance)>30"), + Feature(name="f_day_of_week", feature_type=INT32, transform="dayofweek(lpep_dropoff_datetime)"), ] + request_anchor = FeatureAnchor(name="request_features", source=INPUT_CONTEXT, features=features) - request_anchor = FeatureAnchor(name="request_features", - source=INPUT_CONTEXT, - features=features) + f_trip_time_distance = DerivedFeature( + name="f_trip_time_distance", + feature_type=FLOAT, + input_features=[f_trip_distance, f_trip_time_duration], + transform="f_trip_distance * f_trip_time_duration", + ) - f_trip_time_distance = DerivedFeature(name="f_trip_time_distance", - feature_type=FLOAT, - input_features=[ - f_trip_distance, f_trip_time_duration], - transform="f_trip_distance * f_trip_time_duration") + f_trip_time_rounded = DerivedFeature( + name="f_trip_time_rounded", + feature_type=INT32, + input_features=[f_trip_time_duration], + transform="f_trip_time_duration % 10", + ) - f_trip_time_rounded = DerivedFeature(name="f_trip_time_rounded", - feature_type=INT32, - input_features=[f_trip_time_duration], - transform="f_trip_time_duration % 10") + location_id = TypedKey( + key_column="DOLocationID", + key_column_type=ValueType.INT32, + description="location id in NYC", + full_name="nyc_taxi.location_id", + ) + agg_features = [ + Feature( + name="f_location_avg_fare", + key=location_id, + feature_type=FLOAT, + transform=WindowAggTransformation(agg_expr="cast_float(fare_amount)", agg_func="AVG", window="90d"), + ), + Feature( + name="f_location_max_fare", + key=location_id, + feature_type=FLOAT, + transform=WindowAggTransformation(agg_expr="cast_float(fare_amount)", agg_func="MAX", window="90d"), + ), + ] - location_id = TypedKey(key_column="DOLocationID", - key_column_type=ValueType.INT32, - description="location id in NYC", - full_name="nyc_taxi.location_id") - agg_features = [Feature(name="f_location_avg_fare", - key=location_id, - feature_type=FLOAT, - transform=WindowAggTransformation(agg_expr="cast_float(fare_amount)", - agg_func="AVG", - window="90d")), - Feature(name="f_location_max_fare", - key=location_id, - feature_type=FLOAT, - transform=WindowAggTransformation(agg_expr="cast_float(fare_amount)", - agg_func="MAX", - window="90d")) - ] + agg_anchor = FeatureAnchor(name="aggregationFeatures", source=batch_source, features=agg_features) - agg_anchor = FeatureAnchor(name="aggregationFeatures", - source=batch_source, - features=agg_features) - - c.register_features(anchor_list=[agg_anchor, request_anchor], derived_feature_list=[f_trip_time_distance, f_trip_time_rounded]) + c.register_features( + anchor_list=[agg_anchor, request_anchor], derived_feature_list=[f_trip_time_distance, f_trip_time_rounded] + ) - if __name__ == "__main__": test_parse_source() diff --git a/feathr_project/test/test_secrets_read.py b/feathr_project/test/test_secrets_read.py index 2e5916825..87a68e8d0 100644 --- a/feathr_project/test/test_secrets_read.py +++ b/feathr_project/test/test_secrets_read.py @@ -4,21 +4,30 @@ from unittest import result from click.testing import CliRunner -from feathr import (BOOLEAN, FLOAT, INT32, FeatureQuery, ObservationSettings, - SparkExecutionConfiguration, TypedKey, ValueType) +from feathr import ( + BOOLEAN, + FLOAT, + INT32, + FeatureQuery, + ObservationSettings, + SparkExecutionConfiguration, + TypedKey, + ValueType, +) from feathr.client import FeathrClient from feathr.utils.job_utils import get_result_df from test_fixture import basic_test_setup from feathr.constants import OUTPUT_FORMAT + # test parquet file read/write without an extension name def test_feathr_get_secrets_from_key_vault(): """ Test if the program can read the key vault secrets as expected """ # TODO: need to test get_environment_variable() as well - os.environ['SECRETS__AZURE_KEY_VAULT__NAME'] = 'feathrazuretest3-kv' + os.environ["SECRETS__AZURE_KEY_VAULT__NAME"] = "feathrazuretest3-kv" # the config below doesn't have `ONLINE_STORE__REDIS__HOST` for testing purpose yaml_config = """ @@ -65,4 +74,3 @@ def test_feathr_get_secrets_from_key_vault(): client = FeathrClient(config_path="/tmp/feathr_config.yaml") # `redis_host` should be there since it's not available in the environment variable, and not in the config file, we expect we get it from azure key_vault assert client.redis_host is not None - diff --git a/feathr_project/test/test_spark_sql_source.py b/feathr_project/test/test_spark_sql_source.py index 194e728f1..74ea8f3be 100644 --- a/feathr_project/test/test_spark_sql_source.py +++ b/feathr_project/test/test_spark_sql_source.py @@ -1,9 +1,18 @@ import os from datetime import datetime, timedelta from pathlib import Path -from feathr import (BOOLEAN, FLOAT, INPUT_CONTEXT, INT32, - DerivedFeature, Feature, FeatureAnchor, - TypedKey, ValueType, WindowAggTransformation) +from feathr import ( + BOOLEAN, + FLOAT, + INPUT_CONTEXT, + INT32, + DerivedFeature, + Feature, + FeatureAnchor, + TypedKey, + ValueType, + WindowAggTransformation, +) import pytest from feathr import FeathrClient @@ -19,8 +28,7 @@ def test_feathr_spark_sql_query_source(): - test_workspace_dir = Path( - __file__).parent.resolve() / "test_user_workspace" + test_workspace_dir = Path(__file__).parent.resolve() / "test_user_workspace" config_path = os.path.join(test_workspace_dir, "feathr_config.yaml") _get_offline_features(config_path, _sql_query_source()) @@ -30,26 +38,37 @@ def test_feathr_spark_sql_query_source(): def _get_offline_features(config_path: str, sql_source: SparkSqlSource): client: FeathrClient = _spark_sql_test_setup(config_path, sql_source) - location_id = TypedKey(key_column="DOLocationID", - key_column_type=ValueType.INT32, - description="location id in NYC", - full_name="nyc_taxi.location_id") - - feature_query = FeatureQuery( - feature_list=["f_location_avg_fare"], key=location_id) + location_id = TypedKey( + key_column="DOLocationID", + key_column_type=ValueType.INT32, + description="location id in NYC", + full_name="nyc_taxi.location_id", + ) + + feature_query = FeatureQuery(feature_list=["f_location_avg_fare"], key=location_id) settings = ObservationSettings( observation_path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04.csv", event_timestamp_column="lpep_dropoff_datetime", - timestamp_format="yyyy-MM-dd HH:mm:ss") + timestamp_format="yyyy-MM-dd HH:mm:ss", + ) now = datetime.now() - if client.spark_runtime == 'databricks': - output_path = ''.join(['dbfs:/feathrazure_cijob_materialize_offline_','_', str(now.minute), '_', str(now.second), ""]) + if client.spark_runtime == "databricks": + output_path = "".join( + ["dbfs:/feathrazure_cijob_materialize_offline_", "_", str(now.minute), "_", str(now.second), ""] + ) else: - output_path = ''.join(['abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/demo_data/feathrazure_cijob_materialize_offline_','_', str(now.minute), '_', str(now.second), ""]) - client.get_offline_features(observation_settings=settings, - feature_query=feature_query, - output_path=output_path) + output_path = "".join( + [ + "abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/demo_data/feathrazure_cijob_materialize_offline_", + "_", + str(now.minute), + "_", + str(now.second), + "", + ] + ) + client.get_offline_features(observation_settings=settings, feature_query=feature_query, output_path=output_path) # assuming the job can successfully run; otherwise it will throw exception client.wait_job_to_finish(timeout_sec=Constants.SPARK_JOB_TIMEOUT_SECONDS) @@ -58,101 +77,119 @@ def _get_offline_features(config_path: str, sql_source: SparkSqlSource): def _materialize_to_offline(config_path: str, sql_source: SparkSqlSource): client: FeathrClient = _spark_sql_test_setup(config_path, sql_source) - backfill_time = BackfillTime(start=datetime( - 2020, 5, 20), end=datetime(2020, 5, 20), step=timedelta(days=1)) + backfill_time = BackfillTime(start=datetime(2020, 5, 20), end=datetime(2020, 5, 20), step=timedelta(days=1)) now = datetime.now() - if client.spark_runtime == 'databricks': - output_path = ''.join(['dbfs:/feathrazure_cijob_materialize_offline_sparksql', - '_', str(now.minute), '_', str(now.second), ""]) + if client.spark_runtime == "databricks": + output_path = "".join( + ["dbfs:/feathrazure_cijob_materialize_offline_sparksql", "_", str(now.minute), "_", str(now.second), ""] + ) else: - output_path = ''.join(['abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/demo_data/feathrazure_cijob_materialize_offline_sparksql', - '_', str(now.minute), '_', str(now.second), ""]) + output_path = "".join( + [ + "abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/demo_data/feathrazure_cijob_materialize_offline_sparksql", + "_", + str(now.minute), + "_", + str(now.second), + "", + ] + ) offline_sink = HdfsSink(output_path=output_path) - settings = MaterializationSettings("nycTaxiTable", - sinks=[offline_sink], - feature_names=[ - "f_location_avg_fare", "f_location_max_fare"], - backfill_time=backfill_time) + settings = MaterializationSettings( + "nycTaxiTable", + sinks=[offline_sink], + feature_names=["f_location_avg_fare", "f_location_max_fare"], + backfill_time=backfill_time, + ) client.materialize_features(settings) # assuming the job can successfully run; otherwise it will throw exception client.wait_job_to_finish(timeout_sec=Constants.SPARK_JOB_TIMEOUT_SECONDS) # download result and just assert the returned result is not empty # by default, it will write to a folder appended with date - res_df = get_result_df( - client, "avro", output_path + "/df0/daily/2020/05/20") + res_df = get_result_df(client, "avro", output_path + "/df0/daily/2020/05/20") assert res_df.shape[0] > 0 def _spark_sql_test_setup(config_path: str, sql_source: SparkSqlSource): client = FeathrClient(config_path=config_path) - f_trip_distance = Feature(name="f_trip_distance", - feature_type=FLOAT, transform="trip_distance") - f_trip_time_duration = Feature(name="f_trip_time_duration", - feature_type=INT32, - transform="(to_unix_timestamp(lpep_dropoff_datetime) - to_unix_timestamp(lpep_pickup_datetime))/60") + f_trip_distance = Feature(name="f_trip_distance", feature_type=FLOAT, transform="trip_distance") + f_trip_time_duration = Feature( + name="f_trip_time_duration", + feature_type=INT32, + transform="(to_unix_timestamp(lpep_dropoff_datetime) - to_unix_timestamp(lpep_pickup_datetime))/60", + ) features = [ f_trip_distance, f_trip_time_duration, - Feature(name="f_is_long_trip_distance", - feature_type=BOOLEAN, - transform="cast_float(trip_distance)>30"), - Feature(name="f_day_of_week", - feature_type=INT32, - transform="dayofweek(lpep_dropoff_datetime)"), + Feature(name="f_is_long_trip_distance", feature_type=BOOLEAN, transform="cast_float(trip_distance)>30"), + Feature(name="f_day_of_week", feature_type=INT32, transform="dayofweek(lpep_dropoff_datetime)"), + ] + + request_anchor = FeatureAnchor(name="request_features", source=INPUT_CONTEXT, features=features) + + f_trip_time_distance = DerivedFeature( + name="f_trip_time_distance", + feature_type=FLOAT, + input_features=[f_trip_distance, f_trip_time_duration], + transform="f_trip_distance * f_trip_time_duration", + ) + + f_trip_time_rounded = DerivedFeature( + name="f_trip_time_rounded", + feature_type=INT32, + input_features=[f_trip_time_duration], + transform="f_trip_time_duration % 10", + ) + + location_id = TypedKey( + key_column="DOLocationID", + key_column_type=ValueType.INT32, + description="location id in NYC", + full_name="nyc_taxi.location_id", + ) + agg_features = [ + Feature( + name="f_location_avg_fare", + key=location_id, + feature_type=FLOAT, + transform=WindowAggTransformation( + agg_expr="cast_float(fare_amount)", agg_func="AVG", window="90d", filter="fare_amount > 0" + ), + ), + Feature( + name="f_location_max_fare", + key=location_id, + feature_type=FLOAT, + transform=WindowAggTransformation(agg_expr="cast_float(fare_amount)", agg_func="MAX", window="90d"), + ), ] - request_anchor = FeatureAnchor(name="request_features", - source=INPUT_CONTEXT, - features=features) - - f_trip_time_distance = DerivedFeature(name="f_trip_time_distance", - feature_type=FLOAT, - input_features=[ - f_trip_distance, f_trip_time_duration], - transform="f_trip_distance * f_trip_time_duration") - - f_trip_time_rounded = DerivedFeature(name="f_trip_time_rounded", - feature_type=INT32, - input_features=[f_trip_time_duration], - transform="f_trip_time_duration % 10") - - location_id = TypedKey(key_column="DOLocationID", - key_column_type=ValueType.INT32, - description="location id in NYC", - full_name="nyc_taxi.location_id") - agg_features = [Feature(name="f_location_avg_fare", - key=location_id, - feature_type=FLOAT, - transform=WindowAggTransformation(agg_expr="cast_float(fare_amount)", - agg_func="AVG", - window="90d", - filter="fare_amount > 0" - )), - Feature(name="f_location_max_fare", - key=location_id, - feature_type=FLOAT, - transform=WindowAggTransformation(agg_expr="cast_float(fare_amount)", - agg_func="MAX", - window="90d")) - ] - - agg_anchor = FeatureAnchor(name="aggregationFeatures", - source=sql_source, - features=agg_features) - - client.build_features(anchor_list=[agg_anchor, request_anchor], derived_feature_list=[ - f_trip_time_distance, f_trip_time_rounded]) + agg_anchor = FeatureAnchor(name="aggregationFeatures", source=sql_source, features=agg_features) + + client.build_features( + anchor_list=[agg_anchor, request_anchor], derived_feature_list=[f_trip_time_distance, f_trip_time_rounded] + ) return client def _sql_query_source(): - return SparkSqlSource(name="sparkSqlQuerySource", sql="SELECT * FROM green_tripdata_2020_04_with_index", event_timestamp_column="lpep_dropoff_datetime", timestamp_format="yyyy-MM-dd HH:mm:ss") + return SparkSqlSource( + name="sparkSqlQuerySource", + sql="SELECT * FROM green_tripdata_2020_04_with_index", + event_timestamp_column="lpep_dropoff_datetime", + timestamp_format="yyyy-MM-dd HH:mm:ss", + ) def _sql_table_source(): - return SparkSqlSource(name="sparkSqlTableSource", table="green_tripdata_2020_04_with_index", event_timestamp_column="lpep_dropoff_datetime", timestamp_format="yyyy-MM-dd HH:mm:ss") + return SparkSqlSource( + name="sparkSqlTableSource", + table="green_tripdata_2020_04_with_index", + event_timestamp_column="lpep_dropoff_datetime", + timestamp_format="yyyy-MM-dd HH:mm:ss", + ) diff --git a/feathr_project/test/test_sql_source.py b/feathr_project/test/test_sql_source.py index b584f2667..614482c6c 100644 --- a/feathr_project/test/test_sql_source.py +++ b/feathr_project/test/test_sql_source.py @@ -15,6 +15,7 @@ from test_fixture import get_online_test_table_name from test_utils.constants import Constants + def basic_test_setup(config_path: str): """ Basically this is same as the one in `text_fixture.py` with the same name. @@ -24,70 +25,71 @@ def basic_test_setup(config_path: str): client = FeathrClient(config_path=config_path) # Using database under @windoze account, so this e2e test still doesn't work in CI - batch_source = JdbcSource(name="nycTaxiBatchJdbcSource", - url="jdbc:sqlserver://feathrtestsql4.database.windows.net:1433;database=testsql;encrypt=true;trustServerCertificate=false;hostNameInCertificate=*.database.windows.net;loginTimeout=30;", - dbtable="green_tripdata_2020_04", - auth="USERPASS", - event_timestamp_column="lpep_dropoff_datetime", - timestamp_format="yyyy-MM-dd HH:mm:ss") - + batch_source = JdbcSource( + name="nycTaxiBatchJdbcSource", + url="jdbc:sqlserver://feathrtestsql4.database.windows.net:1433;database=testsql;encrypt=true;trustServerCertificate=false;hostNameInCertificate=*.database.windows.net;loginTimeout=30;", + dbtable="green_tripdata_2020_04", + auth="USERPASS", + event_timestamp_column="lpep_dropoff_datetime", + timestamp_format="yyyy-MM-dd HH:mm:ss", + ) - f_trip_distance = Feature(name="f_trip_distance", - feature_type=FLOAT, transform="trip_distance") - f_trip_time_duration = Feature(name="f_trip_time_duration", - feature_type=INT32, - transform="(to_unix_timestamp(lpep_dropoff_datetime) - to_unix_timestamp(lpep_pickup_datetime))/60") + f_trip_distance = Feature(name="f_trip_distance", feature_type=FLOAT, transform="trip_distance") + f_trip_time_duration = Feature( + name="f_trip_time_duration", + feature_type=INT32, + transform="(to_unix_timestamp(lpep_dropoff_datetime) - to_unix_timestamp(lpep_pickup_datetime))/60", + ) features = [ f_trip_distance, f_trip_time_duration, - Feature(name="f_is_long_trip_distance", - feature_type=BOOLEAN, - transform="cast_float(trip_distance)>30"), - Feature(name="f_day_of_week", - feature_type=INT32, - transform="dayofweek(lpep_dropoff_datetime)"), + Feature(name="f_is_long_trip_distance", feature_type=BOOLEAN, transform="cast_float(trip_distance)>30"), + Feature(name="f_day_of_week", feature_type=INT32, transform="dayofweek(lpep_dropoff_datetime)"), ] + request_anchor = FeatureAnchor(name="request_features", source=INPUT_CONTEXT, features=features) + + f_trip_time_distance = DerivedFeature( + name="f_trip_time_distance", + feature_type=FLOAT, + input_features=[f_trip_distance, f_trip_time_duration], + transform="f_trip_distance * f_trip_time_duration", + ) + + f_trip_time_rounded = DerivedFeature( + name="f_trip_time_rounded", + feature_type=INT32, + input_features=[f_trip_time_duration], + transform="f_trip_time_duration % 10", + ) + + location_id = TypedKey( + key_column="DOLocationID", + key_column_type=ValueType.INT32, + description="location id in NYC", + full_name="nyc_taxi.location_id", + ) - request_anchor = FeatureAnchor(name="request_features", - source=INPUT_CONTEXT, - features=features) - - f_trip_time_distance = DerivedFeature(name="f_trip_time_distance", - feature_type=FLOAT, - input_features=[ - f_trip_distance, f_trip_time_duration], - transform="f_trip_distance * f_trip_time_duration") - - f_trip_time_rounded = DerivedFeature(name="f_trip_time_rounded", - feature_type=INT32, - input_features=[f_trip_time_duration], - transform="f_trip_time_duration % 10") - - location_id = TypedKey(key_column="DOLocationID", - key_column_type=ValueType.INT32, - description="location id in NYC", - full_name="nyc_taxi.location_id") - # This feature is read from Jdbc data source - agg_features = [Feature(name="f_location_avg_fare", - key=location_id, - feature_type=FLOAT, - transform=WindowAggTransformation(agg_expr="cast_float(fare_amount)", - agg_func="AVG", - window="90d")), - ] + agg_features = [ + Feature( + name="f_location_avg_fare", + key=location_id, + feature_type=FLOAT, + transform=WindowAggTransformation(agg_expr="cast_float(fare_amount)", agg_func="AVG", window="90d"), + ), + ] - agg_anchor = FeatureAnchor(name="aggregationFeatures", - source=batch_source, - features=agg_features) + agg_anchor = FeatureAnchor(name="aggregationFeatures", source=batch_source, features=agg_features) - client.build_features(anchor_list=[agg_anchor, request_anchor], derived_feature_list=[ - f_trip_time_distance, f_trip_time_rounded]) + client.build_features( + anchor_list=[agg_anchor, request_anchor], derived_feature_list=[f_trip_time_distance, f_trip_time_rounded] + ) return client + @pytest.mark.skip(reason="Requires database with test data imported, which doesn't exist in the current CI env") def test_feathr_get_offline_features(): """ @@ -100,30 +102,29 @@ def test_feathr_get_offline_features(): These 2 variables will be passed to the Spark job in `--system-properties` parameter so Spark can access the database """ - test_workspace_dir = Path( - __file__).parent.resolve() / "test_user_workspace" + test_workspace_dir = Path(__file__).parent.resolve() / "test_user_workspace" # os.chdir(test_workspace_dir) client = basic_test_setup(os.path.join(test_workspace_dir, "feathr_config.yaml")) - location_id = TypedKey(key_column="DOLocationID", - key_column_type=ValueType.INT32, - description="location id in NYC", - full_name="nyc_taxi.location_id") + location_id = TypedKey( + key_column="DOLocationID", + key_column_type=ValueType.INT32, + description="location id in NYC", + full_name="nyc_taxi.location_id", + ) - feature_query = FeatureQuery( - feature_list=["f_location_avg_fare"], key=location_id) + feature_query = FeatureQuery(feature_list=["f_location_avg_fare"], key=location_id) settings = ObservationSettings( observation_path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04.csv", event_timestamp_column="lpep_dropoff_datetime", - timestamp_format="yyyy-MM-dd HH:mm:ss") + timestamp_format="yyyy-MM-dd HH:mm:ss", + ) now = datetime.now() - output_path = ''.join(['dbfs:/feathrazure_cijob','_', str(now.minute), '_', str(now.second), ".avro"]) - - client.get_offline_features(observation_settings=settings, - feature_query=feature_query, - output_path=output_path) + output_path = "".join(["dbfs:/feathrazure_cijob", "_", str(now.minute), "_", str(now.second), ".avro"]) + + client.get_offline_features(observation_settings=settings, feature_query=feature_query, output_path=output_path) # assuming the job can successfully run; otherwise it will throw exception client.wait_job_to_finish(timeout_sec=Constants.SPARK_JOB_TIMEOUT_SECONDS) @@ -132,6 +133,6 @@ def test_feathr_get_offline_features(): res_df = get_result_df(client) assert res_df.shape[0] > 0 - + if __name__ == "__main__": - test_feathr_get_offline_features() \ No newline at end of file + test_feathr_get_offline_features() diff --git a/feathr_project/test/test_time_partition_pattern_e2e.py b/feathr_project/test/test_time_partition_pattern_e2e.py index b6ada4fbd..14e1c835e 100644 --- a/feathr_project/test/test_time_partition_pattern_e2e.py +++ b/feathr_project/test/test_time_partition_pattern_e2e.py @@ -2,13 +2,18 @@ from datetime import datetime, timedelta from pathlib import Path from feathr import FeathrClient -from feathr import (BackfillTime, MaterializationSettings) +from feathr import BackfillTime, MaterializationSettings from feathr import FeathrClient from feathr import HdfsSink from feathr.utils.job_utils import get_result_df, copy_cloud_dir, cloud_dir_exists -from test_fixture import (basic_test_setup, time_partition_pattern_feature_gen_test_setup, time_partition_pattern_feature_join_test_setup) +from test_fixture import ( + basic_test_setup, + time_partition_pattern_feature_gen_test_setup, + time_partition_pattern_feature_join_test_setup, +) from test_utils.constants import Constants + ''' def setup_module(): """ @@ -60,135 +65,192 @@ def setup_module(): res_df_hourly = get_result_df(client_producer, data_format="avro", res_url=output_hourly_path) assert res_df_hourly.shape[0] > 0 ''' + + def test_feathr_materialize_with_time_partition_pattern(): """ Test FeathrClient() using HdfsSource with 'timePartitionPattern'. """ - test_workspace_dir = Path( - __file__).parent.resolve() / "test_user_workspace" - + test_workspace_dir = Path(__file__).parent.resolve() / "test_user_workspace" + client_dummy = FeathrClient(os.path.join(test_workspace_dir, "feathr_config.yaml")) - if client_dummy.spark_runtime == 'databricks': - source_path = 'dbfs:/timePartitionPattern_test/df0/daily/' + if client_dummy.spark_runtime == "databricks": + source_path = "dbfs:/timePartitionPattern_test/df0/daily/" else: - source_path = 'abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/timePartitionPattern_test/df0/daily/' - - client: FeathrClient = time_partition_pattern_feature_gen_test_setup(os.path.join(test_workspace_dir, "feathr_config.yaml"), source_path,local_workspace_dir="test_materialize_tpp") - - backfill_time_tpp = BackfillTime(start=datetime( - 2020, 5, 2), end=datetime(2020, 5, 2), step=timedelta(days=1)) + source_path = "abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/timePartitionPattern_test/df0/daily/" + + client: FeathrClient = time_partition_pattern_feature_gen_test_setup( + os.path.join(test_workspace_dir, "feathr_config.yaml"), source_path, local_workspace_dir="test_materialize_tpp" + ) + + backfill_time_tpp = BackfillTime(start=datetime(2020, 5, 2), end=datetime(2020, 5, 2), step=timedelta(days=1)) now = datetime.now() - if client.spark_runtime == 'databricks': - output_path_tpp = ''.join(['dbfs:/feathrazure_cijob_materialize_offline_','_', str(now.minute), '_', str(now.second), ""]) + if client.spark_runtime == "databricks": + output_path_tpp = "".join( + ["dbfs:/feathrazure_cijob_materialize_offline_", "_", str(now.minute), "_", str(now.second), ""] + ) else: - output_path_tpp = ''.join(['abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/demo_data/feathrazure_cijob_materialize_offline_','_', str(now.minute), '_', str(now.second), ""]) + output_path_tpp = "".join( + [ + "abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/demo_data/feathrazure_cijob_materialize_offline_", + "_", + str(now.minute), + "_", + str(now.second), + "", + ] + ) offline_sink_tpp = HdfsSink(output_path=output_path_tpp) - settings_tpp = MaterializationSettings("nycTaxiTable", - sinks=[offline_sink_tpp], - feature_names=[ - "f_loc_avg_output", "f_loc_max_output"], - backfill_time=backfill_time_tpp) + settings_tpp = MaterializationSettings( + "nycTaxiTable", + sinks=[offline_sink_tpp], + feature_names=["f_loc_avg_output", "f_loc_max_output"], + backfill_time=backfill_time_tpp, + ) client.materialize_features(settings_tpp) client.wait_job_to_finish(timeout_sec=Constants.SPARK_JOB_TIMEOUT_SECONDS) - + res_df = get_result_df(client, data_format="avro", res_url=output_path_tpp + "/df0/daily/2020/05/02") assert res_df.shape[0] > 0 - + + def test_feathr_materialize_with_time_partition_pattern_postfix_path(): """ Test FeathrClient() using HdfsSource with 'timePartitionPattern' and 'postfixPath'. - """ - test_workspace_dir = Path( - __file__).parent.resolve() / "test_user_workspace" - + """ + test_workspace_dir = Path(__file__).parent.resolve() / "test_user_workspace" + client_dummy = FeathrClient(os.path.join(test_workspace_dir, "feathr_config.yaml")) - if client_dummy.spark_runtime == 'databricks': - source_path = 'dbfs:/timePartitionPattern_postfix_test/df0/daily/' + if client_dummy.spark_runtime == "databricks": + source_path = "dbfs:/timePartitionPattern_postfix_test/df0/daily/" else: - source_path = 'abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/timePartitionPattern_postfix_test/df0/daily/' - - client: FeathrClient = time_partition_pattern_feature_gen_test_setup(os.path.join(test_workspace_dir, "feathr_config.yaml"), source_path,local_workspace_dir="test_materialize_tpp_postfix", postfix_path='postfixPath') + source_path = "abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/timePartitionPattern_postfix_test/df0/daily/" - backfill_time_pf = BackfillTime(start=datetime( - 2020, 5, 2), end=datetime(2020, 5, 2), step=timedelta(days=1)) + client: FeathrClient = time_partition_pattern_feature_gen_test_setup( + os.path.join(test_workspace_dir, "feathr_config.yaml"), + source_path, + local_workspace_dir="test_materialize_tpp_postfix", + postfix_path="postfixPath", + ) + + backfill_time_pf = BackfillTime(start=datetime(2020, 5, 2), end=datetime(2020, 5, 2), step=timedelta(days=1)) now = datetime.now() - if client.spark_runtime == 'databricks': - output_path_pf = ''.join(['dbfs:/feathrazure_cijob_materialize_offline_','_', str(now.minute), '_', str(now.second), ""]) + if client.spark_runtime == "databricks": + output_path_pf = "".join( + ["dbfs:/feathrazure_cijob_materialize_offline_", "_", str(now.minute), "_", str(now.second), ""] + ) else: - output_path_pf = ''.join(['abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/demo_data/feathrazure_cijob_materialize_offline_','_', str(now.minute), '_', str(now.second), ""]) + output_path_pf = "".join( + [ + "abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/demo_data/feathrazure_cijob_materialize_offline_", + "_", + str(now.minute), + "_", + str(now.second), + "", + ] + ) offline_sink_pf = HdfsSink(output_path=output_path_pf) - settings_pf = MaterializationSettings("nycTaxiTable", - sinks=[offline_sink_pf], - feature_names=[ - "f_loc_avg_output", "f_loc_max_output"], - backfill_time=backfill_time_pf) + settings_pf = MaterializationSettings( + "nycTaxiTable", + sinks=[offline_sink_pf], + feature_names=["f_loc_avg_output", "f_loc_max_output"], + backfill_time=backfill_time_pf, + ) client.materialize_features(settings_pf) client.wait_job_to_finish(timeout_sec=Constants.SPARK_JOB_TIMEOUT_SECONDS) res_df = get_result_df(client, data_format="avro", res_url=output_path_pf + "/df0/daily/2020/05/02") assert res_df.shape[0] > 0 - + + def test_feathr_materialize_with_time_partition_pattern_hourly(): """ Test FeathrClient() using HdfsSource with hourly 'timePartitionPattern'. - """ - test_workspace_dir = Path( - __file__).parent.resolve() / "test_user_workspace" - + """ + test_workspace_dir = Path(__file__).parent.resolve() / "test_user_workspace" + client_dummy = FeathrClient(os.path.join(test_workspace_dir, "feathr_config.yaml")) - if client_dummy.spark_runtime == 'databricks': - source_path = 'dbfs:/timePartitionPattern_hourly_test/df0/daily/' + if client_dummy.spark_runtime == "databricks": + source_path = "dbfs:/timePartitionPattern_hourly_test/df0/daily/" else: - source_path = 'abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/timePartitionPattern_hourly_test/df0/daily/' - - client: FeathrClient = time_partition_pattern_feature_gen_test_setup(os.path.join(test_workspace_dir, "feathr_config.yaml"), source_path,local_workspace_dir="test_materialize_hourly", resolution='HOURLY') + source_path = "abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/timePartitionPattern_hourly_test/df0/daily/" + + client: FeathrClient = time_partition_pattern_feature_gen_test_setup( + os.path.join(test_workspace_dir, "feathr_config.yaml"), + source_path, + local_workspace_dir="test_materialize_hourly", + resolution="HOURLY", + ) - backfill_time_tpp = BackfillTime(start=datetime( - 2020, 5, 2), end=datetime(2020, 5, 2), step=timedelta(days=1)) + backfill_time_tpp = BackfillTime(start=datetime(2020, 5, 2), end=datetime(2020, 5, 2), step=timedelta(days=1)) now = datetime.now() - if client.spark_runtime == 'databricks': - output_path_tpp = ''.join(['dbfs:/feathrazure_cijob_materialize_offline_','_', str(now.minute), '_', str(now.second), ""]) + if client.spark_runtime == "databricks": + output_path_tpp = "".join( + ["dbfs:/feathrazure_cijob_materialize_offline_", "_", str(now.minute), "_", str(now.second), ""] + ) else: - output_path_tpp = ''.join(['abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/demo_data/feathrazure_cijob_materialize_offline_','_', str(now.minute), '_', str(now.second), ""]) + output_path_tpp = "".join( + [ + "abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/demo_data/feathrazure_cijob_materialize_offline_", + "_", + str(now.minute), + "_", + str(now.second), + "", + ] + ) offline_sink_tpp = HdfsSink(output_path=output_path_tpp) - settings_tpp = MaterializationSettings("nycTaxiTable", - sinks=[offline_sink_tpp], - feature_names=[ - "f_loc_avg_output", "f_loc_max_output"], - backfill_time=backfill_time_tpp, - resolution = 'HOURLY') + settings_tpp = MaterializationSettings( + "nycTaxiTable", + sinks=[offline_sink_tpp], + feature_names=["f_loc_avg_output", "f_loc_max_output"], + backfill_time=backfill_time_tpp, + resolution="HOURLY", + ) client.materialize_features(settings_tpp) client.wait_job_to_finish(timeout_sec=Constants.SPARK_JOB_TIMEOUT_SECONDS) res_df = get_result_df(client, data_format="avro", res_url=output_path_tpp + "/df0/daily/2020/05/02/00") assert res_df.shape[0] > 0 + def test_feathr_get_offline_with_time_partition_pattern_postfix_path(): """ Test FeathrClient() using HdfsSource with 'timePartitionPattern' and 'postfixPath'. - """ - test_workspace_dir = Path( - __file__).parent.resolve() / "test_user_workspace" - + """ + test_workspace_dir = Path(__file__).parent.resolve() / "test_user_workspace" + client_dummy = FeathrClient(os.path.join(test_workspace_dir, "feathr_config.yaml")) - if client_dummy.spark_runtime == 'databricks': - source_path = 'dbfs:/timePartitionPattern_postfix_test/df0/daily/' + if client_dummy.spark_runtime == "databricks": + source_path = "dbfs:/timePartitionPattern_postfix_test/df0/daily/" else: - source_path = 'abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/timePartitionPattern_postfix_test/df0/daily/' - - [client, feature_query, settings] = time_partition_pattern_feature_join_test_setup(os.path.join(test_workspace_dir, "feathr_config.yaml"), source_path, local_workspace_dir="test_offline_tpp",postfix_path='postfixPath') + source_path = "abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/timePartitionPattern_postfix_test/df0/daily/" + + [client, feature_query, settings] = time_partition_pattern_feature_join_test_setup( + os.path.join(test_workspace_dir, "feathr_config.yaml"), + source_path, + local_workspace_dir="test_offline_tpp", + postfix_path="postfixPath", + ) now = datetime.now() - if client.spark_runtime == 'databricks': - output_path = ''.join(['dbfs:/feathrazure_cijob','_', str(now.minute), '_', str(now.second), ".avro"]) + if client.spark_runtime == "databricks": + output_path = "".join(["dbfs:/feathrazure_cijob", "_", str(now.minute), "_", str(now.second), ".avro"]) else: - output_path = ''.join(['abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/demo_data/output','_', str(now.minute), '_', str(now.second), ".avro"]) - - client.get_offline_features(observation_settings=settings, - feature_query=feature_query, - output_path=output_path) + output_path = "".join( + [ + "abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/demo_data/output", + "_", + str(now.minute), + "_", + str(now.second), + ".avro", + ] + ) + + client.get_offline_features(observation_settings=settings, feature_query=feature_query, output_path=output_path) client.wait_job_to_finish(timeout_sec=Constants.SPARK_JOB_TIMEOUT_SECONDS) - res_df = get_result_df(client, data_format="avro", res_url = output_path) + res_df = get_result_df(client, data_format="avro", res_url=output_path) assert res_df.shape[0] > 0 - diff --git a/feathr_project/test/test_utils/data_generator.py b/feathr_project/test/test_utils/data_generator.py index f8069e151..b884e6de8 100644 --- a/feathr_project/test/test_utils/data_generator.py +++ b/feathr_project/test/test_utils/data_generator.py @@ -28,7 +28,7 @@ # Generate fake products and add them to the list for i in range(NUM_PRODUCTS): - product_id = i+10000 + product_id = i + 10000 product_name = fake.ecommerce_name() category = fake.ecommerce_category() price = round(random.uniform(1, 1000), 2) @@ -38,14 +38,14 @@ discount = round(random.uniform(0, 0.5), 2) product = { - 'product_id': product_id, - 'product_name': product_name, - 'category': category, - 'price': price, - 'quantity': quantity, - 'recent_sold': recent_sold, - 'made_in_state': made_in_state, - 'discount': discount + "product_id": product_id, + "product_name": product_name, + "category": category, + "price": price, + "quantity": quantity, + "recent_sold": recent_sold, + "made_in_state": made_in_state, + "discount": discount, } products.append(product) @@ -59,37 +59,39 @@ # Generate fake data using the profile API and add it to the DataFrame for i in range(NUM_CUSTOMERS): profile = fake.profile() - name = profile['name'] - username = profile['username'] - email = profile['mail'] - job_title = profile['job'] # fix here - company = profile['company'] - address = profile['address'] + name = profile["name"] + username = profile["username"] + email = profile["mail"] + job_title = profile["job"] # fix here + company = profile["company"] + address = profile["address"] phone = fake.phone_number() user_id = i - gender = profile['sex'] + gender = profile["sex"] age = random.randint(18, 80) gift_card_balance = round(random.uniform(0, 1000), 2) number_of_credit_cards = random.randint(1, 5) state = fake.state_abbr() tax_rate = round(random.uniform(0.05, 0.15), 3) - customers.append({ - 'name': name, - 'username': username, - 'email': email, - 'job_title': job_title, - 'company': company, - 'address': address, - 'phone': phone, - 'user_id': user_id, - 'gender': gender, - 'age': age, - 'gift_card_balance': gift_card_balance, - 'number_of_credit_cards': number_of_credit_cards, - 'state': state, - 'tax_rate': tax_rate - }) + customers.append( + { + "name": name, + "username": username, + "email": email, + "job_title": job_title, + "company": company, + "address": address, + "phone": phone, + "user_id": user_id, + "gender": gender, + "age": age, + "gift_card_balance": gift_card_balance, + "number_of_credit_cards": number_of_credit_cards, + "state": state, + "tax_rate": tax_rate, + } + ) # Create a Pandas DataFrame from the list of purchases customers_df = pd.DataFrame(customers) @@ -100,36 +102,36 @@ # Generate fake purchase data and add it to the list for i in range(NUM_PURCHASES): - user_id = random.choice(customers)['user_id'] - purchase_date = fake.date_between(start_date='-1y', end_date='today') + user_id = random.choice(customers)["user_id"] + purchase_date = fake.date_between(start_date="-1y", end_date="today") purchase_amount = round(random.uniform(10, 500), 2) - product_id = random.choice(products)['product_id'] + product_id = random.choice(products)["product_id"] transaction_id = fake.uuid4() price = round(random.uniform(1, 100), 2) discounts = round(random.uniform(0, 20), 2) taxes_and_fees = round(random.uniform(0, 10), 2) total_cost = (price * quantity) - discounts + taxes_and_fees - payment_method = random.choice(['Credit Card', 'PayPal', 'Apple Pay', 'Google Wallet']) + payment_method = random.choice(["Credit Card", "PayPal", "Apple Pay", "Google Wallet"]) shipping_address = fake.address() - status = random.choice(['Pending', 'Complete', 'Refunded']) + status = random.choice(["Pending", "Complete", "Refunded"]) notes = fake.sentence() purchase_quantity = random.randint(1, 100) purchase = { - 'user_id': user_id, - 'purchase_date': purchase_date, - 'purchase_amount': purchase_amount, - 'product_id': product_id, - 'purchase_quantity': purchase_quantity, - 'transaction_id': transaction_id, - 'price': price, - 'discounts': discounts, - 'taxes_and_fees': taxes_and_fees, - 'total_cost': total_cost, - 'payment_method': payment_method, - 'shipping_address': shipping_address, - 'status': status, - 'notes': notes, + "user_id": user_id, + "purchase_date": purchase_date, + "purchase_amount": purchase_amount, + "product_id": product_id, + "purchase_quantity": purchase_quantity, + "transaction_id": transaction_id, + "price": price, + "discounts": discounts, + "taxes_and_fees": taxes_and_fees, + "total_cost": total_cost, + "payment_method": payment_method, + "shipping_address": shipping_address, + "status": status, + "notes": notes, } purchases.append(purchase) @@ -143,17 +145,12 @@ # Generate fake user observation data and add it to the list for i in range(NUM_OBSERVATION): - user_id = random.choice(customers)['user_id'] - purchase_date = fake.date_between(start_date='-1y', end_date='today') - product_id = random.choice(products)['product_id'] + user_id = random.choice(customers)["user_id"] + purchase_date = fake.date_between(start_date="-1y", end_date="today") + product_id = random.choice(products)["product_id"] browser = fake.user_agent() - observation = { - 'user_id': user_id, - 'purchase_date': purchase_date, - 'product_id': product_id, - 'browser': browser - } + observation = {"user_id": user_id, "purchase_date": purchase_date, "product_id": product_id, "browser": browser} observations.append(observation) @@ -168,13 +165,13 @@ print(observation_df) # Save the products DataFrame to a CSV file -products_df.to_csv('product_detail_mock_data.csv', index=False) +products_df.to_csv("product_detail_mock_data.csv", index=False) # Save the purchases DataFrame to a CSV file -purchase_df.to_csv('user_purchase_history_mock_data.csv', index=False) +purchase_df.to_csv("user_purchase_history_mock_data.csv", index=False) # Save the purchases DataFrame to a CSV file -customers_df.to_csv('user_profile_mock_data.csv', index=False) +customers_df.to_csv("user_profile_mock_data.csv", index=False) # Save the observation_df DataFrame to a CSV file -observation_df.to_csv('user_observation_mock_data.csv', index=False) +observation_df.to_csv("user_observation_mock_data.csv", index=False) diff --git a/feathr_project/test/test_utils/query_sql.py b/feathr_project/test/test_utils/query_sql.py index 68412ebff..eaacdc1d0 100644 --- a/feathr_project/test/test_utils/query_sql.py +++ b/feathr_project/test/test_utils/query_sql.py @@ -3,6 +3,7 @@ # script to query SQL database for debugging purpose + def show_table(cursor, table_name): cursor.execute("select * from " + table_name + ";") print(cursor.fetchall()) @@ -23,7 +24,7 @@ def show_table(cursor, table_name): dbname = "postgres" user = "demo" env_config = EnvConfigReader(config_path=None) -password = env_config.get_from_env_or_akv('SQL_TEST_PASSWORD') +password = env_config.get_from_env_or_akv("SQL_TEST_PASSWORD") sslmode = "require" # Construct connection string diff --git a/feathr_project/test/test_utils/udfs.py b/feathr_project/test/test_utils/udfs.py index d3a8f0e6f..f515c98bb 100644 --- a/feathr_project/test/test_utils/udfs.py +++ b/feathr_project/test/test_utils/udfs.py @@ -1,12 +1,14 @@ from pyspark.sql import DataFrame from pyspark.sql.functions import col + def add_new_dropoff_and_fare_amount_column(df: DataFrame): df = df.withColumn("new_lpep_dropoff_datetime", col("lpep_dropoff_datetime")) df = df.withColumn("new_fare_amount", col("fare_amount") + 1000000) return df + def add_new_fare_amount(df: DataFrame) -> DataFrame: df = df.withColumn("fare_amount_new", col("fare_amount") + 8000000) - return df \ No newline at end of file + return df diff --git a/feathr_project/test/unit/datasets/test_dataset_utils.py b/feathr_project/test/unit/datasets/test_dataset_utils.py index 2aabaa9a1..f3e60240d 100644 --- a/feathr_project/test/unit/datasets/test_dataset_utils.py +++ b/feathr_project/test/unit/datasets/test_dataset_utils.py @@ -10,7 +10,8 @@ @pytest.mark.parametrize( # 3924447 is the nyc_taxi sample data's bytes - "expected_bytes", [3924447, None] + "expected_bytes", + [3924447, None], ) def test__maybe_download(expected_bytes: int): """Test maybe_download utility function w/ nyc_taxi data cached at Azure blob.""" diff --git a/feathr_project/test/unit/datasets/test_datasets.py b/feathr_project/test/unit/datasets/test_datasets.py index 8d3bece28..7198f470a 100644 --- a/feathr_project/test/unit/datasets/test_datasets.py +++ b/feathr_project/test/unit/datasets/test_datasets.py @@ -45,7 +45,8 @@ def test__nyc_taxi__get_pandas_df( @pytest.mark.parametrize( - "local_cache_path", [ + "local_cache_path", + [ NYC_TAXI_FILE_PATH, # full filepath str(Path(NYC_TAXI_FILE_PATH).parent), # directory ], @@ -62,25 +63,36 @@ def test__nyc_taxi__get_spark_df( df = nyc_taxi.get_spark_df(spark=spark, local_cache_path=local_cache_path) assert df.count() == 35612 - mocked_maybe_download.assert_called_once_with( - src_url=nyc_taxi.NYC_TAXI_SMALL_URL, dst_filepath=NYC_TAXI_FILE_PATH - ) + mocked_maybe_download.assert_called_once_with(src_url=nyc_taxi.NYC_TAXI_SMALL_URL, dst_filepath=NYC_TAXI_FILE_PATH) @pytest.mark.parametrize( - "local_cache_path, expected_python_cache_path, expected_spark_cache_path", [ + "local_cache_path, expected_python_cache_path, expected_spark_cache_path", + [ # With file path ("test_dir/test.csv", "/dbfs/test_dir/test.csv", "dbfs:/test_dir/test.csv"), # With directory path - ("test_dir", "/dbfs/test_dir/green_tripdata_2020-04_with_index.csv", "dbfs:/test_dir/green_tripdata_2020-04_with_index.csv"), + ( + "test_dir", + "/dbfs/test_dir/green_tripdata_2020-04_with_index.csv", + "dbfs:/test_dir/green_tripdata_2020-04_with_index.csv", + ), # With databricks python file path ("/dbfs/test_dir/test.csv", "/dbfs/test_dir/test.csv", "dbfs:/test_dir/test.csv"), # With databricks python directory path - ("/dbfs/test_dir", "/dbfs/test_dir/green_tripdata_2020-04_with_index.csv", "dbfs:/test_dir/green_tripdata_2020-04_with_index.csv"), + ( + "/dbfs/test_dir", + "/dbfs/test_dir/green_tripdata_2020-04_with_index.csv", + "dbfs:/test_dir/green_tripdata_2020-04_with_index.csv", + ), # With databricks spark file path ("dbfs:/test_dir/test.csv", "/dbfs/test_dir/test.csv", "dbfs:/test_dir/test.csv"), # With databricks spark directory path - ("dbfs:/test_dir", "/dbfs/test_dir/green_tripdata_2020-04_with_index.csv", "dbfs:/test_dir/green_tripdata_2020-04_with_index.csv"), + ( + "dbfs:/test_dir", + "/dbfs/test_dir/green_tripdata_2020-04_with_index.csv", + "dbfs:/test_dir/green_tripdata_2020-04_with_index.csv", + ), ], ) def test__nyc_taxi__get_spark_df__with_databricks( diff --git a/feathr_project/test/unit/spark_provider/test_localspark_submission.py b/feathr_project/test/unit/spark_provider/test_localspark_submission.py index 992f2015e..eb9130280 100644 --- a/feathr_project/test/unit/spark_provider/test_localspark_submission.py +++ b/feathr_project/test/unit/spark_provider/test_localspark_submission.py @@ -17,10 +17,11 @@ def local_spark_job_launcher(tmp_path) -> _FeathrLocalSparkJobLauncher: @pytest.mark.parametrize( - "job_tags,expected_result_uri", [ + "job_tags,expected_result_uri", + [ (None, None), ({OUTPUT_PATH_TAG: "output"}, "output"), - ] + ], ) def test__local_spark_job_launcher__submit_feathr_job( mocker: MockerFixture, @@ -51,9 +52,7 @@ def test__local_spark_job_launcher__submit_feathr_job( assert local_spark_job_launcher.get_job_result_uri() == expected_result_uri -@pytest.mark.parametrize( - "confs", [{}, {"spark.feathr.outputFormat": "parquet"}] -) +@pytest.mark.parametrize("confs", [{}, {"spark.feathr.outputFormat": "parquet"}]) def test__local_spark_job_launcher__init_args( local_spark_job_launcher: _FeathrLocalSparkJobLauncher, confs: Dict[str, str], diff --git a/feathr_project/test/unit/test_dtype.py b/feathr_project/test/unit/test_dtype.py index eb6aaf2ce..6d15d08b1 100644 --- a/feathr_project/test/unit/test_dtype.py +++ b/feathr_project/test/unit/test_dtype.py @@ -9,16 +9,13 @@ def test_key_type(): with pytest.raises(KeyError): key = TypedKey(key_column="key", key_column_type=INT32) + def test_feature_type(): key = TypedKey(key_column="key", key_column_type=ValueType.INT32) - feature = Feature(name="name", - key=key, - feature_type=INT32) - + feature = Feature(name="name", key=key, feature_type=INT32) + assert feature.feature_type == INT32 with pytest.raises(KeyError): - feature = Feature(name="name", - key=key, - feature_type=ValueType.INT32) \ No newline at end of file + feature = Feature(name="name", key=key, feature_type=ValueType.INT32) diff --git a/feathr_project/test/unit/udf/test_preprocessing_pyudf_manager.py b/feathr_project/test/unit/udf/test_preprocessing_pyudf_manager.py index 1daa87632..8b5b8e227 100644 --- a/feathr_project/test/unit/udf/test_preprocessing_pyudf_manager.py +++ b/feathr_project/test/unit/udf/test_preprocessing_pyudf_manager.py @@ -8,8 +8,11 @@ [ ("fn_without_type_hint", "def fn_without_type_hint(a):\n return a + 10\n"), ("fn_with_type_hint", "def fn_with_type_hint(a: int) -> int:\n return a + 10\n"), - ("fn_with_complex_type_hint", "def fn_with_complex_type_hint(a: Union[int, float]) -> Union[int, float]:\n return a + 10\n"), - ] + ( + "fn_with_complex_type_hint", + "def fn_with_complex_type_hint(a: Union[int, float]) -> Union[int, float]:\n return a + 10\n", + ), + ], ) def test__parse_function_str_for_name(fn_name, fn_str): assert fn_name == _PreprocessingPyudfManager._parse_function_str_for_name(fn_str) diff --git a/feathr_project/test/unit/utils/test_config.py b/feathr_project/test/unit/utils/test_config.py index 9bb5b4bae..b586ac45b 100644 --- a/feathr_project/test/unit/utils/test_config.py +++ b/feathr_project/test/unit/utils/test_config.py @@ -9,7 +9,8 @@ @pytest.mark.parametrize( - "output_filepath", [None, "config.yml"], + "output_filepath", + [None, "config.yml"], ) def test__generate_config__output_filepath( output_filepath: str, @@ -60,7 +61,7 @@ def test__generate_config__output_filepath( spark_config__azure_synapse__pool_name="pool_name", ), ), - ] + ], ) def test__generate_config__spark_cluster( mocker: MockerFixture, @@ -88,7 +89,7 @@ def test__generate_config__spark_cluster( ("some_key", "some_name", None), (None, "some_name", ValueError), ("some_key", None, ValueError), - ] + ], ) def test__generate_config__azure_synapse_exceptions( mocker: MockerFixture, @@ -99,10 +100,13 @@ def test__generate_config__azure_synapse_exceptions( """Test if exceptions are raised when databricks url and token are not provided.""" # Either env vars or argument should yield the same result - for environ in [{"ADLS_KEY": adls_key}, { - "ADLS_KEY": adls_key, - "SPARK_CONFIG__AZURE_SYNAPSE__POOL_NAME": pool_name, - }]: + for environ in [ + {"ADLS_KEY": adls_key}, + { + "ADLS_KEY": adls_key, + "SPARK_CONFIG__AZURE_SYNAPSE__POOL_NAME": pool_name, + }, + ]: # Mock the os.environ to return the specified env vars mocker.patch.object(feathr.utils.config.os, "environ", environ) @@ -135,7 +139,7 @@ def test__generate_config__azure_synapse_exceptions( ("some_token", "some_url", None), (None, "some_url", ValueError), ("some_token", None, ValueError), - ] + ], ) def test__generate_config__databricks_exceptions( mocker: MockerFixture, @@ -146,10 +150,13 @@ def test__generate_config__databricks_exceptions( """Test if exceptions are raised when databricks url and token are not provided.""" # Either env vars or argument should yield the same result - for environ in [{"DATABRICKS_WORKSPACE_TOKEN_VALUE": databricks_token}, { - "DATABRICKS_WORKSPACE_TOKEN_VALUE": databricks_token, - "SPARK_CONFIG__DATABRICKS__WORKSPACE_INSTANCE_URL": workspace_url, - }]: + for environ in [ + {"DATABRICKS_WORKSPACE_TOKEN_VALUE": databricks_token}, + { + "DATABRICKS_WORKSPACE_TOKEN_VALUE": databricks_token, + "SPARK_CONFIG__DATABRICKS__WORKSPACE_INSTANCE_URL": workspace_url, + }, + ]: # Mock the os.environ to return the specified env vars mocker.patch.object(feathr.utils.config.os, "environ", environ) diff --git a/feathr_project/test/unit/utils/test_env_config_reader.py b/feathr_project/test/unit/utils/test_env_config_reader.py index fd54e4a27..2d73206e8 100644 --- a/feathr_project/test/unit/utils/test_env_config_reader.py +++ b/feathr_project/test/unit/utils/test_env_config_reader.py @@ -25,7 +25,7 @@ (None, TEST_CONFIG_FILE_CONTENT, TEST_CONFIG_AKV_VAL, TEST_CONFIG_FILE_VAL), (None, "", TEST_CONFIG_AKV_VAL, TEST_CONFIG_AKV_VAL), (None, "", None, "default"), - ] + ], ) def test__envvariableutil__get( mocker: MockerFixture, @@ -34,13 +34,16 @@ def test__envvariableutil__get( akv_value: str, expected_value: str, ): - """Test `get` method if it returns the expected value. - """ + """Test `get` method if it returns the expected value.""" # Mock env variables - mocker.patch.object(feathr.utils._env_config_reader.os, "environ", { - TEST_CONFIG_KEY: env_value, - "secrets__azure_key_vault__name": "test_akv_name", - }) + mocker.patch.object( + feathr.utils._env_config_reader.os, + "environ", + { + TEST_CONFIG_KEY: env_value, + "secrets__azure_key_vault__name": "test_akv_name", + }, + ) # Mock AKS mocker.patch.object( feathr.utils._env_config_reader.AzureKeyVaultClient, @@ -62,7 +65,7 @@ def test__envvariableutil__get( (TEST_CONFIG_ENV_VAL, TEST_CONFIG_AKV_VAL, TEST_CONFIG_ENV_VAL), (None, TEST_CONFIG_AKV_VAL, TEST_CONFIG_AKV_VAL), (None, None, None), - ] + ], ) def test__envvariableutil__get_from_env_or_akv( mocker: MockerFixture, @@ -70,13 +73,16 @@ def test__envvariableutil__get_from_env_or_akv( akv_value: str, expected_value: str, ): - """Test `get_from_env_or_akv` method if it returns the expected value. - """ + """Test `get_from_env_or_akv` method if it returns the expected value.""" # Mock env variables - mocker.patch.object(feathr.utils._env_config_reader.os, "environ", { - TEST_CONFIG_KEY: env_value, - "secrets__azure_key_vault__name": "test_akv_name", - }) + mocker.patch.object( + feathr.utils._env_config_reader.os, + "environ", + { + TEST_CONFIG_KEY: env_value, + "secrets__azure_key_vault__name": "test_akv_name", + }, + ) # Mock AKS mocker.patch.object( feathr.utils._env_config_reader.AzureKeyVaultClient, diff --git a/feathr_project/test/unit/utils/test_job_utils.py b/feathr_project/test/unit/utils/test_job_utils.py index 4a0d835e5..8e5f870dc 100644 --- a/feathr_project/test/unit/utils/test_job_utils.py +++ b/feathr_project/test/unit/utils/test_job_utils.py @@ -53,7 +53,8 @@ def test__get_result_spark_df(mocker: MockerFixture): @pytest.mark.parametrize( - "is_databricks,spark_runtime,res_url,local_cache_path,expected_local_cache_path", [ + "is_databricks,spark_runtime,res_url,local_cache_path,expected_local_cache_path", + [ # For local spark results, res_url must be a local path and local_cache_path will be ignored. (False, "local", "some_res_url", None, "some_res_url"), (False, "local", "some_res_url", "some_local_cache_path", "some_res_url"), @@ -63,7 +64,7 @@ def test__get_result_spark_df(mocker: MockerFixture): (True, "databricks", "dbfs:/some_res_url", "some_local_cache_path", "/dbfs/some_res_url"), (False, "databricks", "dbfs:/some_res_url", None, "mocked_temp_path"), (False, "databricks", "dbfs:/some_res_url", "some_local_cache_path", "some_local_cache_path"), - ] + ], ) def test__get_result_df__with_local_cache_path( mocker: MockerFixture, @@ -98,7 +99,8 @@ def test__get_result_df__with_local_cache_path( @pytest.mark.parametrize( - "is_databricks,spark_runtime,res_url,data_format,expected_error", [ + "is_databricks,spark_runtime,res_url,data_format,expected_error", + [ # Test RuntimeError when the function is running at Databricks but client.spark_runtime is not databricks (True, "local", "some_url", "some_format", RuntimeError), (True, "azure_synapse", "some_url", "some_format", RuntimeError), @@ -116,7 +118,7 @@ def test__get_result_df__with_local_cache_path( (False, "local", "some_url", None, ValueError), (False, "azure_synapse", "some_url", None, ValueError), (False, "databricks", "some_url", None, ValueError), - ] + ], ) def test__get_result_df__exceptions( mocker: MockerFixture, @@ -158,13 +160,18 @@ def test__get_result_df__exceptions( @pytest.mark.parametrize( - "data_format,output_filename,expected_count", [ + "data_format,output_filename,expected_count", + [ ("csv", "output.csv", 5), - ("csv", "output_dir.csv", 4), # TODO add a header to the csv file and change expected_count to 5 after fixing the bug https://github.com/feathr-ai/feathr/issues/811 + ( + "csv", + "output_dir.csv", + 4, + ), # TODO add a header to the csv file and change expected_count to 5 after fixing the bug https://github.com/feathr-ai/feathr/issues/811 ("parquet", "output.parquet", 5), ("avro", "output.avro", 5), ("delta", "output-delta", 5), - ] + ], ) def test__get_result_df( workspace_dir: str, @@ -200,13 +207,18 @@ def test__get_result_df( @pytest.mark.parametrize( - "data_format,output_filename,expected_count", [ + "data_format,output_filename,expected_count", + [ ("csv", "output.csv", 5), - ("csv", "output_dir.csv", 4), # TODO add a header to the csv file and change expected_count = 5 after fixing the bug https://github.com/feathr-ai/feathr/issues/811 + ( + "csv", + "output_dir.csv", + 4, + ), # TODO add a header to the csv file and change expected_count = 5 after fixing the bug https://github.com/feathr-ai/feathr/issues/811 ("parquet", "output.parquet", 5), ("avro", "output.avro", 5), ("delta", "output-delta", 5), - ] + ], ) def test__get_result_df__with_spark_session( workspace_dir: str, @@ -240,9 +252,10 @@ def test__get_result_df__with_spark_session( @pytest.mark.parametrize( - "format,output_filename,expected_count", [ + "format,output_filename,expected_count", + [ ("csv", "output.csv", 5), - ] + ], ) def test__get_result_df__arg_alias( workspace_dir: str,