In [None]:
# Import python packages
import snowflake.connector
from snowflake.connector import DictCursor
import pandas as pd
from snowflake.connector.pandas_tools import write_pandas
import time
import sys

# We can also use Snowpark for our analyses!
# from snowflake.snowpark.context import get_active_session
# session = get_active_session()


In [None]:
user = sys.argv[0].split(',')[0] if sys.argv[0].count(',')>=2 else 'BASUK'
print(f"snowflake user : {user}")
database = sys.argv[0].split(',')[2] if sys.argv[0].count(',')>=2 else 'DEV_P_CMSC_CMO_DB'
print(f"snowflake database : {database}")
warehouse = sys.argv[0].split(',')[1] if sys.argv[0].count(',')>=2 else 'DEV_CMSC_CMO_WH'
print(f"snowflake warehouse : {warehouse}")
schemas = sys.argv[0].split(',')[3:] if sys.argv[0].count(',')>=3 else ['MODELED','REPORTING']
print(f"snowflake schemas : {schemas}")

In [None]:
printLINEAGE_COLUMNS = [
    'ROOT_DATABASE',
    'ROOT_SCHEMA',
    'ROOT_TABLE',
    'ROOT_COLUMN',
    'DISTANCE',
    'SOURCE_OBJECT_DOMAIN',
    'SOURCE_OBJECT_DATABASE',
    'SOURCE_OBJECT_SCHEMA',
    'SOURCE_OBJECT_NAME',
    'SOURCE_COLUMN_NAME',
    'SOURCE_IS_IDENTITY',
    'SOURCE_IS_NULLABLE',
    'SOURCE_DATA_TYPE',
    'SOURCE_CHARACTER_MAXIMUM_LENGTH',
    'SOURCE_NUMERIC_PRECISION',
    'SOURCE_NUMERIC_SCALE',
    'SOURCE_COMMENT',
    'SOURCE_STATUS',
    'TARGET_OBJECT_DOMAIN',
    'TARGET_OBJECT_DATABASE',
    'TARGET_OBJECT_SCHEMA',
    'TARGET_OBJECT_NAME',
    'TARGET_COLUMN_NAME',
    'TARGET_IS_IDENTITY',
    'TARGET_IS_NULLABLE',
    'TARGET_DATA_TYPE',
    'TARGET_CHARACTER_MAXIMUM_LENGTH',
    'TARGET_NUMERIC_PRECISION',
    'TARGET_NUMERIC_SCALE',
    'TARGET_COMMENT',
    'TARGET_STATUS'
]


In [None]:
def get_cursor(conn):
    return conn.cursor(DictCursor)

In [None]:
def get_columns(conn, database, schemas) -> list:
    cur = get_cursor(conn)

    column_sql = f"""
        select table_catalog, table_schema, table_name, column_name, ordinal_position
        from {database}.INFORMATION_SCHEMA.COLUMNS
        where 1=1
    """

    if schemas:
        quoted_schemas = [
            "'" + str(schema).upper() + "'" for schema in schemas]
        column_sql += f"""
        and table_schema in ({','.join(quoted_schemas)})
        """

    column_sql += f"""
        order by table_catalog, table_schema, table_name, ordinal_position
    """

    columns = cur.execute(column_sql).fetchall()
    return columns

In [None]:
def get_column_lineage_sql(column, with_order_by=True):
    domain = f"'{column.get('TABLE_CATALOG')}.{column.get('TABLE_SCHEMA')}.{column.get('TABLE_NAME')}.{column.get('COLUMN_NAME')}'"

    order_by = ""
    if with_order_by:
        order_by = "ORDER BY DISTANCE"

    lineage_sql = f"""
        SELECT
            DISTANCE,
            SOURCE_OBJECT_DOMAIN,
            SOURCE_OBJECT_DATABASE,
            SOURCE_OBJECT_SCHEMA,
            SOURCE_OBJECT_NAME,
            SOURCE_COLUMN_NAME,
            SOURCE_STATUS,
            TARGET_OBJECT_DOMAIN,
            TARGET_OBJECT_DATABASE,
            TARGET_OBJECT_SCHEMA,
            TARGET_OBJECT_NAME,
            TARGET_COLUMN_NAME,
            TARGET_STATUS,
            '{column.get('TABLE_CATALOG')}' as ROOT_DATABASE,
            '{column.get('TABLE_SCHEMA')}' as ROOT_SCHEMA,
            '{column.get('TABLE_NAME')}' as ROOT_TABLE,
            '{column.get('COLUMN_NAME')}' as ROOT_COLUMN
        FROM TABLE (SNOWFLAKE.CORE.GET_LINEAGE({domain}, 'COLUMN', 'UPSTREAM'))
        {order_by}
    """

    return lineage_sql


In [None]:
column_metadata_cache = {}


def get_column_metadata(conn, database, schema, table, column, type):
    cur = get_cursor(conn)

    cache_key = f"{type}:{database}.{schema}".upper()

    if not cache_key in column_metadata_cache:
        print(f"Cache miss for cache_key: {cache_key}")

        query = f"""
            select
                table_name as TABLE_NAME,
                column_name as COLUMN_NAME,
                '{type}' as COLUMN_TYPE,
                is_identity as {type}_IS_IDENTITY,
                
                is_nullable as {type}_IS_NULLABLE,
                data_type as {type}_DATA_TYPE,
                character_maximum_length as {type}_CHARACTER_MAXIMUM_LENGTH,
                numeric_precision as {type}_NUMERIC_PRECISION,
                numeric_scale as {type}_NUMERIC_SCALE,
                comment as {type}_COMMENT
            from {database}.INFORMATION_SCHEMA.COLUMNS
            where table_catalog = '{database}'
            and table_schema = '{schema}'
        """

        column_metadata = cur.execute(query).fetchall()
        column_metadata_cache[cache_key] = column_metadata

    column_data = next((x for x in column_metadata_cache[cache_key] if x.get(
        "TABLE_NAME") == table and x.get("COLUMN_NAME") == column and x.get("COLUMN_TYPE") == type))

    # Need a copy to not alter the cache
    column_data_copy = column_data.copy()
    if column_data_copy:
        # Have to delete these keys because we don't want it on the final output
        del column_data_copy["TABLE_NAME"]
        del column_data_copy["COLUMN_NAME"]
        del column_data_copy["COLUMN_TYPE"]

    return column_data_copy


In [None]:
def get_column_lineage(conn, column):
    cur = get_cursor(conn)

    lineage_sql = get_column_lineage_sql(column)

    lineage = cur.execute(lineage_sql).fetchall()
    return lineage

In [None]:
def get_column_lineage_results(conn, sfqid):
    cur = get_cursor(conn)

    cur.get_results_from_sfqid(sfqid)
    lineage = cur.fetchall()
    return lineage

In [None]:
def check_if_query_complete(conn, sfqid):
    is_running = conn.is_still_running(conn.get_query_status(sfqid))
    return not is_running


In [None]:
def get_column_lineage_async(conn, column):
    cur = get_cursor(conn)

    lineage_sql = get_column_lineage_sql(column)

    cur.execute_async(lineage_sql)

    return cur.sfqid

In [None]:
def drop_table(user, database, schema, table, warehouse):
    '''Drop the existing mapping table so that it can be recreated'''
    con = snowflake.connector.connect(
        user=user,
        authenticator="externalBrowser",
        account='vrtx-data',
        database=database,
        warehouse=warehouse,
        client_session_keep_alive=True,
        client_fetch_use_mp=True
    )
    table_name = f'{database}.{schema}.{table}'
    cur = con.cursor()
    drop_query = f"DROP TABLE IF EXISTS {table_name};"
    cur.execute(drop_query)
    cur.close()
    con.close()

In [None]:
def chunks(lst, n):
    """Yield successive n-sized chunks from lst."""
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

In [None]:
if __name__ == "__main__":
    # parser = argparse.ArgumentParser()
    # parser.add_argument("-d", "--database", default = "DEV_P_CMSC_CMO_DB",
    #                     help="Snowflake database for connection")
    # parser.add_argument("-s", "--schemas",
    #                     help="Snowflake schema for connection.  Provide comma separated list to filter to specific schemas.", default="reporting,modeled")
    # parser.add_argument("-w", "--warehouse",
    #                     help="Snowflake warehouse for connection", default="DEV_CMSC_CMO_WH")
    # args = parser.parse_args()
    args = {'user':user,'database':database,'warehouse':warehouse,'schemas':schemas}
    print(args)
    conn = snowflake.connector.connect(
        user=user,
        authenticator="externalBrowser",
        account='vrtx-data',
        database=args['database'],
        warehouse=args['warehouse'],
        client_session_keep_alive=True,
        client_fetch_use_mp=True
    )
    # schemas = args['schemas'].split(',')
    columns: list = get_columns(conn, args['database'], schemas)
    # user_schema = f'DBT_{user}'
    try:
        # drop_table('DEV_P_CMSC_CMO_DB.DBT_KBASU.MAPPING')
        drop_table(user, database,'ARTIFACTS','LINEAGE',warehouse)
    except Exception as e:
        print(e)
    q = []
    for column_subset in chunks(columns, 50):
            sfqids = []
            for column in column_subset:
                domain = f"'{column.get('TABLE_CATALOG')}.{column.get('TABLE_SCHEMA')}.{column.get('TABLE_NAME')}.{column.get('COLUMN_NAME')}'"
                print(f"Submitting lineage SQL for {domain}")
                sfqids.append(get_column_lineage_async(conn, column))
            query_results = []
            while len(sfqids):
                print(
                    f"# of unfinished lineage SQL queries: {len(sfqids)}")
                sfqid = sfqids[0]
                is_complete = check_if_query_complete(conn, sfqid)
                if is_complete:
                    lineage_records = get_column_lineage_results(conn, sfqid)
                    for record in lineage_records:
                        source_column_metadata = get_column_metadata(
                            conn=conn,
                            database=record.get('SOURCE_OBJECT_DATABASE'),
                            schema=record.get('SOURCE_OBJECT_SCHEMA'),
                            table=record.get('SOURCE_OBJECT_NAME'),
                            column=record.get('SOURCE_COLUMN_NAME'),
                            type="SOURCE"
                        )
                        target_column_metadata = get_column_metadata(
                            conn=conn,
                            database=record.get('TARGET_OBJECT_DATABASE'),
                            schema=record.get('TARGET_OBJECT_SCHEMA'),
                            table=record.get('TARGET_OBJECT_NAME'),
                            column=record.get('TARGET_COLUMN_NAME'),
                            type="TARGET"
                        )
                        record = record | source_column_metadata | target_column_metadata
                        query_results.append(record)
                    sfqids.pop(0)
                else:
                    time.sleep(5)
            df = pd.DataFrame(query_results)
            #Write the fresh mapping
            try:
                        write_pandas(
                            conn=conn,
                            df=df,
                            database=database,
                            schema='ARTIFACTS',
                            table_name='LINEAGE',
                            auto_create_table=True)
                        print(f'Check {database}.ARTIFACTS.LINEAGE table for source to target mapping details')
            except Exception as e:
                        print(e)