Skip to content

Commit

Permalink
added changes as per comments
Browse files Browse the repository at this point in the history
feat: Adding utilities to connect AWS services like RDS

adding AWS Dataservices connectors utilities

----
Signed-off-by: Saisharath <srsharathreddy@gmail.com>
Signed-off-by: skondakindi <saisharathreddy_kondakindi@intuit.com>
  • Loading branch information
skondakindi committed Apr 11, 2024
1 parent 5fdf486 commit 330a6be
Show file tree
Hide file tree
Showing 13 changed files with 322 additions and 305 deletions.
1 change: 1 addition & 0 deletions numalogic/connectors/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ class ConnectorType(IntEnum):
redis = 0
prometheus = 1
druid = 2
rds = 3


@dataclass
Expand Down
149 changes: 68 additions & 81 deletions numalogic/connectors/rds/_base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from abc import ABCMeta, abstractmethod
from typing import Optional
import pandas as pd
from numalogic.connectors.utils.aws.config import DatabaseServiceProvider, RDSConfig
Expand All @@ -9,28 +10,63 @@
_LOGGER = logging.getLogger(__name__)


class RDSBase:
def format_dataframe(
df: pd.DataFrame,
query: str,
datetime_field_name: str,
group_by: Optional[list[str]] = None,
pivot: Optional[Pivot] = None,
) -> pd.DataFrame:
"""
Class: RDSBase.
Executes formatting operations on a pandas DataFrame.
This class represents a data fetcher for RDS (Relational Database Service) connections. It
Arguments
----------
df : pd.DataFrame
The input DataFrame to be formatted.
query : str
The SQL query used to retrieve the data.
datetime_field_name : str
The name of the datetime field in the DataFrame.
group_by : Optional[list[str]], optional
A list of column names to group the DataFrame by, by default None.
pivot : Optional[Pivot], optional
An optional Pivot object specifying the index, columns,
and values for pivoting the DataFrame, by default None.
Returns
-------
pd.DataFrame : The formatted DataFrame.
"""
_start_time = time.perf_counter()
df["timestamp"] = pd.to_datetime(df[datetime_field_name]).astype("int64") // 10**6
df.drop(columns=datetime_field_name, inplace=True)
if group_by:
df = df.groupby(by=group_by).sum().reset_index()

if pivot and pivot.columns:
df = df.pivot(
index=pivot.index,
columns=pivot.columns,
values=pivot.value,
)
df.columns = df.columns.map("{0[1]}".format)
df.reset_index(inplace=True)
_end_time = time.perf_counter() - _start_time
_LOGGER.info("RDS MYSQL Query: %s, Format time: %.4fs", query, _end_time)
return df


class RDSBase(metaclass=ABCMeta):
"""
class represents a data fetcher for RDS (Relational Database Service) connections. It
provides methods for retrieving the RDS token, getting the password, establishing a
connection, and executing queries.
Attributes
----------
- db_config (RDSConfig): The configuration object for the RDS connection.
- kwargs (dict): Additional keyword arguments.
Methods
-------
- get_rds_token(): Retrieves the RDS token using the Boto3ClientManager. - get_password() ->
str: Retrieves the password for the RDS connection. If 'aws_rds_use_iam' is True, it calls
the get_rds_token() method, otherwise it returns the database password from the
configuration. - get_connection(): Placeholder method for establishing a connection to the
RDS database. - get_db_cursor(): Placeholder method for getting a database cursor. -
execute_query(query) -> pd.DataFrame: Placeholder method for executing a query and returning
the result as a pandas DataFrame.
Args:
- db_config (RDSConfig): The configuration object for the RDS connection.
- kwargs (dict): Additional keyword arguments.
"""

Expand All @@ -45,10 +81,9 @@ def get_rds_token(self) -> str:
"""
Generates an RDS authentication token using the provided RDS boto3 client.
Arguments
----------
- rds_boto3_client (boto3.client): The RDS boto3 client used to generate the
authentication token.
Args:
- rds_boto3_client (boto3.client): The RDS boto3 client used to generate the
authentication token.
Returns
-------
Expand All @@ -70,15 +105,14 @@ def get_password(self) -> str:
str: The password for the RDS connection.
"""
password = None
if self.db_config.aws_rds_use_iam:
_LOGGER.info("using aws_rds_use_iam to generate RDS Token")
password = self.get_rds_token()
else:
_LOGGER.info("using password from config to connect RDS Database")
password = self.db_config.database_password
return password
return self.get_rds_token()

_LOGGER.info("using password from config to connect RDS Database")
return self.db_config.database_password

@abstractmethod
def get_connection(self):
"""
Establishes a connection to the RDS database.
Expand All @@ -94,7 +128,8 @@ def get_connection(self):
"""
raise NotImplementedError

def get_db_cursor(self):
@abstractmethod
def get_db_cursor(self, *args, **kwargs):
"""
Retrieves a database cursor for executing queries.
Expand All @@ -106,63 +141,15 @@ def get_db_cursor(self):
None
"""
pass

def format_dataframe(
self,
df: pd.DataFrame,
query: str,
datetime_field_name: str,
group_by: Optional[list[str]] = None,
pivot: Optional[Pivot] = None,
):
"""
Executes formatting operations on a pandas DataFrame.
Arguments
----------
df : pd.DataFrame
The input DataFrame to be formatted.
query : str
The SQL query used to retrieve the data.
datetime_field_name : str
The name of the datetime field in the DataFrame.
group_by : Optional[list[str]], optional
A list of column names to group the DataFrame by, by default None.
pivot : Optional[Pivot], optional
An optional Pivot object specifying the index, columns,
and values for pivoting the DataFrame, by default None.
Returns
-------
pd.DataFrame : The formatted DataFrame.
raise NotImplementedError

"""
_start_time = time.perf_counter()
df["timestamp"] = pd.to_datetime(df[datetime_field_name]).astype("int64") // 10**6
df.drop(columns=datetime_field_name, inplace=True)
if group_by:
df = df.groupby(by=group_by).sum().reset_index()

if pivot and pivot.columns:
df = df.pivot(
index=pivot.index,
columns=pivot.columns,
values=pivot.value,
)
df.columns = df.columns.map("{0[1]}".format)
df.reset_index(inplace=True)
_end_time = time.perf_counter() - _start_time
_LOGGER.info("RDS MYSQL Query: %s, Format time: %.4fs", query, _end_time)
return df

def execute_query(self, query) -> pd.DataFrame:
@abstractmethod
def execute_query(self, query: str) -> pd.DataFrame:
"""
Executes a query on the RDS database and returns the result as a pandas DataFrame.
Parameters
----------
query (str): The SQL query to be executed.
Args:
query (str): The SQL query to be executed.
Returns
-------
Expand Down
18 changes: 5 additions & 13 deletions numalogic/connectors/rds/_rds.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional
from numalogic.connectors._base import DataFetcher
from numalogic.connectors._config import Pivot
from numalogic.connectors.rds._base import format_dataframe
from numalogic.connectors.utils.aws.config import RDSConfig
import logging
import pandas as pd
Expand All @@ -12,9 +13,7 @@

class RDSFetcher(DataFetcher):
"""
RDSFetcher class.
This class is a subclass of DataFetcher and ABC (Abstract Base Class).
class is a subclass of DataFetcher and ABC (Abstract Base Class).
It is used to fetch data from an RDS (Relational Database Service) instance by executing
a given SQL query.
Expand All @@ -23,17 +22,10 @@ class RDSFetcher(DataFetcher):
db_config (RDSConfig): The configuration object for the RDS instance.
fetcher (db.CLASS_TYPE): The fetcher object for the specific database type.
Methods
-------
__init__(self, db_config: RDSConfig):
Initializes the RDSFetcher object with the given RDSConfig object.
fetch(self, query):
Fetches data from the RDS instance by executing the given SQL query.
"""

def __init__(self, db_config: RDSConfig):
super().__init__(db_config.__dict__.get("url"))
super().__init__(db_config.endpoint)
self.db_config = db_config
factory_object = RdsFactory()
self.fetcher = factory_object.get_db_handler(db_config.database_type.lower())(db_config)
Expand All @@ -49,7 +41,7 @@ def fetch(
"""
Fetches data from the RDS instance by executing the given query.
Arguments:
Args:
query (str): The SQL query to be executed.
datetime_field_name (str): The name of the datetime field in the fetched data.
pivot (Optional[Pivot], optional): The pivot configuration for the fetched data.
Expand All @@ -67,7 +59,7 @@ def fetch(
_LOGGER.warning("No data found for query : %s ", query)
return pd.DataFrame()

formatted_df = self.fetcher.format_dataframe(
formatted_df = format_dataframe(
df, query=query, datetime_field_name=datetime_field_name, pivot=pivot, group_by=group_by
)
_end_time = time.perf_counter() - _start_time
Expand Down
42 changes: 20 additions & 22 deletions numalogic/connectors/rds/db/factory.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,34 @@
import logging

from numalogic.connectors.utils.aws.config import DatabaseTypes
from numalogic.connectors.utils.aws.exceptions import UnRecognizedDatabaseTypeException

_LOGGER = logging.getLogger(__name__)


class RdsFactory:
"""
Class: RdsFactory.
"""class represents a factory for creating database handlers for different database types."""

This class represents a factory for creating database handlers for different database types.
@classmethod
def get_db_handler(cls, database_type: DatabaseTypes):
"""
Get the database handler for the specified database type.
Methods
-------
- get_db_handler(database_type: str) -> Type[DatabaseHandler]:
- This method takes a database_type as input and returns
the corresponding database handler class.
- If the database_type is "mysql", it returns the MysqlFetcher class from
the numalogic.connectors.rds.db.mysql_fetcher module.
- If the database_type is not supported, it returns None.
Args:
- database_type (str): The type of the database.
"""
Returns
-------
- The database handler for the specified database type.
@classmethod
def get_db_handler(cls, database_type: str):
db_class = None
if database_type == "mysql":
Raises
------
- UnRecognizedDatabaseTypeException: If the specified database type is not supported.
"""
if database_type == DatabaseTypes.MYSQL.value:
from numalogic.connectors.rds.db.mysql_fetcher import MysqlFetcher

db_class = MysqlFetcher
else:
raise UnRecognizedDatabaseTypeException(
f"database_type: {database_type} is not supported"
)
return db_class
return MysqlFetcher

raise UnRecognizedDatabaseTypeException(f"database_type: {database_type} is not supported")
Loading

0 comments on commit 330a6be

Please sign in to comment.