diff --git a/dataduct/etl/etl_pipeline.py b/dataduct/etl/etl_pipeline.py index 67a1c5b..32b62fc 100644 --- a/dataduct/etl/etl_pipeline.py +++ b/dataduct/etl/etl_pipeline.py @@ -294,17 +294,11 @@ def emr_cluster(self): # Process the boostrap input bootstrap = self.emr_cluster_config.get('bootstrap', None) if bootstrap: - if isinstance(bootstrap, dict): - # If bootstrap script is not a path to local file - param_type = bootstrap['type'] - bootstrap = bootstrap['value'] - else: - # Default the type to path of a local file - param_type = 'path' - - if param_type == 'path': - bootstrap = S3File(path=bootstrap) + if 'string' in bootstrap: + bootstrap = bootstrap['string'] + elif 'script' in bootstrap: # Set the S3 Path for the bootstrap script + bootstrap = S3File(path=bootstrap) bootstrap.s3_path = self.s3_source_dir self.emr_cluster_config['bootstrap'] = bootstrap diff --git a/dataduct/steps/count_check.py b/dataduct/steps/count_check.py index 7a792b7..1977a05 100644 --- a/dataduct/steps/count_check.py +++ b/dataduct/steps/count_check.py @@ -23,7 +23,7 @@ class CountCheckStep(QATransformStep): def __init__(self, id, source_host, source_sql=None, source_table_name=None, destination_table_name=None, destination_table_definition=None, destination_sql=None, tolerance=1.0, script_arguments=None, - log_to_s3=False, script=None, **kwargs): + log_to_s3=False, script=None, source_count_sql=None, **kwargs): """Constructor for the CountCheckStep class Args: @@ -37,9 +37,9 @@ def __init__(self, id, source_host, source_sql=None, source_table_name=None, raise ETLInputError( 'One of dest table name/schema or dest sql needed') - if not exactly_one(source_sql, source_table_name): + if not exactly_one(source_sql, source_table_name, source_count_sql): raise ETLInputError( - 'One of source table name or source sql needed') + 'One of source table name or source sql or source count needed') if script_arguments is None: script_arguments = list() @@ -55,7 +55,7 @@ def __init__(self, id, source_host, source_sql=None, source_table_name=None, destination_table_name, destination_sql) src_sql = self.convert_source_to_count_sql( - source_table_name, source_sql) + source_table_name, source_sql, source_count_sql) script_arguments.extend([ '--tolerance=%s' % str(tolerance), @@ -89,11 +89,14 @@ def convert_destination_to_count_sql(destination_table_name=None, @staticmethod def convert_source_to_count_sql(source_table_name=None, - source_sql=None): + source_sql=None, + source_count_sql=None): """Convert the source query into generic structure to compare """ if source_table_name is not None: source_sql = "SELECT COUNT(1) FROM %s" % source_table_name + elif source_count_sql is not None: + source_sql = source_count_sql else: origin_sql = SqlStatement(source_sql) source_sql = "SELECT COUNT(1) FROM (%s)a" % origin_sql.sql() diff --git a/dataduct/steps/transform.py b/dataduct/steps/transform.py index 832516b..5537fce 100644 --- a/dataduct/steps/transform.py +++ b/dataduct/steps/transform.py @@ -99,6 +99,7 @@ def __init__(self, else: self._output = base_output_node + logger.debug('Script Arguments:') logger.debug(script_arguments) self.create_pipeline_object(