# Chapter 14. Distributed Shared Variables

In [2]:
from pyspark.sql import SparkSession

spark = SparkSession.builder\
    .config("spark.sql.shuffle.partitions", 5)\
    .appName("Distributed Shared Variables")\
    .getOrCreate()

24/09/04 10:31:58 WARN Utils: Your hostname, Khanhs-MAC.local resolves to a loopback address: 127.0.0.1; using 192.168.254.18 instead (on interface en0)
24/09/04 10:31:58 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/09/04 10:31:59 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


## Broadcast Variables

In [3]:
my_collection = "Spark The Definitive Guide : Big Data Processing Made Simple"\
    .split(" ")

words = spark.sparkContext.parallelize(my_collection, 2)

In [4]:
supplementalData = {"Spark": 1000, "Definitive": 200,
                    "Big": 300, "Simple": 100}

In [5]:
suppBroadcast = spark.sparkContext.broadcast(supplementalData)

In [6]:
suppBroadcast.value

{'Spark': 1000, 'Definitive': 200, 'Big': 300, 'Simple': 100}

In [7]:
words.map(lambda word: (word, suppBroadcast.value.get(word, 0)))\
    .sortBy(lambda wordPair: wordPair[1])\
    .collect()

                                                                                

[('The', 0),
 ('Guide', 0),
 (':', 0),
 ('Data', 0),
 ('Processing', 0),
 ('Made', 0),
 ('Simple', 100),
 ('Definitive', 200),
 ('Big', 300),
 ('Spark', 1000)]

## Accumulators

In [8]:
flights = spark.read\
    .parquet("../data/flight-data/parquet/2010-summary.parquet")

In [9]:
accChina = spark.sparkContext.accumulator(0)

In [10]:
def accChinaFunc(flight_row):
    destination = flight_row["DEST_COUNTRY_NAME"]
    origin = flight_row["ORIGIN_COUNTRY_NAME"]
    if destination == "China":
        accChina.add(flight_row["count"])
    if origin == "China":
        accChina.add(flight_row["count"])

In [11]:
flights.foreach(lambda flight_row: accChinaFunc(flight_row))

                                                                                

In [12]:
accChina.value

953

## Custom Accumulators

In [14]:
from pyspark import SparkContext, AccumulatorParam

# Define a custom AccumulatorParam for even numbers
class EvenAccumulatorParam(AccumulatorParam):
    def zero(self, initial_value):
        return 0

    def addInPlace(self, v1, v2):
        if v2 % 2 == 0:
            return v1 + v2
        return v1

# Initialize SparkContext
sc = SparkContext.getOrCreate()

# Create and register the accumulator
even_acc = sc.accumulator(0, EvenAccumulatorParam())

# Function to add values to the accumulator
def add_even_count(flight_row):
    even_acc.add(flight_row['count'])

# Apply the accumulator to the flights dataset
flights.foreach(lambda flight_row: add_even_count(flight_row))

# Retrieve the value of the accumulator
print(even_acc.value)  

31390
