Skip to content

Commit

Permalink
added changes as per comments
Browse files Browse the repository at this point in the history
feat: Adding utilities to connect AWS services like RDS

adding AWS Dataservices connectors utilities

----
Signed-off-by: Saisharath <srsharathreddy@gmail.com>
  • Loading branch information
skondakindi committed Apr 11, 2024
1 parent 219ed6d commit 6cf1acf
Show file tree
Hide file tree
Showing 7 changed files with 13 additions and 12 deletions.
2 changes: 1 addition & 1 deletion numalogic/connectors/rds/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def get_rds_token(self) -> str:
str: The generated RDS authentication token.
"""
rds_client = self.boto3_client_manager.get_client(DatabaseServiceProvider.RDS.value)
rds_client = self.boto3_client_manager.get_client(DatabaseServiceProvider.RDS)
return self.boto3_client_manager.get_rds_token(rds_client)

def get_password(self) -> str:
Expand Down
2 changes: 1 addition & 1 deletion numalogic/connectors/rds/db/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def get_db_handler(cls, database_type: DatabaseTypes):
- UnRecognizedDatabaseTypeException: If the specified database type is not supported.
"""
if database_type == DatabaseTypes.MYSQL.value:
if database_type == DatabaseTypes.MYSQL:
from numalogic.connectors.rds.db.mysql_fetcher import MysqlFetcher

return MysqlFetcher
Expand Down
3 changes: 2 additions & 1 deletion numalogic/connectors/rds/db/mysql_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class that inherits from RDSBase. It is used to fetch data from a MySQL database
use cases.
"""

database_type = DatabaseTypes.MYSQL.value
database_type = DatabaseTypes.MYSQL

def __init__(self, db_config: RDSConfig, **kwargs):
super().__init__(db_config)
Expand Down Expand Up @@ -116,6 +116,7 @@ def execute_query(self, query: str) -> pd.DataFrame:
col_names = [desc[0] for desc in cursor.description]
rows = cursor.fetchall()
df = pd.DataFrame(rows, columns=col_names)
connection.close()
_end_time = time.perf_counter() - _start_time
_LOGGER.info("RDS MYSQL Query: %s, execution time: %.4fs", query, _end_time)
return df
4 changes: 2 additions & 2 deletions numalogic/connectors/utils/aws/boto3_client_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,12 @@ def get_client(self, client_type: str) -> boto3.session.Session.client:
str(self.configurations),
)
if client_type in DatabaseServiceProvider:
if client_type == DatabaseServiceProvider.RDS.value:
if client_type == DatabaseServiceProvider.RDS:
self.rds_client = self.get_boto3_session().client(
"rds", region_name=self.configurations.aws_region
)
client = self.rds_client
if client_type == DatabaseServiceProvider.ATHENA.value:
if client_type == DatabaseServiceProvider.ATHENA:
self.athena_client = self.get_boto3_session().client(
"athena", region_name=self.configurations.aws_region
)
Expand Down
8 changes: 4 additions & 4 deletions numalogic/connectors/utils/aws/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from numalogic.connectors.utils.enum import BaseEnum


class DatabaseServiceProvider(BaseEnum):
class DatabaseServiceProvider(str, BaseEnum):
"""
A class representing the database service providers.
Expand All @@ -19,7 +19,7 @@ class DatabaseServiceProvider(BaseEnum):
ATHENA = "athena"


class DatabaseTypes(BaseEnum):
class DatabaseTypes(str, BaseEnum):
"""
A class representing different types of databases.
Expand Down Expand Up @@ -92,7 +92,7 @@ class RDBMSConfig:
database_username: str = ""
database_password: str = ""
database_connection_timeout: int = 10
database_type: str = DatabaseTypes.MYSQL.value
database_type: str = DatabaseTypes.MYSQL
ssl_enabled: bool = False
ssl: Optional[SSLConfig] = field(default_factory=lambda: SSLConfig())

Expand Down Expand Up @@ -130,4 +130,4 @@ class RDSConfig(AWSConfig, RDBMSConfig):

aws_region: str = ""
aws_rds_use_iam: bool = False
database_provider: str = DatabaseServiceProvider.RDS.value
database_provider: str = DatabaseServiceProvider.RDS
2 changes: 1 addition & 1 deletion tests/connectors/rds/db/test_mysql_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def mock_mysql_fetcher_ssl_disabled(mock_db_config_ssl_disabled):
def test_init_method(mock_mysql_fetcher, mock_db_config):
db_config, params = mock_db_config
assert mock_mysql_fetcher.db_config == db_config
assert mock_mysql_fetcher.database_type == DatabaseTypes.MYSQL.value
assert mock_mysql_fetcher.database_type == DatabaseTypes.MYSQL


def test_get_db_cursor_method(mock_mysql_fetcher):
Expand Down
4 changes: 2 additions & 2 deletions tests/connectors/utils/aws/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_rdbms_config():
database_username="user",
database_password="password",
database_connection_timeout=300,
database_type=DatabaseTypes.MYSQL.value,
database_type=DatabaseTypes.MYSQL,
ssl_enabled=True,
ssl=SSLConfig(ca="path_to_ca"),
)
Expand All @@ -50,7 +50,7 @@ def test_rds_config():
database_username="user",
database_password="password",
database_connection_timeout=300,
database_type=DatabaseTypes.MYSQL.value,
database_type=DatabaseTypes.MYSQL,
ssl_enabled=True,
ssl=SSLConfig(ca="path_to_ca"),
)
Expand Down

0 comments on commit 6cf1acf

Please sign in to comment.