# 1. Setup & Imports

In [17]:
import os
import random 
import pandas as pd
import warnings

from pprint import pprint

warnings.filterwarnings('ignore')

# BigQuery
from google.cloud import bigquery

# Snowflake
import snowflake.connector as sf

# 2. Configure Connections

## 2.1 BigQuey Client

In [18]:
bq_project_id = "okta-ga-rollup"
bq_client = bigquery.Client(project=bq_project_id)
timeout_seconds = 300

## 2.2 Snoflake Connection

In [47]:
# Snoflake credentials
SNOWFLAKE_USER = 'vlad.parakhin@okta.com'
SNOWFLAKE_AUTHENTICATOR = 'https://okta.okta.com'
SNOWFLAKE_ACCOUNT = 'okta'
SNOWFLAKE_PASSWORD = ' '
SNOWFLAKE_WAREHOUSE = 'okta_dt_m'
SNOWFLAKE_ROLE = 'digital_analytics_access_role'
SNOWFLAKE_DB = 'OKTA'
SNOWFLAKE_SCHEMA = 'OKTA_DT'


conn = sf.connect(
    user = SNOWFLAKE_USER, 
    authenticator='externalbrowser', 
    account = SNOWFLAKE_ACCOUNT, 
    password = SNOWFLAKE_PASSWORD,
    warehouse = SNOWFLAKE_WAREHOUSE,
    role = SNOWFLAKE_ROLE,
    database = SNOWFLAKE_DB,
    schema = SNOWFLAKE_SCHEMA
    )

cursor = conn.cursor()
pprint(" ")
pprint("Successfully connected to Snowflake")

Initiating login request with your identity provider. A browser window should have opened for you to complete the login. If you can't see it, check existing browser windows, or your OS settings. Press CTRL+C to abort and try again...
Going to open: https://okta.okta.com/app/snowflake/exk1egd3l4qGXJ0YZ1d8/sso/saml?SAMLRequest=jZJRb9owFIX%2FSuQ9EzuBatQCKgorYwWKSlg33rzYgBfHTn0dAv9%2BTihT%2B9CqL5Zln%2BP7XZ%2FbuznmKjgIC9LoPopCggKhU8Ol3vXROrlrdVEAjmnOlNGij04C0M2gByxXBR2Wbq8fxXMpwAX%2BIQ20ueij0mpqGEigmuUCqEvpajif0TgktLDGmdQo9MrysYMBCOs84cXCQXq8vXMFxbiqqrBqh8bucEwIweQae1Ut%2BXLRH31P7%2BgjTDq13iu8fPnCdiv1%2BQs%2BwvpzFgH9niTL1vJhlaBgeEEdGQ1lLuxK2INMxfpxdgYAT%2FBwnwxD0KbaKpaJ1ORF6fxDod%2FhreBYmZ30vU7HfVRkku%2FIpnuYTDbz4j6fbbL1LWR%2F5wv1ND%2B1Ty4m68X%2B%2BmkxH%2FGr8luKgp%2BXMOM6zClAKaa6jtD5IxJftUjUittJ1KFtQsnXMI7IBgVjH6HUzDXOC6fJHAubpSZjRYH%2FQ2NxzCKx423VeZ78%2BkF%2BbyLexQAG13Gi84TQprodfKLvHn5teJmvhf%2Fy6XhplExPwZ2xOXPvJxKFUXMieWvbSKnImVRDzq0A8MkoZaqRFcz5MXa2FAgPzlXfDvLgHw%3D%3D&RelayState=ver%3A1

# 3. Define Test Configuration

In [155]:
TABLES_TO_TEST = [
    {
        "bq_dataset":"dbt_prod_ga4_reporting",
        "bq_table":"ga4__content_with_ua_union",
        "sf_table":"GA4_CONTENT",
        "agg_columns":["unique_pageviews", "any_conversion_session_conversions_unique"]
    },
    {
        "bq_dataset":"dbt_prod_ga4_reporting",
        "bq_table":"ga4__traffic_with_ua_union",
        "sf_table":"GA4_TRAFFIC",
        "agg_columns":["sessions", "high_value_visits"]
    },
    {
        "bq_dataset":"dbt_prod_ga4_reporting",
        "bq_table":"ga4__flattened_hits",
        "sf_table":"GA4_HITS",
        "agg_columns": [
            "session_engaged"
        ],
        "custom_aggregates": { 
           "sum_flag_high_value": "CASE WHEN is_high_value_visit=TRUE THEN 1 ELSE 0 END"
        }
    }
]

For MoM, QoQ, YoY checks define the time windows:

In [156]:
time_windows = [
    ("2024-12-01", "2024-12-31"),  # December 2024 (MoM example)
    ("2024-10-01", "2024-12-31"),  # Q4 2024 (QoQ example)
    ("2024-01-01", "2024-12-31")  # Full year 2024 (YoY example)
]

# 4. Helper Functions

## 4.1 Get Schema

In [157]:
def get_bq_schema(dataset_name, table_name):
    query = f"""
    SELECT column_name, data_type
    FROM `{bq_project_id}.{dataset_name}.INFORMATION_SCHEMA.COLUMNS`
    WHERE table_name = '{table_name}'
    ORDER BY ordinal_position
    """
    df = bq_client.query(query).to_dataframe()
    return df

def get_snowflake_schema(db, schema, table_name):
    query = f"""
    SELECT 
        column_name, 
        data_type 
    FROM {db}.information_schema.columns
    WHERE table_name = '{table_name.upper()}'
      AND table_schema = '{schema.upper()}'
    ORDER BY ordinal_position
    """
    return pd.read_sql(query, conn)

## 4.2. Compare Schemas

In [158]:
def compare_schemas(bq_schema_df, sf_schema_df):
    """
    Compare BigQuery and Snowflake schemas for missing, extra, or mismatched columns.
    Adjusts for differences in data type representation between the two platforms.
    """
    # Normalize BigQuery and Snowflake data types for comparison
    type_mapping = {
        'string': 'text',
        'int64': 'number',
        'float64': 'float',
        'timestamp': 'timestamp_ntz',  
        'boolean': 'boolean',
        'date': 'date',
        'datetime': 'timestamp',
        'bool': 'boolean',
        'array':'variant',
        'struct':'variant'
    }
    
    def normalize_type(data_type, mapping):
        return mapping.get(data_type.lower(), data_type.lower())
    
    # Convert to dict: {column_name.lower(): normalized_data_type}
    bq_cols = {
        row["column_name"].lower(): normalize_type(row["data_type"], type_mapping)
        for _, row in bq_schema_df.iterrows()
    }
    sf_cols = {
        row["COLUMN_NAME"].lower(): normalize_type(row["DATA_TYPE"], type_mapping)
        for _, row in sf_schema_df.iterrows()
    }
    
    bq_set = set(bq_cols.keys())
    sf_set = set(sf_cols.keys())
    
    missing_in_snowflake = bq_set - sf_set
    extra_in_snowflake = sf_set - bq_set
    
    # For columns that exist in both, compare normalized data types
    type_mismatches = []
    for col in (bq_set & sf_set):
        if bq_cols[col] != sf_cols[col]:
            type_mismatches.append((col, bq_cols[col], sf_cols[col]))
    
    return missing_in_snowflake, extra_in_snowflake, type_mismatches


## 4.3 Row Count & Aggregation Checks

In [159]:
def get_bq_row_count(dataset_name, table_name, start_date, end_date):  
    query = f"""
    SELECT COUNT(*) AS row_count
    FROM `{bq_project_id}.{dataset_name}.{table_name}`
    WHERE date >= '{start_date}' AND date <= '{end_date}'
    """
    df = bq_client.query(query).to_dataframe()
    return df["row_count"].iloc[0]

def get_sf_row_count(db, schema, table_name, start_date, end_date):
    query = f"""
    SELECT COUNT(*) AS ROW_COUNT
    FROM {db}.{schema}.{table_name}
    WHERE date >= '{start_date}' AND date <= '{end_date}'
    """
    df = pd.read_sql(query, conn)
    return df["ROW_COUNT"].iloc[0]

In [160]:
def get_bq_aggregate(dataset_name, table_name, start_date, end_date, agg_column):
    query = f"""
    SELECT SUM({agg_column}) AS total_value
    FROM `{bq_project_id}.{dataset_name}.{table_name}`
    WHERE date >= '{start_date}' AND date <= '{end_date}'
    """
    df = bq_client.query(query).to_dataframe()
    df["total_value"]= df["total_value"].fillna(0)
    return df["total_value"].iloc[0]

def get_sf_aggregate(db, schema, table_name, start_date, end_date, agg_column):
    query = f"""
    SELECT SUM({agg_column}) AS TOTAL_VALUE
    FROM {db}.{schema}.{table_name}
    WHERE date >= '{start_date}' AND date <= '{end_date}'
    """
    df = pd.read_sql(query, conn)
    return df["TOTAL_VALUE"].iloc[0] or 0

In [161]:
def get_bq_custom_aggregate(dataset_name, table_name, start_date, end_date, agg_expr):
    query = f"""
    SELECT SUM({agg_expr}) AS custom_agg
    FROM `{bq_project_id}.{dataset_name}.{table_name}`
    WHERE date BETWEEN '{start_date}' AND '{end_date}'
    """
    result = bq_client.query(query).to_dataframe()
    return result["custom_agg"].iloc[0] if not result.empty else 0

def get_sf_custom_aggregate(db, schema, table_name, start_date, end_date, agg_expr):
    query = f"""
    SELECT SUM({agg_expr.strip()}) AS custom_agg
    FROM {db}.{schema}.{table_name}
    WHERE DATE BETWEEN '{start_date}' AND '{end_date}'
    """
    
    #print(f"Executing Snowflake custom aggregate query:\n{query}")  # Debugging
    result = pd.read_sql(query, conn)


    if "CUSTOM_AGG" in result.columns:
        value = result["CUSTOM_AGG"].iloc[0]
        return value if not result.empty else 0

## 4.4 Random Sampling Checks

In [162]:
test_columns = ["user_type", "country", "data_source", "date", "unique_pageviews", "last_non_direct_channel"]

In [163]:
def random_sample_check(dataset_name, bq_table, sf_table, start_date, end_date, sample_size=5):
    """
    Perform a random sample check between BigQuery and Snowflake.
    Dynamically handles column selection based on the schema.
    """
    # Fetch the schema of the BigQuery table
    bq_schema = get_bq_schema(dataset_name, bq_table)

    # Ensure `bq_schema` is a list of dictionaries
    if isinstance(bq_schema, pd.DataFrame):
        bq_schema = bq_schema.to_dict(orient="records")

    # Extract available column names
    available_columns = [col["column_name"] for col in bq_schema]

    # Define potential key columns to test
    test_columns = ["user_type", "country", "data_source", "date", "unique_pageviews", "last_non_direct_channel"]
    selected_columns = [col for col in test_columns if col in available_columns]

    if not selected_columns:
        return f"No matching columns found for table {bq_table} to perform random sampling."

    # Dynamically construct the SELECT statement
    columns_to_query = ", ".join(selected_columns)
    sample_query = f"""
    SELECT {columns_to_query}
    FROM `{bq_project_id}.{dataset_name}.{bq_table}`
    WHERE DATE BETWEEN '{start_date}' AND '{end_date}'
    AND {selected_columns[0]} IS NOT NULL
    LIMIT 5000
    """
    print(f"Executing BigQuery random sampling query:\n{sample_query}")
    df_candidates = bq_client.query(sample_query).to_dataframe()

    if df_candidates.empty:
        return f"No data found in BigQuery table {bq_table} for the given date range."

    # Randomly sample rows from the retrieved candidates
    df_random = df_candidates.sample(n=min(sample_size, len(df_candidates)))

    mismatch_count = 0

    # Loop through sampled rows and check existence in Snowflake
    for idx, row in df_random.iterrows():
        # Dynamically build conditions for Snowflake query
        conditions = []
        for col in selected_columns:
            if col == "date":
                conditions.append(f"TO_DATE({col}) = TO_DATE('{row[col]}')")
            elif isinstance(row[col], str):
                conditions.append(f"{col} = '{row[col]}'")
            else:
                conditions.append(f"{col} = {row[col]}")
        conditions_str = " AND ".join(conditions)

        sf_check_query = f"""
        SELECT COUNT(*) AS CUSTOM_AGG
        FROM {SNOWFLAKE_DB}.{SNOWFLAKE_SCHEMA}.{sf_table}
        WHERE {conditions_str}
        """
        print(f"Executing Snowflake check query:\n{sf_check_query}")

        try:
            result = pd.read_sql(sf_check_query, conn)

            if "CUSTOM_AGG" in result.columns:
                count_in_sf = result["CUSTOM_AGG"].iloc[0] if not result.empty else 0
            else:
                print(f"  Column 'CUSTOM_AGG' is missing in the result.")
                count_in_sf = 0
        except Exception as e:
            print(f"Error executing Snowflake query: {e}")
            count_in_sf = 0

        if count_in_sf == 0:
            mismatch_count += 1

    # Return results
    if mismatch_count == 0:
        return f"All {sample_size} sampled rows exist in Snowflake."
    else:
        return f"{mismatch_count} out of {sample_size} sampled rows not found in Snowflake."

# 5. Runt tests

In [164]:
for table_map in TABLES_TO_TEST:
    bq_dataset = table_map["bq_dataset"]
    bq_table = table_map["bq_table"]
    sf_table = table_map["sf_table"]
    agg_columns = table_map.get("agg_columns", [])
    custom_aggregates = table_map.get("custom_aggregates", {})  # Fetch custom aggregates if defined
    
    print(f"\n--- Testing Table: BQ[{bq_dataset}.{bq_table}] vs SF[{sf_table}] ---")
    
    # Schema Comparison
    bq_schema = get_bq_schema(bq_dataset, bq_table)
    sf_schema = get_snowflake_schema(SNOWFLAKE_DB, SNOWFLAKE_SCHEMA, sf_table)
    missing_in_sf, extra_in_sf, type_mismatches = compare_schemas(bq_schema, sf_schema)
    
    print("Schema Check Results:")
    print(f"  Missing in Snowflake: {missing_in_sf}" if missing_in_sf else "  No missing columns.")
    print(f"  Extra in Snowflake: {extra_in_sf}" if extra_in_sf else "  No extra columns.")
    if type_mismatches:
        for col, bq_type, sf_type in type_mismatches:
            print(f"  Type mismatch: {col} (BQ={bq_type}, SF={sf_type})")
    else:
        print("  No data type mismatches.")
    
    # Row Count & Aggregates by custom time windows
    for start_date, end_date in time_windows:
        bq_count = get_bq_row_count(bq_dataset, bq_table, start_date, end_date)
        sf_count = get_sf_row_count(SNOWFLAKE_DB, SNOWFLAKE_SCHEMA, sf_table, start_date, end_date)
        
        print(f"\nDate Range: {start_date} to {end_date}")
        print(f"   BQ row count: {bq_count}")
        print(f"   SF row count: {sf_count}")
        
        # Standard aggregates
        for col in agg_columns:
            bq_agg = get_bq_aggregate(bq_dataset, bq_table, start_date, end_date, col)
            sf_agg = get_sf_aggregate(SNOWFLAKE_DB, SNOWFLAKE_SCHEMA, sf_table, start_date, end_date, col)
            
            difference_pct = (
                0 if bq_agg == 0 and sf_agg == 0 else
                abs(bq_agg - sf_agg) / bq_agg * 100 if bq_agg != 0 else 100
            )
            print(f"     {col} => BQ: {bq_agg}, SF: {sf_agg}, Diff%: {difference_pct:.2f}")
        
        # Custom aggregates
        for agg_name, agg_expr in custom_aggregates.items():
            bq_custom_agg = get_bq_custom_aggregate(bq_dataset, bq_table, start_date, end_date, agg_expr)
            sf_custom_agg = get_sf_custom_aggregate(SNOWFLAKE_DB, SNOWFLAKE_SCHEMA, sf_table, start_date, end_date, agg_expr)
            
            difference_pct = (
                0 if bq_custom_agg == 0 and sf_custom_agg == 0 else
                abs(bq_custom_agg - sf_custom_agg) / bq_custom_agg * 100 if bq_custom_agg != 0 else 100
            )
            print(f"     {agg_name} => BQ: {bq_custom_agg}, SF: {sf_custom_agg}, Diff%: {difference_pct:.2f}")



--- Testing Table: BQ[dbt_prod_ga4_reporting.ga4__content_with_ua_union] vs SF[GA4_CONTENT] ---
Schema Check Results:
  No missing columns.
  Extra in Snowflake: {'export_timestamp', 'batch_id'}
  Type mismatch: conversions (BQ=array<struct<conversion_name string, on_page_conversions int64, on_page_conversion_value float64, converting_page_conversions int64, converting_page_conversion_value float64, session_conversions int64, session_conversions_unique int64, session_conversion_value float64, conversion_type string, conversion_type_group string>>, SF=variant)

Date Range: 2024-12-01 to 2024-12-31
   BQ row count: 8204534
   SF row count: 8153241
     unique_pageviews => BQ: 8207389, SF: 8186808, Diff%: 0.25
     any_conversion_session_conversions_unique => BQ: 147051, SF: 146742, Diff%: 0.21

Date Range: 2024-10-01 to 2024-12-31
   BQ row count: 26399946
   SF row count: 26284296
     unique_pageviews => BQ: 26051381, SF: 26018155, Diff%: 0.13
     any_conversion_session_conversions_u

In [154]:

test_bq_dataset = "dbt_prod_ga4_reporting"
test_bq_table = "ga4__content_with_ua_union"
test_sf_table = "GA4_CONTENT"
test_start_date = "2024-10-01"
test_end_date = "2024-10-11"
test_sample_size = 5

result = random_sample_check(
    dataset_name=test_bq_dataset,
    bq_table=test_bq_table,
    sf_table=test_sf_table,
    start_date=test_start_date,
    end_date=test_end_date,
    sample_size=test_sample_size
)


print(result)

Executing BigQuery random sampling query:

    SELECT user_type, country, data_source, date, unique_pageviews, last_non_direct_channel
    FROM `okta-ga-rollup.dbt_prod_ga4_reporting.ga4__content_with_ua_union`
    WHERE DATE BETWEEN '2024-10-01' AND '2024-10-11'
    AND user_type IS NOT NULL
    LIMIT 5000
    
Executing Snowflake check query:

        SELECT COUNT(*) AS CUSTOM_AGG
        FROM OKTA.OKTA_DT.GA4_CONTENT
        WHERE user_type = 'Returning User' AND country = 'United States' AND data_source = 'GA4' AND TO_DATE(date) = TO_DATE('2024-10-11') AND unique_pageviews = 1 AND last_non_direct_channel = 'Paid Search'
        
Executing Snowflake check query:

        SELECT COUNT(*) AS CUSTOM_AGG
        FROM OKTA.OKTA_DT.GA4_CONTENT
        WHERE user_type = 'Returning User' AND country = 'United States' AND data_source = 'GA4' AND TO_DATE(date) = TO_DATE('2024-10-11') AND unique_pageviews = 1 AND last_non_direct_channel = 'Paid Search'
        
Executing Snowflake check query: