From 8bc81a83611194e9efce3d3428536509340e2d78 Mon Sep 17 00:00:00 2001 From: govarsha Date: Fri, 7 Jun 2024 13:34:42 +0530 Subject: [PATCH 1/3] added vault support --- ads/opctl/operator/common/operator_config.py | 1 + .../operator/lowcode/anomaly/schema.yaml | 6 +++ ads/opctl/operator/lowcode/common/utils.py | 42 ++++++++++++++----- .../operator/lowcode/forecast/schema.yaml | 9 ++++ 4 files changed, 48 insertions(+), 10 deletions(-) diff --git a/ads/opctl/operator/common/operator_config.py b/ads/opctl/operator/common/operator_config.py index 7256fe247..4a5e49b1e 100644 --- a/ads/opctl/operator/common/operator_config.py +++ b/ads/opctl/operator/common/operator_config.py @@ -28,6 +28,7 @@ class InputData(DataClassSerializable): limit: int = None sql: str = None table_name: str = None + vault_secret_id: str = None @dataclass(repr=True) diff --git a/ads/opctl/operator/lowcode/anomaly/schema.yaml b/ads/opctl/operator/lowcode/anomaly/schema.yaml index 5b5f066ca..ea273078a 100644 --- a/ads/opctl/operator/lowcode/anomaly/schema.yaml +++ b/ads/opctl/operator/lowcode/anomaly/schema.yaml @@ -78,6 +78,9 @@ spec: limit: required: false type: integer + vault_secret_id: + required: false + type: string validation_data: required: false @@ -130,6 +133,9 @@ spec: limit: required: false type: integer + vault_secret_id: + required: false + type: string datetime_column: type: dict diff --git a/ads/opctl/operator/lowcode/common/utils.py b/ads/opctl/operator/lowcode/common/utils.py index 41355d2b2..dd0c92e66 100644 --- a/ads/opctl/operator/lowcode/common/utils.py +++ b/ads/opctl/operator/lowcode/common/utils.py @@ -7,7 +7,9 @@ import argparse import logging import os +import shutil import sys +import tempfile import time from string import Template from typing import Any, Dict, List, Tuple @@ -28,6 +30,7 @@ ) from ads.opctl.operator.common.operator_config import OutputDirectory from ads.common.object_storage_details import ObjectStorageDetails +from ads.secrets import ADBSecretKeeper def call_pandas_fsspec(pd_fn, filename, storage_options, **kwargs): @@ -53,10 +56,12 @@ def load_data(data_spec, storage_options=None, **kwargs): sql = data_spec.sql table_name = data_spec.table_name limit = data_spec.limit - + vault_secret_id = data_spec.vault_secret_id storage_options = storage_options or ( default_signer() if ObjectStorageDetails.is_oci_path(filename) else {} ) + if vault_secret_id is not None and connect_args is None: + connect_args = dict() if filename is not None: if not format: @@ -76,15 +81,32 @@ def load_data(data_spec, storage_options=None, **kwargs): f"The format {format} is not currently supported for reading data. Please reformat the data source: {filename} ." ) elif connect_args is not None: - con = oracledb.connect(**connect_args) - if table_name is not None: - data = pd.read_sql_table(table_name, con) - elif sql is not None: - data = pd.read_sql(sql, con) - else: - raise InvalidParameterError( - f"Database `connect_args` provided without sql query or table name. Please specify either `sql` or `table_name`." - ) + with tempfile.TemporaryDirectory() as temp_dir: + if vault_secret_id is not None: + try: + with ADBSecretKeeper.load_secret(vault_secret_id, wallet_dir=temp_dir) as adwsecret: + if 'wallet_location' in adwsecret and 'wallet_location' not in connect_args: + shutil.unpack_archive(adwsecret["wallet_location"], temp_dir) + connect_args['wallet_location'] = temp_dir + if 'user_name' in adwsecret and 'user' not in connect_args: + connect_args['user'] = adwsecret['user_name'] + if 'password' in adwsecret and 'password' not in connect_args: + connect_args['password'] = adwsecret['password'] + if 'service_name' in adwsecret and 'service_name' not in connect_args: + connect_args['service_name'] = adwsecret['service_name'] + + except Exception as e: + logger.debug(f"Could not retrieve database credentials from vault : {e}") + + con = oracledb.connect(**connect_args) + if table_name is not None: + data = pd.read_sql(f"SELECT * FROM {table_name}", con) + elif sql is not None: + data = pd.read_sql(sql, con) + else: + raise InvalidParameterError( + f"Database `connect_args` provided without sql query or table name. Please specify either `sql` or `table_name`." + ) else: raise InvalidParameterError( f"No filename/url provided, and no connect_args provided. Please specify one of these if you want to read data from a file or a database respectively." diff --git a/ads/opctl/operator/lowcode/forecast/schema.yaml b/ads/opctl/operator/lowcode/forecast/schema.yaml index 3bac1aa0f..934ab9469 100644 --- a/ads/opctl/operator/lowcode/forecast/schema.yaml +++ b/ads/opctl/operator/lowcode/forecast/schema.yaml @@ -78,6 +78,9 @@ spec: limit: required: false type: integer + vault_secret_id: + required: false + type: string additional_data: required: false @@ -130,6 +133,9 @@ spec: limit: required: false type: integer + vault_secret_id: + required: false + type: string test_data: required: false @@ -181,6 +187,9 @@ spec: limit: required: false type: integer + vault_secret_id: + required: false + type: string type: dict output_directory: From d39a87fdf1a99be6a9443d0f0b2c1eb2b05e8b33 Mon Sep 17 00:00:00 2001 From: govarsha Date: Tue, 18 Jun 2024 17:45:31 +0530 Subject: [PATCH 2/3] added unit tests --- ads/opctl/operator/lowcode/common/utils.py | 2 +- tests/operators/common/test_load_data.py | 58 ++++++++++++++++++++++ 2 files changed, 59 insertions(+), 1 deletion(-) create mode 100644 tests/operators/common/test_load_data.py diff --git a/ads/opctl/operator/lowcode/common/utils.py b/ads/opctl/operator/lowcode/common/utils.py index dd0c92e66..1c9ede754 100644 --- a/ads/opctl/operator/lowcode/common/utils.py +++ b/ads/opctl/operator/lowcode/common/utils.py @@ -96,7 +96,7 @@ def load_data(data_spec, storage_options=None, **kwargs): connect_args['service_name'] = adwsecret['service_name'] except Exception as e: - logger.debug(f"Could not retrieve database credentials from vault : {e}") + raise Exception(f"Could not retrieve database credentials from vault {vault_secret_id}: {e}") con = oracledb.connect(**connect_args) if table_name is not None: diff --git a/tests/operators/common/test_load_data.py b/tests/operators/common/test_load_data.py new file mode 100644 index 000000000..3b54c2225 --- /dev/null +++ b/tests/operators/common/test_load_data.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python + +# Copyright (c) 2024 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ +import pytest +from ads.opctl.operator.lowcode.common.utils import ( + load_data, +) +from ads.opctl.operator.common.operator_config import InputData +from unittest.mock import patch, Mock +import unittest + + +class TestDataLoad(unittest.TestCase): + def setUp(self): + self.data_spec = Mock(spec=InputData) + self.data_spec.connect_args = { + 'dsn': '(description= (retry_count=20)(retry_delay=3)(address=(protocol=tcps)(port=1522)(host=adb.us-ashburn-1.oraclecloud.com))(connect_data=(service_name=q9tjyjeyzhxqwla_h8posa0j7hooatry_high.adb.oraclecloud.com))(security=(ssl_server_dn_match=yes)))', + 'wallet_password': '@Varsha1' + } + self.data_spec.vault_secret_id = 'ocid1.vaultsecret.oc1.iad.amaaaaaav66vvnialgpfay4ys5shd6y5nu4f2tn2e3qius2s23adzipuyhqq' + self.data_spec.table_name = 'DF_SALARY' + self.data_spec.url = None + self.data_spec.format = None + self.data_spec.columns = None + self.data_spec.limit = None + + def testLoadSecretAndDBConnection(self): + data = load_data(self.data_spec) + assert len(data) == 135, f"Expected length 135, but got {len(data)}" + expected_columns = ['CODE', 'PAY_MONTH', 'FIXED_SAL'] + assert list( + data.columns) == expected_columns, f"Expected columns {expected_columns}, but got {list(data.columns)}" + + def testLoadVaultFailure(self): + msg = "Vault exception message" + + def mock_load_secret(*args, **kwargs): + raise Exception(msg) + + with patch('ads.secrets.ADBSecretKeeper.load_secret', side_effect=mock_load_secret): + with pytest.raises(Exception) as e: + load_data(self.data_spec) + + expected_msg = f"Could not retrieve database credentials from vault {self.data_spec.vault_secret_id}: {msg}" + assert str(e.value) == expected_msg, f"Expected exception message '{expected_msg}', but got '{str(e)}'" + + def testDBConnectionFailure(self): + msg = "Mocked DB connection error" + + def mock_oracledb_connect(*args, **kwargs): + raise Exception(msg) + + with patch('oracledb.connect', side_effect=mock_oracledb_connect): + with pytest.raises(Exception) as e: + load_data(self.data_spec) + + assert str(e.value) == msg, f"Expected exception message '{msg}', but got '{str(e)}'" From 1cbb6816c62ae711e1e69711b6ba01950ebfd587 Mon Sep 17 00:00:00 2001 From: govarsha Date: Wed, 19 Jun 2024 07:41:56 +0530 Subject: [PATCH 3/3] added unit tests --- tests/operators/common/test_load_data.py | 97 +++++++++++++++++------- 1 file changed, 71 insertions(+), 26 deletions(-) diff --git a/tests/operators/common/test_load_data.py b/tests/operators/common/test_load_data.py index 3b54c2225..3fc5197d9 100644 --- a/tests/operators/common/test_load_data.py +++ b/tests/operators/common/test_load_data.py @@ -1,4 +1,5 @@ #!/usr/bin/env python +from typing import Union # Copyright (c) 2024 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ @@ -7,52 +8,96 @@ load_data, ) from ads.opctl.operator.common.operator_config import InputData -from unittest.mock import patch, Mock +from unittest.mock import patch, Mock, MagicMock import unittest +import pandas as pd + +mock_secret = { + 'user_name': 'mock_user', + 'password': 'mock_password', + 'service_name': 'mock_service_name' +} + +mock_connect_args = { + 'user': 'mock_user', + 'password': 'mock_password', + 'service_name': 'mock_service_name', + 'dsn': 'mock_dsn' +} + +# Mock data for testing +mock_data = pd.DataFrame({ + 'id': [1, 2, 3], + 'name': ['Alice', 'Bob', 'Charlie'] +}) + +mock_db_connection = MagicMock() + +load_secret_err_msg = "Vault exception message" +db_connect_err_msg = "Mocked DB connection error" + + +def mock_oracledb_connect_failure(*args, **kwargs): + raise Exception(db_connect_err_msg) + + +def mock_oracledb_connect(**kwargs): + assert kwargs == mock_connect_args, f"Expected connect_args {mock_connect_args}, but got {kwargs}" + return mock_db_connection + + +class MockADBSecretKeeper: + @staticmethod + def __enter__(*args, **kwargs): + return mock_secret + + @staticmethod + def __exit__(*args, **kwargs): + pass + + @staticmethod + def load_secret(vault_secret_id, wallet_dir): + return MockADBSecretKeeper() + + @staticmethod + def load_secret_fail(*args, **kwargs): + raise Exception(load_secret_err_msg) class TestDataLoad(unittest.TestCase): def setUp(self): self.data_spec = Mock(spec=InputData) self.data_spec.connect_args = { - 'dsn': '(description= (retry_count=20)(retry_delay=3)(address=(protocol=tcps)(port=1522)(host=adb.us-ashburn-1.oraclecloud.com))(connect_data=(service_name=q9tjyjeyzhxqwla_h8posa0j7hooatry_high.adb.oraclecloud.com))(security=(ssl_server_dn_match=yes)))', - 'wallet_password': '@Varsha1' + 'dsn': 'mock_dsn' } - self.data_spec.vault_secret_id = 'ocid1.vaultsecret.oc1.iad.amaaaaaav66vvnialgpfay4ys5shd6y5nu4f2tn2e3qius2s23adzipuyhqq' - self.data_spec.table_name = 'DF_SALARY' + self.data_spec.vault_secret_id = 'mock_secret_id' + self.data_spec.table_name = 'mock_table_name' self.data_spec.url = None self.data_spec.format = None self.data_spec.columns = None self.data_spec.limit = None def testLoadSecretAndDBConnection(self): - data = load_data(self.data_spec) - assert len(data) == 135, f"Expected length 135, but got {len(data)}" - expected_columns = ['CODE', 'PAY_MONTH', 'FIXED_SAL'] - assert list( - data.columns) == expected_columns, f"Expected columns {expected_columns}, but got {list(data.columns)}" + with patch('ads.secrets.ADBSecretKeeper.load_secret', side_effect=MockADBSecretKeeper.load_secret): + with patch('oracledb.connect', side_effect=mock_oracledb_connect): + with patch('pandas.read_sql', return_value=mock_data) as mock_read_sql: + data = load_data(self.data_spec) + mock_read_sql.assert_called_once_with(f"SELECT * FROM {self.data_spec.table_name}", + mock_db_connection) + pd.testing.assert_frame_equal(data, mock_data) def testLoadVaultFailure(self): - msg = "Vault exception message" - - def mock_load_secret(*args, **kwargs): - raise Exception(msg) - - with patch('ads.secrets.ADBSecretKeeper.load_secret', side_effect=mock_load_secret): + with patch('ads.secrets.ADBSecretKeeper.load_secret', side_effect=MockADBSecretKeeper.load_secret_fail): with pytest.raises(Exception) as e: load_data(self.data_spec) - expected_msg = f"Could not retrieve database credentials from vault {self.data_spec.vault_secret_id}: {msg}" + expected_msg = f"Could not retrieve database credentials from vault {self.data_spec.vault_secret_id}: {load_secret_err_msg}" assert str(e.value) == expected_msg, f"Expected exception message '{expected_msg}', but got '{str(e)}'" def testDBConnectionFailure(self): - msg = "Mocked DB connection error" - - def mock_oracledb_connect(*args, **kwargs): - raise Exception(msg) - - with patch('oracledb.connect', side_effect=mock_oracledb_connect): - with pytest.raises(Exception) as e: - load_data(self.data_spec) + with patch('ads.secrets.ADBSecretKeeper.load_secret', side_effect=MockADBSecretKeeper.load_secret): + with patch('oracledb.connect', side_effect=mock_oracledb_connect_failure): + with pytest.raises(Exception) as e: + load_data(self.data_spec) - assert str(e.value) == msg, f"Expected exception message '{msg}', but got '{str(e)}'" + assert str(e.value) == db_connect_err_msg , f"Expected exception message '{db_connect_err_msg }', but got '{str(e)}'"