In [None]:
# imports
from snowflake.snowpark.context import get_active_session
from snowflake.snowpark.functions import col, regexp_extract, lit
from snowflake.snowpark import DataFrame
import streamlit as st
import altair as alt
from datetime import datetime
import pandas as pd

# get snowpark active session
session = get_active_session()


In [None]:
def load_data(db_schema: str, wh_size: str, tbls_to_load:list[str]) -> list[str]:
    # change the warehouse size"
    session.sql(f"ALTER WAREHOUSE LOAD SET WAREHOUSE_SIZE='{wh_size}'").collect()
    query_ids: list[str] = []

    # set the database schema
    session.use_schema(db_schema)
    
    # truncate and load table in parallel using snowpark
    for tbl in tbls_to_load:
        # print(f"loading {db_schema}.{tbl} from @{location}/{tbl.lower()}/ with {fmt_name} using {wh_size}")
        _ = session.sql(f"TRUNCATE TABLE {tbl}").collect()
        job = session.sql(f"""
            COPY INTO {db_schema}.{tbl}
            FROM @{location}/{tbl.lower()}/
            FILE_FORMAT = ( FORMAT_NAME = '{fmt_name}')
            MATCH_BY_COLUMN_NAME = CASE_INSENSITIVE
            FORCE = TRUE
            """).collect_nowait()
        query_ids.append(job.query_id)
    return query_ids

In [None]:
def wait_till_end_and_scale_down(query_ids: list[str]) -> None:
    # wait till copy async jobs ended on session
    while True:
        # check if there are copy into queries in running status for the current session
        df = (
            session.table_function(
                "MEETUP_GDDP.INFORMATION_SCHEMA.QUERY_HISTORY_BY_SESSION",
                result_limit=lit(10000)
            )
            .filter(
                ~(col("EXECUTION_STATUS").in_(["SUCCESS", "FAILED_WITH_ERROR", "FAILED_WITH_INCIDENT",
                                              "ABORTED", "DISCONNECTED"]) )
                & (col("QUERY_ID").in_(query_ids))
            )
        )

        if df.count() == 0:
            break

    # scale down the warehouse
    session.sql(f"ALTER WAREHOUSE LOAD SET WAREHOUSE_SIZE='X-SMALL'").collect()

In [None]:
def display_results(query_ids: list[str]) -> None:
    from snowflake.snowpark.functions import col, regexp_extract
    # get queries stats by table
    regex_pattern = r'COPY INTO \S+\.\S+\.(\S+)'
    df = (
            session.table_function(
                "MEETUP_GDDP.INFORMATION_SCHEMA.QUERY_HISTORY_BY_SESSION",
                result_limit=lit(10000)
            )
            .filter((col("QUERY_ID").in_(query_ids)))
            .select(
                [
                    "SESSION_ID",
                    "QUERY_ID",
                    "QUERY_TEXT",
                    "WAREHOUSE_SIZE",
                    "START_TIME",
                    "END_TIME",
                    (col("TOTAL_ELAPSED_TIME") / 1000).alias("TOTAL_ELAPSED_TIME_SECONDS"),
                    (col("COMPILATION_TIME") / 1000).alias("COMPILATION_TIME_SECONDS"),
                    (col("QUEUED_OVERLOAD_TIME") / 1000).alias("QUEUED_OVERLOAD_TIME_SECONDS"),
                    "ROWS_PRODUCED",
                    regexp_extract(col("QUERY_TEXT"), regex_pattern, 1).alias("TABLE_NAME")
                ]
            )
        )


    # Collect the DataFrame to Pandas
    df_pandas = df.to_pandas()

    # Display the total elapsed time as a Streamlit metric
    total_duration_seconds = (pd.to_datetime(df_pandas['END_TIME']).max() - pd.to_datetime(df_pandas['START_TIME']).min()).total_seconds()
    st.metric(label="Total Elapsed Time (seconds)",
              value=f"{total_duration_seconds:.2f}")

    # Create the Altair bar chart
    bar_chart = (
        alt.Chart(df_pandas)
        .mark_bar(color="#872D60")
        .encode(
            x=alt.X("TABLE_NAME:N", title="Table", axis=alt.Axis(labelAngle=-90, labelLimit=0)),
            y=alt.Y("TOTAL_ELAPSED_TIME_SECONDS:Q", title="Duration in seconds)"),
            xOffset=alt.XOffset("WAREHOUSE_SIZE:N")
        )
    )

    # Display the Altair chart in Streamlit
    st.altair_chart(bar_chart, use_container_width=True)

    # Display the DataFrame in Streamlit
    st.write(df_pandas)

In [None]:
# file format and base location on external stage
# wh_size = "Medium"
# tbls_to_load = ["CUSTOMER", "LINEITEM", "NATION", "ORDERS", "PART", "PARTSUPP", "SUPPLIER", "REGION"]

wh_size = "X-Small"
tbls_to_load = ["NATION", "REGION"]

db_schema = "MEETUP_GDDP.TPCH_SF100"
fmt_name = "MEETUP_GDDP.UTILS.CSV_FMT1"
location = "MEETUP_GDDP.UTILS.LANDING/tpch-sf100/csv"
view_results = True

query_ids = load_data(db_schema, wh_size, tbls_to_load)
wait_till_end_and_scale_down(query_ids)
if view_results:
    display_results(query_ids)