## Accumulator Demo

In [6]:
from pyspark.sql import SparkSession
from pyspark.sql.types import IntegerType
from pyspark.sql.functions import expr

def handle_bad_record(shipments: str) -> int:
    s = None
    try:
        s = int(shipments)
    except ValueError:
        bad_rec.add(1)
    return s

if __name__ == "__main__":
    spark = SparkSession \
        .builder\
        .appName("accumulator_demo")\
        .master("local[3]")\
        .getOrCreate()
    
    data_list = [("India", "India", '5'),
                 ("India", "China", '7'),
                 ("China", "India", 'three'),
                 ("China", "China", '6'),
                 ("Japan", "India", 'Five'),
    ]
    df = spark.createDataFrame(data_list).toDF("source","destination","shipments")
    bad_rec = spark.sparkContext.accumulator(0)
    
    # register udf
    spark.udf.register("udf_handle_bad_record", handle_bad_record, IntegerType())
    
    df.withColumn("shipment_int", expr("udf_handle_bad_record(shipments)")).show()
    
    print(f"Bad Record Counts {str(bad_rec.value)}")
    
    spark.stop()

23/01/22 14:32:46 WARN SimpleFunctionRegistry: The function udf_handle_bad_record replaced a previously registered function.


+------+-----------+---------+------------+
|source|destination|shipments|shipment_int|
+------+-----------+---------+------------+
| India|      India|        5|           5|
| India|      China|        7|           7|
| China|      India|    three|        null|
| China|      China|        6|           6|
| Japan|      India|     Five|        null|
+------+-----------+---------+------------+

Bad Record Counts 2
