Skip to content

Commit

Permalink
refactor(data-loader,connector): fixed some flake8 issues
Browse files Browse the repository at this point in the history
  • Loading branch information
kartikay-bagla committed Jun 9, 2022
1 parent a8a51cf commit 8c38634
Show file tree
Hide file tree
Showing 7 changed files with 203 additions and 161 deletions.
83 changes: 39 additions & 44 deletions chaos_genius/connectors/base_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from sqlalchemy import exc as sqlalchemy_exc
from sqlalchemy import text


logger = logging.getLogger(__name__)


Expand All @@ -17,7 +16,7 @@ class BaseDb:

@property
def sql_identifier(self):
"""Used to quote any SQL identifier in case of it using special characters or keywords."""
"""Used to quote SQL illegal identifiers."""
return self.__SQL_IDENTIFIER

@property
Expand Down Expand Up @@ -62,46 +61,43 @@ def init_inspector(self):
return self.inspector

def get_schema_metadata(self, get_sequences=False, tables=[]):
"""
Gets all the metadata for the schema provided as input.
"""Gets all the metadata for the schema provided as input.
Output: A multi-dimensional dictionary.
"""
schema = self.get_schema()
schema_dict = dict()
table_dictionary = dict()
table_dictionary = {}
db_tables = self.get_tables(use_schema=schema)
if tables:
db_tables = list(set(db_tables) & set(tables))
for db_table in db_tables:
try:
table_dictionary_info = dict()
table_dictionary_info["table_columns"] = self.get_columns(
db_table, use_schema=schema
)
table_dictionary_info = {
"table_columns": self.get_columns(db_table, use_schema=schema)
}

table_dictionary_info["primary_key"] = self.get_primary_key(
db_table, use_schema=schema
)
table_dictionary_info[
"table_comment"
] = self.get_table_comment(db_table, use_schema=schema)
table_dictionary[db_table] = table_dictionary_info
except sqlalchemy_exc.ResourceClosedError as e:
except sqlalchemy_exc.ResourceClosedError:
logger.warn(f"get_columns failed for table: {db_table}")

schema_dict["tables"] = table_dictionary
schema_dict = {"tables": table_dictionary}
if get_sequences:
schema_sequences = self.get_sequences(use_schema=schema)
schema_dict["sequences"] = schema_sequences
return schema_dict

def get_schema_metadata_from_query(self, query):
"""
Gets all the metadata for the schema provided as input.
"""Gets all the metadata for the schema provided as input.
Output: A multi-dimensional dictionary.
"""
schema_dict = dict()
table_dictionary = dict()
table_dictionary_info = dict()
table_dictionary_info = {}
table_columns = []

# smartly add the limit 1
Expand All @@ -119,34 +115,36 @@ def get_schema_metadata_from_query(self, query):
# create inconsistency becuase of their case insensitive
# nature and can do automated case conversion for metadata
columns = results.keys()
for col in columns:
table_columns.append({"name": col, "type": "TEXT"})
table_columns.extend({"name": col, "type": "TEXT"} for col in columns)
table_dictionary_info["table_columns"] = table_columns
table_dictionary["query"] = table_dictionary_info
schema_dict["tables"] = table_dictionary
return schema_dict
table_dictionary = {"query": table_dictionary_info}
return {"tables": table_dictionary}

def get_tables(self, use_schema=None):
"""
Output: An array with the names of all tables in the database's schema.
"""
"""Returns an array with the names of all tables in the db's schema."""
return self.inspector.get_table_names(schema=use_schema)

def get_columns(self, use_table, use_schema=None):
"""
Output: An array with information about all columns in a table.
"""Returns an array with information about all columns in a table.
Example Output:
[
{'name': 'id', 'type': INTEGER(), 'nullable': False, 'default': 'nextval(\'"API_secrets_id_seq"\'::regclass)', 'autoincrement': True, 'comment': "None"},
{'name': 'secret', 'type': BYTEA(), 'nullable': False, 'default': None, 'autoincrement': False, 'comment': "None"},
{'name': 'datatime', 'type': TIMESTAMP(), 'nullable': True, 'default': None, 'autoincrement': False, 'comment': "None"}
{'name': 'id', 'type': INTEGER(), 'nullable': False,
'default': 'nextval(\'"API_secrets_id_seq"\'::regclass)',
'autoincrement': True, 'comment': "None"},
{'name': 'secret', 'type': BYTEA(), 'nullable': False,
'default': None, 'autoincrement': False, 'comment': "None"},
{'name': 'datatime', 'type': TIMESTAMP(), 'nullable': True,
'default': None, 'autoincrement': False, 'comment': "None"}
]
"""
db_columns = self.inspector.get_columns(
table_name=use_table, schema=use_schema
)
for i in range(len(db_columns)):
try: # Put in Try-Except because some DBs like SQLite do not have comments for columns.
# Put in Try-Except because some DBs like SQLite
# do not have comments for columns.
try:
if db_columns[i]["comment"] is None:
db_columns[i]["comment"] = "None"
except Exception as err_msg:
Expand All @@ -158,16 +156,15 @@ def get_columns(self, use_table, use_schema=None):
return db_columns

def get_primary_key(self, use_table, use_schema=None):
"""
Output: The name of the primary key, or if there is none, it will return "None".
"""
"""Returns the name of the primary key, or "None" if there is no primary key."""
return self.inspector.get_pk_constraint(
table_name=use_table, schema=use_schema
)

def get_table_comment(self, use_table, use_schema=None):
"""
Output: The comment linked with the database table. If there is no comment, it returns "None".
"""Returns the comment linked with the database table.
Returns "None" if there is no comment.
"""
table_comment = self.inspector.get_table_comment(
table_name=use_table, schema=use_schema
Expand All @@ -177,18 +174,16 @@ def get_table_comment(self, use_table, use_schema=None):
return table_comment

def get_sequences(self, use_schema=None):
"""
Output: An array with the names of all sequences in the database's schema.
Example Output: ['secrets_id_seq', 'API_secrets_id_seq', 'hashed__encryption_id_seq']
"""Returns n array with the names of all sequences in the database's schema.
Example Output:
['secrets_id_seq', 'API_secrets_id_seq', 'hashed__encryption_id_seq']
"""
return self.inspector.get_sequence_names(schema=use_schema)

def get_view_names_list(self, schema_name):
data = self.inspector.get_view_names(schema=schema_name)
return data
return self.inspector.get_view_names(schema=schema_name)

def resolve_identifier(self, identifier: str) -> str:
"""
Resolve the identifier if it uses special characters.
"""
"""Resolve the identifier if it uses special characters."""
return identifier
25 changes: 13 additions & 12 deletions chaos_genius/connectors/bigquery.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import json

import pandas as pd
from sqlalchemy.engine import create_engine
from sqlalchemy import text
from sqlalchemy.engine import create_engine

from .base_db import BaseDb
from .connector_utils import merge_dataframe_chunks

Expand All @@ -14,7 +16,7 @@ class BigQueryDb(BaseDb):

@property
def sql_identifier(self):
"""Used to quote any SQL identifier in case of it using special characters or keywords."""
"""Used to quote SQL illegal identifiers."""
return self.__SQL_IDENTIFIER

def __init__(self, *args, **kwargs):
Expand All @@ -36,7 +38,9 @@ def get_db_engine(self):
if not credentials_info:
raise NotImplementedError("Credentials JSON not found for Google BigQuery.")
credentials_info = json.loads(credentials_info)
self.engine = create_engine(db_uri, credentials_info=credentials_info, echo=self.debug)
self.engine = create_engine(
db_uri, credentials_info=credentials_info, echo=self.debug
)
return self.engine

def test_connection(self):
Expand All @@ -48,21 +52,18 @@ def test_connection(self):
with self.engine.connect() as connection:
cursor = connection.execute(query_text)
results = cursor.all()
if results[0][0] == 1:
status = True
else:
status = False
status = results[0][0] == 1
except Exception as err_msg:
status = False
message = str(err_msg)
return status, message

def run_query(self, query, as_df=True):
engine = self.get_db_engine()
if as_df == True:
return merge_dataframe_chunks(pd.read_sql_query(query,
engine,
chunksize=self.CHUNKSIZE))
if as_df:
return merge_dataframe_chunks(
pd.read_sql_query(query, engine, chunksize=self.CHUNKSIZE)
)
else:
return []

Expand All @@ -71,4 +72,4 @@ def get_schema(self):
return self.schema

def get_schema_names_list(self):
return None
return None
70 changes: 39 additions & 31 deletions chaos_genius/connectors/druid.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,29 @@
import pandas as pd
from sqlalchemy import create_engine
from sqlalchemy import text
from sqlalchemy import create_engine, text

from .base_db import BaseDb
from .connector_utils import merge_dataframe_chunks


class Druid(BaseDb):
db_name = "druid"
test_db_query = "SELECT 1"
druid_internal_tables = ["COLUMNS",
"SCHEMATA",
"TABLES",
"segments",
"server_segments",
"servers",
"supervisors",
"tasks"]
druid_internal_tables = [
"COLUMNS",
"SCHEMATA",
"TABLES",
"segments",
"server_segments",
"servers",
"supervisors",
"tasks",
]

__SQL_IDENTIFIER = '"'

@property
def sql_identifier(self):
"""Used to quote any SQL identifier in case of it using special characters or keywords."""
"""Used to quote SQL illegal identifiers."""
return self.__SQL_IDENTIFIER

def __init__(self, *args, **kwargs):
Expand All @@ -32,13 +35,17 @@ def get_db_uri(self):
port = int(db_info.get("port"))
username = db_info.get("username")
password = db_info.get("password")
if not(host and port):
raise NotImplementedError("Database Credential not found for Druid.")
if not (host and port):
raise NotImplementedError(
"Database Credential not found for Druid."
)

self.sqlalchemy_db_uri = (
f"druid://{username}:{password}@{host}:{port}/druid/v2/sql/"
if (username and password)
else f"druid://{host}:{port}/druid/v2/sql/"
)

if not(username and password):
self.sqlalchemy_db_uri = f"druid://{host}:{port}/druid/v2/sql/"
else:
self.sqlalchemy_db_uri = f"druid://{username}:{password}@{host}:{port}/druid/v2/sql/"
return self.sqlalchemy_db_uri

def get_db_engine(self):
Expand All @@ -47,44 +54,45 @@ def get_db_engine(self):
return self.engine

def test_connection(self):
if not hasattr(self, 'engine') or not self.engine:
if not hasattr(self, "engine") or not self.engine:
self.engine = self.get_db_engine()
query_text = text(self.test_db_query)
status, message = None, ""
try:
with self.engine.connect() as connection:
cursor = connection.execute(query_text)
results = cursor.all()
if results[0][0] == 1:
status = True
else:
status = False
status = results[0][0] == 1
except Exception as err_msg:
status = False
message = str(err_msg)
return status, message

def get_tables(self, use_schema=None):
all_tables = self.inspector.get_table_names(schema=use_schema)
filtered_tables = [table for table in all_tables if table not in self.druid_internal_tables]
return filtered_tables
return [
table
for table in all_tables
if table not in self.druid_internal_tables
]

def get_columns(self, use_table, use_schema=None):
db_columns = self.inspector.get_columns(table_name=use_table, schema=use_schema)
db_columns = self.inspector.get_columns(
table_name=use_table, schema=use_schema
)
for i in range(len(db_columns)):
if db_columns[i]["default"] is None:
db_columns[i]["default"] = "None"

db_columns[i]['type'] = str(db_columns[i]['type'])
db_columns[i]["type"] = str(db_columns[i]["type"])
return db_columns


def run_query(self, query, as_df=True):
engine = self.get_db_engine()
if as_df == True:
return merge_dataframe_chunks(pd.read_sql_query(query,
engine,
chunksize=self.CHUNKSIZE))
if as_df:
return merge_dataframe_chunks(
pd.read_sql_query(query, engine, chunksize=self.CHUNKSIZE)
)
else:
return []

Expand Down
Loading

0 comments on commit 8c38634

Please sign in to comment.