In [1]:
import string
import random
import pandas as pd
from typing import List,Iterator,Tuple
from pyspark.sql.pandas.functions import pandas_udf
from pyspark.sql.functions import struct, col

from pyspark.sql.types import StringType,DoubleType
from sparkstudy.deploy.demo_sessions import DemoSQLSessionFactory
%load_ext autoreload
%autoreload 2
%matplotlib inline

COLUMNS = ["name","age","salary"]

比较下开不开启arrow的区别

测试下来，感觉性能提升有点奇怪。有时候会快，有时候会慢。

In [2]:
def create_random_data(row_num:int)->List[tuple]:
     result = list()
     a_str = string.ascii_uppercase
     for i in range(row_num):
         random_letter = random.choice(a_str)
         result.append((random_letter,random.randint(1,row_num),random.random()))
     return result

In [3]:
def test_performance(session_factory:DemoSQLSessionFactory, n:int = 100000):
    data = create_random_data(n)
    spark_session = session_factory.build_session()
    df = spark_session.createDataFrame(data,COLUMNS).cache()
    df.toPandas().head(5)

In [4]:
session_factory_normal = DemoSQLSessionFactory(name="normal")
%time test_performance(session_factory_normal)

CPU times: user 3.17 s, sys: 57.3 ms, total: 3.23 s
Wall time: 11.3 s


In [5]:
session_factory_arrow = DemoSQLSessionFactory(name="with arraw")
session_factory_arrow.add_config("spark.sql.execution.arrow.pyspark.enabled","true")
%time test_performance(session_factory_arrow)

CPU times: user 2.64 s, sys: 20.7 ms, total: 2.66 s
Wall time: 3.41 s


常规的HelloWorld的example。
页面上面的第一个例子。本质就是生成一个新的dataframe
1. 在annotation上面列出的是新的dataframe的col和类型
2. 他会自动的把pd的转换成spark的
3. 函数应该会分批node执行。然后再汇总。因为我看到了。hello world的函数会被执行好几次

In [6]:
session_factory_arrow.add_config('spark.sql.execution.arrow.maxRecordsPerBatch',10)
spark = session_factory_arrow.build_session()
test_data = create_random_data(row_num=1000)
basic_df = spark.createDataFrame(test_data,COLUMNS)
basic_df.show()

+----+----+-------------------+
|name| age|             salary|
+----+----+-------------------+
|   X| 981| 0.7073095672945235|
|   Q| 902|0.09765986318709052|
|   V|  10| 0.3852364010725843|
|   B| 558|0.49268180032026987|
|   E| 960|0.15956741825769705|
|   T| 347|  0.134229798131681|
|   U|  94|0.20868403846792571|
|   E| 102| 0.5278589754699419|
|   I| 398| 0.9390961853964784|
|   Z| 787| 0.2688598344579807|
|   H| 552| 0.4164119910611638|
|   F| 936| 0.9507560150943194|
|   X| 574| 0.6550968641233801|
|   B| 377| 0.8364295345267588|
|   C|1000| 0.8048297391977565|
|   U| 593|0.39202752999800294|
|   V| 489| 0.7930708888138559|
|   W| 297|  0.904389303455135|
|   R| 926|  0.589543682410987|
|   G|  70| 0.5786476288393697|
+----+----+-------------------+
only showing top 20 rows



In [7]:
@pandas_udf("total double")
def func(s1: pd.Series, s2: pd.Series) -> pd.DataFrame:
    print("execute")
    s3 = pd.DataFrame()
    s3['total'] = s1 + s2
    return s3
basic_df.select(func("age","salary").alias("result")).show()

+--------------------+
|              result|
+--------------------+
| [981.7073095672945]|
| [902.0976598631871]|
|[10.385236401072584]|
| [558.4926818003203]|
| [960.1595674182577]|
| [347.1342297981317]|
| [94.20868403846792]|
|[102.52785897546994]|
|[398.93909618539647]|
|  [787.268859834458]|
| [552.4164119910612]|
| [936.9507560150943]|
| [574.6550968641234]|
|[377.83642953452676]|
|[1000.8048297391978]|
|  [593.392027529998]|
| [489.7930708888139]|
|[297.90438930345516]|
|  [926.589543682411]|
| [70.57864762883936]|
+--------------------+
only showing top 20 rows



主要是想要看看。select方法，不能不能接受一个List

In [8]:
def to_str_func(s1: pd.Series) -> pd.Series:
    return s1.astype(dtype=str)
to_str = pandas_udf(to_str_func, returnType=StringType())

age_c = to_str("age").alias("age")
salary_c = to_str("salary").alias("salary")
selects = [age_c,salary_c]
basic_df.select(selects).show()

+----+-------------------+
| age|             salary|
+----+-------------------+
| 981| 0.7073095672945235|
| 902|0.09765986318709052|
|  10| 0.3852364010725843|
| 558|0.49268180032026987|
| 960|0.15956741825769705|
| 347|  0.134229798131681|
|  94|0.20868403846792571|
| 102| 0.5278589754699419|
| 398| 0.9390961853964784|
| 787| 0.2688598344579807|
| 552| 0.4164119910611638|
| 936| 0.9507560150943194|
| 574| 0.6550968641233801|
| 377| 0.8364295345267588|
|1000| 0.8048297391977565|
| 593|0.39202752999800294|
| 489| 0.7930708888138559|
| 297|  0.904389303455135|
| 926|  0.589543682410987|
|  70| 0.5786476288393697|
+----+-------------------+
only showing top 20 rows



测试以下。如果参数是不定的行不行

简单的来书，
- 确定的column个数，用Series
- 不确定用dataframe
- iterator是类似用流

In [9]:
@pandas_udf("double")
def to_sum_func(data: pd.DataFrame) -> pd.Series:
    return data.age*data.salary
cols = [col("age"),col("salary")]
headers = struct(cols)
#my_sum = pandas_udf(to_sum_func, returnType=DoubleType())
basic_df.select(to_sum_func(headers).alias("result")).show()


+------------------+
|            result|
+------------------+
| 693.8706855159276|
| 88.08919659475565|
| 3.852364010725843|
| 274.9164445787106|
|153.18472152738917|
| 46.57773995169331|
|19.616299615985017|
| 53.84161549793407|
| 373.7602817877984|
| 211.5926897184308|
| 229.8594190657624|
| 889.9076301282829|
| 376.0256000068202|
| 315.3339345165881|
| 804.8297391977565|
|232.47232528881574|
|387.81166462997555|
| 268.6036231261751|
| 545.9174499125739|
| 40.50533401875588|
+------------------+
only showing top 20 rows



能不能用于SQL

In [10]:
basic_df.createOrReplaceTempView("pandas_udf")
spark.udf.register("pandas_to_str", to_str)
spark.sql("select pandas_to_str(age) from pandas_udf").show()

+------------------+
|pandas_to_str(age)|
+------------------+
|               981|
|               902|
|                10|
|               558|
|               960|
|               347|
|                94|
|               102|
|               398|
|               787|
|               552|
|               936|
|               574|
|               377|
|              1000|
|               593|
|               489|
|               297|
|               926|
|                70|
+------------------+
only showing top 20 rows



basic_df.createOrReplaceTempView("pandas_udf")
spark.udf.register("pandas_to_str", to_str)
spark.sql("select pandas_to_str(age) from pandas_udf").show()

`__call__`这个方法能不能用哪用

In [15]:
class PandasFunc:
    def __call__(self, data: pd.DataFrame)-> pd.Series:
         return data.age*data.salary

cols = [col("age"),col("salary")]
headers = struct(cols)
class_my_sum = pandas_udf(PandasFunc(), returnType=DoubleType())
basic_df.select(class_my_sum(headers).alias("result")).show()

+------------------+
|            result|
+------------------+
| 693.8706855159276|
| 88.08919659475565|
| 3.852364010725843|
| 274.9164445787106|
|153.18472152738917|
| 46.57773995169331|
|19.616299615985017|
| 53.84161549793407|
| 373.7602817877984|
| 211.5926897184308|
| 229.8594190657624|
| 889.9076301282829|
| 376.0256000068202|
| 315.3339345165881|
| 804.8297391977565|
|232.47232528881574|
|387.81166462997555|
| 268.6036231261751|
| 545.9174499125739|
| 40.50533401875588|
+------------------+
only showing top 20 rows

