# 목표 : spark DF DeepFM Predict

### 1. Data load

In [1]:
from modules.train import get_data
from modules.DeepFM import DeepFM
import modules.config as config
import tensorflow as tf
from tensorflow.keras.metrics import BinaryAccuracy, AUC
import pandas as pd
# data load
test = pd.read_parquet('data/test.parquet')
y_test = test['target']
x_test = test.drop('target',axis = 1)


2023-02-05 19:52:22.170355: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


### 2. model load

In [2]:
# model load
_, _, field_dict, field_index = get_data()
fm = DeepFM(embedding_size=config.EMBEDDING_SIZE, num_feature=len(field_index),
               num_field=len(field_dict), field_index=field_index)

fm.build(input_shape = (1,len(field_index)))
fm.load_weights('./weights/weights-epoch(10)-batch(256)-embedding(5).h5')

Data Prepared...
X shape: (32561, 108)
# of Feature: 108
# of Field: 14
train/test save complete


2023-02-05 19:52:28.391829: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


### 2. Pandas DF -> Spark DF convert

In [38]:
from pyspark.sql import SparkSession
#Create PySpark SparkSession
spark = SparkSession.builder \
    .master("local[1]") \
    .appName("SparkByExamples.com") \
    .getOrCreate()
#Create PySpark DataFrame from Pandas
sparkDF=spark.createDataFrame(x_test) 
sparkDF.printSchema()


root
 |-- age: double (nullable = true)
 |-- workclass- ?: long (nullable = true)
 |-- workclass- Federal-gov: long (nullable = true)
 |-- workclass- Local-gov: long (nullable = true)
 |-- workclass- Never-worked: long (nullable = true)
 |-- workclass- Private: long (nullable = true)
 |-- workclass- Self-emp-inc: long (nullable = true)
 |-- workclass- Self-emp-not-inc: long (nullable = true)
 |-- workclass- State-gov: long (nullable = true)
 |-- workclass- Without-pay: long (nullable = true)
 |-- fnlwgt: double (nullable = true)
 |-- education- 10th: long (nullable = true)
 |-- education- 11th: long (nullable = true)
 |-- education- 12th: long (nullable = true)
 |-- education- 1st-4th: long (nullable = true)
 |-- education- 5th-6th: long (nullable = true)
 |-- education- 7th-8th: long (nullable = true)
 |-- education- 9th: long (nullable = true)
 |-- education- Assoc-acdm: long (nullable = true)
 |-- education- Assoc-voc: long (nullable = true)
 |-- education- Bachelors: long (nullable

In [None]:
# import pyspark.sql.functions as F
# from pyspark.ml.feature import OneHotEncoder
# from pyspark.ml.feature import StringIndexer

# indexer = StringIndexer(inputCols=CAT_FIELDS, outputCols=[col + "_encoded" for col in CAT_FIELDS])
# label_df = indexer.fit(sparkDF).transform(sparkDF)
# label_df.show()



In [5]:
y = float(fm(x_test.head(1).values))
y

0.0488150417804718

### UDF predit

In [11]:
import numpy as np

# column order required by the model.
FEATURES = x_test.columns
path = './weights/weights-epoch(10)-batch(256)-embedding(5).h5' 

def predict(features):

    np_features = np.array([features])
    
    
#     # model load
#     fm = DeepFM(embedding_size=config.EMBEDDING_SIZE, num_feature=len(field_index),
#                num_field=len(field_dict), field_index=field_index)

#     fm.build(input_shape = (1,len(field_index)))
#     fm.load_weights(path)
    
    y = fm(np_features)


    return float(y)

from pyspark.sql import functions as F
from pyspark.sql.functions import udf
from pyspark.sql.types import DoubleType, StringType
 

predict_udf = udf(predict, DoubleType())


# test용 udf
def test(x):
    return x 
test_udf = udf(test, StringType())

In [83]:
# sparkDF2 = sparkDF.withColumn(
#     "test",
#     test_udf(F.col('country- Philippines'))
# )
# sparkDF2.select(F.col('test')).show()

### MapPartitions predict

In [67]:
import pandas as pd

# column order required by the model.
FEATURES = x_test.columns
path = './weights/weights-epoch(10)-batch(256)-embedding(5).h5' 

def predict_partition(rows):
    # model load
    fm = DeepFM(embedding_size=config.EMBEDDING_SIZE, num_feature=len(field_index),
               num_field=len(field_dict), field_index=field_index)
    fm.build(input_shape = (1,len(field_index)))
    fm.load_weights(path)
    
    y = fm(rows[FEATURES].values)
    return y

#     y = fm(rows[FEATURES].values)
#     return float(y)


def test(rows):
    for row in rows:
        yield row.age + row['workclass- Federal-gov']

In [85]:
test_ = df.rdd.mapPartitions(test).collect()
# test_

In [79]:
float(fm(x_test.head(1).values))

0.04881441965699196

In [89]:
df.printSchema

<bound method DataFrame.printSchema of DataFrame[age: double, workclass- ?: smallint, workclass- Federal-gov: smallint, workclass- Local-gov: smallint, workclass- Never-worked: smallint, workclass- Private: smallint, workclass- Self-emp-inc: smallint, workclass- Self-emp-not-inc: smallint, workclass- State-gov: smallint, workclass- Without-pay: smallint, fnlwgt: double, education- 10th: smallint, education- 11th: smallint, education- 12th: smallint, education- 1st-4th: smallint, education- 5th-6th: smallint, education- 7th-8th: smallint, education- 9th: smallint, education- Assoc-acdm: smallint, education- Assoc-voc: smallint, education- Bachelors: smallint, education- Doctorate: smallint, education- HS-grad: smallint, education- Masters: smallint, education- Preschool: smallint, education- Prof-school: smallint, education- Some-college: smallint, education-num: double, marital-status- Divorced: smallint, marital-status- Married-AF-spouse: smallint, marital-status- Married-civ-spouse: 

23/02/06 04:28:40 WARN HeartbeatReceiver: Removing executor driver with no recent heartbeats: 3204856 ms exceeds timeout 120000 ms
23/02/06 04:28:40 WARN SparkContext: Killing executors is not supported by current scheduler.


In [88]:
FEATURES = x_test.columns
path = './weights/weights-epoch(10)-batch(256)-embedding(5).h5' 

# Load the data into a PySpark DataFrame
df = spark.read.parquet('data/test.parquet', header=True, inferSchema=True)


# Define the function that will make predictions with the model
def predict(iterator):
    # model load
    fm = DeepFM(embedding_size=config.EMBEDDING_SIZE, num_feature=len(field_index),
               num_field=len(field_dict), field_index=field_index)
    fm.build(input_shape = (1,len(field_index)))
    fm.load_weights(path)

    # Make predictions for each row in the iterator
    results = [float(fm(row[FEATURES].values)) for row in iterator]
    return results

#     return fm(iterator[FEATURES].values)

# Use the mapPartitions function to make predictions on the data
predictions = df.rdd.mapPartitions(predict).collect()
predictions

2023-02-05 21:04:14.994412: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-02-05 21:04:19.564882: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


23/02/05 21:04:19 ERROR Executor: Exception in task 0.0 in stage 52.0 (TID 52)
org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/Users/youngyong/opt/anaconda3/envs/spark/lib/python3.8/site-packages/pyspark/python/lib/pyspark.zip/pyspark/sql/types.py", line 1884, in __getitem__
    idx = self.__fields__.index(item)
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/Users/youngyong/opt/anaconda3/envs/spark/lib/python3.8/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 686, in main
    process()
  File "/Users/youngyong/opt/anaconda3/envs/spark/lib/python3.8/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 676, in process
    out_iter = func(split_index, iterator)
  File "/Users/youngyong/opt/anaconda3/envs/spark/lib/python3.8/site-packages/pysp

Py4JJavaError: An error occurred while calling z:org.apache.spark.api.python.PythonRDD.collectAndServe.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 52.0 failed 1 times, most recent failure: Lost task 0.0 in stage 52.0 (TID 52) (192.168.0.24 executor driver): org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/Users/youngyong/opt/anaconda3/envs/spark/lib/python3.8/site-packages/pyspark/python/lib/pyspark.zip/pyspark/sql/types.py", line 1884, in __getitem__
    idx = self.__fields__.index(item)
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/Users/youngyong/opt/anaconda3/envs/spark/lib/python3.8/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 686, in main
    process()
  File "/Users/youngyong/opt/anaconda3/envs/spark/lib/python3.8/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 676, in process
    out_iter = func(split_index, iterator)
  File "/Users/youngyong/opt/anaconda3/envs/spark/lib/python3.8/site-packages/pyspark/rdd.py", line 540, in func
    return f(iterator)
  File "/var/folders/y7/ctbm4_yn3_11qs65_7zhkdcm0000gn/T/ipykernel_9775/3889536951.py", line 17, in predict
  File "/var/folders/y7/ctbm4_yn3_11qs65_7zhkdcm0000gn/T/ipykernel_9775/3889536951.py", line 17, in <listcomp>
  File "/Users/youngyong/opt/anaconda3/envs/spark/lib/python3.8/site-packages/pyspark/python/lib/pyspark.zip/pyspark/sql/types.py", line 1889, in __getitem__
    raise ValueError(item)
ValueError: Index(['age', 'workclass- ?', 'workclass- Federal-gov', 'workclass- Local-gov',
       'workclass- Never-worked', 'workclass- Private',
       'workclass- Self-emp-inc', 'workclass- Self-emp-not-inc',
       'workclass- State-gov', 'workclass- Without-pay',
       ...
       'country- Portugal', 'country- Puerto-Rico', 'country- Scotland',
       'country- South', 'country- Taiwan', 'country- Thailand',
       'country- Trinadad&Tobago', 'country- United-States',
       'country- Vietnam', 'country- Yugoslavia'],
      dtype='object', length=108)

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:559)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:765)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:747)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:512)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator.foreach(Iterator.scala:943)
	at scala.collection.Iterator.foreach$(Iterator.scala:943)
	at org.apache.spark.InterruptibleIterator.foreach(InterruptibleIterator.scala:28)
	at scala.collection.generic.Growable.$plus$plus$eq(Growable.scala:62)
	at scala.collection.generic.Growable.$plus$plus$eq$(Growable.scala:53)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:105)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:49)
	at scala.collection.TraversableOnce.to(TraversableOnce.scala:366)
	at scala.collection.TraversableOnce.to$(TraversableOnce.scala:364)
	at org.apache.spark.InterruptibleIterator.to(InterruptibleIterator.scala:28)
	at scala.collection.TraversableOnce.toBuffer(TraversableOnce.scala:358)
	at scala.collection.TraversableOnce.toBuffer$(TraversableOnce.scala:358)
	at org.apache.spark.InterruptibleIterator.toBuffer(InterruptibleIterator.scala:28)
	at scala.collection.TraversableOnce.toArray(TraversableOnce.scala:345)
	at scala.collection.TraversableOnce.toArray$(TraversableOnce.scala:339)
	at org.apache.spark.InterruptibleIterator.toArray(InterruptibleIterator.scala:28)
	at org.apache.spark.rdd.RDD.$anonfun$collect$2(RDD.scala:1021)
	at org.apache.spark.SparkContext.$anonfun$runJob$5(SparkContext.scala:2268)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:136)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:548)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1504)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:551)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1130)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:630)
	at java.base/java.lang.Thread.run(Thread.java:831)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2672)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2608)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2607)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2607)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1182)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1182)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1182)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2860)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2802)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2791)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:952)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2228)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2249)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2268)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2293)
	at org.apache.spark.rdd.RDD.$anonfun$collect$1(RDD.scala:1021)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:406)
	at org.apache.spark.rdd.RDD.collect(RDD.scala:1020)
	at org.apache.spark.api.python.PythonRDD$.collectAndServe(PythonRDD.scala:180)
	at org.apache.spark.api.python.PythonRDD.collectAndServe(PythonRDD.scala)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:78)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:567)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.base/java.lang.Thread.run(Thread.java:831)
Caused by: org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/Users/youngyong/opt/anaconda3/envs/spark/lib/python3.8/site-packages/pyspark/python/lib/pyspark.zip/pyspark/sql/types.py", line 1884, in __getitem__
    idx = self.__fields__.index(item)
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/Users/youngyong/opt/anaconda3/envs/spark/lib/python3.8/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 686, in main
    process()
  File "/Users/youngyong/opt/anaconda3/envs/spark/lib/python3.8/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 676, in process
    out_iter = func(split_index, iterator)
  File "/Users/youngyong/opt/anaconda3/envs/spark/lib/python3.8/site-packages/pyspark/rdd.py", line 540, in func
    return f(iterator)
  File "/var/folders/y7/ctbm4_yn3_11qs65_7zhkdcm0000gn/T/ipykernel_9775/3889536951.py", line 17, in predict
  File "/var/folders/y7/ctbm4_yn3_11qs65_7zhkdcm0000gn/T/ipykernel_9775/3889536951.py", line 17, in <listcomp>
  File "/Users/youngyong/opt/anaconda3/envs/spark/lib/python3.8/site-packages/pyspark/python/lib/pyspark.zip/pyspark/sql/types.py", line 1889, in __getitem__
    raise ValueError(item)
ValueError: Index(['age', 'workclass- ?', 'workclass- Federal-gov', 'workclass- Local-gov',
       'workclass- Never-worked', 'workclass- Private',
       'workclass- Self-emp-inc', 'workclass- Self-emp-not-inc',
       'workclass- State-gov', 'workclass- Without-pay',
       ...
       'country- Portugal', 'country- Puerto-Rico', 'country- Scotland',
       'country- South', 'country- Taiwan', 'country- Thailand',
       'country- Trinadad&Tobago', 'country- United-States',
       'country- Vietnam', 'country- Yugoslavia'],
      dtype='object', length=108)

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:559)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:765)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:747)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:512)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator.foreach(Iterator.scala:943)
	at scala.collection.Iterator.foreach$(Iterator.scala:943)
	at org.apache.spark.InterruptibleIterator.foreach(InterruptibleIterator.scala:28)
	at scala.collection.generic.Growable.$plus$plus$eq(Growable.scala:62)
	at scala.collection.generic.Growable.$plus$plus$eq$(Growable.scala:53)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:105)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:49)
	at scala.collection.TraversableOnce.to(TraversableOnce.scala:366)
	at scala.collection.TraversableOnce.to$(TraversableOnce.scala:364)
	at org.apache.spark.InterruptibleIterator.to(InterruptibleIterator.scala:28)
	at scala.collection.TraversableOnce.toBuffer(TraversableOnce.scala:358)
	at scala.collection.TraversableOnce.toBuffer$(TraversableOnce.scala:358)
	at org.apache.spark.InterruptibleIterator.toBuffer(InterruptibleIterator.scala:28)
	at scala.collection.TraversableOnce.toArray(TraversableOnce.scala:345)
	at scala.collection.TraversableOnce.toArray$(TraversableOnce.scala:339)
	at org.apache.spark.InterruptibleIterator.toArray(InterruptibleIterator.scala:28)
	at org.apache.spark.rdd.RDD.$anonfun$collect$2(RDD.scala:1021)
	at org.apache.spark.SparkContext.$anonfun$runJob$5(SparkContext.scala:2268)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:136)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:548)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1504)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:551)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1130)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:630)
	... 1 more


### MapPartitions example

In [29]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('SparkByExamples.com').getOrCreate()
data = [('James','Smith','M',3000),
  ('Anna','Rose','F',4100),
  ('Robert','Williams','M',6200), 
]

columns = ["firstname","lastname","gender","salary"]
df = spark.createDataFrame(data=data, schema = columns)
df.show()

+---------+--------+------+------+
|firstname|lastname|gender|salary|
+---------+--------+------+------+
|    James|   Smith|     M|  3000|
|     Anna|    Rose|     F|  4100|
|   Robert|Williams|     M|  6200|
+---------+--------+------+------+



                                                                                

In [30]:
# This function calls for each partition
def reformat(partitionData):
    for row in partitionData:
        yield [row['firstname']+","+row.lastname,row.salary*10/100]

df2=df.rdd.mapPartitions(reformat).toDF(["name","bonus"])
df2.show()

+---------------+-----+
|           name|bonus|
+---------------+-----+
|    James,Smith|300.0|
|      Anna,Rose|410.0|
|Robert,Williams|620.0|
+---------------+-----+

