### Machine Learning Advanced Big data Project

In [0]:
orders_sdf = spark.read.csv('/FileStore/tables/orders.csv', header=True, inferSchema=True)
trains_sdf = spark.read.csv('/FileStore/tables/order_products_train.csv', header=True, inferSchema=True)
products_sdf = spark.read.csv('/FileStore/tables/products.csv', header=True, inferSchema=True)
aisles_sdf = spark.read.csv('/FileStore/tables/aisles.csv', header=True, inferSchema=True)
depts_sdf = spark.read.csv('/FileStore/tables/departments.csv', header=True, inferSchema=True)

In [0]:
%fs 
cp /FileStore/tables/order_products_prior.zip file:/home/order_products_prior.zip 

In [0]:
import pandas as pd

priors_pdf = pd.read_csv('/home/order_products_prior.zip', compression='zip', header=0, sep=',', quotechar='"')
priors_sdf = spark.createDataFrame(priors_pdf)
del priors_pdf # 메모리 절약을 위해 pandas dataframe삭제

In [0]:
orders_sdf.createOrReplaceTempView("orders")
priors_sdf.createOrReplaceTempView("priors")
trains_sdf.createOrReplaceTempView("trains")
products_sdf.createOrReplaceTempView("products")
aisles_sdf.createOrReplaceTempView("aisles")
depts_sdf.createOrReplaceTempView("depts")

In [0]:
#테이블 등록
spark.catalog.listTables()

Out[4]: [Table(name='aisles', database=None, description=None, tableType='TEMPORARY', isTemporary=True),
 Table(name='depts', database=None, description=None, tableType='TEMPORARY', isTemporary=True),
 Table(name='orders', database=None, description=None, tableType='TEMPORARY', isTemporary=True),
 Table(name='priors', database=None, description=None, tableType='TEMPORARY', isTemporary=True),
 Table(name='products', database=None, description=None, tableType='TEMPORARY', isTemporary=True),
 Table(name='trains', database=None, description=None, tableType='TEMPORARY', isTemporary=True)]

In [0]:
%fs
rm -r dbfs:/user/hive/warehouse/order_priors_prods

In [0]:
%sql 
drop table if exists order_priors_prods; 

-- priors와 orders를 조인 
-- orders에서는 pk를 확인할 수 없기 때문에 조인
create table  order_priors_prods 
as 
select a.order_id, a.product_id, a.add_to_cart_order, a.reordered, b.user_id, b.eval_set, b.order_number, b.order_dow, b.order_hour_of_day, b.days_since_prior_order 
from priors a, orders b 
where a.order_id = b.order_id;

num_affected_rows,num_inserted_rows


### Creating a product analysis table based on product level analysis attributes
* PK is a product code (product_id) and generates a product analysis table with attributes analyzed in the previous EDA..

In [0]:
%fs
rm -r dbfs:/user/hive/warehouse/prd_mart

In [0]:
%sql
drop table if exists prd_mart;

create table prd_mart
as
with 
-- with 구문 첫번째 집합. product_id 레벨로 group by 하여 상품별 서로 다른 개별 사용자 비율을 추출한 결과에 상품명과 상품 중분류명 알기 위해 products와 aisles로 조인
order_prods_grp as
(
  select a.product_id 
    -- ## 상품별 재주문 속성
    , sum(case when reordered=1 then 1 else 0 end) as prd_reordered_cnt -- 상품별 재 주문 건수
    , sum(case when reordered=0 then 1 else 0 end) as prd_no_reordered_cnt -- 상품별 재 주문 하지 않은 건수 
    , avg(reordered) prd_avg_reordered -- 상품별 재 주문 비율
    -- ## 상품별 고유 사용자 및 이전 주문이후 걸린 일자 속성. 
    , count(distinct user_id) prd_unq_usr_cnt -- 상품별 고유 사용자 건수
    , count(*)  prd_total_cnt -- 상품별 건수
    , count(distinct user_id)/count(*) as prd_usr_ratio -- 상품별 전체 건수 대비 고유 사용자 비율
    , max(c.aisle_id) aisle_id -- 상품 중분류 코드 
    , nvl(avg(days_since_prior_order), 0) as prd_avg_prior_days -- 평균 이전 주문이후 걸린 일자, null인 경우 0으로 변환. 
    , nvl(min(days_since_prior_order), 0) as prd_min_prior_days -- 최소 이전 주문이후 걸린 일자, null인 경우 0으로 변환. 
    , nvl(max(days_since_prior_order), 0) as prd_max_prior_days -- 최대 이전 주문이후 걸린 일자, null인 경우 0으로 변환. 
    from order_priors_prods a, products b, aisles c
  where a.product_id = b.product_id 
  and b.aisle_id = c.aisle_id
  group by a.product_id
),
-- with 구문 두번째 집합. product_id 레벨로 group by 하여 상품별 서로 다른 개별 사용자 비율을 추출한 결과에 product_name과 중분류명, 대분류명을 알기 위해 aisles와 dept로 조인. 
order_aisles_grp as
(
  select c.aisle_id as aisle_id 
     , count(distinct a.user_id) aisle_distinct_usr_cnt -- 상품 중분류별 고유 사용자 건수
     , count(*)  aisle_total_cnt -- 상품 중분류 건수
     , count(distinct a.user_id)/count(*) as aisle_usr_ratio -- 상품 중분류 건수 대비 고유 사용자 건수 비율
  from order_priors_prods a, products b, aisles c
  where a.product_id = b.product_id 
  and b.aisle_id = c.aisle_id
  group by c.aisle_id
),
-- with 구문 세번째 집합. 상품 중분류 별 개별 사용자 비율과 상품별 개별 사용자 비율 차이 추출. 
order_prd_grp_aisle as
(
  select product_id, prd_reordered_cnt,  prd_no_reordered_cnt, prd_avg_reordered, prd_unq_usr_cnt, prd_total_cnt, prd_usr_ratio
    , prd_avg_prior_days, prd_min_prior_days, prd_max_prior_days-- 상품별 속성들
    , b.aisle_distinct_usr_cnt, b.aisle_total_cnt, b.aisle_usr_ratio -- 상품 중분류별 속성들 
    , a.prd_usr_ratio - b.aisle_usr_ratio as usr_ratio_diff -- 상품 중분류 별 개별 사용자 비율과 상품별 개별 사용자 비율 차이
  from order_prods_grp a, order_aisles_grp b
  where a.aisle_id = b.aisle_id
) 
-- end of with 절
select * from order_prd_grp_aisle

num_affected_rows,num_inserted_rows


In [0]:
%sql
select * from prd_mart limit 10

product_id,prd_reordered_cnt,prd_no_reordered_cnt,prd_avg_reordered,prd_unq_usr_cnt,prd_total_cnt,prd_usr_ratio,prd_avg_prior_days,prd_min_prior_days,prd_max_prior_days,aisle_distinct_usr_cnt,aisle_total_cnt,aisle_usr_ratio,usr_ratio_diff
47908,0,3,0.0,3,3,1.0,22.33333333333333,7.0,30.0,85357,575881,0.1482198579220359,0.851780142077964
9856,0,3,0.0,3,3,1.0,10.666666666666666,7.0,17.0,85357,575881,0.1482198579220359,0.851780142077964
3832,0,2,0.0,2,2,1.0,21.5,13.0,30.0,85357,575881,0.1482198579220359,0.851780142077964
12120,0,3,0.0,3,3,1.0,6.333333333333333,6.0,7.0,85357,575881,0.1482198579220359,0.851780142077964
10536,3,7,0.3,7,10,0.7,21.625,9.0,30.0,85357,575881,0.1482198579220359,0.551780142077964
33171,0,8,0.0,8,8,1.0,10.571428571428571,3.0,30.0,85357,575881,0.1482198579220359,0.851780142077964
29994,13,4,0.7647058823529411,4,17,0.2352941176470588,7.666666666666667,1.0,21.0,85357,575881,0.1482198579220359,0.0870742597250228
28551,28,7,0.8,7,35,0.2,6.8,0.0,29.0,85357,575881,0.1482198579220359,0.051780142077964
46860,0,5,0.0,5,5,1.0,11.25,3.0,22.0,85357,575881,0.1482198579220359,0.851780142077964
1431,9,14,0.391304347826087,14,23,0.6086956521739131,11.428571428571429,1.0,30.0,85357,575881,0.1482198579220359,0.4604757942518771


In [0]:
%sql
--49676
select count(*) from prd_mart

count(1)
49676


In [0]:
import pyspark.sql.functions as F

prd_mart_sdf = spark.sql("select * from prd_mart")

display(prd_mart_sdf.select([F.count(F.when(F.col(c).isNull(), c)).alias(c) for c in prd_mart_sdf.columns]))

product_id,prd_reordered_cnt,prd_no_reordered_cnt,prd_avg_reordered,prd_unq_usr_cnt,prd_total_cnt,prd_usr_ratio,prd_avg_prior_days,prd_min_prior_days,prd_max_prior_days,aisle_distinct_usr_cnt,aisle_total_cnt,aisle_usr_ratio,usr_ratio_diff
0,0,0,0,0,0,0,0,0,0,0,0,0,0


### Create user analysis tables based on user-level analysis properties
* PK is the user ID (user_id) and creates a user analysis table with attributes analyzed by previous EDA.
* Order_id is required to create prediction data in the future. To this end, it is necessary to extract order_id by joining the order data for training and testing with user_id.
* The order table is user_id level m, but if eval_set is train and test, the user_id level is 1, so the user_mart table level does not change when joining.

In [0]:
%fs
rm -r dbfs:/user/hive/warehouse/user_mart_01

In [0]:
%sql
drop table if exists user_mart_01;

create table user_mart_01
as
select user_id 
  , count(*) as usr_total_cnt -- 사용자별 주문 건수
  -- 주문 건수 관련 속성 추출. 
  , count(distinct product_id) prd_uq_cnt  -- 사용자별 고유 상품 주문 건수
  , count(distinct order_id) order_uq_cnt -- 사용자별 고유 주문 건수
  , count(*)/count(distinct order_id) as usr_avg_prd_cnt -- 사용자별 1회 주문시 평균 주문 상품 건수
  , count(*)/count(distinct product_id) as usr_avg_uq_prd_cnt -- 사용자별 1회 주문시 평균 고유 주문 상품 건수
  , count(distinct product_id)/count(*) as usr_uq_prd_ratio --사용자별 총 상품 건수 대비 고유 상품 건수 비율
  -- ### reordered 관련 속성 추출. ###
  , sum(reordered) usr_reord_cnt -- 사용자별 reordered된 상품 건수
  , sum(case when reordered = 0 then 1 else 0 end) as usr_no_reord_cnt -- 사용자별 reorder 하지 않은 상품 건수. count(*) - sum(reoredred)와 동일. 
  , avg(reordered) usr_reordered_avg -- 사용자별 reordered 비율
  -- ### days_since_prior_order 관련 속성 추출. ###
  , avg(days_since_prior_order) usr_avg_prior_days
  , max(days_since_prior_order) usr_max_prior_days
  , min(days_since_prior_order) usr_min_prior_days
  -- ### order_dow, order_hour_of_day 관련 속성 추출. ###
  , avg(order_dow) usr_avg_order_dow
  , avg(order_hour_of_day) usr_avg_order_hour_of_day
  -- 사용자별 최대 order_number
  , max(order_number) as usr_max_order_number
from order_priors_prods a group by user_id

num_affected_rows,num_inserted_rows


In [0]:
%sql 
select count(*) from user_mart_01

count(1)
206209


In [0]:
%fs
rm -r dbfs:/user/hive/warehouse/user_mart

In [0]:
%sql
-- drop table if exists user_mart;
-- cmd 20 
-- orders는 eval_set이 train/test일 경우 한개의 user_id가 한개의 order_id를 가짐. 때문에 train/test인 경우 조인키값 user_id로 1레벨이 됨.
-- order_priors_prods에 있는 모든 user_id는 orders의 모든 user_id와 동일. orders는 user_id별로 여러건의 order가 있고, 이들중 마지막 order를 train또는 test로 할당하기 때문
-- 따라서 user_mart_01과 eval_set이 train과 test인 orders를 user_id로 조인하면 1:1 조인이 되고 user_mart_01의 집합 레벨의 변화가 없음. outer join을 하지 않아도 됨. 
create table user_mart
as
select a.*, b.order_id, b.eval_set, b.days_since_prior_order
from user_mart_01 a, orders b
where a.user_id = b.user_id
and b.eval_set in ('train', 'test')   

num_affected_rows,num_inserted_rows


In [0]:
%sql
select * from user_mart limit 10

user_id,usr_total_cnt,prd_uq_cnt,order_uq_cnt,usr_avg_prd_cnt,usr_avg_uq_prd_cnt,usr_uq_prd_ratio,usr_reord_cnt,usr_no_reord_cnt,usr_reordered_avg,usr_avg_prior_days,usr_max_prior_days,usr_min_prior_days,usr_avg_order_dow,usr_avg_order_hour_of_day,usr_max_order_number,order_id,eval_set,days_since_prior_order
1,59,18,10,5.9,3.2777777777777777,0.3050847457627119,41,18,0.6949152542372882,20.25925925925926,30.0,0.0,2.6440677966101696,10.542372881355933,10,1187899,train,14.0
2,195,102,14,13.928571428571429,1.911764705882353,0.5230769230769231,93,102,0.4769230769230769,15.967032967032967,30.0,3.0,2.005128205128205,10.44102564102564,14,1492625,train,30.0
3,88,33,12,7.333333333333333,2.6666666666666665,0.375,55,33,0.625,11.487179487179487,21.0,7.0,1.0113636363636365,16.352272727272727,12,2774568,test,11.0
4,18,17,5,3.6,1.0588235294117647,0.9444444444444444,1,17,0.0555555555555555,15.357142857142858,21.0,0.0,4.722222222222222,13.11111111111111,5,329954,test,30.0
5,37,23,4,9.25,1.608695652173913,0.6216216216216216,14,23,0.3783783783783784,14.5,19.0,10.0,1.6216216216216215,15.72972972972973,4,2196797,train,6.0
6,14,12,3,4.666666666666667,1.1666666666666667,0.8571428571428571,2,12,0.1428571428571428,7.8,12.0,6.0,3.857142857142857,17.0,3,1528013,test,22.0
7,206,68,20,10.3,3.0294117647058822,0.3300970873786408,138,68,0.6699029126213593,13.54639175257732,30.0,2.0,1.7281553398058251,13.631067961165048,20,525192,train,6.0
8,49,36,3,16.333333333333332,1.3611111111111112,0.7346938775510204,13,36,0.2653061224489796,30.0,30.0,30.0,4.204081632653061,2.4489795918367347,3,880375,train,10.0
9,76,58,3,25.33333333333333,1.3103448275862069,0.7631578947368421,18,58,0.2368421052631578,24.26086956521739,30.0,6.0,2.6973684210526314,14.263157894736842,3,1094988,train,30.0
10,143,94,5,28.6,1.5212765957446808,0.6573426573426573,49,94,0.3426573426573426,20.746376811594203,30.0,12.0,4.013986013986014,16.902097902097903,5,1822501,train,30.0


In [0]:
%sql
--206209
select count(*) from user_mart

count(1)
206209


In [0]:
%sql
select count(*)
from orders b
where b.eval_set in ('train', 'test')  

count(1)
206209


user_id : 1 means that the tables includes records for training 

user_id : 3 means that the tables includes records for test

In [0]:

%sql
select * from orders where user_id = 1

order_id,user_id,eval_set,order_number,order_dow,order_hour_of_day,days_since_prior_order
2539329,1,prior,1,2,8,
2398795,1,prior,2,3,7,15.0
473747,1,prior,3,3,12,21.0
2254736,1,prior,4,4,7,29.0
431534,1,prior,5,4,15,28.0
3367565,1,prior,6,2,7,19.0
550135,1,prior,7,1,9,20.0
3108588,1,prior,8,1,14,14.0
2295261,1,prior,9,1,16,0.0
2550362,1,prior,10,4,8,30.0


In [0]:
%sql
select * from orders where user_id = 3

order_id,user_id,eval_set,order_number,order_dow,order_hour_of_day,days_since_prior_order
1374495,3,prior,1,1,14,
444309,3,prior,2,3,19,9.0
3002854,3,prior,3,3,16,21.0
2037211,3,prior,4,2,18,20.0
2710558,3,prior,5,0,17,12.0
1972919,3,prior,6,0,16,7.0
1839752,3,prior,7,0,15,7.0
3225766,3,prior,8,0,17,7.0
3160850,3,prior,9,0,16,7.0
676467,3,prior,10,3,16,17.0


### Create user+product analysis table based on user+product level analysis attributes
* PK is user ID (user_id) + product code (product_id) and generates user + product analysis table with attributes analyzed in previous EDA.
* The previously created prd_mart and user_mart are joined with the user+product analysis table to combine product-related attributes and user-related attributes.

In [0]:
%fs
rm -r dbfs:/user/hive/warehouse/up_mart

In [0]:
%fs
rm -r dbfs:/user/hive/warehouse/up_mart_01

In [0]:
%sql
select * from order_priors_prods where user_id = 1

order_id,product_id,add_to_cart_order,reordered,user_id,eval_set,order_number,order_dow,order_hour_of_day,days_since_prior_order
473747,196,1,1,1,prior,3,3,12,21.0
473747,12427,2,1,1,prior,3,3,12,21.0
473747,10258,3,1,1,prior,3,3,12,21.0
473747,25133,4,0,1,prior,3,3,12,21.0
473747,30450,5,0,1,prior,3,3,12,21.0
3108588,12427,1,1,1,prior,8,1,14,14.0
3108588,196,2,1,1,prior,8,1,14,14.0
3108588,10258,3,1,1,prior,8,1,14,14.0
3108588,25133,4,1,1,prior,8,1,14,14.0
3108588,46149,5,0,1,prior,8,1,14,14.0


In [0]:
%sql
drop table if exists up_mart;
drop table if exists up_mart_01;

create table up_mart
as
with 
-- 사용자+상품 레벨로 group by 하여 속성 추출. 
up_grp as
(
SELECT user_id, product_id
    , count(*) up_cnt  -- 사용자의 개별 상품별 주문 건수
    , sum(reordered) up_reord_cnt -- 사용자의 개별 상품별 reorder 건수
    , sum(case when reordered=0 then 1 else 0 end) up_no_reord_cnt
    , avg(reordered) up_reoredered_avg -- 사용자의 개별 상품 주문별 reorder비율 
    , max(order_number) up_max_ord_num -- 사용자+상품레벨에서 가장 큰 order_number. order_number는 사용자 별로 주문을 수행한 일련번호를 순차적으로 가짐. 
    , min(order_number) up_min_ord_num -- 사용자+상품레벨에서 가장 작은 order_number
    , avg(add_to_cart_order) up_avg_cart --사용자 상품레벨에서 보통 cart에 몇번째로 담는가?
    , avg(days_since_prior_order) as up_avg_prior_days
    , max(days_since_prior_order) as up_max_prior_days
    , min(days_since_prior_order) as up_min_prior_days
    , avg(order_dow) as up_avg_ord_dow
    , avg(order_hour_of_day) as up_avg_ord_hour
FROM order_priors_prods GROUP BY user_id, product_id
)
-- end of with 절 
-- 사용자 레벨로 group by 한 user_mart 테이블과 조인하여 사용자 레벨 속성과 사용자+상품 레벨 속성의 비율을 추출. 
select a.* 
  , a.up_cnt/b.usr_total_cnt as up_usr_ratio -- 사용자별 전체 주문 건수 대비 사용자+상품 주문 건수 비율
  , a.up_reord_cnt/b.usr_reord_cnt as up_usr_reord_ratio -- 사용자별 전체 재주문 건수 대비 사용자+상품 재주문 건수 비율
  , b.usr_reord_cnt
  , b.usr_max_order_number - a.up_max_ord_num as up_usr_ord_num_diff -- 사용자의 가장 최근 주문(가장 큰 주문번호)에서 현 상품 주문번호가 어느정도 이후에 있는지
from up_grp a, user_mart b
where a.user_id = b.user_id

num_affected_rows,num_inserted_rows


In [0]:
%sql
--13307953
select count(*) from up_mart

count(1)
13307953


In [0]:
%sql
-- up_mart에서 user_mart로, user_id로 join이 안되거나 prd_mart로, product_id로 join이 안되는 경우 추출.  
select count(*)
from up_mart a 
left outer join user_mart b
on a.user_id = b.user_id
left outer join prd_mart c
on a.product_id = c.product_id
where (b.user_id is null or c.product_id is null)

count(1)
3


In [0]:
%sql
select * from aisles where aisle_id='Blunted'
/* 
select * from products a where product_id = 6816
select * from aisles where aisle_id='Blunted' 
*/

aisle_id,aisle


In [0]:
%sql
-- 현재까지 만들어진 테이블의 건수 조사 
select 'user_mart count' as gubun, count(*) from user_mart
union all
select 'prd_mart count' as gubun, count(*) from prd_mart
union all
select 'up_mart count' as gubun, count(*) from up_mart

gubun,count(1)
user_mart count,206209
prd_mart count,49676
up_mart count,13307953


#### Create data_mart by combining prd_mart, user_mart, and up_mart created so far.
* The generated data_mart combines product analysis properties and user analysis properties by joining prd_mart and user_mart based on up_mart.

In [0]:
%sql
describe up_mart

col_name,data_type,comment
user_id,int,
product_id,bigint,
up_cnt,bigint,
up_reord_cnt,bigint,
up_no_reord_cnt,bigint,
up_reoredered_avg,double,
up_max_ord_num,int,
up_min_ord_num,int,
up_avg_cart,double,
up_avg_prior_days,double,


In [0]:
print(spark.sql("select * from up_mart").columns)
print(spark.sql("select * from user_mart").columns)
print(spark.sql("select * from prd_mart").columns)

['user_id', 'product_id', 'up_cnt', 'up_reord_cnt', 'up_no_reord_cnt', 'up_reoredered_avg', 'up_max_ord_num', 'up_min_ord_num', 'up_avg_cart', 'up_avg_prior_days', 'up_max_prior_days', 'up_min_prior_days', 'up_avg_ord_dow', 'up_avg_ord_hour', 'up_usr_ratio', 'up_usr_reord_ratio', 'usr_reord_cnt', 'up_usr_ord_num_diff']
['user_id', 'usr_total_cnt', 'prd_uq_cnt', 'order_uq_cnt', 'usr_avg_prd_cnt', 'usr_avg_uq_prd_cnt', 'usr_uq_prd_ratio', 'usr_reord_cnt', 'usr_no_reord_cnt', 'usr_reordered_avg', 'usr_avg_prior_days', 'usr_max_prior_days', 'usr_min_prior_days', 'usr_avg_order_dow', 'usr_avg_order_hour_of_day', 'usr_max_order_number', 'order_id', 'eval_set', 'days_since_prior_order']
['product_id', 'prd_reordered_cnt', 'prd_no_reordered_cnt', 'prd_avg_reordered', 'prd_unq_usr_cnt', 'prd_total_cnt', 'prd_usr_ratio', 'prd_avg_prior_days', 'prd_min_prior_days', 'prd_max_prior_days', 'aisle_distinct_usr_cnt', 'aisle_total_cnt', 'aisle_usr_ratio', 'usr_ratio_diff']


In [0]:
%fs
rm -r dbfs:/user/hive/warehouse/data_mart

In [0]:
%sql
drop table if exists data_mart;

-- up_mart에 user_mart를 user_id로 조인, prd_mart는 product_id로 조인하여 개별 xxx_mart테이블의 속성들을 취합하여 data_mart 테이블 생성. 약 4분정도 걸림. 
create table data_mart
as
select 
  -- up_mart 컬럼 
  a.user_id, a.product_id, b.order_id -- 테스트 데이터 예측 데이터 제출을 위해서 order_id가 필요함. 
  , up_cnt, up_reord_cnt, up_no_reord_cnt, up_reoredered_avg, up_max_ord_num, up_min_ord_num, up_avg_cart, up_avg_prior_days, up_max_prior_days
  , up_min_prior_days, up_avg_ord_dow, up_avg_ord_hour, up_usr_ratio, up_usr_reord_ratio, up_usr_ord_num_diff
  -- user_mart 컬럼, eval_set에 train과 test용 데이터(사용자)구분
  , usr_total_cnt, prd_uq_cnt, order_uq_cnt, usr_avg_prd_cnt, usr_avg_uq_prd_cnt, usr_uq_prd_ratio, a.usr_reord_cnt, usr_no_reord_cnt, usr_reordered_avg, usr_avg_prior_days
  , usr_max_prior_days, usr_min_prior_days, usr_avg_order_dow, usr_avg_order_hour_of_day, usr_max_order_number, eval_set, days_since_prior_order
  -- prd_mart 컬럼
  , prd_reordered_cnt, prd_no_reordered_cnt, prd_avg_reordered, prd_unq_usr_cnt, prd_total_cnt, prd_usr_ratio, prd_avg_prior_days, prd_min_prior_days, prd_max_prior_days
  , aisle_distinct_usr_cnt, aisle_total_cnt, aisle_usr_ratio, usr_ratio_diff
from up_mart a, user_mart b, prd_mart c
where a.user_id = b.user_id and a.product_id = c.product_id

num_affected_rows,num_inserted_rows


In [0]:
%sql
-- 현재까지 생성된 테이블의 건수 조사. data_mart는 up_mart와 동일 건수 - 3 
select 'data_mart count' as gubun, count(*) from data_mart
union all 
select 'user_mart count' as gubun, count(*) from user_mart
union all
select 'prd_mart count' as gubun, count(*) from prd_mart
union all
select 'up_mart count' as gubun, count(*) from up_mart

gubun,count(1)
data_mart count,13307950
user_mart count,206209
prd_mart count,49676
up_mart count,13307953


In [0]:
%sql
select * from data_mart limit 10

user_id,product_id,order_id,up_cnt,up_reord_cnt,up_no_reord_cnt,up_reoredered_avg,up_max_ord_num,up_min_ord_num,up_avg_cart,up_avg_prior_days,up_max_prior_days,up_min_prior_days,up_avg_ord_dow,up_avg_ord_hour,up_usr_ratio,up_usr_reord_ratio,up_usr_ord_num_diff,usr_total_cnt,prd_uq_cnt,order_uq_cnt,usr_avg_prd_cnt,usr_avg_uq_prd_cnt,usr_uq_prd_ratio,usr_reord_cnt,usr_no_reord_cnt,usr_reordered_avg,usr_avg_prior_days,usr_max_prior_days,usr_min_prior_days,usr_avg_order_dow,usr_avg_order_hour_of_day,usr_max_order_number,eval_set,days_since_prior_order,prd_reordered_cnt,prd_no_reordered_cnt,prd_avg_reordered,prd_unq_usr_cnt,prd_total_cnt,prd_usr_ratio,prd_avg_prior_days,prd_min_prior_days,prd_max_prior_days,aisle_distinct_usr_cnt,aisle_total_cnt,aisle_usr_ratio,usr_ratio_diff
4041,34063,730911,8,7,1,0.875,27,4,4.875,15.5,30.0,7.0,3.25,13.0,0.023598820058997,0.0424242424242424,0,339,174,27,12.555555555555555,1.9482758620689653,0.5132743362831859,165,174,0.4867256637168141,12.615615615615615,30.0,2.0,3.274336283185841,11.050147492625369,27,test,20.0,1057,738,0.588857938718663,738,1795,0.411142061281337,12.03411131059246,0.0,30.0,46636,207075,0.2252130870457563,0.1859289742355807
99471,14917,2560214,11,10,1,0.9090909090909092,17,1,6.818181818181818,7.5,15.0,5.0,3.0,13.0,0.0311614730878186,0.0423728813559322,2,353,117,19,18.57894736842105,3.017094017094017,0.3314447592067989,236,117,0.6685552407932012,10.475609756097562,30.0,5.0,2.903682719546742,13.025495750708217,19,test,19.0,2411,1055,0.6956145412579342,1055,3466,0.3043854587420658,9.70420829548895,0.0,30.0,124393,1452343,0.0856498774738474,0.2187355812682183
7440,18770,3078569,16,15,1,0.9375,99,62,7.875,3.875,15.0,1.0,2.625,13.75,0.0133111480865224,0.0155440414507772,0,1202,237,99,12.141414141414142,5.071729957805907,0.1971713810316139,965,237,0.802828618968386,3.687925170068027,18.0,0.0,2.952579034941764,14.537437603993345,99,test,1.0,1177,897,0.5675024108003858,897,2074,0.4324975891996143,10.902051282051282,0.0,30.0,78030,395130,0.1974793106066358,0.2350182785929785
34858,21174,3009068,6,5,1,0.8333333333333334,42,2,5.666666666666667,5.166666666666667,15.0,0.0,2.5,13.666666666666666,0.008695652173913,0.011574074074074,5,690,258,47,14.680851063829786,2.6744186046511627,0.3739130434782609,432,258,0.6260869565217392,7.616191904047976,22.0,0.0,2.1333333333333333,13.7,47,train,5.0,6852,5509,0.5543240838119893,5509,12361,0.4456759161880106,11.14011466296039,0.0,30.0,159418,1765313,0.0903057984618025,0.3553701177262081
16181,35168,1675807,5,4,1,0.8,14,5,10.0,15.8,30.0,6.0,1.2,12.4,0.0137741046831955,0.0156862745098039,5,363,108,19,19.105263157894736,3.361111111111111,0.2975206611570248,255,108,0.7024793388429752,13.746438746438749,30.0,5.0,1.578512396694215,13.110192837465563,19,train,30.0,3250,2618,0.5538513974096796,2618,5868,0.4461486025903204,11.240999265246142,0.0,30.0,57255,193297,0.2962022173132537,0.1499463852770667
153133,21137,1769478,2,1,1,0.5,17,13,24.5,23.0,30.0,16.0,4.0,19.5,0.0086580086580086,0.0070921985815602,0,231,90,17,13.588235294117649,2.566666666666667,0.3896103896103896,141,90,0.6103896103896104,14.748815165876778,30.0,1.0,3.0043290043290045,14.80952380952381,17,test,7.0,205845,58838,0.7777038948477991,58838,264683,0.2222961051522009,10.03712850345864,0.0,30.0,177141,3642188,0.048635874919142,0.1736602302330589
29916,26209,2199322,3,2,1,0.6666666666666666,7,3,8.333333333333334,20.666666666666668,30.0,7.0,3.0,11.333333333333334,0.032258064516129,0.0476190476190476,0,93,51,7,13.285714285714286,1.8235294117647056,0.5483870967741935,42,51,0.4516129032258064,19.51315789473684,30.0,7.0,2.7419354838709675,12.118279569892474,7,train,30.0,95768,44859,0.6810072034531064,44859,140627,0.3189927965468935,11.135485385786456,0.0,30.0,177141,3642188,0.048635874919142,0.2703569216277515
87185,44570,3278341,5,4,1,0.8,11,4,8.8,12.2,30.0,0.0,1.8,16.4,0.017921146953405,0.0305343511450381,3,279,148,14,19.928571428571427,1.885135135135135,0.5304659498207885,131,148,0.4695340501792114,16.070038910505836,30.0,0.0,1.7491039426523298,14.89247311827957,14,test,30.0,6843,5298,0.5636273783049173,5298,12141,0.4363726216950828,12.486755858848882,0.0,30.0,159213,3418021,0.0465804627882625,0.3897921589068202
172302,32981,1243888,2,1,1,0.5,4,3,1.0,3.5,4.0,3.0,3.5,15.0,0.0465116279069767,0.1666666666666666,1,43,37,5,8.6,1.162162162162162,0.8604651162790697,6,37,0.1395348837209302,5.269230769230769,8.0,3.0,2.441860465116279,12.30232558139535,5,test,17.0,2033,1307,0.608682634730539,1307,3340,0.3913173652694611,11.486433474991829,0.0,30.0,11566,37691,0.3068637075163832,0.0844536577530779
112606,23341,330621,11,10,1,0.9090909090909092,14,1,14.272727272727272,17.3,30.0,7.0,1.181818181818182,12.0,0.0212355212355212,0.0297619047619047,1,518,182,15,34.53333333333333,2.8461538461538463,0.3513513513513513,336,182,0.6486486486486487,18.054054054054053,30.0,7.0,1.1853281853281854,11.971042471042471,15,train,20.0,4115,2681,0.605503237198352,2681,6796,0.394496762801648,14.160586319218242,0.0,30.0,92240,452134,0.2040103155259281,0.1904864472757198


### Creating data sets for learning and testing.
* order_products_train.csv (train table) is given the ordered label value for the train.
* The pk in the trains table is order_id + product_id, but in reality, only one order_id is assigned to one user_id, so it is unique to user_id + product_id.
* Create an order_train_prods table to obtain user_id by joining the trains table and the orders table.
* If you join two tables with user_id + product_id (order_train_prods left-out) to attach attributes generated by data_mart relative to the order_train_prods table, a lot of data is not joined.
* Property created by data_mart is not available when not joined.
* Set the label value of ordered by joining order_trains_prods based on data_mart (data_mart left outer join) and set the ordered value to 0 if not joined.

In [0]:
%sql
select * from trains limit 10

order_id,product_id,add_to_cart_order,reordered
1,49302,1,1
1,11109,2,1
1,10246,3,0
1,49683,4,0
1,43633,5,1
1,13176,6,0
1,47209,7,0
1,22035,8,1
36,39612,1,0
36,19660,2,1


In [0]:
%sql
--1384617
select count(*) from trains

count(1)
1384617


In [0]:
%fs
rm -r dbfs:/user/hive/warehouse/order_trains_prods

In [0]:
%sql
drop table if exists order_trains_prods;
-- order_products_train 데이터에(trains 테이블)에 user_id를 얻기 위해서 orders 테이블과 조인
-- 해당 테이블은 kaggle에서 train 용으로 제공었지만, 많은 속성(feature)로 만들어진 data_mart 테이블에 비해 적은 속성을 가지고 있음. 
create table order_trains_prods
as
select a.order_id, a.product_id, a.reordered
  , b.user_id
from trains a, orders b
where a.order_id = b.order_id

num_affected_rows,num_inserted_rows


In [0]:
%sql
select count(*) from order_trains_prods

count(1)
1384617


In [0]:
%sql
select * from order_trains_prods limit 10

order_id,product_id,reordered,user_id
762,21137,1,41751
762,41220,0,41751
762,15872,0,41751
762,30391,0,41751
844,14992,1,27766
844,21405,1,27766
844,11182,1,27766
844,28289,1,27766
844,9387,1,27766
844,18599,1,27766


In [0]:
%sql
-- user_id + product_id 로 중복되는 건수가 없음. 즉 user_id + product_id로 unique
select user_id, product_id, count(*) from order_trains_prods group by user_id, product_id having count(*) > 1

user_id,product_id,count(1)


In [0]:
%sql
-- 555793 건이 user_id + product_id 레벨로 order_trains_prods과 data_mart와 조인되지 않음. 
select count(*) 
from
order_trains_prods a
left outer join data_mart b
on a.user_id = b.user_id and a.product_id = b.product_id
where b.product_id is null

count(1)
555793


In [0]:
%sql
-- 개별 user_id 레벨, 또는 개별 product_id 레벨로 조인이 되지 않는 건수는 거의 없음. 즉 개별 user는 동일하지만 user+상품은 prior와 train에 별도로 존재하는 경우가 많다는 의미
with
data_user_grp as
(
  select user_id from data_mart group by user_id
),
data_product_grp as
(
  select product_id from data_mart group by product_id
)
select 'only_user_id_count' as gubun, count(*) from order_trains_prods a left outer join data_user_grp b on a.user_id = b.user_id 
where b.user_id is null
union all
select 'only_product_id_count' as gubun, count(*) from order_trains_prods a left outer join data_product_grp b on a.product_id = b.product_id 
where b.product_id is null

gubun,count(1)
only_user_id_count,0
only_product_id_count,9


In [0]:
%sql
-- data_mart와 user_id + product_id로 조인되지 않는 order_trains_prods의 reordered는 모두 0 임. 
select a.reordered, count(*) 
from
order_trains_prods a
left outer join data_mart b
on a.user_id = b.user_id and a.product_id = b.product_id
where b.product_id is null
group by a.reordered

reordered,count(1)
0,555793


In [0]:
%fs
rm -r dbfs:/user/hive/warehouse/train_data

In [0]:
print(spark.sql("select * from data_mart").columns)

['user_id', 'product_id', 'order_id', 'up_cnt', 'up_reord_cnt', 'up_no_reord_cnt', 'up_reoredered_avg', 'up_max_ord_num', 'up_min_ord_num', 'up_avg_cart', 'up_avg_prior_days', 'up_max_prior_days', 'up_min_prior_days', 'up_avg_ord_dow', 'up_avg_ord_hour', 'up_usr_ratio', 'up_usr_reord_ratio', 'up_usr_ord_num_diff', 'usr_total_cnt', 'prd_uq_cnt', 'order_uq_cnt', 'usr_avg_prd_cnt', 'usr_avg_uq_prd_cnt', 'usr_uq_prd_ratio', 'usr_reord_cnt', 'usr_no_reord_cnt', 'usr_reordered_avg', 'usr_avg_prior_days', 'usr_max_prior_days', 'usr_min_prior_days', 'usr_avg_order_dow', 'usr_avg_order_hour_of_day', 'usr_max_order_number', 'eval_set', 'days_since_prior_order', 'prd_reordered_cnt', 'prd_no_reordered_cnt', 'prd_avg_reordered', 'prd_unq_usr_cnt', 'prd_total_cnt', 'prd_usr_ratio', 'prd_avg_prior_days', 'prd_min_prior_days', 'prd_max_prior_days', 'aisle_distinct_usr_cnt', 'aisle_total_cnt', 'aisle_usr_ratio', 'usr_ratio_diff']


#### Create Train dataset

In [0]:
%sql
-- 학습용 feature와 label 데이터 세트 생성. 
-- order_trains_prods를 기준으로 data_mart와 outer 조인하면 많은 데이터가 조인되지 않음. 이 경우 해당 데이터는 data_mart의 속성들을 사용할 수 없음. 
-- data_mart를 기준으로 order_trains_prods를 outer 조인하면 역시 많은 데이터가 조인되지 않음. data_mart의 속성은 여전히 사용할 수 있음. 
-- order_trains_prods를 기준으로 학습 데이터를 만들지 않고 data_mart를 기준으로 학습 데이터를 생성. 
-- order_trains_prods의 eval_set가 'train' 인 경우 user_id 레벨로 학습 데이터이므로 이를 이용하여 학습 데이터를 생성. 
-- data_mart와 order_trains_prods가 조인이 되는 경우 order_trains_prods의 reorder값을 이용하고, 조인이 되지 않는 경우는 0으로 (추후)변경
drop table if exists train_data;

create table train_data
as
select 
-- user_id, product_id, order_id -- 학습용 feature 데이터를 만들기에 user_id, product_id, order_id 와 같은 id 속성은 제외
  up_cnt, up_reord_cnt, up_no_reord_cnt, up_reoredered_avg, up_max_ord_num, up_min_ord_num, up_avg_cart, up_avg_prior_days, up_max_prior_days, up_min_prior_days
, up_avg_ord_dow, up_avg_ord_hour, up_usr_ratio, up_usr_reord_ratio, up_usr_ord_num_diff, usr_total_cnt, prd_uq_cnt, order_uq_cnt, usr_avg_prd_cnt, usr_avg_uq_prd_cnt
, usr_uq_prd_ratio, usr_reord_cnt, usr_no_reord_cnt, usr_reordered_avg, usr_avg_prior_days, usr_max_prior_days, usr_min_prior_days, usr_avg_order_dow
, usr_avg_order_hour_of_day, usr_max_order_number
--, eval_set -- eval_set 제외
, days_since_prior_order, prd_reordered_cnt, prd_no_reordered_cnt, prd_avg_reordered, prd_unq_usr_cnt
, prd_total_cnt, prd_usr_ratio, prd_avg_prior_days, prd_min_prior_days, prd_max_prior_days, aisle_distinct_usr_cnt
, aisle_total_cnt, aisle_usr_ratio, usr_ratio_diff
, b.reordered -- label 값. order_train_prods와 조인되지 않는 경우 label을 0으로 변경. nvl(b.reordered, 0) 적용
from data_mart a left outer join order_trains_prods b
on a.user_id = b.user_id and a.product_id = b.product_id
where a.eval_set = 'train'

num_affected_rows,num_inserted_rows


In [0]:
%sql
-- data_mart에서 order_trains_prods와 user_id+product_id로 조인되지 않는 건은 7645837 건. 추후에 DataFrame에서 이들 데이터의 reordered를 모두 0으로 변경. 
select 'train_data' as gubun, count(*) from train_data
union all
select 'reordered null' as gubun, count(*) from train_data where reordered is null

gubun,count(1)
reordered null,7645837
train_data,8474661


#### Create Test dataset

In [0]:
%fs
rm -r dbfs:/user/hive/warehouse/test_data

In [0]:
%sql
-- 테스트용 데이터 세트 생성. data_mart에서 eval_set이 test 인것만 추출하여 생성. reordered 필요 없음. 
drop table if exists test_data;

create table test_data
as
select 
user_id, product_id, order_id -- 테스트 데이터와 학습 데이터와 마찬가지로 id 속성값은 필요가 없지만, 추후에 kaggle 테스트 성능 결과 제출을 위해 order_id, product_id가 필요. 추후 제거. 
, up_cnt, up_reord_cnt, up_no_reord_cnt, up_reoredered_avg, up_max_ord_num, up_min_ord_num, up_avg_cart, up_avg_prior_days, up_max_prior_days, up_min_prior_days
, up_avg_ord_dow, up_avg_ord_hour, up_usr_ratio, up_usr_reord_ratio, up_usr_ord_num_diff, usr_total_cnt, prd_uq_cnt, order_uq_cnt, usr_avg_prd_cnt, usr_avg_uq_prd_cnt
, usr_uq_prd_ratio, usr_reord_cnt, usr_no_reord_cnt, usr_reordered_avg, usr_avg_prior_days, usr_max_prior_days, usr_min_prior_days, usr_avg_order_dow
, usr_avg_order_hour_of_day, usr_max_order_number
--, eval_set -- eval_set 제외
, days_since_prior_order, prd_reordered_cnt, prd_no_reordered_cnt, prd_avg_reordered, prd_unq_usr_cnt
, prd_total_cnt, prd_usr_ratio, prd_avg_prior_days, prd_min_prior_days, prd_max_prior_days, aisle_distinct_usr_cnt
, aisle_total_cnt, aisle_usr_ratio, usr_ratio_diff
--, b.reordered -- label 제외
from data_mart a where a.eval_set = 'test' -- data_mart에서 eval_set이 test 인것만 추출하여 생성. 

num_affected_rows,num_inserted_rows


In [0]:
%sql
select count(*) from test_data

count(1)
4833289


### Pre-processing of learning data, learning models, and performing predictive evaluations
* Treat all learning data null values as zero

-- %fs
-- ls dbfs:/user/hive/warehouse/train_data

In [0]:
spark.catalog.listTables()

Out[13]: [Table(name='data_mart', database='default', description=None, tableType='MANAGED', isTemporary=False),
 Table(name='order_priors_prods', database='default', description=None, tableType='MANAGED', isTemporary=False),
 Table(name='order_trains_prods', database='default', description=None, tableType='MANAGED', isTemporary=False),
 Table(name='prd_mart', database='default', description=None, tableType='MANAGED', isTemporary=False),
 Table(name='test_data', database='default', description=None, tableType='MANAGED', isTemporary=False),
 Table(name='train_data', database='default', description=None, tableType='MANAGED', isTemporary=False),
 Table(name='up_mart', database='default', description=None, tableType='MANAGED', isTemporary=False),
 Table(name='user_mart', database='default', description=None, tableType='MANAGED', isTemporary=False),
 Table(name='user_mart_01', database='default', description=None, tableType='MANAGED', isTemporary=False),
 Table(name='aisles', database=None,

In [0]:
spark.sql("set spark.databricks.delta.formatCheck.enabled=false")

Out[14]: DataFrame[key: string, value: string]

#### SQL to Pyspark Dataframe

In [0]:
# train_data와 test_data 테이블을 DataFrame으로 변환. 
train_sdf = spark.sql("select * from train_data")
test_sdf = spark.sql("select * from test_data")
print('train_sdf type:', type(train_sdf))
print('test_sdf type:', type(test_sdf))

train_sdf type: <class 'pyspark.sql.dataframe.DataFrame'>
test_sdf type: <class 'pyspark.sql.dataframe.DataFrame'>


In [0]:
train_sdf.printSchema()

root
 |-- up_cnt: long (nullable = true)
 |-- up_reord_cnt: long (nullable = true)
 |-- up_no_reord_cnt: long (nullable = true)
 |-- up_reoredered_avg: double (nullable = true)
 |-- up_max_ord_num: integer (nullable = true)
 |-- up_min_ord_num: integer (nullable = true)
 |-- up_avg_cart: double (nullable = true)
 |-- up_avg_prior_days: double (nullable = true)
 |-- up_max_prior_days: double (nullable = true)
 |-- up_min_prior_days: double (nullable = true)
 |-- up_avg_ord_dow: double (nullable = true)
 |-- up_avg_ord_hour: double (nullable = true)
 |-- up_usr_ratio: double (nullable = true)
 |-- up_usr_reord_ratio: double (nullable = true)
 |-- up_usr_ord_num_diff: integer (nullable = true)
 |-- usr_total_cnt: long (nullable = true)
 |-- prd_uq_cnt: long (nullable = true)
 |-- order_uq_cnt: long (nullable = true)
 |-- usr_avg_prd_cnt: double (nullable = true)
 |-- usr_avg_uq_prd_cnt: double (nullable = true)
 |-- usr_uq_prd_ratio: double (nullable = true)
 |-- usr_reord_cnt: long (null

In [0]:
# 컬럼별로 Null 인 경우만 count하는 select 로직.
import pyspark.sql.functions as F
# up_avg_prior_days, up_max_prior_days, up_min_prior_days이 각각 552218건, up_max_prior_days 552218건, up_usr_reord_ratio가 30912건,  
# prd_avg_prior_days가 29건, prd_max_prior_days가 29 건이 null임. up_usr_reord_ratio는 사용자별 재주문도 0건, 사용자 상품별 재주문도 0건이어 무한대가 Null로 처리됨. 
# reoredered는 7654837 건이 null임. 
display(train_sdf.select([F.count(F.when(F.col(c).isNull(), c)).alias(c) for c in train_sdf.columns]))

up_cnt,up_reord_cnt,up_no_reord_cnt,up_reoredered_avg,up_max_ord_num,up_min_ord_num,up_avg_cart,up_avg_prior_days,up_max_prior_days,up_min_prior_days,up_avg_ord_dow,up_avg_ord_hour,up_usr_ratio,up_usr_reord_ratio,up_usr_ord_num_diff,usr_total_cnt,prd_uq_cnt,order_uq_cnt,usr_avg_prd_cnt,usr_avg_uq_prd_cnt,usr_uq_prd_ratio,usr_reord_cnt,usr_no_reord_cnt,usr_reordered_avg,usr_avg_prior_days,usr_max_prior_days,usr_min_prior_days,usr_avg_order_dow,usr_avg_order_hour_of_day,usr_max_order_number,days_since_prior_order,prd_reordered_cnt,prd_no_reordered_cnt,prd_avg_reordered,prd_unq_usr_cnt,prd_total_cnt,prd_usr_ratio,prd_avg_prior_days,prd_min_prior_days,prd_max_prior_days,aisle_distinct_usr_cnt,aisle_total_cnt,aisle_usr_ratio,usr_ratio_diff,reordered
0,0,0,0,0,0,0,552218,552218,552218,0,0,0,30912,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,7645837


In [0]:
train_sdf = train_sdf.fillna(0)

#### Feature vectorization 
* Select some columns for training. 
* reordered : It is label. So drop only that column

In [0]:
# feature vectorization 적용할 column명 추출. label 컬럼인 reordered는 제외
vector_columns = [column_name for column_name, column_type in train_sdf.dtypes if column_name != 'reordered']
print(vector_columns)

['up_cnt', 'up_reord_cnt', 'up_no_reord_cnt', 'up_reoredered_avg', 'up_max_ord_num', 'up_min_ord_num', 'up_avg_cart', 'up_avg_prior_days', 'up_max_prior_days', 'up_min_prior_days', 'up_avg_ord_dow', 'up_avg_ord_hour', 'up_usr_ratio', 'up_usr_reord_ratio', 'up_usr_ord_num_diff', 'usr_total_cnt', 'prd_uq_cnt', 'order_uq_cnt', 'usr_avg_prd_cnt', 'usr_avg_uq_prd_cnt', 'usr_uq_prd_ratio', 'usr_reord_cnt', 'usr_no_reord_cnt', 'usr_reordered_avg', 'usr_avg_prior_days', 'usr_max_prior_days', 'usr_min_prior_days', 'usr_avg_order_dow', 'usr_avg_order_hour_of_day', 'usr_max_order_number', 'days_since_prior_order', 'prd_reordered_cnt', 'prd_no_reordered_cnt', 'prd_avg_reordered', 'prd_unq_usr_cnt', 'prd_total_cnt', 'prd_usr_ratio', 'prd_avg_prior_days', 'prd_min_prior_days', 'prd_max_prior_days', 'aisle_distinct_usr_cnt', 'aisle_total_cnt', 'aisle_usr_ratio', 'usr_ratio_diff']


In [0]:
# feature vector화 적용 
from pyspark.ml.feature import VectorAssembler

vector_assembler = VectorAssembler(inputCols=vector_columns, outputCol='features')
train_sdf_vectorized = vector_assembler.transform(train_sdf)

display(train_sdf_vectorized.limit(10))

up_cnt,up_reord_cnt,up_no_reord_cnt,up_reoredered_avg,up_max_ord_num,up_min_ord_num,up_avg_cart,up_avg_prior_days,up_max_prior_days,up_min_prior_days,up_avg_ord_dow,up_avg_ord_hour,up_usr_ratio,up_usr_reord_ratio,up_usr_ord_num_diff,usr_total_cnt,prd_uq_cnt,order_uq_cnt,usr_avg_prd_cnt,usr_avg_uq_prd_cnt,usr_uq_prd_ratio,usr_reord_cnt,usr_no_reord_cnt,usr_reordered_avg,usr_avg_prior_days,usr_max_prior_days,usr_min_prior_days,usr_avg_order_dow,usr_avg_order_hour_of_day,usr_max_order_number,days_since_prior_order,prd_reordered_cnt,prd_no_reordered_cnt,prd_avg_reordered,prd_unq_usr_cnt,prd_total_cnt,prd_usr_ratio,prd_avg_prior_days,prd_min_prior_days,prd_max_prior_days,aisle_distinct_usr_cnt,aisle_total_cnt,aisle_usr_ratio,usr_ratio_diff,reordered,features
6,5,1,0.8333333333333334,42,2,5.666666666666667,5.166666666666667,15.0,0.0,2.5,13.666666666666666,0.008695652173913,0.011574074074074,5,690,258,47,14.680851063829786,2.6744186046511627,0.3739130434782609,432,258,0.6260869565217392,7.616191904047976,22.0,0.0,2.1333333333333333,13.7,47,5.0,6852,5509,0.5543240838119893,5509,12361,0.4456759161880106,11.14011466296039,0.0,30.0,159418,1765313,0.0903057984618025,0.3553701177262081,0,"Map(vectorType -> dense, length -> 44, values -> List(6.0, 5.0, 1.0, 0.8333333333333334, 42.0, 2.0, 5.666666666666667, 5.166666666666667, 15.0, 0.0, 2.5, 13.666666666666666, 0.008695652173913044, 0.011574074074074073, 5.0, 690.0, 258.0, 47.0, 14.680851063829786, 2.6744186046511627, 0.3739130434782609, 432.0, 258.0, 0.6260869565217392, 7.616191904047976, 22.0, 0.0, 2.1333333333333333, 13.7, 47.0, 5.0, 6852.0, 5509.0, 0.5543240838119893, 5509.0, 12361.0, 0.44567591618801067, 11.14011466296039, 0.0, 30.0, 159418.0, 1765313.0, 0.09030579846180252, 0.35537011772620813))"
5,4,1,0.8,14,5,10.0,15.8,30.0,6.0,1.2,12.4,0.0137741046831955,0.0156862745098039,5,363,108,19,19.105263157894736,3.361111111111111,0.2975206611570248,255,108,0.7024793388429752,13.746438746438749,30.0,5.0,1.578512396694215,13.110192837465563,19,30.0,3250,2618,0.5538513974096796,2618,5868,0.4461486025903204,11.240999265246142,0.0,30.0,57255,193297,0.2962022173132537,0.1499463852770667,0,"Map(vectorType -> dense, length -> 44, values -> List(5.0, 4.0, 1.0, 0.8, 14.0, 5.0, 10.0, 15.8, 30.0, 6.0, 1.2, 12.4, 0.013774104683195593, 0.01568627450980392, 5.0, 363.0, 108.0, 19.0, 19.105263157894736, 3.361111111111111, 0.2975206611570248, 255.0, 108.0, 0.7024793388429752, 13.746438746438747, 30.0, 5.0, 1.578512396694215, 13.110192837465565, 19.0, 30.0, 3250.0, 2618.0, 0.5538513974096796, 2618.0, 5868.0, 0.4461486025903204, 11.240999265246142, 0.0, 30.0, 57255.0, 193297.0, 0.2962022173132537, 0.14994638527706672))"
3,2,1,0.6666666666666666,7,3,8.333333333333334,20.666666666666668,30.0,7.0,3.0,11.333333333333334,0.032258064516129,0.0476190476190476,0,93,51,7,13.285714285714286,1.8235294117647056,0.5483870967741935,42,51,0.4516129032258064,19.51315789473684,30.0,7.0,2.7419354838709675,12.118279569892474,7,30.0,95768,44859,0.6810072034531064,44859,140627,0.3189927965468935,11.135485385786456,0.0,30.0,177141,3642188,0.048635874919142,0.2703569216277515,0,"Map(vectorType -> dense, length -> 44, values -> List(3.0, 2.0, 1.0, 0.6666666666666666, 7.0, 3.0, 8.333333333333334, 20.666666666666668, 30.0, 7.0, 3.0, 11.333333333333334, 0.03225806451612903, 0.047619047619047616, 0.0, 93.0, 51.0, 7.0, 13.285714285714286, 1.8235294117647058, 0.5483870967741935, 42.0, 51.0, 0.45161290322580644, 19.513157894736842, 30.0, 7.0, 2.7419354838709675, 12.118279569892474, 7.0, 30.0, 95768.0, 44859.0, 0.6810072034531064, 44859.0, 140627.0, 0.31899279654689355, 11.135485385786456, 0.0, 30.0, 177141.0, 3642188.0, 0.04863587491914201, 0.27035692162775155))"
11,10,1,0.9090909090909092,14,1,14.272727272727272,17.3,30.0,7.0,1.181818181818182,12.0,0.0212355212355212,0.0297619047619047,1,518,182,15,34.53333333333333,2.8461538461538463,0.3513513513513513,336,182,0.6486486486486487,18.054054054054053,30.0,7.0,1.1853281853281854,11.971042471042471,15,20.0,4115,2681,0.605503237198352,2681,6796,0.394496762801648,14.160586319218242,0.0,30.0,92240,452134,0.2040103155259281,0.1904864472757198,1,"Map(vectorType -> dense, length -> 44, values -> List(11.0, 10.0, 1.0, 0.9090909090909091, 14.0, 1.0, 14.272727272727273, 17.3, 30.0, 7.0, 1.1818181818181819, 12.0, 0.021235521235521235, 0.02976190476190476, 1.0, 518.0, 182.0, 15.0, 34.53333333333333, 2.8461538461538463, 0.35135135135135137, 336.0, 182.0, 0.6486486486486487, 18.054054054054053, 30.0, 7.0, 1.1853281853281854, 11.971042471042471, 15.0, 20.0, 4115.0, 2681.0, 0.605503237198352, 2681.0, 6796.0, 0.394496762801648, 14.160586319218242, 0.0, 30.0, 92240.0, 452134.0, 0.20401031552592816, 0.19048644727571987))"
2,1,1,0.5,22,17,5.0,6.5,8.0,5.0,3.0,14.5,0.015625,0.0172413793103448,0,128,70,22,5.818181818181818,1.8285714285714283,0.546875,58,70,0.453125,7.573913043478261,23.0,0.0,2.5703125,13.515625,22,30.0,1061,758,0.583287520615723,758,1819,0.4167124793842771,10.576741440377804,0.0,30.0,44854,163524,0.2742961277855238,0.1424163515987532,0,"Map(vectorType -> dense, length -> 44, values -> List(2.0, 1.0, 1.0, 0.5, 22.0, 17.0, 5.0, 6.5, 8.0, 5.0, 3.0, 14.5, 0.015625, 0.017241379310344827, 0.0, 128.0, 70.0, 22.0, 5.818181818181818, 1.8285714285714285, 0.546875, 58.0, 70.0, 0.453125, 7.573913043478261, 23.0, 0.0, 2.5703125, 13.515625, 22.0, 30.0, 1061.0, 758.0, 0.583287520615723, 758.0, 1819.0, 0.4167124793842771, 10.576741440377804, 0.0, 30.0, 44854.0, 163524.0, 0.2742961277855238, 0.14241635159875327))"
25,24,1,0.96,52,7,9.08,5.16,10.0,2.0,2.16,8.96,0.0358166189111747,0.0411663807890223,0,698,115,52,13.423076923076923,6.069565217391304,0.164756446991404,583,115,0.835243553008596,4.965116279069767,11.0,1.0,2.0773638968481376,9.101719197707736,52,11.0,607,247,0.7107728337236534,247,854,0.2892271662763466,10.057810578105782,0.0,30.0,53197,249341,0.2133503916323428,0.0758767746440037,1,"Map(vectorType -> dense, length -> 44, values -> List(25.0, 24.0, 1.0, 0.96, 52.0, 7.0, 9.08, 5.16, 10.0, 2.0, 2.16, 8.96, 0.03581661891117478, 0.0411663807890223, 0.0, 698.0, 115.0, 52.0, 13.423076923076923, 6.069565217391304, 0.16475644699140402, 583.0, 115.0, 0.835243553008596, 4.965116279069767, 11.0, 1.0, 2.0773638968481376, 9.101719197707736, 52.0, 11.0, 607.0, 247.0, 0.7107728337236534, 247.0, 854.0, 0.2892271662763466, 10.057810578105782, 0.0, 30.0, 53197.0, 249341.0, 0.21335039163234285, 0.07587677464400375))"
1,0,1,0.0,9,9,10.0,28.0,28.0,28.0,4.0,16.0,0.0158730158730158,0.0,0,63,45,9,7.0,1.4,0.7142857142857143,18,45,0.2857142857142857,21.946428571428573,30.0,6.0,1.8412698412698407,16.317460317460316,9,14.0,19327,9797,0.6636107677516825,9797,29124,0.3363892322483175,11.38551075861056,0.0,30.0,27077,129474,0.2091307907379087,0.1272584415104087,0,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 9.0, 9.0, 10.0, 28.0, 28.0, 28.0, 4.0, 16.0, 0.015873015873015872, 0.0, 0.0, 63.0, 45.0, 9.0, 7.0, 1.4, 0.7142857142857143, 18.0, 45.0, 0.2857142857142857, 21.946428571428573, 30.0, 6.0, 1.8412698412698412, 16.317460317460316, 9.0, 14.0, 19327.0, 9797.0, 0.6636107677516825, 9797.0, 29124.0, 0.3363892322483175, 11.385510758610561, 0.0, 30.0, 27077.0, 129474.0, 0.20913079073790877, 0.12725844151040874))"
6,5,1,0.8333333333333334,12,1,3.6666666666666665,24.2,30.0,11.0,2.5,15.5,0.1034482758620689,0.1785714285714285,0,58,30,12,4.833333333333333,1.9333333333333331,0.5172413793103449,28,30,0.4827586206896552,22.403846153846157,30.0,3.0,2.2758620689655173,14.89655172413793,12,30.0,315913,63537,0.832555013835815,63537,379450,0.167444986164185,10.06454486916168,0.0,30.0,177141,3642188,0.048635874919142,0.118809111245043,0,"Map(vectorType -> dense, length -> 44, values -> List(6.0, 5.0, 1.0, 0.8333333333333334, 12.0, 1.0, 3.6666666666666665, 24.2, 30.0, 11.0, 2.5, 15.5, 0.10344827586206896, 0.17857142857142858, 0.0, 58.0, 30.0, 12.0, 4.833333333333333, 1.9333333333333333, 0.5172413793103449, 28.0, 30.0, 0.4827586206896552, 22.403846153846153, 30.0, 3.0, 2.2758620689655173, 14.89655172413793, 12.0, 30.0, 315913.0, 63537.0, 0.832555013835815, 63537.0, 379450.0, 0.16744498616418502, 10.06454486916168, 0.0, 30.0, 177141.0, 3642188.0, 0.04863587491914201, 0.11880911124504301))"
6,5,1,0.8333333333333334,16,9,4.666666666666667,18.166666666666668,30.0,9.0,2.833333333333333,13.333333333333334,0.0161290322580645,0.0210970464135021,5,372,135,21,17.714285714285715,2.7555555555555555,0.3629032258064516,237,135,0.6370967741935484,15.79886685552408,30.0,2.0,2.9516129032258065,13.526881720430108,21,8.0,315913,63537,0.832555013835815,63537,379450,0.167444986164185,10.06454486916168,0.0,30.0,177141,3642188,0.048635874919142,0.118809111245043,0,"Map(vectorType -> dense, length -> 44, values -> List(6.0, 5.0, 1.0, 0.8333333333333334, 16.0, 9.0, 4.666666666666667, 18.166666666666668, 30.0, 9.0, 2.8333333333333335, 13.333333333333334, 0.016129032258064516, 0.02109704641350211, 5.0, 372.0, 135.0, 21.0, 17.714285714285715, 2.7555555555555555, 0.3629032258064516, 237.0, 135.0, 0.6370967741935484, 15.798866855524079, 30.0, 2.0, 2.9516129032258065, 13.526881720430108, 21.0, 8.0, 315913.0, 63537.0, 0.832555013835815, 63537.0, 379450.0, 0.16744498616418502, 10.06454486916168, 0.0, 30.0, 177141.0, 3642188.0, 0.04863587491914201, 0.11880911124504301))"
1,0,1,0.0,4,4,13.0,6.0,6.0,6.0,0.0,17.0,0.0078740157480314,0.0,3,127,81,7,18.142857142857142,1.567901234567901,0.6377952755905512,46,81,0.3622047244094488,8.839285714285714,14.0,6.0,1.5196850393700787,17.078740157480315,7,5.0,12792,7277,0.6374009666650057,7277,20069,0.3625990333349942,11.452100930182828,0.0,30.0,78030,395130,0.1974793106066358,0.1651197227283584,0,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 4.0, 4.0, 13.0, 6.0, 6.0, 6.0, 0.0, 17.0, 0.007874015748031496, 0.0, 3.0, 127.0, 81.0, 7.0, 18.142857142857142, 1.5679012345679013, 0.6377952755905512, 46.0, 81.0, 0.36220472440944884, 8.839285714285714, 14.0, 6.0, 1.5196850393700787, 17.078740157480315, 7.0, 5.0, 12792.0, 7277.0, 0.6374009666650057, 7277.0, 20069.0, 0.36259903333499427, 11.452100930182828, 0.0, 30.0, 78030.0, 395130.0, 0.1974793106066358, 0.16511972272835848))"


#### Creation of Estimator Model : With RandomForestClassifier
* rf_estimaitor is RandomForestClassifier 
* FeatureCol is features, and labelCol is reordered.

In [0]:
# 학습 데이터로 학습하여 Estimator Model 생성.. 7~8분 정도 걸림. 
from pyspark.ml.classification import RandomForestClassifier

rf_estimator = RandomForestClassifier(featuresCol='features', labelCol='reordered')
rf_model = rf_estimator.fit(train_sdf_vectorized)

#### Test Dataset preprocessing, and Prediction

In [0]:
test_sdf = spark.sql("select * from test_data")

In [0]:
test_sdf.printSchema()

root
 |-- user_id: integer (nullable = true)
 |-- product_id: long (nullable = true)
 |-- order_id: integer (nullable = true)
 |-- up_cnt: long (nullable = true)
 |-- up_reord_cnt: long (nullable = true)
 |-- up_no_reord_cnt: long (nullable = true)
 |-- up_reoredered_avg: double (nullable = true)
 |-- up_max_ord_num: integer (nullable = true)
 |-- up_min_ord_num: integer (nullable = true)
 |-- up_avg_cart: double (nullable = true)
 |-- up_avg_prior_days: double (nullable = true)
 |-- up_max_prior_days: double (nullable = true)
 |-- up_min_prior_days: double (nullable = true)
 |-- up_avg_ord_dow: double (nullable = true)
 |-- up_avg_ord_hour: double (nullable = true)
 |-- up_usr_ratio: double (nullable = true)
 |-- up_usr_reord_ratio: double (nullable = true)
 |-- up_usr_ord_num_diff: integer (nullable = true)
 |-- usr_total_cnt: long (nullable = true)
 |-- prd_uq_cnt: long (nullable = true)
 |-- order_uq_cnt: long (nullable = true)
 |-- usr_avg_prd_cnt: double (nullable = true)
 |-- us

In [0]:
test_sdf_id = test_sdf.select('user_id', 'product_id', 'order_id')
test_sdf = test_sdf.drop('user_id', 'product_id', 'order_id')
display(test_sdf.limit(10))

up_cnt,up_reord_cnt,up_no_reord_cnt,up_reoredered_avg,up_max_ord_num,up_min_ord_num,up_avg_cart,up_avg_prior_days,up_max_prior_days,up_min_prior_days,up_avg_ord_dow,up_avg_ord_hour,up_usr_ratio,up_usr_reord_ratio,up_usr_ord_num_diff,usr_total_cnt,prd_uq_cnt,order_uq_cnt,usr_avg_prd_cnt,usr_avg_uq_prd_cnt,usr_uq_prd_ratio,usr_reord_cnt,usr_no_reord_cnt,usr_reordered_avg,usr_avg_prior_days,usr_max_prior_days,usr_min_prior_days,usr_avg_order_dow,usr_avg_order_hour_of_day,usr_max_order_number,days_since_prior_order,prd_reordered_cnt,prd_no_reordered_cnt,prd_avg_reordered,prd_unq_usr_cnt,prd_total_cnt,prd_usr_ratio,prd_avg_prior_days,prd_min_prior_days,prd_max_prior_days,aisle_distinct_usr_cnt,aisle_total_cnt,aisle_usr_ratio,usr_ratio_diff
8,7,1,0.875,27,4,4.875,15.5,30.0,7.0,3.25,13.0,0.023598820058997,0.0424242424242424,0,339,174,27,12.555555555555555,1.9482758620689653,0.5132743362831859,165,174,0.4867256637168141,12.615615615615615,30.0,2.0,3.274336283185841,11.050147492625369,27,20.0,1057,738,0.588857938718663,738,1795,0.411142061281337,12.03411131059246,0.0,30.0,46636,207075,0.2252130870457563,0.1859289742355807
11,10,1,0.9090909090909092,17,1,6.818181818181818,7.5,15.0,5.0,3.0,13.0,0.0311614730878186,0.0423728813559322,2,353,117,19,18.57894736842105,3.017094017094017,0.3314447592067989,236,117,0.6685552407932012,10.475609756097562,30.0,5.0,2.903682719546742,13.025495750708217,19,19.0,2411,1055,0.6956145412579342,1055,3466,0.3043854587420658,9.70420829548895,0.0,30.0,124393,1452343,0.0856498774738474,0.2187355812682183
16,15,1,0.9375,99,62,7.875,3.875,15.0,1.0,2.625,13.75,0.0133111480865224,0.0155440414507772,0,1202,237,99,12.141414141414142,5.071729957805907,0.1971713810316139,965,237,0.802828618968386,3.687925170068027,18.0,0.0,2.952579034941764,14.537437603993345,99,1.0,1177,897,0.5675024108003858,897,2074,0.4324975891996143,10.902051282051282,0.0,30.0,78030,395130,0.1974793106066358,0.2350182785929785
2,1,1,0.5,17,13,24.5,23.0,30.0,16.0,4.0,19.5,0.0086580086580086,0.0070921985815602,0,231,90,17,13.588235294117649,2.566666666666667,0.3896103896103896,141,90,0.6103896103896104,14.748815165876778,30.0,1.0,3.0043290043290045,14.80952380952381,17,7.0,205845,58838,0.7777038948477991,58838,264683,0.2222961051522009,10.03712850345864,0.0,30.0,177141,3642188,0.048635874919142,0.1736602302330589
5,4,1,0.8,11,4,8.8,12.2,30.0,0.0,1.8,16.4,0.017921146953405,0.0305343511450381,3,279,148,14,19.928571428571427,1.885135135135135,0.5304659498207885,131,148,0.4695340501792114,16.070038910505836,30.0,0.0,1.7491039426523298,14.89247311827957,14,30.0,6843,5298,0.5636273783049173,5298,12141,0.4363726216950828,12.486755858848882,0.0,30.0,159213,3418021,0.0465804627882625,0.3897921589068202
2,1,1,0.5,4,3,1.0,3.5,4.0,3.0,3.5,15.0,0.0465116279069767,0.1666666666666666,1,43,37,5,8.6,1.162162162162162,0.8604651162790697,6,37,0.1395348837209302,5.269230769230769,8.0,3.0,2.441860465116279,12.30232558139535,5,17.0,2033,1307,0.608682634730539,1307,3340,0.3913173652694611,11.486433474991829,0.0,30.0,11566,37691,0.3068637075163832,0.0844536577530779
1,0,1,0.0,3,3,4.0,3.0,3.0,3.0,1.0,19.0,0.0014619883040935,0.0,67,684,129,70,9.771428571428572,5.3023255813953485,0.1885964912280701,555,129,0.8114035087719298,5.411242603550296,14.0,0.0,2.6695906432748537,14.953216374269006,70,3.0,96,115,0.4549763033175355,115,211,0.5450236966824644,10.17412935323383,0.0,30.0,76177,306487,0.2485488780927086,0.2964748185897558
2,1,1,0.5,7,5,3.0,20.5,30.0,11.0,3.0,16.0,0.1052631578947368,1.0,0,19,18,7,2.7142857142857144,1.0555555555555556,0.9473684210526316,1,18,0.0526315789473684,17.72222222222222,30.0,2.0,4.421052631578948,12.947368421052632,7,30.0,1843,999,0.6484869809992962,999,2842,0.3515130190007037,11.666792595391009,0.0,30.0,99755,841533,0.1185396175788709,0.2329734014218328
33,32,1,0.9696969696969696,98,2,9.757575757575758,3.6666666666666665,6.0,1.0,3.212121212121212,11.636363636363637,0.0250950570342205,0.0311890838206627,0,1315,289,98,13.418367346938776,4.550173010380623,0.2197718631178707,1026,289,0.7802281368821293,3.78110599078341,17.0,1.0,3.173384030418251,12.89277566539924,98,2.0,37968,19517,0.6604853440027834,19517,57485,0.3395146559972166,10.508839368616528,0.0,30.0,159213,3418021,0.0465804627882625,0.2929341932089541
1,0,1,0.0,15,15,3.0,1.0,1.0,1.0,3.0,1.0,0.0016722408026755,0.0,22,598,180,37,16.16216216216216,3.3222222222222224,0.3010033444816054,418,180,0.6989966555183946,5.900840336134454,30.0,1.0,2.3444816053511706,10.175585284280936,37,2.0,1585,1801,0.4681039574719433,1801,3386,0.5318960425280567,13.657869327325397,0.0,30.0,77939,254317,0.3064639799934727,0.2254320625345839


In [0]:
test_sdf = test_sdf.fillna(0)
test_sdf_vectorized = vector_assembler.transform(test_sdf)

In [0]:
predictions = rf_model.transform(test_sdf_vectorized)
display(predictions)

up_cnt,up_reord_cnt,up_no_reord_cnt,up_reoredered_avg,up_max_ord_num,up_min_ord_num,up_avg_cart,up_avg_prior_days,up_max_prior_days,up_min_prior_days,up_avg_ord_dow,up_avg_ord_hour,up_usr_ratio,up_usr_reord_ratio,up_usr_ord_num_diff,usr_total_cnt,prd_uq_cnt,order_uq_cnt,usr_avg_prd_cnt,usr_avg_uq_prd_cnt,usr_uq_prd_ratio,usr_reord_cnt,usr_no_reord_cnt,usr_reordered_avg,usr_avg_prior_days,usr_max_prior_days,usr_min_prior_days,usr_avg_order_dow,usr_avg_order_hour_of_day,usr_max_order_number,days_since_prior_order,prd_reordered_cnt,prd_no_reordered_cnt,prd_avg_reordered,prd_unq_usr_cnt,prd_total_cnt,prd_usr_ratio,prd_avg_prior_days,prd_min_prior_days,prd_max_prior_days,aisle_distinct_usr_cnt,aisle_total_cnt,aisle_usr_ratio,usr_ratio_diff,features,rawPrediction,probability,prediction
1,0,1,0.0,3,3,8.0,4.0,4.0,4.0,2.0,12.0,0.0232558139534883,0.0,2,43,37,5,8.6,1.162162162162162,0.8604651162790697,6,37,0.1395348837209302,5.269230769230769,8.0,3.0,2.441860465116279,12.30232558139535,5,17.0,15,19,0.4411764705882353,19,34,0.5588235294117647,10.303030303030305,0.0,30.0,5367,11798,0.4549076114595694,0.1039159179521952,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 3.0, 3.0, 8.0, 4.0, 4.0, 4.0, 2.0, 12.0, 0.023255813953488372, 0.0, 2.0, 43.0, 37.0, 5.0, 8.6, 1.162162162162162, 0.8604651162790697, 6.0, 37.0, 0.13953488372093023, 5.269230769230769, 8.0, 3.0, 2.441860465116279, 12.30232558139535, 5.0, 17.0, 15.0, 19.0, 0.4411764705882353, 19.0, 34.0, 0.5588235294117647, 10.303030303030303, 0.0, 30.0, 5367.0, 11798.0, 0.4549076114595694, 0.10391591795219529))","Map(vectorType -> dense, length -> 2, values -> List(18.58401753606904, 1.41598246393096))","Map(vectorType -> dense, length -> 2, values -> List(0.9292008768034521, 0.07079912319654799))",0.0
3,2,1,0.6666666666666666,22,18,8.333333333333334,9.666666666666666,13.0,4.0,2.6666666666666665,12.666666666666666,0.0118577075098814,0.0144927536231884,0,253,115,22,11.5,2.2,0.4545454545454545,138,115,0.5454545454545454,14.646586345381523,30.0,1.0,3.4229249011857705,12.936758893280633,22,4.0,41329,25657,0.6169796673931867,25657,66986,0.3830203326068133,10.895465002678767,0.0,30.0,159213,3418021,0.0465804627882625,0.3364398698185508,"Map(vectorType -> dense, length -> 44, values -> List(3.0, 2.0, 1.0, 0.6666666666666666, 22.0, 18.0, 8.333333333333334, 9.666666666666666, 13.0, 4.0, 2.6666666666666665, 12.666666666666666, 0.011857707509881422, 0.014492753623188406, 0.0, 253.0, 115.0, 22.0, 11.5, 2.2, 0.45454545454545453, 138.0, 115.0, 0.5454545454545454, 14.646586345381525, 30.0, 1.0, 3.4229249011857705, 12.936758893280633, 22.0, 4.0, 41329.0, 25657.0, 0.6169796673931867, 25657.0, 66986.0, 0.38302033260681334, 10.895465002678769, 0.0, 30.0, 159213.0, 3418021.0, 0.04658046278826256, 0.3364398698185508))","Map(vectorType -> dense, length -> 2, values -> List(17.450519964560513, 2.54948003543949))","Map(vectorType -> dense, length -> 2, values -> List(0.8725259982280255, 0.1274740017719745))",0.0
17,16,1,0.9411764705882352,30,1,3.882352941176471,6.8125,26.0,1.0,3.411764705882353,15.058823529411764,0.0723404255319148,0.1584158415841584,1,235,134,31,7.580645161290323,1.7537313432835822,0.5702127659574469,101,134,0.4297872340425532,6.969026548672566,30.0,1.0,3.093617021276596,15.6,31,25.0,23463,5421,0.8123182384711258,5421,28884,0.1876817615288741,10.691012672095829,0.0,30.0,99755,841533,0.1185396175788709,0.0691421439500031,"Map(vectorType -> dense, length -> 44, values -> List(17.0, 16.0, 1.0, 0.9411764705882353, 30.0, 1.0, 3.8823529411764706, 6.8125, 26.0, 1.0, 3.411764705882353, 15.058823529411764, 0.07234042553191489, 0.15841584158415842, 1.0, 235.0, 134.0, 31.0, 7.580645161290323, 1.7537313432835822, 0.5702127659574469, 101.0, 134.0, 0.4297872340425532, 6.969026548672566, 30.0, 1.0, 3.0936170212765957, 15.6, 31.0, 25.0, 23463.0, 5421.0, 0.8123182384711258, 5421.0, 28884.0, 0.1876817615288741, 10.691012672095827, 0.0, 30.0, 99755.0, 841533.0, 0.11853961757887094, 0.06914214395000318))","Map(vectorType -> dense, length -> 2, values -> List(12.827097103561352, 7.17290289643865))","Map(vectorType -> dense, length -> 2, values -> List(0.6413548551780676, 0.3586451448219325))",0.0
3,2,1,0.6666666666666666,55,45,20.0,5.666666666666667,10.0,2.0,1.3333333333333333,9.0,0.0030895983522142,0.0028901734104046,5,971,279,60,16.183333333333334,3.4802867383512543,0.2873326467559217,692,279,0.7126673532440783,6.110759493670886,23.0,1.0,2.368692070030896,10.383110195674565,60,4.0,1904,867,0.6871165644171779,867,2771,0.3128834355828221,9.424610710216482,0.0,30.0,89271,498425,0.1791061844811155,0.1337772511017066,"Map(vectorType -> dense, length -> 44, values -> List(3.0, 2.0, 1.0, 0.6666666666666666, 55.0, 45.0, 20.0, 5.666666666666667, 10.0, 2.0, 1.3333333333333333, 9.0, 0.003089598352214212, 0.002890173410404624, 5.0, 971.0, 279.0, 60.0, 16.183333333333334, 3.4802867383512543, 0.28733264675592174, 692.0, 279.0, 0.7126673532440783, 6.110759493670886, 23.0, 1.0, 2.368692070030896, 10.383110195674563, 60.0, 4.0, 1904.0, 867.0, 0.6871165644171779, 867.0, 2771.0, 0.3128834355828221, 9.424610710216482, 0.0, 30.0, 89271.0, 498425.0, 0.1791061844811155, 0.1337772511017066))","Map(vectorType -> dense, length -> 2, values -> List(18.223688624151897, 1.7763113758481053))","Map(vectorType -> dense, length -> 2, values -> List(0.9111844312075947, 0.08881556879240525))",0.0
15,14,1,0.9333333333333332,84,15,10.533333333333331,3.4,6.0,1.0,3.0,14.666666666666666,0.0114068441064638,0.0136452241715399,14,1315,289,98,13.418367346938776,4.550173010380623,0.2197718631178707,1026,289,0.7802281368821293,3.78110599078341,17.0,1.0,3.173384030418251,12.89277566539924,98,2.0,16507,12806,0.5631289871388121,12806,29313,0.4368710128611878,11.550171871571711,0.0,30.0,159213,3418021,0.0465804627882625,0.3902905500729253,"Map(vectorType -> dense, length -> 44, values -> List(15.0, 14.0, 1.0, 0.9333333333333333, 84.0, 15.0, 10.533333333333333, 3.4, 6.0, 1.0, 3.0, 14.666666666666666, 0.011406844106463879, 0.01364522417153996, 14.0, 1315.0, 289.0, 98.0, 13.418367346938776, 4.550173010380623, 0.21977186311787072, 1026.0, 289.0, 0.7802281368821293, 3.78110599078341, 17.0, 1.0, 3.173384030418251, 12.89277566539924, 98.0, 2.0, 16507.0, 12806.0, 0.5631289871388121, 12806.0, 29313.0, 0.43687101286118785, 11.550171871571711, 0.0, 30.0, 159213.0, 3418021.0, 0.04658046278826256, 0.3902905500729253))","Map(vectorType -> dense, length -> 2, values -> List(18.24271502637221, 1.757284973627792))","Map(vectorType -> dense, length -> 2, values -> List(0.9121357513186105, 0.0878642486813896))",0.0
1,0,1,0.0,1,1,14.0,0.0,0.0,0.0,0.0,13.0,0.0175438596491228,0.0,2,57,38,3,19.0,1.5,0.6666666666666666,19,38,0.3333333333333333,18.5,23.0,10.0,1.3859649122807018,12.403508771929824,3,30.0,4726,3951,0.5446582920364181,3951,8677,0.4553417079635818,9.624555581891444,0.0,30.0,159213,3418021,0.0465804627882625,0.4087612451753193,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 14.0, 0.0, 0.0, 0.0, 0.0, 13.0, 0.017543859649122806, 0.0, 2.0, 57.0, 38.0, 3.0, 19.0, 1.5, 0.6666666666666666, 19.0, 38.0, 0.3333333333333333, 18.5, 23.0, 10.0, 1.3859649122807018, 12.403508771929825, 3.0, 30.0, 4726.0, 3951.0, 0.5446582920364181, 3951.0, 8677.0, 0.45534170796358187, 9.624555581891444, 0.0, 30.0, 159213.0, 3418021.0, 0.04658046278826256, 0.4087612451753193))","Map(vectorType -> dense, length -> 2, values -> List(18.791753559627125, 1.2082464403728759))","Map(vectorType -> dense, length -> 2, values -> List(0.9395876779813562, 0.06041232201864379))",0.0
1,0,1,0.0,3,3,1.0,17.0,17.0,17.0,0.0,17.0,0.0101010101010101,0.0,16,99,53,19,5.2105263157894735,1.8679245283018868,0.5353535353535354,46,53,0.4646464646464646,16.282608695652176,30.0,2.0,3.2525252525252526,14.171717171717171,19,7.0,3301,2831,0.5383235485975212,2831,6132,0.4616764514024788,12.165841584158416,0.0,30.0,77080,377586,0.2041389246423331,0.2575375267601457,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 3.0, 3.0, 1.0, 17.0, 17.0, 17.0, 0.0, 17.0, 0.010101010101010102, 0.0, 16.0, 99.0, 53.0, 19.0, 5.2105263157894735, 1.8679245283018868, 0.5353535353535354, 46.0, 53.0, 0.46464646464646464, 16.282608695652176, 30.0, 2.0, 3.2525252525252526, 14.171717171717171, 19.0, 7.0, 3301.0, 2831.0, 0.5383235485975212, 2831.0, 6132.0, 0.4616764514024788, 12.165841584158416, 0.0, 30.0, 77080.0, 377586.0, 0.20413892464233313, 0.2575375267601457))","Map(vectorType -> dense, length -> 2, values -> List(18.68615199424255, 1.3138480057574546))","Map(vectorType -> dense, length -> 2, values -> List(0.9343075997121274, 0.06569240028787272))",0.0
16,15,1,0.9375,43,6,17.5,8.0,16.0,4.0,1.8125,10.0625,0.0168776371308016,0.0195567144719687,2,948,181,45,21.066666666666663,5.237569060773481,0.1909282700421941,767,181,0.8090717299578059,8.375135135135135,23.0,3.0,2.721518987341772,11.827004219409282,45,11.0,8316,3863,0.682814681008293,3863,12179,0.317185318991707,10.55242566510172,0.0,30.0,124393,1452343,0.0856498774738474,0.2315354415178596,"Map(vectorType -> dense, length -> 44, values -> List(16.0, 15.0, 1.0, 0.9375, 43.0, 6.0, 17.5, 8.0, 16.0, 4.0, 1.8125, 10.0625, 0.016877637130801686, 0.01955671447196871, 2.0, 948.0, 181.0, 45.0, 21.066666666666666, 5.237569060773481, 0.1909282700421941, 767.0, 181.0, 0.8090717299578059, 8.375135135135135, 23.0, 3.0, 2.721518987341772, 11.827004219409282, 45.0, 11.0, 8316.0, 3863.0, 0.682814681008293, 3863.0, 12179.0, 0.31718531899170704, 10.55242566510172, 0.0, 30.0, 124393.0, 1452343.0, 0.08564987747384743, 0.2315354415178596))","Map(vectorType -> dense, length -> 2, values -> List(17.944122527317727, 2.055877472682275))","Map(vectorType -> dense, length -> 2, values -> List(0.8972061263658863, 0.10279387363411373))",0.0
3,2,1,0.6666666666666666,30,24,12.666666666666666,8.333333333333334,15.0,5.0,4.333333333333333,10.666666666666666,0.0051635111876075,0.0046948356807511,6,581,155,36,16.13888888888889,3.7483870967741937,0.2667814113597246,426,155,0.7332185886402753,10.564516129032258,30.0,1.0,3.253012048192771,12.81067125645439,36,1.0,5399,3880,0.5818514926177389,3880,9279,0.418148507382261,10.420258123991704,0.0,30.0,54202,234065,0.231568154145216,0.1865803532370449,"Map(vectorType -> dense, length -> 44, values -> List(3.0, 2.0, 1.0, 0.6666666666666666, 30.0, 24.0, 12.666666666666666, 8.333333333333334, 15.0, 5.0, 4.333333333333333, 10.666666666666666, 0.0051635111876075735, 0.004694835680751174, 6.0, 581.0, 155.0, 36.0, 16.13888888888889, 3.7483870967741937, 0.2667814113597246, 426.0, 155.0, 0.7332185886402753, 10.564516129032258, 30.0, 1.0, 3.253012048192771, 12.81067125645439, 36.0, 1.0, 5399.0, 3880.0, 0.5818514926177389, 3880.0, 9279.0, 0.418148507382261, 10.420258123991703, 0.0, 30.0, 54202.0, 234065.0, 0.23156815414521606, 0.18658035323704494))","Map(vectorType -> dense, length -> 2, values -> List(18.248206580873042, 1.7517934191269613))","Map(vectorType -> dense, length -> 2, values -> List(0.9124103290436519, 0.08758967095634805))",0.0
2,1,1,0.5,11,10,10.0,5.0,8.0,2.0,5.5,10.5,0.0119047619047619,0.0175438596491228,1,168,111,12,14.0,1.5135135135135136,0.6607142857142857,57,111,0.3392857142857143,5.778523489932886,14.0,1.0,2.9166666666666665,10.392857142857142,12,7.0,5115,2220,0.6973415132924335,2220,7335,0.3026584867075664,11.891954690676735,0.0,30.0,25608,99032,0.2585830842555942,0.0440754024519722,"Map(vectorType -> dense, length -> 44, values -> List(2.0, 1.0, 1.0, 0.5, 11.0, 10.0, 10.0, 5.0, 8.0, 2.0, 5.5, 10.5, 0.011904761904761904, 0.017543859649122806, 1.0, 168.0, 111.0, 12.0, 14.0, 1.5135135135135136, 0.6607142857142857, 57.0, 111.0, 0.3392857142857143, 5.778523489932886, 14.0, 1.0, 2.9166666666666665, 10.392857142857142, 12.0, 7.0, 5115.0, 2220.0, 0.6973415132924335, 2220.0, 7335.0, 0.30265848670756645, 11.891954690676735, 0.0, 30.0, 25608.0, 99032.0, 0.2585830842555942, 0.04407540245197228))","Map(vectorType -> dense, length -> 2, values -> List(18.51466965945544, 1.4853303405445624))","Map(vectorType -> dense, length -> 2, values -> List(0.9257334829727719, 0.07426651702722811))",0.0


#### All done!

In [0]:
predictions.groupBy('prediction').count().show()

+----------+-------+
|prediction|  count|
+----------+-------+
|       0.0|4826749|
|       1.0|   6540|
+----------+-------+



In [0]:
display(test_sdf_id.limit(50))

user_id,product_id,order_id
4041,34063,730911
99471,14917,2560214
7440,18770,3078569
153133,21137,1769478
87185,44570,3278341
172302,32981,1243888
5692,45542,784733
192691,2717,2332893
163789,10749,1689230
86173,42475,1200695


### Change prediction result to kaggle submission format.

In [0]:
#예측 결과 predictions와 test_sdf_id를 조인하여 붙이기.  
from pyspark.sql.functions import monotonically_increasing_id
# row건수별로 0부터 순차적으로 증가하는 row_id 컬럼을 monotonically_increasing_id()을 이용하여 생성. 
test_sdf_id = test_sdf_id.withColumn("row_id", monotonically_increasing_id())
predictions = predictions.withColumn("row_id", monotonically_increasing_id())

display(predictions.limit(10))

up_cnt,up_reord_cnt,up_no_reord_cnt,up_reoredered_avg,up_max_ord_num,up_min_ord_num,up_avg_cart,up_avg_prior_days,up_max_prior_days,up_min_prior_days,up_avg_ord_dow,up_avg_ord_hour,up_usr_ratio,up_usr_reord_ratio,up_usr_ord_num_diff,usr_total_cnt,prd_uq_cnt,order_uq_cnt,usr_avg_prd_cnt,usr_avg_uq_prd_cnt,usr_uq_prd_ratio,usr_reord_cnt,usr_no_reord_cnt,usr_reordered_avg,usr_avg_prior_days,usr_max_prior_days,usr_min_prior_days,usr_avg_order_dow,usr_avg_order_hour_of_day,usr_max_order_number,days_since_prior_order,prd_reordered_cnt,prd_no_reordered_cnt,prd_avg_reordered,prd_unq_usr_cnt,prd_total_cnt,prd_usr_ratio,prd_avg_prior_days,prd_min_prior_days,prd_max_prior_days,aisle_distinct_usr_cnt,aisle_total_cnt,aisle_usr_ratio,usr_ratio_diff,features,rawPrediction,probability,prediction,row_id
1,0,1,0.0,3,3,8.0,4.0,4.0,4.0,2.0,12.0,0.0232558139534883,0.0,2,43,37,5,8.6,1.162162162162162,0.8604651162790697,6,37,0.1395348837209302,5.269230769230769,8.0,3.0,2.441860465116279,12.30232558139535,5,17.0,15,19,0.4411764705882353,19,34,0.5588235294117647,10.303030303030305,0.0,30.0,5367,11798,0.4549076114595694,0.1039159179521952,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 3.0, 3.0, 8.0, 4.0, 4.0, 4.0, 2.0, 12.0, 0.023255813953488372, 0.0, 2.0, 43.0, 37.0, 5.0, 8.6, 1.162162162162162, 0.8604651162790697, 6.0, 37.0, 0.13953488372093023, 5.269230769230769, 8.0, 3.0, 2.441860465116279, 12.30232558139535, 5.0, 17.0, 15.0, 19.0, 0.4411764705882353, 19.0, 34.0, 0.5588235294117647, 10.303030303030303, 0.0, 30.0, 5367.0, 11798.0, 0.4549076114595694, 0.10391591795219529))","Map(vectorType -> dense, length -> 2, values -> List(18.58401753606904, 1.41598246393096))","Map(vectorType -> dense, length -> 2, values -> List(0.9292008768034521, 0.07079912319654799))",0.0,0
3,2,1,0.6666666666666666,22,18,8.333333333333334,9.666666666666666,13.0,4.0,2.6666666666666665,12.666666666666666,0.0118577075098814,0.0144927536231884,0,253,115,22,11.5,2.2,0.4545454545454545,138,115,0.5454545454545454,14.646586345381523,30.0,1.0,3.4229249011857705,12.936758893280633,22,4.0,41329,25657,0.6169796673931867,25657,66986,0.3830203326068133,10.895465002678767,0.0,30.0,159213,3418021,0.0465804627882625,0.3364398698185508,"Map(vectorType -> dense, length -> 44, values -> List(3.0, 2.0, 1.0, 0.6666666666666666, 22.0, 18.0, 8.333333333333334, 9.666666666666666, 13.0, 4.0, 2.6666666666666665, 12.666666666666666, 0.011857707509881422, 0.014492753623188406, 0.0, 253.0, 115.0, 22.0, 11.5, 2.2, 0.45454545454545453, 138.0, 115.0, 0.5454545454545454, 14.646586345381525, 30.0, 1.0, 3.4229249011857705, 12.936758893280633, 22.0, 4.0, 41329.0, 25657.0, 0.6169796673931867, 25657.0, 66986.0, 0.38302033260681334, 10.895465002678769, 0.0, 30.0, 159213.0, 3418021.0, 0.04658046278826256, 0.3364398698185508))","Map(vectorType -> dense, length -> 2, values -> List(17.450519964560513, 2.54948003543949))","Map(vectorType -> dense, length -> 2, values -> List(0.8725259982280255, 0.1274740017719745))",0.0,1
17,16,1,0.9411764705882352,30,1,3.882352941176471,6.8125,26.0,1.0,3.411764705882353,15.058823529411764,0.0723404255319148,0.1584158415841584,1,235,134,31,7.580645161290323,1.7537313432835822,0.5702127659574469,101,134,0.4297872340425532,6.969026548672566,30.0,1.0,3.093617021276596,15.6,31,25.0,23463,5421,0.8123182384711258,5421,28884,0.1876817615288741,10.691012672095829,0.0,30.0,99755,841533,0.1185396175788709,0.0691421439500031,"Map(vectorType -> dense, length -> 44, values -> List(17.0, 16.0, 1.0, 0.9411764705882353, 30.0, 1.0, 3.8823529411764706, 6.8125, 26.0, 1.0, 3.411764705882353, 15.058823529411764, 0.07234042553191489, 0.15841584158415842, 1.0, 235.0, 134.0, 31.0, 7.580645161290323, 1.7537313432835822, 0.5702127659574469, 101.0, 134.0, 0.4297872340425532, 6.969026548672566, 30.0, 1.0, 3.0936170212765957, 15.6, 31.0, 25.0, 23463.0, 5421.0, 0.8123182384711258, 5421.0, 28884.0, 0.1876817615288741, 10.691012672095827, 0.0, 30.0, 99755.0, 841533.0, 0.11853961757887094, 0.06914214395000318))","Map(vectorType -> dense, length -> 2, values -> List(12.827097103561352, 7.17290289643865))","Map(vectorType -> dense, length -> 2, values -> List(0.6413548551780676, 0.3586451448219325))",0.0,2
3,2,1,0.6666666666666666,55,45,20.0,5.666666666666667,10.0,2.0,1.3333333333333333,9.0,0.0030895983522142,0.0028901734104046,5,971,279,60,16.183333333333334,3.4802867383512543,0.2873326467559217,692,279,0.7126673532440783,6.110759493670886,23.0,1.0,2.368692070030896,10.383110195674565,60,4.0,1904,867,0.6871165644171779,867,2771,0.3128834355828221,9.424610710216482,0.0,30.0,89271,498425,0.1791061844811155,0.1337772511017066,"Map(vectorType -> dense, length -> 44, values -> List(3.0, 2.0, 1.0, 0.6666666666666666, 55.0, 45.0, 20.0, 5.666666666666667, 10.0, 2.0, 1.3333333333333333, 9.0, 0.003089598352214212, 0.002890173410404624, 5.0, 971.0, 279.0, 60.0, 16.183333333333334, 3.4802867383512543, 0.28733264675592174, 692.0, 279.0, 0.7126673532440783, 6.110759493670886, 23.0, 1.0, 2.368692070030896, 10.383110195674563, 60.0, 4.0, 1904.0, 867.0, 0.6871165644171779, 867.0, 2771.0, 0.3128834355828221, 9.424610710216482, 0.0, 30.0, 89271.0, 498425.0, 0.1791061844811155, 0.1337772511017066))","Map(vectorType -> dense, length -> 2, values -> List(18.223688624151897, 1.7763113758481053))","Map(vectorType -> dense, length -> 2, values -> List(0.9111844312075947, 0.08881556879240525))",0.0,3
15,14,1,0.9333333333333332,84,15,10.533333333333331,3.4,6.0,1.0,3.0,14.666666666666666,0.0114068441064638,0.0136452241715399,14,1315,289,98,13.418367346938776,4.550173010380623,0.2197718631178707,1026,289,0.7802281368821293,3.78110599078341,17.0,1.0,3.173384030418251,12.89277566539924,98,2.0,16507,12806,0.5631289871388121,12806,29313,0.4368710128611878,11.550171871571711,0.0,30.0,159213,3418021,0.0465804627882625,0.3902905500729253,"Map(vectorType -> dense, length -> 44, values -> List(15.0, 14.0, 1.0, 0.9333333333333333, 84.0, 15.0, 10.533333333333333, 3.4, 6.0, 1.0, 3.0, 14.666666666666666, 0.011406844106463879, 0.01364522417153996, 14.0, 1315.0, 289.0, 98.0, 13.418367346938776, 4.550173010380623, 0.21977186311787072, 1026.0, 289.0, 0.7802281368821293, 3.78110599078341, 17.0, 1.0, 3.173384030418251, 12.89277566539924, 98.0, 2.0, 16507.0, 12806.0, 0.5631289871388121, 12806.0, 29313.0, 0.43687101286118785, 11.550171871571711, 0.0, 30.0, 159213.0, 3418021.0, 0.04658046278826256, 0.3902905500729253))","Map(vectorType -> dense, length -> 2, values -> List(18.24271502637221, 1.757284973627792))","Map(vectorType -> dense, length -> 2, values -> List(0.9121357513186105, 0.0878642486813896))",0.0,4
1,0,1,0.0,1,1,14.0,0.0,0.0,0.0,0.0,13.0,0.0175438596491228,0.0,2,57,38,3,19.0,1.5,0.6666666666666666,19,38,0.3333333333333333,18.5,23.0,10.0,1.3859649122807018,12.403508771929824,3,30.0,4726,3951,0.5446582920364181,3951,8677,0.4553417079635818,9.624555581891444,0.0,30.0,159213,3418021,0.0465804627882625,0.4087612451753193,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 14.0, 0.0, 0.0, 0.0, 0.0, 13.0, 0.017543859649122806, 0.0, 2.0, 57.0, 38.0, 3.0, 19.0, 1.5, 0.6666666666666666, 19.0, 38.0, 0.3333333333333333, 18.5, 23.0, 10.0, 1.3859649122807018, 12.403508771929825, 3.0, 30.0, 4726.0, 3951.0, 0.5446582920364181, 3951.0, 8677.0, 0.45534170796358187, 9.624555581891444, 0.0, 30.0, 159213.0, 3418021.0, 0.04658046278826256, 0.4087612451753193))","Map(vectorType -> dense, length -> 2, values -> List(18.791753559627125, 1.2082464403728759))","Map(vectorType -> dense, length -> 2, values -> List(0.9395876779813562, 0.06041232201864379))",0.0,5
1,0,1,0.0,3,3,1.0,17.0,17.0,17.0,0.0,17.0,0.0101010101010101,0.0,16,99,53,19,5.2105263157894735,1.8679245283018868,0.5353535353535354,46,53,0.4646464646464646,16.282608695652176,30.0,2.0,3.2525252525252526,14.171717171717171,19,7.0,3301,2831,0.5383235485975212,2831,6132,0.4616764514024788,12.165841584158416,0.0,30.0,77080,377586,0.2041389246423331,0.2575375267601457,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 3.0, 3.0, 1.0, 17.0, 17.0, 17.0, 0.0, 17.0, 0.010101010101010102, 0.0, 16.0, 99.0, 53.0, 19.0, 5.2105263157894735, 1.8679245283018868, 0.5353535353535354, 46.0, 53.0, 0.46464646464646464, 16.282608695652176, 30.0, 2.0, 3.2525252525252526, 14.171717171717171, 19.0, 7.0, 3301.0, 2831.0, 0.5383235485975212, 2831.0, 6132.0, 0.4616764514024788, 12.165841584158416, 0.0, 30.0, 77080.0, 377586.0, 0.20413892464233313, 0.2575375267601457))","Map(vectorType -> dense, length -> 2, values -> List(18.68615199424255, 1.3138480057574546))","Map(vectorType -> dense, length -> 2, values -> List(0.9343075997121274, 0.06569240028787272))",0.0,6
16,15,1,0.9375,43,6,17.5,8.0,16.0,4.0,1.8125,10.0625,0.0168776371308016,0.0195567144719687,2,948,181,45,21.066666666666663,5.237569060773481,0.1909282700421941,767,181,0.8090717299578059,8.375135135135135,23.0,3.0,2.721518987341772,11.827004219409282,45,11.0,8316,3863,0.682814681008293,3863,12179,0.317185318991707,10.55242566510172,0.0,30.0,124393,1452343,0.0856498774738474,0.2315354415178596,"Map(vectorType -> dense, length -> 44, values -> List(16.0, 15.0, 1.0, 0.9375, 43.0, 6.0, 17.5, 8.0, 16.0, 4.0, 1.8125, 10.0625, 0.016877637130801686, 0.01955671447196871, 2.0, 948.0, 181.0, 45.0, 21.066666666666666, 5.237569060773481, 0.1909282700421941, 767.0, 181.0, 0.8090717299578059, 8.375135135135135, 23.0, 3.0, 2.721518987341772, 11.827004219409282, 45.0, 11.0, 8316.0, 3863.0, 0.682814681008293, 3863.0, 12179.0, 0.31718531899170704, 10.55242566510172, 0.0, 30.0, 124393.0, 1452343.0, 0.08564987747384743, 0.2315354415178596))","Map(vectorType -> dense, length -> 2, values -> List(17.944122527317727, 2.055877472682275))","Map(vectorType -> dense, length -> 2, values -> List(0.8972061263658863, 0.10279387363411373))",0.0,7
3,2,1,0.6666666666666666,30,24,12.666666666666666,8.333333333333334,15.0,5.0,4.333333333333333,10.666666666666666,0.0051635111876075,0.0046948356807511,6,581,155,36,16.13888888888889,3.7483870967741937,0.2667814113597246,426,155,0.7332185886402753,10.564516129032258,30.0,1.0,3.253012048192771,12.81067125645439,36,1.0,5399,3880,0.5818514926177389,3880,9279,0.418148507382261,10.420258123991704,0.0,30.0,54202,234065,0.231568154145216,0.1865803532370449,"Map(vectorType -> dense, length -> 44, values -> List(3.0, 2.0, 1.0, 0.6666666666666666, 30.0, 24.0, 12.666666666666666, 8.333333333333334, 15.0, 5.0, 4.333333333333333, 10.666666666666666, 0.0051635111876075735, 0.004694835680751174, 6.0, 581.0, 155.0, 36.0, 16.13888888888889, 3.7483870967741937, 0.2667814113597246, 426.0, 155.0, 0.7332185886402753, 10.564516129032258, 30.0, 1.0, 3.253012048192771, 12.81067125645439, 36.0, 1.0, 5399.0, 3880.0, 0.5818514926177389, 3880.0, 9279.0, 0.418148507382261, 10.420258123991703, 0.0, 30.0, 54202.0, 234065.0, 0.23156815414521606, 0.18658035323704494))","Map(vectorType -> dense, length -> 2, values -> List(18.248206580873042, 1.7517934191269613))","Map(vectorType -> dense, length -> 2, values -> List(0.9124103290436519, 0.08758967095634805))",0.0,8
2,1,1,0.5,11,10,10.0,5.0,8.0,2.0,5.5,10.5,0.0119047619047619,0.0175438596491228,1,168,111,12,14.0,1.5135135135135136,0.6607142857142857,57,111,0.3392857142857143,5.778523489932886,14.0,1.0,2.9166666666666665,10.392857142857142,12,7.0,5115,2220,0.6973415132924335,2220,7335,0.3026584867075664,11.891954690676735,0.0,30.0,25608,99032,0.2585830842555942,0.0440754024519722,"Map(vectorType -> dense, length -> 44, values -> List(2.0, 1.0, 1.0, 0.5, 11.0, 10.0, 10.0, 5.0, 8.0, 2.0, 5.5, 10.5, 0.011904761904761904, 0.017543859649122806, 1.0, 168.0, 111.0, 12.0, 14.0, 1.5135135135135136, 0.6607142857142857, 57.0, 111.0, 0.3392857142857143, 5.778523489932886, 14.0, 1.0, 2.9166666666666665, 10.392857142857142, 12.0, 7.0, 5115.0, 2220.0, 0.6973415132924335, 2220.0, 7335.0, 0.30265848670756645, 11.891954690676735, 0.0, 30.0, 25608.0, 99032.0, 0.2585830842555942, 0.04407540245197228))","Map(vectorType -> dense, length -> 2, values -> List(18.51466965945544, 1.4853303405445624))","Map(vectorType -> dense, length -> 2, values -> List(0.9257334829727719, 0.07426651702722811))",0.0,9


In [0]:
# order_id와 product_id를 얻기 위해 test_sdf_id와 predictions을 row_id로 조인 시킴. 
predictions = test_sdf_id.join(predictions, ("row_id")).drop("row_id")
print(test_sdf.count(), predictions.count())
display(predictions.limit(10))

4833289 4833289


user_id,product_id,order_id,up_cnt,up_reord_cnt,up_no_reord_cnt,up_reoredered_avg,up_max_ord_num,up_min_ord_num,up_avg_cart,up_avg_prior_days,up_max_prior_days,up_min_prior_days,up_avg_ord_dow,up_avg_ord_hour,up_usr_ratio,up_usr_reord_ratio,up_usr_ord_num_diff,usr_total_cnt,prd_uq_cnt,order_uq_cnt,usr_avg_prd_cnt,usr_avg_uq_prd_cnt,usr_uq_prd_ratio,usr_reord_cnt,usr_no_reord_cnt,usr_reordered_avg,usr_avg_prior_days,usr_max_prior_days,usr_min_prior_days,usr_avg_order_dow,usr_avg_order_hour_of_day,usr_max_order_number,days_since_prior_order,prd_reordered_cnt,prd_no_reordered_cnt,prd_avg_reordered,prd_unq_usr_cnt,prd_total_cnt,prd_usr_ratio,prd_avg_prior_days,prd_min_prior_days,prd_max_prior_days,aisle_distinct_usr_cnt,aisle_total_cnt,aisle_usr_ratio,usr_ratio_diff,features,rawPrediction,probability,prediction
172302,19822,1243888,1,0,1,0.0,3,3,8.0,4.0,4.0,4.0,2.0,12.0,0.0232558139534883,0.0,2,43,37,5,8.6,1.162162162162162,0.8604651162790697,6,37,0.1395348837209302,5.269230769230769,8.0,3.0,2.441860465116279,12.30232558139535,5,17.0,15,19,0.4411764705882353,19,34,0.5588235294117647,10.303030303030305,0.0,30.0,5367,11798,0.4549076114595694,0.1039159179521952,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 3.0, 3.0, 8.0, 4.0, 4.0, 4.0, 2.0, 12.0, 0.023255813953488372, 0.0, 2.0, 43.0, 37.0, 5.0, 8.6, 1.162162162162162, 0.8604651162790697, 6.0, 37.0, 0.13953488372093023, 5.269230769230769, 8.0, 3.0, 2.441860465116279, 12.30232558139535, 5.0, 17.0, 15.0, 19.0, 0.4411764705882353, 19.0, 34.0, 0.5588235294117647, 10.303030303030303, 0.0, 30.0, 5367.0, 11798.0, 0.4549076114595694, 0.10391591795219529))","Map(vectorType -> dense, length -> 2, values -> List(18.58401753606904, 1.41598246393096))","Map(vectorType -> dense, length -> 2, values -> List(0.9292008768034521, 0.07079912319654799))",0.0
1623,43908,1778015,2,1,1,0.5,4,1,2.0,30.0,30.0,30.0,0.0,11.0,0.016260162601626,0.0126582278481012,16,123,44,20,6.15,2.7954545454545454,0.3577235772357723,79,44,0.6422764227642277,15.721739130434782,30.0,2.0,1.1788617886178865,12.471544715447154,20,6.0,2115,2357,0.4729427549194991,2357,4472,0.5270572450805009,11.802751166789486,0.0,30.0,58749,390299,0.150523060525392,0.3765341845551088,"Map(vectorType -> dense, length -> 44, values -> List(2.0, 1.0, 1.0, 0.5, 4.0, 1.0, 2.0, 30.0, 30.0, 30.0, 0.0, 11.0, 0.016260162601626018, 0.012658227848101266, 16.0, 123.0, 44.0, 20.0, 6.15, 2.7954545454545454, 0.35772357723577236, 79.0, 44.0, 0.6422764227642277, 15.721739130434782, 30.0, 2.0, 1.1788617886178863, 12.471544715447154, 20.0, 6.0, 2115.0, 2357.0, 0.4729427549194991, 2357.0, 4472.0, 0.5270572450805009, 11.802751166789486, 0.0, 30.0, 58749.0, 390299.0, 0.15052306052539208, 0.3765341845551088))","Map(vectorType -> dense, length -> 2, values -> List(18.63729236572747, 1.3627076342725357))","Map(vectorType -> dense, length -> 2, values -> List(0.9318646182863732, 0.06813538171362678))",0.0
160722,46979,2741763,11,10,1,0.9090909090909092,34,1,12.454545454545457,10.1,22.0,1.0,2.4545454545454546,11.727272727272728,0.0121546961325966,0.0168067226890756,4,905,310,38,23.81578947368421,2.919354838709677,0.3425414364640884,595,310,0.6574585635359116,8.077966101694916,26.0,1.0,2.889502762430939,12.103867403314917,38,5.0,41585,25698,0.618061025816328,25698,67283,0.3819389741836719,10.788400680629104,0.0,30.0,159213,3418021,0.0465804627882625,0.3353585113954094,"Map(vectorType -> dense, length -> 44, values -> List(11.0, 10.0, 1.0, 0.9090909090909091, 34.0, 1.0, 12.454545454545455, 10.1, 22.0, 1.0, 2.4545454545454546, 11.727272727272727, 0.012154696132596685, 0.01680672268907563, 4.0, 905.0, 310.0, 38.0, 23.81578947368421, 2.9193548387096775, 0.3425414364640884, 595.0, 310.0, 0.6574585635359116, 8.077966101694916, 26.0, 1.0, 2.889502762430939, 12.103867403314917, 38.0, 5.0, 41585.0, 25698.0, 0.618061025816328, 25698.0, 67283.0, 0.38193897418367195, 10.788400680629104, 0.0, 30.0, 159213.0, 3418021.0, 0.04658046278826256, 0.3353585113954094))","Map(vectorType -> dense, length -> 2, values -> List(18.10412682180586, 1.8958731781941405))","Map(vectorType -> dense, length -> 2, values -> List(0.905206341090293, 0.09479365890970702))",0.0
137485,21903,1392405,5,4,1,0.8,16,6,4.0,17.2,30.0,2.0,2.6,20.8,0.0427350427350427,0.0909090909090909,0,117,73,16,7.3125,1.6027397260273972,0.6239316239316239,44,73,0.376068376068376,16.570093457943926,30.0,0.0,3.1880341880341883,20.05982905982906,16,20.0,186884,55037,0.7725001136734719,55037,241921,0.2274998863265281,11.199653703303918,0.0,30.0,159418,1765313,0.0903057984618025,0.1371940878647256,"Map(vectorType -> dense, length -> 44, values -> List(5.0, 4.0, 1.0, 0.8, 16.0, 6.0, 4.0, 17.2, 30.0, 2.0, 2.6, 20.8, 0.042735042735042736, 0.09090909090909091, 0.0, 117.0, 73.0, 16.0, 7.3125, 1.6027397260273972, 0.6239316239316239, 44.0, 73.0, 0.37606837606837606, 16.570093457943926, 30.0, 0.0, 3.1880341880341883, 20.05982905982906, 16.0, 20.0, 186884.0, 55037.0, 0.7725001136734719, 55037.0, 241921.0, 0.2274998863265281, 11.199653703303918, 0.0, 30.0, 159418.0, 1765313.0, 0.09030579846180252, 0.1371940878647256))","Map(vectorType -> dense, length -> 2, values -> List(12.628242273317392, 7.371757726682604))","Map(vectorType -> dense, length -> 2, values -> List(0.6314121136658697, 0.3685878863341303))",0.0
100368,1529,1583293,2,1,1,0.5,8,6,14.0,13.5,21.0,6.0,0.0,16.0,0.0082987551867219,0.0064935064935064,4,241,87,12,20.08333333333333,2.7701149425287355,0.3609958506224066,154,87,0.6390041493775933,16.5990990990991,30.0,6.0,1.95850622406639,12.755186721991702,12,30.0,3838,4898,0.4393315018315018,4898,8736,0.5606684981684982,11.348765049025692,0.0,30.0,76738,377741,0.2031497772283125,0.3575187209401856,"Map(vectorType -> dense, length -> 44, values -> List(2.0, 1.0, 1.0, 0.5, 8.0, 6.0, 14.0, 13.5, 21.0, 6.0, 0.0, 16.0, 0.008298755186721992, 0.006493506493506494, 4.0, 241.0, 87.0, 12.0, 20.083333333333332, 2.7701149425287355, 0.36099585062240663, 154.0, 87.0, 0.6390041493775933, 16.5990990990991, 30.0, 6.0, 1.95850622406639, 12.755186721991702, 12.0, 30.0, 3838.0, 4898.0, 0.4393315018315018, 4898.0, 8736.0, 0.5606684981684982, 11.348765049025692, 0.0, 30.0, 76738.0, 377741.0, 0.2031497772283125, 0.35751872094018566))","Map(vectorType -> dense, length -> 2, values -> List(18.68615199424255, 1.3138480057574546))","Map(vectorType -> dense, length -> 2, values -> List(0.9343075997121274, 0.06569240028787272))",0.0
177418,35166,1786000,2,1,1,0.5,70,28,7.0,3.0,4.0,2.0,2.5,10.0,0.0022935779816513,0.0015723270440251,24,872,236,94,9.27659574468085,3.694915254237288,0.2706422018348624,636,236,0.7293577981651376,3.1795166858458,21.0,0.0,2.311926605504587,10.97362385321101,94,5.0,489,1169,0.2949336550060313,1169,1658,0.7050663449939686,10.973868706182282,0.0,30.0,76476,297037,0.2574628749953709,0.4476034699985976,"Map(vectorType -> dense, length -> 44, values -> List(2.0, 1.0, 1.0, 0.5, 70.0, 28.0, 7.0, 3.0, 4.0, 2.0, 2.5, 10.0, 0.0022935779816513763, 0.0015723270440251573, 24.0, 872.0, 236.0, 94.0, 9.27659574468085, 3.694915254237288, 0.2706422018348624, 636.0, 236.0, 0.7293577981651376, 3.1795166858457997, 21.0, 0.0, 2.311926605504587, 10.97362385321101, 94.0, 5.0, 489.0, 1169.0, 0.29493365500603136, 1169.0, 1658.0, 0.7050663449939686, 10.973868706182282, 0.0, 30.0, 76476.0, 297037.0, 0.25746287499537096, 0.4476034699985976))","Map(vectorType -> dense, length -> 2, values -> List(18.850900731484263, 1.1490992685157388))","Map(vectorType -> dense, length -> 2, values -> List(0.9425450365742132, 0.057454963425786945))",0.0
115553,41793,2548573,5,4,1,0.8,96,17,6.8,4.4,15.0,1.0,2.4,9.0,0.0046816479400749,0.0047449584816132,2,1068,225,98,10.89795918367347,4.746666666666667,0.2106741573033707,843,225,0.7893258426966292,3.885230479774224,30.0,0.0,3.093632958801498,11.82865168539326,98,3.0,2333,869,0.7286071205496565,869,3202,0.2713928794503435,9.980026195153895,0.0,30.0,54202,234065,0.231568154145216,0.0398247253051274,"Map(vectorType -> dense, length -> 44, values -> List(5.0, 4.0, 1.0, 0.8, 96.0, 17.0, 6.8, 4.4, 15.0, 1.0, 2.4, 9.0, 0.0046816479400749065, 0.004744958481613286, 2.0, 1068.0, 225.0, 98.0, 10.89795918367347, 4.746666666666667, 0.21067415730337077, 843.0, 225.0, 0.7893258426966292, 3.885230479774224, 30.0, 0.0, 3.093632958801498, 11.82865168539326, 98.0, 3.0, 2333.0, 869.0, 0.7286071205496565, 869.0, 3202.0, 0.2713928794503435, 9.980026195153897, 0.0, 30.0, 54202.0, 234065.0, 0.23156815414521606, 0.039824725305127456))","Map(vectorType -> dense, length -> 2, values -> List(18.136398020554342, 1.863601979445663))","Map(vectorType -> dense, length -> 2, values -> List(0.9068199010277169, 0.09318009897228313))",0.0
105400,45007,2176640,2,1,1,0.5,12,9,6.0,19.5,30.0,9.0,2.5,15.5,0.011049723756906,0.0075187969924812,6,181,48,18,10.055555555555555,3.7708333333333335,0.2651933701657458,133,48,0.7348066298342542,13.538922155688622,30.0,5.0,2.0939226519337018,16.03867403314917,18,12.0,72165,32658,0.6884462379439626,32658,104823,0.3115537620560373,10.597150051962586,0.0,30.0,159213,3418021,0.0465804627882625,0.2649732992677747,"Map(vectorType -> dense, length -> 44, values -> List(2.0, 1.0, 1.0, 0.5, 12.0, 9.0, 6.0, 19.5, 30.0, 9.0, 2.5, 15.5, 0.011049723756906077, 0.007518796992481203, 6.0, 181.0, 48.0, 18.0, 10.055555555555555, 3.7708333333333335, 0.26519337016574585, 133.0, 48.0, 0.7348066298342542, 13.538922155688622, 30.0, 5.0, 2.0939226519337018, 16.03867403314917, 18.0, 12.0, 72165.0, 32658.0, 0.6884462379439626, 32658.0, 104823.0, 0.3115537620560373, 10.597150051962586, 0.0, 30.0, 159213.0, 3418021.0, 0.04658046278826256, 0.26497329926777474))","Map(vectorType -> dense, length -> 2, values -> List(18.596309099092576, 1.403690900907425))","Map(vectorType -> dense, length -> 2, values -> List(0.9298154549546288, 0.07018454504537125))",0.0
93880,49517,247785,28,27,1,0.9642857142857144,33,1,5.25,10.22222222222222,28.0,1.0,3.107142857142857,10.785714285714286,0.0590717299578059,0.0694087403598971,0,474,85,33,14.363636363636363,5.576470588235294,0.1793248945147679,389,85,0.820675105485232,10.951002227171491,28.0,1.0,2.8375527426160336,11.232067510548523,33,26.0,2706,1115,0.7081915728866789,1115,3821,0.2918084271133211,11.783276450511943,0.0,30.0,109226,891015,0.1225860395167309,0.1692223875965902,"Map(vectorType -> dense, length -> 44, values -> List(28.0, 27.0, 1.0, 0.9642857142857143, 33.0, 1.0, 5.25, 10.222222222222221, 28.0, 1.0, 3.107142857142857, 10.785714285714286, 0.05907172995780591, 0.06940874035989718, 0.0, 474.0, 85.0, 33.0, 14.363636363636363, 5.576470588235294, 0.17932489451476794, 389.0, 85.0, 0.820675105485232, 10.951002227171493, 28.0, 1.0, 2.8375527426160336, 11.232067510548523, 33.0, 26.0, 2706.0, 1115.0, 0.7081915728866789, 1115.0, 3821.0, 0.2918084271133211, 11.783276450511945, 0.0, 30.0, 109226.0, 891015.0, 0.12258603951673092, 0.1692223875965902))","Map(vectorType -> dense, length -> 2, values -> List(10.195436432543355, 9.804563567456643))","Map(vectorType -> dense, length -> 2, values -> List(0.5097718216271677, 0.49022817837283217))",0.0
63685,39190,1263744,1,0,1,0.0,4,4,14.0,8.0,8.0,8.0,1.0,23.0,0.0054347826086956,0.0,8,184,130,12,15.333333333333334,1.4153846153846157,0.7065217391304348,54,130,0.2934782608695652,18.22360248447205,30.0,2.0,2.717391304347826,13.36413043478261,12,19.0,6294,4678,0.5736419978126139,4678,10972,0.426358002187386,12.504883101509323,0.0,30.0,96326,638253,0.1509213431037535,0.2754366590836326,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 4.0, 4.0, 14.0, 8.0, 8.0, 8.0, 1.0, 23.0, 0.005434782608695652, 0.0, 8.0, 184.0, 130.0, 12.0, 15.333333333333334, 1.4153846153846155, 0.7065217391304348, 54.0, 130.0, 0.29347826086956524, 18.22360248447205, 30.0, 2.0, 2.717391304347826, 13.36413043478261, 12.0, 19.0, 6294.0, 4678.0, 0.5736419978126139, 4678.0, 10972.0, 0.42635800218738606, 12.504883101509323, 0.0, 30.0, 96326.0, 638253.0, 0.15092134310375352, 0.2754366590836326))","Map(vectorType -> dense, length -> 2, values -> List(18.74200689205801, 1.2579931079419875))","Map(vectorType -> dense, length -> 2, values -> List(0.9371003446029006, 0.06289965539709938))",0.0


In [0]:
predictions.printSchema()

root
 |-- user_id: integer (nullable = true)
 |-- product_id: long (nullable = true)
 |-- order_id: integer (nullable = true)
 |-- up_cnt: long (nullable = true)
 |-- up_reord_cnt: long (nullable = true)
 |-- up_no_reord_cnt: long (nullable = true)
 |-- up_reoredered_avg: double (nullable = false)
 |-- up_max_ord_num: integer (nullable = true)
 |-- up_min_ord_num: integer (nullable = true)
 |-- up_avg_cart: double (nullable = false)
 |-- up_avg_prior_days: double (nullable = false)
 |-- up_max_prior_days: double (nullable = false)
 |-- up_min_prior_days: double (nullable = false)
 |-- up_avg_ord_dow: double (nullable = false)
 |-- up_avg_ord_hour: double (nullable = false)
 |-- up_usr_ratio: double (nullable = false)
 |-- up_usr_reord_ratio: double (nullable = false)
 |-- up_usr_ord_num_diff: integer (nullable = true)
 |-- usr_total_cnt: long (nullable = true)
 |-- prd_uq_cnt: long (nullable = true)
 |-- order_uq_cnt: long (nullable = true)
 |-- usr_avg_prd_cnt: double (nullable = fals

In [0]:
# 여러 값으로 구성된 vector 컬럼에서 특정 값만 추출. probability 컬럼은 0/1 일때의 확률을 모두 가짐. 이중 1일 때(즉 재주문)의 확률을 추출
# 먼저 vector를 array로 변환
from pyspark.ml.functions import vector_to_array
predictions = predictions.withColumn("probability_arr", vector_to_array('probability'))
display(predictions.limit(10))

user_id,product_id,order_id,up_cnt,up_reord_cnt,up_no_reord_cnt,up_reoredered_avg,up_max_ord_num,up_min_ord_num,up_avg_cart,up_avg_prior_days,up_max_prior_days,up_min_prior_days,up_avg_ord_dow,up_avg_ord_hour,up_usr_ratio,up_usr_reord_ratio,up_usr_ord_num_diff,usr_total_cnt,prd_uq_cnt,order_uq_cnt,usr_avg_prd_cnt,usr_avg_uq_prd_cnt,usr_uq_prd_ratio,usr_reord_cnt,usr_no_reord_cnt,usr_reordered_avg,usr_avg_prior_days,usr_max_prior_days,usr_min_prior_days,usr_avg_order_dow,usr_avg_order_hour_of_day,usr_max_order_number,days_since_prior_order,prd_reordered_cnt,prd_no_reordered_cnt,prd_avg_reordered,prd_unq_usr_cnt,prd_total_cnt,prd_usr_ratio,prd_avg_prior_days,prd_min_prior_days,prd_max_prior_days,aisle_distinct_usr_cnt,aisle_total_cnt,aisle_usr_ratio,usr_ratio_diff,features,rawPrediction,probability,prediction,probability_arr
172302,19822,1243888,1,0,1,0.0,3,3,8.0,4.0,4.0,4.0,2.0,12.0,0.0232558139534883,0.0,2,43,37,5,8.6,1.162162162162162,0.8604651162790697,6,37,0.1395348837209302,5.269230769230769,8.0,3.0,2.441860465116279,12.30232558139535,5,17.0,15,19,0.4411764705882353,19,34,0.5588235294117647,10.303030303030305,0.0,30.0,5367,11798,0.4549076114595694,0.1039159179521952,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 3.0, 3.0, 8.0, 4.0, 4.0, 4.0, 2.0, 12.0, 0.023255813953488372, 0.0, 2.0, 43.0, 37.0, 5.0, 8.6, 1.162162162162162, 0.8604651162790697, 6.0, 37.0, 0.13953488372093023, 5.269230769230769, 8.0, 3.0, 2.441860465116279, 12.30232558139535, 5.0, 17.0, 15.0, 19.0, 0.4411764705882353, 19.0, 34.0, 0.5588235294117647, 10.303030303030303, 0.0, 30.0, 5367.0, 11798.0, 0.4549076114595694, 0.10391591795219529))","Map(vectorType -> dense, length -> 2, values -> List(18.58401753606904, 1.41598246393096))","Map(vectorType -> dense, length -> 2, values -> List(0.9292008768034521, 0.07079912319654799))",0.0,"List(0.9292008768034521, 0.07079912319654799)"
1623,43908,1778015,2,1,1,0.5,4,1,2.0,30.0,30.0,30.0,0.0,11.0,0.016260162601626,0.0126582278481012,16,123,44,20,6.15,2.7954545454545454,0.3577235772357723,79,44,0.6422764227642277,15.721739130434782,30.0,2.0,1.1788617886178865,12.471544715447154,20,6.0,2115,2357,0.4729427549194991,2357,4472,0.5270572450805009,11.802751166789486,0.0,30.0,58749,390299,0.150523060525392,0.3765341845551088,"Map(vectorType -> dense, length -> 44, values -> List(2.0, 1.0, 1.0, 0.5, 4.0, 1.0, 2.0, 30.0, 30.0, 30.0, 0.0, 11.0, 0.016260162601626018, 0.012658227848101266, 16.0, 123.0, 44.0, 20.0, 6.15, 2.7954545454545454, 0.35772357723577236, 79.0, 44.0, 0.6422764227642277, 15.721739130434782, 30.0, 2.0, 1.1788617886178863, 12.471544715447154, 20.0, 6.0, 2115.0, 2357.0, 0.4729427549194991, 2357.0, 4472.0, 0.5270572450805009, 11.802751166789486, 0.0, 30.0, 58749.0, 390299.0, 0.15052306052539208, 0.3765341845551088))","Map(vectorType -> dense, length -> 2, values -> List(18.63729236572747, 1.3627076342725357))","Map(vectorType -> dense, length -> 2, values -> List(0.9318646182863732, 0.06813538171362678))",0.0,"List(0.9318646182863732, 0.06813538171362678)"
160722,46979,2741763,11,10,1,0.9090909090909092,34,1,12.454545454545457,10.1,22.0,1.0,2.4545454545454546,11.727272727272728,0.0121546961325966,0.0168067226890756,4,905,310,38,23.81578947368421,2.919354838709677,0.3425414364640884,595,310,0.6574585635359116,8.077966101694916,26.0,1.0,2.889502762430939,12.103867403314917,38,5.0,41585,25698,0.618061025816328,25698,67283,0.3819389741836719,10.788400680629104,0.0,30.0,159213,3418021,0.0465804627882625,0.3353585113954094,"Map(vectorType -> dense, length -> 44, values -> List(11.0, 10.0, 1.0, 0.9090909090909091, 34.0, 1.0, 12.454545454545455, 10.1, 22.0, 1.0, 2.4545454545454546, 11.727272727272727, 0.012154696132596685, 0.01680672268907563, 4.0, 905.0, 310.0, 38.0, 23.81578947368421, 2.9193548387096775, 0.3425414364640884, 595.0, 310.0, 0.6574585635359116, 8.077966101694916, 26.0, 1.0, 2.889502762430939, 12.103867403314917, 38.0, 5.0, 41585.0, 25698.0, 0.618061025816328, 25698.0, 67283.0, 0.38193897418367195, 10.788400680629104, 0.0, 30.0, 159213.0, 3418021.0, 0.04658046278826256, 0.3353585113954094))","Map(vectorType -> dense, length -> 2, values -> List(18.10412682180586, 1.8958731781941405))","Map(vectorType -> dense, length -> 2, values -> List(0.905206341090293, 0.09479365890970702))",0.0,"List(0.905206341090293, 0.09479365890970702)"
137485,21903,1392405,5,4,1,0.8,16,6,4.0,17.2,30.0,2.0,2.6,20.8,0.0427350427350427,0.0909090909090909,0,117,73,16,7.3125,1.6027397260273972,0.6239316239316239,44,73,0.376068376068376,16.570093457943926,30.0,0.0,3.1880341880341883,20.05982905982906,16,20.0,186884,55037,0.7725001136734719,55037,241921,0.2274998863265281,11.199653703303918,0.0,30.0,159418,1765313,0.0903057984618025,0.1371940878647256,"Map(vectorType -> dense, length -> 44, values -> List(5.0, 4.0, 1.0, 0.8, 16.0, 6.0, 4.0, 17.2, 30.0, 2.0, 2.6, 20.8, 0.042735042735042736, 0.09090909090909091, 0.0, 117.0, 73.0, 16.0, 7.3125, 1.6027397260273972, 0.6239316239316239, 44.0, 73.0, 0.37606837606837606, 16.570093457943926, 30.0, 0.0, 3.1880341880341883, 20.05982905982906, 16.0, 20.0, 186884.0, 55037.0, 0.7725001136734719, 55037.0, 241921.0, 0.2274998863265281, 11.199653703303918, 0.0, 30.0, 159418.0, 1765313.0, 0.09030579846180252, 0.1371940878647256))","Map(vectorType -> dense, length -> 2, values -> List(12.628242273317392, 7.371757726682604))","Map(vectorType -> dense, length -> 2, values -> List(0.6314121136658697, 0.3685878863341303))",0.0,"List(0.6314121136658697, 0.3685878863341303)"
100368,1529,1583293,2,1,1,0.5,8,6,14.0,13.5,21.0,6.0,0.0,16.0,0.0082987551867219,0.0064935064935064,4,241,87,12,20.08333333333333,2.7701149425287355,0.3609958506224066,154,87,0.6390041493775933,16.5990990990991,30.0,6.0,1.95850622406639,12.755186721991702,12,30.0,3838,4898,0.4393315018315018,4898,8736,0.5606684981684982,11.348765049025692,0.0,30.0,76738,377741,0.2031497772283125,0.3575187209401856,"Map(vectorType -> dense, length -> 44, values -> List(2.0, 1.0, 1.0, 0.5, 8.0, 6.0, 14.0, 13.5, 21.0, 6.0, 0.0, 16.0, 0.008298755186721992, 0.006493506493506494, 4.0, 241.0, 87.0, 12.0, 20.083333333333332, 2.7701149425287355, 0.36099585062240663, 154.0, 87.0, 0.6390041493775933, 16.5990990990991, 30.0, 6.0, 1.95850622406639, 12.755186721991702, 12.0, 30.0, 3838.0, 4898.0, 0.4393315018315018, 4898.0, 8736.0, 0.5606684981684982, 11.348765049025692, 0.0, 30.0, 76738.0, 377741.0, 0.2031497772283125, 0.35751872094018566))","Map(vectorType -> dense, length -> 2, values -> List(18.68615199424255, 1.3138480057574546))","Map(vectorType -> dense, length -> 2, values -> List(0.9343075997121274, 0.06569240028787272))",0.0,"List(0.9343075997121274, 0.06569240028787272)"
177418,35166,1786000,2,1,1,0.5,70,28,7.0,3.0,4.0,2.0,2.5,10.0,0.0022935779816513,0.0015723270440251,24,872,236,94,9.27659574468085,3.694915254237288,0.2706422018348624,636,236,0.7293577981651376,3.1795166858458,21.0,0.0,2.311926605504587,10.97362385321101,94,5.0,489,1169,0.2949336550060313,1169,1658,0.7050663449939686,10.973868706182282,0.0,30.0,76476,297037,0.2574628749953709,0.4476034699985976,"Map(vectorType -> dense, length -> 44, values -> List(2.0, 1.0, 1.0, 0.5, 70.0, 28.0, 7.0, 3.0, 4.0, 2.0, 2.5, 10.0, 0.0022935779816513763, 0.0015723270440251573, 24.0, 872.0, 236.0, 94.0, 9.27659574468085, 3.694915254237288, 0.2706422018348624, 636.0, 236.0, 0.7293577981651376, 3.1795166858457997, 21.0, 0.0, 2.311926605504587, 10.97362385321101, 94.0, 5.0, 489.0, 1169.0, 0.29493365500603136, 1169.0, 1658.0, 0.7050663449939686, 10.973868706182282, 0.0, 30.0, 76476.0, 297037.0, 0.25746287499537096, 0.4476034699985976))","Map(vectorType -> dense, length -> 2, values -> List(18.850900731484263, 1.1490992685157388))","Map(vectorType -> dense, length -> 2, values -> List(0.9425450365742132, 0.057454963425786945))",0.0,"List(0.9425450365742132, 0.057454963425786945)"
115553,41793,2548573,5,4,1,0.8,96,17,6.8,4.4,15.0,1.0,2.4,9.0,0.0046816479400749,0.0047449584816132,2,1068,225,98,10.89795918367347,4.746666666666667,0.2106741573033707,843,225,0.7893258426966292,3.885230479774224,30.0,0.0,3.093632958801498,11.82865168539326,98,3.0,2333,869,0.7286071205496565,869,3202,0.2713928794503435,9.980026195153895,0.0,30.0,54202,234065,0.231568154145216,0.0398247253051274,"Map(vectorType -> dense, length -> 44, values -> List(5.0, 4.0, 1.0, 0.8, 96.0, 17.0, 6.8, 4.4, 15.0, 1.0, 2.4, 9.0, 0.0046816479400749065, 0.004744958481613286, 2.0, 1068.0, 225.0, 98.0, 10.89795918367347, 4.746666666666667, 0.21067415730337077, 843.0, 225.0, 0.7893258426966292, 3.885230479774224, 30.0, 0.0, 3.093632958801498, 11.82865168539326, 98.0, 3.0, 2333.0, 869.0, 0.7286071205496565, 869.0, 3202.0, 0.2713928794503435, 9.980026195153897, 0.0, 30.0, 54202.0, 234065.0, 0.23156815414521606, 0.039824725305127456))","Map(vectorType -> dense, length -> 2, values -> List(18.136398020554342, 1.863601979445663))","Map(vectorType -> dense, length -> 2, values -> List(0.9068199010277169, 0.09318009897228313))",0.0,"List(0.9068199010277169, 0.09318009897228313)"
105400,45007,2176640,2,1,1,0.5,12,9,6.0,19.5,30.0,9.0,2.5,15.5,0.011049723756906,0.0075187969924812,6,181,48,18,10.055555555555555,3.7708333333333335,0.2651933701657458,133,48,0.7348066298342542,13.538922155688622,30.0,5.0,2.0939226519337018,16.03867403314917,18,12.0,72165,32658,0.6884462379439626,32658,104823,0.3115537620560373,10.597150051962586,0.0,30.0,159213,3418021,0.0465804627882625,0.2649732992677747,"Map(vectorType -> dense, length -> 44, values -> List(2.0, 1.0, 1.0, 0.5, 12.0, 9.0, 6.0, 19.5, 30.0, 9.0, 2.5, 15.5, 0.011049723756906077, 0.007518796992481203, 6.0, 181.0, 48.0, 18.0, 10.055555555555555, 3.7708333333333335, 0.26519337016574585, 133.0, 48.0, 0.7348066298342542, 13.538922155688622, 30.0, 5.0, 2.0939226519337018, 16.03867403314917, 18.0, 12.0, 72165.0, 32658.0, 0.6884462379439626, 32658.0, 104823.0, 0.3115537620560373, 10.597150051962586, 0.0, 30.0, 159213.0, 3418021.0, 0.04658046278826256, 0.26497329926777474))","Map(vectorType -> dense, length -> 2, values -> List(18.596309099092576, 1.403690900907425))","Map(vectorType -> dense, length -> 2, values -> List(0.9298154549546288, 0.07018454504537125))",0.0,"List(0.9298154549546288, 0.07018454504537125)"
93880,49517,247785,28,27,1,0.9642857142857144,33,1,5.25,10.22222222222222,28.0,1.0,3.107142857142857,10.785714285714286,0.0590717299578059,0.0694087403598971,0,474,85,33,14.363636363636363,5.576470588235294,0.1793248945147679,389,85,0.820675105485232,10.951002227171491,28.0,1.0,2.8375527426160336,11.232067510548523,33,26.0,2706,1115,0.7081915728866789,1115,3821,0.2918084271133211,11.783276450511943,0.0,30.0,109226,891015,0.1225860395167309,0.1692223875965902,"Map(vectorType -> dense, length -> 44, values -> List(28.0, 27.0, 1.0, 0.9642857142857143, 33.0, 1.0, 5.25, 10.222222222222221, 28.0, 1.0, 3.107142857142857, 10.785714285714286, 0.05907172995780591, 0.06940874035989718, 0.0, 474.0, 85.0, 33.0, 14.363636363636363, 5.576470588235294, 0.17932489451476794, 389.0, 85.0, 0.820675105485232, 10.951002227171493, 28.0, 1.0, 2.8375527426160336, 11.232067510548523, 33.0, 26.0, 2706.0, 1115.0, 0.7081915728866789, 1115.0, 3821.0, 0.2918084271133211, 11.783276450511945, 0.0, 30.0, 109226.0, 891015.0, 0.12258603951673092, 0.1692223875965902))","Map(vectorType -> dense, length -> 2, values -> List(10.195436432543355, 9.804563567456643))","Map(vectorType -> dense, length -> 2, values -> List(0.5097718216271677, 0.49022817837283217))",0.0,"List(0.5097718216271677, 0.49022817837283217)"
63685,39190,1263744,1,0,1,0.0,4,4,14.0,8.0,8.0,8.0,1.0,23.0,0.0054347826086956,0.0,8,184,130,12,15.333333333333334,1.4153846153846157,0.7065217391304348,54,130,0.2934782608695652,18.22360248447205,30.0,2.0,2.717391304347826,13.36413043478261,12,19.0,6294,4678,0.5736419978126139,4678,10972,0.426358002187386,12.504883101509323,0.0,30.0,96326,638253,0.1509213431037535,0.2754366590836326,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 4.0, 4.0, 14.0, 8.0, 8.0, 8.0, 1.0, 23.0, 0.005434782608695652, 0.0, 8.0, 184.0, 130.0, 12.0, 15.333333333333334, 1.4153846153846155, 0.7065217391304348, 54.0, 130.0, 0.29347826086956524, 18.22360248447205, 30.0, 2.0, 2.717391304347826, 13.36413043478261, 12.0, 19.0, 6294.0, 4678.0, 0.5736419978126139, 4678.0, 10972.0, 0.42635800218738606, 12.504883101509323, 0.0, 30.0, 96326.0, 638253.0, 0.15092134310375352, 0.2754366590836326))","Map(vectorType -> dense, length -> 2, values -> List(18.74200689205801, 1.2579931079419875))","Map(vectorType -> dense, length -> 2, values -> List(0.9371003446029006, 0.06289965539709938))",0.0,"List(0.9371003446029006, 0.06289965539709938)"


In [0]:
predictions.select(F.col('probability_arr')[1]).show(10)

+-------------------+
| probability_arr[1]|
+-------------------+
|0.07079912319654799|
|0.10279387363411373|
|0.06813538171362678|
|0.10114706612587518|
|0.09479365890970702|
| 0.3685878863341303|
|0.06132787010799483|
|0.09811041377649277|
|0.06569240028787272|
|0.06041232201864379|
+-------------------+
only showing top 10 rows



In [0]:
# 변환된 array에서 1일때의 확률값을 추출. 
predictions = predictions.withColumn('1_proba', F.col('probability_arr')[1])
display(predictions.limit(10))

user_id,product_id,order_id,up_cnt,up_reord_cnt,up_no_reord_cnt,up_reoredered_avg,up_max_ord_num,up_min_ord_num,up_avg_cart,up_avg_prior_days,up_max_prior_days,up_min_prior_days,up_avg_ord_dow,up_avg_ord_hour,up_usr_ratio,up_usr_reord_ratio,up_usr_ord_num_diff,usr_total_cnt,prd_uq_cnt,order_uq_cnt,usr_avg_prd_cnt,usr_avg_uq_prd_cnt,usr_uq_prd_ratio,usr_reord_cnt,usr_no_reord_cnt,usr_reordered_avg,usr_avg_prior_days,usr_max_prior_days,usr_min_prior_days,usr_avg_order_dow,usr_avg_order_hour_of_day,usr_max_order_number,days_since_prior_order,prd_reordered_cnt,prd_no_reordered_cnt,prd_avg_reordered,prd_unq_usr_cnt,prd_total_cnt,prd_usr_ratio,prd_avg_prior_days,prd_min_prior_days,prd_max_prior_days,aisle_distinct_usr_cnt,aisle_total_cnt,aisle_usr_ratio,usr_ratio_diff,features,rawPrediction,probability,prediction,probability_arr,1_proba
172302,19822,1243888,1,0,1,0.0,3,3,8.0,4.0,4.0,4.0,2.0,12.0,0.0232558139534883,0.0,2,43,37,5,8.6,1.162162162162162,0.8604651162790697,6,37,0.1395348837209302,5.269230769230769,8.0,3.0,2.441860465116279,12.30232558139535,5,17.0,15,19,0.4411764705882353,19,34,0.5588235294117647,10.303030303030305,0.0,30.0,5367,11798,0.4549076114595694,0.1039159179521952,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 3.0, 3.0, 8.0, 4.0, 4.0, 4.0, 2.0, 12.0, 0.023255813953488372, 0.0, 2.0, 43.0, 37.0, 5.0, 8.6, 1.162162162162162, 0.8604651162790697, 6.0, 37.0, 0.13953488372093023, 5.269230769230769, 8.0, 3.0, 2.441860465116279, 12.30232558139535, 5.0, 17.0, 15.0, 19.0, 0.4411764705882353, 19.0, 34.0, 0.5588235294117647, 10.303030303030303, 0.0, 30.0, 5367.0, 11798.0, 0.4549076114595694, 0.10391591795219529))","Map(vectorType -> dense, length -> 2, values -> List(18.58401753606904, 1.41598246393096))","Map(vectorType -> dense, length -> 2, values -> List(0.9292008768034521, 0.07079912319654799))",0.0,"List(0.9292008768034521, 0.07079912319654799)",0.0707991231965479
1623,43908,1778015,2,1,1,0.5,4,1,2.0,30.0,30.0,30.0,0.0,11.0,0.016260162601626,0.0126582278481012,16,123,44,20,6.15,2.7954545454545454,0.3577235772357723,79,44,0.6422764227642277,15.721739130434782,30.0,2.0,1.1788617886178865,12.471544715447154,20,6.0,2115,2357,0.4729427549194991,2357,4472,0.5270572450805009,11.802751166789486,0.0,30.0,58749,390299,0.150523060525392,0.3765341845551088,"Map(vectorType -> dense, length -> 44, values -> List(2.0, 1.0, 1.0, 0.5, 4.0, 1.0, 2.0, 30.0, 30.0, 30.0, 0.0, 11.0, 0.016260162601626018, 0.012658227848101266, 16.0, 123.0, 44.0, 20.0, 6.15, 2.7954545454545454, 0.35772357723577236, 79.0, 44.0, 0.6422764227642277, 15.721739130434782, 30.0, 2.0, 1.1788617886178863, 12.471544715447154, 20.0, 6.0, 2115.0, 2357.0, 0.4729427549194991, 2357.0, 4472.0, 0.5270572450805009, 11.802751166789486, 0.0, 30.0, 58749.0, 390299.0, 0.15052306052539208, 0.3765341845551088))","Map(vectorType -> dense, length -> 2, values -> List(18.63729236572747, 1.3627076342725357))","Map(vectorType -> dense, length -> 2, values -> List(0.9318646182863732, 0.06813538171362678))",0.0,"List(0.9318646182863732, 0.06813538171362678)",0.0681353817136267
160722,46979,2741763,11,10,1,0.9090909090909092,34,1,12.454545454545457,10.1,22.0,1.0,2.4545454545454546,11.727272727272728,0.0121546961325966,0.0168067226890756,4,905,310,38,23.81578947368421,2.919354838709677,0.3425414364640884,595,310,0.6574585635359116,8.077966101694916,26.0,1.0,2.889502762430939,12.103867403314917,38,5.0,41585,25698,0.618061025816328,25698,67283,0.3819389741836719,10.788400680629104,0.0,30.0,159213,3418021,0.0465804627882625,0.3353585113954094,"Map(vectorType -> dense, length -> 44, values -> List(11.0, 10.0, 1.0, 0.9090909090909091, 34.0, 1.0, 12.454545454545455, 10.1, 22.0, 1.0, 2.4545454545454546, 11.727272727272727, 0.012154696132596685, 0.01680672268907563, 4.0, 905.0, 310.0, 38.0, 23.81578947368421, 2.9193548387096775, 0.3425414364640884, 595.0, 310.0, 0.6574585635359116, 8.077966101694916, 26.0, 1.0, 2.889502762430939, 12.103867403314917, 38.0, 5.0, 41585.0, 25698.0, 0.618061025816328, 25698.0, 67283.0, 0.38193897418367195, 10.788400680629104, 0.0, 30.0, 159213.0, 3418021.0, 0.04658046278826256, 0.3353585113954094))","Map(vectorType -> dense, length -> 2, values -> List(18.10412682180586, 1.8958731781941405))","Map(vectorType -> dense, length -> 2, values -> List(0.905206341090293, 0.09479365890970702))",0.0,"List(0.905206341090293, 0.09479365890970702)",0.094793658909707
137485,21903,1392405,5,4,1,0.8,16,6,4.0,17.2,30.0,2.0,2.6,20.8,0.0427350427350427,0.0909090909090909,0,117,73,16,7.3125,1.6027397260273972,0.6239316239316239,44,73,0.376068376068376,16.570093457943926,30.0,0.0,3.1880341880341883,20.05982905982906,16,20.0,186884,55037,0.7725001136734719,55037,241921,0.2274998863265281,11.199653703303918,0.0,30.0,159418,1765313,0.0903057984618025,0.1371940878647256,"Map(vectorType -> dense, length -> 44, values -> List(5.0, 4.0, 1.0, 0.8, 16.0, 6.0, 4.0, 17.2, 30.0, 2.0, 2.6, 20.8, 0.042735042735042736, 0.09090909090909091, 0.0, 117.0, 73.0, 16.0, 7.3125, 1.6027397260273972, 0.6239316239316239, 44.0, 73.0, 0.37606837606837606, 16.570093457943926, 30.0, 0.0, 3.1880341880341883, 20.05982905982906, 16.0, 20.0, 186884.0, 55037.0, 0.7725001136734719, 55037.0, 241921.0, 0.2274998863265281, 11.199653703303918, 0.0, 30.0, 159418.0, 1765313.0, 0.09030579846180252, 0.1371940878647256))","Map(vectorType -> dense, length -> 2, values -> List(12.628242273317392, 7.371757726682604))","Map(vectorType -> dense, length -> 2, values -> List(0.6314121136658697, 0.3685878863341303))",0.0,"List(0.6314121136658697, 0.3685878863341303)",0.3685878863341303
100368,1529,1583293,2,1,1,0.5,8,6,14.0,13.5,21.0,6.0,0.0,16.0,0.0082987551867219,0.0064935064935064,4,241,87,12,20.08333333333333,2.7701149425287355,0.3609958506224066,154,87,0.6390041493775933,16.5990990990991,30.0,6.0,1.95850622406639,12.755186721991702,12,30.0,3838,4898,0.4393315018315018,4898,8736,0.5606684981684982,11.348765049025692,0.0,30.0,76738,377741,0.2031497772283125,0.3575187209401856,"Map(vectorType -> dense, length -> 44, values -> List(2.0, 1.0, 1.0, 0.5, 8.0, 6.0, 14.0, 13.5, 21.0, 6.0, 0.0, 16.0, 0.008298755186721992, 0.006493506493506494, 4.0, 241.0, 87.0, 12.0, 20.083333333333332, 2.7701149425287355, 0.36099585062240663, 154.0, 87.0, 0.6390041493775933, 16.5990990990991, 30.0, 6.0, 1.95850622406639, 12.755186721991702, 12.0, 30.0, 3838.0, 4898.0, 0.4393315018315018, 4898.0, 8736.0, 0.5606684981684982, 11.348765049025692, 0.0, 30.0, 76738.0, 377741.0, 0.2031497772283125, 0.35751872094018566))","Map(vectorType -> dense, length -> 2, values -> List(18.68615199424255, 1.3138480057574546))","Map(vectorType -> dense, length -> 2, values -> List(0.9343075997121274, 0.06569240028787272))",0.0,"List(0.9343075997121274, 0.06569240028787272)",0.0656924002878727
177418,35166,1786000,2,1,1,0.5,70,28,7.0,3.0,4.0,2.0,2.5,10.0,0.0022935779816513,0.0015723270440251,24,872,236,94,9.27659574468085,3.694915254237288,0.2706422018348624,636,236,0.7293577981651376,3.1795166858458,21.0,0.0,2.311926605504587,10.97362385321101,94,5.0,489,1169,0.2949336550060313,1169,1658,0.7050663449939686,10.973868706182282,0.0,30.0,76476,297037,0.2574628749953709,0.4476034699985976,"Map(vectorType -> dense, length -> 44, values -> List(2.0, 1.0, 1.0, 0.5, 70.0, 28.0, 7.0, 3.0, 4.0, 2.0, 2.5, 10.0, 0.0022935779816513763, 0.0015723270440251573, 24.0, 872.0, 236.0, 94.0, 9.27659574468085, 3.694915254237288, 0.2706422018348624, 636.0, 236.0, 0.7293577981651376, 3.1795166858457997, 21.0, 0.0, 2.311926605504587, 10.97362385321101, 94.0, 5.0, 489.0, 1169.0, 0.29493365500603136, 1169.0, 1658.0, 0.7050663449939686, 10.973868706182282, 0.0, 30.0, 76476.0, 297037.0, 0.25746287499537096, 0.4476034699985976))","Map(vectorType -> dense, length -> 2, values -> List(18.850900731484263, 1.1490992685157388))","Map(vectorType -> dense, length -> 2, values -> List(0.9425450365742132, 0.057454963425786945))",0.0,"List(0.9425450365742132, 0.057454963425786945)",0.0574549634257869
115553,41793,2548573,5,4,1,0.8,96,17,6.8,4.4,15.0,1.0,2.4,9.0,0.0046816479400749,0.0047449584816132,2,1068,225,98,10.89795918367347,4.746666666666667,0.2106741573033707,843,225,0.7893258426966292,3.885230479774224,30.0,0.0,3.093632958801498,11.82865168539326,98,3.0,2333,869,0.7286071205496565,869,3202,0.2713928794503435,9.980026195153895,0.0,30.0,54202,234065,0.231568154145216,0.0398247253051274,"Map(vectorType -> dense, length -> 44, values -> List(5.0, 4.0, 1.0, 0.8, 96.0, 17.0, 6.8, 4.4, 15.0, 1.0, 2.4, 9.0, 0.0046816479400749065, 0.004744958481613286, 2.0, 1068.0, 225.0, 98.0, 10.89795918367347, 4.746666666666667, 0.21067415730337077, 843.0, 225.0, 0.7893258426966292, 3.885230479774224, 30.0, 0.0, 3.093632958801498, 11.82865168539326, 98.0, 3.0, 2333.0, 869.0, 0.7286071205496565, 869.0, 3202.0, 0.2713928794503435, 9.980026195153897, 0.0, 30.0, 54202.0, 234065.0, 0.23156815414521606, 0.039824725305127456))","Map(vectorType -> dense, length -> 2, values -> List(18.136398020554342, 1.863601979445663))","Map(vectorType -> dense, length -> 2, values -> List(0.9068199010277169, 0.09318009897228313))",0.0,"List(0.9068199010277169, 0.09318009897228313)",0.0931800989722831
105400,45007,2176640,2,1,1,0.5,12,9,6.0,19.5,30.0,9.0,2.5,15.5,0.011049723756906,0.0075187969924812,6,181,48,18,10.055555555555555,3.7708333333333335,0.2651933701657458,133,48,0.7348066298342542,13.538922155688622,30.0,5.0,2.0939226519337018,16.03867403314917,18,12.0,72165,32658,0.6884462379439626,32658,104823,0.3115537620560373,10.597150051962586,0.0,30.0,159213,3418021,0.0465804627882625,0.2649732992677747,"Map(vectorType -> dense, length -> 44, values -> List(2.0, 1.0, 1.0, 0.5, 12.0, 9.0, 6.0, 19.5, 30.0, 9.0, 2.5, 15.5, 0.011049723756906077, 0.007518796992481203, 6.0, 181.0, 48.0, 18.0, 10.055555555555555, 3.7708333333333335, 0.26519337016574585, 133.0, 48.0, 0.7348066298342542, 13.538922155688622, 30.0, 5.0, 2.0939226519337018, 16.03867403314917, 18.0, 12.0, 72165.0, 32658.0, 0.6884462379439626, 32658.0, 104823.0, 0.3115537620560373, 10.597150051962586, 0.0, 30.0, 159213.0, 3418021.0, 0.04658046278826256, 0.26497329926777474))","Map(vectorType -> dense, length -> 2, values -> List(18.596309099092576, 1.403690900907425))","Map(vectorType -> dense, length -> 2, values -> List(0.9298154549546288, 0.07018454504537125))",0.0,"List(0.9298154549546288, 0.07018454504537125)",0.0701845450453712
93880,49517,247785,28,27,1,0.9642857142857144,33,1,5.25,10.22222222222222,28.0,1.0,3.107142857142857,10.785714285714286,0.0590717299578059,0.0694087403598971,0,474,85,33,14.363636363636363,5.576470588235294,0.1793248945147679,389,85,0.820675105485232,10.951002227171491,28.0,1.0,2.8375527426160336,11.232067510548523,33,26.0,2706,1115,0.7081915728866789,1115,3821,0.2918084271133211,11.783276450511943,0.0,30.0,109226,891015,0.1225860395167309,0.1692223875965902,"Map(vectorType -> dense, length -> 44, values -> List(28.0, 27.0, 1.0, 0.9642857142857143, 33.0, 1.0, 5.25, 10.222222222222221, 28.0, 1.0, 3.107142857142857, 10.785714285714286, 0.05907172995780591, 0.06940874035989718, 0.0, 474.0, 85.0, 33.0, 14.363636363636363, 5.576470588235294, 0.17932489451476794, 389.0, 85.0, 0.820675105485232, 10.951002227171493, 28.0, 1.0, 2.8375527426160336, 11.232067510548523, 33.0, 26.0, 2706.0, 1115.0, 0.7081915728866789, 1115.0, 3821.0, 0.2918084271133211, 11.783276450511945, 0.0, 30.0, 109226.0, 891015.0, 0.12258603951673092, 0.1692223875965902))","Map(vectorType -> dense, length -> 2, values -> List(10.195436432543355, 9.804563567456643))","Map(vectorType -> dense, length -> 2, values -> List(0.5097718216271677, 0.49022817837283217))",0.0,"List(0.5097718216271677, 0.49022817837283217)",0.4902281783728321
63685,39190,1263744,1,0,1,0.0,4,4,14.0,8.0,8.0,8.0,1.0,23.0,0.0054347826086956,0.0,8,184,130,12,15.333333333333334,1.4153846153846157,0.7065217391304348,54,130,0.2934782608695652,18.22360248447205,30.0,2.0,2.717391304347826,13.36413043478261,12,19.0,6294,4678,0.5736419978126139,4678,10972,0.426358002187386,12.504883101509323,0.0,30.0,96326,638253,0.1509213431037535,0.2754366590836326,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 4.0, 4.0, 14.0, 8.0, 8.0, 8.0, 1.0, 23.0, 0.005434782608695652, 0.0, 8.0, 184.0, 130.0, 12.0, 15.333333333333334, 1.4153846153846155, 0.7065217391304348, 54.0, 130.0, 0.29347826086956524, 18.22360248447205, 30.0, 2.0, 2.717391304347826, 13.36413043478261, 12.0, 19.0, 6294.0, 4678.0, 0.5736419978126139, 4678.0, 10972.0, 0.42635800218738606, 12.504883101509323, 0.0, 30.0, 96326.0, 638253.0, 0.15092134310375352, 0.2754366590836326))","Map(vectorType -> dense, length -> 2, values -> List(18.74200689205801, 1.2579931079419875))","Map(vectorType -> dense, length -> 2, values -> List(0.9371003446029006, 0.06289965539709938))",0.0,"List(0.9371003446029006, 0.06289965539709938)",0.0628996553970993


In [0]:
REORDER_THRESHOLD = 0.21
# 1_proba값이 REORDER_THRESHOLD보다 크면 1, 그렇지 않으면 0으로 reordered 컬럼 추가.
predictions = predictions.withColumn('reordered', (F.col('1_proba') > REORDER_THRESHOLD).cast('int')) 
display(predictions.limit(10))

user_id,product_id,order_id,up_cnt,up_reord_cnt,up_no_reord_cnt,up_reoredered_avg,up_max_ord_num,up_min_ord_num,up_avg_cart,up_avg_prior_days,up_max_prior_days,up_min_prior_days,up_avg_ord_dow,up_avg_ord_hour,up_usr_ratio,up_usr_reord_ratio,up_usr_ord_num_diff,usr_total_cnt,prd_uq_cnt,order_uq_cnt,usr_avg_prd_cnt,usr_avg_uq_prd_cnt,usr_uq_prd_ratio,usr_reord_cnt,usr_no_reord_cnt,usr_reordered_avg,usr_avg_prior_days,usr_max_prior_days,usr_min_prior_days,usr_avg_order_dow,usr_avg_order_hour_of_day,usr_max_order_number,days_since_prior_order,prd_reordered_cnt,prd_no_reordered_cnt,prd_avg_reordered,prd_unq_usr_cnt,prd_total_cnt,prd_usr_ratio,prd_avg_prior_days,prd_min_prior_days,prd_max_prior_days,aisle_distinct_usr_cnt,aisle_total_cnt,aisle_usr_ratio,usr_ratio_diff,features,rawPrediction,probability,prediction,probability_arr,1_proba,reordered
172302,19822,1243888,1,0,1,0.0,3,3,8.0,4.0,4.0,4.0,2.0,12.0,0.0232558139534883,0.0,2,43,37,5,8.6,1.162162162162162,0.8604651162790697,6,37,0.1395348837209302,5.269230769230769,8.0,3.0,2.441860465116279,12.30232558139535,5,17.0,15,19,0.4411764705882353,19,34,0.5588235294117647,10.303030303030305,0.0,30.0,5367,11798,0.4549076114595694,0.1039159179521952,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 3.0, 3.0, 8.0, 4.0, 4.0, 4.0, 2.0, 12.0, 0.023255813953488372, 0.0, 2.0, 43.0, 37.0, 5.0, 8.6, 1.162162162162162, 0.8604651162790697, 6.0, 37.0, 0.13953488372093023, 5.269230769230769, 8.0, 3.0, 2.441860465116279, 12.30232558139535, 5.0, 17.0, 15.0, 19.0, 0.4411764705882353, 19.0, 34.0, 0.5588235294117647, 10.303030303030303, 0.0, 30.0, 5367.0, 11798.0, 0.4549076114595694, 0.10391591795219529))","Map(vectorType -> dense, length -> 2, values -> List(18.58401753606904, 1.41598246393096))","Map(vectorType -> dense, length -> 2, values -> List(0.9292008768034521, 0.07079912319654799))",0.0,"List(0.9292008768034521, 0.07079912319654799)",0.0707991231965479,0
1623,43908,1778015,2,1,1,0.5,4,1,2.0,30.0,30.0,30.0,0.0,11.0,0.016260162601626,0.0126582278481012,16,123,44,20,6.15,2.7954545454545454,0.3577235772357723,79,44,0.6422764227642277,15.721739130434782,30.0,2.0,1.1788617886178865,12.471544715447154,20,6.0,2115,2357,0.4729427549194991,2357,4472,0.5270572450805009,11.802751166789486,0.0,30.0,58749,390299,0.150523060525392,0.3765341845551088,"Map(vectorType -> dense, length -> 44, values -> List(2.0, 1.0, 1.0, 0.5, 4.0, 1.0, 2.0, 30.0, 30.0, 30.0, 0.0, 11.0, 0.016260162601626018, 0.012658227848101266, 16.0, 123.0, 44.0, 20.0, 6.15, 2.7954545454545454, 0.35772357723577236, 79.0, 44.0, 0.6422764227642277, 15.721739130434782, 30.0, 2.0, 1.1788617886178863, 12.471544715447154, 20.0, 6.0, 2115.0, 2357.0, 0.4729427549194991, 2357.0, 4472.0, 0.5270572450805009, 11.802751166789486, 0.0, 30.0, 58749.0, 390299.0, 0.15052306052539208, 0.3765341845551088))","Map(vectorType -> dense, length -> 2, values -> List(18.63729236572747, 1.3627076342725357))","Map(vectorType -> dense, length -> 2, values -> List(0.9318646182863732, 0.06813538171362678))",0.0,"List(0.9318646182863732, 0.06813538171362678)",0.0681353817136267,0
160722,46979,2741763,11,10,1,0.9090909090909092,34,1,12.454545454545457,10.1,22.0,1.0,2.4545454545454546,11.727272727272728,0.0121546961325966,0.0168067226890756,4,905,310,38,23.81578947368421,2.919354838709677,0.3425414364640884,595,310,0.6574585635359116,8.077966101694916,26.0,1.0,2.889502762430939,12.103867403314917,38,5.0,41585,25698,0.618061025816328,25698,67283,0.3819389741836719,10.788400680629104,0.0,30.0,159213,3418021,0.0465804627882625,0.3353585113954094,"Map(vectorType -> dense, length -> 44, values -> List(11.0, 10.0, 1.0, 0.9090909090909091, 34.0, 1.0, 12.454545454545455, 10.1, 22.0, 1.0, 2.4545454545454546, 11.727272727272727, 0.012154696132596685, 0.01680672268907563, 4.0, 905.0, 310.0, 38.0, 23.81578947368421, 2.9193548387096775, 0.3425414364640884, 595.0, 310.0, 0.6574585635359116, 8.077966101694916, 26.0, 1.0, 2.889502762430939, 12.103867403314917, 38.0, 5.0, 41585.0, 25698.0, 0.618061025816328, 25698.0, 67283.0, 0.38193897418367195, 10.788400680629104, 0.0, 30.0, 159213.0, 3418021.0, 0.04658046278826256, 0.3353585113954094))","Map(vectorType -> dense, length -> 2, values -> List(18.10412682180586, 1.8958731781941405))","Map(vectorType -> dense, length -> 2, values -> List(0.905206341090293, 0.09479365890970702))",0.0,"List(0.905206341090293, 0.09479365890970702)",0.094793658909707,0
137485,21903,1392405,5,4,1,0.8,16,6,4.0,17.2,30.0,2.0,2.6,20.8,0.0427350427350427,0.0909090909090909,0,117,73,16,7.3125,1.6027397260273972,0.6239316239316239,44,73,0.376068376068376,16.570093457943926,30.0,0.0,3.1880341880341883,20.05982905982906,16,20.0,186884,55037,0.7725001136734719,55037,241921,0.2274998863265281,11.199653703303918,0.0,30.0,159418,1765313,0.0903057984618025,0.1371940878647256,"Map(vectorType -> dense, length -> 44, values -> List(5.0, 4.0, 1.0, 0.8, 16.0, 6.0, 4.0, 17.2, 30.0, 2.0, 2.6, 20.8, 0.042735042735042736, 0.09090909090909091, 0.0, 117.0, 73.0, 16.0, 7.3125, 1.6027397260273972, 0.6239316239316239, 44.0, 73.0, 0.37606837606837606, 16.570093457943926, 30.0, 0.0, 3.1880341880341883, 20.05982905982906, 16.0, 20.0, 186884.0, 55037.0, 0.7725001136734719, 55037.0, 241921.0, 0.2274998863265281, 11.199653703303918, 0.0, 30.0, 159418.0, 1765313.0, 0.09030579846180252, 0.1371940878647256))","Map(vectorType -> dense, length -> 2, values -> List(12.628242273317392, 7.371757726682604))","Map(vectorType -> dense, length -> 2, values -> List(0.6314121136658697, 0.3685878863341303))",0.0,"List(0.6314121136658697, 0.3685878863341303)",0.3685878863341303,1
100368,1529,1583293,2,1,1,0.5,8,6,14.0,13.5,21.0,6.0,0.0,16.0,0.0082987551867219,0.0064935064935064,4,241,87,12,20.08333333333333,2.7701149425287355,0.3609958506224066,154,87,0.6390041493775933,16.5990990990991,30.0,6.0,1.95850622406639,12.755186721991702,12,30.0,3838,4898,0.4393315018315018,4898,8736,0.5606684981684982,11.348765049025692,0.0,30.0,76738,377741,0.2031497772283125,0.3575187209401856,"Map(vectorType -> dense, length -> 44, values -> List(2.0, 1.0, 1.0, 0.5, 8.0, 6.0, 14.0, 13.5, 21.0, 6.0, 0.0, 16.0, 0.008298755186721992, 0.006493506493506494, 4.0, 241.0, 87.0, 12.0, 20.083333333333332, 2.7701149425287355, 0.36099585062240663, 154.0, 87.0, 0.6390041493775933, 16.5990990990991, 30.0, 6.0, 1.95850622406639, 12.755186721991702, 12.0, 30.0, 3838.0, 4898.0, 0.4393315018315018, 4898.0, 8736.0, 0.5606684981684982, 11.348765049025692, 0.0, 30.0, 76738.0, 377741.0, 0.2031497772283125, 0.35751872094018566))","Map(vectorType -> dense, length -> 2, values -> List(18.68615199424255, 1.3138480057574546))","Map(vectorType -> dense, length -> 2, values -> List(0.9343075997121274, 0.06569240028787272))",0.0,"List(0.9343075997121274, 0.06569240028787272)",0.0656924002878727,0
177418,35166,1786000,2,1,1,0.5,70,28,7.0,3.0,4.0,2.0,2.5,10.0,0.0022935779816513,0.0015723270440251,24,872,236,94,9.27659574468085,3.694915254237288,0.2706422018348624,636,236,0.7293577981651376,3.1795166858458,21.0,0.0,2.311926605504587,10.97362385321101,94,5.0,489,1169,0.2949336550060313,1169,1658,0.7050663449939686,10.973868706182282,0.0,30.0,76476,297037,0.2574628749953709,0.4476034699985976,"Map(vectorType -> dense, length -> 44, values -> List(2.0, 1.0, 1.0, 0.5, 70.0, 28.0, 7.0, 3.0, 4.0, 2.0, 2.5, 10.0, 0.0022935779816513763, 0.0015723270440251573, 24.0, 872.0, 236.0, 94.0, 9.27659574468085, 3.694915254237288, 0.2706422018348624, 636.0, 236.0, 0.7293577981651376, 3.1795166858457997, 21.0, 0.0, 2.311926605504587, 10.97362385321101, 94.0, 5.0, 489.0, 1169.0, 0.29493365500603136, 1169.0, 1658.0, 0.7050663449939686, 10.973868706182282, 0.0, 30.0, 76476.0, 297037.0, 0.25746287499537096, 0.4476034699985976))","Map(vectorType -> dense, length -> 2, values -> List(18.850900731484263, 1.1490992685157388))","Map(vectorType -> dense, length -> 2, values -> List(0.9425450365742132, 0.057454963425786945))",0.0,"List(0.9425450365742132, 0.057454963425786945)",0.0574549634257869,0
115553,41793,2548573,5,4,1,0.8,96,17,6.8,4.4,15.0,1.0,2.4,9.0,0.0046816479400749,0.0047449584816132,2,1068,225,98,10.89795918367347,4.746666666666667,0.2106741573033707,843,225,0.7893258426966292,3.885230479774224,30.0,0.0,3.093632958801498,11.82865168539326,98,3.0,2333,869,0.7286071205496565,869,3202,0.2713928794503435,9.980026195153895,0.0,30.0,54202,234065,0.231568154145216,0.0398247253051274,"Map(vectorType -> dense, length -> 44, values -> List(5.0, 4.0, 1.0, 0.8, 96.0, 17.0, 6.8, 4.4, 15.0, 1.0, 2.4, 9.0, 0.0046816479400749065, 0.004744958481613286, 2.0, 1068.0, 225.0, 98.0, 10.89795918367347, 4.746666666666667, 0.21067415730337077, 843.0, 225.0, 0.7893258426966292, 3.885230479774224, 30.0, 0.0, 3.093632958801498, 11.82865168539326, 98.0, 3.0, 2333.0, 869.0, 0.7286071205496565, 869.0, 3202.0, 0.2713928794503435, 9.980026195153897, 0.0, 30.0, 54202.0, 234065.0, 0.23156815414521606, 0.039824725305127456))","Map(vectorType -> dense, length -> 2, values -> List(18.136398020554342, 1.863601979445663))","Map(vectorType -> dense, length -> 2, values -> List(0.9068199010277169, 0.09318009897228313))",0.0,"List(0.9068199010277169, 0.09318009897228313)",0.0931800989722831,0
105400,45007,2176640,2,1,1,0.5,12,9,6.0,19.5,30.0,9.0,2.5,15.5,0.011049723756906,0.0075187969924812,6,181,48,18,10.055555555555555,3.7708333333333335,0.2651933701657458,133,48,0.7348066298342542,13.538922155688622,30.0,5.0,2.0939226519337018,16.03867403314917,18,12.0,72165,32658,0.6884462379439626,32658,104823,0.3115537620560373,10.597150051962586,0.0,30.0,159213,3418021,0.0465804627882625,0.2649732992677747,"Map(vectorType -> dense, length -> 44, values -> List(2.0, 1.0, 1.0, 0.5, 12.0, 9.0, 6.0, 19.5, 30.0, 9.0, 2.5, 15.5, 0.011049723756906077, 0.007518796992481203, 6.0, 181.0, 48.0, 18.0, 10.055555555555555, 3.7708333333333335, 0.26519337016574585, 133.0, 48.0, 0.7348066298342542, 13.538922155688622, 30.0, 5.0, 2.0939226519337018, 16.03867403314917, 18.0, 12.0, 72165.0, 32658.0, 0.6884462379439626, 32658.0, 104823.0, 0.3115537620560373, 10.597150051962586, 0.0, 30.0, 159213.0, 3418021.0, 0.04658046278826256, 0.26497329926777474))","Map(vectorType -> dense, length -> 2, values -> List(18.596309099092576, 1.403690900907425))","Map(vectorType -> dense, length -> 2, values -> List(0.9298154549546288, 0.07018454504537125))",0.0,"List(0.9298154549546288, 0.07018454504537125)",0.0701845450453712,0
93880,49517,247785,28,27,1,0.9642857142857144,33,1,5.25,10.22222222222222,28.0,1.0,3.107142857142857,10.785714285714286,0.0590717299578059,0.0694087403598971,0,474,85,33,14.363636363636363,5.576470588235294,0.1793248945147679,389,85,0.820675105485232,10.951002227171491,28.0,1.0,2.8375527426160336,11.232067510548523,33,26.0,2706,1115,0.7081915728866789,1115,3821,0.2918084271133211,11.783276450511943,0.0,30.0,109226,891015,0.1225860395167309,0.1692223875965902,"Map(vectorType -> dense, length -> 44, values -> List(28.0, 27.0, 1.0, 0.9642857142857143, 33.0, 1.0, 5.25, 10.222222222222221, 28.0, 1.0, 3.107142857142857, 10.785714285714286, 0.05907172995780591, 0.06940874035989718, 0.0, 474.0, 85.0, 33.0, 14.363636363636363, 5.576470588235294, 0.17932489451476794, 389.0, 85.0, 0.820675105485232, 10.951002227171493, 28.0, 1.0, 2.8375527426160336, 11.232067510548523, 33.0, 26.0, 2706.0, 1115.0, 0.7081915728866789, 1115.0, 3821.0, 0.2918084271133211, 11.783276450511945, 0.0, 30.0, 109226.0, 891015.0, 0.12258603951673092, 0.1692223875965902))","Map(vectorType -> dense, length -> 2, values -> List(10.195436432543355, 9.804563567456643))","Map(vectorType -> dense, length -> 2, values -> List(0.5097718216271677, 0.49022817837283217))",0.0,"List(0.5097718216271677, 0.49022817837283217)",0.4902281783728321,1
63685,39190,1263744,1,0,1,0.0,4,4,14.0,8.0,8.0,8.0,1.0,23.0,0.0054347826086956,0.0,8,184,130,12,15.333333333333334,1.4153846153846157,0.7065217391304348,54,130,0.2934782608695652,18.22360248447205,30.0,2.0,2.717391304347826,13.36413043478261,12,19.0,6294,4678,0.5736419978126139,4678,10972,0.426358002187386,12.504883101509323,0.0,30.0,96326,638253,0.1509213431037535,0.2754366590836326,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 4.0, 4.0, 14.0, 8.0, 8.0, 8.0, 1.0, 23.0, 0.005434782608695652, 0.0, 8.0, 184.0, 130.0, 12.0, 15.333333333333334, 1.4153846153846155, 0.7065217391304348, 54.0, 130.0, 0.29347826086956524, 18.22360248447205, 30.0, 2.0, 2.717391304347826, 13.36413043478261, 12.0, 19.0, 6294.0, 4678.0, 0.5736419978126139, 4678.0, 10972.0, 0.42635800218738606, 12.504883101509323, 0.0, 30.0, 96326.0, 638253.0, 0.15092134310375352, 0.2754366590836326))","Map(vectorType -> dense, length -> 2, values -> List(18.74200689205801, 1.2579931079419875))","Map(vectorType -> dense, length -> 2, values -> List(0.9371003446029006, 0.06289965539709938))",0.0,"List(0.9371003446029006, 0.06289965539709938)",0.0628996553970993,0


In [0]:
# reordered가 1인 데이터만 추출. 
prediction_reordered = predictions.filter('reordered == 1')

In [0]:
#sample_submission.csv에 있는 데이터 로드하여, orders.csv에 있는 eval_set이 test이 데이터와 건수 비교. 
submission_sdf = spark.read.csv('/FileStore/tables/sample_submission.csv', header=True, inferSchema=True)
print(submission_sdf.count(), orders_sdf.filter("eval_set == 'test'").count())
display(submission_sdf)

75000 75000


order_id,products
17,39276 29259
34,39276 29259
137,39276 29259
182,39276 29259
257,39276 29259
313,39276 29259
353,39276 29259
386,39276 29259
414,39276 29259
418,39276 29259


In [0]:
test_orders_sdf = orders_sdf.filter("eval_set == 'test'")
display(test_orders_sdf.orderBy('order_id'))

order_id,user_id,eval_set,order_number,order_dow,order_hour_of_day,days_since_prior_order
17,36855,test,5,6,15,1.0
34,35220,test,20,3,11,8.0
137,187107,test,9,2,19,30.0
182,115892,test,28,0,11,8.0
257,35581,test,9,6,23,5.0
313,113359,test,31,6,22,7.0
353,173814,test,4,4,13,30.0
386,55492,test,8,0,15,30.0
414,120775,test,18,5,14,8.0
418,33565,test,12,0,12,14.0


In [0]:
submission_sdf.createOrReplaceTempView('submission')

In [0]:
%sql
-- test_data에서 submission에 없는 order가 있는지 확인
select count(*)
from test_data a
left outer join submission b
on a.order_id = b.order_id 
where b.order_id is null 

count(1)
0


In [0]:
display(predictions.limit(10))

user_id,product_id,order_id,up_cnt,up_reord_cnt,up_no_reord_cnt,up_reoredered_avg,up_max_ord_num,up_min_ord_num,up_avg_cart,up_avg_prior_days,up_max_prior_days,up_min_prior_days,up_avg_ord_dow,up_avg_ord_hour,up_usr_ratio,up_usr_reord_ratio,up_usr_ord_num_diff,usr_total_cnt,prd_uq_cnt,order_uq_cnt,usr_avg_prd_cnt,usr_avg_uq_prd_cnt,usr_uq_prd_ratio,usr_reord_cnt,usr_no_reord_cnt,usr_reordered_avg,usr_avg_prior_days,usr_max_prior_days,usr_min_prior_days,usr_avg_order_dow,usr_avg_order_hour_of_day,usr_max_order_number,days_since_prior_order,prd_reordered_cnt,prd_no_reordered_cnt,prd_avg_reordered,prd_unq_usr_cnt,prd_total_cnt,prd_usr_ratio,prd_avg_prior_days,prd_min_prior_days,prd_max_prior_days,aisle_distinct_usr_cnt,aisle_total_cnt,aisle_usr_ratio,usr_ratio_diff,features,rawPrediction,probability,prediction,probability_arr,1_proba,reordered
172302,19822,1243888,1,0,1,0.0,3,3,8.0,4.0,4.0,4.0,2.0,12.0,0.0232558139534883,0.0,2,43,37,5,8.6,1.162162162162162,0.8604651162790697,6,37,0.1395348837209302,5.269230769230769,8.0,3.0,2.441860465116279,12.30232558139535,5,17.0,15,19,0.4411764705882353,19,34,0.5588235294117647,10.303030303030305,0.0,30.0,5367,11798,0.4549076114595694,0.1039159179521952,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 3.0, 3.0, 8.0, 4.0, 4.0, 4.0, 2.0, 12.0, 0.023255813953488372, 0.0, 2.0, 43.0, 37.0, 5.0, 8.6, 1.162162162162162, 0.8604651162790697, 6.0, 37.0, 0.13953488372093023, 5.269230769230769, 8.0, 3.0, 2.441860465116279, 12.30232558139535, 5.0, 17.0, 15.0, 19.0, 0.4411764705882353, 19.0, 34.0, 0.5588235294117647, 10.303030303030303, 0.0, 30.0, 5367.0, 11798.0, 0.4549076114595694, 0.10391591795219529))","Map(vectorType -> dense, length -> 2, values -> List(18.58401753606904, 1.41598246393096))","Map(vectorType -> dense, length -> 2, values -> List(0.9292008768034521, 0.07079912319654799))",0.0,"List(0.9292008768034521, 0.07079912319654799)",0.0707991231965479,0
1623,43908,1778015,2,1,1,0.5,4,1,2.0,30.0,30.0,30.0,0.0,11.0,0.016260162601626,0.0126582278481012,16,123,44,20,6.15,2.7954545454545454,0.3577235772357723,79,44,0.6422764227642277,15.721739130434782,30.0,2.0,1.1788617886178865,12.471544715447154,20,6.0,2115,2357,0.4729427549194991,2357,4472,0.5270572450805009,11.802751166789486,0.0,30.0,58749,390299,0.150523060525392,0.3765341845551088,"Map(vectorType -> dense, length -> 44, values -> List(2.0, 1.0, 1.0, 0.5, 4.0, 1.0, 2.0, 30.0, 30.0, 30.0, 0.0, 11.0, 0.016260162601626018, 0.012658227848101266, 16.0, 123.0, 44.0, 20.0, 6.15, 2.7954545454545454, 0.35772357723577236, 79.0, 44.0, 0.6422764227642277, 15.721739130434782, 30.0, 2.0, 1.1788617886178863, 12.471544715447154, 20.0, 6.0, 2115.0, 2357.0, 0.4729427549194991, 2357.0, 4472.0, 0.5270572450805009, 11.802751166789486, 0.0, 30.0, 58749.0, 390299.0, 0.15052306052539208, 0.3765341845551088))","Map(vectorType -> dense, length -> 2, values -> List(18.63729236572747, 1.3627076342725357))","Map(vectorType -> dense, length -> 2, values -> List(0.9318646182863732, 0.06813538171362678))",0.0,"List(0.9318646182863732, 0.06813538171362678)",0.0681353817136267,0
160722,46979,2741763,11,10,1,0.9090909090909092,34,1,12.454545454545457,10.1,22.0,1.0,2.4545454545454546,11.727272727272728,0.0121546961325966,0.0168067226890756,4,905,310,38,23.81578947368421,2.919354838709677,0.3425414364640884,595,310,0.6574585635359116,8.077966101694916,26.0,1.0,2.889502762430939,12.103867403314917,38,5.0,41585,25698,0.618061025816328,25698,67283,0.3819389741836719,10.788400680629104,0.0,30.0,159213,3418021,0.0465804627882625,0.3353585113954094,"Map(vectorType -> dense, length -> 44, values -> List(11.0, 10.0, 1.0, 0.9090909090909091, 34.0, 1.0, 12.454545454545455, 10.1, 22.0, 1.0, 2.4545454545454546, 11.727272727272727, 0.012154696132596685, 0.01680672268907563, 4.0, 905.0, 310.0, 38.0, 23.81578947368421, 2.9193548387096775, 0.3425414364640884, 595.0, 310.0, 0.6574585635359116, 8.077966101694916, 26.0, 1.0, 2.889502762430939, 12.103867403314917, 38.0, 5.0, 41585.0, 25698.0, 0.618061025816328, 25698.0, 67283.0, 0.38193897418367195, 10.788400680629104, 0.0, 30.0, 159213.0, 3418021.0, 0.04658046278826256, 0.3353585113954094))","Map(vectorType -> dense, length -> 2, values -> List(18.10412682180586, 1.8958731781941405))","Map(vectorType -> dense, length -> 2, values -> List(0.905206341090293, 0.09479365890970702))",0.0,"List(0.905206341090293, 0.09479365890970702)",0.094793658909707,0
137485,21903,1392405,5,4,1,0.8,16,6,4.0,17.2,30.0,2.0,2.6,20.8,0.0427350427350427,0.0909090909090909,0,117,73,16,7.3125,1.6027397260273972,0.6239316239316239,44,73,0.376068376068376,16.570093457943926,30.0,0.0,3.1880341880341883,20.05982905982906,16,20.0,186884,55037,0.7725001136734719,55037,241921,0.2274998863265281,11.199653703303918,0.0,30.0,159418,1765313,0.0903057984618025,0.1371940878647256,"Map(vectorType -> dense, length -> 44, values -> List(5.0, 4.0, 1.0, 0.8, 16.0, 6.0, 4.0, 17.2, 30.0, 2.0, 2.6, 20.8, 0.042735042735042736, 0.09090909090909091, 0.0, 117.0, 73.0, 16.0, 7.3125, 1.6027397260273972, 0.6239316239316239, 44.0, 73.0, 0.37606837606837606, 16.570093457943926, 30.0, 0.0, 3.1880341880341883, 20.05982905982906, 16.0, 20.0, 186884.0, 55037.0, 0.7725001136734719, 55037.0, 241921.0, 0.2274998863265281, 11.199653703303918, 0.0, 30.0, 159418.0, 1765313.0, 0.09030579846180252, 0.1371940878647256))","Map(vectorType -> dense, length -> 2, values -> List(12.628242273317392, 7.371757726682604))","Map(vectorType -> dense, length -> 2, values -> List(0.6314121136658697, 0.3685878863341303))",0.0,"List(0.6314121136658697, 0.3685878863341303)",0.3685878863341303,1
100368,1529,1583293,2,1,1,0.5,8,6,14.0,13.5,21.0,6.0,0.0,16.0,0.0082987551867219,0.0064935064935064,4,241,87,12,20.08333333333333,2.7701149425287355,0.3609958506224066,154,87,0.6390041493775933,16.5990990990991,30.0,6.0,1.95850622406639,12.755186721991702,12,30.0,3838,4898,0.4393315018315018,4898,8736,0.5606684981684982,11.348765049025692,0.0,30.0,76738,377741,0.2031497772283125,0.3575187209401856,"Map(vectorType -> dense, length -> 44, values -> List(2.0, 1.0, 1.0, 0.5, 8.0, 6.0, 14.0, 13.5, 21.0, 6.0, 0.0, 16.0, 0.008298755186721992, 0.006493506493506494, 4.0, 241.0, 87.0, 12.0, 20.083333333333332, 2.7701149425287355, 0.36099585062240663, 154.0, 87.0, 0.6390041493775933, 16.5990990990991, 30.0, 6.0, 1.95850622406639, 12.755186721991702, 12.0, 30.0, 3838.0, 4898.0, 0.4393315018315018, 4898.0, 8736.0, 0.5606684981684982, 11.348765049025692, 0.0, 30.0, 76738.0, 377741.0, 0.2031497772283125, 0.35751872094018566))","Map(vectorType -> dense, length -> 2, values -> List(18.68615199424255, 1.3138480057574546))","Map(vectorType -> dense, length -> 2, values -> List(0.9343075997121274, 0.06569240028787272))",0.0,"List(0.9343075997121274, 0.06569240028787272)",0.0656924002878727,0
177418,35166,1786000,2,1,1,0.5,70,28,7.0,3.0,4.0,2.0,2.5,10.0,0.0022935779816513,0.0015723270440251,24,872,236,94,9.27659574468085,3.694915254237288,0.2706422018348624,636,236,0.7293577981651376,3.1795166858458,21.0,0.0,2.311926605504587,10.97362385321101,94,5.0,489,1169,0.2949336550060313,1169,1658,0.7050663449939686,10.973868706182282,0.0,30.0,76476,297037,0.2574628749953709,0.4476034699985976,"Map(vectorType -> dense, length -> 44, values -> List(2.0, 1.0, 1.0, 0.5, 70.0, 28.0, 7.0, 3.0, 4.0, 2.0, 2.5, 10.0, 0.0022935779816513763, 0.0015723270440251573, 24.0, 872.0, 236.0, 94.0, 9.27659574468085, 3.694915254237288, 0.2706422018348624, 636.0, 236.0, 0.7293577981651376, 3.1795166858457997, 21.0, 0.0, 2.311926605504587, 10.97362385321101, 94.0, 5.0, 489.0, 1169.0, 0.29493365500603136, 1169.0, 1658.0, 0.7050663449939686, 10.973868706182282, 0.0, 30.0, 76476.0, 297037.0, 0.25746287499537096, 0.4476034699985976))","Map(vectorType -> dense, length -> 2, values -> List(18.850900731484263, 1.1490992685157388))","Map(vectorType -> dense, length -> 2, values -> List(0.9425450365742132, 0.057454963425786945))",0.0,"List(0.9425450365742132, 0.057454963425786945)",0.0574549634257869,0
115553,41793,2548573,5,4,1,0.8,96,17,6.8,4.4,15.0,1.0,2.4,9.0,0.0046816479400749,0.0047449584816132,2,1068,225,98,10.89795918367347,4.746666666666667,0.2106741573033707,843,225,0.7893258426966292,3.885230479774224,30.0,0.0,3.093632958801498,11.82865168539326,98,3.0,2333,869,0.7286071205496565,869,3202,0.2713928794503435,9.980026195153895,0.0,30.0,54202,234065,0.231568154145216,0.0398247253051274,"Map(vectorType -> dense, length -> 44, values -> List(5.0, 4.0, 1.0, 0.8, 96.0, 17.0, 6.8, 4.4, 15.0, 1.0, 2.4, 9.0, 0.0046816479400749065, 0.004744958481613286, 2.0, 1068.0, 225.0, 98.0, 10.89795918367347, 4.746666666666667, 0.21067415730337077, 843.0, 225.0, 0.7893258426966292, 3.885230479774224, 30.0, 0.0, 3.093632958801498, 11.82865168539326, 98.0, 3.0, 2333.0, 869.0, 0.7286071205496565, 869.0, 3202.0, 0.2713928794503435, 9.980026195153897, 0.0, 30.0, 54202.0, 234065.0, 0.23156815414521606, 0.039824725305127456))","Map(vectorType -> dense, length -> 2, values -> List(18.136398020554342, 1.863601979445663))","Map(vectorType -> dense, length -> 2, values -> List(0.9068199010277169, 0.09318009897228313))",0.0,"List(0.9068199010277169, 0.09318009897228313)",0.0931800989722831,0
105400,45007,2176640,2,1,1,0.5,12,9,6.0,19.5,30.0,9.0,2.5,15.5,0.011049723756906,0.0075187969924812,6,181,48,18,10.055555555555555,3.7708333333333335,0.2651933701657458,133,48,0.7348066298342542,13.538922155688622,30.0,5.0,2.0939226519337018,16.03867403314917,18,12.0,72165,32658,0.6884462379439626,32658,104823,0.3115537620560373,10.597150051962586,0.0,30.0,159213,3418021,0.0465804627882625,0.2649732992677747,"Map(vectorType -> dense, length -> 44, values -> List(2.0, 1.0, 1.0, 0.5, 12.0, 9.0, 6.0, 19.5, 30.0, 9.0, 2.5, 15.5, 0.011049723756906077, 0.007518796992481203, 6.0, 181.0, 48.0, 18.0, 10.055555555555555, 3.7708333333333335, 0.26519337016574585, 133.0, 48.0, 0.7348066298342542, 13.538922155688622, 30.0, 5.0, 2.0939226519337018, 16.03867403314917, 18.0, 12.0, 72165.0, 32658.0, 0.6884462379439626, 32658.0, 104823.0, 0.3115537620560373, 10.597150051962586, 0.0, 30.0, 159213.0, 3418021.0, 0.04658046278826256, 0.26497329926777474))","Map(vectorType -> dense, length -> 2, values -> List(18.596309099092576, 1.403690900907425))","Map(vectorType -> dense, length -> 2, values -> List(0.9298154549546288, 0.07018454504537125))",0.0,"List(0.9298154549546288, 0.07018454504537125)",0.0701845450453712,0
93880,49517,247785,28,27,1,0.9642857142857144,33,1,5.25,10.22222222222222,28.0,1.0,3.107142857142857,10.785714285714286,0.0590717299578059,0.0694087403598971,0,474,85,33,14.363636363636363,5.576470588235294,0.1793248945147679,389,85,0.820675105485232,10.951002227171491,28.0,1.0,2.8375527426160336,11.232067510548523,33,26.0,2706,1115,0.7081915728866789,1115,3821,0.2918084271133211,11.783276450511943,0.0,30.0,109226,891015,0.1225860395167309,0.1692223875965902,"Map(vectorType -> dense, length -> 44, values -> List(28.0, 27.0, 1.0, 0.9642857142857143, 33.0, 1.0, 5.25, 10.222222222222221, 28.0, 1.0, 3.107142857142857, 10.785714285714286, 0.05907172995780591, 0.06940874035989718, 0.0, 474.0, 85.0, 33.0, 14.363636363636363, 5.576470588235294, 0.17932489451476794, 389.0, 85.0, 0.820675105485232, 10.951002227171493, 28.0, 1.0, 2.8375527426160336, 11.232067510548523, 33.0, 26.0, 2706.0, 1115.0, 0.7081915728866789, 1115.0, 3821.0, 0.2918084271133211, 11.783276450511945, 0.0, 30.0, 109226.0, 891015.0, 0.12258603951673092, 0.1692223875965902))","Map(vectorType -> dense, length -> 2, values -> List(10.195436432543355, 9.804563567456643))","Map(vectorType -> dense, length -> 2, values -> List(0.5097718216271677, 0.49022817837283217))",0.0,"List(0.5097718216271677, 0.49022817837283217)",0.4902281783728321,1
63685,39190,1263744,1,0,1,0.0,4,4,14.0,8.0,8.0,8.0,1.0,23.0,0.0054347826086956,0.0,8,184,130,12,15.333333333333334,1.4153846153846157,0.7065217391304348,54,130,0.2934782608695652,18.22360248447205,30.0,2.0,2.717391304347826,13.36413043478261,12,19.0,6294,4678,0.5736419978126139,4678,10972,0.426358002187386,12.504883101509323,0.0,30.0,96326,638253,0.1509213431037535,0.2754366590836326,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 4.0, 4.0, 14.0, 8.0, 8.0, 8.0, 1.0, 23.0, 0.005434782608695652, 0.0, 8.0, 184.0, 130.0, 12.0, 15.333333333333334, 1.4153846153846155, 0.7065217391304348, 54.0, 130.0, 0.29347826086956524, 18.22360248447205, 30.0, 2.0, 2.717391304347826, 13.36413043478261, 12.0, 19.0, 6294.0, 4678.0, 0.5736419978126139, 4678.0, 10972.0, 0.42635800218738606, 12.504883101509323, 0.0, 30.0, 96326.0, 638253.0, 0.15092134310375352, 0.2754366590836326))","Map(vectorType -> dense, length -> 2, values -> List(18.74200689205801, 1.2579931079419875))","Map(vectorType -> dense, length -> 2, values -> List(0.9371003446029006, 0.06289965539709938))",0.0,"List(0.9371003446029006, 0.06289965539709938)",0.0628996553970993,0


In [0]:
# predictions 결과를 order_id로 group by하여 개별 order별 예측 상품건수와 재 주문 상품 건수 계산. 
predictions_grp = predictions.groupby('order_id').agg(F.count('*').alias('total_cnt_by_order_id'), 
                                                      F.sum(F.col('reordered')).alias('reordered_cnt'))
print(predictions_grp.count(), predictions_grp.filter('reordered_cnt == 0').count())
display(predictions_grp.filter('reordered_cnt == 0').orderBy('order_id'))

75000 7464


order_id,total_cnt_by_order_id,reordered_cnt
353,12,0
474,21,0
513,16,0
1195,16,0
1564,20,0
1789,17,0
2297,10,0
3519,15,0
4848,27,0
5216,23,0


In [0]:
# collect_list()함수 결과 보기. 
import pyspark.sql.functions as F

display(predictions.filter('reordered == 1').groupBy('order_id').agg(F.collect_list('product_id')).limit(10))

order_id,collect_list(product_id)
34,"List(47766, 43504, 39180, 21137, 47792, 2596, 16083, 47029, 39475)"
137,"List(2326, 24852, 38689, 5134, 41787, 23794, 25890)"
386,"List(21479, 42265, 28985, 38281, 40759, 15872, 22124, 45066, 4920, 39180, 37935, 47766, 24852, 30450)"
497,"List(31964, 36316, 17122, 1831, 39947, 27275)"
604,"List(28745, 24852, 37511, 16797, 12099)"
758,List(19660)
1304,"List(24852, 22035)"
1802,"List(34969, 43295, 20114, 3896, 38837, 21137, 13176, 4920, 38313, 47209, 21709)"
2247,"List(19125, 18234, 13176, 49235)"
2721,"List(21137, 33129)"


In [0]:
# collect_list('product_id')로 입력되는 product_id list값을 ' '으로 결합된 문자열로 변환하는 함수 생성. 
def get_product_ids_str(product_id_group):
    #product_id_group은 collect_list('product_id')로 group by된 집합으로 product_id를 list로 가지고 있는 형태로 입력 됨. 
    product_ids_str = ''
    for product_id in product_id_group:
        product_ids_str += ' ' + str(product_id)
    
    return product_ids_str

In [0]:
from pyspark.sql.functions import udf,col
from pyspark.sql.types import StringType

# 일반 python용 UDF를 pyspark용 UDF로 변환. udf(lambda 입력변수: 일반 UDF, 해당 일반 UDF의 반환형)
udf_get_product_ids_str = udf(lambda x:get_product_ids_str(x), StringType() )

In [0]:
submission_01 = predictions.filter('reordered == 1').groupBy('order_id').agg(udf_get_product_ids_str(F.collect_list('product_id')).alias('products'))
display(submission_01.limit(10))

order_id,products
34,47766 43504 39180 21137 47792 2596 16083 47029 39475
137,2326 24852 38689 5134 41787 23794 25890
386,21479 42265 28985 38281 40759 15872 22124 45066 4920 39180 37935 47766 24852 30450
497,31964 36316 17122 1831 39947 27275
604,28745 24852 37511 16797 12099
758,19660
1304,24852 22035
1802,34969 43295 20114 3896 38837 21137 13176 4920 38313 47209 21709
2247,19125 18234 13176 49235
2721,21137 33129


In [0]:
submission_02 = predictions_grp.filter('reordered_cnt == 0').withColumn('products', F.lit('None')).select('order_id', 'products')
display(submission_02.limit(10))


order_id,products
2446723,
1552436,
2726972,
488000,
2322361,
2322403,
3319167,
567184,
1107801,
743061,


In [0]:
submission = submission_01.union(submission_02)
print('submission count:', submission.count())
submission = submission.orderBy('order_id')

display(submission.limit(76000))

submission count: 75000


order_id,products
17,13107
34,47766 43504 39180 21137 47792 2596 16083 47029 39475
137,2326 24852 38689 5134 41787 23794 25890
182,33000 9337 13629 39275 47672 5479
257,24852 30233 45013 4605 13870 1025 27104 49235 29837 27966
313,21903 13198 46906 12779 28535 45007
353,
386,21479 42265 28985 38281 40759 15872 22124 45066 4920 39180 37935 47766 24852 30450
414,21709 20564 31730 20392 21230 33320 27845
418,30489 47766 41950 40268 38694


In [0]:
display(submission)

order_id,products
17,13107
34,47766 43504 39180 21137 47792 2596 16083 47029 39475
137,2326 24852 38689 5134 41787 23794 25890
182,33000 9337 13629 39275 47672 5479
257,24852 30233 45013 4605 13870 1025 27104 49235 29837 27966
313,21903 13198 46906 12779 28535 45007
353,
386,21479 42265 28985 38281 40759 15872 22124 45066 4920 39180 37935 47766 24852 30450
414,21709 20564 31730 20392 21230 33320 27845
418,30489 47766 41950 40268 38694


In [0]:
kaggle competitions submit -c instacart-market-basket-analysis -f submission.csv -m "Message"