In [1]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as f


spark = SparkSession.builder.appName("Chapter4-3").getOrCreate()
RATINGS_CSV_LOCATION = "/home/jovyan/data-sets/ml-latest-small/ratings.csv"

In [2]:
ratings = (
    spark.read.csv(
        path=RATINGS_CSV_LOCATION,
        sep=",",
        header=True,
        quote='"',
        schema="userId INT, movieId INT, rating DOUBLE, timestamp INT",
    )
    # .withColumn("timestamp", f.to_timestamp(f.from_unixtime("timestamp")))
    .drop("timestamp")
    .cache()
)

The ALS class has this signature:

```python
class pyspark.ml.recommendation.ALS(
    rank=10,
    maxIter=10,
    regParam=0.1,
    numUserBlocks=10,
    numItemBlocks=10,
    implicitPrefs=False,
    alpha=1.0,
    userCol="user",
    itemCol="item",
    seed=None,
    ratingCol="rating",
    nonnegative=False,
    checkpointInterval=10,
    intermediateStorageLevel="MEMORY_AND_DISK",
    finalStorageLevel="MEMORY_AND_DISK",
    coldStartStrategy="nan",
)
```

In [3]:
from pyspark.ml.recommendation import ALS

In [4]:
model = (
    ALS(
        userCol="userId",
        itemCol="movieId",
        ratingCol="rating",
    ).fit(ratings)
)

In [5]:
predictions = model.transform(ratings)
predictions.show(10, False)

+------+-------+------+----------+
|userId|movieId|rating|prediction|
+------+-------+------+----------+
|191   |148    |5.0   |4.918388  |
|133   |471    |4.0   |3.240957  |
|597   |471    |2.0   |3.904796  |
|385   |471    |4.0   |3.421205  |
|436   |471    |3.0   |3.3957336 |
|602   |471    |4.0   |3.3982077 |
|91    |471    |1.0   |2.497981  |
|409   |471    |3.0   |3.702609  |
|372   |471    |3.0   |3.11863   |
|599   |471    |2.5   |2.5535092 |
+------+-------+------+----------+
only showing top 10 rows



In [6]:
model.userFactors.show(5)

+---+--------------------+
| id|            features|
+---+--------------------+
| 10|[0.9379934, -0.30...|
| 20|[1.0685308, 0.153...|
| 30|[1.107887, 0.5763...|
| 40|[0.34756732, 0.22...|
| 50|[0.19457427, 0.41...|
+---+--------------------+
only showing top 5 rows



In [7]:
model.itemFactors.show(5)

+---+--------------------+
| id|            features|
+---+--------------------+
| 10|[0.69911313, 0.31...|
| 20|[0.8197691, 0.147...|
| 30|[-0.32572478, 0.3...|
| 40|[0.24038884, 0.89...|
| 50|[0.6135771, 0.661...|
+---+--------------------+
only showing top 5 rows

