In [1]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import sys
import os
sys.path.append(os.path.abspath("../.."))  # Adds the project root to sys.path
import pyspark
from pyspark.ml.recommendation import ALS
import pyspark.sql.functions as F
from pyspark.sql.functions import col
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField
from pyspark.sql.types import StringType, FloatType, IntegerType, LongType
from recommenders.datasets.spark_splitters import spark_random_split
from recommenders.utils.timer import Timer
import utils

In [2]:
COL_USER = "UserID"
COL_ITEM = "ProductID"
COL_RATING = "Rating"
DATA_FILE_PATH = "tools_recommendation_dataset.json"

In [None]:
#utils.generate_fake_data()
#utils.generate_csv_arrays_from_json(DATA_FILE_PATH)

Dataset saved.
Saved users.csv
Saved products.csv
Saved userPurchases.csv
Saved userReviews.csv


In [3]:
# the following settings work well for debugging locally on VM - change when running on a cluster
# set up a giant single executor with many threads and specify memory cap
spark = utils.start_or_get_spark("ALS PySpark", memory="16g")
spark.conf.set("spark.sql.analyzer.failAmbiguousSelfJoin", "false")

In [4]:
#users_df = spark.read.option("multiLine", "true").csv("users.csv", header=True, inferSchema=True)
#products_df = spark.read.option("multiLine", "true").csv("products.csv", header=True, inferSchema=True)
#purchases_df = spark.read.option("multiLine", "true").csv("userPurchases.csv", header=True, inferSchema=True)
reviews_df = spark.read.option("multiLine", "true").csv("ratings_Tools_and_Home_Improvement.csv", header=True, inferSchema=True)

In [5]:
reviews_df = reviews_df.na.fill({
    "UserID": 0,
    "ProductID": 0,
    "Rating": 0.0,
    "Timestamp": 0
})
reviews_df = reviews_df.withColumn("UserID", col("UserID").cast("integer")).withColumn("ProductID", col("ProductID").cast("integer"))
reviews_df.toPandas()

Unnamed: 0,UserID,ProductID,Rating,Timestamp
0,1594226,1212835,5.0,1389657600
1,19039,205062040,5.0,1373846400
2,9834,205062040,5.0,1382659200
3,8583,205062040,5.0,1372723200
4,165,205062040,4.0,1364256000
...,...,...,...,...
99994,3348,645,5.0,1199577600
99995,1699815,645,2.0,1277769600
99996,22,645,4.0,1244160000
99997,1904,645,5.0,1325635200


In [6]:
train, test = spark_random_split(reviews_df, ratio=0.75, seed=123)
print ("N train", train.cache().count())
print ("N test", test.cache().count())

N train 75018
N test 24981


In [7]:
header = {
    "userCol": COL_USER,
    "itemCol": COL_ITEM,
    "ratingCol": COL_RATING,
}

als = ALS(
    rank=10,
    maxIter=15,
    implicitPrefs=False,
    regParam=0.05,
    coldStartStrategy='drop',
    nonnegative=False,
    seed=42,
    **header
)

In [8]:
with Timer() as train_time:
    model = als.fit(train)

print("Took {} seconds for training.".format(train_time.interval))

Took 5.120330699952319 seconds for training.


In [9]:
with Timer() as test_time:

    # Get the cross join of all user-item pairs and score them.
    users = train.select(COL_USER).distinct()
    items = train.select(COL_ITEM).distinct()
    user_item = users.crossJoin(items)
    dfs_pred = model.transform(user_item)

    # Remove seen items.
    dfs_pred_exclude_train = dfs_pred.alias("pred").join(
        train.alias("train"),
        (dfs_pred[COL_USER] == train[COL_USER]) & (dfs_pred[COL_ITEM] == train[COL_ITEM]),
        how='outer'
    )

    top_all = dfs_pred_exclude_train.filter(dfs_pred_exclude_train[f"train.{COL_RATING}"].isNull()) \
        .select('pred.' + COL_USER, 'pred.' + COL_ITEM, 'pred.' + "prediction")

    # In Spark, transformations are lazy evaluation
    # Use an action to force execute and measure the test time 
    top_all.cache().count()

print("Took {} seconds for prediction.".format(test_time.interval))

Took 25.14241459988989 seconds for prediction.


In [10]:
top_all.show()

+-----------+-----------+----------+
|     UserID|  ProductID|prediction|
+-----------+-----------+----------+
|-2084477354|          0| 3.9647949|
|-2084477354|      22444| 2.3318439|
|-2084477354|      22454| 2.2646725|
|-2084477354| 1398501980|0.46135074|
|-2002542219|        496| 1.0963039|
|-2002542219|        569| 0.8405908|
|-2002542219|       2262| 1.1932302|
|-2002542219|       4837| 1.7314619|
|-2002542219| 1059552957|  0.823743|
|-2002542219| 1800039584|0.29650843|
|-1957642280|        410| 3.6291337|
|-1957642280|       2542| 3.2722466|
|-1957642280|       5283| 3.8566172|
|-1957642280| 1223027090| 3.4972682|
|-1913026792|       2205| 3.8257842|
|-1913026792|       4109| 3.9599617|
|-1913026792|      22551| 1.2333772|
|-1801416171|-1366914274| 0.5103797|
|-1801416171|       2258|  3.665957|
|-1801416171|       6405| 0.3616918|
+-----------+-----------+----------+
only showing top 20 rows

