In [0]:
dbutils.widgets.removeAll()
dbutils.widgets.text("test_type", "join", "1. Test Type e.g join, aggregate")

In [0]:
test_type = dbutils.widgets.get("test_type")
print("test_type: ", test_type)

In [0]:
def load_transactions(spark):
    spark.read.parquet("dbfs:/transactions/*.parquet").createOrReplaceTempView(
        "transactions"
    )

def load_stores(spark):
    spark.read.parquet("dbfs:/stores/*.parquet").createOrReplaceTempView("stores")


def load_countries(spark):
    spark.read.parquet("dbfs:/countries/*.parquet").createOrReplaceTempView("countries")

## Preparation

In [0]:
load_transactions(spark)
load_stores(spark)
load_countries(spark)

## Run Tests

In [0]:
if test_type == "join":
    ## Test join w/o broadcast
    joined_df_no_broadcast = spark.sql("""
        SELECT 
            transactions.id,
            amount,
            countries.name as country_name,
            employees,
            stores.name as store_name
        FROM
            transactions
        LEFT JOIN
            stores
            ON
                transactions.store_id = stores.id
        LEFT JOIN
            countries
            ON
                transactions.country_id = countries.id
    """)
    ## Create a table with the joined data
    (
        joined_df_no_broadcast.write.mode("overwrite").parquet(
            "dbfs:/photon_transact_countries"
        )
    )

elif test_type == "aggregate":
    ## Test groupBy
    grouped_df = spark.sql("""
        SELECT 
            country_id, 
            COUNT(*) AS count,
            AVG(amount) AS avg_amount
        FROM transactions
        GROUP BY country_id
    """)

    ## Create a table with the grouped_df data
    (grouped_df.write.mode("overwrite").parquet("dbfs:/photon_country_agg"))
