In [2]:
%output --no-stdout

In [3]:
@file:Repository("https://binrepo.target.com/artifactory/gradle")
@file:Repository("https://binrepo.target.com/artifactory/maven-central")
@file:Repository("https://binrepo.target.com/artifactory/jcenter")
@file:Repository("https://binrepo.target.com/artifactory/jitpack-maven")
@file:Repository("https://binrepo.target.com/artifactory/kotlin-maven")
@file:Repository("https://binrepo.target.com/artifactory/apache-maven")
@file:Repository("https://binrepo.target.com/artifactory/jitpack")
%use spark

In [4]:
%output --reset-to-defaults
@file:DependsOn("org.jetbrains.kotlinx.spark:kotlin-spark-api-3.0.0_2.12:1.0.0-preview1")

In [41]:
import org.jetbrains.kotlinx.spark.api.*
import org.apache.spark.sql.functions.*
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.recommendation.ALS
import org.apache.spark.ml.tuning.CrossValidator
import org.apache.spark.ml.tuning.ParamGridBuilder

In [6]:
val ratingFile = "data/ml-latest/ratings.csv"
val movieFile = "data/ml-latest/movies.csv"
val linkFile = "data/ml-latest/links.csv"
val tagFile = "data/ml-latest/tags.csv"

In [7]:
val spark = SparkSession
.builder()
.master("local[*]")
.appName("SparkMl").orCreate

In [14]:
val ratings = spark
        .read()
        .option("header", "true")
        .option("inferSchema", "true")
        .csv(ratingFile)
        .drop("timestamp")
        .cache()
        

In [15]:
ratings.show()

+------+-------+------+
|userId|movieId|rating|
+------+-------+------+
|     1|    307|   3.5|
|     1|    481|   3.5|
|     1|   1091|   1.5|
|     1|   1257|   4.5|
|     1|   1449|   4.5|
|     1|   1590|   2.5|
|     1|   1591|   1.5|
|     1|   2134|   4.5|
|     1|   2478|   4.0|
|     1|   2840|   3.0|
|     1|   2986|   2.5|
|     1|   3020|   4.0|
|     1|   3424|   4.5|
|     1|   3698|   3.5|
|     1|   3826|   2.0|
|     1|   3893|   3.5|
|     2|    170|   3.5|
|     2|    849|   3.5|
|     2|   1186|   3.5|
|     2|   1235|   3.0|
+------+-------+------+
only showing top 20 rows



In [16]:
ratings.printSchema()

root
 |-- userId: integer (nullable = true)
 |-- movieId: integer (nullable = true)
 |-- rating: double (nullable = true)



In [12]:
ratings.summary().show()

+-------+------------------+-----------------+------------------+
|summary|            userId|          movieId|            rating|
+-------+------------------+-----------------+------------------+
|  count|          27753444|         27753444|          27753444|
|   mean|141942.01557064414|18487.99983414671|3.5304452124932677|
| stddev| 81707.40009148984| 35102.6252474677|1.0663527502319696|
|    min|                 1|                1|               0.5|
|    25%|             71164|             1099|               3.0|
|    50%|            142014|             2716|               3.5|
|    75%|            212447|             7151|               4.0|
|    max|            283228|           193886|               5.0|
+-------+------------------+-----------------+------------------+



In [20]:
val als = ALS()
    .setUserCol("userId")
    .setRatingCol("rating")
    .setItemCol("movieId")
    
val model = als.fit(ratings)

In [21]:
val predictions = model.transform(ratings)

In [23]:
predictions.show(10, false)

+------+-------+------+----------+
|userId|movieId|rating|prediction|
+------+-------+------+----------+
|107339|148    |4.0   |3.3621397 |
|93112 |148    |3.0   |2.9129353 |
|106148|148    |2.5   |2.8071108 |
|234926|148    |4.0   |2.945117  |
|253535|148    |4.0   |2.991361  |
|50155 |148    |3.0   |3.0100925 |
|65991 |148    |4.0   |3.0302284 |
|146376|148    |5.0   |3.6678638 |
|207939|148    |3.0   |2.6889248 |
|41788 |148    |3.0   |2.760429  |
+------+-------+------+----------+
only showing top 10 rows



In [24]:
val evaluator = RegressionEvaluator()
    .setMetricName("rmse")
    .setLabelCol("rating")
    .setPredictionCol("prediction")

In [25]:
val rmse = evaluator.evaluate(predictions)

In [26]:
rmse

0.773457379850624

In [28]:
val (trainingData, validationData) = ratings.randomSplit(doubleArrayOf(8.0, 2.0))

In [30]:
println(trainingData.count())
println(validationData.count())

22201630
5551814


In [31]:
ratings.count()

27753444

In [32]:
val model = als.fit(trainingData)

In [36]:
val predictions = model.transform(validationData)

In [39]:
val rmse = evaluator.evaluate(predictions.na().drop())

In [40]:
rmse

0.8140429426510024

In [42]:
val parameterGrid = ParamGridBuilder()
    .addGrid(als.rank(), intArrayOf(1,5,10))
    .addGrid(als.maxIter(), intArrayOf(20))
    .addGrid(als.regParam(), doubleArrayOf(0.05, 0.1))
    .build()

In [43]:
parameterGrid.javaClass.name

[Lorg.apache.spark.ml.param.ParamMap;

In [46]:
parameterGrid.contentToString()

[{
	als_a93d71366e31-maxIter: 20,
	als_a93d71366e31-rank: 1,
	als_a93d71366e31-regParam: 0.05
}, {
	als_a93d71366e31-maxIter: 20,
	als_a93d71366e31-rank: 5,
	als_a93d71366e31-regParam: 0.05
}, {
	als_a93d71366e31-maxIter: 20,
	als_a93d71366e31-rank: 10,
	als_a93d71366e31-regParam: 0.05
}, {
	als_a93d71366e31-maxIter: 20,
	als_a93d71366e31-rank: 1,
	als_a93d71366e31-regParam: 0.1
}, {
	als_a93d71366e31-maxIter: 20,
	als_a93d71366e31-rank: 5,
	als_a93d71366e31-regParam: 0.1
}, {
	als_a93d71366e31-maxIter: 20,
	als_a93d71366e31-rank: 10,
	als_a93d71366e31-regParam: 0.1
}]

In [47]:
val crossValidator = CrossValidator().setEstimator(als)
.setEstimatorParamMaps(parameterGrid)
.setEvaluator(evaluator)
.setNumFolds(2)

In [48]:
val crossValidatedModel = crossValidator.fit(trainingData)

In [49]:
val predictions = crossValidatedModel.transform(validationData)

In [50]:
val rmse = evaluator.evaluate(predictions.na().drop())
rmse

0.874035007571565

In [51]:
val model = crossValidatedModel.bestModel()

In [78]:
model.parent().extractParamMap()

{
	als_a93d71366e31-alpha: 1.0,
	als_a93d71366e31-checkpointInterval: 10,
	als_a93d71366e31-coldStartStrategy: nan,
	als_a93d71366e31-finalStorageLevel: MEMORY_AND_DISK,
	als_a93d71366e31-implicitPrefs: false,
	als_a93d71366e31-intermediateStorageLevel: MEMORY_AND_DISK,
	als_a93d71366e31-itemCol: movieId,
	als_a93d71366e31-maxIter: 20,
	als_a93d71366e31-nonnegative: false,
	als_a93d71366e31-numItemBlocks: 10,
	als_a93d71366e31-numUserBlocks: 10,
	als_a93d71366e31-predictionCol: prediction,
	als_a93d71366e31-rank: 1,
	als_a93d71366e31-ratingCol: rating,
	als_a93d71366e31-regParam: 0.05,
	als_a93d71366e31-seed: 1994790107,
	als_a93d71366e31-userCol: userId
}