In [26]:
import findspark
findspark.init()
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.ml.feature import StringIndexer
from pyspark.ml.recommendation import ALS
from pyspark.ml.evaluation import RegressionEvaluator
import pandas as pd

In [50]:
data=pd.read_csv('./data/ratings_Electronics.csv',names=['userId','productId','Rating','timestamp'])
data=data[['userId','productId','Rating']]
data.head()

Unnamed: 0,userId,productId,Rating
0,AKM1MP6P0OYPR,132793040,5.0
1,A2CX7LUOHB2NDG,321732944,5.0
2,A2NWSAGRHCP8N5,439886341,1.0
3,A2WNBOD3WNDNKT,439886341,3.0
4,A1GI0U4ZRJA8WN,439886341,1.0


In [51]:
newData=data.sample(frac=0.1,random_state=1)
print(len(data),len(newData))
print(len(newData)/len(data))

7824482 782448
0.0999999744392025


In [52]:
spark=SparkSession.builder.appName('alsRec').getOrCreate()

In [53]:
newData=spark.createDataFrame(newData)
newData.show(10)

+--------------+----------+------+
|        userId| productId|Rating|
+--------------+----------+------+
| A9UYPR4Q055LZ|B00AFXUUV6|   1.0|
| ACUWMWOHAFDXF|B00BFDHVAS|   5.0|
| ADH0W9QWMJ8V7|B000OMKR8E|   4.0|
|A1VDA4Z5EMT052|B0075SUG3Q|   5.0|
| ASYZVGMBYVY2Z|B0006B486K|   2.0|
| AJHVJMH379SQF|B0047UNP3I|   1.0|
|A1CDZ07YBEZP6C|B0046HNWO4|   5.0|
| APRVK6PDTNH82|B002U8573K|   3.0|
|A3KD4DEAEO2KRF|B001UZJBGI|   5.0|
|A2CGKM93KATTY3|B00746LVOM|   1.0|
+--------------+----------+------+
only showing top 10 rows



In [54]:
newData.describe().toPandas()

Unnamed: 0,summary,userId,productId,Rating
0,count,782448,782448,782448.0
1,mean,,5.370610959261765E9,4.012918430362146
2,stddev,,4.0989770434053845E9,1.3802497297953855
3,min,A00018041RRVMCICCAP79,0528881469,1.0
4,max,AZZZWXXUPZ1F3,BT008G3W52,5.0


In [55]:
newData.printSchema()

root
 |-- userId: string (nullable = true)
 |-- productId: string (nullable = true)
 |-- Rating: double (nullable = true)



In [56]:
newData.count()

782448

In [57]:
newData.select(countDistinct('userId')).show()

+----------------------+
|count(DISTINCT userId)|
+----------------------+
|                672618|
+----------------------+



In [58]:
indexer = StringIndexer(inputCol = 'userId', outputCol = "moduserID")
DFF=indexer.fit(newData).transform(newData)

In [59]:
indexer = StringIndexer(inputCol = 'productId', outputCol = "modproductID")
DFF=indexer.fit(DFF).transform(DFF)

In [60]:
DFF.show(10)

+--------------+----------+------+---------+------------+
|        userId| productId|Rating|moduserID|modproductID|
+--------------+----------+------+---------+------------+
| A9UYPR4Q055LZ|B00AFXUUV6|   1.0| 558262.0|     14326.0|
| ACUWMWOHAFDXF|B00BFDHVAS|   5.0| 571418.0|      1021.0|
| ADH0W9QWMJ8V7|B000OMKR8E|   4.0|  64513.0|       426.0|
|A1VDA4Z5EMT052|B0075SUG3Q|   5.0| 211300.0|      8889.0|
| ASYZVGMBYVY2Z|B0006B486K|   2.0| 642022.0|       328.0|
| AJHVJMH379SQF|B0047UNP3I|   1.0| 600495.0|    119827.0|
|A1CDZ07YBEZP6C|B0046HNWO4|   5.0| 128396.0|     65161.0|
| APRVK6PDTNH82|B002U8573K|   3.0| 627917.0|     11214.0|
|A3KD4DEAEO2KRF|B001UZJBGI|   5.0| 479855.0|    104154.0|
|A2CGKM93KATTY3|B00746LVOM|   1.0| 286936.0|       410.0|
+--------------+----------+------+---------+------------+
only showing top 10 rows



In [61]:
DFF.printSchema()

root
 |-- userId: string (nullable = true)
 |-- productId: string (nullable = true)
 |-- Rating: double (nullable = true)
 |-- moduserID: double (nullable = false)
 |-- modproductID: double (nullable = false)



In [62]:
DFF.filter(DFF.modproductID.isNull()).count()

0

In [63]:
DFF=DFF.drop(*['userId','productId'])

In [64]:
DFF.show(10)

+------+---------+------------+
|Rating|moduserID|modproductID|
+------+---------+------------+
|   1.0| 558262.0|     14326.0|
|   5.0| 571418.0|      1021.0|
|   4.0|  64513.0|       426.0|
|   5.0| 211300.0|      8889.0|
|   2.0| 642022.0|       328.0|
|   1.0| 600495.0|    119827.0|
|   5.0| 128396.0|     65161.0|
|   3.0| 627917.0|     11214.0|
|   5.0| 479855.0|    104154.0|
|   1.0| 286936.0|       410.0|
+------+---------+------------+
only showing top 10 rows



In [65]:
training, test = DFF.randomSplit([0.8,0.2])

In [66]:
als = ALS(maxIter=20, regParam=0.01, userCol="moduserID", itemCol="modproductID", 
          ratingCol="Rating",coldStartStrategy='drop')
model = als.fit(training)
predictions = model.transform(test)

Py4JJavaError: An error occurred while calling o807.fit.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 486.0 failed 1 times, most recent failure: Lost task 0.0 in stage 486.0 (TID 4926, 192.168.110.129, executor driver): java.lang.OutOfMemoryError: Java heap space
	at java.lang.reflect.Array.newInstance(Array.java:75)
	at java.io.ObjectInputStream.readArray(ObjectInputStream.java:1996)
	at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1613)
	at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2344)
	at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2268)
	at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2126)
	at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1625)
	at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2344)
	at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2268)
	at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2126)
	at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1625)
	at java.io.ObjectInputStream.readArray(ObjectInputStream.java:2032)
	at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1613)
	at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2344)
	at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2268)
	at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2126)
	at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1625)
	at java.io.ObjectInputStream.readArray(ObjectInputStream.java:2032)
	at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1613)
	at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2344)
	at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2268)
	at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2126)
	at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1625)
	at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2344)
	at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2268)
	at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2126)
	at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1625)
	at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2344)
	at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2268)
	at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2126)
	at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1625)
	at java.io.ObjectInputStream.readObject(ObjectInputStream.java:465)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2023)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:1972)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:1971)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1971)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:950)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:950)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:950)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2203)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2152)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2141)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:752)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2093)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2114)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2133)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2158)
	at org.apache.spark.rdd.RDD.count(RDD.scala:1227)
	at org.apache.spark.ml.recommendation.ALS$.train(ALS.scala:960)
	at org.apache.spark.ml.recommendation.ALS.$anonfun$fit$1(ALS.scala:709)
	at org.apache.spark.ml.util.Instrumentation$.$anonfun$instrumented$1(Instrumentation.scala:191)
	at scala.util.Try$.apply(Try.scala:213)
	at org.apache.spark.ml.util.Instrumentation$.instrumented(Instrumentation.scala:191)
	at org.apache.spark.ml.recommendation.ALS.fit(ALS.scala:691)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:238)
	at java.lang.Thread.run(Thread.java:748)
Caused by: java.lang.OutOfMemoryError: Java heap space
	at java.lang.reflect.Array.newInstance(Array.java:75)
	at java.io.ObjectInputStream.readArray(ObjectInputStream.java:1996)
	at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1613)
	at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2344)
	at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2268)
	at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2126)
	at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1625)
	at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2344)
	at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2268)
	at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2126)
	at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1625)
	at java.io.ObjectInputStream.readArray(ObjectInputStream.java:2032)
	at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1613)
	at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2344)
	at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2268)
	at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2126)
	at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1625)
	at java.io.ObjectInputStream.readArray(ObjectInputStream.java:2032)
	at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1613)
	at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2344)
	at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2268)
	at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2126)
	at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1625)
	at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2344)
	at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2268)
	at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2126)
	at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1625)
	at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2344)
	at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2268)
	at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2126)
	at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1625)
	at java.io.ObjectInputStream.readObject(ObjectInputStream.java:465)


----------------------------------------
Exception happened during processing of request from ('127.0.0.1', 52876)
Traceback (most recent call last):
  File "/usr/lib/python3.8/socketserver.py", line 316, in _handle_request_noblock
    self.process_request(request, client_address)
  File "/usr/lib/python3.8/socketserver.py", line 347, in process_request
    self.finish_request(request, client_address)
  File "/usr/lib/python3.8/socketserver.py", line 360, in finish_request
    self.RequestHandlerClass(request, client_address, self)
  File "/usr/lib/python3.8/socketserver.py", line 720, in __init__
    self.handle()
  File "/usr/lib/spark/python/pyspark/accumulators.py", line 268, in handle
    poll(accum_updates)
  File "/usr/lib/spark/python/pyspark/accumulators.py", line 241, in poll
    if func():
  File "/usr/lib/spark/python/pyspark/accumulators.py", line 245, in accum_updates
    num_updates = read_int(self.rfile)
  File "/usr/lib/spark/python/pyspark/serializers.py", line 595, i

In [47]:
predictions.describe().toPandas()

Unnamed: 0,summary,Rating,moduserID,modproductID,prediction
0,count,3019.0,3019.0,3019.0,3019.0
1,mean,4.192116594898973,4637.724412056972,7793.387214309374,-0.0091667246318268
2,stddev,1.2036740165664686,3357.326941445099,8494.303117700896,1.6709708858234569
3,min,1.0,0.0,0.0,-12.698547
4,max,5.0,10888.0,31626.0,11.719991


In [48]:
predictions =predictions.na.drop()

In [49]:
evaluator = RegressionEvaluator(metricName='rmse', labelCol='Rating')
rmse = evaluator.evaluate(predictions)
rmse

4.6839928356071985