From 57aad4805f54a006d24664322a76165abf7e6f23 Mon Sep 17 00:00:00 2001 From: Abhishek Singh Date: Fri, 15 Aug 2025 13:00:30 -0700 Subject: [PATCH 1/4] bug fixes related to profile attributes and optional attributes --- samples/async/create_ai_credential.py | 44 +++++++++ samples/async/delete_ai_credential.py | 32 ++++++ samples/async/disable_ai_provider.py | 33 +++++++ samples/async/enable_ai_provider.py | 33 +++++++ samples/delete_ai_credential.py | 25 +++++ src/select_ai/__init__.py | 15 ++- src/select_ai/async_profile.py | 20 +++- src/select_ai/base_profile.py | 10 +- src/select_ai/conversation.py | 20 +++- src/select_ai/credential.py | 135 ++++++++++++++++++++++++++ src/select_ai/db.py | 4 +- src/select_ai/profile.py | 26 +++-- src/select_ai/provider.py | 106 +++++++++++++++++++- src/select_ai/vector_index.py | 10 +- src/select_ai/version.py | 2 +- 15 files changed, 486 insertions(+), 29 deletions(-) create mode 100644 samples/async/create_ai_credential.py create mode 100644 samples/async/delete_ai_credential.py create mode 100644 samples/async/disable_ai_provider.py create mode 100644 samples/async/enable_ai_provider.py create mode 100644 samples/delete_ai_credential.py create mode 100644 src/select_ai/credential.py diff --git a/samples/async/create_ai_credential.py b/samples/async/create_ai_credential.py new file mode 100644 index 0000000..45af376 --- /dev/null +++ b/samples/async/create_ai_credential.py @@ -0,0 +1,44 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +# ----------------------------------------------------------------------------- +# async/create_ai_credential.py +# +# Async API to create credential +# ----------------------------------------------------------------------------- + +import asyncio +import os + +import oci +import select_ai + +user = os.getenv("SELECT_AI_USER") +password = os.getenv("SELECT_AI_PASSWORD") +dsn = os.getenv("SELECT_AI_DB_CONNECT_STRING") + + +async def main(): + await select_ai.async_connect(user=user, password=password, dsn=dsn) + default_config = oci.config.from_file() + oci.config.validate_config(default_config) + with open(default_config["key_file"]) as fp: + key_contents = fp.read() + credential = { + "credential_name": "my_oci_ai_profile_key", + "user_ocid": default_config["user"], + "tenancy_ocid": default_config["tenancy"], + "private_key": key_contents, + "fingerprint": default_config["fingerprint"], + } + await select_ai.async_create_credential( + credential=credential, replace=True + ) + print("Created credential: ", credential["credential_name"]) + + +asyncio.run(main()) diff --git a/samples/async/delete_ai_credential.py b/samples/async/delete_ai_credential.py new file mode 100644 index 0000000..8affbb2 --- /dev/null +++ b/samples/async/delete_ai_credential.py @@ -0,0 +1,32 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +# ----------------------------------------------------------------------------- +# async/create_ai_credential.py +# +# Async API to create credential +# ----------------------------------------------------------------------------- + +import asyncio +import os + +import select_ai + +user = os.getenv("SELECT_AI_USER") +password = os.getenv("SELECT_AI_PASSWORD") +dsn = os.getenv("SELECT_AI_DB_CONNECT_STRING") + + +async def main(): + await select_ai.async_connect(user=user, password=password, dsn=dsn) + await select_ai.async_delete_credential( + credential_name="my_oci_ai_profile_key", force=True + ) + print("Deleted credential: my_oci_ai_profile_key") + + +asyncio.run(main()) diff --git a/samples/async/disable_ai_provider.py b/samples/async/disable_ai_provider.py new file mode 100644 index 0000000..8e53601 --- /dev/null +++ b/samples/async/disable_ai_provider.py @@ -0,0 +1,33 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +# ----------------------------------------------------------------------------- +# async/disable_ai_provider.py +# +# Async API to disable AI provider for database users +# ----------------------------------------------------------------------------- + +import asyncio +import os + +import select_ai + +admin_user = os.getenv("SELECT_AI_ADMIN_USER") +password = os.getenv("SELECT_AI_ADMIN_PASSWORD") +dsn = os.getenv("SELECT_AI_DB_CONNECT_STRING") +select_ai_user = os.getenv("SELECT_AI_USER") + + +async def main(): + await select_ai.async_connect(user=admin_user, password=password, dsn=dsn) + await select_ai.async_disable_provider( + users=select_ai_user, provider_endpoint="*.openai.azure.com" + ) + print("Disabled AI provider for user: ", select_ai_user) + + +asyncio.run(main()) diff --git a/samples/async/enable_ai_provider.py b/samples/async/enable_ai_provider.py new file mode 100644 index 0000000..ef5c5cb --- /dev/null +++ b/samples/async/enable_ai_provider.py @@ -0,0 +1,33 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +# ----------------------------------------------------------------------------- +# async/enable_ai_provider.py +# +# Async API to enable AI provider for database users +# ----------------------------------------------------------------------------- + +import asyncio +import os + +import select_ai + +admin_user = os.getenv("SELECT_AI_ADMIN_USER") +password = os.getenv("SELECT_AI_ADMIN_PASSWORD") +dsn = os.getenv("SELECT_AI_DB_CONNECT_STRING") +select_ai_user = os.getenv("SELECT_AI_USER") + + +async def main(): + await select_ai.async_connect(user=admin_user, password=password, dsn=dsn) + await select_ai.async_enable_provider( + users=select_ai_user, provider_endpoint="*.openai.azure.com" + ) + print("Enabled AI provider for user: ", select_ai_user) + + +asyncio.run(main()) diff --git a/samples/delete_ai_credential.py b/samples/delete_ai_credential.py new file mode 100644 index 0000000..457ffa4 --- /dev/null +++ b/samples/delete_ai_credential.py @@ -0,0 +1,25 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +# ----------------------------------------------------------------------------- +# delete_ai_credential.py +# +# Create a Database credential storing OCI Gen AI's credentials +# ----------------------------------------------------------------------------- +import os + +import select_ai + +user = os.getenv("SELECT_AI_USER") +password = os.getenv("SELECT_AI_PASSWORD") +dsn = os.getenv("SELECT_AI_DB_CONNECT_STRING") + +select_ai.connect(user=user, password=password, dsn=dsn) +select_ai.delete_credential( + credential_name="my_oci_ai_profile_key", force=True +) +print("Deleted credential: my_oci_ai_profile_key") diff --git a/src/select_ai/__init__.py b/src/select_ai/__init__.py index cc79842..54b50fb 100644 --- a/src/select_ai/__init__.py +++ b/src/select_ai/__init__.py @@ -6,11 +6,6 @@ # ----------------------------------------------------------------------------- from .action import Action -from .admin import ( - create_credential, - disable_provider, - enable_provider, -) from .async_profile import AsyncProfile from .base_profile import BaseProfile, ProfileAttributes from .conversation import ( @@ -18,6 +13,12 @@ Conversation, ConversationAttributes, ) +from .credential import ( + async_create_credential, + async_delete_credential, + create_credential, + delete_credential, +) from .db import ( async_connect, async_cursor, @@ -39,6 +40,10 @@ OCIGenAIProvider, OpenAIProvider, Provider, + async_disable_provider, + async_enable_provider, + disable_provider, + enable_provider, ) from .synthetic_data import ( SyntheticDataAttributes, diff --git a/src/select_ai/async_profile.py b/src/select_ai/async_profile.py index a00b7ad..bbaf0b3 100644 --- a/src/select_ai/async_profile.py +++ b/src/select_ai/async_profile.py @@ -55,7 +55,7 @@ async def _init_profile(self): :return: None :raises: oracledb.DatabaseError """ - if self.profile_name is not None: + if self.profile_name: profile_exists = False try: saved_attributes = await self._get_attributes( @@ -75,7 +75,7 @@ async def _init_profile(self): profile_name=self.profile_name ) except ProfileNotFoundError: - if self.attributes is None: + if self.attributes is None and self.description is None: raise else: if self.attributes is None: @@ -91,10 +91,13 @@ async def _init_profile(self): await self.create( replace=self.replace, description=self.description ) + else: # profile name is None: + if self.attributes is not None or self.description is not None: + raise ValueError("'profile_name' cannot be empty or None") return self @staticmethod - async def _get_profile_description(profile_name) -> str: + async def _get_profile_description(profile_name) -> Union[str, None]: """Get description of profile from USER_CLOUD_AI_PROFILES :param str profile_name: Name of profile @@ -110,7 +113,10 @@ async def _get_profile_description(profile_name) -> str: ) profile = await cr.fetchone() if profile: - return await profile[1].read() + if profile[1] is not None: + return await profile[1].read() + else: + return None else: raise ProfileNotFoundError(profile_name) @@ -186,6 +192,12 @@ async def set_attributes(self, attributes: ProfileAttributes): attributes :return: None """ + if not isinstance(attributes, ProfileAttributes): + raise TypeError( + "'attributes' must be an object of type " + "select_ai.ProfileAttributes" + ) + self.attributes = attributes parameters = { "profile_name": self.profile_name, diff --git a/src/select_ai/base_profile.py b/src/select_ai/base_profile.py index 431b792..5336ca2 100644 --- a/src/select_ai/base_profile.py +++ b/src/select_ai/base_profile.py @@ -73,10 +73,9 @@ class ProfileAttributes(SelectAIDataClass): vector_index_name: Optional[str] = None def __post_init__(self): - if not isinstance(self.provider, Provider): + if self.provider and not isinstance(self.provider, Provider): raise ValueError( - f"The arg `provider` must be an object of " - f"type select_ai.Provider" + f"'provider' must be an object of " f"type select_ai.Provider" ) def json(self, exclude_null=True): @@ -166,6 +165,11 @@ def __init__( ): """Initialize a base profile""" self.profile_name = profile_name + if attributes and not isinstance(attributes, ProfileAttributes): + raise TypeError( + "'attributes' must be an object of type " + "select_ai.ProfileAttributes" + ) self.attributes = attributes self.description = description self.merge = merge diff --git a/src/select_ai/conversation.py b/src/select_ai/conversation.py index 600a52b..84fb10e 100644 --- a/src/select_ai/conversation.py +++ b/src/select_ai/conversation.py @@ -129,7 +129,10 @@ def get_attributes(self) -> ConversationAttributes: attributes = cr.fetchone() if attributes: conversation_title = attributes[0] - description = attributes[1].read() # Oracle.LOB + if attributes[1]: + description = attributes[1].read() # Oracle.LOB + else: + description = None retention_days = attributes[2] return ConversationAttributes( title=conversation_title, @@ -154,7 +157,10 @@ def list(cls) -> Iterator["Conversation"]: for row in cr.fetchall(): conversation_id = row[0] conversation_title = row[1] - description = row[2].read() # Oracle.LOB + if row[2]: + description = row[2].read() # Oracle.LOB + else: + description = None retention_days = row[3] attributes = ConversationAttributes( title=conversation_title, @@ -224,7 +230,10 @@ async def get_attributes(self) -> ConversationAttributes: attributes = await cr.fetchone() if attributes: conversation_title = attributes[0] - description = await attributes[1].read() # Oracle.AsyncLOB + if attributes[1]: + description = await attributes[1].read() # Oracle.AsyncLOB + else: + description = None retention_days = attributes[2] return ConversationAttributes( title=conversation_title, @@ -250,7 +259,10 @@ async def list(cls) -> AsyncGenerator["AsyncConversation", None]: for row in rows: conversation_id = row[0] conversation_title = row[1] - description = await row[2].read() # Oracle.AsyncLOB + if row[2]: + description = await row[2].read() # Oracle.AsyncLOB + else: + description = None retention_days = row[3] attributes = ConversationAttributes( title=conversation_title, diff --git a/src/select_ai/credential.py b/src/select_ai/credential.py new file mode 100644 index 0000000..17df49b --- /dev/null +++ b/src/select_ai/credential.py @@ -0,0 +1,135 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +from typing import Mapping + +import oracledb + +from .db import async_cursor, cursor + +__all__ = [ + "async_create_credential", + "async_delete_credential", + "create_credential", + "delete_credential", +] + + +def _validate_credential(credential: Mapping[str, str]): + valid_keys = { + "credential_name", + "username", + "password", + "user_ocid", + "tenancy_ocid", + "private_key", + "fingerprint", + "comments", + } + for k in credential.keys(): + if k.lower() not in valid_keys: + raise ValueError( + f"Invalid value {k}: {credential[k]} for credential object" + ) + + +async def async_create_credential(credential: Mapping, replace: bool = False): + """ + Async API to create credential. + + Creates a credential object using DBMS_CLOUD.CREATE_CREDENTIAL. if replace + is True, credential will be replaced if it already exists + + """ + _validate_credential(credential) + async with async_cursor() as cr: + try: + await cr.callproc( + "DBMS_CLOUD.CREATE_CREDENTIAL", keyword_parameters=credential + ) + except oracledb.DatabaseError as e: + (error,) = e.args + # If already exists and replace is True then drop and recreate + if error.code == 20022 and replace: + await cr.callproc( + "DBMS_CLOUD.DROP_CREDENTIAL", + keyword_parameters={ + "credential_name": credential["credential_name"] + }, + ) + await cr.callproc( + "DBMS_CLOUD.CREATE_CREDENTIAL", + keyword_parameters=credential, + ) + else: + raise + + +async def async_delete_credential(credential_name: str, force: bool = False): + """ + Async API to create credential. + + Deletes a credential object using DBMS_CLOUD.DROP_CREDENTIAL + """ + async with async_cursor() as cr: + try: + await cr.callproc( + "DBMS_CLOUD.DROP_CREDENTIAL", + keyword_parameters={"credential_name": credential_name}, + ) + except oracledb.DatabaseError as e: + (error,) = e.args + if error.code == 20004 and force: # does not exist + pass + else: + raise + + +def create_credential(credential: Mapping, replace: bool = False): + """ + + Creates a credential object using DBMS_CLOUD.CREATE_CREDENTIAL. if replace + is True, credential will be replaced if it "already exists" + + """ + _validate_credential(credential) + with cursor() as cr: + try: + cr.callproc( + "DBMS_CLOUD.CREATE_CREDENTIAL", keyword_parameters=credential + ) + except oracledb.DatabaseError as e: + (error,) = e.args + # If already exists and replace is True then drop and recreate + if error.code == 20022 and replace: + cr.callproc( + "DBMS_CLOUD.DROP_CREDENTIAL", + keyword_parameters={ + "credential_name": credential["credential_name"] + }, + ) + cr.callproc( + "DBMS_CLOUD.CREATE_CREDENTIAL", + keyword_parameters=credential, + ) + else: + raise + + +def delete_credential(credential_name: str, force: bool = False): + with cursor() as cr: + try: + cr.callproc( + "DBMS_CLOUD.DROP_CREDENTIAL", + keyword_parameters={"credential_name": credential_name}, + ) + except oracledb.DatabaseError as e: + (error,) = e.args + if error.code == 20004 and force: # does not exist + pass + else: + raise diff --git a/src/select_ai/db.py b/src/select_ai/db.py index aa10986..1a7987a 100644 --- a/src/select_ai/db.py +++ b/src/select_ai/db.py @@ -73,7 +73,7 @@ def is_connected() -> bool: return False try: return conn.ping() is None - except oracledb.DatabaseError: + except (oracledb.DatabaseError, oracledb.InterfaceError): return False @@ -87,7 +87,7 @@ async def async_is_connected() -> bool: return False try: return await conn.ping() is None - except oracledb.DatabaseError: + except (oracledb.DatabaseError, oracledb.InterfaceError): return False diff --git a/src/select_ai/profile.py b/src/select_ai/profile.py index 45b4e2a..eaeb5b9 100644 --- a/src/select_ai/profile.py +++ b/src/select_ai/profile.py @@ -44,7 +44,7 @@ def _init_profile(self) -> None: :return: None :raises: oracledb.DatabaseError """ - if self.profile_name is not None: + if self.profile_name: profile_exists = False try: saved_attributes = self._get_attributes( @@ -64,7 +64,7 @@ def _init_profile(self) -> None: profile_name=self.profile_name ) except ProfileNotFoundError: - if self.attributes is None: + if self.attributes is None and self.description is None: raise else: if self.attributes is None: @@ -78,20 +78,28 @@ def _init_profile(self) -> None: ) if self.replace or not profile_exists: self.create(replace=self.replace) + else: # profile name is None + if self.attributes is not None or self.description is not None: + raise ValueError( + "Attribute 'profile_name' cannot be empty or None" + ) @staticmethod - def _get_profile_description(profile_name) -> str: + def _get_profile_description(profile_name) -> Union[str, None]: """Get description of profile from USER_CLOUD_AI_PROFILES :param str profile_name: - :return: str + :return: Union[str, None] profile description :raises: ProfileNotFoundError """ with cursor() as cr: cr.execute(GET_USER_AI_PROFILE, profile_name=profile_name.upper()) profile = cr.fetchone() if profile: - return profile[1].read() + if profile[1] is not None: + return profile[1].read() + else: + return None else: raise ProfileNotFoundError(profile_name) @@ -165,6 +173,11 @@ def set_attributes(self, attributes: ProfileAttributes): attributes :return: None """ + if not isinstance(attributes, ProfileAttributes): + raise TypeError( + "'attributes' must be an object of type" + " select_ai.ProfileAttributes" + ) self.attributes = attributes parameters = { "profile_name": self.profile_name, @@ -182,7 +195,8 @@ def create(self, replace: Optional[int] = False) -> None: :return: None :raises: oracledb.DatabaseError """ - + if self.attributes is None: + raise AttributeError("Profile attributes cannot be None") parameters = { "profile_name": self.profile_name, "attributes": self.attributes.json(), diff --git a/src/select_ai/provider.py b/src/select_ai/provider.py index ffa3018..9dec23c 100644 --- a/src/select_ai/provider.py +++ b/src/select_ai/provider.py @@ -5,11 +5,19 @@ # http://oss.oracle.com/licenses/upl. # ----------------------------------------------------------------------------- -from dataclasses import dataclass, fields -from typing import Optional +from dataclasses import dataclass +from typing import List, Optional, Union from select_ai._abc import SelectAIDataClass +from .db import async_cursor, cursor +from .sql import ( + DISABLE_AI_PROFILE_DOMAIN_FOR_USER, + ENABLE_AI_PROFILE_DOMAIN_FOR_USER, + GRANT_PRIVILEGES_TO_USER, + REVOKE_PRIVILEGES_FROM_USER, +) + OPENAI = "openai" COHERE = "cohere" AZURE = "azure" @@ -184,3 +192,97 @@ class AnthropicProvider(Provider): provider_name: str = ANTHROPIC provider_endpoint = "api.anthropic.com" + + +async def async_enable_provider( + users: Union[str, List[str]], provider_endpoint: str = None +): + """ + Async API to enable AI profile for database users. + + This method grants execute privilege on the packages DBMS_CLOUD, + DBMS_CLOUD_AI and DBMS_CLOUD_PIPELINE. It also enables the database + user to invoke the AI Provider (LLM) endpoint + + """ + if isinstance(users, str): + users = [users] + + async with async_cursor() as cr: + for user in users: + await cr.execute(GRANT_PRIVILEGES_TO_USER.format(user)) + if provider_endpoint: + await cr.execute( + ENABLE_AI_PROFILE_DOMAIN_FOR_USER, + user=user, + host=provider_endpoint, + ) + + +async def async_disable_provider( + users: Union[str, List[str]], provider_endpoint: str = None +): + """ + Async API to disable AI profile for database users + + Disables AI provider for the user. This method revokes execute privilege + on the packages DBMS_CLOUD, DBMS_CLOUD_AI and DBMS_CLOUD_PIPELINE. It + also disables the user to invoke the AI Provider (LLM) endpoint + """ + if isinstance(users, str): + users = [users] + + async with async_cursor() as cr: + for user in users: + await cr.execute(REVOKE_PRIVILEGES_FROM_USER.format(user)) + if provider_endpoint: + await cr.execute( + DISABLE_AI_PROFILE_DOMAIN_FOR_USER, + user=user, + host=provider_endpoint, + ) + + +def enable_provider( + users: Union[str, List[str]], provider_endpoint: str = None +): + """ + Enables AI profile for the user. This method grants execute privilege + on the packages DBMS_CLOUD, DBMS_CLOUD_AI and DBMS_CLOUD_PIPELINE. It + also enables the user to invoke the AI Provider (LLM) endpoint + """ + if isinstance(users, str): + users = [users] + + with cursor() as cr: + for user in users: + cr.execute(GRANT_PRIVILEGES_TO_USER.format(user)) + if provider_endpoint: + cr.execute( + ENABLE_AI_PROFILE_DOMAIN_FOR_USER, + user=user, + host=provider_endpoint, + ) + + +def disable_provider( + users: Union[str, List[str]], provider_endpoint: str = None +): + """ + Disables AI provider for the user. This method revokes execute privilege + on the packages DBMS_CLOUD, DBMS_CLOUD_AI and DBMS_CLOUD_PIPELINE. It + also disables the user to invoke the AI(LLM) endpoint + + """ + if isinstance(users, str): + users = [users] + + with cursor() as cr: + for user in users: + cr.execute(REVOKE_PRIVILEGES_FROM_USER.format(user)) + if provider_endpoint: + cr.execute( + DISABLE_AI_PROFILE_DOMAIN_FOR_USER, + user=user, + host=provider_endpoint, + ) diff --git a/src/select_ai/vector_index.py b/src/select_ai/vector_index.py index c6078c4..d18d03b 100644 --- a/src/select_ai/vector_index.py +++ b/src/select_ai/vector_index.py @@ -325,7 +325,10 @@ def list(cls, index_name_pattern: str = ".*") -> Iterator["VectorIndex"]: ) for row in cr.fetchall(): index_name = row[0] - description = row[1].read() # Oracle.LOB + if row[1]: + description = row[1].read() # Oracle.LOB + else: + description = None attributes = cls._get_attributes(index_name=index_name) yield cls( index_name=index_name, @@ -534,7 +537,10 @@ async def list( rows = await cr.fetchall() for row in rows: index_name = row[0] - description = await row[1].read() # AsyncLOB + if row[1]: + description = await row[1].read() # AsyncLOB + else: + description = None attributes = await cls._get_attributes(index_name=index_name) yield VectorIndex( index_name=index_name, diff --git a/src/select_ai/version.py b/src/select_ai/version.py index bfcce91..71eeca5 100644 --- a/src/select_ai/version.py +++ b/src/select_ai/version.py @@ -5,4 +5,4 @@ # http://oss.oracle.com/licenses/upl. # ----------------------------------------------------------------------------- -__version__ = "1.0.0.dev7" +__version__ = "1.0.0.dev8" From ac3d3bca99bc7d6871ac44bc232e93484237f92d Mon Sep 17 00:00:00 2001 From: Abhishek Singh Date: Fri, 15 Aug 2025 15:08:51 -0700 Subject: [PATCH 2/4] Use error code to check for objects which already exists --- samples/async/delete_ai_credential.py | 4 +- samples/delete_ai_credential.py | 2 +- src/select_ai/admin.py | 116 -------------------------- src/select_ai/async_profile.py | 5 +- src/select_ai/profile.py | 2 +- src/select_ai/vector_index.py | 4 +- 6 files changed, 9 insertions(+), 124 deletions(-) delete mode 100644 src/select_ai/admin.py diff --git a/samples/async/delete_ai_credential.py b/samples/async/delete_ai_credential.py index 8affbb2..94e043e 100644 --- a/samples/async/delete_ai_credential.py +++ b/samples/async/delete_ai_credential.py @@ -6,9 +6,9 @@ # ----------------------------------------------------------------------------- # ----------------------------------------------------------------------------- -# async/create_ai_credential.py +# async/delete_ai_credential # -# Async API to create credential +# Async API to delete credential # ----------------------------------------------------------------------------- import asyncio diff --git a/samples/delete_ai_credential.py b/samples/delete_ai_credential.py index 457ffa4..1d54cbb 100644 --- a/samples/delete_ai_credential.py +++ b/samples/delete_ai_credential.py @@ -8,7 +8,7 @@ # ----------------------------------------------------------------------------- # delete_ai_credential.py # -# Create a Database credential storing OCI Gen AI's credentials +# Delete AI credential # ----------------------------------------------------------------------------- import os diff --git a/src/select_ai/admin.py b/src/select_ai/admin.py deleted file mode 100644 index 0d195fb..0000000 --- a/src/select_ai/admin.py +++ /dev/null @@ -1,116 +0,0 @@ -# ----------------------------------------------------------------------------- -# Copyright (c) 2025, Oracle and/or its affiliates. -# -# Licensed under the Universal Permissive License v 1.0 as shown at -# http://oss.oracle.com/licenses/upl. -# ----------------------------------------------------------------------------- - -from typing import List, Mapping, Union - -import oracledb - -from .db import cursor -from .sql import ( - DISABLE_AI_PROFILE_DOMAIN_FOR_USER, - ENABLE_AI_PROFILE_DOMAIN_FOR_USER, - GRANT_PRIVILEGES_TO_USER, - REVOKE_PRIVILEGES_FROM_USER, -) - -__all__ = [ - "create_credential", - "disable_provider", - "enable_provider", -] - - -def create_credential(credential: Mapping, replace: bool = False): - """ - Creates a credential object using DBMS_CLOUD.CREATE_CREDENTIAL - - if replace is True, credential will be replaced if it "already exists" - - """ - valid_keys = { - "credential_name", - "username", - "password", - "user_ocid", - "tenancy_ocid", - "private_key", - "fingerprint", - "comments", - } - for k in credential.keys(): - if k.lower() not in valid_keys: - raise ValueError( - f"Invalid value {k}: {credential[k]} for credential object" - ) - - with cursor() as cr: - try: - cr.callproc( - "DBMS_CLOUD.CREATE_CREDENTIAL", keyword_parameters=credential - ) - except oracledb.DatabaseError as e: - (error,) = e.args - # If already exists and replace is True then drop and recreate - if "already exists" in error.message.lower() and replace: - cr.callproc( - "DBMS_CLOUD.DROP_CREDENTIAL", - keyword_parameters={ - "credential_name": credential["credential_name"] - }, - ) - cr.callproc( - "DBMS_CLOUD.CREATE_CREDENTIAL", - keyword_parameters=credential, - ) - else: - raise - - -def enable_provider( - users: Union[str, List[str]], provider_endpoint: str = None -): - """ - Enables AI profile for the user. This method grants execute privilege - on the packages DBMS_CLOUD, DBMS_CLOUD_AI and DBMS_CLOUD_PIPELINE. It - also enables the user to invoke the AI(LLM) endpoint hosted at a - certain domain - """ - if isinstance(users, str): - users = [users] - - with cursor() as cr: - for user in users: - cr.execute(GRANT_PRIVILEGES_TO_USER.format(user)) - if provider_endpoint: - cr.execute( - ENABLE_AI_PROFILE_DOMAIN_FOR_USER, - user=user, - host=provider_endpoint, - ) - - -def disable_provider( - users: Union[str, List[str]], provider_endpoint: str = None -): - """ - Disables AI provider for the user. This method revokes execute privilege - on the packages DBMS_CLOUD, DBMS_CLOUD_AI and DBMS_CLOUD_PIPELINE. It - also disables the user to invoke the AI(LLM) endpoint hosted at a - certain domain - """ - if isinstance(users, str): - users = [users] - - with cursor() as cr: - for user in users: - cr.execute(REVOKE_PRIVILEGES_FROM_USER.format(user)) - if provider_endpoint: - cr.execute( - DISABLE_AI_PROFILE_DOMAIN_FOR_USER, - user=user, - host=provider_endpoint, - ) diff --git a/src/select_ai/async_profile.py b/src/select_ai/async_profile.py index bbaf0b3..dbe2e3e 100644 --- a/src/select_ai/async_profile.py +++ b/src/select_ai/async_profile.py @@ -218,13 +218,14 @@ async def create( :return: None :raises: oracledb.DatabaseError """ + if self.attributes is None: + raise AttributeError("Profile attributes cannot be None") parameters = { "profile_name": self.profile_name, "attributes": self.attributes.json(), } if description: parameters["description"] = description - async with async_cursor() as cr: try: await cr.callproc( @@ -234,7 +235,7 @@ async def create( except oracledb.DatabaseError as e: (error,) = e.args # If already exists and replace is True then drop and recreate - if "already exists" in error.message.lower() and replace: + if error.code == 20046 and replace: await self.delete(force=True) await cr.callproc( "DBMS_CLOUD_AI.CREATE_PROFILE", diff --git a/src/select_ai/profile.py b/src/select_ai/profile.py index eaeb5b9..61fed81 100644 --- a/src/select_ai/profile.py +++ b/src/select_ai/profile.py @@ -213,7 +213,7 @@ def create(self, replace: Optional[int] = False) -> None: except oracledb.DatabaseError as e: (error,) = e.args # If already exists and replace is True then drop and recreate - if "already exists" in error.message.lower() and replace: + if error.code == 20046 and replace: self.delete(force=True) cr.callproc( "DBMS_CLOUD_AI.CREATE_PROFILE", diff --git a/src/select_ai/vector_index.py b/src/select_ai/vector_index.py index d18d03b..9143d88 100644 --- a/src/select_ai/vector_index.py +++ b/src/select_ai/vector_index.py @@ -192,7 +192,7 @@ def create(self, replace: Optional[bool] = False): except oracledb.DatabaseError as e: (error,) = e.args # If already exists and replace is True then drop and recreate - if "already exists" in error.message.lower() and replace: + if error.code == 20048 and replace: self.delete(force=True) cr.callproc( "DBMS_CLOUD_AI.CREATE_VECTOR_INDEX", @@ -399,7 +399,7 @@ async def create(self, replace: Optional[bool] = False) -> None: except oracledb.DatabaseError as e: (error,) = e.args # If already exists and replace is True then drop and recreate - if "already exists" in error.message.lower() and replace: + if error.code == 20048 and replace: await self.delete(force=True) await cr.callproc( "DBMS_CLOUD_AI.CREATE_VECTOR_INDEX", From d0e370d076af1cc8197468a3c9a0bb368825ce77 Mon Sep 17 00:00:00 2001 From: Abhishek Singh Date: Fri, 15 Aug 2025 18:51:49 -0700 Subject: [PATCH 3/4] Beta release documentation, readme and samples updates --- README.md | 61 +++++++++++++++++-- doc/source/user_guide/async_profile.rst | 8 +++ doc/source/user_guide/conversation.rst | 6 ++ doc/source/user_guide/credential.rst | 17 ++++++ doc/source/user_guide/profile.rst | 11 +++- doc/source/user_guide/provider.rst | 36 ++++++++++- doc/source/user_guide/synthetic_data.rst | 56 +++++++++++++++++ doc/source/user_guide/vector_index.rst | 7 +++ pyproject.toml | 10 ++- samples/async/profile_explain_sql.py | 3 +- .../profile_gen_multi_table_synthetic_data.py | 48 +++++++++++++++ ...profile_gen_single_table_synthetic_data.py | 41 +++++++++++++ samples/async/profile_run_sql.py | 2 +- samples/async/profile_show_sql.py | 2 +- samples/async/profile_sql_concurrent_tasks.py | 4 +- samples/async/vector_index_create.py | 1 - samples/async/vector_index_rag.py | 1 - samples/conversation_chat_session.py | 1 + samples/conversation_create.py | 1 + samples/create_ai_credential.py | 1 + samples/profile_explain_sql.py | 2 +- samples/profile_narrate.py | 2 +- samples/profile_run_sql.py | 4 +- samples/profile_show_sql.py | 4 +- src/select_ai/async_profile.py | 14 ++++- src/select_ai/profile.py | 13 ++++ src/select_ai/synthetic_data.py | 6 ++ src/select_ai/version.py | 2 +- 28 files changed, 336 insertions(+), 28 deletions(-) create mode 100644 samples/async/profile_gen_multi_table_synthetic_data.py create mode 100644 samples/async/profile_gen_single_table_synthetic_data.py diff --git a/README.md b/README.md index 03ce307..f98bb70 100644 --- a/README.md +++ b/README.md @@ -15,16 +15,63 @@ python3 -m pip install select_ai ## Samples -Examples can be found in the samples directory +Examples can be found in the [/samples][samples] directory -## Contributing +### Basic Example + +```python +import select_ai + +user = "" +password = "" +dsn = "" + +select_ai.connect(user=user, password=password, dsn=dsn) +profile = select_ai.Profile(profile_name="oci_ai_profile") +# run_sql returns a pandas dataframe +df = profile.run_sql(prompt="How many promotions?") +print(df.columns) +print(df) +``` + +### Async Example + +```python + +import asyncio + +import select_ai +user = "" +password = "" +dsn = "" -This project welcomes contributions from the community. Before submitting a pull request, please [review our contribution guide](./CONTRIBUTING.md) +# This example shows how to asynchronously run sql +async def main(): + await select_ai.async_connect(user=user, password=password, dsn=dsn) + async_profile = await select_ai.AsyncProfile( + profile_name="async_oci_ai_profile", + ) + # run_sql returns a pandas df + df = await async_profile.run_sql("How many promotions?") + print(df) + +asyncio.run(main()) + +``` +## Help + +Questions can be asked in [GitHub Discussions][ghdiscussions]. + +Problem reports can be raised in [GitHub Issues][ghissues]. + +## Contributing + +This project welcomes contributions from the community. Before submitting a pull request, please [review our contribution guide][contributing] ## Security -Please consult the [security guide](./SECURITY.md) for our responsible security vulnerability disclosure process +Please consult the [security guide][security] for our responsible security vulnerability disclosure process ## License @@ -32,3 +79,9 @@ Copyright (c) 2025 Oracle and/or its affiliates. Released under the Universal Permissive License v1.0 as shown at . + +[contributing]: https://github.com/oracle/python-select-ai/blob/main/CONTRIBUTING.md +[ghdiscussions]: https://github.com/oracle/python-select-ai/discussions +[ghissues]: https://github.com/oracle/python-select-ai/issues +[samples]: https://github.com/oracle/python-select-ai/tree/main/samples +[security]: https://github.com/oracle/python-select-ai/blob/main/SECURITY.md diff --git a/doc/source/user_guide/async_profile.rst b/doc/source/user_guide/async_profile.rst index 24d80d9..fdad4d9 100644 --- a/doc/source/user_guide/async_profile.rst +++ b/doc/source/user_guide/async_profile.rst @@ -20,6 +20,7 @@ Async Profile creation .. literalinclude:: ../../../samples/async/profile_create.py :language: python + :lines: 14- output:: @@ -60,6 +61,7 @@ Async explain SQL .. literalinclude:: ../../../samples/async/profile_explain_sql.py :language: python + :lines: 12- output:: @@ -89,6 +91,7 @@ Async run SQL .. literalinclude:: ../../../samples/async/profile_run_sql.py :language: python + :lines: 14- output:: @@ -103,6 +106,7 @@ Async show SQL .. literalinclude:: ../../../samples/async/profile_show_sql.py :language: python + :lines: 14- output:: @@ -117,6 +121,7 @@ Async concurrent SQL .. literalinclude:: ../../../samples/async/profile_sql_concurrent_tasks.py :language: python + :lines: 15- output:: @@ -152,6 +157,7 @@ Async chat .. literalinclude:: ../../../samples/async/profile_chat.py :language: python + :lines: 14- output:: @@ -177,6 +183,7 @@ Async pipeline .. literalinclude:: ../../../samples/async/profile_pipeline.py :language: python + :lines: 14- output:: @@ -209,6 +216,7 @@ List profiles asynchronously .. literalinclude:: ../../../samples/async/profiles_list.py :language: python + :lines: 14- output:: diff --git a/doc/source/user_guide/conversation.rst b/doc/source/user_guide/conversation.rst index e589a13..6e6d7dc 100644 --- a/doc/source/user_guide/conversation.rst +++ b/doc/source/user_guide/conversation.rst @@ -36,6 +36,7 @@ Create conversion .. literalinclude:: ../../../samples/conversation_create.py :language: python + :lines: 15- output:: @@ -48,6 +49,7 @@ Chat session .. literalinclude:: ../../../samples/conversation_chat_session.py :language: python + :lines: 14- output:: @@ -71,6 +73,7 @@ List conversations .. literalinclude:: ../../../samples/conversations_list.py :language: python + :lines: 14- output:: @@ -87,6 +90,7 @@ Delete conversation .. literalinclude:: ../../../samples/conversation_delete.py :language: python + :lines: 14- output:: @@ -109,6 +113,7 @@ Async chat session .. literalinclude:: ../../../samples/async/conversation_chat_session.py :language: python + :lines: 13- output:: @@ -132,6 +137,7 @@ Async list conversations .. literalinclude:: ../../../samples/async/conversations_list.py :language: python + :lines: 14- output:: diff --git a/doc/source/user_guide/credential.rst b/doc/source/user_guide/credential.rst index 9f30879..8bc8f47 100644 --- a/doc/source/user_guide/credential.rst +++ b/doc/source/user_guide/credential.rst @@ -39,8 +39,25 @@ Create credential In this example, we create a credential object to authenticate to OCI Gen AI service provider: +Sync API +++++++++ + .. literalinclude:: ../../../samples/create_ai_credential.py :language: python + :lines: 14- + +output:: + + Created credential: my_oci_ai_profile_key + +.. latex:clearpage:: + +Async API ++++++++++ + +.. literalinclude:: ../../../samples/async/create_ai_credential.py + :language: python + :lines: 14- output:: diff --git a/doc/source/user_guide/profile.rst b/doc/source/user_guide/profile.rst index a5a8cbd..1211cb0 100644 --- a/doc/source/user_guide/profile.rst +++ b/doc/source/user_guide/profile.rst @@ -39,6 +39,7 @@ Create Profile .. literalinclude:: ../../../samples/profile_create.py :language: python + :lines: 14- output:: @@ -79,6 +80,7 @@ Narrate .. literalinclude:: ../../../samples/profile_narrate.py :language: python + :lines: 14- output:: @@ -93,6 +95,7 @@ Show SQL .. literalinclude:: ../../../samples/profile_show_sql.py :language: python + :lines: 14- output:: @@ -108,6 +111,7 @@ Run SQL .. literalinclude:: ../../../samples/profile_run_sql.py :language: python + :lines: 14- output:: @@ -123,7 +127,8 @@ Chat ************************** .. literalinclude:: ../../../samples/profile_chat.py - :language: python + :language: python + :lines: 14- output:: @@ -139,7 +144,9 @@ List profiles ************************** .. literalinclude:: ../../../samples/profiles_list.py - :language: python + :language: python + :lines: 14- + output:: diff --git a/doc/source/user_guide/provider.rst b/doc/source/user_guide/provider.rst index 1e57793..e998818 100644 --- a/doc/source/user_guide/provider.rst +++ b/doc/source/user_guide/provider.rst @@ -103,13 +103,28 @@ Enable AI service provider export SELECT_AI_DB_CONNECT_STRING= export TNS_ADMIN= +Sync API +++++++++ + This method grants execute privilege on the packages ``DBMS_CLOUD``, ``DBMS_CLOUD_AI`` and ``DBMS_CLOUD_PIPELINE``. It -also enables the user to invoke the AI(LLM) endpoint hosted at a -certain domain +also enables the database user to invoke the AI(LLM) endpoint .. literalinclude:: ../../../samples/enable_ai_provider.py :language: python + :lines: 15- + +output:: + + Enabled AI provider for user: + +.. latex:clearpage:: + +Async API ++++++++++ +.. literalinclude:: ../../../samples/async/enable_ai_provider.py + :language: python + :lines: 14- output:: @@ -121,8 +136,25 @@ output:: Disable AI service provider *************************** +Sync API +++++++++ + .. literalinclude:: ../../../samples/disable_ai_provider.py :language: python + :lines: 14- + +output:: + + Disabled AI provider for user: + +.. latex:clearpage:: + +Async API ++++++++++ + +.. literalinclude:: ../../../samples/async/disable_ai_provider.py + :language: python + :lines: 14- output:: diff --git a/doc/source/user_guide/synthetic_data.rst b/doc/source/user_guide/synthetic_data.rst index bec0601..97e2272 100644 --- a/doc/source/user_guide/synthetic_data.rst +++ b/doc/source/user_guide/synthetic_data.rst @@ -27,8 +27,12 @@ Single table synthetic data The below example shows single table synthetic data generation +Sync API +++++++++ + .. literalinclude:: ../../../samples/profile_gen_single_table_synthetic_data.py :language: python + :lines: 14- output:: @@ -40,14 +44,66 @@ output:: .. latex:clearpage:: +Async API ++++++++++ + +.. literalinclude:: ../../../samples/async/profile_gen_single_table_synthetic_data.py + :language: python + :lines: 12- + +output:: + + SQL> select count(*) from movie; + + COUNT(*) + ---------- + 100 + +.. latex:clearpage:: + + **************************** Multi table synthetic data **************************** The below example shows multitable synthetic data generation +Sync API +++++++++ + .. literalinclude:: ../../../samples/profile_gen_multi_table_synthetic_data.py :language: python + :lines: 14- + + +output:: + + SQL> select count(*) from actor; + + COUNT(*) + ---------- + 40 + + SQL> select count(*) from director; + + COUNT(*) + ---------- + 13 + + SQL> select count(*) from movie; + + COUNT(*) + ---------- + 300 + + +Async API ++++++++++ + + +.. literalinclude:: ../../../samples/async/profile_gen_multi_table_synthetic_data.py + :language: python + :lines: 12- output:: diff --git a/doc/source/user_guide/vector_index.rst b/doc/source/user_guide/vector_index.rst index d29b0a5..193c2e2 100644 --- a/doc/source/user_guide/vector_index.rst +++ b/doc/source/user_guide/vector_index.rst @@ -59,6 +59,7 @@ objects (to create embedding for) reside in OCI's object store .. literalinclude:: ../../../samples/vector_index_create.py :language: python + :lines: 14- output:: @@ -71,6 +72,7 @@ List vector index .. literalinclude:: ../../../samples/vector_index_list.py :language: python + :lines: 15- output:: @@ -84,6 +86,7 @@ RAG using vector index .. literalinclude:: ../../../samples/vector_index_rag.py :language: python + :lines: 14- output:: @@ -108,6 +111,7 @@ Delete vector index .. literalinclude:: ../../../samples/vector_index_delete.py :language: python + :lines: 12- output:: @@ -131,6 +135,7 @@ Async create vector index .. literalinclude:: ../../../samples/async/vector_index_create.py :language: python + :lines: 14- output:: @@ -144,6 +149,7 @@ Async list vector index .. literalinclude:: ../../../samples/async/vector_index_list.py :language: python + :lines: 15- output:: @@ -158,6 +164,7 @@ Async RAG using vector index .. literalinclude:: ../../../samples/async/vector_index_rag.py :language: python + :lines: 15- output:: diff --git a/pyproject.toml b/pyproject.toml index c692e02..2a11a13 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,8 @@ build-backend = "setuptools.build_meta" [project] name = "select_ai" dynamic = ["version"] -description = "Python API for Select AI" +description = "Select AI for Python" +readme = {file = "README.md", content-type = "text/markdown"} requires-python = ">=3.9" authors = [ {name="Abhishek Singh", email="abhishek.o.singh@oracle.com"} @@ -22,7 +23,7 @@ keywords = [ license = " UPL-1.0" license-files = ["LICENSE.txt"] classifiers = [ - "Development Status :: 3 - Alpha", + "Development Status :: 4 - Beta", "Intended Audience :: Developers", "Natural Language :: English", "Operating System :: OS Independent", @@ -40,6 +41,11 @@ dependencies = [ "pandas==2.2.3" ] +[project.urls] +Homepage = "https://github.com/oracle/python-select-ai" +Repository = "https://github.com/oracle/python-select-ai" +Issues = "https://github.com/oracle/python-select-ai/issues" + [tool.setuptools.packages.find] where = ["src"] diff --git a/samples/async/profile_explain_sql.py b/samples/async/profile_explain_sql.py index 17bd92d..0b79c7d 100644 --- a/samples/async/profile_explain_sql.py +++ b/samples/async/profile_explain_sql.py @@ -19,13 +19,12 @@ dsn = os.getenv("SELECT_AI_DB_CONNECT_STRING") -# This example shows how to asynchronously ask the LLM to explain SQL async def main(): await select_ai.async_connect(user=user, password=password, dsn=dsn) async_profile = await select_ai.AsyncProfile( profile_name="async_oci_ai_profile", ) - response = await async_profile.explain_sql("How many promotions") + response = await async_profile.explain_sql("How many promotions ?") print(response) diff --git a/samples/async/profile_gen_multi_table_synthetic_data.py b/samples/async/profile_gen_multi_table_synthetic_data.py new file mode 100644 index 0000000..5732ae7 --- /dev/null +++ b/samples/async/profile_gen_multi_table_synthetic_data.py @@ -0,0 +1,48 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +# ----------------------------------------------------------------------------- +# async/profile_gen_multi_table_synthetic_data.py +# ----------------------------------------------------------------------------- + +import asyncio +import os + +import select_ai + +user = os.getenv("SELECT_AI_USER") +password = os.getenv("SELECT_AI_PASSWORD") +dsn = os.getenv("SELECT_AI_DB_CONNECT_STRING") + + +async def main(): + await select_ai.async_connect(user=user, password=password, dsn=dsn) + async_profile = await select_ai.AsyncProfile( + profile_name="async_oci_ai_profile", + ) + synthetic_data_params = select_ai.SyntheticDataParams( + sample_rows=100, table_statistics=True, priority="HIGH" + ) + object_list = [ + { + "owner": user, + "name": "MOVIE", + "record_count": 100, + "user_prompt": "the release date for the movies should be in 2019", + }, + {"owner": user, "name": "ACTOR", "record_count": 10}, + {"owner": user, "name": "DIRECTOR", "record_count": 5}, + ] + synthetic_data_attributes = select_ai.SyntheticDataAttributes( + object_list=object_list, params=synthetic_data_params + ) + await async_profile.generate_synthetic_data( + synthetic_data_attributes=synthetic_data_attributes + ) + + +asyncio.run(main()) diff --git a/samples/async/profile_gen_single_table_synthetic_data.py b/samples/async/profile_gen_single_table_synthetic_data.py new file mode 100644 index 0000000..15eaf6f --- /dev/null +++ b/samples/async/profile_gen_single_table_synthetic_data.py @@ -0,0 +1,41 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +# ----------------------------------------------------------------------------- +# async/profile_gen_single_table_synthetic_data.py +# ----------------------------------------------------------------------------- + +import asyncio +import os + +import select_ai + +user = os.getenv("SELECT_AI_USER") +password = os.getenv("SELECT_AI_PASSWORD") +dsn = os.getenv("SELECT_AI_DB_CONNECT_STRING") + + +async def main(): + await select_ai.async_connect(user=user, password=password, dsn=dsn) + async_profile = await select_ai.AsyncProfile( + profile_name="async_oci_ai_profile", + ) + synthetic_data_params = select_ai.SyntheticDataParams( + sample_rows=100, table_statistics=True, priority="HIGH" + ) + synthetic_data_attributes = select_ai.SyntheticDataAttributes( + object_name="MOVIE", + user_prompt="the release date for the movies should be in 2019", + params=synthetic_data_params, + record_count=100, + ) + await async_profile.generate_synthetic_data( + synthetic_data_attributes=synthetic_data_attributes + ) + + +asyncio.run(main()) diff --git a/samples/async/profile_run_sql.py b/samples/async/profile_run_sql.py index 837c234..49b5089 100644 --- a/samples/async/profile_run_sql.py +++ b/samples/async/profile_run_sql.py @@ -28,7 +28,7 @@ async def main(): profile_name="async_oci_ai_profile", ) # run_sql returns a pandas df - df = await async_profile.run_sql("How many promotions") + df = await async_profile.run_sql("How many promotions?") print(df) diff --git a/samples/async/profile_show_sql.py b/samples/async/profile_show_sql.py index 93a3f22..f7a88b2 100644 --- a/samples/async/profile_show_sql.py +++ b/samples/async/profile_show_sql.py @@ -26,7 +26,7 @@ async def main(): async_profile = await select_ai.AsyncProfile( profile_name="async_oci_ai_profile", ) - response = await async_profile.show_sql("How many promotions") + response = await async_profile.show_sql("How many promotions?") print(response) diff --git a/samples/async/profile_sql_concurrent_tasks.py b/samples/async/profile_sql_concurrent_tasks.py index 11f2b3a..141ac01 100644 --- a/samples/async/profile_sql_concurrent_tasks.py +++ b/samples/async/profile_sql_concurrent_tasks.py @@ -29,8 +29,8 @@ async def main(): ) sql_tasks = [ async_profile.show_sql(prompt="How many customers?"), - async_profile.run_sql(prompt="How many promotions"), - async_profile.explain_sql(prompt="How many promotions"), + async_profile.run_sql(prompt="How many promotions?"), + async_profile.explain_sql(prompt="How many promotions?"), ] # Collect results from multiple asynchronous tasks diff --git a/samples/async/vector_index_create.py b/samples/async/vector_index_create.py index a2159b8..34f8a2d 100644 --- a/samples/async/vector_index_create.py +++ b/samples/async/vector_index_create.py @@ -11,7 +11,6 @@ # Create a vector index for Retrieval Augmented Generation (RAG) # ----------------------------------------------------------------------------- - import asyncio import os diff --git a/samples/async/vector_index_rag.py b/samples/async/vector_index_rag.py index 8589f50..85c1ce8 100644 --- a/samples/async/vector_index_rag.py +++ b/samples/async/vector_index_rag.py @@ -11,7 +11,6 @@ # Demonstrates Retrieval Augmented Generation (RAG) using ai_profile.narrate() # ----------------------------------------------------------------------------- - import asyncio import os diff --git a/samples/conversation_chat_session.py b/samples/conversation_chat_session.py index df7f219..522d7e8 100644 --- a/samples/conversation_chat_session.py +++ b/samples/conversation_chat_session.py @@ -10,6 +10,7 @@ # # Demonstrates context aware conversation using AI Profile # ----------------------------------------------------------------------------- + import os import select_ai diff --git a/samples/conversation_create.py b/samples/conversation_create.py index a231884..11086ca 100644 --- a/samples/conversation_create.py +++ b/samples/conversation_create.py @@ -11,6 +11,7 @@ # Create a new conversation given a title and description. The created # conversation can be used in profile.chat_session() # ----------------------------------------------------------------------------- + import os import select_ai diff --git a/samples/create_ai_credential.py b/samples/create_ai_credential.py index fd38be4..d942300 100644 --- a/samples/create_ai_credential.py +++ b/samples/create_ai_credential.py @@ -10,6 +10,7 @@ # # Create a Database credential storing OCI Gen AI's credentials # ----------------------------------------------------------------------------- + import os import oci diff --git a/samples/profile_explain_sql.py b/samples/profile_explain_sql.py index 7ab0b00..fa847d4 100644 --- a/samples/profile_explain_sql.py +++ b/samples/profile_explain_sql.py @@ -23,6 +23,6 @@ ) print(profile.description) explanation = profile.explain_sql( - prompt="How many promotions are there in the sh database?" + prompt="How many promotions are there in the database?" ) print(explanation) diff --git a/samples/profile_narrate.py b/samples/profile_narrate.py index fb62f81..420ad95 100644 --- a/samples/profile_narrate.py +++ b/samples/profile_narrate.py @@ -24,6 +24,6 @@ profile_name="oci_ai_profile", ) narration = profile.narrate( - prompt="How many promotions are there in the sh database?" + prompt="How many promotions are there in the database?" ) print(narration) diff --git a/samples/profile_run_sql.py b/samples/profile_run_sql.py index fbfba95..e30635f 100644 --- a/samples/profile_run_sql.py +++ b/samples/profile_run_sql.py @@ -21,8 +21,6 @@ select_ai.connect(user=user, password=password, dsn=dsn) profile = select_ai.Profile(profile_name="oci_ai_profile") -df = profile.run_sql( - prompt="How many promotions are there in the sh database?" -) +df = profile.run_sql(prompt="How many promotions ?") print(df.columns) print(df) diff --git a/samples/profile_show_sql.py b/samples/profile_show_sql.py index 57af084..df852e0 100644 --- a/samples/profile_show_sql.py +++ b/samples/profile_show_sql.py @@ -21,7 +21,5 @@ select_ai.connect(user=user, password=password, dsn=dsn) profile = select_ai.Profile(profile_name="oci_ai_profile") -sql = profile.show_sql( - prompt="How many promotions are there in the sh database?" -) +sql = profile.show_sql(prompt="How many promotions ?") print(sql) diff --git a/src/select_ai/async_profile.py b/src/select_ai/async_profile.py index dbe2e3e..da4f502 100644 --- a/src/select_ai/async_profile.py +++ b/src/select_ai/async_profile.py @@ -315,7 +315,7 @@ async def list( ) async def generate( - self, prompt, action=Action.SHOWSQL, params: Mapping = None + self, prompt: str, action=Action.SHOWSQL, params: Mapping = None ) -> Union[pandas.DataFrame, str, None]: """Asynchronously perform AI translation using this profile @@ -325,6 +325,9 @@ async def generate( conversation_id for context-aware chats :return: Union[pandas.DataFrame, str] """ + if not prompt: + raise ValueError("prompt cannot be empty or None") + parameters = { "prompt": prompt, "action": action, @@ -444,6 +447,15 @@ async def generate_synthetic_data( :raises: oracledb.DatabaseError """ + if synthetic_data_attributes is None: + raise ValueError("'synthetic_data_attributes' cannot be None") + + if not isinstance(synthetic_data_attributes, SyntheticDataAttributes): + raise TypeError( + "'synthetic_data_attributes' must be an object " + "of type select_ai.SyntheticDataAttributes" + ) + keyword_parameters = synthetic_data_attributes.prepare() keyword_parameters["profile_name"] = self.profile_name async with async_cursor() as cr: diff --git a/src/select_ai/profile.py b/src/select_ai/profile.py index 61fed81..e375a01 100644 --- a/src/select_ai/profile.py +++ b/src/select_ai/profile.py @@ -297,6 +297,8 @@ def generate( conversation_id for context-aware chats :return: Union[pandas.DataFrame, str] """ + if not prompt: + raise ValueError("prompt cannot be empty or None") parameters = { "prompt": prompt, "action": action, @@ -407,6 +409,17 @@ def generate_synthetic_data( :raises: oracledb.DatabaseError """ + if synthetic_data_attributes is None: + raise ValueError( + "Param 'synthetic_data_attributes' cannot be None" + ) + + if not isinstance(synthetic_data_attributes, SyntheticDataAttributes): + raise TypeError( + "'synthetic_data_attributes' must be an object " + "of type select_ai.SyntheticDataAttributes" + ) + keyword_parameters = synthetic_data_attributes.prepare() keyword_parameters["profile_name"] = self.profile_name with cursor() as cr: diff --git a/src/select_ai/synthetic_data.py b/src/select_ai/synthetic_data.py index b378c88..0daa5ba 100644 --- a/src/select_ai/synthetic_data.py +++ b/src/select_ai/synthetic_data.py @@ -60,6 +60,12 @@ class SyntheticDataAttributes(SelectAIDataClass): record_count: Optional[int] = None user_prompt: Optional[str] = None + def __post_init__(self): + if self.params and not isinstance(self.params, SyntheticDataParams): + raise TypeError( + "'params' must be an object of" " type SyntheticDataParams'" + ) + def dict(self, exclude_null=True): attributes = {} for k, v in self.__dict__.items(): diff --git a/src/select_ai/version.py b/src/select_ai/version.py index 71eeca5..1875fd5 100644 --- a/src/select_ai/version.py +++ b/src/select_ai/version.py @@ -5,4 +5,4 @@ # http://oss.oracle.com/licenses/upl. # ----------------------------------------------------------------------------- -__version__ = "1.0.0.dev8" +__version__ = "1.0.0b1" From f5c989d77c43ec1bb54fe3262c6e4d81a3872bb0 Mon Sep 17 00:00:00 2001 From: Abhishek Singh Date: Fri, 15 Aug 2025 18:58:17 -0700 Subject: [PATCH 4/4] Fix the prompt in samples --- samples/profile_explain_sql.py | 4 +--- samples/profile_narrate.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/samples/profile_explain_sql.py b/samples/profile_explain_sql.py index fa847d4..5245d65 100644 --- a/samples/profile_explain_sql.py +++ b/samples/profile_explain_sql.py @@ -22,7 +22,5 @@ profile_name="oci_ai_profile", ) print(profile.description) -explanation = profile.explain_sql( - prompt="How many promotions are there in the database?" -) +explanation = profile.explain_sql(prompt="How many promotions?") print(explanation) diff --git a/samples/profile_narrate.py b/samples/profile_narrate.py index 420ad95..7637ca8 100644 --- a/samples/profile_narrate.py +++ b/samples/profile_narrate.py @@ -23,7 +23,5 @@ profile = select_ai.Profile( profile_name="oci_ai_profile", ) -narration = profile.narrate( - prompt="How many promotions are there in the database?" -) +narration = profile.narrate(prompt="How many promotions?") print(narration)