## Init Script

#### - Imports

In [1]:
#------------------------__Init__Import Packages-------------------------------------
from snowflake.snowpark import Session
from snowflake.snowpark.functions import col, count, lit, avg, mean, length, min, median, max, stddev, when, mode, approx_percentile, datediff, dateadd, to_date, date_from_parts, unix_timestamp, from_unixtime, collect_set
import pandas as pd
import ipywidgets as widgets
from IPython.display import display
from snowflake.snowpark.types import StringType, FloatType, IntegerType, DateType, BooleanType, DecimalType, VariantType, StructType, StructField, LongType
import datetime
from datetime import datetime
import json
from time import time
import numpy as np
from IPython.display import clear_output
import ast
from pprint import pprint 
import matplotlib.pyplot as plt
from snowflake.snowpark import DataFrame


#### - Parameters

In [2]:
#------------------------__Init__Initialize Parameters-------------------------------------

#__Needed from previous functions__

#__Set for the first time in this function__

#--get_user_inputs--
connection_parameters = None
restart_session_boolean = None
reload_data_boolean = None
enroll_data_table_name = None #change to raw_source_name
col_analysis_table_name = None
raw_enroll_data_table_name = None
'''add col_analysis_table_name'''
'''add database & schema'''


#--get_or_create_session--
session_exists = False  # 
session = None

#--load_data--
enroll_stable_copy = None #change to raw_source_copy
enroll_stable_copy_schema = None #change to raw_source_schema
dataframe_initialized = False
dtype_to_class_name = None

#--map_data_types_to_categories--
schema_mapped_df = None

#--run_column_tests--
test_results_df = None

#--upload_column_tests_to_snowflake--
column_analysis = None
snowflake_updated_column_analysis = None

#--classify_columns_for_corr_and_anova--
categorical_corr_anova_threshold = None
main_target_column = None
NUMERIC_COLUMN_NAMES = None
CATEGORICAL_COLUMN_NAMES = None
CATEGORICAL_COLUMN_NAMES_plus_goal = None
IGNORE_COLUMN_NAMES = None

#--convert_dates_to_unix_encoding--
raw_data_snowpark_copy = None
DATES_TO_CONVERT_TO_NUMERIC = None

#--sort_category_values_by_goal_metric--
sorted_category_averages = None

#--convert_category_values_to_encoding--
encoded_df = None
binary_encoded_df = None

#--run_standard_encoded_corr_analysis--
standard_pearson_corr_results = None

#--run_binary_encoded_corr_analysis--
binary_pearson_corr_results = None

#### - Custom Debug Functions

In [None]:
#------------------------ "Peek" Debug Tool ("","v","p"..."3p") -------------------------------------
import pandas as pd

# Adjusted peek function to handle Snowpark DataFrames as well
def peek(scope="1"):
    # Mapping the scope number to the structure's keys
    scope_map = {"1": "direct", "2": "context", "3": "general"}

    # Determine whether we are looking for variables or parameters
    section_suffix = {
        "v": "variables",
        "p": "parameters"
    }
    section = section_suffix.get(scope[-1].lower(), None)
    
    # Check if we're dealing with sub-sections or the entire section
    if section:
        scope_key = f"{scope_map.get(scope[0])}_{section}"
    else:
        scope_key = f"{scope_map.get(scope)}_{section_suffix['v']}"  # Default to variables if not specified

    print(scope_key)
    
    # Fetch the values from the global scope
    global_variables = globals()

    # Helper function to limit output based on type
    def limited_output(value):
        # Check for Snowpark DataFrame and convert to Pandas DataFrame
        if isinstance(value, DataFrame):
            return display(value.to_pandas().iloc[:5, :5])#.to_string(index=False)
        if isinstance(value, pd.DataFrame):
            return display(value.iloc[:5, :5])#.to_string(index=False)
        elif isinstance(value, list):
            return ', '.join(map(str, value[:30]))
        elif isinstance(value, str):
            return value[:30]
        elif isinstance(value, dict):
            return ', '.join([f"{k}: {str(v)[:30]}" for k, v in list(value.items())[:3]])
        elif isinstance(value, set):
            return ', '.join(map(str, list(value)[:30]))
        else:
            # Convert to string to handle other data types and limit the output
            return str(value)[:30]
            
    # Print the variable values with output limiting
    for var_name in var_param_structure[scope_map.get(scope[0], "")][scope_key]:
        value = global_variables.get(var_name, 'Variable not found')
        limited_val = limited_output(value)
        print(f"{var_name}:\n{limited_val}\n")

    output = []
    for var_name in var_param_structure[scope_map.get(scope[0], "")][scope_key]:
        value = global_variables.get(var_name, 'Variable not found')
        limited_val = limited_output(value)
        output.append(f"{var_name}:\n{limited_val}\n")
    return "\n".join(output)

# The 'structure' variable should be defined in the global scope as provided.






var_param_structure = {
    "direct": {
        "direct_variables": ["enroll_stable_copy",
                             "enroll_stable_copy_schema",
                             "NUMERIC_COLUMN_NAMES",
                             "CATEGORICAL_COLUMN_NAMES",
                             "IGNORE_COLUMN_NAMES",
                             "DATES_TO_CONVERT_TO_NUMERIC",
                             "sorted_category_averages",
                             "encoded_df"],
        "direct_parameters": ["enroll_data_table_name",
                              "col_analysis_table_name", 
                              "Category_Mapping_Reference",
                              "enroll_stable_copy_schema"]
    },
    "context": {
        "context_variables": ["schema_mapped_df"],
        "context_parameters": ["connection_parameters.database",
                               "connection_parameters.schema"]
    },
    "general": {
        "general_variables": ["experiment_category_columns",
                              "experiment_category_list", 
                              "cat_test_results_dict",
                              "categorical_summary_results_df"],
        "general_parameters": ["GlobalStructureMap_file_name",
                               "CellStructureMap_file_name", 
                               "CellInteractionMap_file_name"]
    }
}





# Example usage:
#peek("1v")

In [4]:
#------------------------Showall Function-------------------------------------
def showall (data, max_rows = 900, max_columns = 120):
    pd.options.display.max_columns = max_columns
    pd.options.display.max_rows = max_rows
    display(data)
    pd.options.display.max_columns = 60
    pd.options.display.max_rows = 20

#### - Get User Inputs

In [5]:
#------------------------ Get User Inputs-------------------------------------
def get_user_inputs ():
    #__Set for the first time in this function__
    global restart_session_boolean
    global reload_data_boolean
    global connection_parameters
    global enroll_data_table_name
    global col_analysis_table_name
    global raw_enroll_data_table_name

    #__function logic starts here__
    
    restart_session_boolean = 'No'
    reload_data_boolean  = 'No'
    
    connection_parameters = {
        "account": "NEEDS_POPULATED",
        "user": "NEEDS_POPULATED",
        "authenticator": "externalbrowser", #.
        "database": "NEEDS_POPULATED",  # Specify your database name here
        "schema": "NEEDS_POPULATED"  # Specify your schema name here
        #"warehouse": "your_warehouse_name",  # Specify if necessary
        #"role": "your_role_name",  # Specify if necessary
    }  

    #Eventually change to [database].[schema].... format
    raw_enroll_data_table_name = 'NEEDS_POPULATED'
    
    enroll_data_table_name = 'NEEDS_POPULATED'

    col_analysis_table_name = "NEEDS_POPULATED"

    categorical_corr_anova_threshold = 10

#### - Connections/Integrations

In [6]:
#------------------------Connect to Snowpark API-------------------------------------
def get_or_create_session():
    #__Needed from previous functions__
    #--get_or_create_session--
    global connection_parameters
    global restart_session_boolean

    #__Set for the first time in this function__
    global session_exists
    global session
    
    #___function logic starts here___

    if not session_exists or restart_session_boolean == 'Yes':
        if not session_exists:
            print("no existing session, creating new...")
            print("")
        elif restart_session_boolean == 'Yes':
            print("closing existing session, reconnecting...")
            
        # Create a new session
        print("snowpark API response...")
        session = Session.builder.configs(connection_parameters).create()
        
        print("")
        print("session created")
        print("")

        session.use_database(connection_parameters['database'])
        session.use_schema(connection_parameters['schema'])
        print("Database:",connection_parameters['database'])
        print("Schema:",connection_parameters['schema'])

        session_exists = True  # Update flag to indicate session now exists
        return
    else:
        print("continuing existing session")
        # Use the existing session
        from snowflake.snowpark.context import get_active_session
        return get_active_session()

## Load & Pre-Process Data

#### - Convert Potential Date Fields

In [7]:
#------------------------ Convert Potential Date Fields  -------------------------------------
def convert_potential_date_fields():
    print("converting potential date fields...")
    string_columns_query = f"""
    SELECT COLUMN_NAME
    FROM INFORMATION_SCHEMA.COLUMNS
    WHERE TABLE_NAME = 'BI_ENROLL_RATE_PROD_DATA'
    AND DATA_TYPE = 'TEXT';
    """
    string_columns_df = session.sql(string_columns_query).collect()
    string_column_names = [row['COLUMN_NAME'] for row in string_columns_df]
    print("string columns pulled")

    select_query = f"""
    SELECT {', '.join(string_column_names)}
    FROM "{raw_enroll_data_table_name}"
    """    
    df = session.sql(select_query).to_pandas()
    print("string column data queried")
    
    successful_conversions = []
    
    # Attempt to convert each string column to datetime, raising exceptions for unparseable strings
    for column in df.select_dtypes(include=['object']).columns:
        try:
            # Attempt to convert the column to datetime, raising an exception if it fails
            pd.to_datetime(df[column], errors='raise')
            # If no exception was raised, the conversion was successful
            print(f"Column '{column}' successfully converted to datetime.")
            successful_conversions.append(column)
        except ValueError:
            continue

    snowpark_df = session.create_dataframe(df)
    snowpark_df.write.save_as_table(raw_enroll_data_table_name+"_temporary", mode="overwrite")

#### - Normalize Dates

In [8]:
#------------------------ Add Normalized Dates -------------------------------------
def add_normalized_dates():
    # Define your Stored Procedure creation SQL
    create_procedure_sql = """
    CREATE OR REPLACE PROCEDURE dynamic_date_diff_and_union_save_as_table()
    RETURNS STRING
    LANGUAGE JAVASCRIPT
    EXECUTE AS CALLER
    AS
    $$
    var query = `
    SELECT COLUMN_NAME, DATA_TYPE
    FROM "EDW_PROD"."INFORMATION_SCHEMA"."COLUMNS"
    WHERE TABLE_SCHEMA = 'ANALYTICS'
    AND TABLE_NAME = 'BI_ENROLL_RATE_PROD_DATA_TEMPORARY';
    `;
    // Debug: Print the query to check for errors
    console.log(query); 
    
    var resultSet = snowflake.execute({sqlText: query});
    var columnPairs = []; // For storing original and normalized date columns together
    var numberColumns = [];
    var binaryColumns = [];
    var booleanColumns = [];
    var stringColumns = [];
    var arrayColumns = [];
    var variantObjectColumns = [];
    var geographyColumns = [];
    var timeColumns = [];
    var otherColumns = [];
    
    while (resultSet.next()) {
        var colName = resultSet.getColumnValue(1);
        var dataType = resultSet.getColumnValue(2);
    
        if (['DATE', 'TIMESTAMP', 'TIMESTAMP_LTZ', 'TIMESTAMP_NTZ', 'TIMESTAMP_TZ'].includes(dataType)) {
            // Include the original date column
            columnPairs.push(`"${colName}"`);
            // Include the normalized date column if it's not the reference date column
            if (colName !== 'IDENTIFIED_LIST_LOAD_DATE') {
                //columnPairs.push(`DATEDIFF(day, IDENTIFIED_LIST_LOAD_DATE, "${colName}") AS ${colName.toUpperCase()}_NORMALIZED`);
                columnPairs.push(`DATEDIFF(day, IDENTIFIED_LIST_LOAD_DATE, ${colName}) AS "${colName.toUpperCase()}_NORMALIZED"`);
            }
        } else if (['NUMBER', 'FLOAT', 'DECIMAL', 'INTEGER', 'BIGINT', 'SMALLINT', 'NUMERIC'].includes(dataType)) {
            numberColumns.push(`"${colName}"`);
        } else if (dataType === 'BINARY') {
            binaryColumns.push(`"${colName}"`);
        } else if (dataType === 'BOOLEAN') {
            booleanColumns.push(`"${colName}"`);
        } else if (['VARCHAR', 'CHAR', 'TEXT', 'STRING'].includes(dataType)) {
            stringColumns.push(`"${colName}"`);
        } else if (dataType === 'ARRAY') {
            arrayColumns.push(`"${colName}"`);
        } else if (['VARIANT', 'OBJECT'].includes(dataType)) {
            variantObjectColumns.push(`"${colName}"`);
        } else if (dataType === 'GEOGRAPHY') {
            geographyColumns.push(`"${colName}"`);
        } else if (dataType === 'TIME') {
            timeColumns.push(`"${colName}"`);
        } else {
            otherColumns.push(`"${colName}"`);
        }
    }
    
    // Construct the final select query with the desired column order
    var combinedColumns = columnPairs.concat(numberColumns, binaryColumns, booleanColumns, stringColumns, arrayColumns, variantObjectColumns, geographyColumns, timeColumns, otherColumns).join(', ');
    
    // Debug: Print the combinedColumns to check for errors
    console.log(combinedColumns); 
    
    var selectQuery = `CREATE OR REPLACE TABLE "ANALYTICS"."BI_ENROLL_RATE_FULL_2" AS SELECT ${combinedColumns} FROM "EDW_PROD"."ANALYTICS"."BI_ENROLL_RATE_PROD_DATA";`;

    // Debug: Print the selectQuery to check for errors
    console.log(selectQuery); 
    
    try {
        snowflake.execute({sqlText: selectQuery});
        return 'Table "ANALYTICS"."BI_ENROLL_RATE_FULL_2" created successfully with columns in the desired order.';
    } catch (err) {
        return `Error when attempting to create table "ANALYTICS"."BI_ENROLL_RATE_FULL_2":
                Query: ${selectQuery}
                Query results: ${JSON.stringify(resultSet).substring(0, 200)}
                combinedColumns: ${combinedColumns}
                selectQuery: ${selectQuery}
                selectQuery Error Code: ${err.code}
                selectQuery Error Message: ${err.message}`;
    }
    $$;
    """

    print("executing sql")
    # Execute the command to create the Stored Procedure
    session.sql(create_procedure_sql).collect()

    print("calling stored procedure")
    # Now, you can call the newly created Stored Procedure
    call_procedure_sql = "CALL dynamic_date_diff_and_union_save_as_table()"
    result = session.sql(call_procedure_sql).collect()

    print("result is:")
    # Print the result
    print(result)

#add_normalized_dates()

#### - Add Date Sequences (For Later Date Analysis)

In [9]:
#---------------- Add date sequences ----------------------

from snowflake.snowpark.functions import lit, row_number
from snowflake.snowpark.window import Window
import pandas as pd

def add_date_sequences():
    print("add_date_sequences running...")
    global enroll_stable_copy
    global schema_mapped_df

    enroll_stable_copy = session.table(enroll_data_table_name)

    # Step 0: Add an artificial row_id for tracking rows
    enroll_stable_copy_with_id = enroll_stable_copy.withColumn("row_id", row_number().over(Window.orderBy(lit(1))))
    
    # Step 1: Identify date columns
    date_columns_query = f"""
    SELECT COLUMN_NAME
    FROM INFORMATION_SCHEMA.COLUMNS
    WHERE TABLE_NAME = 'BI_ENROLL_RATE_FULL_2'
    AND DATA_TYPE IN('DATE', 'TIMESTAMP', 'TIMESTAMP_LTZ', 'TIMESTAMP_NTZ', 'TIMESTAMP_TZ');
    """
    date_columns_df = session.sql(date_columns_query).collect()
    date_column_names = [row['COLUMN_NAME'] for row in date_columns_df]
    print("date_column_names")
    print(date_column_names)

    
    # Include date columns and normalized counterparts
    date_columns_with_normalized = [
        col for col in date_column_names if col in enroll_stable_copy_with_id.columns
    ] + [
        col + "_NORMALIZED" for col in date_column_names if col + "_NORMALIZED" in enroll_stable_copy_with_id.columns
    ]

    # Select columns including 'row_id'
    columns_to_select = ["row_id"] + date_columns_with_normalized
    enroll_stable_copy_with_dates = enroll_stable_copy_with_id.select(columns_to_select)
    

    # Convert to pandas DataFrame
    pandas_df_with_dates = enroll_stable_copy_with_dates.to_pandas()

    # Convert date columns to datetime format
    for col_name in date_column_names:
        pandas_df_with_dates[col_name] = pd.to_datetime(pandas_df_with_dates[col_name], errors='coerce')

    # Transform each row to create ordered lists
    def transform_row_to_ordered_lists(row):
        items = sorted(
            [(col_name, row[col_name], row.get(col_name + "_NORMALIZED")) for col_name in date_column_names if pd.notnull(row[col_name])],
            key=lambda x: x[1]
        )
        return {
            'ORDERED_DATE_COL_NAMES': [item[0] for item in items],
            'ORDERED_DATES': [item[1] for item in items],
            'ORDERED_NORMALIZED_NUMS': [item[2] for item in items],
        }

    transformed_data = pandas_df_with_dates.apply(transform_row_to_ordered_lists, axis=1)
    pandas_df_with_dates = pd.concat([pandas_df_with_dates, pd.json_normalize(transformed_data)], axis=1)

    # Drop original date columns from pandas DataFrame before converting back
    pandas_df_final = pandas_df_with_dates.drop(columns=date_columns_with_normalized)

    # Convert modified pandas DataFrame back to Snowpark DataFrame
    modified_snowpark_df = session.create_dataframe(pandas_df_final)

    # Merge modified data back using 'row_id' and update the global enroll_stable_copy
    enroll_stable_copy = enroll_stable_copy_with_id.join(
        modified_snowpark_df,
        "row_id"
    ).drop("row_id")

    display(enroll_stable_copy.to_pandas().head())
    
    print("done, pushing to snowflake")
    enroll_stable_copy.write.save_as_table(raw_enroll_data_table_name+"_temporary", mode="overwrite")
    print("pushed to snowflake")

#add_date_sequences()

#### - Load Raw Source Data from Snowflake

In [10]:
#------------------------Load Data-------------------------------------

#deprecated... columns_to_load = ["IS_CURRENT"]
def load_data(columns_to_load = None, num_rows = None):
    #__Needed from previous functions__
    #...
    #--get_user_inputs--
    global reload_data_boolean
    global enroll_data_table_name
    
    #--get_or_create_session--
    global session

    #__Set for the first time in this function__
    global enroll_stable_copy
    global enroll_stable_copy_schema 
    global dataframe_initialized
    global dtype_to_class_name

    #___function logic starts here___
    
    if not dataframe_initialized or reload_data_boolean == 'Yes':
        #number_of_rows = 3
        # deprecated... Select specific columns and limit the number of rows (.select(columns_to_load) #.limit(number_of_rows))
        if num_rows is not None and columns_to_load is None:
            enroll_stable_copy = session.table(enroll_data_table_name).limit(num_rows)
            
        elif num_rows is None and columns_to_load is not None:
            enroll_stable_copy = session.table(enroll_data_table_name).limit(num_rows)
            
        elif num_rows is not None and columns_to_load is not None:
            enroll_stable_copy = session.table(enroll_data_table_name).select(columns_to_load).limit(num_rows)

        else:
            enroll_stable_copy = session.table(enroll_data_table_name)

        
        enroll_stable_copy_schema = enroll_stable_copy.schema

        num_rows = enroll_stable_copy.count()
        print(f"Number of rows: {num_rows}")
        
        num_columns = len(enroll_stable_copy_schema.fields)
        print(f"Number of columns: {num_columns}")

        dataframe_initialized = True 

        dtype_to_class_name = {}
        for field in enroll_stable_copy_schema.fields:
            # key = str of  datatype
            # value = class name
            dtype_to_class_name[str(field.datatype)] = field.datatype.__class__.__name__

#load_data()
 # Update flag to indicate dataframe is now initialized





#### - Map Native Data Types -> Simplified Test Types

In [11]:
#------------------------Map Native Data Types => Test Data Types-------------------------------------
def map_data_types_to_categories():
    #__Needed from previous functions__
    #--get_or_create_session--
    global session
    
    #--load_data--
    global enroll_stable_copy
    global enroll_stable_copy_schema
    global dtype_to_class_name

    #__Set for the first time in this function__
    global schema_mapped_df
    
    #___function logic starts here___
    
    Category_Mapping_Reference = {
        'STRING': ['StringType'],
        'BINARY': ['BinaryType', 'BooleanType'],
        'NUMBER': ['DecimalType', 'DoubleType', 'FloatType', 'IntegerType', 'LongType', 'ShortType'],
        'DATE': ['DateType', 'TimestampType'],
        'TIME': ['TimeType'],
        'VARIANT': ['ArrayType', 'MapType', 'Variant', 'VariantType'],
        'GEOGRAPHY': ['Geography', 'GeographyType'],
        'OTHER': ['ByteType', 'ColumnIdentifier', 'DataType', 'StructField', 'StructType']
    }
    
    # Function to create DataFrame from mappings
    def create_dtype_mapping_df(dtype_to_class, category_mapping):
        data = []
        for dtype, class_name in dtype_to_class.items():
            # Find the category for the class name
            category = 'OTHER'
            for cat, types in category_mapping.items():
                if class_name in types:
                    category = cat
                    break
    
            data.append({
                'native_data_type': dtype,
                'native_data_class': class_name,
                'test_data_type': category
            })
    
        return pd.DataFrame(data)
    
    def apply_mapping_to_schema(schema, mapping_df):
        # Prepare data for the DataFrame
        data = []
        for field in schema.fields:
            dtype_str = str(field.datatype)
            # Find the mapping in the DataFrame
            mapping = mapping_df[mapping_df['native_data_type'] == dtype_str].iloc[0]
            
            data.append({
                'column_name': field.name,
                'native_data_type': mapping['native_data_type'],
                'native_data_class': mapping['native_data_class'],
                'test_data_type': mapping['test_data_type']
            })
    
        # Create and return the DataFrame
        return pd.DataFrame(data)
    
    
    
    #-----------background/context-----
    #print(enroll_stable_copy_schema)
    #print("Revant previous steps/variables so far:")
    #print("enroll_stable_copy = session.table(enroll_data_table_name)#.select(columns_to_load) #.limit(number_of_rows)")
    #print("enroll_stable_copy_schema = enroll_stable_copy.schema")
    #print("enroll_stable_copy_schema:")
    #print("display(dtype_to_class_name)")
    #display(dtype_to_class_name)
    
    
    #---------function call/printing-------
    
    # Create the DataFrame
    dtype_mapping_df = create_dtype_mapping_df(dtype_to_class_name, Category_Mapping_Reference)
    
    
    # Assuming enroll_stable_copy_schema and dtype_mapping_df are defined
    schema_mapped_df = apply_mapping_to_schema(enroll_stable_copy_schema, dtype_mapping_df)
    
    print("upper-casing pandas columns before transform")
    schema_mapped_df.columns = schema_mapped_df.columns.str.upper()
    print("")
    print("Mapping is:")
    #pd.options.display.max_rows = None
    showall(schema_mapped_df)
    #showall(schema_mapped_df[schema_mapped_df['TEST_DATA_TYPE'] == 'DATE'])
    #pd.options.display.max_rows = 60


#map_data_types_to_categories()

## Column Analysis

#### - Run Analysis

In [12]:
#------------------------Run Column Analysis-------------------------------------

def run_column_analysis (subset = None):
    #__Needed from previous functions__
    #--get_or_create_session--
    global session
    #--load_data--
    global enroll_stable_copy

    #--map_data_types_to_categories--
    global schema_mapped_df

    #__Set for the first time in this function__
    global test_results_df

    #___function logic starts here___
    if subset == None:
        print("No subset provided - using original raw snowflake data")
        raw_data_subset = enroll_stable_copy
    else:
        print("Subset provided - using subset instead of original raw snowflake data")
        raw_data_subset = subset

    easy_tests = {
        'GENERAL': ['NULL_PERCENTAGE','DISTINCT_VALUES'],
        'STRING': ['STRING_AVG_LENGTH','STRING_MAX_LENGTH','STRING_MIN_LENGTH'],
        'NUMBER': ['NUMBER_MIN','NUMBER_AVG','NUMBER_MEDIAN','NUMBER_MODE','NUMBER_MAX','NUMBER_STD_DEV','NUMBER_LOWER_QUARTILE','NUMBER_UPPER_QUARTILE'],
        'BINARY': ['BINARY_PERCENT_TRUE_OR_1','BINARY_PERCENT_FALSE_OR_0'],
        'DATE': ['DATE_MIN','DATE_MAX']
    }
    
    hard_tests = {
        'VARIANT': ['VARIANT_UNIQUE_KEYS','VARIANT_AVG_VALUE_LENGTH','VARIANT_UNIQUE_VALUE_TYPES'],
        'CATEGORICAL': ['CATEGORICAL_UNIQUE_CATEGORIES_COUNT','CATEGORICAL_UNIQUE_CATEGORIES_LIST_TOP_50','CATEGORICAL_CATEGORY_FREQUENCY','CATEGORICAL_PERCENTAGE_OF_TOTAL_PER_CATEGORY','CATEGORICAL_DISTRIBUTION_NORMALIZATION_COEF','CATEGORICAL_DISTRIBUTION_SKEW_COEF'],
        'STRING': ['STRING_COMMON_PATTERNS'],
        'NUMBER': ['NUMBER_DISTRIBUTION_NORMALIZATION_COEF','NUMBER_DISTRIBUTION_SKEW_COEF'],
        'DATE': ['DATE_AVG','DATE_FREQUENCY_DISTRIBUTION'],
        'GEOGRAPHY': ['GEOGRAPHY_DATA_TYPES']#find more?
    }
    
    test_column_subsets = {}

    print("----------starting testing----------")
    """ Pulls only needed data for specific tests from a large dataset. """

    print("Copy schema_mapped_df with COLUMN_NAME as index for results compilation")
    test_results_df = schema_mapped_df.set_index(schema_mapped_df.columns[0])
    
    #combined_keys = list(set(easy_tests.keys()) | set(hard_tests.keys()))
    
    combined_keys = ['GENERAL','DATE','NUMBER','BINARY','STRING','VARIANT','GEOGRAPHY','CATEGORICAL']
    #combined_keys = ['DATE']
    
    for test_data_type in combined_keys:
        subset = schema_mapped_df.loc[schema_mapped_df['TEST_DATA_TYPE'] == test_data_type]
        #print("filter_data_and_feed_tests 0: for test_data_type in easy_tests.keys():")
        #print(subset)
        if subset.empty:
            if test_data_type == 'GENERAL':
                #print("filter_data_and_feed_tests 1: subset empty, setting GENERAL subset to all")
                test_column_subsets[test_data_type] = schema_mapped_df
            else:
                #print("filter_data_and_feed_tests 2: Not Found")
                print(f"No columns for test_data_type '{test_data_type}' found.")
        else:
            #print("filter_data_and_feed_tests 3: Subsetting data")
            test_column_subsets[test_data_type] = subset
            #print("filter_data_and_feed_tests 4: DATA SUBSET IS:")
            #print(test_column_subsets[test_data_type])

    
    for test_data_type in test_column_subsets.keys():
        #print("filter_data_and_feed_tests 5: For loop, current data type is:")
        print("Testing data type: ",test_data_type)
        
        if test_data_type == 'GENERAL':
            test_columns = test_column_subsets[test_data_type]['COLUMN_NAME'].tolist()
            print("raw_data_subset:")
            test_data = raw_data_subset.select(test_columns)
            for test in easy_tests[test_data_type]:
                if test == 'NULL_PERCENTAGE':#---------------------------------
                    print('   NULL_PERCENTAGE...')
                    # Create an aggregate expression for each column to calculate the percentage of NULLs
                    NULL_PERCENTAGE_expression = [
                        ((count(lit(1)) - count(col(name))) / count(lit(1)) * 100).alias(name)
                        for name in test_columns
                    ]
                    percent_null_df = test_data.agg(*NULL_PERCENTAGE_expression).to_pandas().transpose().rename(columns={0:test})
                    test_results_df = test_results_df.join(percent_null_df, how='left')
                    
                if test == 'DISTINCT_VALUES': #get distinct values AND # of distinct
                    #--DISTINCT_VALUES--------------------------------------------
                    print('   DISTINCT_VALUES...')
                                      
                    DISTINCT_VALUES_expression = [
                        collect_set(col(name)).alias(name)
                        for name in test_columns
                    ]
                    DISTINCT_VALUES_df = test_data.agg(*DISTINCT_VALUES_expression).to_pandas().transpose().rename(columns={0:test})
                    
                    DISTINCT_VALUES_df = DISTINCT_VALUES_df.replace('\n\s*|\n$', '', regex=True)


                    def try_convert_to_list(x):
                        try:
                            # Correct boolean values to Python's True and False
                            corrected_x = x.replace('true', 'True').replace('false', 'False')
                            # Attempt to evaluate the corrected string as a Python literal
                            evaluated = ast.literal_eval(corrected_x)
                            # Ensure the evaluated result is a list
                            if isinstance(evaluated, list):
                                return evaluated
                            return []
                        except Exception as e:
                            print(f"Error with value: {x}. Error: {e}")
                            return []
                          
                    DISTINCT_VALUES_df = DISTINCT_VALUES_df.applymap(try_convert_to_list)
                    
                    test_results_df = test_results_df.join(DISTINCT_VALUES_df.applymap(json.dumps), how='left')
                    
                    #--DISTINCT_COUNT--------------------------------------------
                    DISTINCT_COUNT_df = DISTINCT_VALUES_df.applymap(len).rename(columns={test:"DISTINCT_COUNT"})
    
                    test_results_df = test_results_df.join(DISTINCT_COUNT_df, how='left')

                    print("distinct done.")        
                    

        elif test_data_type == 'BINARY':
            test_columns = test_column_subsets[test_data_type]['COLUMN_NAME'].tolist()
            test_data = raw_data_subset.select(test_columns)
            for test in easy_tests[test_data_type]:
                if test == 'BINARY_PERCENT_TRUE_OR_1':#---------------------------------
                    print('   BINARY_PERCENT_TRUE_OR_1...')
                    
                    PERCENT_TRUE_OR_1_expression = [
                        ((count(when(col(name).isin([1, True]), 1)) / count(lit(1)) * 100).alias(name))
                        for name in test_columns
                    ]
                    PERCENT_TRUE_OR_1_df = test_data.agg(*PERCENT_TRUE_OR_1_expression).to_pandas().transpose().rename(columns={0:test})
                    test_results_df = test_results_df.join(PERCENT_TRUE_OR_1_df, how='left')
                    
                if test == 'BINARY_PERCENT_FALSE_OR_0':#---------------------------------
                    print('   BINARY_PERCENT_FALSE_OR_0...')
                    
                    PERCENT_FALSE_OR_0_expression = [
                        ((count(when(col(name).isin([0, False]), True)) / count(lit(1)) * 100).alias(name))
                        for name in test_columns
                    ]
                    PERCENT_FALSE_OR_0_df = test_data.agg(*PERCENT_FALSE_OR_0_expression).to_pandas().transpose().rename(columns={0:test})
                    test_results_df = test_results_df.join(PERCENT_FALSE_OR_0_df, how='left')
       
        # elif test_data_type != '':
       #     test_data_type = 'fdsafdsafsd'
        
        elif test_data_type == 'STRING':
            test_columns = test_column_subsets[test_data_type]['COLUMN_NAME'].tolist()
            test_data = raw_data_subset.select(test_columns)
            for test in easy_tests[test_data_type]:
                if test == 'STRING_MAX_LENGTH':#---------------------------------
                    print('   STRING_MAX_LENGTH...')
                    
                    MAX_LENGTH_expression = [
                        (max(length(col(name)))).alias(name)
                        for name in test_columns
                    ]
                    MAX_LENGTH_df = test_data.agg(*MAX_LENGTH_expression).to_pandas().transpose().rename(columns={0:test})
                    test_results_df = test_results_df.join(MAX_LENGTH_df, how='left')
                    
                if test == 'STRING_AVG_LENGTH':#---------------------------------
                    print('   STRING_AVG_LENGTH...')
                    
                    AVG_LENGTH_expression = [
                        (avg(length(col(name)))).alias(name)
                        for name in test_columns
                    ]
                    AVG_LENGTH_df = test_data.agg(*AVG_LENGTH_expression).to_pandas().transpose().rename(columns={0:test})
                    test_results_df = test_results_df.join(AVG_LENGTH_df, how='left')
                    
                if test == 'STRING_MIN_LENGTH':#---------------------------------
                    print('   STRING_MIN_LENGTH...')
                    
                    MIN_LENGTH_expression = [
                        (min(length(col(name)))).alias(name)
                        for name in test_columns
                    ]
                    MIN_LENGTH_df = test_data.agg(*MIN_LENGTH_expression).to_pandas().transpose().rename(columns={0:test})
                    test_results_df = test_results_df.join(MIN_LENGTH_df, how='left')
                    
        elif test_data_type == 'NUMBER':
            
            test_columns = test_column_subsets[test_data_type]['COLUMN_NAME'].tolist()
            test_data = raw_data_subset.select(test_columns)
            for test in easy_tests[test_data_type]:
                if test == 'NUMBER_MIN':#---------------------------------
                    print('   NUMBER_MIN...')
                    
                    MIN_expression = [
                        (min(col(name))).alias(name)
                        for name in test_columns
                    ]
                    MIN_df = test_data.agg(*MIN_expression).to_pandas().transpose().rename(columns={0:test})
                    test_results_df = test_results_df.join(MIN_df, how='left')
                    
                if test == 'NUMBER_AVG':#---------------------------------
                    print('   NUMBER_AVG...')
                    
                    AVG_expression = [
                        (avg(col(name))).alias(name)
                        for name in test_columns
                    ]
                    AVG_df = test_data.agg(*AVG_expression).to_pandas().transpose().rename(columns={0:test})
                    test_results_df = test_results_df.join(AVG_df, how='left')
                    
                if test == 'NUMBER_MEDIAN':#---------------------------------
                    print('   NUMBER_MEDIAN...')
                    
                    MEDIAN_expression = [
                        (approx_percentile(name, 0.5)).alias(name)
                        for name in test_columns
                    ]
                    MEDIAN_df = test_data.agg(*MEDIAN_expression).to_pandas().transpose().rename(columns={0:test})
                    test_results_df = test_results_df.join(MEDIAN_df, how='left')
                    
                if test == 'NUMBER_MODE':#---------------------------------
                    print('   NUMBER_MODE...')
                    
                    MODE_expression = [
                        (mode(col(name))).alias(name)  # Assuming 'mode' is a predefined UDF for calculating the mode of a column
                        for name in test_columns
                    ]
                    MODE_df = test_data.agg(*MODE_expression).to_pandas().transpose().rename(columns={0:test})
                    test_results_df = test_results_df.join(MODE_df, how='left')

                if test == 'NUMBER_MAX':#---------------------------------
                    print('   NUMBER_MAX...')
                    
                    MAX_expression = [
                        (max(col(name))).alias(name)
                        for name in test_columns
                    ]
                    MAX_df = test_data.agg(*MAX_expression).to_pandas().transpose().rename(columns={0:test})
                    test_results_df = test_results_df.join(MAX_df, how='left')
                    
                if test == 'NUMBER_STD_DEV':#---------------------------------
                    print('   NUMBER_STD_DEV...')
                    
                    STD_DEV_expression = [
                        (stddev(col(name))).alias(name)
                        for name in test_columns
                    ]
                    STD_DEV_df = test_data.agg(*STD_DEV_expression).to_pandas().transpose().rename(columns={0:test})
                    test_results_df = test_results_df.join(STD_DEV_df, how='left')
                    
                if test == 'NUMBER_LOWER_QUARTILE':#---------------------------------
                    print('   NUMBER_LOWER_QUARTILE...')
                    
                    LOWER_QUARTILE_expression = [
                        (approx_percentile(col(name), 0.25)).alias(name)
                        for name in test_columns
                    ]
                    LOWER_QUARTILE_df = test_data.agg(*LOWER_QUARTILE_expression).to_pandas().transpose().rename(columns={0:test})
                    test_results_df = test_results_df.join(LOWER_QUARTILE_df, how='left')
                    
                if test == 'NUMBER_UPPER_QUARTILE':#---------------------------------
                    print('   NUMBER_UPPER_QUARTILE...')
                    
                    UPPER_QUARTILE_expression = [
                        (approx_percentile(col(name), 0.75)).alias(name)
                        for name in test_columns
                    ]
                    UPPER_QUARTILE_df = test_data.agg(*UPPER_QUARTILE_expression).to_pandas().transpose().rename(columns={0:test})
                    test_results_df = test_results_df.join(UPPER_QUARTILE_df, how='left')

        elif test_data_type == 'DATE':
            test_columns = test_column_subsets[test_data_type]['COLUMN_NAME'].tolist()
            test_data = raw_data_subset.select(test_columns)
            for test in easy_tests[test_data_type]:
                if test == 'DATE_MIN':#---------------------------------
                    print('   DATE_MIN...')
                    
                    MIN_expression = [
                        (min(col(name))).alias(name)
                        for name in test_columns
                    ]
                    MIN_df = test_data.agg(*MIN_expression).to_pandas().transpose().rename(columns={0:test})
                    test_results_df = test_results_df.join(MIN_df, how='left')
                    
                if test == 'DATE_MAX':#---------------------------------
                    print('   DATE_MAX...')
                    
                    MAX_expression = [
                        (max(col(name))).alias(name)
                        for name in test_columns
                    ]
                    MAX_df = test_data.agg(*MAX_expression).to_pandas().transpose().rename(columns={0:test})
                    test_results_df = test_results_df.join(MAX_df, how='left')
                    
            for test in hard_tests[test_data_type]:
                if test == 'DATE_AVG':#---------------------------------
                    print('   DATE_AVG...')
                    
                    AVG_expression = [
                        to_date(
                        from_unixtime(
                            avg(unix_timestamp(col(name))
                               ))).alias(name)
                        for name in test_columns
                    ]
                    
                    AVG_df = test_data.agg(*AVG_expression).to_pandas().transpose().rename(columns={0:test})
                    test_results_df = test_results_df.join(AVG_df, how='left')
        else:
            print("No tests for data type:",test_data_type)

    # Convert the index into the first column
    test_results_df = test_results_df.reset_index()

    display(test_results_df)
    

    
    
    return(test_results_df)

    
#run_column_analysis()

#### - Push Analysis to Snowflake

In [13]:
#------------------------Push Column Analysis to Snowflake-------------------------------------

def upload_column_analysis_to_snowflake ():
    #__Needed from previous functions__
    #--get_user_inputs--
    global col_analysis_table_name
    #--get_or_create_session--
    global session
    #--run_column_tests--
    global test_results_df
    
    #__Set for the first time in this function__
    global column_analysis
    global snowflake_updated_column_analysis
    
    #___function logic starts here___
    
    #print(test_results_df.columns)
    #print(int(time()))
    print("#-----get data from before merge-----")
    print("")
    #print("running before_sql_query...")
    
    before_sql_query = "SELECT * FROM ENROLL_RATE_COLUMN_ANALYSIS"
    before_pandas_df = session.sql(before_sql_query).to_pandas()
    before_pandas_df = before_pandas_df.reindex(columns=test_results_df.columns)
    
    #print("#-----write test_results_df to a temporary table-----")
    temp_table_name = 'temp_test_results'
    
    print("-----dropping old temp table, and making new-----")
    print("")
    
    drop_table_sql = f"""
    DROP TABLE IF EXISTS "{temp_table_name}";
    """
    session.sql(drop_table_sql).collect()
    
    session.write_pandas(test_results_df, temp_table_name, auto_create_table=True)
    
    
    
    def get_snowflake_table_columns(session: Session, table_name: str) -> list:
        get_columns_SQL = f'DESC TABLE "EDW_PROD"."ANALYTICS"."{table_name}"'
        result = session.sql(get_columns_SQL).collect()
        target_columns = [row["name"].upper() for row in result]
    
        return target_columns
    
    #print("Temp Table Columns Are...")
    testttttt = get_snowflake_table_columns(session, temp_table_name)
    #print(testttttt)
    #print("")
    #print("")
    #print("")
    
    
    print("#-----add missing columns-----")
    print("")
    # Define the target Snowflake table name
    target_table_name = col_analysis_table_name
    
    #TEMP STEP!!! <- Remove everything except 'COLUMN_NAME', 'NATIVE_DATA_TYPE', 'TEST_DATA_TYPE', 'NATIVE_DATA_CLASS', 'FUNCTIONAL_CATEGORY'
    keep_list = ['COLUMN_NAME', 'NATIVE_DATA_TYPE', 'TEST_DATA_TYPE', 'NATIVE_DATA_CLASS', 'FUNCTIONAL_CATEGORY']
    columns_to_reset = [i for i in get_snowflake_table_columns(session, target_table_name) if i not in keep_list]
    
    #print("columns to remove are:")
    #print(columns_to_reset)
    #print("")
    
    if len(columns_to_reset) != 0:
        # Step 4: For each missing column, construct and execute an ALTER TABLE statement to add the column
        for column in columns_to_reset:
            alter_table_sql = f"ALTER TABLE {target_table_name} DROP COLUMN {column};"
            session.sql(alter_table_sql).collect()
    
    # Step 1: Get the columns from the pandas DataFrame (source schema)
    source_columns = test_results_df.columns.tolist()
    #print("source_columns:")
    #print(source_columns)
    #print("")
    
    # Step 2: Get the columns from the Snowflake table (target schema)
    #print("target_columns are...")
    target_columns = get_snowflake_table_columns(session, target_table_name)
    #print(target_columns)
    #print("")
    
    # Step 3: Identify missing columns in the target schema
    missing_columns = [element for element in source_columns if element not in target_columns]
    #print("Columns to Add:")
    #print(missing_columns)
    #print("")
    
    if len(missing_columns) != 0:
        # Step 4: For each missing column, construct and execute an ALTER TABLE statement to add the column
        for column in missing_columns:
            column_type = "STRING" #"DATE" if "DATE" in column.upper() else "STRING"
            alter_table_sql = f"ALTER TABLE {target_table_name} ADD COLUMN {column} {column_type};"
            session.sql(alter_table_sql).collect()
    
    #print("Checking for columns added...")
    target_columns_after_drop_and_add = get_snowflake_table_columns(session, target_table_name)
    #print("columns removed:")
    #print(set(target_columns) - set(target_columns_after_drop_and_add))
    #print("")
    #print("columns added:")
    #print(set(target_columns_after_drop_and_add) - set(target_columns))
    #print("")
    #print("")
    
    
    
    print("#-----perform merge-----")
    print("")
    # Construct update and insert clauses dynamically
    update_set = ",\n".join([f"target.{col} = COALESCE(source.{col}, target.{col})" for col in source_columns])
    insert_columns = ", ".join(list(source_columns))
    insert_values = ", ".join([f"source.{col}" for col in source_columns])
    
    # Construct the MERGE SQL statement
    merge_sql = f"""
    MERGE INTO "ENROLL_RATE_COLUMN_ANALYSIS" AS target
    USING "{temp_table_name}" AS source
    ON target.column_name = source.column_name
    WHEN MATCHED THEN
        UPDATE SET
        {update_set}
    WHEN NOT MATCHED THEN
        INSERT ({insert_columns})
        VALUES ({insert_values});
    """
    #print("executing SQL:")
    #print(merge_sql)
    #print("")
    
    # Execute the MERGE statement
    session.sql(merge_sql).collect()
    
    print("#-----display what was uploaded:-----")
    #print("running after_sql_query...")
    print("")
    after_sql_query = "SELECT * FROM ENROLL_RATE_COLUMN_ANALYSIS"
    after_pandas_df = session.sql(after_sql_query).to_pandas()
    after_pandas_df = after_pandas_df.reindex(columns=test_results_df.columns)
    
    #print("before_pandas_df.columns:")
    #print(before_pandas_df.columns)
    #print("")
    #print("after_pandas_df.columns:")
    #print(after_pandas_df.columns)
    #print("")
    #print("Displaying before DF...")
    #display(before_pandas_df)
    #print("Displaying after DF...")
    display(after_pandas_df)

    snowflake_updated_column_analysis = after_pandas_df
        
    
    print("#------------------------------drop temp table------------------------------")
    # Clean up by dropping the temporary table
    session.sql(f"DROP TABLE IF EXISTS {temp_table_name}").collect()

#upload_column_analysis_to_snowflake()

## Date Sequence Analysis

## MLR / Correlation Analysis

#### - Data Cleaning & Encoding

In [14]:
#--------------------------    Classify Columns for .Corr / ANOVA     ----------------------------

def classify_columns_for_corr_and_anova():
    #__Needed from previous functions__
    #--load_data()--
    global enroll_stable_copy
    #--map_data_types_to_categories--
    global schema_mapped_df
    #--run_column_analysis--
    global test_results_df

    #__Set for the first time in this function__
    global categorical_corr_anova_threshold #manually set in user input function
    global main_target_column #placeholder, need to consolidate & replace hardcoding first. already initialized in params cell.
    global NUMERIC_COLUMN_NAMES
    global CATEGORICAL_COLUMN_NAMES
    global CATEGORICAL_COLUMN_NAMES_plus_goal #placeholder, not sure if I'll need. Not currently initialized in params cell.
    global IGNORE_COLUMN_NAMES
    
    #___function logic starts here___
    threshold = categorical_corr_anova_threshold
    
    import pandas as pd
    print("Starting...")
    
    test_results_to_remove = test_results_df[
        (test_results_df['NULL_PERCENTAGE'] == 1) |
        ((test_results_df['NULL_PERCENTAGE'] == 0) & (test_results_df['DISTINCT_COUNT'] == 1)) |
        (test_results_df['DISTINCT_COUNT'] == 0) |
        (test_results_df['DISTINCT_VALUES'].isna()) #|
        #(test_results_df['COLUMN_NAME'] == 'NON_APP_REFERRAL_DATE')
    ]
    test_results_to_remove = test_results_to_remove[['COLUMN_NAME','NULL_PERCENTAGE','DISTINCT_VALUES','DISTINCT_COUNT']]
    IGNORE_COLUMN_NAMES = test_results_to_remove['COLUMN_NAME'].tolist()
    
    print("Columns removed from testing:")
    display(test_results_to_remove)
    
    filtered_test_results_df = test_results_df[
        ~(
            (test_results_df['NULL_PERCENTAGE'] == 1) |
            ((test_results_df['NULL_PERCENTAGE'] == 0) & (test_results_df['DISTINCT_COUNT'] == 1)) |
            (test_results_df['DISTINCT_COUNT'] == 0) |
            (test_results_df['DISTINCT_VALUES'].isna()) #|
            #(test_results_df['COLUMN_NAME'] == 'NON_APP_REFERRAL_DATE')
        )
    ]
    
    schema_mapped_df_encoding_copy = schema_mapped_df[~schema_mapped_df['COLUMN_NAME'].isin(IGNORE_COLUMN_NAMES)]
    
    Category_Mapping_Reference = {
        'ALWAYS_CATEGORICAL': ['STRING', 'VARIANT', 'MapType', 'GEOGRAPHY'],
        'ALWAYS_NUMERIC': ['BINARY'],
        'POTENTIALLY_NUMERIC_ALL' : ['NUMBER', 'DATE', 'TIME'],
        'IGNORE': ['OTHER']
    }
    
    ALWAYS_CATEGORICAL_COLUMN_NAMES = schema_mapped_df_encoding_copy[schema_mapped_df_encoding_copy['TEST_DATA_TYPE'].isin(Category_Mapping_Reference['ALWAYS_CATEGORICAL'])]['COLUMN_NAME'].tolist()
    ALWAYS_NUMERIC_COLUMN_NAMES = schema_mapped_df_encoding_copy[schema_mapped_df_encoding_copy['TEST_DATA_TYPE'].isin(Category_Mapping_Reference['ALWAYS_NUMERIC'])]['COLUMN_NAME'].tolist()
    POTENTIALLY_NUMERIC_ALL_COLUMN_NAMES = schema_mapped_df_encoding_copy[schema_mapped_df_encoding_copy['TEST_DATA_TYPE'].isin(Category_Mapping_Reference['POTENTIALLY_NUMERIC_ALL'])]['COLUMN_NAME'].tolist()
    IGNORE_COLUMN_NAMES += schema_mapped_df_encoding_copy[schema_mapped_df_encoding_copy['TEST_DATA_TYPE'].isin(Category_Mapping_Reference['IGNORE'])]['COLUMN_NAME'].tolist()
    
    #Decide if categorical or numeric based on # of distinct values
    
    # Filter test_results_df for rows where "COLUMN_NAME" is in POTENTIALLY_NUMERIC_ALL
    categorical_test_results = filtered_test_results_df[filtered_test_results_df["COLUMN_NAME"].isin(POTENTIALLY_NUMERIC_ALL_COLUMN_NAMES)].sort_values(by='DISTINCT_COUNT', ascending=False)
    
    # Further filter for rows where "DISTINCT_COUNT" < threshold
    threshold = 10 
    columns_to_add_to_numeric = categorical_test_results[categorical_test_results["DISTINCT_COUNT"] > threshold]["COLUMN_NAME"].tolist()
    columns_to_add_to_categorical = categorical_test_results[categorical_test_results["DISTINCT_COUNT"] <= threshold]["COLUMN_NAME"].tolist()
    
    # Split up the 'potentially...' categories concretely into numeric or categorical
    
    NUMERIC_COLUMN_NAMES = ALWAYS_NUMERIC_COLUMN_NAMES + columns_to_add_to_numeric
    CATEGORICAL_COLUMN_NAMES = ALWAYS_CATEGORICAL_COLUMN_NAMES + columns_to_add_to_categorical
    
    
    CATEGORICAL_COLUMN_NAMES_plus_goal = CATEGORICAL_COLUMN_NAMES.copy()
    
    # Ensure 'CURRENT_VALUE' is the first element of CATEGORICAL_COLUMN_NAMES_plus_goal
    if 'CURRENT_VALUE' not in CATEGORICAL_COLUMN_NAMES_plus_goal:
        CATEGORICAL_COLUMN_NAMES_plus_goal.insert(0, 'CURRENT_VALUE')  # Adds 'CURRENT_VALUE' at the start
    elif CATEGORICAL_COLUMN_NAMES_plus_goal.index('CURRENT_VALUE') != 0:
        # If 'CURRENT_VALUE' is already in the list but not the first, move it to the start
        CATEGORICAL_COLUMN_NAMES_plus_goal.remove('CURRENT_VALUE')
        CATEGORICAL_COLUMN_NAMES_plus_goal.insert(0, 'CURRENT_VALUE')
    
    
    print("Done!")
    print("")
    print("# of columns in original: ",len(schema_mapped_df)) 
    print("# of columns to test: ",len(NUMERIC_COLUMN_NAMES) + len(CATEGORICAL_COLUMN_NAMES))
    print("# of numeric columns to test: ",len(NUMERIC_COLUMN_NAMES))
    print("# of categorical columns to test: ",len(CATEGORICAL_COLUMN_NAMES))

In [15]:
#--------------------------    Convert Dates to Unix #s for Encoding (snowpark)     ----------------------------

def convert_dates_to_unix_encoding ():    
    print("#--------------------------    Convert Dates to Unix #s for Encoding (snowpark)     ----------------------------")
    #__Needed from previous functions__
    #--load_data--
    global enroll_stable_copy
    #--map_data_types_to_categories--
    global schema_mapped_df
    #--classify_columns_for_corr_and_anova--
    global NUMERIC_COLUMN_NAMES
    global IGNORE_COLUMN_NAMES
    
    #__Set for the first time in this function__
    global raw_data_snowpark_copy
    global DATES_TO_CONVERT_TO_NUMERIC

    
    #___function logic starts here___
    DATES_TO_CONVERT_TO_NUMERIC = schema_mapped_df[
        (schema_mapped_df['COLUMN_NAME'].isin(NUMERIC_COLUMN_NAMES)) &
        (schema_mapped_df['TEST_DATA_TYPE'] == 'DATE')
    ]['COLUMN_NAME'].tolist()
    
    import pandas as pd
    from snowflake.snowpark.session import Session
    from snowflake.snowpark.functions import unix_timestamp, col
    
    
    raw_data_snowpark_copy = enroll_stable_copy.drop(*IGNORE_COLUMN_NAMES)
    
    print("Convert date columns to Unix time in seconds (epoch time)")
    # Assuming all columns in 'filtered_df' are date columns to be converted
    for column_name in DATES_TO_CONVERT_TO_NUMERIC:
        raw_data_snowpark_copy = raw_data_snowpark_copy.withColumn(column_name, unix_timestamp(col(column_name)))
    
    # Display the updated Snowpark DataFrame
    display(raw_data_snowpark_copy.to_pandas()[DATES_TO_CONVERT_TO_NUMERIC].head())

In [16]:
#--------------------------    Sort Category Values by Avg Goal/Value Metric     ----------------------------

def sort_category_values_by_goal_metric():
    print("#--------------------------    Sort Category Values by Avg Goal/Value Metric     ----------------------------")
    #__Needed from previous functions__
    #--load_data--
    #global enroll_stable_copy
    #--map_data_types_to_categories--
    #global schema_mapped_df
    #--classify_columns_for_corr_and_anova--
    global CATEGORICAL_COLUMN_NAMES
    #--convert_dates_to_unix_encoding--
    global raw_data_snowpark_copy
    
    #__Set for the first time in this function__
    global sorted_category_averages

    
    #___function logic starts here___
    def dynamic_sort_by_avg_current_value(df, categorical_columns):
        index = 1
        total_columns = len(categorical_columns)
        
        sorted_category_avg = {}
        final_output_with_ranking = {}
        
        for column in categorical_columns:
            
            if index % 10 == 0:
                print(f"{index}/{total_columns}")
            index += 1
            
            # Group by each categorical column and calculate the average CURRENT_VALUE
            df_grouped_avg = df.groupBy(col(column)).agg(avg("current_value").alias("avg_current_value"))
    
            # Sort the result by the average CURRENT_VALUE in descending order
            df_sorted_avg = df_grouped_avg.sort(col("avg_current_value").asc())
    
            # Collect the results
            sorted_category_avg[column] = df_sorted_avg.collect()
            sorted_values = df_sorted_avg.collect()
    
            #Save as {cat_column1: {'cat1':rank, 'cat2':rank...}, cat_column2...}
            category_ranking = {value[column]: idx for idx, value in enumerate(sorted_values)}
            
            final_output_with_ranking[column] = category_ranking
    
        return final_output_with_ranking
    
    # Execute the function and print debugging information
    sorted_category_averages = dynamic_sort_by_avg_current_value(raw_data_snowpark_copy, CATEGORICAL_COLUMN_NAMES)
    
    print("sample output:")
    print(sorted_category_averages[list(sorted_category_averages.keys())[0]])


In [17]:
#--------------------------    Convert Category Values to their Encoding     ----------------------------

def convert_category_values_to_encoding():
    #__Needed from previous functions__
    #--load_data--
    global enroll_stable_copy
    #--map_data_types_to_categories--
    #global schema_mapped_df
    #--classify_columns_for_corr_and_anova--
    global CATEGORICAL_COLUMN_NAMES
    #--convert_dates_to_unix_encoding--
    global raw_data_snowpark_copy
    #--sort_category_values_by_goal_metric--
    global sorted_category_averages
    
    #__Set for the first time in this function__
    global encoded_df
    global binary_encoded_df

    #___function logic starts here___
    # Convert the Snowflake DataFrame to a Pandas DataFrame for processing
    try:
        encoded_df = raw_data_snowpark_copy.toPandas()
    except Exception as e:
        print(f"Error converting Snowflake DataFrame to Pandas: {e}")
        # Exiting if conversion fails to avoid further errors
        raise
    
    index = 1
    # Apply mappings to replace categories in the DataFrame with their corresponding ranks
    for column, cat_mapping in sorted_category_averages.items():
        if index %10==0:
            print(f"{index}/{len(sorted_category_averages)}")
        index += 1
        # Check if the column exists to avoid KeyError
        if column in encoded_df.columns and column != 'CURRENT_VALUE':
            # Map each category to its rank, retain original value for unmapped/non-categorical columns
            encoded_df[column] = encoded_df[column].map(cat_mapping).fillna(encoded_df[column])
        else:
            # Inform if a specified column is missing in the DataFrame
            print(f"Column {column} not found in DataFrame.")
    
    
    
    
    binary_encoded_df = enroll_stable_copy
    all_columns = binary_encoded_df.columns
    
    # Apply transformation to all columns except the excluded one
    # Replace nulls with 0 and non-nulls with 1 for the rest
    transformed_columns = [
        when(binary_encoded_df[col].is_not_null(), 1).otherwise(0).alias(col)
        if col != 'CURRENT_VALUE'
        else binary_encoded_df[col]
        for col in all_columns
    ]
    
    # Create a new pandas DataFrame with the transformed columns
    binary_encoded_df = binary_encoded_df.select(transformed_columns).to_pandas()
    
    # Display the first few rows to verify
    display(binary_encoded_df.head())
    
    
    '''
    # Convert the updated Pandas DataFrame back to a Snowflake DataFrame
    try:
        updated_enroll_stable_copy = session.createDataFrame(enroll_stable_copy_df)
    except Exception as e:
        print(f"Error converting updated Pandas DataFrame back to Snowflake: {e}")
        # Exiting if conversion fails to prevent incomplete updates
        raise
    
    # Attempt to write the updated DataFrame back to Snowflake, replacing the existing table
    try:
        updated_enroll_stable_copy.write.mode("overwrite").saveAsTable(table_name)
        # Confirm successful update
        print(f"Successfully updated table {table_name}.")
    except Exception as e:
        print(f"Error writing updated DataFrame to Snowflake: {e}")
        # Exiting on failure to ensure data integrity
        raise
    ''';

#### - Run Corr Analysis (Numeric Conversion)

In [18]:
#--------------------------    Perform Correlation Analysis (Standard Encoded)    ----------------------------

def run_standard_encoded_corr_analysis():
    #__Needed from previous functions__
    #--load_data--
    #global enroll_stable_copy
    #--map_data_types_to_categories--
    global schema_mapped_df
    #--classify_columns_for_corr_and_anova--
    global CATEGORICAL_COLUMN_NAMES
    global NUMERIC_COLUMN_NAMES
    global DATES_TO_CONVERT_TO_NUMERIC
    #--convert_dates_to_unix_encoding--
    global raw_data_snowpark_copy
    #--sort_category_values_by_goal_metric--
    global sorted_category_averages
    #--convert_category_values_to_encoding--
    global encoded_df
    
    #__Set for the first time in this function__
    global standard_pearson_corr_results

    #___function logic starts here___
    print('starting...')
    
    #categorical_test_data = encoded_df[CATEGORICAL_COLUMN_NAMES].head(10)
    #date_test_data = encoded_df[DATES_TO_CONVERT_TO_NUMERIC].head(10)
    #NON_DATE_NUMERIC_COLUMNS = list(set(NUMERIC_COLUMN_NAMES) - set(DATES_TO_CONVERT_TO_NUMERIC))
    #numeric_test_data = encoded_df[NON_DATE_NUMERIC_COLUMNS].replace({True: 1, False: 0})
    
    
    
    
    
    
    
    
    
    
    #raw_all_pearson_corr = encoded_df.corr(method='pearson', numeric_only = False)
    #raw_all_spearman_corr = encoded_df.corr(method='spearman', numeric_only = False)
    #raw_all_kendall_corr = encoded_df.corr(method='kendall', numeric_only = False)
    
    start_time = datetime.now()
    print(datetime.now())
    
    #from snowflake.ml.modeling.metrics import correlation
    
    print("# of columns in original: ",len(schema_mapped_df)) 
    print("# of columns to test: ",len(NUMERIC_COLUMN_NAMES) + len(CATEGORICAL_COLUMN_NAMES))
    print("# of numeric columns to test: ",len(NUMERIC_COLUMN_NAMES))
    print("# of numeric (date) columns to test: ", len(DATES_TO_CONVERT_TO_NUMERIC))
    print("# of categorical columns to test: ",len(CATEGORICAL_COLUMN_NAMES))
    
    
    
    
    print("running corr on pandas df...")
    standard_pearson_corr_results = encoded_df.corr(method='pearson', numeric_only = False)
    print(datetime.now())
    duration = datetime.now() - start_time
    print(duration.total_seconds())
    
    print("update structure")
    standard_pearson_corr_results = standard_pearson_corr_results.dropna(axis=1, how='all')
    standard_pearson_corr_results = standard_pearson_corr_results.dropna(axis=0, how='all')
    standard_pearson_corr_results = standard_pearson_corr_results.reset_index()
    standard_pearson_corr_results.rename(columns={standard_pearson_corr_results.columns[0]: 'COLUMN_NAME'}, inplace=True)
    print(datetime.now())
    duration = datetime.now() - start_time
    print(duration.total_seconds())
    
    print("displaying the results...")
    display(standard_pearson_corr_results)
    print(datetime.now())
    duration = datetime.now() - start_time
    print(duration.total_seconds())
    
    print("saving to csv...")
    standard_pearson_corr_results.to_csv(r'C:\Users\miwilliams\Downloads\pandas_corr.csv', index=False)
    print(datetime.now())
    duration = datetime.now() - start_time
    print(duration.total_seconds())
    
    
    
    
    '''
    set_a = set(categorical_test_data.columns)
    set_b = set(date_test_data.columns)
    set_c = set(non_categorical_or_date_test_data.columns)
    
    # Find the unique elements in each set and their intersections
    unique_to_a = set_a - (set_b | set_c)
    unique_to_b = set_b - (set_a | set_c)
    unique_to_c = set_c - (set_a | set_b)
    common_to_all = set_a & set_b & set_c
    common_to_a_b = (set_a & set_b) - set_c
    common_to_a_c = (set_a & set_c) - set_b
    common_to_b_c = (set_b & set_c) - set_a
    
    # Combine results in a dictionary for clarity
    comparison_results = {
        #"unique_to_a": unique_to_a,
        "unique_to_b": unique_to_b,
        "unique_to_c": unique_to_c,
        #"common_to_all": common_to_all,
        #"common_to_a_b": common_to_a_b,
        #"common_to_a_c": common_to_a_c,
        "common_to_b_c": common_to_b_c
    }
    
    pprint(comparison_results)
    '''
    
    
    
    
    
    
    
    
    
    #print("pushing to snowflake...")
    #encoded_df_snowflake = session.write_pandas(encoded_df, "bi_enroll_data_correlation_test", auto_create_table=True)
    #print(datetime.now())
    #duration = datetime.now() - start_time
    #print(duration.total_seconds())
          
    #print("starting correlation analysis...")
    #correlation_matrix_df = correlation(df=encoded_df_snowflake)
    #print(datetime.now())
    #duration = datetime.now() - start_time
    #print(duration.total_seconds())
    
    #print("update structure")
    #correlation_matrix_df = correlation_matrix_df.dropna(axis=1, how='all')
    #correlation_matrix_df = correlation_matrix_df.dropna(axis=0, how='all')
    #correlation_matrix_df = correlation_matrix_df.reset_index()
    #correlation_matrix_df.rename(columns={correlation_matrix_df.columns[0]: 'COLUMN_NAME'}, inplace=True)
    #print(datetime.now())
    #duration = datetime.now() - start_time
    #print(duration.total_seconds())
    
    #print("displaying the results...")
    #display(correlation_matrix_df)
    #print(datetime.now())
    #duration = datetime.now() - start_time
    #print(duration.total_seconds())
    
    #print("saving to csv...")
    #correlation_matrix_df.to_csv(r'C:\Users\miwilliams\Downloads\snowflake_corr.csv', index=False)
    #print(datetime.now())
    #duration = datetime.now() - start_time
    #print(duration.total_seconds())
    
    
    
    
    
    
    
    
    
    
    '''
    print("\n\n\n\n\n\n\n\n\n------------------------------------------------------------------------------\n\n\n\n\n\n\n\n\n")
    
    print("categorical corr tests:")
    categorical_pearson_corr = categorical_test_data.corr(method='pearson', numeric_only = False)
    categorical_spearman_corr = categorical_test_data.corr(method='spearman', numeric_only = False)
    categorical_kendall_corr = categorical_test_data.corr(method='kendall', numeric_only = False)
    
    display(categorical_pearson_corr)
    
    print("\n\n\n\n\n\n\n\n\n------------------------------------------------------------------------------\n\n\n\n\n\n\n\n\n")
    print("date corr tests:")
    date_pearson_corr = date_test_data.corr(method='pearson', numeric_only = False)
    date_spearman_corr = date_test_data.corr(method='spearman', numeric_only = False)
    date_kendall_corr = date_test_data.corr(method='kendall', numeric_only = False)
    
    display(date_pearson_corr)
    
    print("\n\n\n\n\n\n\n\n\n------------------------------------------------------------------------------\n\n\n\n\n\n\n\n\n")
    
    print("numeric corr tests:")
    numeric_pearson_corr = numeric_test_data.corr(method='pearson', numeric_only = False)
    numeric_spearman_corr = numeric_test_data.corr(method='spearman', numeric_only = False)
    numeric_kendall_corr = numeric_test_data.corr(method='kendall', numeric_only = False)
    
    display(numeric_pearson_corr)
    '''
    
    print("done!")

#### - Run Corr Analysis (Binary Conversion - Null/Populated)

In [19]:
#--------------------------    Perform Correlation Analysis (Binary)    ----------------------------

def run_binary_encoded_corr_analysis():
    #__Needed from previous functions__
    #--convert_category_values_to_encoding--
    global binary_encoded_df
    
    #__Set for the first time in this function__
    global binary_pearson_corr_results

    #___function logic starts here___
    '''Inputs:'''
    #enroll_stable_copy
    '''Outputs:'''
    #binary_df
    #raw_all_pearson_corr_binary
    
    print('displaying binary_encoded_df...')
    display(binary_encoded_df.head())
    
    start_time = datetime.now()
    print(start_time)
    
    print("running corr on pandas df...")
    binary_pearson_corr_results = binary_encoded_df.corr(method='pearson', numeric_only = False)
    print(datetime.now())
    duration = datetime.now() - start_time
    print(duration.total_seconds())
    
    print("update structure")
    binary_pearson_corr_results = binary_pearson_corr_results.dropna(axis=1, how='all')
    binary_pearson_corr_results = binary_pearson_corr_results.dropna(axis=0, how='all')
    binary_pearson_corr_results = binary_pearson_corr_results.reset_index()
    binary_pearson_corr_results.rename(columns={binary_pearson_corr_results.columns[0]: 'COLUMN_NAME'}, inplace=True)
    print(datetime.now())
    duration = datetime.now() - start_time
    print(duration.total_seconds())
    
    print("displaying the results...")
    display(binary_pearson_corr_results)
    print(datetime.now())
    duration = datetime.now() - start_time
    print(duration.total_seconds())
    
    print("saving to csv...")
    binary_pearson_corr_results.to_csv(r'C:\Users\miwilliams\Downloads\pandas_corr_binary.csv', index=False)
    print(datetime.now())
    duration = datetime.now() - start_time
    print(duration.total_seconds())

#### - Impute Nulls

In [20]:
#--------------------------    Impute Nulls    ----------------------------

imputed_df = None
def impute_nulls():
    #__Needed from previous functions__
    #--convert_category_values_to_encoding--
    global encoded_df
    global binary_encoded_df
    
    #__Set for the first time in this function__
    global imputed_df

    #___function logic starts here___
    print("impute_nulls running...")
    #from sklearn.impute import KNNImputer <- Dope, but takes too long
    from sklearn.impute import SimpleImputer

    # Removing all raw unix date columns from the encoded_df
    encoded_df_no_dates = encoded_df.drop(columns=DATES_TO_CONVERT_TO_NUMERIC)
    
    # Renaming binary_encoded_df columns to add '_binary' suffix
    binary_encoded_df_renamed = encoded_df_no_dates.add_suffix('_binary')
    
    # Merging the modified encoded_df with the renamed binary_encoded_df
    combined_encoded_df = pd.concat([encoded_df_no_dates, binary_encoded_df_renamed], axis=1)
    
    
    print('1')
    imputed_data = SimpleImputer(strategy='mean').fit_transform(combined_encoded_df) #median, most_frequent
    print('2')
    imputed_df = pd.DataFrame(imputed_data, columns=combined_encoded_df.columns)
    
    # Display the imputed DataFrame
    print("Imputed DataFrame:")
    display(imputed_df.head())
#impute_nulls()

In [21]:
#--------------------------    Calculate MLR    ----------------------------

results_df = {}
best_model_details = None
X_imputed_df = None
def calculate_MLR():
    #__Needed from previous functions__
    #--impute_nulls--
    global imputed_df
    
    #__Set for the first time in this function__
    global results_dict
    global best_model_details
    global X_imputed_df

    #___function logic starts here___
    import statsmodels.api as sm
    from sklearn.linear_model import LinearRegression, Lasso, ElasticNet
    from sklearn.metrics import mean_squared_error
    from math import sqrt
    import warnings
    from sklearn.exceptions import ConvergenceWarning
    import re

    # Split the data into features and target for MLR
    print('Preparing data for modeling...')
    X = imputed_df.drop('CURRENT_VALUE', axis=1)
    y = imputed_df['CURRENT_VALUE']
    X_imputed_df = X 
    
    results_df = {}
    
    # Using sklearn for Lasso Regression
    def attempt_MLR(optimal_alpha, optimal_max_iter, optimal_tol, conservative_alpha, conservative_max_iter, conservative_tol):
        results = {
            'optimal': {'converged': False, 'model': None, 'r_squared': None, 'train_time': None, 'duality_gap': None, 'num_sparse_coefs': None},
            'conservative': {'converged': False, 'model': None, 'r_squared': None, 'train_time': None, 'duality_gap': None, 'num_sparse_coefs': None}
        }
    
        # Function to extract duality gap from warning message
        def extract_duality_gap(warn_message):
            match = re.search(r"Duality gap: ([\d.e+]+)", warn_message)
            if match:
                print('Did not converge. Successfully extracted duality gap.')
                return float(match.group(1))
            print('Did not converge. Could not extract duality gap.')
            return None
    
        # Optimistic attempt
        start_time = time()
        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("always", ConvergenceWarning)
            try:
                model_sk_lasso_optimistic = Lasso(alpha=optimal_alpha, max_iter=optimal_max_iter, tol=optimal_tol).fit(X, y)
                r_squared_lasso_optimistic = model_sk_lasso_optimistic.score(X, y)
                num_sparse_coefs_optimistic = np.sum(model_sk_lasso_optimistic.coef_ != 0)
                results['optimal'].update({'converged': True, 'model': model_sk_lasso_optimistic, 'r_squared': r_squared_lasso_optimistic, 'train_time': time() - start_time, 'duality_gap': model_sk_lasso_optimistic.dual_gap_, 'num_sparse_coefs': num_sparse_coefs_optimistic})
                print(f"Optimal converged successfully")
            except ConvergenceWarning as e:
                print(f"Optimistic Lasso failed to converge: {e}")
                results['optimal'].update({'converged': False, 'model': model_sk_lasso_optimistic, 'r_squared': r_squared_lasso_optimistic, 'train_time': time() - start_time, 'duality_gap': model_sk_lasso_optimistic.dual_gap_, 'num_sparse_coefs': num_sparse_coefs_optimistic})
            except Exception as e:
                print(f"An unexpected error occurred: {e}")
                results['optimal'].update({'train_time': time() - start_time})
        
        # Conservative attempt
        start_time = time()
        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("always", ConvergenceWarning)
            try:
                model_sk_lasso_conservative = Lasso(alpha=conservative_alpha, max_iter=conservative_max_iter, tol=conservative_tol).fit(X, y)
                r_squared_lasso_conservative = model_sk_lasso_conservative.score(X, y)
                num_sparse_coefs_conservative = np.sum(model_sk_lasso_conservative.coef_ != 0)
                results['conservative'].update({'converged': True, 'model': model_sk_lasso_conservative, 'r_squared': r_squared_lasso_conservative, 'train_time': time() - start_time, 'duality_gap': model_sk_lasso_conservative.dual_gap_, 'num_sparse_coefs': num_sparse_coefs_conservative})
                print(f"Conservative converged successfully")
            except ConvergenceWarning as e:
                print(f"Conservative Lasso failed to converge: {e}")
                results['conservative'].update({'converged': False, 'model': model_sk_lasso_conservative, 'r_squared': r_squared_lasso_conservative, 'train_time': time() - start_time, 'duality_gap': model_sk_lasso_conservative.dual_gap_, 'num_sparse_coefs': num_sparse_coefs_conservative})
            except Exception as e:
                print(f"An unexpected error occurred: {e}")
                results['conservative'].update({'train_time': time() - start_time})
    
        return results


    # Additional Function to Refine Parameters Based on Previous Results
    def iterative_refinement():
        global results_dict
        global best_model_details
        # Initialize DataFrame to store results
        columns = ['iteration', 'type', 'alpha', 'max_iter', 'tol', 'train_time', 'num_sparse_coefs', 'duality_gap', 'r_squared', 'converged_goal_met', 'r_squared_goal_gap']

        results_dict = {}
        best_model_details = None
    
        # Initial Parameters
        optimal_alpha = 0.1
        optimal_max_iter = 500
        optimal_tol = 0.0001
        conservative_alpha = 1
        conservative_max_iter = 1000
        conservative_tol = 0.0001
        r_squared_goal = 0.75  # Define the goal for r_squared
        
        for iteration in range(1, 3):  # Iteration count starts at 1
            print(f"\nIteration {iteration}:")
    
            # Run Lasso regression with current parameters
            attempt_results = attempt_MLR(optimal_alpha, optimal_max_iter, optimal_tol,
                                          conservative_alpha, conservative_max_iter, conservative_tol)

            results_dict[iteration] = attempt_results
            

            for model_type, model_details in attempt_results.items():
                if model_details['converged']:
                    if best_model_details is None:
                        best_model_details = model_details
                    else:
                        # Compare with current best model based on r_squared and num_sparse_coefs
                        is_better_r_squared = model_details['r_squared'] > best_model_details['r_squared']
                        is_equal_r_squared = model_details['r_squared'] == best_model_details['r_squared']
                        is_less_sparse_coefs = model_details['num_sparse_coefs'] < best_model_details['num_sparse_coefs']
    
                        if is_better_r_squared or (is_equal_r_squared and is_less_sparse_coefs):
                            best_model_details = model_details

            
            '''
            # Adjust parameters based on the outcome
            for model_key in ['optimal', 'conservative']:
                model_results = attempt_results[model_key]
                if not model_results['converged']:
                    if model_key == 'optimal':
                        optimal_alpha *= 0.75  # Decrease alpha for optimal model
                        optimal_max_iter += 500  # Increase max_iter for more chances to converge
                    else:
                        conservative_max_iter += 500  # Increase max_iter for conservative model
                else:
                    if model_results['r_squared'] < r_squared_goal:
                        # Increase max_iter to improve fit
                        if model_key == 'optimal':
                            optimal_max_iter += 500
                        else:
                            conservative_max_iter += 500
                    if model_results.get('duality_gap') is not None and model_results['duality_gap'] > 1000:
                        # Duality gap too high, adjust tolerance and max_iter
                        if model_key == 'optimal':
                            optimal_tol *= 0.75
                            optimal_max_iter += 500
                        else:
                            conservative_tol *= 0.75
                            conservative_max_iter += 500
    
            # Evaluate and adjust based on sparsity
            optimal_model = attempt_results['optimal'].get('model')
            conservative_model = attempt_results['conservative'].get('model')
            if optimal_model and np.sum(optimal_model.coef_ != 0) < 10:  # Example condition for sparsity
                optimal_alpha *= 1.1  # Increase alpha to promote sparsity
            if conservative_model and np.sum(conservative_model.coef_ != 0) > 50:
                conservative_alpha *= 0.9  # Decrease alpha to allow more features
    
            # Condition to exit the loop if certain criteria are met
            if attempt_results['optimal']['r_squared'] > 0.75 and attempt_results['conservative']['r_squared'] > 0.75:
                print("Both models have satisfactory R-squared values. Ending iterations.")
                break
            '''
    
    iterative_refinement()
#calculate_MLR()

In [22]:
#--------------------------    Extract Model    ----------------------------


def extract_model():
    global sorted_category_averages
    global X_imputed_df
    global best_model_details

    
    print("pd.DataFrame.from_dict(results_dict)")
    
    print("best_model_details")
    print(best_model_details)
    
    best_model = best_model_details['model']
    
    print("Intercept:")
    print(best_model.intercept_)
    
    coefficients = best_model.coef_
    
    feature_names = X_imputed_df.columns

    # Mapping coefficients to feature names
    feature_coefficients = dict(zip(feature_names, coefficients))
    feature_coefficients = dict(sorted(feature_coefficients.items(), key=lambda item: item[1], reverse=True))
    
    
    #print("Feature Coefficients:")
    #for feature, coef in feature_coefficients.items():
    #    print(f"{feature}: {coef}")
    
    
    sparse_feature_coefficients = {feature: coef for feature, coef in feature_coefficients.items() if coef != 0}
    removed_feature_coefficients = {feature: coef for feature, coef in feature_coefficients.items() if coef == 0}
    
    
    sparse_category_cipher = {key: sorted_category_averages[key] for key in sorted_category_averages if key in sparse_feature_coefficients}
    #print("sparse_category_cipher.keys()")
    #print(sparse_category_cipher.keys())
    
    
    
    #print("Sparse Feature Coefficients:")
    #for feature, coef in sparse_feature_coefficients.items():
    #    print(f"{feature}: {coef}")
    print('')
    print('')
    print('')
    print('')
    print('')
    print('')
    
    
    
    
    
    import matplotlib.pyplot as plt
    
    # Assuming 'coefficients' and 'feature_names' are defined as above
    plt.figure(figsize=(10, 8))
    plt.barh(range(len(feature_names)), coefficients)
    plt.yticks(range(len(feature_names)), feature_names)
    plt.xlabel("Coefficient Value")
    plt.ylabel("Feature Name")
    plt.title("Lasso Coefficients")
    plt.show()




    



    # Predict using the best model
    estimated_goals = best_model.predict(X_imputed_df)

    # Add the prediction results as a new column to imputed_df
    imputed_df['estimated goal'] = estimated_goals
    
    
    # Optional: Print or display a portion of imputed_df to verify the new column
    showall(imputed_df[['estimated goal', 'CURRENT_VALUE']])
#extract_model()

In [23]:




    
    
    
    # Placeholder for Step 1.3: Adjusting based on feedback
    # Here you would adjust the alpha values based on the results of the initial fits,
    # potentially in a loop or iteratively adjusting until a satisfactory starting point is found.
    
    # Note: For actual implementation, consider dynamically adjusting the alpha (and other parameters) 
    # based on performance metrics and convergence status.









    
    # Using sklearn for Elastic Net Regression
    #print('10')
    #model_sk_enet = ElasticNet(alpha=0.1, l1_ratio=0.5, verbose=True).fit(X, y)  # l1_ratio controls the mix of L1 and L2 regularization
    #print('11')
    #r_squared_enet = model_sk_enet.score(X, y)





    #--------- simple, maybe later/for comparison ---------

    # Using statsmodels for MLR
    #print('4')
    #X_sm = sm.add_constant(X)  # adding a constant
    #print('5')
    #model_sm = sm.OLS(y, X_sm).fit()
    
    # Using sklearn for MLR
    #print('6')
    #model_sk = LinearRegression().fit(X, y)
    #print('7')
    #r_squared_sk = model_sk.score(X, y)
    
    # Print the statsmodels summary
    #print("\nStatsmodels Summary:")
    #print(model_sm.summary())
    
    # Print the sklearn R-squared value
    #print("\nSklearn R-squared:", r_squared_sk)

## Master Controller

#### - Master Controller

In [24]:
#------------------------ Master Controller (Run all Functions) -------------------------------------
def master_controller ():
    #Collapse all cells: Ctrl + Shift + Left Arrow
    #Uncollapse all cells: Ctrl + Shift + Right Arrow

    print("Master Controller Starting..")
    
    #General:
    print("--Getting user inputs/creating session--")
    get_user_inputs()
    get_or_create_session()
    print("--complete--")
    print("")
    #clear_output(wait=False)

    print("Converting Date Fields...")  
    convert_potential_date_fields()
    print("--complete--")
    print("")

    print("--Adding Normalized Dates--")  
    add_normalized_dates()
    print("--complete--")
    print("")
    #clear_output(wait=False)

    print("--Adding date sequences--")
    add_date_sequences()
    print("--complete--")
    print("")
    #clear_output(wait=False)

    print("--loading data--")
    #load_data(columns_to_load = None, num_rows = 5000)
    load_data()
    print("--complete--")
    print("")
    #clear_output(wait=False)
    
    print("--map_data_types_to_categories--")
    map_data_types_to_categories()
    print("--complete--")
    print("")
    #clear_output(wait=False)

    print("--run_column_analysis--")
    run_column_analysis()
    print("--complete--")
    print("")
    #clear_output(wait=False)
    '''apply the "Fix true/false function" to the categorical values upstream, instead of just the count'''

    print("--uploading analysis to snowflake--")
    upload_column_analysis_to_snowflake()
    print("--complete--")
    print("")
    #clear_output(wait=False)

    '''Encoding:'''
    print("--classify columns for corr/anova--")
    classify_columns_for_corr_and_anova()
    print("--complete--")
    print("")
    #clear_output(wait=False)

    print("--convert dates to unix encoding--")
    convert_dates_to_unix_encoding()
    print("--complete--")
    print("")
    #clear_output(wait=False)

    print("--sort category values by goal metric--")
    sort_category_values_by_goal_metric()
    print("--complete--")
    print("")
    #clear_output(wait=False)
    
    print("--convert category values to encoding--")
    convert_category_values_to_encoding()
    print("--complete--")
    print("")
    #clear_output(wait=False)

    '''correlation analyis:'''
    print("--run standard encoded corr analysis--")
    run_standard_encoded_corr_analysis()
    print("--complete--")
    print("")
    #clear_output(wait=False)
    
    print("--run binary encoded corr analysis--")
    run_binary_encoded_corr_analysis()
    print("--complete--")
    print("")
    #clear_output(wait=False)

    print("--impute nulls--")
    impute_nulls()
    print("--complete--")
    print("")
    #clear_output(wait=False)
    
    print("--calculate MLR--")
    calculate_MLR()
    print("--complete--")
    print("")
    #clear_output(wait=False)

    #Category analysis:
    #...

    print("Master Controller Complete.")

master_controller()

Master Controller Starting..
--Getting user inputs/creating session--
no existing session, creating new...

snowpark API response...
Initiating login request with your identity provider. A browser window should have opened for you to complete the login. If you can't see it, check existing browser windows, or your OS settings. Press CTRL+C to abort and try again...
Going to open: https://aledade.okta.com/app/snowflake/exkd0psxuu9C0LKlN697/sso/saml?SAMLRequest=lZJfb9owFMW%2FSuQ9J3ZCgWIBVQariko7VlI29c2NDfXi2MHXIfDt5%2FBn6h5aaW%2BRc45%2Fx%2Ffc4c2%2BVMFOWJBGj1AcERQInRsu9WaEnrPb8BoF4JjmTBktRuggAN2Mh8BKVdG0dm%2F6SWxrAS7wF2mg7Y8Rqq2mhoEEqlkpgLqcLtOHOU0iQhmAsM7j0NnCQXrWm3MVxbhpmqjpRMZucEIIwWSAvaqVfEHvENXnjMoaZ3KjLpa9f9MHiBiTqxbhFZ6wOBu%2FSn0awWeU15MI6F2WLcLF92WGgvTyuonRUJfCLoXdyVw8P81PAcAn%2BDFPr0ky6EU1hIKBC%2BMItGnWihUiN2VVO39t5L%2FwWnCszEb6Yc2mI1QVkhsmXkzn931B%2BmxVFl223Xbvfh4yu2oeSLr69Y3vVrxXCPna5ChYXapN2mpnALWY6bZQ549IchWSbhh3s%2FiadghNBlE%2FJi8omPpCpWbu6LykZkpwxkVkCseO4VhV4b%2B5sdgXnFSwr%2

Unnamed: 0,DOCUMENT_COMPLETION_STATUS,OPERATIONAL_ACO,MARKET,CURRENT_PRACTICE,PRAC_PRIMARY_FACILITY_TYPE,SECONDARY_REFERRAL_SOURCE,CONTRACT_PRAC_PRIMARY_FACILITY_TYPE,REFERRAL_SOURCE,TARGETED_CONTRACT_MSSP_TRACK_GROUPED,ACO_DISPLAY_NAME,...,TARGETED_PT_CONTRACT_TYPE,CURRENT_ENGAGEMENT_STATUS,CONTRACTED_ENTITY,IRIS_UNABLE_TO_SCHEDULE_REASON,SECOND_WAVE_ENGAGEMENT_DATE,PRAC_OR_LOCATION,REFER_TO_PARTNER_BY,ORDERED_DATE_COL_NAMES,ORDERED_DATES,ORDERED_NORMALIZED_NUMS
0,,KS MSSP Enhanced ACO,Kansas-Oklahoma,238,,,,PARTNER_ALGORITHM,MA,KS MSSP Enhanced (A2916),...,MA,TAGGED_FOR_CACP,,,,,prac_name,[],[],[]
1,,MD Chesapeake ACO,Maryland,1548,FQHC,,FQHC,DIRECT_REFERRAL,MSSP ENHANCED,Chesapeake MSSP (A5003),...,MSSP,PATIENT_DECLINED_SERVICE_TO_IRIS,205860113.0,,,West Cecil Health Center Inc (6261),prac_name,[],[],[]
2,,WA ACO,Washington-Oregon,1756,,,,PARTNER_ALGORITHM,NO ATTRIBUTION,CA MSSP 2021 Enhanced (A4777),...,,TAGGED_FOR_CACP,,,,,prac_name,[],[],[]
3,,OH-PA ACO,Ohio,902,Private Practice,,Private Practice,PARTNER_ALGORITHM,MSSP ENHANCED,PA MSSP Legacy + Gateway Enhanced (A3457),...,MSSP,PATIENT_DECLINED_SERVICE_TO_IRIS,141965232.0,,,Adult Geriatrics of Wooster (902),prac_name,[],[],[]
4,,NC East ACO,North Carolina,942,Private Practice,,Private Practice,PARTNER_ALGORITHM,MA,Non-MSSP,...,MA,IRIS_OUTREACH_IN_PROGRESS,834172038.0,,2024-04-11,Wolinsky Primary Care (942),prac_name,[],[],[]


done, pushing to snowflake
pushed to snowflake
--complete--

--loading data--
Number of rows: 429299
Number of columns: 90
--complete--

--map_data_types_to_categories--
upper-casing pandas columns before transform

Mapping is:


Unnamed: 0,COLUMN_NAME,NATIVE_DATA_TYPE,NATIVE_DATA_CLASS,TEST_DATA_TYPE
0,DOCUMENT_COMPLETION_STATUS,StringType(16777216),StringType,STRING
1,OPERATIONAL_ACO,StringType(16777216),StringType,STRING
2,MARKET,StringType(16777216),StringType,STRING
3,CURRENT_PRACTICE,StringType(16777216),StringType,STRING
4,PRAC_PRIMARY_FACILITY_TYPE,StringType(16777216),StringType,STRING
5,SECONDARY_REFERRAL_SOURCE,StringType(16777216),StringType,STRING
6,CONTRACT_PRAC_PRIMARY_FACILITY_TYPE,StringType(16777216),StringType,STRING
7,REFERRAL_SOURCE,StringType(16777216),StringType,STRING
8,TARGETED_CONTRACT_MSSP_TRACK_GROUPED,StringType(19),StringType,STRING
9,ACO_DISPLAY_NAME,StringType(16777216),StringType,STRING


--complete--

--run_column_analysis--
No subset provided - using original raw snowflake data
----------starting testing----------
Copy schema_mapped_df with COLUMN_NAME as index for results compilation
No columns for test_data_type 'DATE' found.
No columns for test_data_type 'NUMBER' found.
No columns for test_data_type 'BINARY' found.
No columns for test_data_type 'VARIANT' found.
No columns for test_data_type 'GEOGRAPHY' found.
No columns for test_data_type 'CATEGORICAL' found.
Testing data type:  GENERAL
raw_data_subset:
   NULL_PERCENTAGE...
   DISTINCT_VALUES...
distinct done.
Testing data type:  STRING
   STRING_AVG_LENGTH...
   STRING_MAX_LENGTH...
   STRING_MIN_LENGTH...


Unnamed: 0,COLUMN_NAME,NATIVE_DATA_TYPE,NATIVE_DATA_CLASS,TEST_DATA_TYPE,NULL_PERCENTAGE,DISTINCT_VALUES,DISTINCT_COUNT,STRING_AVG_LENGTH,STRING_MAX_LENGTH,STRING_MIN_LENGTH
0,DOCUMENT_COMPLETION_STATUS,StringType(16777216),StringType,STRING,100.0000,[],0,,,
1,OPERATIONAL_ACO,StringType(16777216),StringType,STRING,3.6513,"[""TN Opportunity ACO"", ""NJ Multi-State--AAC 10...",124,16.521353,41.0,6.0
2,MARKET,StringType(16777216),StringType,STRING,3.6816,"[""Kentucky"", ""Alabama"", ""Upper Midwest (IL, IN...",31,11.334215,36.0,4.0
3,CURRENT_PRACTICE,StringType(16777216),StringType,STRING,0.0002,"[""2174"", ""2247"", ""512"", ""1228"", ""140"", ""1421"",...",1965,3.328310,4.0,1.0
4,PRAC_PRIMARY_FACILITY_TYPE,StringType(16777216),StringType,STRING,11.2879,"[""CAH"", "" Unknown Facility Type"", ""Provider-ba...",6,14.300572,22.0,3.0
...,...,...,...,...,...,...,...,...,...,...
85,CONTRACTED_ENTITY,StringType(16777216),StringType,STRING,36.5251,"[""956419205"", ""522010253"", ""800682402"", ""56099...",1853,9.000000,9.0,9.0
86,IRIS_UNABLE_TO_SCHEDULE_REASON,StringType(16777216),StringType,STRING,81.0393,"[""NO_LONGER_WITH_PHYSICIAN"", ""DECEASED"", ""PATI...",8,12.330266,38.0,8.0
87,SECOND_WAVE_ENGAGEMENT_DATE,StringType(16777216),StringType,STRING,88.0370,"[""2022-09-12"", ""2023-03-31"", ""2024-02-16"", ""20...",41,10.000000,10.0,10.0
88,PRAC_OR_LOCATION,StringType(16777216),StringType,STRING,11.2879,"[""Dundalk (6156)"", ""Bear River Medical Arts PC...",4400,31.353836,90.0,10.0


--complete--

--uploading analysis to snowflake--
#-----get data from before merge-----

-----dropping old temp table, and making new-----

#-----add missing columns-----

#-----perform merge-----

#-----display what was uploaded:-----



Unnamed: 0,COLUMN_NAME,NATIVE_DATA_TYPE,NATIVE_DATA_CLASS,TEST_DATA_TYPE,NULL_PERCENTAGE,DISTINCT_VALUES,DISTINCT_COUNT,STRING_AVG_LENGTH,STRING_MAX_LENGTH,STRING_MIN_LENGTH
0,OPERATIONAL_MARKET_NAME_NORMALIZED,LongType(),LongType,NUMBER,,,,,,
1,CONTRACT_AGREEMENT_NORMALIZED,LongType(),LongType,NUMBER,,,,,,
2,START_OF_1ST_REGULAR____VIDA_ID_NORMALIZED,LongType(),LongType,NUMBER,,,,,,
3,SECOND_WAVE_ENGAGEMENT_DATE_NORMALIZED,LongType(),LongType,NUMBER,,,,,,
4,IS_CURRENTLY_HIPRI_NORMALIZED,LongType(),LongType,NUMBER,,,,,,
...,...,...,...,...,...,...,...,...,...,...
208,PT_CONTRACT_PAYER,StringType(16777216),StringType,STRING,11.2879,"[""Allwell MA"", ""AmeriHealth DE"", ""UHC"", ""BCBS ...",26.00000,7.647881,20,3
209,CONTRACT_RISK_TYPE,StringType(16777216),StringType,STRING,11.2879,"["""", ""1SR"", ""2SR""]",3.00000,2.127878,3,0
210,CONTRACT_MSSP_TRACK,StringType(16777216),StringType,STRING,11.2879,"[""Basic E"", ""Enhanced"", """", ""Global""]",4.00000,5.146791,8,0
211,MEMBER_ID,StringType(16777216),StringType,STRING,11.441,"[""2YE2CX7QD14"", ""MEN101558799100"", ""3V81FK3RU3...",380177.00000,11.070813,16,5


#------------------------------drop temp table------------------------------
--complete--

--classify columns for corr/anova--
Starting...
Columns removed from testing:


Unnamed: 0,COLUMN_NAME,NULL_PERCENTAGE,DISTINCT_VALUES,DISTINCT_COUNT
0,DOCUMENT_COMPLETION_STATUS,100.0,[],0


Done!

# of columns in original:  90
# of columns to test:  89
# of numeric columns to test:  0
# of categorical columns to test:  89
--complete--

--convert dates to unix encoding--
#--------------------------    Convert Dates to Unix #s for Encoding (snowpark)     ----------------------------
Convert date columns to Unix time in seconds (epoch time)


0
1
2
3
4


--complete--

--sort category values by goal metric--
#--------------------------    Sort Category Values by Avg Goal/Value Metric     ----------------------------


SnowparkSQLException: (1304): 01b4599b-0604-bff1-0014-1f8374606a0e: 000904 (42000): SQL compilation error: error line 1 at position 30
invalid identifier 'CURRENT_VALUE'

## References

#### Parking Lot

#### Documentation/To-Do-List

In [None]:
# ------------------------Documentation-------------------------------------
'''

#----List of functions ... ----
#General:
    #get_user_inputs
    #get_or_create_session
    #load_data
#Column analysis:
    #map_data_types_to_categories
    #run_column_analysis
    #upload_column_tests_to_snowflake
#Master Controller:
    #master_controller



#---- Structure/template... ----
#Initialize Parameters
    #This section initializes all parameters that will be used
    #Parameter = a variable that will be referenced/updated outside of the function that sets its initial value
    #Below is a list of all functions in this script, and which variables that function sets
    
    #Structured like so...
    #--name_of_function--
    #1st_param_it_sets
    #2nd_param_it_sets
    #...

#Functions
    #All functions will be structured like so...
    
    #------------------------func_name-------------------------------------
    
    #def func_name():
        #__Needed from previous functions__
        #--name_of_function_1--
        #1st_param
        #2nd_param
    
        #--name_of_function_2--
        #1st_param
        #2nd_param
    
        #...
    
        #__Set for the first time in this function__
        #--name_of_function_1--
        #1st_param
        #2nd_param
    
        #--name_of_function_2--
        #1st_param
        #2nd_param
    
        #...
    
        #___function logic starts here___
        #function logic...
        ''';

In [None]:
#------------------------ Jupyter Notebook Auto-Documentation (MUST BE AT BOTTOM!)-------------------------------------
import os
import nbformat
import ast
import inspect
import re
import json
from collections import defaultdict
from IPython.core.getipython import get_ipython
from datetime import datetime

def jupyter_notebook_auto_documentation ():
    # Step 1: Define the notebook path and get the base file name
    notebook_path = os.path.join(os.getcwd(), 'Enroll Rate Data Analysis (Table Analysis).ipynb')
    base_name = os.path.splitext(os.path.basename(notebook_path))[0]
    
    # Step 2: Load the notebook
    with open(notebook_path, 'r', encoding='utf-8') as f:
        nb = nbformat.read(f, as_version=4)
    
    # Step 3: Use IPython to get Python version
    ipython = get_ipython()
    python_version = ipython.run_cell('import platform\nplatform.python_version()').result
    
    # Initialize containers for summary information
    all_imports = set()
    all_variables = defaultdict(set)
    function_listing = defaultdict(list)
    
    # Get current date for file naming
    update_date = datetime.now().strftime("%d-%m-%Y")
    
    # GlobalStructureMap file name
    GlobalStructureMap_file_name = f"{base_name}_GlobalStructureMap_{update_date}.txt"
    
    # CellStructureMap file name
    CellStructureMap_file_name = f"{base_name}_CellStructureMap_{update_date}.txt"
    
    # CellInteractionMap file name
    CellInteractionMap_file_name = f"{base_name}_CellInteractionMap_{update_date}.txt"

    
    # Step 4: Parse notebook cells for overview information
    for cell in nb['cells']:
        if cell['cell_type'] == 'code':
            # Find imports using regex and add them to the set of all imports
            imports = re.findall(r'^(?:import|from)\s+(\S+)', cell['source'], re.MULTILINE)
            all_imports.update(imports)
            
            # AST analysis for functions and variables
            tree = ast.parse(cell['source'])
            for node in ast.walk(tree):
                if isinstance(node, ast.FunctionDef):
                    function_listing[cell.get('execution_count', 'N/A')].append(node.name)
                    # Check for local variables within functions
                    for inner_node in ast.walk(node):
                        if isinstance(inner_node, ast.Assign):
                            for target in inner_node.targets:
                                if isinstance(target, ast.Name):
                                    all_variables[node.name].add(target.id)
                elif isinstance(node, ast.Assign):
                    # Consider top-level assignments as global variables
                    for target in node.targets:
                        if isinstance(target, ast.Name):
                            all_variables['global'].add(target.id)
    
    # Save summary information to a text file
    with open(GlobalStructureMap_file_name, 'w', encoding='utf-8') as summary_file:
        # Overview section at the beginning
        summary_file.write("Overview of GlobalStructureMap Document:\n")
        summary_file.write("   - Structure: High-level overview including the environment setup.\n")
        summary_file.write("   - Contents: Python version, imports, global variables, functions.\n")
        summary_file.write("   - Use-Case: Quick understanding of dependencies and setup.\n\n")
        
        # Table of Contents for quick navigation
        summary_file.write("Table of Contents:\n")
        summary_file.write("   1. Python Version\n")
        summary_file.write("   2. Imports\n")
        summary_file.write("   3. Variables\n")
        summary_file.write("   4. Function Listings\n\n")
        
        # How to interpret/use
        summary_file.write("How to use/interpret this document:\n")
        summary_file.write("   - Python Version: Ensures compatibility and environment setup.\n")
        summary_file.write("   - Imports and Global Variables: Highlights external dependencies and key global data.\n")
        summary_file.write("   - Functions List: A quick reference to functionalities defined within.\n\n")
        
        # Variables
        summary_file.write(f"Python version: {python_version}\n")
        summary_file.write(f"All Imports: {sorted(all_imports)}\n")
        summary_file.write("All Variables:\n")
        for scope, vars in all_variables.items():
            summary_file.write(f"Scope: {scope}, Variables: {sorted(vars)}\n")
        summary_file.write("Current values of important variables:\n")
        summary_file.write("(Direct = used in current WIP, context = referenced...)\n")
        
        #current state of variables        
        summary_file.write(f"All Imports: {peek('1v')}\n")
        summary_file.write(f"All Imports: {peek('1p')}\n")
        summary_file.write(f"All Imports: {peek('2v')}\n")
        summary_file.write(f"All Imports: {peek('3p')}\n")
        summary_file.write(f"All Imports: {peek('3v')}\n")
        summary_file.write(f"All Imports: {peek('3p')}\n")
        
        summary_file.write("Cell Execution and Functions:\n")
        # Modify the sorting line for function_listing.items() with a custom sort key
        for exec_order, funcs in sorted(function_listing.items(), key=lambda x: (x[0] is None, x[0])):
            summary_file.write(f"Cell Execution Order: {exec_order}, Functions: {funcs}\n")

    
    # Detailed analysis part
    with open(CellStructureMap_file_name, 'w', encoding='utf-8') as detailed_file:
        # Overview section at the beginning
        detailed_file.write("Overview of CellStructureMap Document:\n")
        detailed_file.write("   - Structure: Detailed content mapping for each cell.\n")
        detailed_file.write("   - Contents: Cell types, execution orders, imports, functions, variables, and integrations.\n")
        detailed_file.write("   - Use-Case: Facilitates debugging and understanding cell roles and outputs.\n\n")
        # Table of Contents for quick navigation
        detailed_file.write("Table of Contents:\n")
        detailed_file.write("   - Cell Execution and Contents Overview\n\n")
        # How to interpret/use
        detailed_file.write("How to use/interpret this document:\n")
        detailed_file.write("   - Execution Order: Provides the sequence of notebook execution for logical flow.\n")
        detailed_file.write("   - Detailed Cell Analysis: Offers an in-depth look at the code structure and data processing steps.\n\n")
    
        # Content
        for cell in nb['cells']:
            detailed_file.write(f"\nCell type: {cell['cell_type']}, Execution order: {cell.get('execution_count', 'N/A')}\n")
            
            if cell['cell_type'] == 'code':
                # Find imports using regex
                imports = re.findall(r'^(?:import|from)\s+(\S+)', cell['source'], re.MULTILINE)
                detailed_file.write(f"Imports: {imports}\n")
                
                # AST analysis for functions and variables
                tree = ast.parse(cell['source'])
                for node in ast.walk(tree):
                    if isinstance(node, ast.FunctionDef):
                        detailed_file.write(f"Function: {node.name}, Docstring: {ast.get_docstring(node)}\n")
                        nested_funcs = [n.name for n in ast.walk(node) if isinstance(n, ast.FunctionDef) and n is not node]
                        if nested_funcs:
                            detailed_file.write(f"Nested functions: {nested_funcs}\n")
                    elif isinstance(node, ast.Assign):
                        for target in node.targets:
                            if isinstance(target, ast.Name):
                                detailed_file.write(f"Variable: {target.id}\n")
                    elif isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute):
                        detailed_file.write(f"Possible API call or Integration: {ast.dump(node)}\n")
    
    
    # Containers for highly detailed information
    classes = defaultdict(list)
    async_functions = defaultdict(list)
    decorators = defaultdict(list)
    comprehensions = defaultdict(list)
    function_calls_details = defaultdict(lambda: defaultdict(list))
    
    # Parsing for highly detailed information
    for cell_index, cell in enumerate(nb['cells']):
        if cell['cell_type'] == 'code':
            tree = ast.parse(cell['source'])
            for node in ast.walk(tree):
                # Classes
                if isinstance(node, ast.ClassDef):
                    classes[cell_index].append(node.name)
                # Async Functions
                elif isinstance(node, ast.AsyncFunctionDef):
                    async_functions[cell_index].append(node.name)
                # Decorators
                elif isinstance(node, ast.FunctionDef) or isinstance(node, ast.AsyncFunctionDef):
                    for decorator in node.decorator_list:
                        decorator_name = getattr(decorator, 'id', str(decorator))
                        decorators[cell_index].append(decorator_name)
                # Comprehensions
                elif isinstance(node, (ast.ListComp, ast.DictComp, ast.SetComp, ast.GeneratorExp)):
                    comprehensions[cell_index].append(ast.dump(node))
                # Function Calls
                elif isinstance(node, ast.Call):
                    callee = getattr(node.func, 'id', getattr(node.func, 'attr', str(node.func)))
                    function_calls_details[cell_index]['calls'].append(callee)
    
    # Writing highly detailed analysis to file
    with open(CellInteractionMap_file_name, 'w', encoding='utf-8') as f:
        # Overview section at the beginning
        f.write("Overview of CellInteractionMap Document:\n")
        f.write("   - Structure: Examination of cell interconnections and relationships.\n")
        f.write("   - Contents: Classes, async functions, decorators, comprehensions, function calls.\n")
        f.write("   - Use-Case: Essential for data flow, dependencies, and interaction optimization.\n\n")
        # Table of Contents for quick navigation
        f.write("Table of Contents:\n")
        f.write("   1. Classes\n")
        f.write("   2. Async Functions\n")
        f.write("   3. Decorators\n")
        f.write("   4. Comprehensions\n")
        f.write("   5. Function Calls\n\n")
        # How to interpret/use
        f.write("How to use/interpret this document:\n")
        f.write("   - Inter-cell Relationships: Unveils how different components work together, enhancing comprehension of complex notebooks.\n")
        f.write("   - Insights for Optimization: Identifies potential refactoring opportunities for efficiency or readability improvements.\n\n")
    
        
        #Content
        
        # Classes
        if classes:
            f.write("Classes:\n")
            for cell, class_names in classes.items():
                f.write(f"Cell {cell}: {', '.join(class_names)}\n")
        # Async Functions
        if async_functions:
            f.write("\nAsync Functions:\n")
            for cell, functions in async_functions.items():
                f.write(f"Cell {cell}: {', '.join(functions)}\n")
        # Decorators
        if decorators:
            f.write("\nDecorators:\n")
            for cell, decorator_names in decorators.items():
                f.write(f"Cell {cell}: {', '.join(decorator_names)}\n")
        # Comprehensions
        if comprehensions:
            f.write("\nComprehensions:\n")
            for cell, comps in comprehensions.items():
                f.write(f"Cell {cell}: {', '.join(comps)}\n")
        # Function Calls
        if function_calls_details:
            f.write("\nFunction Calls:\n")
            for cell, details in function_calls_details.items():
                f.write(f"Cell {cell}: {', '.join(details['calls'])}\n")

jupyter_notebook_auto_documentation()

## Categorical Analysis (Need to convert to functions)

In [None]:
#------------------------ Categorical Analysis - Data Generation -------------------------------------
#to add later:
#   -distribution (to column & compare here for both continuous & categorical)
#   -variability changes/analysis

experiment_category_columns = None
experiment_category_list = None
cat_test_results_dict = None
def categorical_analysis ():
    print("categorical_analysis starting...")
    #__Needed from previous functions__
    #--get_or_create_session--
    global session
    #--load_data--
    global enroll_stable_copy
    #--upload_column_tests_to_snowflake--
    global snowflake_updated_column_analysis
    
    #__Set for the first time in this function__
    global experiment_category_columns
    global experiment_category_list
    global cat_test_results_dict
    
    #___function logic starts here___
    if snowflake_updated_column_analysis is not None:
        print("Column Analysis Results in Local Storage")
    else:
        print("Checking for Analysis Results Table in Snowflake")

        print("If not in snowflake...")
        print("Calling Column Analysis Function...")

    columns_to_pull = ["COLUMN_NAME", "TEST_DATA_TYPE", "DISTINCT_COUNT", "DISTINCT_VALUES"]
    new_column_names = {
        "COLUMN_NAME": "CATEGORICAL_FIELD",
        "TEST_DATA_TYPE": "CATEGORICAL_FIELD_TYPE",
        "DISTINCT_COUNT": "NUM_CATEGORIES",
        "DISTINCT_VALUES": "CATEGORIES_LIST"
    }
    
    # Convert 'DISTINCT_COUNT' column to numeric, errors='coerce' will convert non-convertible values to NaN
    snowflake_updated_column_analysis['DISTINCT_COUNT'] = pd.to_numeric(snowflake_updated_column_analysis['DISTINCT_COUNT'], errors='coerce')
    
    # Now you can filter the DataFrame
    potentially_categorical_df = snowflake_updated_column_analysis[snowflake_updated_column_analysis['DISTINCT_COUNT'] <= 10][columns_to_pull]
    
    # Rename the columns to the new names
    potentially_categorical_df = potentially_categorical_df.rename(columns=new_column_names)

    print("len is:",potentially_categorical_df.shape[0])
    #pd.options.display.max_columns = 60
    #pd.options.display.max_rows = 180
    #display(potentially_categorical_df)
    #pd.options.display.max_columns = 60
    #pd.options.display.max_rows = 20

    experiment_category_columns = ['CURRENT_PIPELINE_STEP']

    cat_test_results_dict = {}
    for experiment_category_column in experiment_category_columns:
        
        print("experiment_category_column is:")
        print(experiment_category_column)
        experiment_category_list = potentially_categorical_df[potentially_categorical_df['CATEGORICAL_FIELD'] == experiment_category_column]['CATEGORIES_LIST']
        experiment_category_list = ast.literal_eval(experiment_category_list.iloc[0])
        print(experiment_category_list)
        for experiment_category_value in experiment_category_list:
            print("experiment_category_value is:")
            print(experiment_category_value)

            print("running column analysis on cat...")
            category_subset_results_raw = enroll_stable_copy[enroll_stable_copy[experiment_category_column] == experiment_category_value]
            category_subset_results = run_column_analysis(category_subset_results_raw)
            cat_test_results_dict[experiment_category_value] = category_subset_results
            #display(category_subset_results)
            print("cat column analysis complete")

            print("running column analysis on inverse...")
            category_subset_results_inverse_raw = enroll_stable_copy[enroll_stable_copy[experiment_category_column] != experiment_category_value]
            category_subset_results_inverse = run_column_analysis(category_subset_results_inverse_raw)
            cat_test_results_dict[experiment_category_value + "inverse"] = category_subset_results_inverse
            #display(category_subset_results_inverse)
            print("inverse column analysis complete")
            
        print("done")
            
categorical_analysis()

In [None]:
#------------------------ Categorical Analysis - Comparison -------------------------------------
#experiment_category_columns = None
#experiment_category_list = None
#cat_test_results_dict = None

#pd.options.display.max_columns = 60
#pd.options.display.max_rows = 180
#display(potentially_categorical_df)
#pd.options.display.max_columns = 60
#pd.options.display.max_rows = 20

#print(experiment_category_columns)
#print(experiment_category_list)
#print(cat_test_results_dict)

snowflake_updated_column_analysis_copy = snowflake_updated_column_analysis.set_index("COLUMN_NAME")

# Initialize an empty list to store row dictionaries
rows_list = []

for experiment_category_column in experiment_category_columns:
    print("CATEGORY_COLUMN:",experiment_category_column)
    for experiment_category_value in experiment_category_list:
        print("CATEGORY_VALUE:",experiment_category_value)
        
        cat_results = cat_test_results_dict[experiment_category_value]
        #cat_results = cat_results.set_index("COLUMN_NAME")
        cat_results = cat_results.set_index("COLUMN_NAME")
        #showall(cat_results)
        #display(cat_results)
        
        cat_inverse_results = cat_test_results_dict[experiment_category_value + "inverse"]
        cat_inverse_results = cat_inverse_results.set_index("COLUMN_NAME")
        
        #cat_inverse_results = cat_inverse_results.set_index("COLUMN_NAME")
        #display(cat_inverse_results)
        
        diff = cat_results.compare(cat_inverse_results)

        
        cat_columns = set(cat_results.columns)
        cat_rows = set(cat_results.index)#set(cat_results.loc[list(cat_results.index), "COLUMN_NAME"])

        cat_inv_columns = set(cat_inverse_results.columns)
        cat_inv_rows = set(cat_inverse_results.index)#set(cat_inverse_results.loc[list(cat_inverse_results.index), "COLUMN_NAME"])

        diff_columns = set([col[0] for col in diff.columns])
        diff_rows = set(diff.index)
        #cat_rows = set(diff.loc[list(diff.index), "COLUMN_NAME"])
        #cat_rows = set(diff.loc[diff.index, "COLUMN_NAME"])

        #CAT_COLUMNS_NOT_IN_DIFF     = cat_results_columns - diff_columns
        #CAT_INV_COLUMNS_NOT_IN_DIFF = cat_inv_columns - diff_columns
        #CAT_ROWS_NOT_IN_DIFF        = cat_rows - diff_rows
        #CAT_INV_ROWS_NOT_IN_DIFF    = cat_inv_rows - diff_rows
        
        _0_percent_null_fields = cat_results[cat_results['NULL_PERCENTAGE'] == 0].index.tolist()        
        _0_percent_null_fields_count = len(_0_percent_null_fields)

        _100_percent_null_fields = cat_results[cat_results['NULL_PERCENTAGE'] == 1].index.tolist()
        _100_percent_null_fields_count = len(_100_percent_null_fields)
    
        columns_to_analyze = ['NULL_PERCENTAGE', 'NUMBER_AVG','BINARY_PERCENT_TRUE_OR_1','BINARY_PERCENT_FALSE_OR_0','STRING_AVG_LENGTH']
        #DISTINCT_VALUES
        #DATE_AVG 
        
        # Utility function for calculating percent change
        def calculate_percent_change(before, after):
            try:
                before_numeric = float(before) if before is not None else None
                after_numeric = float(after) if after is not None else None

                # Now check for np.isnan, since we have ensured the values are either None or float
                before_numeric = None if before_numeric is not None and np.isnan(before_numeric) else before_numeric
                after_numeric = None if after_numeric is not None and np.isnan(after_numeric) else after_numeric
                
                if before_numeric is not None and after_numeric is not None and before_numeric != 0:
                    return (after_numeric - before_numeric) / before_numeric
            except ValueError:
                pass
            return None
        
        # Initialize a dictionary to hold all results
        analysis_results = {}
        
        for column in columns_to_analyze:
            # Extract column-specific dictionaries from DataFrames
            column_original = snowflake_updated_column_analysis_copy[column].to_dict()
            column_cat = cat_results[column].to_dict()
            column_cat_inverse = cat_inverse_results[column].to_dict()
        
            # Prepare storage within analysis_results
            analysis_results[f'percent_change_in_{column}_compared_to_inverse'] = {}
            analysis_results[f'percent_change_in_{column}_compared_to_original'] = {}
            analysis_results[f'columns_where_{column}_did_change_compared_to_inverse'] = []
            analysis_results[f'columns_where_{column}_didnt_change_compared_to_inverse'] = []
            analysis_results[f'columns_where_{column}_did_change_compared_to_original'] = []
            analysis_results[f'columns_where_{column}_didnt_change_compared_to_original'] = []
        
            # Combined loop for both comparisons
            for column_name in column_cat.keys():
                cat_value = column_cat[column_name]
                inv_value = column_cat_inverse.get(column_name)
                orig_value = column_original.get(column_name)
        
                # Calculate percent changes with numeric conversion
                percent_change_to_inv = calculate_percent_change(cat_value, inv_value)
                percent_change_to_orig = calculate_percent_change(cat_value, orig_value)
        
                # Store results if percent change is not None and not 0
                if percent_change_to_inv not in (None, 0):
                    analysis_results[f'percent_change_in_{column}_compared_to_inverse'][column_name] = {'cat_value': cat_value, 'inv_value': inv_value, 'percent_change': percent_change_to_inv}
                    analysis_results[f'columns_where_{column}_did_change_compared_to_inverse'].append(column_name)
                else:
                    analysis_results[f'columns_where_{column}_didnt_change_compared_to_inverse'].append(column_name)
        
                if percent_change_to_orig not in (None, 0):
                    analysis_results[f'percent_change_in_{column}_compared_to_original'][column_name] = {'cat_value': cat_value, 'original_value': orig_value, 'percent_change': percent_change_to_orig}
                    analysis_results[f'columns_where_{column}_did_change_compared_to_original'].append(column_name)
                else:
                    analysis_results[f'columns_where_{column}_didnt_change_compared_to_original'].append(column_name)
            analysis_results[f'percent_change_in_{column}_compared_to_inverse'] = dict(sorted(analysis_results[f'percent_change_in_{column}_compared_to_inverse'].items(), key=lambda item: item[1]['percent_change'], reverse=True))
            analysis_results[f'percent_change_in_{column}_compared_to_original'] = dict(sorted(analysis_results[f'percent_change_in_{column}_compared_to_original'].items(), key=lambda item: item[1]['percent_change'], reverse=True))


        '''
        columns_to_verify = ['NULL_PERCENTAGE', 'NUMBER_AVG']
        
        # Printing the analysis results for percent changes
        for column in columns_to_verify:
            print(f"Analysis results for {column}:")
        
            # Define keys for percent changes
            percent_change_keys = [
                f'percent_change_in_{column}_compared_to_inverse',
                f'percent_change_in_{column}_compared_to_original'
            ]
        
            # Iterate through the specific keys for percent changes
            for key in percent_change_keys:
                print(f"\n{key.replace('_', ' ').title()}:")
        
                # Check if the key exists to avoid KeyError
                if key in analysis_results:
                    for col_name, change_details in analysis_results[key].items():
                        print(f"  {col_name}: {change_details}")
                else:
                    print("  No data available.")
        
            print("\n")  # Add an extra newline for better readability between column results
        '''

        # Create a dictionary for the row
        row_dict = {
            "CATEGORY_COLUMN" : experiment_category_column,
            "CATEGORY_VALUE": experiment_category_value,
            #"CAT_COLUMNS_NOT_IN_DIFF": CAT_COLUMNS_NOT_IN_DIFF,
            #"CAT_INV_COLUMNS_NOT_IN_DIFF": CAT_INV_COLUMNS_NOT_IN_DIFF,
            #"CAT_ROWS_NOT_IN_DIFF": CAT_ROWS_NOT_IN_DIFF,
            #"CAT_INV_ROWS_NOT_IN_DIFF": CAT_INV_ROWS_NOT_IN_DIFF,
            "_0_percent_null_fields": _0_percent_null_fields,
            "_0_percent_null_fields_count": _0_percent_null_fields_count,
            "_100_percent_null_fields": _100_percent_null_fields,
            "_100_percent_null_fields_count": _100_percent_null_fields_count
        }
        # Update row_dict by merging analysis_results into it
        row_dict.update(analysis_results)
        
        # Append the row dictionary to the list
        rows_list.append(row_dict)


        #print(cat_results_columns)
        #print(cat_results_rows)
        #print(cat_inverse_results_columns)
        #print(cat_inverse_results_rows)  


# Convert the list of dictionaries to a DataFrame
stats_df = pd.DataFrame(rows_list)

# Display the new DataFrame
display(stats_df.iloc[:, :12]) 