In [2]:
## imports and environment variables
# imports
from firecloud import api as fapi
import json
import os
import pandas as pd
import csv
from io import StringIO
from google.cloud import storage
from google.cloud import bigquery
import re
import hashlib
import logging

# Configure logging format
logging.basicConfig(format="%(asctime)s - %(levelname)s: %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p", level=logging.INFO)

# workspace environment variables
ws_name = os.environ["WORKSPACE_NAME"]
ws_project = os.environ["WORKSPACE_NAMESPACE"]
ws_bucket = os.environ["WORKSPACE_BUCKET"]
ws_bucket_name = re.sub('^gs://', '', ws_bucket)

# print(f"workspace name = {ws_name}")
# print(f"workspace project = {ws_project}")
# print(f"workspace bucket = {ws_bucket}")
# print(f"workspace bucket name = {ws_bucket_name}")


In [4]:
def profile_data(params):
    
    # Retrieve parameters of interest
    bq_project = params["bq_project"]
    bq_schema = params["bq_schema"]
    tdr_schema_file = params["tdr_schema_file"]
    val_output_dir = params["val_output_dir"]
    val_output_file = bq_schema + "_metric_results.csv"

    # Read schema object from GCS into a dictionary
    storage_client = storage.Client()
    bucket = storage_client.get_bucket(ws_bucket_name)
    schema_blob = bucket.blob(tdr_schema_file)
    schema_dict = json.loads(schema_blob.download_as_string(client=None))

    # Parse schema into something more useful for building queries: A table set, a list of fields, and a list of array fields
    table_set = set()
    array_field_set = set()
    field_list = []
    for table_entry in schema_dict['tables']:
        table_set.add(table_entry['name'])
        for column_entry in table_entry['columns']:
            field_dict = {}
            field_dict['table'] = table_entry['name']
            field_dict['column'] = column_entry['name']
            field_dict['datatype'] = column_entry['datatype']
            field_dict['is_array'] = column_entry['array_of']
            if column_entry['name'] in table_entry['primaryKey']:
                field_dict['is_primary_key'] = True
            else:
                field_dict['is_primary_key'] = False
            joins_to_list = []
            for relation_entry in schema_dict['relationships']:
                joins_to_dict = {}
                if relation_entry['from']['table'] == table_entry['name'] and relation_entry['from']['column'] == column_entry['name']:
                    joins_to_dict['table'] = relation_entry['to']['table']
                    joins_to_dict['column'] = relation_entry['to']['column']
                    joins_to_list.append(joins_to_dict)
            field_dict['joins_to'] = joins_to_list
            field_list.append(field_dict)
            if column_entry['array_of'] == True:
                array_field_set.add(table_entry['name'] + '.' + column_entry['name'])

    # Collect all of the incoming references to files (to build file orphan checks)
    file_join_list = []
    for relation_entry in schema_dict['relationships']:
        inner_list = []
        if relation_entry['to']['table'] == 'file' and relation_entry['to']['column'] == 'file_id':
            inner_list.append(relation_entry['from']['table'])
            inner_list.append(relation_entry['from']['column'])
            file_join_list.append(inner_list)
    
    # Initialize metric collect from BigQuery and create dataframe to store results 
    client = bigquery.Client()
    df = pd.DataFrame(columns = ['metric_type', 'source_table', 'source_column', 'metric', 'n', 'd', 'r'])
    
    
    ## Build and execute table-level queries
    logging.info("Building and executing table-level queries.")
    # Loop through tables in the table set and pull record counts
    for table_entry in table_set:

        # Construct the query
        query = """SELECT 'Summary Stats' AS metric_type, '{table}' AS source_table, 'All' AS source_column, 
                   'Count of records in table' AS metric, 
                   COUNT(*) AS n, null AS d, null AS r 
                   FROM `{project}.{schema}.{table}`""".format(project = bq_project, schema = bq_schema, table = table_entry)

        # Execute the query and append results to dataframe
        df = df.append(client.query(query).result().to_dataframe())
    
    
    ## Build and execute column-level queries
    logging.info("Building and executing column-level queries.")
    # Loop through columns and pull null counts and distinct value counts
    for column_entry in field_list:
    
        # Set parameters for null count and distinct count queries
        table_name = column_entry['table']
        col_name = column_entry['column']
        if column_entry['is_array'] == True:
            null_case_statement = 'array_length({col}) = 0'.format(col = col_name)
            distinct_case_statement = "(select string_agg(cast(val as string), ',') from unnest({col}) val)".format(col = col_name)
        else:
            null_case_statement = '{col} is null'.format(col = col_name)
            distinct_case_statement = col_name

        # Construct the null count query
        null_query = """SELECT 'Summary Stats' AS metric_type, '{table}' AS source_table, '{col}' AS source_column, 
                   'Count of nulls or empty lists in column' AS metric, 
                   SUM(CASE WHEN {null_case} THEN 1 ELSE 0 END) AS n, 
                   COUNT(*) AS d, 
                   CASE WHEN COUNT(*) > 0 THEN SUM(CASE WHEN {null_case} THEN 1 ELSE 0 END)/COUNT(*) END AS r 
                   FROM `{project}.{schema}.{table}`""".format(project = bq_project, schema = bq_schema, table = table_name, col = col_name, null_case = null_case_statement)

        # Execute the null count query and append results to dataframe
        df = df.append(client.query(null_query).result().to_dataframe())
        #print(null_query)

        # Construct the distinct count query
        distinct_query = """SELECT 'Summary Stats' AS metric_type, '{table}' AS source_table, '{col}' AS source_column, 
                   'Count of distinct values in column' AS metric, 
                   COUNT(DISTINCT {distinct_case}) AS n, 
                   COUNT(*) AS d, 
                   CASE WHEN COUNT(*) > 0 THEN COUNT(DISTINCT {distinct_case})/COUNT(*) END AS r 
                   FROM `{project}.{schema}.{table}`""".format(project = bq_project, schema = bq_schema, table = table_name, col = col_name, distinct_case = distinct_case_statement)

        # Execute the distinct count query and append results to dataframe
        df = df.append(client.query(distinct_query).result().to_dataframe())
        #print(distinct_query)

        # Loop through join fields (if any) and build referential integrity queries 
        for join_entry in column_entry['joins_to']:

            # Set parameters for referential integrity queries
            target_table = join_entry['table']
            target_col = join_entry['column']
            target_table_col = target_table + '.' + target_col
            if column_entry['is_array'] == True:
                src_col_name = '{col}_unnest'.format(col = col_name)
                from_statement = '(select * from `{project}.{schema}.{table}` t left join unnest(t.{col}) as {unnest_col}) src'.format(project = bq_project, schema = bq_schema, table = table_name, col = col_name, unnest_col = src_col_name)
                where_statement = 'array_length(src.{col}) > 0'.format(col = col_name)
            else:
                src_col_name = col_name
                from_statement = '`{project}.{schema}.{table}` src'.format(project = bq_project, schema = bq_schema, table = table_name)
                where_statement = 'src.{col} is not null'.format(col = col_name)
            if target_table_col in array_field_set:
                tar_col_name = '{col}_unnest'.format(col = target_col)
                join_statement = '(select * from `{project}.{schema}.{table}` t left join unnest(t.{col}) as {unnest_col}) tar'.format(project = bq_project, schema = bq_schema, table = target_table, col = target_col, unnest_col = tar_col_name)
            else:
                tar_col_name = target_col
                join_statement = '`{project}.{schema}.{table}` tar'.format(project = bq_project, schema = bq_schema, table = target_table)

            # Construct the referential integrity query
            ref_int_query = """SELECT 'Referential Integrity' AS metric_type, '{table}' AS source_table, '{col}' AS source_column, 
                   'Count of non-null rows that do not fully join to {target}' AS metric, 
                   COUNT(DISTINCT CASE WHEN tar.datarepo_row_id IS NULL THEN src.datarepo_row_id END) AS n, 
                   COUNT(DISTINCT src.datarepo_row_id) AS d, 
                   CASE WHEN COUNT(DISTINCT src.datarepo_row_id) > 0 THEN COUNT(DISTINCT CASE WHEN tar.datarepo_row_id IS NULL THEN src.datarepo_row_id END)/COUNT(DISTINCT src.datarepo_row_id) END AS r
                   FROM {frm}
                   LEFT JOIN {join}
                   ON src.{src_col} = tar.{tar_col}
                   WHERE {where}""".format(project = bq_project, schema = bq_schema, table = table_name, col = col_name, target = target_table_col, frm = from_statement, join = join_statement, src_col = src_col_name, tar_col = tar_col_name, where = where_statement)

            # Execute the referential integrity query and append results to dataframe
            df = df.append(client.query(ref_int_query).result().to_dataframe())
            #print(ref_int_query)
    
    
    ## Build and execute file orphan query
    logging.info("Building and executing orphan file queries.")
    # Construct file foreign keys CTE query
    counter = 0
    cte_query = 'WITH file_fks AS ('
    for entry in file_join_list:
        cte_query_segment = ''
        counter += 1
        source_table = entry[0]
        source_column = entry[1]
        source_table_col = entry[0] + '.' + entry[1]
        if counter > 1:
            cte_query_segment = 'UNION ALL '
        if source_table_col in array_field_set:
            cte_query_segment = cte_query_segment + 'SELECT DISTINCT file_id FROM `{project}.{schema}.{table}` CROSS JOIN UNNEST({col}) AS file_id'.format(project = bq_project, schema = bq_schema, table = source_table, col = source_column)
        else:
            cte_query_segment = cte_query_segment + 'SELECT DISTINCT {col} as file_id FROM `{project}.{schema}.{table}`'.format(project = bq_project, schema = bq_schema, table = source_table, col = source_column)
        cte_query = cte_query + cte_query_segment + ' '
    cte_query = cte_query + ')'

    # Construct orphaned files query
    orphaned_file_query = """{cte} SELECT 'Orphaned Files' As metric_type, 'file' AS source_table, 'file_id' AS source_column, 
                          'Count of file_ids not referenced by another table' AS metric,
                          COUNT(DISTINCT CASE WHEN tar.file_id IS NULL THEN src.file_id END) AS n,
                          COUNT(DISTINCT src.file_id) AS d,
                          CASE WHEN COUNT(DISTINCT src.file_id) > 0 THEN COUNT(DISTINCT CASE WHEN tar.file_id IS NULL THEN src.file_id END)/COUNT(DISTINCT src.file_id) END AS r
                          FROM `{project}.{schema}.file` src LEFT JOIN file_fks tar ON src.file_id = tar.file_id""".format(cte = cte_query, project = bq_project, schema = bq_schema)

    # Execute the orphaned files query and append results to dataframe
    df = df.append(client.query(orphaned_file_query).result().to_dataframe())
    
    ## Write metrics results dataframe out to CSV
    logging.info(f"Writing out results to {val_output_dir}/{val_output_file}.")
    df_final = df.fillna(0)
    destination_dir = val_output_dir
    destination_file = val_output_file
    df_final.to_csv(destination_file, index=False)

    # Copy file to workspace bucket
    !gsutil cp $destination_file $ws_bucket/$destination_dir/ 2> stdout

    # Remove file from notebook environment
    !rm $destination_file


In [5]:
# Test
# params = {}
# params["val_output_dir"] = "ingest_pipeline/output/tim_core/validation"
# params["tdr_schema_file"] = "ingest_pipeline/output/tim_core/schema/tdr_schema_object.json"
# params["bq_project"] = "datarepo-7949025c"
# params["bq_schema"] = "datarepo_tdr_anvil_ingest_bjt"
# profile_data(params)

05/10/2022 04:11:58 PM - INFO: Building and executing table-level queries.
05/10/2022 04:12:19 PM - INFO: Building and executing column-level queries.
05/10/2022 04:27:41 PM - INFO: Building and executing orphan file queries.
05/10/2022 04:27:44 PM - INFO: Writing out results to ingest_pipeline/output/tim_core/validation/datarepo_tdr_anvil_ingest_bjt_metric_results.csv.
