In [1]:
import findspark
findspark.init()

In [5]:
import pandas as pd
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, pandas_udf,PandasUDFType 
from pyspark.sql.types import LongType, IntegerType, ArrayType
from pyspark.sql.functions import *

In [6]:
spark = (SparkSession.builder.appName('UDFwithPandas').getOrCreate())

In [6]:
def cubed(a: pd.Series)->pd.Series:
    return a*a*a

In [7]:
#Create the pandas UDF for the cubed function
cubed_udf = pandas_udf(cubed, returnType=LongType())

In [68]:
x = pd.Series([1,2,3,4])

In [9]:
print(cubed(x))

0     1
1     8
2    27
3    64
dtype: int64


In [11]:
df = spark.range(1,9)

In [15]:
df.createOrReplaceTempView('udf_test')

In [13]:
(df.select("id", cubed_udf(col('id'))).show())

+---+---------+
| id|cubed(id)|
+---+---------+
|  1|        1|
|  2|        8|
|  3|       27|
|  4|       64|
|  5|      125|
|  6|      216|
|  7|      343|
|  8|      512|
+---+---------+



In [24]:
#convert dataframe column to list 
id_list = df.select('id').collect()

In [25]:
spark.sql("""SELECT array_distinct(array(1,2,3,4,5,3))""").show()

+---------------------------------------+
|array_distinct(array(1, 2, 3, 4, 5, 3))|
+---------------------------------------+
|                        [1, 2, 3, 4, 5]|
+---------------------------------------+



### High-Order Function

In [5]:
from pyspark.sql.types import *

In [6]:
schema = StructType([StructField("celsius", ArrayType(IntegerType()))])

In [7]:
t_list = [[35, 36, 32, 30, 40, 42, 36]], [[31, 32, 34, 55, 56, 32]]
t_c = spark.createDataFrame(t_list, schema=schema)
t_c.createOrReplaceTempView('tC')

In [8]:
t_c.show(truncate=False)

+----------------------------+
|celsius                     |
+----------------------------+
|[35, 36, 32, 30, 40, 42, 36]|
|[31, 32, 34, 55, 56, 32]    |
+----------------------------+



In [10]:
#transform
fahrenheit_t=spark.sql("""SELECT celsius, transform(celsius, t->((t*9) div 5) +32) as fahrenheit
    FROM tC """)

In [11]:
fahrenheit_t.write.saveAsTable('fahrenheit_transform_tbl')

In [19]:
t_c.select('celsius',transform(col('id'), lambda x: ((x*9) / 5)+32).alias('fahrenheit')).show() #new in version 3.1

NameError: name 'transform' is not defined

In [12]:
spark.sql("""SELECT filter(celsius, t->t>32)as high FROM tC""").show() #filter

+--------------------+
|                high|
+--------------------+
|[35, 36, 40, 42, 36]|
|        [34, 55, 56]|
+--------------------+



In [13]:
spark.sql("""SELECT celsius, exists(celsius, t->t>42) as threshold FROM tC """).show()

+--------------------+---------+
|             celsius|threshold|
+--------------------+---------+
|[35, 36, 32, 30, ...|    false|
|[31, 32, 34, 55, ...|     true|
+--------------------+---------+



In [16]:
def calculate(cel):
    sum_e = 0
    for v in cel:
        sum_e += v
    avg = sum_e/len(cel)
    return (avg*9)/5+32
        
def avgAndTransform(a:pd.Series)->pd.Series:
    return a.map(lambda x: calculate(x))

In [18]:
avgAndTransform_udf = pandas_udf(avgAndTransform, returnType=LongType())

In [22]:
#caculate average celsius and transform into fahrenheit using reduce() function
(t_c.select('celsius', avgAndTransform_udf(col('celsius')).alias('avgFahrenheit')).show(truncate=False))

+----------------------------+-------------+
|celsius                     |avgFahrenheit|
+----------------------------+-------------+
|[35, 36, 32, 30, 40, 42, 36]|96           |
|[31, 32, 34, 55, 56, 32]    |104          |
+----------------------------+-------------+



In [15]:
fahrenheit = spark.read.parquet('spark-warehouse/fahrenheit_transform_tbl')

In [16]:
fahrenheit.show()

+--------------------+--------------------+
|             celsius|          fahrenheit|
+--------------------+--------------------+
|[35, 36, 32, 30, ...|[95, 96, 89, 86, ...|
|[31, 32, 34, 55, ...|[87, 89, 93, 131,...|
+--------------------+--------------------+



In [15]:
fahrenheit.show(truncate=False)

+----------------------------+------------------------------+
|celsius                     |fahrenheit                    |
+----------------------------+------------------------------+
|[35, 36, 32, 30, 40, 42, 36]|[95, 96, 89, 86, 104, 107, 96]|
|[31, 32, 34, 55, 56, 32]    |[87, 89, 93, 131, 132, 89]    |
+----------------------------+------------------------------+

