This code creates a Streamlit web application for comparing and mapping data between modernized and legacy tables in Snowflake. It uses AI models to analyze the data and generate mappings, with options for incorporating vector embeddings and sample data. The app allows users to input table information, select AI models and parameters, and view the generated prompts and AI responses.

# Import necessary libraries
import streamlit as st
from snowflake.snowpark.context import get_active_session
import json

# Get the current Snowflake session
session = get_active_session()

# Set up the Streamlit app title
st.title("Data Comparison and Mapping Tool")

# Initialize session state variables for storing prompt, AI response, and token count
# These persist across reruns of the app

# Input fields for user to enter database, schema, and table information
# Default values are provided for convenience

# Dropdown for selecting the AI model to use

# Checkboxes for additional options (vector embeddings and sample data)

# Function to fetch available vector embeddings based on table names
def fetch_available_embeddings(modern_table, legacy_table):
    # Query Snowflake to get relevant vector mappings

# Fetch and display available embeddings if any are found

# Function to generate the prompt for the AI model
def generate_prompt():
    # Construct the prompt based on user selections and data
    # Include vector embeddings and/or sample data if selected
    # Format the prompt with instructions for the AI model

# Function to get the AI response using Snowflake's CORTEX.COMPLETE function
def get_ai_response(options):
    # Call the AI model with the generated prompt and user-specified parameters
    # Handle the response and any potential errors

# Button to generate the prompt
# When clicked, it generates the prompt and calculates the token count

# Display the generated prompt and token count if available

# UI for setting AI model parameters (temperature, max tokens, top p)

# Button to get the AI response
# When clicked, it calls the AI model with the generated prompt and chosen parameters

# Display the AI model response, including the generated text and usage statistics

In [None]:
import streamlit as st
from snowflake.snowpark.context import get_active_session
import json

# Get the current Snowflake session
session = get_active_session()

st.title("Data Comparison and Mapping Tool")

# Initialize session state variables
if 'prompt' not in st.session_state:
    st.session_state.prompt = ""
if 'ai_response' not in st.session_state:
    st.session_state.ai_response = ""
if 'token_count' not in st.session_state:
    st.session_state.token_count = 0

# Input fields for database, schema, and table names
database_name = st.text_input("Enter Database Name", "DATA_OPS_MAPPING")
schema_name = st.text_input("Enter Schema Name", "DATA_ONESTREAM")
modern_table_name = st.text_input("Enter Modernised Table Name", "Deposit_Accounts_Onestream")
legacy_table_name = st.text_input("Enter Legacy Table Name", "LEGACY_DEPOSIT_ACCOUNTS")
record_limit = st.number_input("Enter Sample Record Limit", min_value=1, value=5)

# Model selection
model = st.selectbox("Select AI Model", [
    "snowflake-arctic", "mistral-large", "reka-flash", "reka-core", "mixtral-8x7b",
    "jamba-instruct", "llama2-70b-chat", "llama3-8b", "llama3-70b", "llama3.1-8b",
    "llama3.1-70b", "llama3.1-405b", "mistral-7b", "gemma-7b"
])

# Options for prompt generation
use_vector_embeddings = st.checkbox("Use Vector Embeddings")
use_sample_data = st.checkbox("Use Sample Data")

def fetch_available_embeddings(modern_table, legacy_table):
    query = f"""
    SELECT VECTOR_MAPPING_FOR_TABLE_NAME 
    FROM SMART_AI_MAPPER.SMART_AI_MAPPER_TOOL.TOP_3_SIMILAR_FIELDS_FROM_VEC_EMB
    WHERE UPPER(VECTOR_MAPPING_FOR_TABLE_NAME) LIKE '%{modern_table.upper()}%' 
    AND UPPER(VECTOR_MAPPING_FOR_TABLE_NAME) LIKE '%{legacy_table.upper()}%'
    GROUP BY 1
    """
    result = session.sql(query).collect()
    return [row['VECTOR_MAPPING_FOR_TABLE_NAME'] for row in result]

# Fetch available embeddings
available_embeddings = fetch_available_embeddings(modern_table_name, legacy_table_name)

# Dropdown for selecting embedding
selected_embedding = None
if available_embeddings:
    selected_embedding = st.selectbox("Select Vector Embedding", available_embeddings)
else:
    st.warning("No matching vector embeddings found for the given table names.")

def generate_prompt():
    combined_data = {}
    prompt_text = "Given the provided data, compare the fields between the Modernised and Legacy tables. "
    
    if use_vector_embeddings and selected_embedding:
        # Fetch vector embeddings and mappings
        vector_mappings = session.sql(f"""
        SELECT ARRAY_AGG(
            OBJECT_CONSTRUCT(*)) AS JSON_DATA FROM
            (
        SELECT VECTOR_MAPPING_FOR_TABLE_NAME,        
         LISTAGG(CONCAT(MODERNISED_TABLE_FIELD_NAME,'-',TOP_SIMILARITIES_LEGACY_FIELDS,',')) within group 
            (ORDER BY MODERNISED_TABLE_FIELD_NAME ) AS Modern_table_field_to_legacy_table_mappings
        FROM  SMART_AI_MAPPER.SMART_AI_MAPPER_TOOL.TOP_3_SIMILAR_FIELDS_FROM_VEC_EMB
        WHERE VECTOR_MAPPING_FOR_TABLE_NAME = '{selected_embedding}'
        GROUP BY 1)
        """).collect()[0]['JSON_DATA']
        combined_data["vector_mappings"] = vector_mappings
        prompt_text += "Use the vector embeddings to help map fields between the Modernised table and the Legacy table. "
    
    if use_sample_data:
        # Call the stored procedure for modern table
        modern_result = session.call('fetch_and_process_table_data', 
                                     database_name, 
                                     schema_name, 
                                     modern_table_name, 
                                     record_limit)
        
        # Call the stored procedure for legacy table
        legacy_result = session.call('fetch_and_process_table_data', 
                                     database_name, 
                                     schema_name, 
                                     legacy_table_name, 
                                     record_limit)
        
        combined_data["modern_data"] = modern_result
        combined_data["legacy_data"] = legacy_result
        prompt_text += "Analyze the sample data to identify discrepancies and provide a mapping between the two tables. Highlight any differences in field names or data values. "
    
    prompt_text += """
    For fields with mismatches, suggest corrections and as per mapping identified, provide a Snowflake compatiable SQL query to transform data from Modernised table to align with the Legacy Table. 
    Ensure that all data matches correctly between Modernised and Legacy, Report if any discrepancies and apply if transformations required in the snowflake 
    SQL generated.
    If there are fields in one table that do not have direct matches in the other table, note these discrepancies and indicate how to handle them.
    
    Data: {data}
    """
    
    st.session_state.prompt = prompt_text.format(data=json.dumps(combined_data))

def get_ai_response(options):
    try:
        # Call the CORTEX.COMPLETE function directly
        result = session.sql(f"""
        SELECT SNOWFLAKE.CORTEX.COMPLETE(
            '{model}',
            ARRAY_CONSTRUCT(OBJECT_CONSTRUCT('role', 'user', 'content', '{st.session_state.prompt.replace("'", "''")}')),
            OBJECT_CONSTRUCT('temperature', {options['temperature']}, 'max_tokens', {options['max_tokens']}, 'top_p', {options['top_p']})
        ) AS response
        """).collect()
        
        if result and len(result) > 0:
            response_json = json.loads(result[0]['RESPONSE'])
            st.session_state.ai_response = response_json
        else:
            st.error("No response received from the AI model.")
    except Exception as e:
        st.error(f"An error occurred: {str(e)}")
        st.session_state.ai_response = "Error: Unable to get AI response."

if st.button("Generate Prompt"):
    generate_prompt()
    
    # Calculate token count
    try:
        token_count_result = session.sql(f"""
        SELECT SNOWFLAKE.CORTEX.COUNT_TOKENS(
            '{model}',
            '{st.session_state.prompt.replace("'", "''")}'
        ) AS token_count
        """).collect()
        
        if token_count_result and len(token_count_result) > 0:
            st.session_state.token_count = token_count_result[0]['TOKEN_COUNT']
        else:
            st.warning("Unable to calculate token count.")
    except Exception as e:
        st.warning(f"Error calculating token count: {str(e)}")

if st.session_state.prompt:
    st.subheader("Generated Prompt:")
    st.text_area("Prompt", st.session_state.prompt, height=300)
    st.write(f"Token Count: {st.session_state.token_count}")
    
    st.subheader("Model Parameters")
    temperature = st.slider("Temperature", 0.0, 1.0, 0.7, 0.1)
    max_tokens = st.number_input("Max Tokens", 1, 4096, min(st.session_state.token_count, 4096))
    top_p = st.slider("Top P", 0.0, 1.0, 1.0, 0.1)
    
    if st.button("Get AI Response"):
        options = {
            "temperature": temperature,
            "max_tokens": max_tokens,
            "top_p": top_p,
        }
        get_ai_response(options)

if st.session_state.ai_response:
    st.subheader("AI Model Response:")
    if isinstance(st.session_state.ai_response, dict):
        if 'choices' in st.session_state.ai_response and len(st.session_state.ai_response['choices']) > 0:
            message = st.session_state.ai_response['choices'][0].get('messages', '')
            st.markdown(message)
        
        if 'usage' in st.session_state.ai_response:
            usage = st.session_state.ai_response['usage']
            st.write("Token Usage:")
            st.write(f"- Completion tokens: {usage.get('completion_tokens', 'N/A')}")
            st.write(f"- Prompt tokens: {usage.get('prompt_tokens', 'N/A')}")
            st.write(f"- Total tokens: {usage.get('total_tokens', 'N/A')}")
        
        if 'model' in st.session_state.ai_response:
            st.write(f"Model used: {st.session_state.ai_response['model']}")
    else:
        st.write(st.session_state.ai_response)


Sample Report Generated by Llama3.1-405b with Both Vector Embeddings and Smaple Data Given.

AI Model Response:
After analyzing the provided data, I've identified the following field mappings between the Modernised and Legacy tables:

Matching Fields:

ACCOUNT_BALANCE (Modernised) → ACCT_BALANCE (Legacy)
ACCOUNT_NUMBER (Modernised) → ACCT_NUM (Legacy)
ACCOUNT_STATUS (Modernised) → ACCT_STATUS (Legacy)
ACCOUNT_TYPE (Modernised) → ACCT_TYPE (Legacy)
ADDRESS (Modernised) → ADDR (Legacy)
BRANCH_ID (Modernised) → BRANCH_ID (Legacy)
CITY (Modernised) → CITY (Legacy)
CREATED_TIMESTAMP (Modernised) → CREATION_TIMESTAMP (Legacy)
CUSTOMER_ID (Modernised) → CUST_ID (Legacy)
DATE_OF_BIRTH (Modernised) → DOB (Legacy)
EMAIL (Modernised) → EMAIL_ADDRESS (Legacy)
FIRST_NAME (Modernised) → F_NAME (Legacy)
INTEREST_RATE (Modernised) → INTEREST (Legacy)
KYC_STATUS (Modernised) → KYC (Legacy)
LAST_NAME (Modernised) → L_NAME (Legacy)
LAST_TRANSACTION_DATE (Modernised) → LAST_TRANS_DATE (Legacy)
MARKETING_OPT_IN (Modernised) → MARKETING_CONSENT (Legacy)
PHONE_NUMBER (Modernised) → PHONE (Legacy)
RISK_LEVEL (Modernised) → RISK (Legacy)
SSN (Modernised) → SOCIAL_SECURITY_NUMBER (Legacy)
STATE (Modernised) → STATE (Legacy)
ZIP_CODE (Modernised) → POSTAL_CODE (Legacy)
Non-Matching Fields:

EMPLOYMENT_STATUS (Legacy) - No equivalent field in Modernised table.
MIDDLE_NAME (Legacy) - No equivalent field in Modernised table.
NATIONALITY (Legacy) - No equivalent field in Modernised table.
MODIFICATION_TIMESTAMP (Legacy) - While there is an UPDATED_TIMESTAMP field in the Modernised table, the values do not match.
OPEN_DATE (Legacy) - While there is a DATE_OPENED field in the Modernised table, the values do not match.
Transformations:

To align the Modernised table with the Legacy table, the following transformations are necessary:

Rename ACCOUNT_BALANCE to ACCT_BALANCE
Rename ACCOUNT_NUMBER to ACCT_NUM
Rename ACCOUNT_STATUS to ACCT_STATUS
Rename ACCOUNT_TYPE to ACCT_TYPE
Rename ADDRESS to ADDR
Rename CREATED_TIMESTAMP to CREATION_TIMESTAMP
Rename CUSTOMER_ID to CUST_ID
Rename DATE_OF_BIRTH to DOB
Rename EMAIL to EMAIL_ADDRESS
Rename FIRST_NAME to F_NAME
Rename INTEREST_RATE to INTEREST
Rename KYC_STATUS to KYC
Rename LAST_NAME to L_NAME
Rename LAST_TRANSACTION_DATE to LAST_TRANS_DATE
Rename MARKETING_OPT_IN to MARKETING_CONSENT
Rename PHONE_NUMBER to PHONE
Rename RISK_LEVEL to RISK
Rename SSN to SOCIAL_SECURITY_NUMBER
Rename STATE to STATE (no change)
Rename ZIP_CODE to POSTAL_CODE
Here is a sample Snowflake-compatible SQL query that performs these transformations:

SELECT 
  ACCOUNT_BALANCE AS ACCT_BALANCE,
  ACCOUNT_NUMBER AS ACCT_NUM,
  ACCOUNT_STATUS AS ACCT_STATUS,
  ACCOUNT_TYPE AS ACCT_TYPE,
  ADDRESS AS ADDR,
  BRANCH_ID,
  CITY,
  CREATED_TIMESTAMP AS CREATION_TIMESTAMP,
  CUSTOMER_ID AS CUST_ID,
  DATE_OF_BIRTH AS DOB,
  EMAIL AS EMAIL_ADDRESS,
  FIRST_NAME AS F_NAME,
  INTEREST_RATE AS INTEREST,
  KYC_STATUS AS KYC,
  LAST_NAME AS L_NAME,
  LAST_TRANSACTION_DATE AS LAST_TRANS_DATE,
  MARKETING_OPT_IN AS MARKETING_CONSENT,
  PHONE_NUMBER AS PHONE,
  RISK_LEVEL AS RISK,
  SSN AS SOCIAL_SECURITY_NUMBER,
  STATE,
  ZIP_CODE AS POSTAL_CODE
FROM 
  MODERNISED_TABLE;

Note that this query assumes that the Modernised table is named MODERNISED_TABLE. You should replace this with the actual table name in your database. Additionally, this query does not handle the non-matching fields (EMPLOYMENT_STATUS, MIDDLE_NAME, NATIONALITY, MODIFICATION_TIMESTAMP, and OPEN_DATE). You will need to decide how to handle these fields based on your specific use case.

Token Usage:

Completion tokens: 1131
Prompt tokens: 4173
Total tokens: 5304
Model used: llama3.1-405b

Model Response for Different Dataset with Array type compared only with sample data .

Based on the provided data, I've identified the following discrepancies and mapping between the Modernised and Legacy tables:

Field name mismatches:
CREATED_AT (Modernised) vs. CRT_DT (Legacy)
CUSTOMER_SINCE (Modernised) vs. CUST_SINCE (Legacy)
DATE_OF_BIRTH (Modernised) vs. DOB (Legacy)
ANNUAL_INCOME (Modernised) vs. ANN_INC (Legacy)
EDUCATION_LEVEL (Modernised) vs. EDU_LVL (Legacy)
EMPLOYMENT_STATUS (Modernised) vs. EMP_STS (Legacy)
GENDER (Modernised) vs. GNDR (Legacy)
HOUSEHOLD_SIZE (Modernised) vs. HHLD_SZ (Legacy)
HOME_OWNERSHIP (Modernised) vs. HOME_OWN (Legacy)
LEGAL_NAME (Modernised) vs. LEGAL_NM (Legacy)
MARITAL_STATUS (Modernised) vs. MRTL_STS (Legacy)
NET_WORTH (Modernised) vs. NET_WRTH (Legacy)
OCCUPATION (Modernised) vs. OCCPTN (Legacy)
PARTY_TYPE (Modernised) vs. PARTY_TYP (Legacy)
UPDATED_AT (Modernised) vs. UPD_DT (Legacy)
Data type mismatches:
DEMOGRAPHICS (Modernised) is an array of objects, while the Legacy table has separate fields for each demographic attribute.
Data value mismatches:
The PARTY_ID values do not match between the two tables.
To transform the Modernised table to align with the Legacy table, I suggest the following Snowflake compatible SQL query:

SELECT 
    m.PARTY_ID,
    m.CREATED_AT AS CRT_DT,
    m.CUSTOMER_SINCE AS CUST_SINCE,
    m.DATE_OF_BIRTH AS DOB,
    d.ANNUAL_INCOME AS ANN_INC,
    d.EDUCATION_LEVEL AS EDU_LVL,
    d.EMPLOYMENT_STATUS AS EMP_STS,
    d.GENDER AS GNDR,
    d.HOUSEHOLD_SIZE AS HHLD_SZ,
    d.HOME_OWNERSHIP AS HOME_OWN,
    m.LEGAL_NAME AS LEGAL_NM,
    d.MARITAL_STATUS AS MRTL_STS,
    d.NET_WORTH AS NET_WRTH,
    d.OCCUPATION AS OCCPTN,
    m.PARTY_TYPE AS PARTY_TYP,
    m.UPDATED_AT AS UPD_DT
FROM 
    modern_data.processed_data m,
    LATERAL FLATTEN(m.DEMOGRAPHICS) d

This query flattens the DEMOGRAPHICS array and joins it with the main table, mapping the fields to their corresponding Legacy table fields.

Note that the PARTY_ID values still do not match between the two tables. To resolve this, you may need to perform additional data cleansing or matching steps.

Additionally, the Legacy table has separate fields for each demographic attribute, while the Modernised table has an array of objects. If you want to maintain the same data structure in the transformed table, you can modify the query to use Snowflake's OBJECT_CONSTRUCT function to create separate columns for each demographic attribute. For example:

SELECT 
    m.PARTY_ID,
    m.CREATED_AT AS CRT_DT,
    m.CUSTOMER_SINCE AS CUST_SINCE,
    m.DATE_OF_BIRTH AS DOB,
    OBJECT_CONSTRUCT(
        'ANN_INC': d.ANNUAL_INCOME,
        'EDU_LVL': d.EDUCATION_LEVEL,
        'EMP_STS': d.EMPLOYMENT_STATUS,
        'GNDR': d.GENDER,
        'HHLD_SZ': d.HOUSEHOLD_SIZE,
        'HOME_OWN': d.HOME_OWNERSHIP,
        'MRTL_STS': d.MARITAL_STATUS,
        'NET_WRTH': d.NET_WORTH,
        'OCCPTN': d.OCCUPATION
    ) AS DEMOGRAPHICS,
    m.LEGAL_NAME AS LEGAL_NM,
    m.PARTY_TYPE AS PARTY_TYP,
    m.UPDATED_AT AS UPD_DT
FROM 
    modern_data.processed_data m,
    LATERAL FLATTEN(m.DEMOGRAPHICS) d

This will create a DEMOGRAPHICS column with a JSON object containing the individual demographic attributes.

Token Usage:

Completion tokens: 1012
Prompt tokens: 896
Total tokens: 1908
Model used: llama3.1-405b

In [None]:
CREATE OR REPLACE TABLE AI_MAPPING_TOOL_METADATA (
    id INTEGER AUTOINCREMENT ,
    user STRING,
    role STRING,
    database_name STRING,
    schema_name STRING,
    modernized_table_name STRING,
    legacy_table_name STRING,
    vector_embeddings_used BOOLEAN,
    sample_data_used BOOLEAN,
    input_prompt TEXT,
    user_liked BOOLEAN,
    user_feedback TEXT,
    input_token_length INTEGER,
    output_token_length INTEGER,
    total_token_length INTEGER,
    model_name VARCHAR,
    errors TEXT,
    output_response VARCHAR,
    start_time NUMBER(38,0),  -- Changed from TIMESTAMP_NTZ to NUMBER
    end_time NUMBER(38,0),    -- Changed from TIMESTAMP_NTZ to NUMBER
   total_time FLOAT,
    sample_input_limit INTEGER,
    --sample_input_limit TIMESTAMP_NTZ,
    created_at TIMESTAMP_NTZ DEFAULT CURRENT_TIMESTAMP(),
    updated_at TIMESTAMP_NTZ DEFAULT CURRENT_TIMESTAMP()
);

In [None]:
import streamlit as st
from snowflake.snowpark.context import get_active_session
from snowflake.snowpark.functions import parse_json, to_timestamp, current_timestamp, lit, col, when
from snowflake.snowpark.types import StringType, BooleanType, IntegerType, FloatType, TimestampType, VariantType
import json
import traceback
from snowflake.snowpark.functions import current_timestamp, col, when
from datetime import datetime, date
import json

# Get the current Snowflake session
session = get_active_session()

st.title("Data Comparison and Mapping Tool")

# Initialize session state variables
if 'prompt' not in st.session_state:
    st.session_state.prompt = ""
if 'ai_response' not in st.session_state:
    st.session_state.ai_response = ""
if 'token_count' not in st.session_state:
    st.session_state.token_count = 0
if 'start_time' not in st.session_state:
    st.session_state.start_time = None
if 'end_time' not in st.session_state:
    st.session_state.end_time = None
if 'errors' not in st.session_state:
    st.session_state.errors = []
if 'metadata_id' not in st.session_state:
    st.session_state.metadata_id = None

# Input fields for database, schema, and table names
database_name = st.text_input("Enter Database Name", "SMART_AI_MAPPER")
schema_name = st.text_input("Enter Schema Name", "SMART_AI_MAPPER_TOOL")
modern_table_name = st.text_input("Enter Modernised Table Name", "Deposit_Accounts_Modernised")
legacy_table_name = st.text_input("Enter Legacy Table Name", "Deposit_Accounts_Legacy")
record_limit = st.number_input("Enter Sample Record Limit", min_value=1, value=5)

# Set the current database and schema
session.use_database(database_name)
session.use_schema(schema_name)

# Model selection
model = st.selectbox("Select AI Model", [
    "snowflake-arctic", "mistral-large", "reka-flash", "reka-core", "mixtral-8x7b",
    "jamba-instruct", "llama2-70b-chat", "llama3-8b", "llama3-70b", "llama3.1-8b",
    "llama3.1-70b", "llama3.1-405b", "mistral-7b", "gemma-7b"
])

# Options for prompt generation
use_vector_embeddings = st.checkbox("Use Vector Embeddings")
use_sample_data = st.checkbox("Use Sample Data")

def fetch_available_embeddings(modern_table, legacy_table):
    query = f"""
    SELECT VECTOR_MAPPING_FOR_TABLE_NAME 
    FROM TOP_3_SIMILAR_FIELDS_FROM_VEC_EMB
    WHERE UPPER(VECTOR_MAPPING_FOR_TABLE_NAME) LIKE '%{modern_table.upper()}%' 
    AND UPPER(VECTOR_MAPPING_FOR_TABLE_NAME) LIKE '%{legacy_table.upper()}%'
    GROUP BY 1
    """
    result = session.sql(query).collect()
    return [row['VECTOR_MAPPING_FOR_TABLE_NAME'] for row in result]

# Fetch available embeddings
available_embeddings = fetch_available_embeddings(modern_table_name, legacy_table_name)

# Dropdown for selecting embedding
selected_embedding = None
if available_embeddings:
    selected_embedding = st.selectbox("Select Vector Embedding", available_embeddings)
else:
    st.warning("No matching vector embeddings found for the given table names.")

def generate_prompt():
    combined_data = {}
    prompt_text = "Given the provided data, compare the fields between the Modernised and Legacy tables. "
    
    if use_vector_embeddings and selected_embedding:
        # Fetch vector embeddings and mappings
        vector_mappings = session.sql(f"""
        SELECT ARRAY_AGG(
            OBJECT_CONSTRUCT(*)) AS JSON_DATA FROM
            (
        SELECT VECTOR_MAPPING_FOR_TABLE_NAME,        
         LISTAGG(CONCAT(MODERNISED_TABLE_FIELD_NAME,'-',TOP_SIMILARITIES_LEGACY_FIELDS,',')) within group 
            (ORDER BY MODERNISED_TABLE_FIELD_NAME ) AS Modern_table_field_to_legacy_table_mappings
        FROM TOP_3_SIMILAR_FIELDS_FROM_VEC_EMB
        WHERE VECTOR_MAPPING_FOR_TABLE_NAME = '{selected_embedding}'
        GROUP BY 1)
        """).collect()[0]['JSON_DATA']
        combined_data["vector_mappings"] = vector_mappings
        prompt_text += "Use the vector embeddings to help map fields between the Modernised table and the Legacy table. "
    
    if use_sample_data:
        # Call the stored procedure for modern table
        modern_result = session.call('fetch_and_process_table_data', 
                                     database_name, 
                                     schema_name, 
                                     modern_table_name, 
                                     record_limit)
        
        # Call the stored procedure for legacy table
        legacy_result = session.call('fetch_and_process_table_data', 
                                     database_name, 
                                     schema_name, 
                                     legacy_table_name, 
                                     record_limit)
        
        combined_data["modern_data"] = modern_result
        combined_data["legacy_data"] = legacy_result
        prompt_text += "Analyze the sample data to identify discrepancies and provide a mapping between the two tables. Highlight any differences in field names or data values. "
    
    prompt_text += """
    For fields with mismatches, suggest corrections and as per mapping identified, provide a Snowflake compatible SQL query to transform data from Modernised table to align with the Legacy Table. 
    Ensure that all data matches correctly between Modernised and Legacy, Report if any discrepancies and apply if transformations required in the snowflake 
    SQL generated.
    If there are fields in one table that do not have direct matches in the other table, note these discrepancies and indicate how to handle them.
    
    Data: {data}
    """
    
    # Escape special characters in the combined_data JSON string
    escaped_data = json.dumps(combined_data).replace('"', '\\"')
    #escaped_data = "Give a Response some table data in json format"
    st.session_state.prompt = prompt_text.format(data=escaped_data)


def insert_metadata():
    current_user = session.sql("SELECT CURRENT_USER()").collect()[0][0]
    current_role = session.sql("SELECT CURRENT_ROLE()").collect()[0][0]
    
    ai_response_json = st.session_state.ai_response
    
    # Extract only the "messages" content
    messages_content = ai_response_json["choices"][0]["messages"]
    
    # Create a new JSON object with just the "messages" content
    messages = json.dumps({"messages": messages_content})
    print(messages)
    
    # Calculate total_time
    total_time = (st.session_state.end_time - st.session_state.start_time).total_seconds() if st.session_state.start_time and st.session_state.end_time else None
    
    # Truncate errors if too long
    errors = ", ".join(st.session_state.errors) if st.session_state.errors else None
    if errors and len(errors) > 1000:
        errors = errors[:997] + "..."

    # Create a DataFrame with a single row of metadata
    metadata_df = session.create_dataframe([
        (None, current_user, current_role, database_name, schema_name, modern_table_name, legacy_table_name,
         use_vector_embeddings, use_sample_data, st.session_state.prompt, messages, None, None,
         st.session_state.token_count,
         st.session_state.ai_response.get('usage', {}).get('completion_tokens') if st.session_state.ai_response else None,
         st.session_state.ai_response.get('usage', {}).get('total_tokens') if st.session_state.ai_response else None,
         int(st.session_state.start_time.timestamp()) if st.session_state.start_time else None,
         int(st.session_state.end_time.timestamp()) if st.session_state.end_time else None,
         total_time, model, record_limit, errors, None, None)
    ], schema=[
        "id", "user", "role", "database_name", "schema_name", "modernized_table_name", "legacy_table_name",
        "vector_embeddings_used", "sample_data_used", "input_prompt", "output_response", "user_liked",
        "user_feedback", "input_token_length", "output_token_length", "total_token_length", "start_time",
        "end_time", "total_time", "model_name", "sample_input_limit", "errors", "created_at", "updated_at"
    ])

    # Apply transformations
    #metadata_df = metadata_df.with_column("output_response", parse_json(col("output_response")))
    metadata_df = metadata_df.with_column("output_response", col("output_response").cast("string"))
    metadata_df = metadata_df.with_column("start_time", col("start_time").cast("number(38,0)"))
    metadata_df = metadata_df.with_column("end_time", col("end_time").cast("number(38,0)"))
    metadata_df = metadata_df.with_column("total_time", col("total_time").cast("float"))
    metadata_df = metadata_df.with_column("created_at", current_timestamp())
    metadata_df = metadata_df.with_column("updated_at", current_timestamp())

    try:
        # Insert the data into the table
        metadata_df.write.save_as_table("AI_MAPPING_TOOL_METADATA", mode="append")
        
        # Fetch the last inserted ID
        last_id_df = session.sql("SELECT MAX(id) as last_id FROM AI_MAPPING_TOOL_METADATA").collect()
        st.session_state.metadata_id = last_id_df[0]['LAST_ID']
        
        st.success("Metadata inserted successfully!")
    except Exception as e:
        st.error(f"Error inserting metadata: {str(e)}")
        st.error(f"Metadata: {metadata_df.collect()}")

    # Print for debugging
    print("Debugging metadata_df:")
    for row in metadata_df.collect():
        print("\nRow:")
        for field_name, field_value in row.asDict().items():
            print(f"  {field_name}: ", end="")
            if isinstance(field_value, (datetime, date)):
                print(field_value.isoformat())
            elif isinstance(field_value, str):
                print(f"'{field_value[:100]}...'" if len(field_value) > 100 else f"'{field_value}'")
            elif isinstance(field_value, (int, float)):
                print(f"{field_value}")
            elif field_value is None:
                print("None")
            else:
                print(f"{type(field_value).__name__}({repr(field_value)[:100]}...)")
        print("-" * 50)

# Print data types for debugging
print(f"start_time type: {type(st.session_state.start_time)}")
print(f"end_time type: {type(st.session_state.end_time)}")
print(f"start_time value: {st.session_state.start_time}")
print(f"end_time value: {st.session_state.end_time}")

def update_metadata(user_liked, user_feedback):
    update_query = """
    UPDATE AI_MAPPING_TOOL_METADATA
    SET user_liked = ?,
        user_feedback = ?,
        updated_at = CURRENT_TIMESTAMP()
    WHERE id = ?
    """
    try:
        session.sql(update_query).update([user_liked, user_feedback, st.session_state.metadata_id])
        st.success("Feedback submitted successfully!")
    except Exception as e:
        st.error(f"Error updating metadata: {str(e)}")

def get_ai_response(options):
    try:
        st.session_state.start_time = datetime.now()
        # Prepare the prompt
        prompt_content = st.session_state.prompt.replace("'", "''")
        
        # Call the CORTEX.COMPLETE function
        query = f"""
        SELECT SNOWFLAKE.CORTEX.COMPLETE(
            '{model}',
            ARRAY_CONSTRUCT(OBJECT_CONSTRUCT('role', 'user', 'content', '{prompt_content}')),
            OBJECT_CONSTRUCT('temperature', {options['temperature']}, 'max_tokens', {options['max_tokens']}, 'top_p', {options['top_p']})
        ) AS response
        """
        result = session.sql(query).collect()
        
        if result and len(result) > 0:
            response_json = json.loads(result[0]['RESPONSE'])
            st.session_state.ai_response = response_json
            st.session_state.end_time =  datetime.now()
            
            print("LLM Response done")
            insert_metadata()
        else:
            st.error("No response received from the AI model.")
    except Exception as e:
        error_msg = f"An error occurred: {str(e)}\n{traceback.format_exc()}"
        st.error(error_msg)
        st.session_state.ai_response = "Error: Unable to get AI response."
        st.session_state.errors.append(error_msg)

if st.button("Generate Prompt"):
    generate_prompt()
    
    # Calculate token count
    try:
        token_count_result = session.sql(f"""
        SELECT SNOWFLAKE.CORTEX.COUNT_TOKENS(
            '{model}',
            '{st.session_state.prompt.replace("'", "''")}'
        ) AS token_count
        """).collect()
        
        if token_count_result and len(token_count_result) > 0:
            st.session_state.token_count = token_count_result[0]['TOKEN_COUNT']
        else:
            st.warning("Unable to calculate token count.")
    except Exception as e:
        st.warning(f"Error calculating token count: {str(e)}")

if st.session_state.prompt:
    st.subheader("Generated Prompt:")
    st.text_area("Prompt", st.session_state.prompt, height=300)
    st.write(f"Token Count: {st.session_state.token_count}")
    
    st.subheader("Model Parameters")
    temperature = st.slider("Temperature", 0.0, 1.0, 0.7, 0.1)
    max_tokens = st.number_input("Max Tokens", 1, 4096, min(st.session_state.token_count, 4096))
    top_p = st.slider("Top P", 0.0, 1.0, 1.0, 0.1)
    
    if st.button("Get AI Response"):
        options = {
            "temperature": temperature,
            "max_tokens": max_tokens,
            "top_p": top_p,
        }
        get_ai_response(options)

if st.session_state.ai_response:
    st.subheader("AI Model Response:")
    if isinstance(st.session_state.ai_response, dict):
        if 'choices' in st.session_state.ai_response and len(st.session_state.ai_response['choices']) > 0:
            message = st.session_state.ai_response['choices'][0].get('message', {}).get('content', '')
            print("before LLM Markdown")
            st.markdown(message)
            #insert_metadata()
        
        if 'usage' in st.session_state.ai_response:
            usage = st.session_state.ai_response['usage']
            st.write("Token Usage:")
            st.write(f"- Completion tokens: {usage.get('completion_tokens', 'N/A')}")
            st.write(f"- Prompt tokens: {usage.get('prompt_tokens', 'N/A')}")
            st.write(f"- Total tokens: {usage.get('total_tokens', 'N/A')}")
        
        if 'model' in st.session_state.ai_response:
            st.write(f"Model used: {st.session_state.ai_response['model']}")
    else:
        st.write(st.session_state.ai_response)

    # User feedback
    col1, col2 = st.columns(2)
    with col1:
        liked = st.button("👍 Like")
    with col2:
        disliked = st.button("👎 Dislike")
    
    feedback = st.text_area("Additional Feedback (optional)")

    if liked or disliked or feedback:
        try:
            user_liked = True if liked else (False if disliked else None)
            update_metadata(user_liked, feedback)
        except Exception as e:
            error_msg = f"Error submitting feedback: {str(e)}\n{traceback.format_exc()}"
            st.error(error_msg)
            st.session_state.errors.append(error_msg)

In [None]:
SELECT * FROM AI_MAPPING_TOOL_METADATA 
/*DELETE FROM AI_MAPPING_TOOL_METADATA;
INSERT INTO AI_MAPPING_TOOL_METADATA (
    user,
    role,
    database_name,
    schema_name,
    modernized_table_name,
    legacy_table_name,
    vector_embeddings_used,
    sample_data_used,
    input_prompt,
    output_response,
    user_liked,
    user_feedback,
    input_token_length,
    output_token_length,
    total_token_length,
    start_time,
    end_time,
    total_time,
    model_name,
    sample_input_limit,
    errors
)
SELECT
    'JOHN_DOE',
    'DATA_ANALYST',
    'SMART_AI_MAPPER',
    'SMART_AI_MAPPER_TOOL',
    'Deposit_Accounts_Modernised',
    'Deposit_Accounts_Legacy',
    TRUE,
    TRUE,
    'Compare the fields between the Modernised and Legacy Deposit_Accounts tables. Analyze the sample data to identify discrepancies and provide a mapping between the two tables.',
    OBJECT_CONSTRUCT('choices', ARRAY_CONSTRUCT(OBJECT_CONSTRUCT('message', OBJECT_CONSTRUCT('content', 'Based on the analysis of the Modernised and Legacy Deposit_Accounts tables, here are the mappings and discrepancies identified: 1. Customer_ID (Modernised) -> Cust_ID (Legacy) 2. Account_Number (Modernised) -> Acct_Num (Legacy) 3. Account_Type (Modernised) -> Acct_Type (Legacy) 4. Balance (Modernised) -> Acct_Balance (Legacy) 5. Open_Date (Modernised) -> Opening_Date (Legacy) Discrepancies: 1. The Modernised table uses Customer_ID while the Legacy table uses Cust_ID for the unique identifier. 2. Account numbers are stored in Account_Number in the Modernised table and Acct_Num in the Legacy table. 3. The Balance field is named differently in both tables. 4. Date formats may differ between Open_Date and Opening_Date. Snowflake SQL to transform Modernised data to Legacy format: SELECT Customer_ID AS Cust_ID, Account_Number AS Acct_Num, Account_Type AS Acct_Type, Balance AS Acct_Balance, Open_Date AS Opening_Date FROM Deposit_Accounts_Modernised; Note: Ensure that the data types are compatible between the two tables, especially for the date and numeric fields. You may need to add appropriate CAST functions if there are any data type mismatches.'))), 'usage', OBJECT_CONSTRUCT('completion_tokens', 231, 'prompt_tokens', 124, 'total_tokens', 355)),
    TRUE,
    'The mapping provided is accurate and helpful.',
    124,
    231,
    355,
    '2023-08-06 10:15:30'::TIMESTAMP_NTZ,
    '2023-08-06 10:15:35'::TIMESTAMP_NTZ,
    5.0,
    'snowflake-arctic',
    5,
    NULL
UNION ALL
SELECT
    'JANE_SMITH',
    'DATA_ENGINEER',
    'SMART_AI_MAPPER',
    'SMART_AI_MAPPER_TOOL',
    'Customer_Info_Modernised',
    'Customer_Info_Legacy',
    FALSE,
    TRUE,
    'Provide a mapping between the Modernised and Legacy Customer_Info tables. Highlight any differences in field names or data values.',
    OBJECT_CONSTRUCT('choices', ARRAY_CONSTRUCT(OBJECT_CONSTRUCT('message', OBJECT_CONSTRUCT('content', 'Based on the sample data provided for the Modernised and Legacy Customer_Info tables, here is the mapping and analysis: 1. CustomerID (Modernised) -> Cust_ID (Legacy) 2. FirstName (Modernised) -> First_Name (Legacy) 3. LastName (Modernised) -> Last_Name (Legacy) 4. Email (Modernised) -> EmailAddress (Legacy) 5. PhoneNumber (Modernised) -> Phone (Legacy) 6. DateOfBirth (Modernised) -> DOB (Legacy) Differences and Discrepancies: 1. The Modernised table uses camelCase for field names, while the Legacy table uses underscores. 2. The email field is named differently in both tables. 3. Phone number field has a different name in each table. 4. Date of Birth is abbreviated in the Legacy table. Snowflake SQL to transform Modernised data to Legacy format: SELECT CustomerID AS Cust_ID, FirstName AS First_Name, LastName AS Last_Name, Email AS EmailAddress, PhoneNumber AS Phone, DateOfBirth AS DOB FROM Customer_Info_Modernised; Note: Ensure that the data types are consistent between the two tables, especially for the Date of Birth field. You may need to use appropriate date formatting functions if the date formats differ between the tables.'))), 'usage', OBJECT_CONSTRUCT('completion_tokens', 246, 'prompt_tokens', 98, 'total_tokens', 344)),
    TRUE,
    'The mapping is correct and the SQL transformation is useful.',
    98,
    246,
    344,
    '2023-08-07 14:30:00'::TIMESTAMP_NTZ,
    '2023-08-07 14:30:07'::TIMESTAMP_NTZ,
    7.0,
    'mistral-large',
    10,
    NULL
UNION ALL
SELECT
    'BOB_JOHNSON',
    'DATA_SCIENTIST',
    'SMART_AI_MAPPER',
    'SMART_AI_MAPPER_TOOL',
    'Transaction_Modernised',
    'Transaction_Legacy',
    TRUE,
    TRUE,
    'Compare the Transaction_Modernised and Transaction_Legacy tables. Provide a mapping and suggest any necessary data transformations.',
    OBJECT_CONSTRUCT('error', 'API request failed due to rate limiting. Please try again later.'),
    FALSE,
    'The AI model failed to provide a response due to rate limiting.',
    115,
    NULL,
    NULL,
    '2023-08-08 09:45:00'::TIMESTAMP_NTZ,
    '2023-08-08 09:45:02'::TIMESTAMP_NTZ,
    2.0,
    'reka-flash',
    5,
    'API request failed due to rate limiting. Please try again later.';*/

In [None]:
SELECT * FROM AI_MAPPING_TOOL_METADATA

In [None]:
import json
from datetime import datetime, timedelta
from snowflake.snowpark.functions import col, lit, current_timestamp
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)

def test_insert_metadata(session, user, role, database_name, schema_name, modern_table_name, legacy_table_name,
                         use_vector_embeddings, use_sample_data, model, record_limit, mock_ai_response):
    logging.info("Starting metadata insertion test")
    
    mock_start_time = datetime.now()
    mock_end_time = mock_start_time + timedelta(seconds=5)  # Simulating 5 seconds of processing time
    total_time = (mock_end_time - mock_start_time).total_seconds()

    # Create a DataFrame with a single row of metadata
    metadata_df = session.create_dataframe([
        (None, user, role, database_name, schema_name, modern_table_name, legacy_table_name,
         use_vector_embeddings, use_sample_data, "Test prompt", json.dumps(mock_ai_response), None, None,
         150, mock_ai_response['usage']['completion_tokens'],
         mock_ai_response['usage']['total_tokens'],
         int(mock_start_time.timestamp()),
         int(mock_end_time.timestamp()),
         float(total_time), model, record_limit, None, None, None)
    ], schema=[
        "id", "user", "role", "database_name", "schema_name", "modernized_table_name", "legacy_table_name",
        "vector_embeddings_used", "sample_data_used", "input_prompt", "output_response", "user_liked",
        "user_feedback", "input_token_length", "output_token_length", "total_token_length", "start_time",
        "end_time", "total_time", "model_name", "sample_input_limit", "errors", "created_at", "updated_at"
    ])

    # Apply transformations
    metadata_df = metadata_df.select(
        col("id").cast("integer"),
        col("output_response").cast("string"),
        col("start_time").cast("number(38,0)"),
        col("end_time").cast("number(38,0)"),
        col("total_time").cast("float"),
        col("input_token_length").cast("integer"),
        col("output_token_length").cast("integer"),
        col("total_token_length").cast("integer"),
        col("sample_input_limit").cast("integer"),
        col("model_name").cast("string"),
        col("vector_embeddings_used").cast("boolean"),
        col("sample_data_used").cast("boolean"),
        current_timestamp().alias("created_at"),
        current_timestamp().alias("updated_at"),
        "*"
    )

    try:
        # Insert the data into the table
        metadata_df.write.save_as_table("AI_MAPPING_TOOL_METADATA", mode="append")
        
        # Fetch the last inserted ID
        last_id_df = session.sql("SELECT MAX(id) as last_id FROM AI_MAPPING_TOOL_METADATA").collect()
        last_inserted_id = last_id_df[0]['LAST_ID']
        
        logging.info(f"Test metadata inserted successfully! Last inserted ID: {last_inserted_id}")
        
        # Verify the inserted data
        inserted_data = session.sql(f"SELECT * FROM AI_MAPPING_TOOL_METADATA WHERE id = {last_inserted_id}").collect()
        assert len(inserted_data) == 1, "Inserted data not found"
        
        return last_inserted_id
    except Exception as e:
        logging.error(f"Error inserting test metadata: {str(e)}")
        logging.debug(f"Test Metadata: {metadata_df.collect()}")
        raise
    finally:
        logging.info("DataFrame Schema:")
        metadata_df.printSchema()

# Stored procedure handler
def main(session):
    try:
        # Mock AI response
        mock_ai_response = {
            "usage": {
                "completion_tokens": 100,
                "total_tokens": 250
            }
        }

        # Call the test function
        last_id = test_insert_metadata(
            session=session,
            user="test_user",
            role="test_role",
            database_name="test_db",
            schema_name="test_schema",
            modern_table_name="modern_table",
            legacy_table_name="legacy_table",
            use_vector_embeddings=True,
            use_sample_data=True,
            model="test_model",
            record_limit=1000,
            mock_ai_response=mock_ai_response
        )
        return f"Test completed successfully. Last inserted ID: {last_id}"

    except Exception as e:
        return f"Test failed: {str(e)}"

# This is needed for Snowflake to recognize the stored procedure
def run(session):
    return main(session)