In [None]:
from pyspark.sql.functions import udf, col, broadcast
from pyspark.sql.types import StringType
from pyspark.sql import SparkSession

In [None]:
spark = SparkSession.builder.appName("broadcast_n_accumulators").getOrCreate()

In [None]:
# preparing data
data = [
    (1, 101, "Apple", 3, 2.5, "2024-01-01"),
    (2, 102, "Banana", 5, 1.2, "2024-01-02"),
    (3, 101, "Orange", 2, 3.0, "2024-01-03"),
    (4, 103, "Milk", 1, 4.5, "2024-01-04"),
    (5, 104, "Eggs", 12, 0.2, "2024-01-05"),
    (6, 102, "Bread", 2, 2.0, "2024-01-06"),
    (7, 105, "Butter", 1, 3.5, "2024-01-07"),
    (8, 106, "Cheese", 1, 5.0, "2024-01-08"),
    (9, 103, "Cereal", 2, 3.8, "2024-01-09"),
    (10, 107, "Juice", 1, 2.5, "2024-01-10")
]

# customer dictionary
customer_dict = {
    101: "John Doe",
    102: "Jane Smith",
    103: "Emily Davis",
    104: "Michael Brown",
    105: "Sarah Wilson",
    106: "Chris Johnson",
    107: "Laura Lee"
}

In [None]:
# broadcasting the customer dictionary
Customers=spark.sparkContext.broadcast(customer_dict)

In [None]:
def getCustomerName(cust_id):
    return Customers.value[cust_id]
getCustomerName_udf = udf(getCustomerName, StringType())

In [None]:
columns = ["TransactionID", "Customer", "Product", "Quantity", "Price", "TransactionDate"]
df = spark.createDataFrame(data, schema = columns)

In [None]:
# transform the dataset
transformed_df = df.withColumn("Customer", getCustomerName_udf(col("Customer")))
transformed_df.show()

## Broadcasting entire DataFrame

In [None]:
# preparing data
data = [
    (1, 101, "Apple", 3, 2.5, "2024-01-01"),
    (2, 102, "Banana", 5, 1.2, "2024-01-02"),
    (3, 101, "Orange", 2, 3.0, "2024-01-03"),
    (4, 103, "Milk", 1, 4.5, "2024-01-04"),
    (5, 104, "Eggs", 12, 0.2, "2024-01-05"),
    (6, 102, "Bread", 2, 2.0, "2024-01-06"),
    (7, 105, "Butter", 1, 3.5, "2024-01-07"),
    (8, 106, "Cheese", 1, 5.0, "2024-01-08"),
    (9, 103, "Cereal", 2, 3.8, "2024-01-09"),
    (10, 107, "Juice", 1, 2.5, "2024-01-10")
]

# Define the customer data
customer_data = [
    (101, "John Doe", 28, "New York", "johndoe@example.com"),
    (102, "Jane Smith", 34, "Los Angeles", "janesmith@example.com"),
    (103, "Emily Davis", 23, "Chicago", "emilydavis@example.com"),
    (104, "Michael Brown", 40, "Houston", "michaelbrown@example.com"),
    (105, "Sarah Wilson", 30, "San Francisco", "sarahwilson@example.com"),
    (106, "Chris Johnson", 36, "Seattle", "chrisjohnson@example.com"),
    (107, "Laura Lee", 27, "Austin", "lauralee@example.com")
]

In [None]:
customer_columns = ["CustomerID", "Name", "Age", "City", "Email"]
columns = ["TransactionID", "CustomerID", "Product", "Quantity", "Price", "TransactionDate"]
df = spark.createDataFrame(data, schema = columns)
customer_df = spark.createDataFrame(customer_data, customer_columns)

In [None]:
cust_broad = broadcast(customer_df)

# transform the dataset
transformed_df = df.join(cust_broad, df.CustomerID == cust_broad.CustomerID).selectExpr("TransactionID", "Product", "Quantity", "Price", "TransactionDate", "Name as CustomerName", "Age", "City", "Email")
transformed_df.show()

## Checking explain plan

### With broadcating the dataframe

In [0]:
# transform the dataset
transformed_df = df.join(cust_broad, df.CustomerID == cust_broad.CustomerID).selectExpr("TransactionID", "Product", "Quantity", "Price", "TransactionDate", "Name as CustomerName", "Age", "City", "Email")
transformed_df.explain()

### Without broadcating the dataframe

In [0]:
# transform the dataset
transformed_df = df.join(customer_df, df.CustomerID == customer_df.CustomerID).selectExpr("TransactionID", "Product", "Quantity", "Price", "TransactionDate", "Name as CustomerName", "Age", "City", "Email")
transformed_df.explain()

## Accumulators

In [0]:
# Create an accumulator for adding
acc = spark.sparkContext.accumulator(0)

acc.add(10)
print(acc.value)