In [1]:
from pyspark.sql.functions import udf, struct, collect_list
from pyspark.sql.types import IntegerType
from sparkstudy.deploy.demo_sessions import DemoSQLSessionFactory
%load_ext autoreload
%autoreload 2
%matplotlib inline

舒适化数据

In [2]:
sessionFactory = DemoSQLSessionFactory(name="local file")
spark = sessionFactory.build_session()
data = [("1","a","e"),("2","b","f"),("3","c","g")]
columns = ["str_age","name","address"]
logData = spark.createDataFrame(data, columns).cache()
logData.show()

+-------+----+-------+
|str_age|name|address|
+-------+----+-------+
|      1|   a|      e|
|      2|   b|      f|
|      3|   c|      g|
+-------+----+-------+



下面这个例子。是一个读一列的例子。
[例子来源](https://www.bmc.com/blogs/how-to-write-spark-udf-python/)

In [3]:
colsInt = udf(lambda z: to_int(z), IntegerType())
spark.udf.register("colsInt", colsInt)
def to_int(s):
    if isinstance(s, str):
        return int(s)
    else:
         return None

In [4]:
df2 = logData.withColumn('age',colsInt('str_age'))
df2.show()

+-------+----+-------+---+
|str_age|name|address|age|
+-------+----+-------+---+
|      1|   a|      e|  1|
|      2|   b|      f|  2|
|      3|   c|      g|  3|
+-------+----+-------+---+



读多列的样子

In [5]:
def count_columns(row):
    for r in row:
        print(r)
    return len(row)
countRow = udf(lambda row: count_columns(row), IntegerType())

In [6]:
columns = struct('name','address')
df3 = logData.withColumn("columns", countRow(columns))
df3.show()

+-------+----+-------+-------+
|str_age|name|address|columns|
+-------+----+-------+-------+
|      1|   a|      e|      2|
|      2|   b|      f|      2|
|      3|   c|      g|      2|
+-------+----+-------+-------+



作用于SQL 发觉需要做以下几件事情。
1. 注册成为函数。
2. 在SQL中使用


In [7]:
spark.udf.register("to_int", to_int)



logData.createOrReplaceTempView("logdata")

In [8]:
def count_sql_columns(*row):
    for r in row:
        print(r)
    return len(row)
spark.udf.register("count_sql_columns", count_sql_columns)
sql_data = spark.sql("select to_int(str_age) as abc, count_sql_columns(*) as col_num from logdata")
sql_data.show()

+---+-------+
|abc|col_num|
+---+-------+
|  1|      3|
|  2|      3|
|  3|      3|
+---+-------+



聚合，简单的来收，就是groupby

In [9]:
agg_data = [("1","a","e"),
            ("2","a","e"),
            ("3","a","e"),
            ("4","b","f"),
            ("3","c","g")]
columns = ["str_age","name","address"]
agg_df = spark.createDataFrame(agg_data, columns).cache()
agg_df.show()

+-------+----+-------+
|str_age|name|address|
+-------+----+-------+
|      1|   a|      e|
|      2|   a|      e|
|      3|   a|      e|
|      4|   b|      f|
|      3|   c|      g|
+-------+----+-------+



In [10]:
def find_a(x):
  """Count 'a's in list."""
  output_count = 0
  print(x)
  for i in x:
    if i == 'a':
      output_count += 1
  return output_count

find_a_udf = udf(find_a, IntegerType())
agg_df.groupBy('str_age').\
       agg(find_a_udf(collect_list('name')).
       alias('a_count')).show()

+-------+-------+
|str_age|a_count|
+-------+-------+
|      3|      1|
|      1|      1|
|      4|      0|
|      2|      1|
+-------+-------+

