diff --git a/ads/opctl/operator/common/operator_config.py b/ads/opctl/operator/common/operator_config.py index 7256fe247..4a5e49b1e 100644 --- a/ads/opctl/operator/common/operator_config.py +++ b/ads/opctl/operator/common/operator_config.py @@ -28,6 +28,7 @@ class InputData(DataClassSerializable): limit: int = None sql: str = None table_name: str = None + vault_secret_id: str = None @dataclass(repr=True) diff --git a/ads/opctl/operator/lowcode/anomaly/schema.yaml b/ads/opctl/operator/lowcode/anomaly/schema.yaml index 5b5f066ca..ea273078a 100644 --- a/ads/opctl/operator/lowcode/anomaly/schema.yaml +++ b/ads/opctl/operator/lowcode/anomaly/schema.yaml @@ -78,6 +78,9 @@ spec: limit: required: false type: integer + vault_secret_id: + required: false + type: string validation_data: required: false @@ -130,6 +133,9 @@ spec: limit: required: false type: integer + vault_secret_id: + required: false + type: string datetime_column: type: dict diff --git a/ads/opctl/operator/lowcode/common/utils.py b/ads/opctl/operator/lowcode/common/utils.py index 41355d2b2..1c9ede754 100644 --- a/ads/opctl/operator/lowcode/common/utils.py +++ b/ads/opctl/operator/lowcode/common/utils.py @@ -7,7 +7,9 @@ import argparse import logging import os +import shutil import sys +import tempfile import time from string import Template from typing import Any, Dict, List, Tuple @@ -28,6 +30,7 @@ ) from ads.opctl.operator.common.operator_config import OutputDirectory from ads.common.object_storage_details import ObjectStorageDetails +from ads.secrets import ADBSecretKeeper def call_pandas_fsspec(pd_fn, filename, storage_options, **kwargs): @@ -53,10 +56,12 @@ def load_data(data_spec, storage_options=None, **kwargs): sql = data_spec.sql table_name = data_spec.table_name limit = data_spec.limit - + vault_secret_id = data_spec.vault_secret_id storage_options = storage_options or ( default_signer() if ObjectStorageDetails.is_oci_path(filename) else {} ) + if vault_secret_id is not None and connect_args is None: + connect_args = dict() if filename is not None: if not format: @@ -76,15 +81,32 @@ def load_data(data_spec, storage_options=None, **kwargs): f"The format {format} is not currently supported for reading data. Please reformat the data source: {filename} ." ) elif connect_args is not None: - con = oracledb.connect(**connect_args) - if table_name is not None: - data = pd.read_sql_table(table_name, con) - elif sql is not None: - data = pd.read_sql(sql, con) - else: - raise InvalidParameterError( - f"Database `connect_args` provided without sql query or table name. Please specify either `sql` or `table_name`." - ) + with tempfile.TemporaryDirectory() as temp_dir: + if vault_secret_id is not None: + try: + with ADBSecretKeeper.load_secret(vault_secret_id, wallet_dir=temp_dir) as adwsecret: + if 'wallet_location' in adwsecret and 'wallet_location' not in connect_args: + shutil.unpack_archive(adwsecret["wallet_location"], temp_dir) + connect_args['wallet_location'] = temp_dir + if 'user_name' in adwsecret and 'user' not in connect_args: + connect_args['user'] = adwsecret['user_name'] + if 'password' in adwsecret and 'password' not in connect_args: + connect_args['password'] = adwsecret['password'] + if 'service_name' in adwsecret and 'service_name' not in connect_args: + connect_args['service_name'] = adwsecret['service_name'] + + except Exception as e: + raise Exception(f"Could not retrieve database credentials from vault {vault_secret_id}: {e}") + + con = oracledb.connect(**connect_args) + if table_name is not None: + data = pd.read_sql(f"SELECT * FROM {table_name}", con) + elif sql is not None: + data = pd.read_sql(sql, con) + else: + raise InvalidParameterError( + f"Database `connect_args` provided without sql query or table name. Please specify either `sql` or `table_name`." + ) else: raise InvalidParameterError( f"No filename/url provided, and no connect_args provided. Please specify one of these if you want to read data from a file or a database respectively." diff --git a/ads/opctl/operator/lowcode/forecast/schema.yaml b/ads/opctl/operator/lowcode/forecast/schema.yaml index 3bac1aa0f..934ab9469 100644 --- a/ads/opctl/operator/lowcode/forecast/schema.yaml +++ b/ads/opctl/operator/lowcode/forecast/schema.yaml @@ -78,6 +78,9 @@ spec: limit: required: false type: integer + vault_secret_id: + required: false + type: string additional_data: required: false @@ -130,6 +133,9 @@ spec: limit: required: false type: integer + vault_secret_id: + required: false + type: string test_data: required: false @@ -181,6 +187,9 @@ spec: limit: required: false type: integer + vault_secret_id: + required: false + type: string type: dict output_directory: diff --git a/tests/operators/common/test_load_data.py b/tests/operators/common/test_load_data.py new file mode 100644 index 000000000..3fc5197d9 --- /dev/null +++ b/tests/operators/common/test_load_data.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python +from typing import Union + +# Copyright (c) 2024 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ +import pytest +from ads.opctl.operator.lowcode.common.utils import ( + load_data, +) +from ads.opctl.operator.common.operator_config import InputData +from unittest.mock import patch, Mock, MagicMock +import unittest +import pandas as pd + +mock_secret = { + 'user_name': 'mock_user', + 'password': 'mock_password', + 'service_name': 'mock_service_name' +} + +mock_connect_args = { + 'user': 'mock_user', + 'password': 'mock_password', + 'service_name': 'mock_service_name', + 'dsn': 'mock_dsn' +} + +# Mock data for testing +mock_data = pd.DataFrame({ + 'id': [1, 2, 3], + 'name': ['Alice', 'Bob', 'Charlie'] +}) + +mock_db_connection = MagicMock() + +load_secret_err_msg = "Vault exception message" +db_connect_err_msg = "Mocked DB connection error" + + +def mock_oracledb_connect_failure(*args, **kwargs): + raise Exception(db_connect_err_msg) + + +def mock_oracledb_connect(**kwargs): + assert kwargs == mock_connect_args, f"Expected connect_args {mock_connect_args}, but got {kwargs}" + return mock_db_connection + + +class MockADBSecretKeeper: + @staticmethod + def __enter__(*args, **kwargs): + return mock_secret + + @staticmethod + def __exit__(*args, **kwargs): + pass + + @staticmethod + def load_secret(vault_secret_id, wallet_dir): + return MockADBSecretKeeper() + + @staticmethod + def load_secret_fail(*args, **kwargs): + raise Exception(load_secret_err_msg) + + +class TestDataLoad(unittest.TestCase): + def setUp(self): + self.data_spec = Mock(spec=InputData) + self.data_spec.connect_args = { + 'dsn': 'mock_dsn' + } + self.data_spec.vault_secret_id = 'mock_secret_id' + self.data_spec.table_name = 'mock_table_name' + self.data_spec.url = None + self.data_spec.format = None + self.data_spec.columns = None + self.data_spec.limit = None + + def testLoadSecretAndDBConnection(self): + with patch('ads.secrets.ADBSecretKeeper.load_secret', side_effect=MockADBSecretKeeper.load_secret): + with patch('oracledb.connect', side_effect=mock_oracledb_connect): + with patch('pandas.read_sql', return_value=mock_data) as mock_read_sql: + data = load_data(self.data_spec) + mock_read_sql.assert_called_once_with(f"SELECT * FROM {self.data_spec.table_name}", + mock_db_connection) + pd.testing.assert_frame_equal(data, mock_data) + + def testLoadVaultFailure(self): + with patch('ads.secrets.ADBSecretKeeper.load_secret', side_effect=MockADBSecretKeeper.load_secret_fail): + with pytest.raises(Exception) as e: + load_data(self.data_spec) + + expected_msg = f"Could not retrieve database credentials from vault {self.data_spec.vault_secret_id}: {load_secret_err_msg}" + assert str(e.value) == expected_msg, f"Expected exception message '{expected_msg}', but got '{str(e)}'" + + def testDBConnectionFailure(self): + with patch('ads.secrets.ADBSecretKeeper.load_secret', side_effect=MockADBSecretKeeper.load_secret): + with patch('oracledb.connect', side_effect=mock_oracledb_connect_failure): + with pytest.raises(Exception) as e: + load_data(self.data_spec) + + assert str(e.value) == db_connect_err_msg , f"Expected exception message '{db_connect_err_msg }', but got '{str(e)}'"