Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/oracle-xe-adapter-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ jobs:
- name: Install dbt-oracle with core dependencies
run: |
python -m pip install --upgrade pip
pip install pytest dbt-tests-adapter==1.5.0
pip install pytest dbt-tests-adapter==1.5.1
pip install -r requirements.txt
pip install -e .

Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,4 @@ doc/build.gitbak
.venv1.3/
.venv1.4/
.venv1.5/
dbt_adbs_py_test_project
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Configuration variables
VERSION=1.5.0
VERSION=1.5.1
PROJ_DIR?=$(shell pwd)
VENV_DIR?=${PROJ_DIR}/.bldenv
BUILD_DIR=${PROJ_DIR}/build
Expand Down
2 changes: 1 addition & 1 deletion dbt/adapters/oracle/__version__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
version = "1.5.0"
version = "1.5.1"
40 changes: 35 additions & 5 deletions dbt/adapters/oracle/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@
from dbt.adapters.base import Credentials
from dbt.adapters.sql import SQLConnectionManager
from dbt.contracts.connection import AdapterResponse
from dbt.events.functions import fire_event
from dbt.events.types import ConnectionUsed, SQLQuery, SQLCommit, SQLQueryStatus
from dbt.events import AdapterLogger
from dbt.events.contextvars import get_node_info
from dbt.utils import cast_to_str

from dbt.version import __version__ as dbt_version
from dbt.adapters.oracle.connection_helper import oracledb, SQLNET_ORA_CONFIG
Expand Down Expand Up @@ -105,6 +109,9 @@ class OracleAdapterCredentials(Credentials):
retry_count: Optional[int] = 1
retry_delay: Optional[int] = 3

# Fetch an auth token to run Python UDF
oml_auth_token_uri: Optional[str] = None


_ALIASES = {
'dbname': 'database',
Expand All @@ -129,7 +136,7 @@ def _connection_keys(self) -> Tuple[str]:
'service', 'connection_string',
'shardingkey', 'supershardingkey',
'cclass', 'purity', 'retry_count',
'retry_delay'
'retry_delay', 'oml_auth_token_uri'
)

@classmethod
Expand Down Expand Up @@ -293,20 +300,36 @@ def add_query(
if auto_begin and connection.transaction_open is False:
self.begin()

logger.debug('Using {} connection "{}".'
.format(self.TYPE, connection.name))
fire_event(
ConnectionUsed(
conn_type=self.TYPE,
conn_name=cast_to_str(connection.name),
node_info=get_node_info(),
)
)

with self.exception_handler(sql):
if abridge_sql_log:
log_sql = '{}...'.format(sql[:512])
else:
log_sql = sql

logger.debug(f'On {connection.name}: f{log_sql}')
fire_event(
SQLQuery(
conn_name=cast_to_str(connection.name), sql=log_sql, node_info=get_node_info()
)
)

pre = time.time()
cursor = connection.handle.cursor()
cursor.execute(sql, bindings)
logger.debug(f"SQL status: {self.get_status(cursor)} in {(time.time() - pre)} seconds")
fire_event(
SQLQueryStatus(
status=str(self.get_response(cursor)),
elapsed=round((time.time() - pre)),
node_info=get_node_info(),
)
)
return connection, cursor

def add_begin_query(self):
Expand All @@ -317,3 +340,10 @@ def add_begin_query(self):
@classmethod
def data_type_code_to_name(cls, type_code) -> str:
return DATATYPES[type_code.name]

def commit(self):
connection = self.get_thread_connection()
fire_event(SQLCommit(conn_name=connection.name, node_info=get_node_info()))
self.add_commit_query()
connection.transaction_open = False
return connection
68 changes: 68 additions & 0 deletions dbt/adapters/oracle/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
import datetime
from typing import (
Optional, List, Set
)
Expand All @@ -24,6 +25,7 @@
Dict)

import agate
import requests

import dbt.exceptions
from dbt.adapters.base.relation import BaseRelation, InformationSchema
Expand Down Expand Up @@ -345,3 +347,69 @@ def render_raw_columns_constraints(cls, raw_columns: Dict[str, Dict[str, Any]])
rendered_column_constraints.append(" ".join(rendered_column_constraint))

return rendered_column_constraints

def get_oml_auth_token(self) -> str:
if self.config.credentials.oml_auth_token_uri is None:
raise dbt.exceptions.DbtRuntimeError("oml_auth_token_uri should be set to run dbt-py models")
data = {
"grant_type": "password",
"username": self.config.credentials.user,
"password": self.config.credentials.password
}
try:
r = requests.post(url=self.config.credentials.oml_auth_token_uri,
json=data)
r.raise_for_status()
except requests.exceptions.RequestException:
raise dbt.exceptions.DbtRuntimeError("Error getting OML OAuth2.0 token")
else:
return r.json()["accessToken"]

def submit_python_job(self, parsed_model: dict, compiled_code: str):
"""Submit user defined Python function

The function pyqEval when used in Oracle Autonomous Database,
calls a user-defined Python function.

pyqEval(PAR_LST, OUT_FMT, SRC_NAME, SRC_OWNER, ENV_NAME)

- PAR_LST -> Parameter List
- OUT_FMT -> JSON clob of the columns
- ENV_NAME -> Name of conda environment


"""
identifier = parsed_model["alias"]
oml_oauth_access_token = self.get_oml_auth_token()
py_q_script_name = f"{identifier}_dbt_py_script"
py_q_eval_output_fmt = '{"result":"number"}'
py_q_eval_result_table = f"o$pt_dbt_pyqeval_{identifier}_tmp_{datetime.datetime.utcnow().strftime('%H%M%S')}"

conda_env_name = parsed_model["config"].get("conda_env_name")
if conda_env_name:
logger.info("Custom python environment is %s", conda_env_name)
py_q_eval_sql = f"""CREATE GLOBAL TEMPORARY TABLE {py_q_eval_result_table}
AS SELECT * FROM TABLE(pyqEval(par_lst => NULL,
out_fmt => ''{py_q_eval_output_fmt}'',
scr_name => ''{py_q_script_name}'',
scr_owner => NULL,
env_name => ''{conda_env_name}''))"""
else:
py_q_eval_sql = f"""CREATE GLOBAL TEMPORARY TABLE {py_q_eval_result_table}
AS SELECT * FROM TABLE(pyqEval(par_lst => NULL,
out_fmt => ''{py_q_eval_output_fmt}'',
scr_name => ''{py_q_script_name}'',
scr_owner => NULL))"""

py_exec_main_sql = f"""
BEGIN
sys.pyqSetAuthToken('{oml_oauth_access_token}');
sys.pyqScriptCreate('{py_q_script_name}', '{compiled_code.strip()}', FALSE, TRUE);
EXECUTE IMMEDIATE '{py_q_eval_sql}';
EXECUTE IMMEDIATE 'DROP TABLE {py_q_eval_result_table}';
sys.pyqScriptDrop('{py_q_script_name}');
END;
"""
response, _ = self.execute(sql=py_exec_main_sql)
logger.info(response)
return response
50 changes: 27 additions & 23 deletions dbt/include/oracle/macros/adapters.sql
Original file line number Diff line number Diff line change
Expand Up @@ -136,29 +136,33 @@
{%- endmacro %}


{% macro oracle__create_table_as(temporary, relation, sql) -%}
{%- set sql_header = config.get('sql_header', none) -%}
{%- set parallel = config.get('parallel', none) -%}
{%- set compression_clause = config.get('table_compression_clause', none) -%}
{%- set contract_config = config.get('contract') -%}

{{ sql_header if sql_header is not none }}

create {% if temporary -%}
global temporary
{%- endif %} table {{ relation.include(schema=(not temporary)) }}
{%- if contract_config.enforced -%}
{{ get_assert_columns_equivalent(sql) }}
{{ get_table_columns_and_constraints() }}
{%- set sql = get_select_subquery(sql) %}
{% endif %}
{% if temporary -%} on commit preserve rows {%- endif %}
{% if not temporary -%}
{% if parallel %} parallel {{ parallel }}{% endif %}
{% if compression_clause %} {{ compression_clause }} {% endif %}
{%- endif %}
as
{{ sql }}
{% macro oracle__create_table_as(temporary, relation, sql, language='sql') -%}
{%- if language == 'sql' -%}
{%- set sql_header = config.get('sql_header', none) -%}
{%- set parallel = config.get('parallel', none) -%}
{%- set compression_clause = config.get('table_compression_clause', none) -%}
{%- set contract_config = config.get('contract') -%}
{{ sql_header if sql_header is not none }}
create {% if temporary -%}
global temporary
{%- endif %} table {{ relation.include(schema=(not temporary)) }}
{%- if contract_config.enforced -%}
{{ get_assert_columns_equivalent(sql) }}
{{ get_table_columns_and_constraints() }}
{%- set sql = get_select_subquery(sql) %}
{% endif %}
{% if temporary -%} on commit preserve rows {%- endif %}
{% if not temporary -%}
{% if parallel %} parallel {{ parallel }}{% endif %}
{% if compression_clause %} {{ compression_clause }} {% endif %}
{%- endif %}
as
{{ sql }}
{%- elif language == 'python' -%}
{{ py_write_table(compiled_code=compiled_code, target_relation=relation, temporary=temporary) }}
{%- else -%}
{% do exceptions.raise_compiler_error("oracle__create_table_as macro didn't get supported language, it got %s" % language) %}
{%- endif -%}

{%- endmacro %}
{% macro oracle__create_view_as(relation, sql) -%}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
See the License for the specific language governing permissions and
limitations under the License.
#}
{% materialization incremental, adapter='oracle' %}
{% materialization incremental, adapter='oracle', supported_languages=['sql', 'python'] %}

{% set unique_key = config.get('unique_key') %}
{% set full_refresh_mode = flags.FULL_REFRESH %}

{%- set language = model['language'] -%}
{% set target_relation = this.incorporate(type='table') %}
{% set existing_relation = load_relation(this) %}
{% set tmp_relation = make_temp_relation(this) %}
Expand All @@ -32,7 +32,7 @@

{% set to_drop = [] %}
{% if existing_relation is none %}
{% set build_sql = create_table_as(False, target_relation, sql) %}
{% set build_sql = create_table_as(False, target_relation, sql, language) %}
{% elif existing_relation.is_view or full_refresh_mode %}
{#-- Make sure the backup doesn't exist so we don't encounter issues with the rename below #}
{% set backup_identifier = existing_relation.identifier ~ "__dbt_backup" %}
Expand All @@ -43,12 +43,16 @@
{% else %}
{% do adapter.rename_relation(existing_relation, backup_relation) %}
{% endif %}
{% set build_sql = create_table_as(False, target_relation, sql) %}
{% set build_sql = create_table_as(False, target_relation, sql, language) %}
{% do to_drop.append(backup_relation) %}
{% else %}
{% set tmp_relation = make_temp_relation(target_relation) %}
{% do to_drop.append(tmp_relation) %}
{% do run_query(create_table_as(True, tmp_relation, sql)) %}
{% call statement("make_tmp_relation", language=language) %}
{{create_table_as(True, tmp_relation, sql, language)}}
{% endcall %}
{#-- After this language should be SQL --#}
{% set language = 'sql' %}
{% do adapter.expand_target_column_types(
from_relation=tmp_relation,
to_relation=target_relation) %}
Expand All @@ -66,7 +70,7 @@

{% endif %}

{% call statement("main") %}
{% call statement("main", language=language) %}
{{ build_sql }}
{% endcall %}

Expand Down
Loading