Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ Run
python3 -m pip install select_ai
```

## Documentation

See [Select AI for Python documentation][documentation]

## Samples

Examples can be found in the [/samples][samples] directory
Expand Down Expand Up @@ -81,6 +85,7 @@ Released under the Universal Permissive License v1.0 as shown at
<https://oss.oracle.com/licenses/upl/>.

[contributing]: https://github.com/oracle/python-select-ai/blob/main/CONTRIBUTING.md
[documentation]: https://docs.oracle.com/en/cloud/paas/autonomous-database/serverless/pysai/
[ghdiscussions]: https://github.com/oracle/python-select-ai/discussions
[ghissues]: https://github.com/oracle/python-select-ai/issues
[samples]: https://github.com/oracle/python-select-ai/tree/main/samples
Expand Down
7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ keywords = [
license = " UPL-1.0"
license-files = ["LICENSE.txt"]
classifiers = [
"Development Status :: 4 - Beta",
"Development Status :: 5 - Production/Stable",
"Intended Audience :: Developers",
"Natural Language :: English",
"Operating System :: OS Independent",
Expand All @@ -34,7 +34,9 @@ classifiers = [
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: Implementation :: CPython",
"Topic :: Database"
"Topic :: Database",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development :: Libraries :: Python Modules"
]
dependencies = [
"oracledb",
Expand All @@ -45,6 +47,7 @@ dependencies = [
Homepage = "https://github.com/oracle/python-select-ai"
Repository = "https://github.com/oracle/python-select-ai"
Issues = "https://github.com/oracle/python-select-ai/issues"
Documentation = "https://docs.oracle.com/en/cloud/paas/autonomous-database/serverless/pysai/"

[tool.setuptools.packages.find]
where = ["src"]
Expand Down
123 changes: 123 additions & 0 deletions src/select_ai/_validations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# -----------------------------------------------------------------------------
# Copyright (c) 2025, Oracle and/or its affiliates.
#
# Licensed under the Universal Permissive License v 1.0 as shown at
# http://oss.oracle.com/licenses/upl.
# -----------------------------------------------------------------------------

import inspect
from collections.abc import Mapping, Sequence, Set
from functools import wraps
from typing import Any, get_args, get_origin, get_type_hints

NoneType = type(None)


def _match(value, annot) -> bool:
"""Recursively validate value against a typing annotation."""
if annot is Any:
return True

origin = get_origin(annot)
args = get_args(annot)

# Handle Annotated[T, ...] → treat as T
if origin is getattr(__import__("typing"), "Annotated", None):
annot = args[0]
origin = get_origin(annot)
args = get_args(annot)

# Optional[T] is Union[T, NoneType]
if origin is getattr(__import__("typing"), "Union", None):
return any(_match(value, a) for a in args)

# Literal[…]
if origin is getattr(__import__("typing"), "Literal", None):
return any(value == lit for lit in args)

# Tuple cases
if origin is tuple:
if not isinstance(value, tuple):
return False
if len(args) == 2 and args[1] is Ellipsis:
# tuple[T, ...]
return all(_match(v, args[0]) for v in value)
if len(args) != len(value):
return False
return all(_match(v, a) for v, a in zip(value, args))

# Mappings (dict-like)
if origin in (dict, Mapping):
if not isinstance(value, Mapping):
return False
k_annot, v_annot = args if args else (Any, Any)
return all(
_match(k, k_annot) and _match(v, v_annot) for k, v in value.items()
)

# Sequences (list, Sequence) – but not str/bytes
if origin in (list, Sequence):
if isinstance(value, (str, bytes)):
return False
if not isinstance(value, Sequence):
return False
elem_annot = args[0] if args else Any
return all(_match(v, elem_annot) for v in value)

# Sets
if origin in (set, frozenset, Set):
if not isinstance(value, (set, frozenset)):
return False
elem_annot = args[0] if args else Any
return all(_match(v, elem_annot) for v in value)

# Fall back to normal isinstance for non-typing classes
if isinstance(annot, type):
return isinstance(value, annot)

# If annot is a typing alias like 'list' without args
if origin is not None:
# Treat bare containers as accepting anything inside
return isinstance(value, origin)

# Unknown/unsupported typing form: accept conservatively
return True


def enforce_types(func):
# Resolve ForwardRefs using function globals (handles "User" as a string, etc.)
hints = get_type_hints(
func, globalns=func.__globals__, include_extras=True
)
sig = inspect.signature(func)

def _check(bound):
for name, val in bound.arguments.items():
if name in hints:
annot = hints[name]
if not _match(val, annot):
raise TypeError(
f"Argument '{name}' failed type check: expected {annot!r}, "
f"got {type(val).__name__} -> {val!r}"
)

if inspect.iscoroutinefunction(func):

@wraps(func)
async def aw(*args, **kwargs):
bound = sig.bind(*args, **kwargs)
bound.apply_defaults()
_check(bound)
return await func(*args, **kwargs)

return aw
else:

@wraps(func)
def w(*args, **kwargs):
bound = sig.bind(*args, **kwargs)
bound.apply_defaults()
_check(bound)
return func(*args, **kwargs)

return w
14 changes: 10 additions & 4 deletions src/select_ai/async_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,15 @@ async def generate(
keyword_parameters=parameters,
)
if data is not None:
return await data.read()
return None
result = await data.read()
else:
result = None
if action == Action.RUNSQL and result:
return pandas.DataFrame(json.loads(result))
elif action == Action.RUNSQL:
return pandas.DataFrame()
else:
return result

async def chat(self, prompt, params: Mapping = None) -> str:
"""Asynchronously chat with the LLM
Expand Down Expand Up @@ -411,8 +418,7 @@ async def run_sql(
:param params: Parameters to include in the LLM request
:return: pandas.DataFrame
"""
data = await self.generate(prompt, action=Action.RUNSQL, params=params)
return pandas.DataFrame(json.loads(data))
return await self.generate(prompt, action=Action.RUNSQL, params=params)

async def show_sql(self, prompt, params: Mapping = None):
"""Show the generated SQL
Expand Down
1 change: 1 addition & 0 deletions src/select_ai/base_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class ProfileAttributes(SelectAIDataClass):
vector_index_name: Optional[str] = None

def __post_init__(self):
super().__post_init__()
if self.provider and not isinstance(self.provider, Provider):
raise ValueError(
f"'provider' must be an object of " f"type select_ai.Provider"
Expand Down
22 changes: 14 additions & 8 deletions src/select_ai/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import json
from contextlib import contextmanager
from dataclasses import replace as dataclass_replace
from typing import Iterator, Mapping, Optional, Union
from typing import Generator, Iterator, Mapping, Optional, Union

import oracledb
import pandas
Expand Down Expand Up @@ -258,7 +258,9 @@ def _from_db(cls, profile_name: str) -> "Profile":
raise ProfileNotFoundError(profile_name=profile_name)

@classmethod
def list(cls, profile_name_pattern: str = ".*") -> Iterator["Profile"]:
def list(
cls, profile_name_pattern: str = ".*"
) -> Generator["Profile", None, None]:
"""List AI Profiles saved in the database.

:param str profile_name_pattern: Regular expressions can be used
Expand Down Expand Up @@ -314,8 +316,15 @@ def generate(
keyword_parameters=parameters,
)
if data is not None:
return data.read()
return None
result = data.read()
else:
result = None
if action == Action.RUNSQL and result:
return pandas.DataFrame(json.loads(result))
elif action == Action.RUNSQL:
return pandas.DataFrame()
else:
return result

def chat(self, prompt: str, params: Mapping = None) -> str:
"""Chat with the LLM
Expand Down Expand Up @@ -375,10 +384,7 @@ def run_sql(self, prompt: str, params: Mapping = None) -> pandas.DataFrame:
:param params: Parameters to include in the LLM request
:return: pandas.DataFrame
"""
data = json.loads(
self.generate(prompt, action=Action.RUNSQL, params=params)
)
return pandas.DataFrame(data)
return self.generate(prompt, action=Action.RUNSQL, params=params)

def show_sql(self, prompt: str, params: Mapping = None) -> str:
"""Show the generated SQL
Expand Down
13 changes: 9 additions & 4 deletions src/select_ai/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import List, Optional, Union

from select_ai._abc import SelectAIDataClass
from select_ai._validations import enforce_types

from .db import async_cursor, cursor
from .sql import (
Expand Down Expand Up @@ -194,6 +195,7 @@ class AnthropicProvider(Provider):
provider_endpoint = "api.anthropic.com"


@enforce_types
async def async_enable_provider(
users: Union[str, List[str]], provider_endpoint: str = None
):
Expand All @@ -210,7 +212,7 @@ async def async_enable_provider(

async with async_cursor() as cr:
for user in users:
await cr.execute(GRANT_PRIVILEGES_TO_USER.format(user))
await cr.execute(GRANT_PRIVILEGES_TO_USER.format(user.strip()))
if provider_endpoint:
await cr.execute(
ENABLE_AI_PROFILE_DOMAIN_FOR_USER,
Expand All @@ -219,6 +221,7 @@ async def async_enable_provider(
)


@enforce_types
async def async_disable_provider(
users: Union[str, List[str]], provider_endpoint: str = None
):
Expand All @@ -234,7 +237,7 @@ async def async_disable_provider(

async with async_cursor() as cr:
for user in users:
await cr.execute(REVOKE_PRIVILEGES_FROM_USER.format(user))
await cr.execute(REVOKE_PRIVILEGES_FROM_USER.format(user.strip()))
if provider_endpoint:
await cr.execute(
DISABLE_AI_PROFILE_DOMAIN_FOR_USER,
Expand All @@ -243,6 +246,7 @@ async def async_disable_provider(
)


@enforce_types
def enable_provider(
users: Union[str, List[str]], provider_endpoint: str = None
):
Expand All @@ -256,7 +260,7 @@ def enable_provider(

with cursor() as cr:
for user in users:
cr.execute(GRANT_PRIVILEGES_TO_USER.format(user))
cr.execute(GRANT_PRIVILEGES_TO_USER.format(user.strip()))
if provider_endpoint:
cr.execute(
ENABLE_AI_PROFILE_DOMAIN_FOR_USER,
Expand All @@ -265,6 +269,7 @@ def enable_provider(
)


@enforce_types
def disable_provider(
users: Union[str, List[str]], provider_endpoint: str = None
):
Expand All @@ -279,7 +284,7 @@ def disable_provider(

with cursor() as cr:
for user in users:
cr.execute(REVOKE_PRIVILEGES_FROM_USER.format(user))
cr.execute(REVOKE_PRIVILEGES_FROM_USER.format(user.strip()))
if provider_endpoint:
cr.execute(
DISABLE_AI_PROFILE_DOMAIN_FOR_USER,
Expand Down
10 changes: 10 additions & 0 deletions src/select_ai/vector_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,16 @@ def __init__(
attributes: Optional[VectorIndexAttributes] = None,
):
"""Initialize a Vector Index"""
if attributes and not isinstance(attributes, VectorIndexAttributes):
raise TypeError(
"'attributes' must be an object of type "
"select_ai.VectorIndexAttributes"
)
if profile and not isinstance(profile, BaseProfile):
raise TypeError(
"'profile' must be an object of type "
"select_ai.Profile or select_ai.AsyncProfile"
)
self.profile = profile
self.index_name = index_name
self.attributes = attributes
Expand Down
2 changes: 1 addition & 1 deletion src/select_ai/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
# http://oss.oracle.com/licenses/upl.
# -----------------------------------------------------------------------------

__version__ = "1.0.0b1"
__version__ = "1.0.0"