In [1]:
import random
import string
from typing import List

import pandas as pd
from pyspark.sql.functions import struct, col
from pyspark.sql.pandas.functions import pandas_udf
from pyspark.sql.types import StringType, DoubleType

from sparkstudy.deploy.demo_sessions import DemoSQLSessionFactory
from sparkstudy.libs.tools import create_random_data
%load_ext autoreload
%autoreload 2
%matplotlib inline

COLUMNS = ["name","age","salary"]

比较下开不开启arrow的区别

测试下来，感觉性能提升有点奇怪。有时候会快，有时候会慢。

In [4]:
def test_performance(session_factory:DemoSQLSessionFactory, n:int = 100000):
    data = create_random_data(n)
    spark_session = session_factory.build_session()
    df = spark_session.createDataFrame(data,COLUMNS).cache()
    df.toPandas().head(5)

CPU times: user 2.84 s, sys: 58.1 ms, total: 2.9 s
Wall time: 10.3 s


In [5]:
session_factory_arrow = DemoSQLSessionFactory(name="with arraw")
session_factory_arrow.add_config("spark.sql.execution.arrow.pyspark.enabled","true")
%time test_performance(session_factory_arrow)

CPU times: user 2.56 s, sys: 20.9 ms, total: 2.58 s
Wall time: 3.11 s


常规的HelloWorld的example。
页面上面的第一个例子。本质就是生成一个新的dataframe
1. 在annotation上面列出的是新的dataframe的col和类型
2. 他会自动的把pd的转换成spark的
3. 函数应该会分批node执行。然后再汇总。因为我看到了。hello world的函数会被执行好几次

In [6]:
session_factory_arrow.add_config('spark.sql.execution.arrow.maxRecordsPerBatch',10)
spark = session_factory_arrow.build_session()
test_data = create_random_data(row_num=1000)
basic_df = spark.createDataFrame(test_data,COLUMNS)
basic_df.show()

+----+---+-------------------+
|name|age|             salary|
+----+---+-------------------+
|   L|968| 0.4871012494101088|
|   J|310| 0.2164529909895111|
|   H|784| 0.9459266156019447|
|   C|875| 0.1757703778152585|
|   W|369|0.34885292079772745|
|   J|259|    0.9833618918608|
|   A|724|0.37111406445478035|
|   H|659|0.19371750530598741|
|   T|522|0.11004928119368063|
|   O|971| 0.8169672921580708|
|   L|766|0.22604935892515998|
|   F|240| 0.2845676594137294|
|   F|767|0.03420255096249325|
|   I|604| 0.7210673035833547|
|   L|560| 0.8370531228587831|
|   R|961| 0.2337764138858095|
|   Q|960|0.10799747674724702|
|   A|262| 0.8642967644099723|
|   G|873| 0.0977050978708397|
|   Z|887| 0.2777555551588018|
+----+---+-------------------+
only showing top 20 rows



In [7]:
@pandas_udf("total double")
def func(s1: pd.Series, s2: pd.Series) -> pd.DataFrame:
    print("execute")
    s3 = pd.DataFrame()
    s3['total'] = s1 + s2
    return s3
basic_df.select(func("age","salary").alias("result")).show()

+--------------------+
|              result|
+--------------------+
| [968.4871012494101]|
| [310.2164529909895]|
| [784.9459266156019]|
| [875.1757703778153]|
|[369.34885292079775]|
| [259.9833618918608]|
| [724.3711140644548]|
|  [659.193717505306]|
| [522.1100492811937]|
| [971.8169672921581]|
| [766.2260493589251]|
|[240.28456765941374]|
| [767.0342025509625]|
| [604.7210673035834]|
| [560.8370531228588]|
| [961.2337764138858]|
| [960.1079974767473]|
|[262.86429676440997]|
| [873.0977050978709]|
| [887.2777555551588]|
+--------------------+
only showing top 20 rows



主要是想要看看。select方法，不能不能接受一个List

In [8]:
def to_str_func(s1: pd.Series) -> pd.Series:
    return s1.astype(dtype=str)
to_str = pandas_udf(to_str_func, returnType=StringType())

age_c = to_str("age").alias("age")
salary_c = to_str("salary").alias("salary")
selects = [age_c,salary_c]
basic_df.select(selects).show()

+---+-------------------+
|age|             salary|
+---+-------------------+
|968| 0.4871012494101088|
|310| 0.2164529909895111|
|784| 0.9459266156019447|
|875| 0.1757703778152585|
|369|0.34885292079772745|
|259|    0.9833618918608|
|724|0.37111406445478035|
|659|0.19371750530598741|
|522|0.11004928119368063|
|971| 0.8169672921580708|
|766|0.22604935892515998|
|240| 0.2845676594137294|
|767|0.03420255096249325|
|604| 0.7210673035833547|
|560| 0.8370531228587831|
|961| 0.2337764138858095|
|960|0.10799747674724702|
|262| 0.8642967644099723|
|873| 0.0977050978708397|
|887| 0.2777555551588018|
+---+-------------------+
only showing top 20 rows



测试以下。如果参数是不定的行不行

简单的来书，
- 确定的column个数，用Series
- 不确定用dataframe
- iterator是类似用流

In [9]:
@pandas_udf("double")
def to_sum_func(data: pd.DataFrame) -> pd.Series:
    return data.age*data.salary
cols = [col("age"),col("salary")]
headers = struct(cols)
#my_sum = pandas_udf(to_sum_func, returnType=DoubleType())
basic_df.select(to_sum_func(headers).alias("result")).show()


+------------------+
|            result|
+------------------+
|471.51400942898533|
| 67.10042720674844|
| 741.6064666319246|
| 153.7990805883512|
|128.72672777436142|
| 254.6907299919472|
|268.68658266526097|
|127.65983599664571|
|57.445724783101284|
| 793.2752406854868|
|173.15380893667253|
| 68.29623825929505|
|26.233356588232322|
|435.52465136434625|
|468.74974880091855|
|224.65913374426293|
|103.67757767735714|
|226.44575227541273|
| 85.29655044124306|
| 246.3691774258572|
+------------------+
only showing top 20 rows



能不能用于SQL

In [10]:
basic_df.createOrReplaceTempView("pandas_udf")
spark.udf.register("pandas_to_str", to_str)
spark.sql("select pandas_to_str(age) from pandas_udf").show()

+------------------+
|pandas_to_str(age)|
+------------------+
|               968|
|               310|
|               784|
|               875|
|               369|
|               259|
|               724|
|               659|
|               522|
|               971|
|               766|
|               240|
|               767|
|               604|
|               560|
|               961|
|               960|
|               262|
|               873|
|               887|
+------------------+
only showing top 20 rows



basic_df.createOrReplaceTempView("pandas_udf")
spark.udf.register("pandas_to_str", to_str)
spark.sql("select pandas_to_str(age) from pandas_udf").show()

`__call__`这个方法能不能用哪用

In [11]:
class PandasFunc:
    def __call__(self, data: pd.DataFrame)-> pd.Series:
         return data.age*data.salary

cols = [col("age"),col("salary")]
headers = struct(cols)
class_my_sum = pandas_udf(PandasFunc(), returnType=DoubleType())
basic_df.select(class_my_sum(headers).alias("result")).show()

+------------------+
|            result|
+------------------+
|471.51400942898533|
| 67.10042720674844|
| 741.6064666319246|
| 153.7990805883512|
|128.72672777436142|
| 254.6907299919472|
|268.68658266526097|
|127.65983599664571|
|57.445724783101284|
| 793.2752406854868|
|173.15380893667253|
| 68.29623825929505|
|26.233356588232322|
|435.52465136434625|
|468.74974880091855|
|224.65913374426293|
|103.67757767735714|
|226.44575227541273|
| 85.29655044124306|
| 246.3691774258572|
+------------------+
only showing top 20 rows



返回多列的处理方法。

In [12]:
@pandas_udf("col1 double, col2 double")
def to_multi_return_func(data: pd.DataFrame) -> pd.DataFrame:
    print("execute")
    s3 = pd.DataFrame()
    s3['col1'] = data.age
    s3['col2'] = data.salary
    return s3
cols = [col("age"),col("salary")]
headers = struct(cols)
#my_sum = pandas_udf(to_sum_func, returnType=DoubleType())
multi_return_df = basic_df.withColumn("abc",to_multi_return_func(headers))
multi_return_df.select(col("age"),col("salary"),col("abc.col1"),col("abc.col2")).show()


+---+-------------------+-----+-------------------+
|age|             salary| col1|               col2|
+---+-------------------+-----+-------------------+
|968| 0.4871012494101088|968.0| 0.4871012494101088|
|310| 0.2164529909895111|310.0| 0.2164529909895111|
|784| 0.9459266156019447|784.0| 0.9459266156019447|
|875| 0.1757703778152585|875.0| 0.1757703778152585|
|369|0.34885292079772745|369.0|0.34885292079772745|
|259|    0.9833618918608|259.0|    0.9833618918608|
|724|0.37111406445478035|724.0|0.37111406445478035|
|659|0.19371750530598741|659.0|0.19371750530598741|
|522|0.11004928119368063|522.0|0.11004928119368063|
|971| 0.8169672921580708|971.0| 0.8169672921580708|
|766|0.22604935892515998|766.0|0.22604935892515998|
|240| 0.2845676594137294|240.0| 0.2845676594137294|
|767|0.03420255096249325|767.0|0.03420255096249325|
|604| 0.7210673035833547|604.0| 0.7210673035833547|
|560| 0.8370531228587831|560.0| 0.8370531228587831|
|961| 0.2337764138858095|961.0| 0.2337764138858095|
|960|0.10799