Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vault support for Forecast & AD operators #876

Merged
merged 5 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ads/opctl/operator/common/operator_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class InputData(DataClassSerializable):
limit: int = None
sql: str = None
table_name: str = None
vault_secret_id: str = None


@dataclass(repr=True)
Expand Down
6 changes: 6 additions & 0 deletions ads/opctl/operator/lowcode/anomaly/schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ spec:
limit:
required: false
type: integer
vault_secret_id:
required: false
type: string

validation_data:
required: false
Expand Down Expand Up @@ -130,6 +133,9 @@ spec:
limit:
required: false
type: integer
vault_secret_id:
required: false
type: string

datetime_column:
type: dict
Expand Down
42 changes: 32 additions & 10 deletions ads/opctl/operator/lowcode/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import argparse
import logging
import os
import shutil
import sys
import tempfile
import time
from string import Template
from typing import Any, Dict, List, Tuple
Expand All @@ -28,6 +30,7 @@
)
from ads.opctl.operator.common.operator_config import OutputDirectory
from ads.common.object_storage_details import ObjectStorageDetails
from ads.secrets import ADBSecretKeeper


def call_pandas_fsspec(pd_fn, filename, storage_options, **kwargs):
Expand All @@ -53,10 +56,12 @@ def load_data(data_spec, storage_options=None, **kwargs):
sql = data_spec.sql
table_name = data_spec.table_name
limit = data_spec.limit

vault_secret_id = data_spec.vault_secret_id
storage_options = storage_options or (
default_signer() if ObjectStorageDetails.is_oci_path(filename) else {}
)
if vault_secret_id is not None and connect_args is None:
connect_args = dict()

if filename is not None:
if not format:
Expand All @@ -76,15 +81,32 @@ def load_data(data_spec, storage_options=None, **kwargs):
f"The format {format} is not currently supported for reading data. Please reformat the data source: {filename} ."
)
elif connect_args is not None:
con = oracledb.connect(**connect_args)
if table_name is not None:
data = pd.read_sql_table(table_name, con)
elif sql is not None:
data = pd.read_sql(sql, con)
else:
raise InvalidParameterError(
f"Database `connect_args` provided without sql query or table name. Please specify either `sql` or `table_name`."
)
with tempfile.TemporaryDirectory() as temp_dir:
if vault_secret_id is not None:
try:
with ADBSecretKeeper.load_secret(vault_secret_id, wallet_dir=temp_dir) as adwsecret:
if 'wallet_location' in adwsecret and 'wallet_location' not in connect_args:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great handling here! Connect_args can overwrite the Vault.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ahosler Added unit tests.

shutil.unpack_archive(adwsecret["wallet_location"], temp_dir)
connect_args['wallet_location'] = temp_dir
if 'user_name' in adwsecret and 'user' not in connect_args:
connect_args['user'] = adwsecret['user_name']
if 'password' in adwsecret and 'password' not in connect_args:
connect_args['password'] = adwsecret['password']
if 'service_name' in adwsecret and 'service_name' not in connect_args:
connect_args['service_name'] = adwsecret['service_name']

except Exception as e:
raise Exception(f"Could not retrieve database credentials from vault {vault_secret_id}: {e}")

con = oracledb.connect(**connect_args)
if table_name is not None:
data = pd.read_sql(f"SELECT * FROM {table_name}", con)
elif sql is not None:
data = pd.read_sql(sql, con)
else:
raise InvalidParameterError(
f"Database `connect_args` provided without sql query or table name. Please specify either `sql` or `table_name`."
)
else:
raise InvalidParameterError(
f"No filename/url provided, and no connect_args provided. Please specify one of these if you want to read data from a file or a database respectively."
Expand Down
9 changes: 9 additions & 0 deletions ads/opctl/operator/lowcode/forecast/schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ spec:
limit:
required: false
type: integer
vault_secret_id:
required: false
type: string

additional_data:
required: false
Expand Down Expand Up @@ -130,6 +133,9 @@ spec:
limit:
required: false
type: integer
vault_secret_id:
required: false
type: string

test_data:
required: false
Expand Down Expand Up @@ -181,6 +187,9 @@ spec:
limit:
required: false
type: integer
vault_secret_id:
required: false
type: string
type: dict

output_directory:
Expand Down
103 changes: 103 additions & 0 deletions tests/operators/common/test_load_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#!/usr/bin/env python
from typing import Union

# Copyright (c) 2024 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
import pytest
from ads.opctl.operator.lowcode.common.utils import (
load_data,
)
from ads.opctl.operator.common.operator_config import InputData
from unittest.mock import patch, Mock, MagicMock
import unittest
import pandas as pd

mock_secret = {
'user_name': 'mock_user',
'password': 'mock_password',
'service_name': 'mock_service_name'
}

mock_connect_args = {
'user': 'mock_user',
'password': 'mock_password',
'service_name': 'mock_service_name',
'dsn': 'mock_dsn'
}

# Mock data for testing
mock_data = pd.DataFrame({
'id': [1, 2, 3],
'name': ['Alice', 'Bob', 'Charlie']
})

mock_db_connection = MagicMock()

load_secret_err_msg = "Vault exception message"
db_connect_err_msg = "Mocked DB connection error"


def mock_oracledb_connect_failure(*args, **kwargs):
raise Exception(db_connect_err_msg)


def mock_oracledb_connect(**kwargs):
assert kwargs == mock_connect_args, f"Expected connect_args {mock_connect_args}, but got {kwargs}"
return mock_db_connection


class MockADBSecretKeeper:
@staticmethod
def __enter__(*args, **kwargs):
return mock_secret

@staticmethod
def __exit__(*args, **kwargs):
pass

@staticmethod
def load_secret(vault_secret_id, wallet_dir):
return MockADBSecretKeeper()

@staticmethod
def load_secret_fail(*args, **kwargs):
raise Exception(load_secret_err_msg)


class TestDataLoad(unittest.TestCase):
def setUp(self):
self.data_spec = Mock(spec=InputData)
self.data_spec.connect_args = {
'dsn': 'mock_dsn'
}
self.data_spec.vault_secret_id = 'mock_secret_id'
self.data_spec.table_name = 'mock_table_name'
self.data_spec.url = None
self.data_spec.format = None
self.data_spec.columns = None
self.data_spec.limit = None

def testLoadSecretAndDBConnection(self):
with patch('ads.secrets.ADBSecretKeeper.load_secret', side_effect=MockADBSecretKeeper.load_secret):
with patch('oracledb.connect', side_effect=mock_oracledb_connect):
with patch('pandas.read_sql', return_value=mock_data) as mock_read_sql:
data = load_data(self.data_spec)
mock_read_sql.assert_called_once_with(f"SELECT * FROM {self.data_spec.table_name}",
mock_db_connection)
pd.testing.assert_frame_equal(data, mock_data)

def testLoadVaultFailure(self):
with patch('ads.secrets.ADBSecretKeeper.load_secret', side_effect=MockADBSecretKeeper.load_secret_fail):
with pytest.raises(Exception) as e:
load_data(self.data_spec)

expected_msg = f"Could not retrieve database credentials from vault {self.data_spec.vault_secret_id}: {load_secret_err_msg}"
assert str(e.value) == expected_msg, f"Expected exception message '{expected_msg}', but got '{str(e)}'"

def testDBConnectionFailure(self):
with patch('ads.secrets.ADBSecretKeeper.load_secret', side_effect=MockADBSecretKeeper.load_secret):
with patch('oracledb.connect', side_effect=mock_oracledb_connect_failure):
with pytest.raises(Exception) as e:
load_data(self.data_spec)

assert str(e.value) == db_connect_err_msg , f"Expected exception message '{db_connect_err_msg }', but got '{str(e)}'"