In [None]:
from snowflake.snowpark.session import Session
from snowflake.snowpark.functions import call_builtin
import json
import re
import logging

# Configure logging for better debugging and visibility
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')

# Get the active Snowpark session (in a Snowflake notebook, the session is usually already available)
session = Session.builder.getOrCreate()

def fetch_snowflake_data():
    query = """
    WITH OBJECTS_VIEW AS (
    SELECT 
        query_start_time,
        query_id,
        objects_modified,
        f.value:objectName::STRING AS object_name,
        f.value:objectDomain::STRING AS object_domain,
        f.value:objectDomain::STRING AS object_type,
        f.value:objectId::NUMBER AS object_id,
        c.value:columnName::STRING AS column_name,
        c.value:columnId::NUMBER AS column_id,
        d.value:columnName::STRING AS direct_source_column,
        d.value:objectName::STRING AS direct_source_table,
        d.value:objectDomain::STRING AS direct_source_type,
        rank() over (PARTITION by object_name order by query_start_time desc) as rnk

    FROM snowflake.account_usage.access_history,
    LATERAL FLATTEN(input => objects_modified) f,
    LATERAL FLATTEN(input => f.value:columns) c,
    LATERAL FLATTEN(input => c.value:directSources,OUTER => TRUE) d
    )
    
    SELECT DISTINCT 
        ov.object_name,
        ov.column_name,
        ov.direct_source_table,
        ov.direct_source_column,
        ov.direct_source_type,
        q.query_text,
    FROM OBJECTS_VIEW ov
    LEFT OUTER JOIN snowflake.account_usage.query_history AS q ON ov.query_id = q.query_id
    WHERE schema_name = 'DBT_JMARWAHA'
    and OBJECT_NAME IN ('JAFFLE_SHOP.DBT_JMARWAHA.CUSTOMERS','JAFFLE_SHOP.DBT_JMARWAHA.LOCATIONS','JAFFLE_SHOP.DBT_JMARWAHA.ORDERS','JAFFLE_SHOP.DBT_JMARWAHA.ORDER_ITEMS','JAFFLE_SHOP.DBT_JMARWAHA.SUPPLIES','JAFFLE_SHOP.DBT_JMARWAHA.PRODUCTS')
    and ov.rnk = 1

    ORDER BY ov.column_name
    """
    return session.sql(query).collect()
table_defination = fetch_snowflake_data()
#print(table_defination[0])
print(len(table_defination))



def extract_col_lineage_from_table(table_defination):
        table_defination=table_defination.as_dict()
        object_name=table_defination["OBJECT_NAME"]
        column_name=table_defination["COLUMN_NAME"]
        direct_source_table=table_defination["DIRECT_SOURCE_TABLE"]
        direct_source_column=table_defination["DIRECT_SOURCE_COLUMN"]
        direct_source_type=table_defination["DIRECT_SOURCE_TYPE"]
        query_text=table_defination["QUERY_TEXT"]
    
        
        prompt = f"""You are an expert in SQL lineage analysis.        
        Analyze the following data provided.
        Provide the very short and to the point and in business-friendly language transformation logic or reasoning between the source column and the target column based on the query text.        
        Additionally Provide the output in 100 characters or less without any unnecessary information about the task.
        
        OBJECT_NAME : {object_name}
        COLUMN_NAME : {column_name}
        DIRECT_SOURCE_TABLE : {direct_source_table}
        DIRECT_SOURCE_COLUMN : {direct_source_column}
        DIRECT_SOURCE_TYPE : {direct_source_type}
        QUERY_TEXT : {query_text}

        """
        prompt = prompt.replace("'", "''")

    
        # Call the Cortex LLM using the SNOWFLAKE.CORTEX.COMPLETE function
        try:
            lineage_response_df = session.sql(f"""
            SELECT SNOWFLAKE.CORTEX.COMPLETE(
            'llama3.1-405b',
            '{prompt}'
            ) AS LINEAGE_RESPONSE
            """)
        
            
            lineage_response_row = lineage_response_df.collect()[0]
            lineage_response = lineage_response_row['LINEAGE_RESPONSE']
        except Exception as e:
            lineage_response="error"
            logging.error(f"Error calling Cortex LLM: {e}")
            print(e)
        return object_name,column_name,direct_source_table,direct_source_column,direct_source_type,query_text,lineage_response

        
        
output_list = []
for table in table_defination:
    output = extract_col_lineage_from_table(table)
    output_list.append(output)
        
    
        
#print(output_list)
columns = ['object_name','column_name','direct_source_table','direct_source_column','direct_source_type','query_text','reasoning']


insert_df_table = session.create_dataframe(output_list,schema=columns)
insert_df_table.write.save_as_table("FINAL_LINEAGE_SNOWFLAKE_TABLE_NEW",mode="overwrite")
print("table is successfully created")



    
       






        

        
       
    


    
    