In [41]:
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.feature import StringIndexer
from pyspark.ml.recommendation import ALS
from pyspark.sql import SparkSession

from pyspark.sql.functions import col, rank, expr
from pyspark.sql.window import Window

In [42]:
spark = (
    SparkSession.builder.appName("Collaborative Filtering with random stratified split")  # type: ignore
    .config("spark.jars.packages", "org.apache.hadoop:hadoop-aws:3.3.4")
    .config("fs.s3a.endpoint", "s3.us-east-2.amazonaws.com")
    .config(
        "fs.s3a.aws.credentials.provider",
        "com.amazonaws.auth.DefaultAWSCredentialsProviderChain",
    )
    .getOrCreate()
)

In [43]:
data = spark.read.parquet("s3a://amazon-reviews-eafit/sample-for-demo/")

In [44]:
data.printSchema()

root
 |-- customer_id: integer (nullable = true)
 |-- product_id: string (nullable = true)
 |-- star_rating: float (nullable = true)
 |-- category: string (nullable = true)
 |-- review_date: string (nullable = true)
 |-- verified_purchase: string (nullable = true)
 |-- review_id: string (nullable = true)
 |-- product_title: string (nullable = true)
 |-- product_category: string (nullable = true)



In [45]:
data.count()

                                                                                

4193

In [46]:
indexer = StringIndexer(inputCol="product_id", outputCol="item_id")

indexer_model = indexer.fit(data)

                                                                                

In [47]:
data = indexer_model.transform(data)

## Split de la data

Se hace un sampleo por customer id. Se toma el 80% de los customer id para el train y el 20% restante para el test.
Se repite el proceso para obtener el conjunto de validación partiendo de la data de training.


In [48]:
def split_data(data, percent_items_to_mask=0.2):
    user_window = Window.partitionBy("customer_id").orderBy(col("product_id").desc())
    data_processed = data.withColumn(
        "number_of_products", expr("count(*) over (partition by customer_id)")
    )
    data_processed = data_processed.withColumn(
        "number_of_products_to_mask",
        (col("number_of_products") * percent_items_to_mask).cast("int"),
    )
    data_processed = data_processed.withColumn("product_rank", rank().over(user_window))

    training = data_processed.filter(
        col("product_rank") > col("number_of_products_to_mask")
    )
    test = data_processed.filter(
        col("product_rank") <= col("number_of_products_to_mask")
    )

    return training, test

In [49]:
training, test = split_data(data, percent_items_to_mask=0.3)
training, validation = split_data(training, percent_items_to_mask=0.3)

training_count = training.count()
validation_count = validation.count()
test_count = test.count()

print(f"Training count: {training_count}")
print(f"Validation count: {validation_count}")
print(f"Test count: {test_count}")

[Stage 1305:>                                                       (0 + 1) / 1]

Training count: 2525
Validation count: 661
Test count: 1007


                                                                                

In [50]:
training.select("customer_id").distinct().show()

[Stage 1311:>                                                       (0 + 1) / 1]

+-----------+
|customer_id|
+-----------+
|      17396|
|     240731|
|     416779|
|     538167|
|     665033|
|     753990|
|     887722|
|    1138029|
|    1144861|
|    1451822|
|    1467085|
|    1501941|
|    1569901|
|    1945409|
|    2052299|
|    2098369|
|    2299179|
|    2465423|
|    2466460|
|    2623619|
+-----------+
only showing top 20 rows



                                                                                

## Validación con un usuario

Validamos cuántos datos tiene un usuario en total, cuántos quedaron en training, test y validation.


In [51]:
data[data["customer_id"] == 17396].count()

                                                                                

4

In [52]:
training[training["customer_id"] == 17396].count()

                                                                                

3

In [53]:
validation[validation["customer_id"] == 17396].count()

                                                                                

0

In [54]:
test[test["customer_id"] == 17396].count()

                                                                                

1

## Función para entrenar el modelo

Esta función recibe dos parámetros: `maxIter` y `regParam`. Esto debido a que más adelante se hará un grid search para ver la mejor combinación de parámetros.


In [55]:
def train_model(maxIter=15, regParam=0.1, rank=15):
    als = ALS(
        maxIter=maxIter,
        regParam=regParam,
        userCol="customer_id",
        itemCol="item_id",
        ratingCol="star_rating",
        seed=42,
        nonnegative=True,
        rank=rank,
        coldStartStrategy="drop",
    )
    model = als.fit(training)
    return model

In [56]:
model = train_model()

                                                                                

In [57]:
def get_metrics(dataset, model):
    predictions = model.transform(dataset)

    print(f"Predictions count: {predictions.count()}")
    evaluator_rmse = RegressionEvaluator(
        metricName="rmse", labelCol="star_rating", predictionCol="prediction"
    )
    rmse = evaluator_rmse.evaluate(predictions)
    print(f"Root-mean-square error = {rmse}")

    evaluator_mae = RegressionEvaluator(
        metricName="mae", labelCol="star_rating", predictionCol="prediction"
    )
    mae = evaluator_mae.evaluate(predictions)
    print(f"Mean absolute error = {mae}")

    return rmse, mae

In [58]:
get_metrics(test, model)

Predictions count: 10


                                                                                

Root-mean-square error = 3.362075552166499


[Stage 1573:>                                                       (0 + 1) / 1]

Mean absolute error = 3.1968454718589783


                                                                                

(3.362075552166499, 3.1968454718589783)

In [59]:
get_metrics(validation, model)

                                                                                

Predictions count: 13


                                                                                

Root-mean-square error = 2.951945224498374


[Stage 1807:>                                                       (0 + 1) / 1]

Mean absolute error = 2.5373427936663995


                                                                                

(2.951945224498374, 2.5373427936663995)