In [16]:
from pyspark.sql import SparkSession

spark = SparkSession \
    .builder \
    .appName("Python Spark RFM example") \
    .config("spark.some.config.option", "some-value") \
    .getOrCreate()

In [17]:
df_raw = spark.read.format('com.databricks.spark.csv').\
                       options(header='true', \
                       inferschema='true').\
            load("file:///home/hadoop/OnlineRetail.csv",header=True);

In [18]:
df_raw.show(5)
df_raw.printSchema()

+---------+---------+--------------------+--------+--------------+---------+----------+--------------+
|InvoiceNo|StockCode|         Description|Quantity|   InvoiceDate|UnitPrice|CustomerID|       Country|
+---------+---------+--------------------+--------+--------------+---------+----------+--------------+
|   536365|   85123A|WHITE HANGING HEA...|       6|12/1/2010 8:26|     2.55|     17850|United Kingdom|
|   536365|    71053| WHITE METAL LANTERN|       6|12/1/2010 8:26|     3.39|     17850|United Kingdom|
|   536365|   84406B|CREAM CUPID HEART...|       8|12/1/2010 8:26|     2.75|     17850|United Kingdom|
|   536365|   84029G|KNITTED UNION FLA...|       6|12/1/2010 8:26|     3.39|     17850|United Kingdom|
|   536365|   84029E|RED WOOLLY HOTTIE...|       6|12/1/2010 8:26|     3.39|     17850|United Kingdom|
+---------+---------+--------------------+--------+--------------+---------+----------+--------------+
only showing top 5 rows

root
 |-- InvoiceNo: string (nullable = true)
 |

In [19]:
from pyspark.sql.functions import count

def my_count(df_in):
    df_in.agg( *[ count(c).alias(c) for c in df_in.columns ] ).show()

In [20]:
# 数据清理和数据操作

# 检查并移除 null
import pyspark.sql.functions as F
from pyspark.sql.functions import round
df_raw = df_raw.withColumn('Asset',round( F.col('Quantity') * F.col('UnitPrice'), 2 ))
df = df_raw.withColumnRenamed('StockCode', 'Cusip')\
           .select('CustomerID','Cusip','Quantity','UnitPrice','Asset')

In [21]:
my_count(df)

+----------+------+--------+---------+------+
|CustomerID| Cusip|Quantity|UnitPrice| Asset|
+----------+------+--------+---------+------+
|    406829|541909|  541909|   541909|541909|
+----------+------+--------+---------+------+



In [22]:
#由于计数结果不相同，因此在 CustomerID 列。我们可以从数据集中除去这些记录。
df =  df.filter(F.col('Asset')>=0)
df = df.dropna(how='any')
my_count(df)

+----------+------+--------+---------+------+
|CustomerID| Cusip|Quantity|UnitPrice| Asset|
+----------+------+--------+---------+------+
|    397924|397924|  397924|   397924|397924|
+----------+------+--------+---------+------+



In [23]:
df.show(3)

+----------+------+--------+---------+-----+
|CustomerID| Cusip|Quantity|UnitPrice|Asset|
+----------+------+--------+---------+-----+
|     17850|85123A|       6|     2.55| 15.3|
|     17850| 71053|       6|     3.39|20.34|
|     17850|84406B|       8|     2.75| 22.0|
+----------+------+--------+---------+-----+
only showing top 3 rows



In [24]:
#转换 Cusip 格式一致
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType, DoubleType

def toUpper(s):
    return s.upper()

upper_udf = udf(lambda x: toUpper(x), StringType())

In [25]:
pop = df.groupBy('Cusip')\
  .agg(F.count('CustomerID').alias('Customers'),F.round(F.sum('Asset'),2).alias('TotalAsset'))\
  .sort([F.col('Customers'),F.col('TotalAsset')],ascending=[0,0])

pop.show(5)

+------+---------+----------+
| Cusip|Customers|TotalAsset|
+------+---------+----------+
|85123A|     2035|  100603.5|
| 22423|     1724| 142592.95|
|85099B|     1618|  85220.78|
| 84879|     1408|  56580.34|
| 47566|     1397|  68844.33|
+------+---------+----------+
only showing top 5 rows



In [26]:
import pandas as pd
top = 10
cusip_lst = pd.DataFrame(pop.select('Cusip').head(top)).astype('str').iloc[:, 0].tolist()
cusip_lst.insert(0,'CustomerID')

In [27]:
#为每个客户创建投资组合表
pivot_tab = df.groupBy('CustomerID').pivot('Cusip').sum('Asset')
pivot_tab = pivot_tab.fillna(0)

In [28]:
#取最多 n 每个客户的股票投资组合表
selected_tab  = pivot_tab.select(cusip_lst)
selected_tab.show(4)

+----------+------+-----+------+-----+-----+-----+-----+-----+----+-----+
|CustomerID|85123A|22423|85099B|84879|47566|20725|22720|20727|POST|23203|
+----------+------+-----+------+-----+-----+-----+-----+-----+----+-----+
|     16503|   0.0|  0.0|   0.0|  0.0|  0.0|  0.0|  0.0| 33.0| 0.0|  0.0|
|     15727| 123.9| 25.5|   0.0|  0.0|  0.0| 33.0| 99.0|  0.0| 0.0|  0.0|
|     14570|   0.0|  0.0|   0.0|  0.0|  0.0|  0.0|  0.0|  0.0| 0.0|  0.0|
|     14450|   0.0|  0.0|  8.32|  0.0|  0.0|  0.0| 49.5|  0.0| 0.0|  0.0|
+----------+------+-----+------+-----+-----+-----+-----+-----+----+-----+
only showing top 4 rows



In [36]:
#建造 train 和 test 数据集
train, test = selected_tab .randomSplit([0.8,0.2])

train.show(5)
test.show(5)

+----------+------+-----+------+-----+-----+-----+-----+-----+----+-----+
|CustomerID|85123A|22423|85099B|84879|47566|20725|22720|20727|POST|23203|
+----------+------+-----+------+-----+-----+-----+-----+-----+----+-----+
|     12940|  17.7| 51.0|   0.0|13.52| 19.8|  0.0|  0.0|  0.0| 0.0|  0.0|
|     13623| 14.75| 25.5|   0.0|  0.0|  0.0|  0.0|  0.0|  0.0| 0.0|  0.0|
|     13832|   0.0|  0.0|   0.0|  0.0|  0.0|  0.0|  0.0|  0.0| 0.0|  0.0|
|     14450|   0.0|  0.0|  8.32|  0.0|  0.0|  0.0| 49.5|  0.0| 0.0|  0.0|
|     15447|   0.0| 25.5|   0.0|13.52|23.25|  0.0|  0.0|  0.0| 0.0|  0.0|
+----------+------+-----+------+-----+-----+-----+-----+-----+----+-----+
only showing top 5 rows

+----------+------+-----+------+-----+-----+-----+-----+-----+-----+-----+
|CustomerID|85123A|22423|85099B|84879|47566|20725|22720|20727| POST|23203|
+----------+------+-----+------+-----+-----+-----+-----+-----+-----+-----+
|     13285|   0.0|12.75|  41.6|  0.0|  0.0| 33.0|  0.0| 33.0|  0.0| 40.3|
|     145