# 0. 設定 Spark

In [1]:
spark

In [2]:
import pandas as pd
import pyspark.pandas as ps
import numpy as np
import gc
import pyspark.sql.functions as sf
from pyspark.context import SparkContext
from pyspark.sql.functions import array_max
from pyspark.ml.feature import StringIndexer, IndexToString
from pyspark.pandas.config import set_option, reset_option
from dateutil.relativedelta import relativedelta

In [3]:
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", True)
ps.set_option("compute.default_index_type", "distributed")
set_option("compute.ops_on_diff_frames", True)

In [4]:
spark.conf.get("spark.kryoserializer.buffer.max")

'512m'

# 1. 讀取檔案 => tran_ps

In [13]:
# 讀取檔案
# customers = ps.read_parquet('/user/HM_parquet/customers.parquet')
# articles = ps.read_parquet('/user/HM_parquet/articles.parquet')
tran_ps = ps.read_parquet('/user/HM_parquet/transactions_train.parquet').drop(['price', 'sales_channel_id'], axis=1)

In [14]:
tran_ps.set_index('t_dat',inplace=True)
tran_ps['start_test'] = ''
tran_ps['split_id'] = ''
tran_ps.head(5)

                                                                                

Unnamed: 0_level_0,customer_id,article_id,start_test,split_id
t_dat,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
2018-09-20,000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...,663713001,,
2018-09-20,000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...,541518023,,
2018-09-20,00007d2de826758b65a93dd24ce629ed66842531df6699...,505221004,,
2018-09-20,00007d2de826758b65a93dd24ce629ed66842531df6699...,685687003,,
2018-09-20,00007d2de826758b65a93dd24ce629ed66842531df6699...,685687004,,


# 2. 作時間分割 tran_ps => split_data, split_id

In [15]:
def split(data,train_period=30, test_period=7, stride=30,show_progress=False):
    
    split_data = ps.DataFrame(columns = ['t_dat','customer_id', 'article_id', 'split_id', 'start_test']).set_index('t_dat',inplace=True)

    end_test = data.index.max()
    start_test = end_test - relativedelta(days=test_period)
    start_train = start_test - relativedelta(days=train_period)
    split_id=0

    while start_train >= data.index.min():

        df = data.loc[start_train:end_test]
        df['start_test']=start_test
        df['split_id'] = split_id
        split_data = ps.concat([split_data,df])

        if(show_progress):
            print("Split_id:",split_id,", Train period:",start_train,"-" , start_test, ", test period", start_test, "-", end_test)

        # update dates:
        end_test = end_test - relativedelta(days=stride)
        start_test = end_test - relativedelta(days=test_period)
        start_train = start_test - relativedelta(days=train_period)
        split_id += 1
    
    return split_data, split_id

In [16]:
split_data, split_id = split(tran_ps,30,7,30,True)

                                                                                

Split_id: 0 , Train period: 2020-08-16 - 2020-09-15 , test period 2020-09-15 - 2020-09-22


                                                                                

Split_id: 1 , Train period: 2020-07-17 - 2020-08-16 , test period 2020-08-16 - 2020-08-23


                                                                                

Split_id: 2 , Train period: 2020-06-17 - 2020-07-17 , test period 2020-07-17 - 2020-07-24


                                                                                

Split_id: 3 , Train period: 2020-05-18 - 2020-06-17 , test period 2020-06-17 - 2020-06-24


                                                                                

Split_id: 4 , Train period: 2020-04-18 - 2020-05-18 , test period 2020-05-18 - 2020-05-25


                                                                                

Split_id: 5 , Train period: 2020-03-19 - 2020-04-18 , test period 2020-04-18 - 2020-04-25


                                                                                

Split_id: 6 , Train period: 2020-02-18 - 2020-03-19 , test period 2020-03-19 - 2020-03-26


                                                                                

Split_id: 7 , Train period: 2020-01-19 - 2020-02-18 , test period 2020-02-18 - 2020-02-25


                                                                                

Split_id: 8 , Train period: 2019-12-20 - 2020-01-19 , test period 2020-01-19 - 2020-01-26


                                                                                

Split_id: 9 , Train period: 2019-11-20 - 2019-12-20 , test period 2019-12-20 - 2019-12-27


                                                                                

Split_id: 10 , Train period: 2019-10-21 - 2019-11-20 , test period 2019-11-20 - 2019-11-27


                                                                                

Split_id: 11 , Train period: 2019-09-21 - 2019-10-21 , test period 2019-10-21 - 2019-10-28


                                                                                

Split_id: 12 , Train period: 2019-08-22 - 2019-09-21 , test period 2019-09-21 - 2019-09-28


                                                                                

Split_id: 13 , Train period: 2019-07-23 - 2019-08-22 , test period 2019-08-22 - 2019-08-29


                                                                                

Split_id: 14 , Train period: 2019-06-23 - 2019-07-23 , test period 2019-07-23 - 2019-07-30


                                                                                

Split_id: 15 , Train period: 2019-05-24 - 2019-06-23 , test period 2019-06-23 - 2019-06-30


                                                                                

Split_id: 16 , Train period: 2019-04-24 - 2019-05-24 , test period 2019-05-24 - 2019-05-31


                                                                                

Split_id: 17 , Train period: 2019-03-25 - 2019-04-24 , test period 2019-04-24 - 2019-05-01


                                                                                

Split_id: 18 , Train period: 2019-02-23 - 2019-03-25 , test period 2019-03-25 - 2019-04-01


                                                                                

Split_id: 19 , Train period: 2019-01-24 - 2019-02-23 , test period 2019-02-23 - 2019-03-02


                                                                                

Split_id: 20 , Train period: 2018-12-25 - 2019-01-24 , test period 2019-01-24 - 2019-01-31


                                                                                

Split_id: 21 , Train period: 2018-11-25 - 2018-12-25 , test period 2018-12-25 - 2019-01-01


                                                                                

Split_id: 22 , Train period: 2018-10-26 - 2018-11-25 , test period 2018-11-25 - 2018-12-02


                                                                                

Split_id: 23 , Train period: 2018-09-26 - 2018-10-26 , test period 2018-10-26 - 2018-11-02


                                                                                

In [17]:
split_data.reset_index(inplace=True)

In [18]:
split_data.count()

                                                                                

start_test     39919883
t_dat          39919883
customer_id    39919883
article_id     39919883
split_id       39919883
dtype: int64

# 3. 製作參數表 split_id => para_cross_split

In [None]:
# 製作參數表 paras_grid
from itertools import product

paras = list(
    product(
        [25,50,100,150,200],
        [20,30,40,50],
        [0.01]
    )
)
paras_grid = ps.DataFrame(paras,columns= ['n_factors','n_epochs','reg_all'])
paras_grid.count()

In [None]:
# 製作 split_id 表
split_id_ps = ps.DataFrame({'split_id': range(split_id)})
split_id_ps.count()

In [None]:
# 將 paras_grid 與 split_id 做 cross join
paras_grid['key'] = 1
split_id_ps['key'] = 1

para_cross_split = ps.merge(paras_grid, split_id_ps, on ='key').drop('key')
del split_id_ps
len(para_cross_split)

In [None]:
# 將 cross join 後的表新增遞增的 group_id 欄位，之後要用來做 pandas_udf 的 groupby
para_cross_split['group_id'] = 0
para_cross_split['group_id'] = np.arange(len(para_cross_split)).tolist()
para_cross_split

In [None]:
# para_cross_split.to_parquet('/user/HM_parquet/SVD_model/para_cross_split.parquet')

# 4. join 參數表和資料表 split_data, para_cross_split => join_data

In [None]:
join_data = split_data.join(para_cross_split.set_index('split_id'), on='split_id')
join_data.set_index('t_dat',inplace=True)

In [None]:
join_data.head(5)

In [None]:
39919883*20

In [None]:
join_data.count()

In [None]:
# join_data.to_parquet('/user/HM_parquet/SVD_model/join_data30.parquet')

In [None]:
# 刪除用不到的資料表
del split_data, para_cross_split
gc.collect()

# -----------------------------

# 讀取資料表

In [6]:
join_data = ps.read_parquet('/user/HM_parquet/SVD_model/join_data30.parquet')
join_data.set_index('t_dat',inplace=True)
join_data.count()

                                                                                

n_factors      798397660
group_id       798397660
start_test     798397660
n_epochs       798397660
customer_id    798397660
reg_all        798397660
article_id     798397660
split_id       798397660
dtype: int64

In [7]:
join_data.head(5)

                                                                                

Unnamed: 0_level_0,split_id,customer_id,article_id,start_test,n_factors,n_epochs,reg_all,group_id
t_dat,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
2019-06-23,14,00007d2de826758b65a93dd24ce629ed66842531df6699...,651697001,2019-07-23,25,20,0.01,14
2019-06-23,14,00007d2de826758b65a93dd24ce629ed66842531df6699...,651697001,2019-07-23,25,40,0.01,62
2019-06-23,14,00007d2de826758b65a93dd24ce629ed66842531df6699...,651697001,2019-07-23,25,50,0.01,86
2019-06-23,14,00007d2de826758b65a93dd24ce629ed66842531df6699...,651697001,2019-07-23,50,40,0.01,158
2019-06-23,14,00007d2de826758b65a93dd24ce629ed66842531df6699...,651697001,2019-07-23,100,50,0.01,278


# surpriseSVD

In [5]:
import pandas as pd
# from surprise import NormalPredictor
from surprise import Dataset
from surprise import Reader
from surprise import SVDpp,SVD
from surprise import accuracy
from surprise.model_selection import train_test_split
from collections import defaultdict
import numpy as np
import average_precision as metrics
# import ml_metrics as metrics

class surpriseSVD():
    def __init__(self):
        self = self

    def get_top_n(self, predictions, n=12):
        """Return the top-N recommendation for each user from a set of predictions.
        Args:
            predictions(list of Prediction objects): The list of predictions, as
                returned by the test method of an algorithm.
            n(int): The number of recommendation to output for each user. Default
                is 10.
        Returns:
        A dict where keys are user (raw) ids and values are lists of tuples:
            [(raw item id, rating estimation), ...] of size n.
        """

        # First map the predictions to each user.
        top_n = defaultdict(list)
        for uid, iid, true_r, est, _ in predictions:
            top_n[uid].append((iid, est))

        # Then sort the predictions for each user and retrieve the k highest ones.
        for uid, user_ratings in top_n.items():
            user_ratings.sort(key=lambda x: x[1], reverse=True)
            top_n[uid] = user_ratings[:n]

        return top_n

    def get_set(self,df):
        reader = Reader(rating_scale=(1, 500))
        data_set = Dataset.load_from_df(df[['customer_id','article_id','rating']], reader)
        return data_set

    def get_rating_set(self,df):
        rating = df[['customer_id','article_id','price']].groupby(['customer_id','article_id']).count().reset_index()
        rating.columns = ['customer_id','article_id','rating']
        rating_set = self.get_set(rating)
        return rating_set


    def train_SVD(self, dataTrain, dataTest, paras={}):

        ## 讀取評分資料為surprise可以訓練的格式
        trainset = self.get_rating_set(train_data)
        testset = self.get_rating_set(test_data)

        ## rmse 需要的資料
        testset2 = [testset.df.loc[i].to_list() for i in range(len(testset.df))]

        ## map@k testing 需要產的資料
        test_data.loc[:,'rating']=0
        test_processed = self.get_set(test_data)
        NA, test2 = train_test_split(test_processed, test_size=1.0)

        # ======= 消費者的實際購買清單 =======
        test_data['article_id'] = test_data['article_id'].astype('str')
        test_uni = test_data.drop_duplicates(subset=['customer_id', 'article_id'], keep='first')
        buy_n = test_uni[['customer_id','article_id']].groupby('customer_id')['article_id'].apply(list).to_dict()

        cust_actual_list = []
        for uid, user_ratings in buy_n.items():
            cust_pred_tuple = (uid, [iid for iid in user_ratings])
            cust_actual_list.append(cust_pred_tuple)

        # ======= 訓練 SVD 模型 =======
        algo = SVD(random_state=42,**paras)

        # 訓練模型
        algo.fit(trainset.build_full_trainset())

        ##### rmse #####
        predictions = algo.test(testset2)
        rmse = accuracy.rmse(predictions)

        ##### map@k #####
        predictions_map = algo.test(test2)
        # est = [i.est for i in predictions_map] 

        ##  消費者的預測清單 
        top_n = self.get_top_n(predictions=predictions_map, n=12)

        cust_pred_list = []
        for uid, user_ratings in top_n.items():
            cust_pred_tuple = (uid, [str(iid) for (iid, _) in user_ratings])
            cust_pred_list.append(cust_pred_tuple)

        final_list = list(zip(cust_actual_list, cust_pred_list))

        # map@k計算 
        mapk_list = []
        for i in range(len(final_list)):
            map_k = metrics.mapk([final_list[i][0][1]],[final_list[i][1][1]],12)
            mapk_list.append(map_k)

        map_k = sum(mapk_list)/len(mapk_list)

        return rmse, map_k

# pandas_udf

In [8]:
t = [True, False, False, True]
not t[0]

False

In [21]:
a = (0,1)
b = (2,3)
c = (4,5)

x,y  = zip(a,b,c)

In [22]:
y

(1, 3, 5)

In [34]:
def get_train(d) -> (bool,bool):
    train_index = (d['t_dat'] <= d['start_test'])
    test_index = not train_index
    return [train_index, test_index]

In [37]:
train_index, test_index = zip(split_data[['t_dat','start_test']].apply(get_train).tolist())

2022-03-29 17:11:36,660 WARN scheduler.TaskSetManager: Lost task 32.0 in stage 119.0 (TID 1782) (bdse137.example.com executor 11): org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/usr/local/spark/python/pyspark/pandas/groupby.py", line 1459, in rename_output
  File "/usr/local/spark/python/pyspark/pandas/frame.py", line 2525, in apply_func
  File "/usr/local/lib/python3.8/dist-packages/pandas/core/frame.py", line 8833, in apply
    return op.apply().__finalize__(self, method="apply")
  File "/usr/local/lib/python3.8/dist-packages/pandas/core/apply.py", line 727, in apply
    return self.apply_standard()
  File "/usr/local/lib/python3.8/dist-packages/pandas/core/apply.py", line 851, in apply_standard
    results, res_index = self.apply_series_generator()
  File "/usr/local/lib/python3.8/dist-packages/pandas/core/apply.py", line 871, in apply_series_generator
    results[i] = results[i].copy(deep=False)
  File "/tmp/ipykernel_17163/4018977022.py", 

Py4JJavaError: An error occurred while calling o5601.getResult.
: org.apache.spark.SparkException: Exception thrown in awaitResult: 
	at org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:301)
	at org.apache.spark.security.SocketAuthServer.getResult(SocketAuthServer.scala:97)
	at org.apache.spark.security.SocketAuthServer.getResult(SocketAuthServer.scala:93)
	at sun.reflect.GeneratedMethodAccessor65.invoke(Unknown Source)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.lang.Thread.run(Thread.java:748)
Caused by: org.apache.spark.SparkException: Job aborted due to stage failure: Task 32 in stage 119.0 failed 4 times, most recent failure: Lost task 32.3 in stage 119.0 (TID 1824) (bdse108.example.com executor 4): org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/usr/local/spark/python/pyspark/pandas/groupby.py", line 1459, in rename_output
  File "/usr/local/spark/python/pyspark/pandas/frame.py", line 2525, in apply_func
  File "/usr/local/lib/python3.8/dist-packages/pandas/core/frame.py", line 8833, in apply
    return op.apply().__finalize__(self, method="apply")
  File "/usr/local/lib/python3.8/dist-packages/pandas/core/apply.py", line 727, in apply
    return self.apply_standard()
  File "/usr/local/lib/python3.8/dist-packages/pandas/core/apply.py", line 851, in apply_standard
    results, res_index = self.apply_series_generator()
  File "/usr/local/lib/python3.8/dist-packages/pandas/core/apply.py", line 871, in apply_series_generator
    results[i] = results[i].copy(deep=False)
  File "/tmp/ipykernel_17163/4018977022.py", line 2, in get_train
  File "/usr/local/lib/python3.8/dist-packages/pandas/core/series.py", line 958, in __getitem__
    return self._get_value(key)
  File "/usr/local/lib/python3.8/dist-packages/pandas/core/series.py", line 1069, in _get_value
    loc = self.index.get_loc(label)
  File "/usr/local/lib/python3.8/dist-packages/pandas/core/indexes/base.py", line 3623, in get_loc
    raise KeyError(key) from err
KeyError: 't_dat'

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:555)
	at org.apache.spark.sql.execution.python.PythonArrowOutput$$anon$1.read(PythonArrowOutput.scala:101)
	at org.apache.spark.sql.execution.python.PythonArrowOutput$$anon$1.read(PythonArrowOutput.scala:50)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:508)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:491)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage26.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:759)
	at org.apache.spark.sql.execution.arrow.ArrowConverters$$anon$1.hasNext(ArrowConverters.scala:99)
	at scala.collection.Iterator.foreach(Iterator.scala:943)
	at scala.collection.Iterator.foreach$(Iterator.scala:943)
	at org.apache.spark.sql.execution.arrow.ArrowConverters$$anon$1.foreach(ArrowConverters.scala:97)
	at scala.collection.generic.Growable.$plus$plus$eq(Growable.scala:62)
	at scala.collection.generic.Growable.$plus$plus$eq$(Growable.scala:53)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:105)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:49)
	at scala.collection.TraversableOnce.to(TraversableOnce.scala:366)
	at scala.collection.TraversableOnce.to$(TraversableOnce.scala:364)
	at org.apache.spark.sql.execution.arrow.ArrowConverters$$anon$1.to(ArrowConverters.scala:97)
	at scala.collection.TraversableOnce.toBuffer(TraversableOnce.scala:358)
	at scala.collection.TraversableOnce.toBuffer$(TraversableOnce.scala:358)
	at org.apache.spark.sql.execution.arrow.ArrowConverters$$anon$1.toBuffer(ArrowConverters.scala:97)
	at scala.collection.TraversableOnce.toArray(TraversableOnce.scala:345)
	at scala.collection.TraversableOnce.toArray$(TraversableOnce.scala:339)
	at org.apache.spark.sql.execution.arrow.ArrowConverters$$anon$1.toArray(ArrowConverters.scala:97)
	at org.apache.spark.sql.Dataset.$anonfun$collectAsArrowToPython$6(Dataset.scala:3650)
	at org.apache.spark.SparkContext.$anonfun$runJob$6(SparkContext.scala:2308)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:131)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:506)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1462)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:509)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:748)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2454)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2403)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2402)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2402)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1160)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1160)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1160)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2642)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2584)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2573)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:938)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2214)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2309)
	at org.apache.spark.sql.Dataset.$anonfun$collectAsArrowToPython$5(Dataset.scala:3648)
	at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1462)
	at org.apache.spark.sql.Dataset.$anonfun$collectAsArrowToPython$2(Dataset.scala:3652)
	at org.apache.spark.sql.Dataset.$anonfun$collectAsArrowToPython$2$adapted(Dataset.scala:3629)
	at org.apache.spark.sql.Dataset.$anonfun$withAction$1(Dataset.scala:3706)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$5(SQLExecution.scala:103)
	at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:163)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$1(SQLExecution.scala:90)
	at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:775)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:64)
	at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3704)
	at org.apache.spark.sql.Dataset.$anonfun$collectAsArrowToPython$1(Dataset.scala:3629)
	at org.apache.spark.sql.Dataset.$anonfun$collectAsArrowToPython$1$adapted(Dataset.scala:3628)
	at org.apache.spark.security.SocketAuthServer$.$anonfun$serveToStream$2(SocketAuthServer.scala:139)
	at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1462)
	at org.apache.spark.security.SocketAuthServer$.$anonfun$serveToStream$1(SocketAuthServer.scala:141)
	at org.apache.spark.security.SocketAuthServer$.$anonfun$serveToStream$1$adapted(SocketAuthServer.scala:136)
	at org.apache.spark.security.SocketFuncServer.handleConnection(SocketAuthServer.scala:113)
	at org.apache.spark.security.SocketFuncServer.handleConnection(SocketAuthServer.scala:107)
	at org.apache.spark.security.SocketAuthServer$$anon$1.$anonfun$run$4(SocketAuthServer.scala:68)
	at scala.util.Try$.apply(Try.scala:213)
	at org.apache.spark.security.SocketAuthServer$$anon$1.run(SocketAuthServer.scala:68)
Caused by: org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/usr/local/spark/python/pyspark/pandas/groupby.py", line 1459, in rename_output
  File "/usr/local/spark/python/pyspark/pandas/frame.py", line 2525, in apply_func
  File "/usr/local/lib/python3.8/dist-packages/pandas/core/frame.py", line 8833, in apply
    return op.apply().__finalize__(self, method="apply")
  File "/usr/local/lib/python3.8/dist-packages/pandas/core/apply.py", line 727, in apply
    return self.apply_standard()
  File "/usr/local/lib/python3.8/dist-packages/pandas/core/apply.py", line 851, in apply_standard
    results, res_index = self.apply_series_generator()
  File "/usr/local/lib/python3.8/dist-packages/pandas/core/apply.py", line 871, in apply_series_generator
    results[i] = results[i].copy(deep=False)
  File "/tmp/ipykernel_17163/4018977022.py", line 2, in get_train
  File "/usr/local/lib/python3.8/dist-packages/pandas/core/series.py", line 958, in __getitem__
    return self._get_value(key)
  File "/usr/local/lib/python3.8/dist-packages/pandas/core/series.py", line 1069, in _get_value
    loc = self.index.get_loc(label)
  File "/usr/local/lib/python3.8/dist-packages/pandas/core/indexes/base.py", line 3623, in get_loc
    raise KeyError(key) from err
KeyError: 't_dat'

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:555)
	at org.apache.spark.sql.execution.python.PythonArrowOutput$$anon$1.read(PythonArrowOutput.scala:101)
	at org.apache.spark.sql.execution.python.PythonArrowOutput$$anon$1.read(PythonArrowOutput.scala:50)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:508)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:491)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage26.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:759)
	at org.apache.spark.sql.execution.arrow.ArrowConverters$$anon$1.hasNext(ArrowConverters.scala:99)
	at scala.collection.Iterator.foreach(Iterator.scala:943)
	at scala.collection.Iterator.foreach$(Iterator.scala:943)
	at org.apache.spark.sql.execution.arrow.ArrowConverters$$anon$1.foreach(ArrowConverters.scala:97)
	at scala.collection.generic.Growable.$plus$plus$eq(Growable.scala:62)
	at scala.collection.generic.Growable.$plus$plus$eq$(Growable.scala:53)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:105)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:49)
	at scala.collection.TraversableOnce.to(TraversableOnce.scala:366)
	at scala.collection.TraversableOnce.to$(TraversableOnce.scala:364)
	at org.apache.spark.sql.execution.arrow.ArrowConverters$$anon$1.to(ArrowConverters.scala:97)
	at scala.collection.TraversableOnce.toBuffer(TraversableOnce.scala:358)
	at scala.collection.TraversableOnce.toBuffer$(TraversableOnce.scala:358)
	at org.apache.spark.sql.execution.arrow.ArrowConverters$$anon$1.toBuffer(ArrowConverters.scala:97)
	at scala.collection.TraversableOnce.toArray(TraversableOnce.scala:345)
	at scala.collection.TraversableOnce.toArray$(TraversableOnce.scala:339)
	at org.apache.spark.sql.execution.arrow.ArrowConverters$$anon$1.toArray(ArrowConverters.scala:97)
	at org.apache.spark.sql.Dataset.$anonfun$collectAsArrowToPython$6(Dataset.scala:3650)
	at org.apache.spark.SparkContext.$anonfun$runJob$6(SparkContext.scala:2308)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:131)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:506)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1462)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:509)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:748)


In [None]:
def time_split_hyperparameter_search(data):
    paras = {
        'factors':data.factors.values[0], 
        'iterations':dta.iterations.values[0], 
        'regularization':data.regularization.values[0]
    }
    
    train_index, test_index = join_data.apply(get_train, axis=1)
    dataTrain = data[(data['t_dat'] <= data['start_test'])]
    dataTest = data[(data['t_dat'] > start_test)]
    
    rmse, map12 = model.train_SVD(dataTrain, dataTest, paras)
    
    paras.update({
        'date_x_paras_id' : data.date_x_paras_id.values[0],
        'val_date' : data.val_date.values[0],
        'map12' : map12
    })
    
    results = pd.DataFrame([paras])
    
    return results


In [None]:
# pandas_udf

schema = StructType(
    [
        StructField('date_x_paras_id', IntegerType(),True),
        StructField('map12', FloatType(),True),
        StructField('val_date', DateTime(),True),
        StructField("para1", IntegerType(), True),
        StructField("para2", IntegerType(), True)
     ]
)

results = df.groupby('date_x_paras_id').applyInPandas(time_split_hyperparameter_search, schema)
results.show()

# (DataFrame)

## 1. 讀取檔案 /DataFrame

In [None]:
# # 讀取檔案
# customers = spark.read.option('header','true').parquet('/user/HM_parquet/customers.parquet')
# articles = spark.read.option('header','true').parquet('/user/HM_parquet/articles.parquet')
# transactions = spark.read.option('header','true').parquet('/user/HM_parquet/transactions_train.parquet')

In [None]:
# transactions.show()

## 2. 將customer_id(字串)轉為customer_index(整數) /DataFrame

In [None]:
# # 將customers的customer_id轉為數字(buffer要增加到512m)
# toIndex = StringIndexer(inputCol="customer_id", outputCol="customer_index").fit(customers)
# customers = toIndex.transform(customers)
# # customers.head(5)

In [None]:
# # 將transactions的customer_id轉為數字
# transactions = toIndex.transform(transactions)
# # transactions.head(5)

In [None]:
# transactions.describe()