<h2>사용자 정의 함수(User defined function; UDF)

스파크 SQL UDF

In [15]:
import os
os.environ["PYSPARK_PYTHON"] = "C:/Python39/python.exe"  # 실제 경로로 수정
os.environ["PYSPARK_DRIVER_PYTHON"] = "python"

In [16]:
from pyspark.sql.types import LongType
from pyspark.sql import SparkSession
driver_path = os.path.abspath(r"C:\mysql-connector-j-8.3.0\mysql-connector-j-8.3.0.jar")

# SparkSession 생성
spark = (SparkSession
         .builder
         .config("spark.jars", driver_path)      
         .appName("SparkMllibExampleApp")
         .getOrCreate())

In [3]:
# 큐브 함수 생성
def cubed(s):
    return s*s*s

# UDF로 등록
spark.udf.register("cubed", cubed, LongType())

# 임시 뷰 생성
spark.range(1,9).createOrReplaceTempView("udf_test")

spark.sql("select id, cubed(id) as id_cubed from udf_test").show()

+---+--------+
| id|id_cubed|
+---+--------+
|  1|       1|
|  2|       8|
|  3|      27|
|  4|      64|
|  5|     125|
|  6|     216|
|  7|     343|
|  8|     512|
+---+--------+



판다스 UDF

In [4]:
import pandas as pd
from pyspark.sql.functions import col, pandas_udf
from pyspark.sql.types import LongType

def cubed(a: pd.Series) -> pd.Series:
    return a*a*a

cubed_udf = pandas_udf(cubed, returnType= LongType())

In [5]:
x = pd.Series([1,2,3])
print(cubed(x))

0     1
1     8
2    27
dtype: int64


In [6]:
df = spark.range(1,4)
df.select('id',cubed_udf(col('id'))).show()

+---+---------+
| id|cubed(id)|
+---+---------+
|  1|        1|
|  2|        8|
|  3|       27|
+---+---------+



<h3>MySQL DB 연결

In [7]:
ip = "127.0.0.1"
port = "3306" 
user = 'root'
passwd = ''
db = 'news_project'
table_name = 'search_information'

In [None]:
url = f"jdbc:mysql://{ip}:{port}/{db}"

properties = {
    "user": user,
    "password": passwd,
    "driver": "com.mysql.cj.jdbc.Driver"
}

#데이터 읽기
df = (spark.read.format("jdbc")
    .option("url", url)
    .option("driver", "com.mysql.cj.jdbc.Driver")
    .option("dbtable", table_name)
    .option("user", user)
    .option("password", passwd)
    .load()
)

df.show()


+-------+-----+-------+----+----+
|user_id|title|summary|date|href|
+-------+-----+-------+----+----+
+-------+-----+-------+----+----+



데이터 쓰기

In [None]:
(df
 .write
 .format('jdbc')
 .options(
        url=url,  # DB 정보
        driver="com.mysql.cj.jdbc.Driver",
        dbtable= 'test',  # 저장할 테이블 이름
        user="root",
        password= ''
    ).save())

고차함수

In [19]:
from pyspark.sql.types import *

schema = StructType([StructField('celsius', ArrayType(IntegerType()))])

t_list = [[35,36,32,30,40,42,38]], [[31,32,34,55,56]]
t_c = spark.createDataFrame(t_list, schema)
t_c.createOrReplaceTempView('tC')

t_c.show()

+--------------------+
|             celsius|
+--------------------+
|[35, 36, 32, 30, ...|
|[31, 32, 34, 55, 56]|
+--------------------+



transform()

In [21]:
spark.sql("""SELECT celsius,
                    transform(celsius, t -> ((t * 9) div 5) + 32 ) AS fahrenheit
             FROM tC""").show()

+--------------------+--------------------+
|             celsius|          fahrenheit|
+--------------------+--------------------+
|[35, 36, 32, 30, ...|[95, 96, 89, 86, ...|
|[31, 32, 34, 55, 56]|[87, 89, 93, 131,...|
+--------------------+--------------------+



filter()

In [None]:
spark.sql("""SELECT celsius,
        filter(celsius, t -> t>38) AS high
        FROM tC""").show()

+--------------------+--------+
|             celsius|    high|
+--------------------+--------+
|[35, 36, 32, 30, ...|[40, 42]|
|[31, 32, 34, 55, 56]|[55, 56]|
+--------------------+--------+



exists()

In [23]:
spark.sql("""SELECT celsius,
                    exists(celsius, t -> t = 38) as threshold
             FROM tC""").show()

+--------------------+---------+
|             celsius|threshold|
+--------------------+---------+
|[35, 36, 32, 30, ...|     true|
|[31, 32, 34, 55, 56]|    false|
+--------------------+---------+



In [14]:
spark.stop()