-
Notifications
You must be signed in to change notification settings - Fork 27
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Adding utilities to connect AWS services like RDS
adding AWS Dataservices connectors utilities ---- Signed-off-by: Saisharath <srsharathreddy@gmail.com> Signed-off-by: skondakindi <saisharathreddy_kondakindi@intuit.com>
- Loading branch information
skondakindi
committed
Apr 2, 2024
1 parent
2a14ac0
commit 7f6eb72
Showing
16 changed files
with
1,493 additions
and
305 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from enum import Enum, EnumMeta | ||
|
||
|
||
class MetaEnum(EnumMeta): | ||
def __contains__(cls, item): | ||
try: | ||
cls(item) | ||
except ValueError: | ||
return False | ||
return True | ||
|
||
|
||
class BaseEnum(Enum, metaclass=MetaEnum): | ||
@classmethod | ||
def list(cls): | ||
return list(map(lambda c: c.value, cls)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
from boto3 import Session | ||
import logging | ||
from numalogic.connectors.aws import BaseEnum | ||
from numalogic.connectors.aws.exceptions import UnRecognizedAWSClientException | ||
from numalogic.connectors.aws.sts_client_manager import STSClientManager | ||
from numalogic.connectors.aws.db_configurations import ( | ||
load_db_conf, | ||
DatabaseServiceProvider, | ||
DatabaseTypes, | ||
) | ||
|
||
logging.basicConfig(level=logging.INFO) | ||
_LOGGER = logging.getLogger(__name__) | ||
|
||
|
||
class Boto3ClientManager: | ||
|
||
def __init__(self, configurations): | ||
""" | ||
Initializes the Boto3ClientManager with the given configurations. | ||
The Boto3ClientManager is responsible for managing AWS clients for different services like RDS and Athena. | ||
It uses the configurations to create the clients and manage their sessions. | ||
Args: | ||
configurations (object): An object containing the necessary configurations. The configurations should include: | ||
- aws_assume_role_arn: The ARN of the role to assume for AWS services. | ||
- aws_assume_role_session_name: The session name to use when assuming the role. | ||
- endpoint: The endpoint for the AWS service. | ||
- port: The port to use for the AWS service. | ||
- database_username: The username for the database. | ||
- aws_region: The AWS region where the services are located. | ||
Attributes: | ||
rds_client (boto3.client): The client for AWS RDS service. Initialized as None. | ||
athena_client (boto3.client): The client for AWS Athena service. Initialized as None. | ||
configurations (object): The configurations for the AWS services. | ||
sts_client_manager (STSClientManager): The STSClientManager for managing AWS STS sessions. | ||
""" | ||
Check failure on line 39 in numalogic/connectors/aws/boto3_client_manager.py GitHub Actions / ruffRuff (D407)
|
||
self.rds_client = None | ||
self.athena_client = None | ||
self.configurations = configurations | ||
self.sts_client_manager = STSClientManager() | ||
|
||
def get_boto3_session(self) -> Session: | ||
""" | ||
Returns a Boto3 session object with the necessary credentials. | ||
This method retrieves the credentials from the STSClientManager using the given AWS assume role ARN and | ||
session name. It then creates a Boto3 session object with the retrieved credentials and returns it. | ||
Returns: | ||
Session: A Boto3 session object with the necessary credentials. | ||
""" | ||
credentials = self.sts_client_manager.get_credentials( | ||
self.configurations.aws_assume_role_arn, | ||
self.configurations.aws_assume_role_session_name, | ||
) | ||
tmp_access_key = credentials["AccessKeyId"] | ||
tmp_secret_key = credentials["SecretAccessKey"] | ||
security_token = credentials["SessionToken"] | ||
boto3_session = Session( | ||
aws_access_key_id=tmp_access_key, | ||
aws_secret_access_key=tmp_secret_key, | ||
aws_session_token=security_token, | ||
) | ||
return boto3_session | ||
|
||
def get_rds_token(self, rds_boto3_client) -> str: | ||
""" | ||
Generates an RDS authentication token using the provided RDS boto3 client. | ||
This method generates an RDS authentication token by calling the 'generate_db_auth_token' method of the | ||
provided RDS boto3 client. The authentication token is generated using the following parameters: - | ||
DBHostname: The endpoint of the RDS database. - Port: The port number of the RDS database. - DBUsername: The | ||
username for the RDS database. - Region: The AWS region where the RDS database is located. | ||
Parameters: | ||
rds_boto3_client (boto3.client): The RDS boto3 client used to generate the authentication token. | ||
Returns: | ||
str: The generated RDS authentication token. | ||
""" | ||
rds_token = rds_boto3_client.generate_db_auth_token( | ||
DBHostname=self.configurations.endpoint, | ||
Port=self.configurations.port, | ||
DBUsername=self.configurations.database_username, | ||
Region=self.configurations.aws_region, | ||
) | ||
return rds_token | ||
|
||
def get_client(self, client_type: str): | ||
""" | ||
Generates an AWS client based on the provided client type. | ||
This method generates an AWS client based on the provided client type. It first checks if the client type is | ||
recognized by checking if it exists in the `DatabaseServiceProvider` enum. If the client type is recognized, | ||
it creates the corresponding AWS client using the `get_boto3_session().client()` method and returns the | ||
client object. | ||
Parameters: client_type (str): The type of AWS client to generate. This should be one of the values defined | ||
in the `DatabaseServiceProvider` enum. | ||
Returns: | ||
boto3.client: The generated AWS client object. | ||
Raises: UnRecognizedAWSClientException: If the client type is not recognized, an exception is raised with a | ||
message indicating the unrecognized client type and the available options. | ||
""" | ||
_LOGGER.debug( | ||
f"Generating AWS client for client_type: {client_type} , and configurations: {str(self.configurations)}" | ||
) | ||
if client_type in DatabaseServiceProvider: | ||
if client_type == DatabaseServiceProvider.rds.value: | ||
self.rds_client = self.get_boto3_session().client( | ||
"rds", region_name=self.configurations.aws_region | ||
) | ||
return self.rds_client | ||
if client_type == DatabaseServiceProvider.athena.value: | ||
self.athena_client = self.get_boto3_session().client( | ||
"athena", region_name=self.configurations.aws_region | ||
) | ||
else: | ||
raise UnRecognizedAWSClientException( | ||
f"Unrecognized Client Type : {client_type}, please choose one from {DatabaseServiceProvider.list()}" | ||
) | ||
|
||
# | ||
# if __name__ == "__main__": | ||
# config = load_db_conf( | ||
# "./db_config.yaml") | ||
# boto3_client_manager = Boto3ClientManager(config) | ||
# rds = DatabaseServiceProvider.rds.value | ||
# rds_client = boto3_client_manager.get_client(rds) | ||
# _LOGGER.info(boto3_client_manager.get_rds_token(rds_client)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
import os | ||
import logging | ||
from dataclasses import dataclass, field | ||
from typing import Optional | ||
from omegaconf import OmegaConf | ||
from numalogic.connectors.aws import BaseEnum | ||
from numalogic.connectors.aws.exceptions import ConfigNotFoundError | ||
|
||
_LOGGER = logging.getLogger(__name__) | ||
|
||
|
||
class DatabaseServiceProvider(BaseEnum): | ||
""" | ||
A class representing the database service providers. | ||
Attributes: | ||
rds (str): Represents the RDS (Relational Database Service) provider. | ||
athena (str): Represents the Athena provider. | ||
""" | ||
|
||
rds = "rds" | ||
athena = "athena" | ||
|
||
|
||
class DatabaseTypes(BaseEnum): | ||
""" | ||
A class representing different types of databases. | ||
Attributes: | ||
mysql (str): Represents the MySQL database type. | ||
athena (str): Represents the Athena database type. | ||
""" | ||
|
||
mysql = "mysql" | ||
athena = "athena" | ||
|
||
|
||
@dataclass | ||
class AWSConfig: | ||
""" | ||
Class representing AWS configuration. | ||
Attributes: | ||
aws_assume_role_arn (str): The ARN of the IAM role to assume. | ||
aws_assume_role_session_name (str): The name of the session when assuming the IAM role. | ||
""" | ||
|
||
aws_assume_role_arn: str = "" | ||
aws_assume_role_session_name: str = "" | ||
|
||
|
||
@dataclass | ||
class SSLConfig: | ||
""" | ||
SSLConfig class represents the configuration for SSL/TLS settings. | ||
Attributes: | ||
ca (Optional[str]): The path to the Certificate Authority (CA) file. Defaults to an empty string. | ||
""" | ||
|
||
ca: Optional[str] = "" | ||
|
||
|
||
@dataclass | ||
class RDBMSConfig: | ||
""" | ||
RDBMSConfig class represents the configuration for a Relational Database Management System (RDBMS). | ||
Attributes: | ||
endpoint (str): The endpoint or hostname of the database. Defaults to an empty string. | ||
port (int): The port number of the database. Defaults to 3306. | ||
database_name (str): The name of the database. Defaults to an empty string. | ||
database_username (str): The username for the database connection. Defaults to an empty string. | ||
database_password (str): The password for the database connection. Defaults to an empty string. | ||
database_connection_timeout (int): The timeout duration for the database connection in seconds. Defaults to 10. | ||
database_type (str): The type of the database. Defaults to 'mysql'. | ||
database_provider (str): The provider of the database service. Defaults to 'rds'. | ||
ssl_enabled (bool): Flag indicating whether SSL/TLS is enabled for the database connection. Defaults to False. | ||
ssl (Optional[SSLConfig]): The SSL/TLS configuration for the database connection. Defaults to an empty SSLConfig object. | ||
""" | ||
|
||
endpoint: str = "" | ||
port: int = 3306 | ||
database_name: str = "" | ||
database_username: str = "" | ||
database_password: str = "" | ||
database_connection_timeout: int = 10 | ||
database_type: str = DatabaseTypes.mysql.value | ||
database_provider: str = DatabaseServiceProvider.rds.value | ||
ssl_enabled: bool = False | ||
ssl: Optional[SSLConfig] = field(default_factory=lambda: SSLConfig()) | ||
|
||
|
||
@dataclass | ||
class RDSConfig(AWSConfig, RDBMSConfig): | ||
""" | ||
Class representing the configuration for an RDS (Relational Database Service) instance. | ||
Inherits from: | ||
- AWSConfig: Class representing AWS configuration. | ||
- RDBMSConfig: Class representing the configuration for a Relational Database Management System (RDBMS). | ||
Attributes: | ||
aws_assume_role_arn (str): The ARN of the IAM role to assume. | ||
aws_assume_role_session_name (str): The name of the session when assuming the IAM role. | ||
endpoint (str): The endpoint or hostname of the database. Defaults to an empty string. | ||
port (int): The port number of the database. Defaults to 3306. | ||
database_name (str): The name of the database. Defaults to an empty string. | ||
database_username (str): The username for the database connection. Defaults to an empty string. | ||
database_password (str): The password for the database connection. Defaults to an empty string. | ||
database_connection_timeout (int): The timeout duration for the database connection in seconds. Defaults to 10. | ||
database_type (str): The type of the database. Defaults to 'mysql'. | ||
database_provider (str): The provider of the database service. Defaults to 'rds'. | ||
ssl_enabled (bool): Flag indicating whether SSL/TLS is enabled for the database connection. Defaults to False. | ||
ssl (Optional[SSLConfig]): The SSL/TLS configuration for the database connection. Defaults to an empty SSLConfig object. | ||
aws_region (str): The AWS region for the RDS instance. | ||
aws_rds_use_iam (bool): Flag indicating whether to use IAM authentication for the RDS instance. Defaults to False. | ||
""" | ||
|
||
aws_region: str = "" | ||
aws_rds_use_iam: bool = False | ||
|
||
|
||
def load_db_conf(*paths: str) -> RDSConfig: | ||
""" | ||
Load database configuration from one or more YAML files. | ||
Parameters: | ||
- paths (str): One or more paths to YAML files containing the database configuration. | ||
Returns: | ||
- RDSConfig: An instance of the RDSConfig class representing the loaded database configuration. | ||
Raises: | ||
- ConfigNotFoundError: If none of the given configuration file paths exist. | ||
Example: | ||
load_db_conf("/path/to/config.yaml", "/path/to/another/config.yaml") | ||
""" | ||
confs = [] | ||
for _path in paths: | ||
try: | ||
conf = OmegaConf.load(_path) | ||
except FileNotFoundError: | ||
_LOGGER.warning("Config file path: %s not found. Skipping...", _path) | ||
continue | ||
confs.append(conf) | ||
|
||
if not confs: | ||
_err_msg = f"None of the given conf paths exist: {paths}" | ||
raise ConfigNotFoundError(_err_msg) | ||
|
||
schema = OmegaConf.structured(RDSConfig) | ||
conf = OmegaConf.merge(schema, *confs) | ||
return OmegaConf.to_object(conf) | ||
|
||
|
||
# if __name__ == "__main__": | ||
# print( | ||
# load_db_conf( | ||
# "/Users/skondakindi/Desktop/codebase/odl/odl-ml-python-sdk/tests/resources/db_config.yaml" | ||
# ) | ||
# ) |
Oops, something went wrong.