Skip to content
This repository has been archived by the owner on Sep 12, 2024. It is now read-only.

fix(connectors): updated quoted identifier strings #980

Merged
merged 6 commits into from
Jun 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 42 additions & 41 deletions chaos_genius/connectors/base_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from sqlalchemy import exc as sqlalchemy_exc
from sqlalchemy import text


logger = logging.getLogger(__name__)


Expand All @@ -17,7 +16,7 @@ class BaseDb:

@property
def sql_identifier(self):
"""Used to quote any SQL identifier in case of it using special characters or keywords."""
"""Used to quote SQL illegal identifiers."""
return self.__SQL_IDENTIFIER

@property
Expand Down Expand Up @@ -62,46 +61,43 @@ def init_inspector(self):
return self.inspector

def get_schema_metadata(self, get_sequences=False, tables=[]):
"""
Gets all the metadata for the schema provided as input.
"""Gets all the metadata for the schema provided as input.

Output: A multi-dimensional dictionary.
"""
schema = self.get_schema()
schema_dict = dict()
table_dictionary = dict()
table_dictionary = {}
db_tables = self.get_tables(use_schema=schema)
if tables:
db_tables = list(set(db_tables) & set(tables))
for db_table in db_tables:
try:
table_dictionary_info = dict()
table_dictionary_info["table_columns"] = self.get_columns(
db_table, use_schema=schema
)
table_dictionary_info = {
"table_columns": self.get_columns(db_table, use_schema=schema)
}

table_dictionary_info["primary_key"] = self.get_primary_key(
db_table, use_schema=schema
)
table_dictionary_info[
"table_comment"
] = self.get_table_comment(db_table, use_schema=schema)
table_dictionary[db_table] = table_dictionary_info
except sqlalchemy_exc.ResourceClosedError as e:
except sqlalchemy_exc.ResourceClosedError:
logger.warn(f"get_columns failed for table: {db_table}")

schema_dict["tables"] = table_dictionary
schema_dict = {"tables": table_dictionary}
if get_sequences:
schema_sequences = self.get_sequences(use_schema=schema)
schema_dict["sequences"] = schema_sequences
return schema_dict

def get_schema_metadata_from_query(self, query):
"""
Gets all the metadata for the schema provided as input.
"""Gets all the metadata for the schema provided as input.

Output: A multi-dimensional dictionary.
"""
schema_dict = dict()
table_dictionary = dict()
table_dictionary_info = dict()
table_dictionary_info = {}
table_columns = []

# smartly add the limit 1
Expand All @@ -119,34 +115,36 @@ def get_schema_metadata_from_query(self, query):
# create inconsistency becuase of their case insensitive
# nature and can do automated case conversion for metadata
columns = results.keys()
for col in columns:
table_columns.append({"name": col, "type": "TEXT"})
table_columns.extend({"name": col, "type": "TEXT"} for col in columns)
table_dictionary_info["table_columns"] = table_columns
table_dictionary["query"] = table_dictionary_info
schema_dict["tables"] = table_dictionary
return schema_dict
table_dictionary = {"query": table_dictionary_info}
return {"tables": table_dictionary}

def get_tables(self, use_schema=None):
"""
Output: An array with the names of all tables in the database's schema.
"""
"""Returns an array with the names of all tables in the db's schema."""
return self.inspector.get_table_names(schema=use_schema)

def get_columns(self, use_table, use_schema=None):
"""
Output: An array with information about all columns in a table.
"""Returns an array with information about all columns in a table.

Example Output:
[
{'name': 'id', 'type': INTEGER(), 'nullable': False, 'default': 'nextval(\'"API_secrets_id_seq"\'::regclass)', 'autoincrement': True, 'comment': "None"},
{'name': 'secret', 'type': BYTEA(), 'nullable': False, 'default': None, 'autoincrement': False, 'comment': "None"},
{'name': 'datatime', 'type': TIMESTAMP(), 'nullable': True, 'default': None, 'autoincrement': False, 'comment': "None"}
{'name': 'id', 'type': INTEGER(), 'nullable': False,
'default': 'nextval(\'"API_secrets_id_seq"\'::regclass)',
'autoincrement': True, 'comment': "None"},
{'name': 'secret', 'type': BYTEA(), 'nullable': False,
'default': None, 'autoincrement': False, 'comment': "None"},
{'name': 'datatime', 'type': TIMESTAMP(), 'nullable': True,
'default': None, 'autoincrement': False, 'comment': "None"}
]
"""
db_columns = self.inspector.get_columns(
table_name=use_table, schema=use_schema
)
for i in range(len(db_columns)):
try: # Put in Try-Except because some DBs like SQLite do not have comments for columns.
# Put in Try-Except because some DBs like SQLite
# do not have comments for columns.
try:
if db_columns[i]["comment"] is None:
db_columns[i]["comment"] = "None"
except Exception as err_msg:
Expand All @@ -158,16 +156,15 @@ def get_columns(self, use_table, use_schema=None):
return db_columns

def get_primary_key(self, use_table, use_schema=None):
"""
Output: The name of the primary key, or if there is none, it will return "None".
"""
"""Returns the name of the primary key, or "None" if there is no primary key."""
return self.inspector.get_pk_constraint(
table_name=use_table, schema=use_schema
)

def get_table_comment(self, use_table, use_schema=None):
"""
Output: The comment linked with the database table. If there is no comment, it returns "None".
"""Returns the comment linked with the database table.

Returns "None" if there is no comment.
"""
table_comment = self.inspector.get_table_comment(
table_name=use_table, schema=use_schema
Expand All @@ -177,12 +174,16 @@ def get_table_comment(self, use_table, use_schema=None):
return table_comment

def get_sequences(self, use_schema=None):
"""
Output: An array with the names of all sequences in the database's schema.
Example Output: ['secrets_id_seq', 'API_secrets_id_seq', 'hashed__encryption_id_seq']
"""Returns n array with the names of all sequences in the database's schema.

Example Output:
['secrets_id_seq', 'API_secrets_id_seq', 'hashed__encryption_id_seq']
"""
return self.inspector.get_sequence_names(schema=use_schema)

def get_view_names_list(self, schema_name):
data = self.inspector.get_view_names(schema=schema_name)
return data
return self.inspector.get_view_names(schema=schema_name)

def resolve_identifier(self, identifier: str) -> str:
"""Resolve the identifier if it uses special characters."""
return identifier
30 changes: 19 additions & 11 deletions chaos_genius/connectors/bigquery.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import json

import pandas as pd
from sqlalchemy.engine import create_engine
from sqlalchemy import text
from sqlalchemy.engine import create_engine

from .base_db import BaseDb
from .connector_utils import merge_dataframe_chunks

Expand All @@ -10,6 +12,13 @@ class BigQueryDb(BaseDb):
db_name = "bigquery"
test_db_query = "SELECT 1"

__SQL_IDENTIFIER = "`"

@property
def sql_identifier(self):
"""Used to quote SQL illegal identifiers."""
return self.__SQL_IDENTIFIER

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

Expand All @@ -29,7 +38,9 @@ def get_db_engine(self):
if not credentials_info:
raise NotImplementedError("Credentials JSON not found for Google BigQuery.")
credentials_info = json.loads(credentials_info)
self.engine = create_engine(db_uri, credentials_info=credentials_info, echo=self.debug)
self.engine = create_engine(
db_uri, credentials_info=credentials_info, echo=self.debug
)
return self.engine

def test_connection(self):
Expand All @@ -41,21 +52,18 @@ def test_connection(self):
with self.engine.connect() as connection:
cursor = connection.execute(query_text)
results = cursor.all()
if results[0][0] == 1:
status = True
else:
status = False
status = results[0][0] == 1
except Exception as err_msg:
status = False
message = str(err_msg)
return status, message

def run_query(self, query, as_df=True):
engine = self.get_db_engine()
if as_df == True:
return merge_dataframe_chunks(pd.read_sql_query(query,
engine,
chunksize=self.CHUNKSIZE))
if as_df:
return merge_dataframe_chunks(
pd.read_sql_query(query, engine, chunksize=self.CHUNKSIZE)
)
else:
return []

Expand All @@ -64,4 +72,4 @@ def get_schema(self):
return self.schema

def get_schema_names_list(self):
return None
return None
75 changes: 45 additions & 30 deletions chaos_genius/connectors/druid.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,30 @@
import pandas as pd
from sqlalchemy import create_engine
from sqlalchemy import text
from sqlalchemy import create_engine, text

from .base_db import BaseDb
from .connector_utils import merge_dataframe_chunks


class Druid(BaseDb):
db_name = "druid"
test_db_query = "SELECT 1"
druid_internal_tables = ["COLUMNS",
"SCHEMATA",
"TABLES",
"segments",
"server_segments",
"servers",
"supervisors",
"tasks"]
druid_internal_tables = [
"COLUMNS",
"SCHEMATA",
"TABLES",
"segments",
"server_segments",
"servers",
"supervisors",
"tasks",
]

__SQL_IDENTIFIER = '"'

@property
def sql_identifier(self):
"""Used to quote SQL illegal identifiers."""
return self.__SQL_IDENTIFIER

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand All @@ -25,13 +35,17 @@ def get_db_uri(self):
port = int(db_info.get("port"))
username = db_info.get("username")
password = db_info.get("password")
if not(host and port):
raise NotImplementedError("Database Credential not found for Druid.")
if not (host and port):
raise NotImplementedError(
"Database Credential not found for Druid."
)

self.sqlalchemy_db_uri = (
f"druid://{username}:{password}@{host}:{port}/druid/v2/sql/"
if (username and password)
else f"druid://{host}:{port}/druid/v2/sql/"
)

if not(username and password):
self.sqlalchemy_db_uri = f"druid://{host}:{port}/druid/v2/sql/"
else:
self.sqlalchemy_db_uri = f"druid://{username}:{password}@{host}:{port}/druid/v2/sql/"
return self.sqlalchemy_db_uri

def get_db_engine(self):
Expand All @@ -40,44 +54,45 @@ def get_db_engine(self):
return self.engine

def test_connection(self):
if not hasattr(self, 'engine') or not self.engine:
if not hasattr(self, "engine") or not self.engine:
self.engine = self.get_db_engine()
query_text = text(self.test_db_query)
status, message = None, ""
try:
with self.engine.connect() as connection:
cursor = connection.execute(query_text)
results = cursor.all()
if results[0][0] == 1:
status = True
else:
status = False
status = results[0][0] == 1
except Exception as err_msg:
status = False
message = str(err_msg)
return status, message

def get_tables(self, use_schema=None):
all_tables = self.inspector.get_table_names(schema=use_schema)
filtered_tables = [table for table in all_tables if table not in self.druid_internal_tables]
return filtered_tables
return [
table
for table in all_tables
if table not in self.druid_internal_tables
]

def get_columns(self, use_table, use_schema=None):
db_columns = self.inspector.get_columns(table_name=use_table, schema=use_schema)
db_columns = self.inspector.get_columns(
table_name=use_table, schema=use_schema
)
for i in range(len(db_columns)):
if db_columns[i]["default"] is None:
db_columns[i]["default"] = "None"

db_columns[i]['type'] = str(db_columns[i]['type'])
db_columns[i]["type"] = str(db_columns[i]["type"])
return db_columns


def run_query(self, query, as_df=True):
engine = self.get_db_engine()
if as_df == True:
return merge_dataframe_chunks(pd.read_sql_query(query,
engine,
chunksize=self.CHUNKSIZE))
if as_df:
return merge_dataframe_chunks(
pd.read_sql_query(query, engine, chunksize=self.CHUNKSIZE)
)
else:
return []

Expand Down
Loading