From c58329d348c9a0ccfc7cf78318cd631fa9cb4ac1 Mon Sep 17 00:00:00 2001 From: Abhishek Singh Date: Thu, 6 Jul 2023 16:08:19 -0700 Subject: [PATCH 1/3] ADBS OML REST API and support for async execution of Python models - Switched to REST API for remote Python execution - Introduced config parameters: async_flag, timeout and service - Better error reporting for Python models - Tested with 10M data records --- .github/workflows/oracle-xe-adapter-tests.yml | 2 +- dbt/adapters/oracle/__version__.py | 2 +- dbt/adapters/oracle/connections.py | 6 +- dbt/adapters/oracle/impl.py | 56 ++--- dbt/adapters/oracle/python_submissions.py | 225 ++++++++++++++++++ dbt_adbs_test_project/models/test_py_ref.py | 3 + dbt_adbs_test_project/profiles.yml | 2 +- requirements.txt | 4 +- requirements_dev.txt | 2 +- setup.cfg | 6 +- setup.py | 6 +- tox.ini | 2 +- 12 files changed, 262 insertions(+), 54 deletions(-) create mode 100644 dbt/adapters/oracle/python_submissions.py diff --git a/.github/workflows/oracle-xe-adapter-tests.yml b/.github/workflows/oracle-xe-adapter-tests.yml index 0f9c03b..6709f0e 100644 --- a/.github/workflows/oracle-xe-adapter-tests.yml +++ b/.github/workflows/oracle-xe-adapter-tests.yml @@ -48,7 +48,7 @@ jobs: - name: Install dbt-oracle with core dependencies run: | python -m pip install --upgrade pip - pip install pytest dbt-tests-adapter==1.5.1 + pip install pytest dbt-tests-adapter==1.5.2 pip install -r requirements.txt pip install -e . diff --git a/dbt/adapters/oracle/__version__.py b/dbt/adapters/oracle/__version__.py index 54eae38..fe59a53 100644 --- a/dbt/adapters/oracle/__version__.py +++ b/dbt/adapters/oracle/__version__.py @@ -14,4 +14,4 @@ See the License for the specific language governing permissions and limitations under the License. """ -version = "1.5.1" +version = "1.5.2" diff --git a/dbt/adapters/oracle/connections.py b/dbt/adapters/oracle/connections.py index 1c10c78..3e20d5c 100644 --- a/dbt/adapters/oracle/connections.py +++ b/dbt/adapters/oracle/connections.py @@ -109,8 +109,8 @@ class OracleAdapterCredentials(Credentials): retry_count: Optional[int] = 1 retry_delay: Optional[int] = 3 - # Fetch an auth token to run Python UDF - oml_auth_token_uri: Optional[str] = None + # Base URL for ADB-S OML REST API + oml_cloud_service_url: Optional[str] = None _ALIASES = { @@ -136,7 +136,7 @@ def _connection_keys(self) -> Tuple[str]: 'service', 'connection_string', 'shardingkey', 'supershardingkey', 'cclass', 'purity', 'retry_count', - 'retry_delay', 'oml_auth_token_uri' + 'retry_delay', 'oml_cloud_service_url' ) @classmethod diff --git a/dbt/adapters/oracle/impl.py b/dbt/adapters/oracle/impl.py index 495fee6..a0e5142 100644 --- a/dbt/adapters/oracle/impl.py +++ b/dbt/adapters/oracle/impl.py @@ -42,6 +42,8 @@ from dbt.utils import filter_null_values from dbt.adapters.oracle.keyword_catalog import KEYWORDS +from dbt.adapters.oracle.python_submissions import OracleADBSPythonJob +from dbt.adapters.oracle.connections import AdapterResponse logger = AdapterLogger("oracle") @@ -367,49 +369,27 @@ def get_oml_auth_token(self) -> str: def submit_python_job(self, parsed_model: dict, compiled_code: str): """Submit user defined Python function - - The function pyqEval when used in Oracle Autonomous Database, - calls a user-defined Python function. - - pyqEval(PAR_LST, OUT_FMT, SRC_NAME, SRC_OWNER, ENV_NAME) - - - PAR_LST -> Parameter List - - OUT_FMT -> JSON clob of the columns - - ENV_NAME -> Name of conda environment + https://docs.oracle.com/en/database/oracle/machine-learning/oml4py/1/mlepe/op-py-scripts-v1-do-eval-scriptname-post.html """ identifier = parsed_model["alias"] - oml_oauth_access_token = self.get_oml_auth_token() py_q_script_name = f"{identifier}_dbt_py_script" - py_q_eval_output_fmt = '{"result":"number"}' - py_q_eval_result_table = f"o$pt_dbt_pyqeval_{identifier}_tmp_{datetime.datetime.utcnow().strftime('%H%M%S')}" - - conda_env_name = parsed_model["config"].get("conda_env_name") - if conda_env_name: - logger.info("Custom python environment is %s", conda_env_name) - py_q_eval_sql = f"""CREATE GLOBAL TEMPORARY TABLE {py_q_eval_result_table} - AS SELECT * FROM TABLE(pyqEval(par_lst => NULL, - out_fmt => ''{py_q_eval_output_fmt}'', - scr_name => ''{py_q_script_name}'', - scr_owner => NULL, - env_name => ''{conda_env_name}''))""" - else: - py_q_eval_sql = f"""CREATE GLOBAL TEMPORARY TABLE {py_q_eval_result_table} - AS SELECT * FROM TABLE(pyqEval(par_lst => NULL, - out_fmt => ''{py_q_eval_output_fmt}'', - scr_name => ''{py_q_script_name}'', - scr_owner => NULL))""" - - py_exec_main_sql = f""" - BEGIN - sys.pyqSetAuthToken('{oml_oauth_access_token}'); - sys.pyqScriptCreate('{py_q_script_name}', '{compiled_code.strip()}', FALSE, TRUE); - EXECUTE IMMEDIATE '{py_q_eval_sql}'; - EXECUTE IMMEDIATE 'DROP TABLE {py_q_eval_result_table}'; - sys.pyqScriptDrop('{py_q_script_name}'); - END; + py_q_create_script = f""" + BEGIN + sys.pyqScriptCreate('{py_q_script_name}', '{compiled_code.strip()}', FALSE, TRUE); + END; """ - response, _ = self.execute(sql=py_exec_main_sql) + response, _ = self.execute(sql=py_q_create_script) + python_job = OracleADBSPythonJob(parsed_model=parsed_model, + credential=self.config.credentials) + python_job() + py_q_drop_script = f""" + BEGIN + sys.pyqScriptDrop('{py_q_script_name}'); + END; + """ + + response, _ = self.execute(sql=py_q_drop_script) logger.info(response) return response diff --git a/dbt/adapters/oracle/python_submissions.py b/dbt/adapters/oracle/python_submissions.py new file mode 100644 index 0000000..5a43152 --- /dev/null +++ b/dbt/adapters/oracle/python_submissions.py @@ -0,0 +1,225 @@ +""" +Copyright (c) 2023, Oracle and/or its affiliates. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import datetime +import http +import json +from typing import Dict + +import requests +import time + +import dbt.exceptions +from dbt.adapters.oracle import OracleAdapterCredentials +from dbt.events import AdapterLogger +from dbt.ui import red, green + +# ADB-S OML Rest API minimum timeout is 1800 seconds +DEFAULT_TIMEOUT_IN_SECONDS = 1800 +DEFAULT_DELAY_BETWEEN_POLL_IN_SECONDS = 2 + +OMLUSERS_OAUTH_API = "/omlusers/api/oauth2/v1/token" +OML_DO_EVAL_API = "/oml/api/py-scripts/v1/do-eval/{script_name}" + +logger = AdapterLogger("oracle") + + +class OracleOML4PYClient: + + def __init__(self, oml_cloud_service_url, username, password): + self.base_url = oml_cloud_service_url + self._username = username + self._password = password + self.token = None + self.token_expires_at = None + self.token_url = self.base_url + OMLUSERS_OAUTH_API + self._session = requests.Session() + + @property + def session(self): + return self._session + + def get_token(self): + """Get access_token or refresh_token""" + # If access token is about to expire then refresh the token + if self.token_expires_at and self.token_expires_at - datetime.datetime.utcnow() < datetime.timedelta(minutes=1): + return self._get_token(grant_type="refresh_token") + elif self.token: # Token is valid + return self.token + else: # Generate a new token + return self._get_token(grant_type="password") + + def _get_token(self, grant_type="password"): + """Gets access_token or refresh_token using /broker/pdbcs/private/v1/token""" + data = {"grant_type": grant_type} + if grant_type == "password": + data["username"] = self._username + data["password"] = self._password + else: + data["token"] = self.token + + r = self.session.post( + url=self.token_url, + json=data, + headers={ + "Accept": "application/json", + "Content-type": "application/json", + }, + ) + r.raise_for_status() + response = r.json() + self.token = response["accessToken"] + self.token_expires_at = datetime.datetime.utcnow() + datetime.timedelta(seconds=response["expiresIn"]) + return self.token + + @property + def default_headers(self): + """Default headers added to every request""" + return { + "Content-type": "application/json", + "Accept": "application/json", + "Authorization": f"Bearer {self.get_token()}", + } + + def request(self, method: str, path: str, + raise_for_status: bool = False, + **kwargs) -> requests.Response: + """ + Description: + Perform a desired action (GET, PUT, POST) on a certain resource + + Args: + method (str) -> HTTP verb like GET, PUT, POST, etc + path (str) -> path to the resource e.g. /job/{job_id} + raise_for_status (bool) -> True if HTTPError should be raised in case of an error else False + + Returns: + object of type request.Response + + Raises: + requests.HTTPError() in case of en error, if raise_for_status is True + + """ + url = path if path.startswith(self.base_url) else self.base_url + path + self.session.headers.update(self.default_headers) + r = self.session.request(method=method, url=url, **kwargs) + try: + r.raise_for_status() + except requests.HTTPError: + if raise_for_status: + raise + return r + + +class OracleADBSPythonJob: + """Callable to submit Python Script to ADB-S + + """ + + def __init__(self, + parsed_model: Dict, + credential: OracleAdapterCredentials) -> None: + self.identifier = parsed_model["alias"] + self.py_q_script_name = f"{self.identifier}_dbt_py_script" + self.conda_env_name = parsed_model["config"].get("conda_env_name") + self.timeout = parsed_model["config"].get("timeout", DEFAULT_TIMEOUT_IN_SECONDS) + self.async_flag = parsed_model["config"].get("async_flag", False) + self.service = parsed_model["config"].get("service", "HIGH") + self.oml4py_client = OracleOML4PYClient(oml_cloud_service_url=credential.oml_cloud_service_url, + username=credential.user, + password=credential.password) + + def schedule_async_job_and_wait_for_completion(self, data): + logger.info(f"Running Python aysnc job using {data}") + try: + r = self.oml4py_client.request(method="POST", + path=OML_DO_EVAL_API.format(script_name=self.py_q_script_name), + data=json.dumps(data), + raise_for_status=False) + if r.status_code in (http.HTTPStatus.BAD_REQUEST, http.HTTPStatus.INTERNAL_SERVER_ERROR): + logger.error(red(r.json())) + r.raise_for_status() + except requests.exceptions.RequestException as e: + logger.error(red(f"Error {e} scheduling async Python job for model {self.identifier}")) + raise dbt.exceptions.DbtRuntimeError(f"Error scheduling Python model {self.identifier}") + + job_location = r.headers["location"] + logger.info(f"Started async job {job_location}") + start_time = time.time() + + while time.time() - start_time < self.timeout: + logger.debug(f"Checking Job status for : {job_location}") + try: + job_status = self.oml4py_client.request(method="GET", + path=job_location, + raise_for_status=False) + job_status_code = job_status.status_code + logger.debug(f"Job status code is: {job_status_code}") + if job_status_code == http.HTTPStatus.FOUND: + logger.info(green(f"Job {job_location} completed")) + job_result = self.oml4py_client.request(method="GET", + path=f"{job_location}/result", + raise_for_status=False) + job_result_json = job_result.json() + if 'errorMessage' in job_result_json: + logger.error(red(f"FAILURE - Python model {self.identifier} Job failure is: {job_result_json}")) + raise dbt.exceptions.DbtRuntimeError(f"Error running Python model {self.identifier}") + job_result.raise_for_status() + logger.info(green(f"SUCCESS - Python model {self.identifier} Job result is: {job_result_json}")) + return + elif job_status_code == http.HTTPStatus.INTERNAL_SERVER_ERROR: + logger.error(red(f"FAILURE - Job status is: {job_status.json()}")) + raise dbt.exceptions.DbtRuntimeError(f"Error running Python model {self.identifier}") + else: + logger.debug(f"Python model {self.identifier} job status is: {job_status.json()}") + job_status.raise_for_status() + + except requests.exceptions.RequestException as e: + logger.error(red(f"Error {e} checking status of Python job {job_location} for model {self.identifier}")) + raise dbt.exceptions.DbtRuntimeError(f"Error checking status for job {job_location}") + + time.sleep(DEFAULT_DELAY_BETWEEN_POLL_IN_SECONDS) + logger.error(red(f"Timeout error for Python model {self.identifier}")) + raise dbt.exceptions.DbtRuntimeError(f"Timeout error for Python model {self.identifier}") + + def __call__(self, *args, **kwargs): + data = { + "service": self.service + } + if self.async_flag: + data["asyncFlag"] = self.async_flag + data["timeout"] = self.timeout + if self.conda_env_name: + data["envName"] = self.conda_env_name + + if self.async_flag: + self.schedule_async_job_and_wait_for_completion(data=data) + else: # Run in blocking mode + logger.info(f"Running Python model {self.identifier} with args {data}") + try: + r = self.oml4py_client.request(method="POST", + path=OML_DO_EVAL_API.format(script_name=self.py_q_script_name), + data=json.dumps(data), + raise_for_status=False) + job_result = r.json() + if 'errorMessage' in job_result: + logger.error(red(f"FAILURE - Python model {self.identifier} Job failure is: {job_result}")) + raise dbt.exceptions.DbtRuntimeError(f"Error running Python model {self.identifier}") + r.raise_for_status() + logger.info(green(f"SUCCESS - Python model {self.identifier} Job result is: {job_result}")) + except requests.exceptions.RequestException as e: + logger.error(red(f"Error {e} running Python model {self.identifier}")) + raise dbt.exceptions.DbtRuntimeError(f"Error running Python model {self.identifier}") + diff --git a/dbt_adbs_test_project/models/test_py_ref.py b/dbt_adbs_test_project/models/test_py_ref.py index 3085f6e..12b1ce7 100644 --- a/dbt_adbs_test_project/models/test_py_ref.py +++ b/dbt_adbs_test_project/models/test_py_ref.py @@ -1,6 +1,9 @@ def model(dbt, session): # Must be either table or incremental (view is not currently supported) dbt.config(materialized="table") + dbt.config(async_flag=True) + dbt.config(timeout=900) # In seconds + dbt.config(service="HIGH") # LOW, MEDIUM, HIGH # oml.core.DataFrame representing a datasource s_df = dbt.ref("sales_cost") return s_df diff --git a/dbt_adbs_test_project/profiles.yml b/dbt_adbs_test_project/profiles.yml index 8e544b8..b211f52 100644 --- a/dbt_adbs_test_project/profiles.yml +++ b/dbt_adbs_test_project/profiles.yml @@ -11,7 +11,7 @@ dbt_test: service: "{{ env_var('DBT_ORACLE_SERVICE') }}" #database: "{{ env_var('DBT_ORACLE_DATABASE') }}" schema: "{{ env_var('DBT_ORACLE_SCHEMA') }}" - oml_auth_token_uri: "{{ env_var('DBT_ORACLE_OML_AUTH_TOKEN_API')}}" + oml_cloud_service_url: "{{ env_var('DBT_ORACLE_OML_CLOUD_SERVICE_URL')}}" retry_count: 1 retry_delay: 5 shardingkey: diff --git a/requirements.txt b/requirements.txt index 2230318..9c24e14 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -dbt-core==1.5.1 +dbt-core==1.5.2 cx_Oracle==8.3.0 -oracledb==1.3.1 +oracledb==1.3.2 diff --git a/requirements_dev.txt b/requirements_dev.txt index 2c0d18c..8cfc04b 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -6,4 +6,4 @@ tox coverage twine pytest -dbt-tests-adapter==1.5.1 +dbt-tests-adapter==1.5.2 diff --git a/setup.cfg b/setup.cfg index 647e275..d8fb4f8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -33,12 +33,12 @@ zip_safe = False packages = find: include_package_data = True install_requires = - dbt-core==1.5.1 + dbt-core==1.5.2 cx_Oracle==8.3.0 - oracledb==1.3.1 + oracledb==1.3.2 test_suite=tests test_requires = - dbt-tests-adapter==1.5.1 + dbt-tests-adapter==1.5.2 pytest scripts = bin/create-pem-from-p12 diff --git a/setup.py b/setup.py index 5473d31..6fbe168 100644 --- a/setup.py +++ b/setup.py @@ -32,13 +32,13 @@ requirements = [ - "dbt-core==1.5.1", + "dbt-core==1.5.2", "cx_Oracle==8.3.0", - "oracledb==1.3.1" + "oracledb==1.3.2" ] test_requirements = [ - "dbt-tests-adapter==1.5.1", + "dbt-tests-adapter==1.5.2", "pytest" ] diff --git a/tox.ini b/tox.ini index 30543f5..fd80d84 100644 --- a/tox.ini +++ b/tox.ini @@ -15,7 +15,7 @@ passenv = deps = -rrequirements.txt - dbt-tests-adapter==1.5.1 + dbt-tests-adapter==1.5.2 pytest commands = pytest From 79c3552ff207956b66cf8909a26399f986272728 Mon Sep 17 00:00:00 2001 From: Abhishek Singh Date: Fri, 7 Jul 2023 14:07:27 -0700 Subject: [PATCH 2/3] Changed signature of oracle__get_empty_subquery_sql based on dbt-core upgrade --- dbt/include/oracle/macros/adapters.sql | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dbt/include/oracle/macros/adapters.sql b/dbt/include/oracle/macros/adapters.sql index 4653eaf..9b20c6c 100644 --- a/dbt/include/oracle/macros/adapters.sql +++ b/dbt/include/oracle/macros/adapters.sql @@ -25,7 +25,7 @@ {{ return(load_result('get_columns_in_query').table.columns | map(attribute='name') | list) }} {% endmacro %} -{% macro oracle__get_empty_subquery_sql(select_sql) %} +{% macro oracle__get_empty_subquery_sql(select_sql, select_sql_header=none) %} select * from ( {{ select_sql }} ) dbt_sbq_tmp From 6c6cd590a0dca3fbca9e8b4b003c25749e3aa87e Mon Sep 17 00:00:00 2001 From: Abhishek Singh Date: Fri, 7 Jul 2023 14:54:58 -0700 Subject: [PATCH 3/3] Test Case fix: Removed round brackets around constraint 'check' expression --- tests/functional/adapter/constraints/fixtures.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/functional/adapter/constraints/fixtures.py b/tests/functional/adapter/constraints/fixtures.py index fe49d40..1011b7e 100644 --- a/tests/functional/adapter/constraints/fixtures.py +++ b/tests/functional/adapter/constraints/fixtures.py @@ -194,7 +194,7 @@ - type: not_null - type: primary_key - type: check - expression: (id > 0) + expression: id > 0 tests: - unique - name: color @@ -269,7 +269,7 @@ enforced: true constraints: - type: check - expression: (id > 0) + expression: id > 0 - type: primary_key columns: [ id ] - type: unique