Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vault support for Forecast & AD operators #876

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions ads/opctl/operator/common/operator_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class InputData(DataClassSerializable):
limit: int = None
sql: str = None
table_name: str = None
vault_secret_id: str = None


@dataclass(repr=True)
Expand Down
6 changes: 6 additions & 0 deletions ads/opctl/operator/lowcode/anomaly/schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ spec:
limit:
required: false
type: integer
vault_secret_id:
required: false
type: string

validation_data:
required: false
Expand Down Expand Up @@ -130,6 +133,9 @@ spec:
limit:
required: false
type: integer
vault_secret_id:
required: false
type: string

datetime_column:
type: dict
Expand Down
42 changes: 32 additions & 10 deletions ads/opctl/operator/lowcode/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import argparse
import logging
import os
import shutil
import sys
import tempfile
import time
from string import Template
from typing import Any, Dict, List, Tuple
Expand All @@ -28,6 +30,7 @@
)
from ads.opctl.operator.common.operator_config import OutputDirectory
from ads.common.object_storage_details import ObjectStorageDetails
from ads.secrets import ADBSecretKeeper


def call_pandas_fsspec(pd_fn, filename, storage_options, **kwargs):
Expand All @@ -53,10 +56,12 @@ def load_data(data_spec, storage_options=None, **kwargs):
sql = data_spec.sql
table_name = data_spec.table_name
limit = data_spec.limit

vault_secret_id = data_spec.vault_secret_id
storage_options = storage_options or (
default_signer() if ObjectStorageDetails.is_oci_path(filename) else {}
)
if vault_secret_id is not None and connect_args is None:
connect_args = dict()

if filename is not None:
if not format:
Expand All @@ -76,15 +81,32 @@ def load_data(data_spec, storage_options=None, **kwargs):
f"The format {format} is not currently supported for reading data. Please reformat the data source: {filename} ."
)
elif connect_args is not None:
con = oracledb.connect(**connect_args)
if table_name is not None:
data = pd.read_sql_table(table_name, con)
elif sql is not None:
data = pd.read_sql(sql, con)
else:
raise InvalidParameterError(
f"Database `connect_args` provided without sql query or table name. Please specify either `sql` or `table_name`."
)
with tempfile.TemporaryDirectory() as temp_dir:
if vault_secret_id is not None:
try:
with ADBSecretKeeper.load_secret(vault_secret_id, wallet_dir=temp_dir) as adwsecret:
if 'wallet_location' in adwsecret and 'wallet_location' not in connect_args:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great handling here! Connect_args can overwrite the Vault.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ahosler Added unit tests.

shutil.unpack_archive(adwsecret["wallet_location"], temp_dir)
connect_args['wallet_location'] = temp_dir
if 'user_name' in adwsecret and 'user' not in connect_args:
connect_args['user'] = adwsecret['user_name']
if 'password' in adwsecret and 'password' not in connect_args:
connect_args['password'] = adwsecret['password']
if 'service_name' in adwsecret and 'service_name' not in connect_args:
connect_args['service_name'] = adwsecret['service_name']

except Exception as e:
logger.debug(f"Could not retrieve database credentials from vault : {e}")

con = oracledb.connect(**connect_args)
if table_name is not None:
data = pd.read_sql(f"SELECT * FROM {table_name}", con)
elif sql is not None:
data = pd.read_sql(sql, con)
else:
raise InvalidParameterError(
f"Database `connect_args` provided without sql query or table name. Please specify either `sql` or `table_name`."
)
else:
raise InvalidParameterError(
f"No filename/url provided, and no connect_args provided. Please specify one of these if you want to read data from a file or a database respectively."
Expand Down
9 changes: 9 additions & 0 deletions ads/opctl/operator/lowcode/forecast/schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ spec:
limit:
required: false
type: integer
vault_secret_id:
required: false
type: string

additional_data:
required: false
Expand Down Expand Up @@ -130,6 +133,9 @@ spec:
limit:
required: false
type: integer
vault_secret_id:
required: false
type: string

test_data:
required: false
Expand Down Expand Up @@ -181,6 +187,9 @@ spec:
limit:
required: false
type: integer
vault_secret_id:
required: false
type: string
type: dict

output_directory:
Expand Down