In [1]:
import random
import string
from typing import List

import pandas as pd
from pyspark.sql.functions import struct, col
from pyspark.sql.pandas.functions import pandas_udf
from pyspark.sql.types import StringType, DoubleType

from sparkstudy.deploy.demo_sessions import DemoSQLSessionFactory
%load_ext autoreload
%autoreload 2
%matplotlib inline

COLUMNS = ["name","age","salary"]

比较下开不开启arrow的区别

测试下来，感觉性能提升有点奇怪。有时候会快，有时候会慢。

In [2]:
def create_random_data(row_num:int)->List[tuple]:
     result = list()
     a_str = string.ascii_uppercase
     for i in range(row_num):
         random_letter = random.choice(a_str)
         result.append((random_letter,random.randint(1,row_num),random.random()))
     return result

In [3]:
def test_performance(session_factory:DemoSQLSessionFactory, n:int = 100000):
    data = create_random_data(n)
    spark_session = session_factory.build_session()
    df = spark_session.createDataFrame(data,COLUMNS).cache()
    df.toPandas().head(5)

In [4]:
session_factory_normal = DemoSQLSessionFactory(name="normal")
%time test_performance(session_factory_normal)

CPU times: user 2.75 s, sys: 54.4 ms, total: 2.8 s
Wall time: 10.2 s


In [5]:
session_factory_arrow = DemoSQLSessionFactory(name="with arraw")
session_factory_arrow.add_config("spark.sql.execution.arrow.pyspark.enabled","true")
%time test_performance(session_factory_arrow)

CPU times: user 2.84 s, sys: 29.3 ms, total: 2.87 s
Wall time: 3.62 s


常规的HelloWorld的example。
页面上面的第一个例子。本质就是生成一个新的dataframe
1. 在annotation上面列出的是新的dataframe的col和类型
2. 他会自动的把pd的转换成spark的
3. 函数应该会分批node执行。然后再汇总。因为我看到了。hello world的函数会被执行好几次

In [6]:
session_factory_arrow.add_config('spark.sql.execution.arrow.maxRecordsPerBatch',10)
spark = session_factory_arrow.build_session()
test_data = create_random_data(row_num=1000)
basic_df = spark.createDataFrame(test_data,COLUMNS)
basic_df.show()

+----+---+--------------------+
|name|age|              salary|
+----+---+--------------------+
|   V|998|  0.2713079265456966|
|   W|473| 0.09608969362120778|
|   F|678| 0.07191766994148774|
|   M|624|  0.7805802879792632|
|   S|361|  0.8922136333483364|
|   G|505|  0.9017783167361239|
|   C|389| 0.27793918129778183|
|   M|247|  0.9734607293830692|
|   X|874|  0.9584370599441783|
|   F|384| 0.21520652529231066|
|   V|699| 0.03320378604414065|
|   I|764|  0.5951774351126187|
|   H|597| 0.28892018066268255|
|   N|985|  0.8533399451400041|
|   F|459|  0.9953521916101892|
|   P|316|  0.8110962722359051|
|   J|206| 0.43682929686694527|
|   X|850|   0.955839790517653|
|   I|205| 0.46154739932642985|
|   C|233|0.044360600720149246|
+----+---+--------------------+
only showing top 20 rows



In [7]:
@pandas_udf("total double")
def func(s1: pd.Series, s2: pd.Series) -> pd.DataFrame:
    print("execute")
    s3 = pd.DataFrame()
    s3['total'] = s1 + s2
    return s3
basic_df.select(func("age","salary").alias("result")).show()

+--------------------+
|              result|
+--------------------+
| [998.2713079265457]|
| [473.0960896936212]|
| [678.0719176699415]|
| [624.7805802879793]|
| [361.8922136333483]|
| [505.9017783167361]|
|[389.27793918129777]|
|[247.97346072938308]|
| [874.9584370599442]|
|[384.21520652529233]|
| [699.0332037860442]|
| [764.5951774351126]|
| [597.2889201806627]|
|   [985.85333994514]|
| [459.9953521916102]|
| [316.8110962722359]|
|[206.43682929686693]|
| [850.9558397905176]|
|[205.46154739932643]|
|[233.04436060072015]|
+--------------------+
only showing top 20 rows



主要是想要看看。select方法，不能不能接受一个List

In [8]:
def to_str_func(s1: pd.Series) -> pd.Series:
    return s1.astype(dtype=str)
to_str = pandas_udf(to_str_func, returnType=StringType())

age_c = to_str("age").alias("age")
salary_c = to_str("salary").alias("salary")
selects = [age_c,salary_c]
basic_df.select(selects).show()

+---+--------------------+
|age|              salary|
+---+--------------------+
|998|  0.2713079265456966|
|473| 0.09608969362120778|
|678| 0.07191766994148774|
|624|  0.7805802879792632|
|361|  0.8922136333483364|
|505|  0.9017783167361239|
|389| 0.27793918129778183|
|247|  0.9734607293830692|
|874|  0.9584370599441783|
|384| 0.21520652529231066|
|699| 0.03320378604414065|
|764|  0.5951774351126187|
|597| 0.28892018066268255|
|985|  0.8533399451400041|
|459|  0.9953521916101892|
|316|  0.8110962722359051|
|206| 0.43682929686694527|
|850|   0.955839790517653|
|205| 0.46154739932642985|
|233|0.044360600720149246|
+---+--------------------+
only showing top 20 rows



测试以下。如果参数是不定的行不行

简单的来书，
- 确定的column个数，用Series
- 不确定用dataframe
- iterator是类似用流

In [9]:
@pandas_udf("double")
def to_sum_func(data: pd.DataFrame) -> pd.Series:
    return data.age*data.salary
cols = [col("age"),col("salary")]
headers = struct(cols)
#my_sum = pandas_udf(to_sum_func, returnType=DoubleType())
basic_df.select(to_sum_func(headers).alias("result")).show()


+------------------+
|            result|
+------------------+
|270.76531069260517|
| 45.45042508283128|
|48.760180220328685|
| 487.0820996990602|
| 322.0891216387494|
|455.39804995174256|
|108.11834152483713|
| 240.4448001576181|
| 837.6739903912118|
| 82.63930571224729|
|23.209446444854315|
|454.71556042604067|
|172.48534785562148|
|  840.539845962904|
|456.86665594907686|
|  256.306422026546|
| 89.98683515459072|
|  812.463821940005|
| 94.61721686191812|
|10.336019967794774|
+------------------+
only showing top 20 rows



能不能用于SQL

In [10]:
basic_df.createOrReplaceTempView("pandas_udf")
spark.udf.register("pandas_to_str", to_str)
spark.sql("select pandas_to_str(age) from pandas_udf").show()

+------------------+
|pandas_to_str(age)|
+------------------+
|               998|
|               473|
|               678|
|               624|
|               361|
|               505|
|               389|
|               247|
|               874|
|               384|
|               699|
|               764|
|               597|
|               985|
|               459|
|               316|
|               206|
|               850|
|               205|
|               233|
+------------------+
only showing top 20 rows



basic_df.createOrReplaceTempView("pandas_udf")
spark.udf.register("pandas_to_str", to_str)
spark.sql("select pandas_to_str(age) from pandas_udf").show()

`__call__`这个方法能不能用哪用

In [11]:
class PandasFunc:
    def __call__(self, data: pd.DataFrame)-> pd.Series:
         return data.age*data.salary

cols = [col("age"),col("salary")]
headers = struct(cols)
class_my_sum = pandas_udf(PandasFunc(), returnType=DoubleType())
basic_df.select(class_my_sum(headers).alias("result")).show()

+------------------+
|            result|
+------------------+
|270.76531069260517|
| 45.45042508283128|
|48.760180220328685|
| 487.0820996990602|
| 322.0891216387494|
|455.39804995174256|
|108.11834152483713|
| 240.4448001576181|
| 837.6739903912118|
| 82.63930571224729|
|23.209446444854315|
|454.71556042604067|
|172.48534785562148|
|  840.539845962904|
|456.86665594907686|
|  256.306422026546|
| 89.98683515459072|
|  812.463821940005|
| 94.61721686191812|
|10.336019967794774|
+------------------+
only showing top 20 rows



返回多列的处理方法。

In [34]:
@pandas_udf("col1 double, col2 double")
def to_multi_return_func(data: pd.DataFrame) -> pd.DataFrame:
    print("execute")
    s3 = pd.DataFrame()
    s3['col1'] = data.age
    s3['col2'] = data.salary
    return s3
cols = [col("age"),col("salary")]
headers = struct(cols)
#my_sum = pandas_udf(to_sum_func, returnType=DoubleType())
multi_return_df = basic_df.withColumn("abc",to_multi_return_func(headers))
multi_return_df.select(col("age"),col("salary"),col("abc.col1"),col("abc.col2")).show()


+---+--------------------+-----+--------------------+
|age|              salary| col1|                col2|
+---+--------------------+-----+--------------------+
|998|  0.2713079265456966|998.0|  0.2713079265456966|
|473| 0.09608969362120778|473.0| 0.09608969362120778|
|678| 0.07191766994148774|678.0| 0.07191766994148774|
|624|  0.7805802879792632|624.0|  0.7805802879792632|
|361|  0.8922136333483364|361.0|  0.8922136333483364|
|505|  0.9017783167361239|505.0|  0.9017783167361239|
|389| 0.27793918129778183|389.0| 0.27793918129778183|
|247|  0.9734607293830692|247.0|  0.9734607293830692|
|874|  0.9584370599441783|874.0|  0.9584370599441783|
|384| 0.21520652529231066|384.0| 0.21520652529231066|
|699| 0.03320378604414065|699.0| 0.03320378604414065|
|764|  0.5951774351126187|764.0|  0.5951774351126187|
|597| 0.28892018066268255|597.0| 0.28892018066268255|
|985|  0.8533399451400041|985.0|  0.8533399451400041|
|459|  0.9953521916101892|459.0|  0.9953521916101892|
|316|  0.8110962722359051|31