## Set Up

In [None]:
from datetime import date

import pandas as pd
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()

In [None]:
# Create a Spark DataFrame
item_price_pandas = pd.DataFrame(
    {
        "item_id": [1, 2, 3, 4],
        "price": [4, 2, 5, 1],
        "transaction_date": [
            date(2025, 1, 15),
            date(2025, 2, 1),
            date(2025, 3, 10),
            date(2025, 4, 22),
        ],
    }
)

item_price = spark.createDataFrame(item_price_pandas)
item_price.show()

## Traditional Query Approach

In [None]:
item_price.createOrReplaceTempView("item_price_view")
transaction_date_str = "2025-02-15"

query_with_fstring = f"""SELECT *
FROM item_price_view
WHERE transaction_date > '{transaction_date_str}'
"""

spark.sql(query_with_fstring).show()

## Parameterized Queries with PySpark Custom String Formatting

In [None]:
parametrized_query = """SELECT *
FROM {item_price}
WHERE transaction_date > {transaction_date}
"""

spark.sql(
    parametrized_query, item_price=item_price, transaction_date=transaction_date_str
).show()

## Parameterized Queries with Parameter Markers

In [None]:
query_with_markers = """SELECT *
FROM {item_price}
WHERE transaction_date > :transaction_date
"""

transaction_date = date(2025, 2, 15)

spark.sql(
    query_with_markers,
    item_price=item_price,
    args={"transaction_date": transaction_date},
).show()

## Make SQL Easier to Reuse

In [None]:
transaction_date_1 = date(2025, 3, 9)

spark.sql(
    query_with_markers,
    item_price=item_price,
    args={"transaction_date": transaction_date_1},
).show()

In [None]:
transaction_date_2 = date(2025, 3, 15)

spark.sql(
    query_with_markers,
    item_price=item_price,
    args={"transaction_date": transaction_date_2},
).show()

## Easier Unit Testing with Parameterized Queries

In [None]:
def filter_by_price_threshold(df, amount):
    return spark.sql(
        "SELECT * from {df} where price > :amount", df=df, args={"amount": amount}
    )


In [None]:
# Create test input DataFrame
df = spark.createDataFrame(
    [
        ("Product 1", 10.0, 5),
        ("Product 2", 15.0, 3),
        ("Product 3", 8.0, 2),
    ],
    ["name", "price", "quantity"],
)

# Execute query with parameters
assert filter_by_price_threshold(df, 10).count() == 1
assert filter_by_price_threshold(df, 8).count() == 2