In [91]:
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.feature import StringIndexer, IndexToString
from pyspark.ml.recommendation import ALS
from pyspark.sql import SparkSession

from pyspark.sql.functions import col, rank, expr
from pyspark.sql.window import Window
import itertools

In [92]:
spark = (
    SparkSession.builder.appName("Collaborative Filtering")  # type: ignore
    .config("spark.jars.packages", "org.apache.hadoop:hadoop-aws:3.3.4")
    .config("fs.s3a.endpoint", "s3.us-east-2.amazonaws.com")
    .config(
        "fs.s3a.aws.credentials.provider",
        "com.amazonaws.auth.DefaultAWSCredentialsProviderChain",
    )
    .getOrCreate()
)

In [93]:
# Leer el archivo Parquet
data = spark.read.parquet("s3a://amazon-reviews-eafit/sample-for-model/")

In [94]:
data.printSchema()

root
 |-- customer_id: integer (nullable = true)
 |-- product_id: string (nullable = true)
 |-- star_rating: float (nullable = true)
 |-- category: string (nullable = true)
 |-- review_date: string (nullable = true)
 |-- verified_purchase: string (nullable = true)
 |-- review_id: string (nullable = true)
 |-- product_title: string (nullable = true)
 |-- product_category: string (nullable = true)



In [95]:
data.count()

                                                                                

4184

In [96]:
indexer = StringIndexer(inputCol="product_id", outputCol="item_id")

indexer_model = indexer.fit(data)

inverter = IndexToString(
    inputCol="item_id", outputCol="original_item_id", labels=indexer_model.labels
)

inverter.write().overwrite().save("./inverter")

                                                                                

In [97]:
data = indexer_model.transform(data)

In [98]:
loaded_inverter = IndexToString.load("./inverter")

## Split de la data

Se hace un sampleo por customer id. Se toma el 80% de los customer id para el train y el 20% restante para el test.
Se repite el proceso para obtener el conjunto de validación partiendo de la data de training.


In [99]:
def split_data(data, percent_items_to_mask=0.2):
    user_window = Window.partitionBy("customer_id").orderBy(col("product_id").desc())
    data_processed = data.withColumn(
        "number_of_products", expr("count(*) over (partition by customer_id)")
    )
    data_processed = data_processed.withColumn(
        "number_of_products_to_mask",
        (col("number_of_products") * percent_items_to_mask).cast("int"),
    )
    data_processed = data_processed.withColumn("product_rank", rank().over(user_window))

    training = data_processed.filter(
        col("product_rank") > col("number_of_products_to_mask")
    )
    test = data_processed.filter(
        col("product_rank") <= col("number_of_products_to_mask")
    )

    return training, test

In [100]:
training, test = split_data(data, percent_items_to_mask=0.1)
training, validation = split_data(training, percent_items_to_mask=0.1)

training_count = training.count()
validation_count = validation.count()
test_count = test.count()

print(f"Training count: {training_count}")
print(f"Validation count: {validation_count}")
print(f"Test count: {test_count}")

[Stage 24671:>                                                      (0 + 2) / 2]

Training count: 3819
Validation count: 169
Test count: 196


                                                                                

In [101]:
training.select("customer_id").distinct().show()

[Stage 24677:>                                                      (0 + 2) / 2]

+-----------+
|customer_id|
+-----------+
|      76286|
|      93131|
|     205422|
|     337529|
|     583137|
|     722290|
|     890674|
|    1176852|
|    1498833|
|    1553047|
|    1671892|
|    1735873|
|    2008691|
|    2039136|
|    2040014|
|    2127453|
|    2156883|
|    2177134|
|    2531809|
|    2561699|
+-----------+
only showing top 20 rows



                                                                                

## Validación con un usuario

Validamos cuántos datos tiene un usuario en total, cuántos quedaron en training, test y validation.


In [102]:
data[data["customer_id"] == 76286].count()

                                                                                

59

In [103]:
training[training["customer_id"] == 76286].count()

                                                                                

49

In [104]:
validation[validation["customer_id"] == 76286].count()

                                                                                

5

In [105]:
test[test["customer_id"] == 76286].count()

                                                                                

5

## Función para entrenar el modelo

Esta función recibe dos parámetros: `maxIter` y `regParam`. Esto debido a que más adelante se hará un grid search para ver la mejor combinación de parámetros.


In [106]:
def train_model(maxIter=5, regParam=0.1, rank=10):
    als = ALS(
        maxIter=maxIter,
        regParam=regParam,
        userCol="customer_id",
        itemCol="item_id",
        ratingCol="star_rating",
        seed=42,
        nonnegative=True,
        rank=rank,
        coldStartStrategy="drop",
    )
    model = als.fit(training)
    return model

In [107]:
model = train_model()

                                                                                

In [108]:
def get_metrics(dataset, model):
    predictions = model.transform(dataset)

    print(f"Predictions count: {predictions.count()}")
    evaluator_rmse = RegressionEvaluator(
        metricName="rmse", labelCol="star_rating", predictionCol="prediction"
    )
    rmse = evaluator_rmse.evaluate(predictions)
    print(f"Root-mean-square error = {rmse}")

    evaluator_mae = RegressionEvaluator(
        metricName="mae", labelCol="star_rating", predictionCol="prediction"
    )
    mae = evaluator_mae.evaluate(predictions)
    print(f"Mean absolute error = {mae}")

    return rmse, mae

In [109]:
get_metrics(test, model)

                                                                                

Predictions count: 3


                                                                                

Root-mean-square error = 3.8107639193596663




Mean absolute error = 3.6978548765182495


                                                                                

(3.8107639193596663, 3.6978548765182495)

In [110]:
get_metrics(validation, model)

                                                                                

Predictions count: 3


                                                                                

Root-mean-square error = 4.183111697187393
Mean absolute error = 4.138520240783691


                                                                                

(4.183111697187393, 4.138520240783691)

In [111]:
parameters = {
    "maxIter": [5, 10, 15],
    "regParam": [0.001, 0.01, 0.1],
    "rank": [1, 5, 10, 15, 20],
}
param_combinations = list(itertools.product(*parameters.values()))
tuning_parameters = [
    {"maxIter": maxIter, "regParam": regParam, "rank": rank}
    for maxIter, regParam, rank in param_combinations
]

In [112]:
corresponding_rmse, best_mae, best_parameters = float("inf"), float("inf"), None

for parameters_combination in tuning_parameters:
    print(f"Parameters: {parameters_combination}")
    model = train_model(**parameters_combination)
    rmse, mae = get_metrics(validation, model)
    print("-----------------------------------------")
    if mae < best_mae:
        best_mae = mae
        corresponding_rmse = rmse
        best_parameters = parameters_combination

Parameters: {'maxIter': 5, 'regParam': 0.001, 'rank': 1}


                                                                                

Predictions count: 3


                                                                                

Root-mean-square error = 2.9999550755487627


                                                                                

Mean absolute error = 2.3331985473632812
-----------------------------------------
Parameters: {'maxIter': 5, 'regParam': 0.001, 'rank': 5}


                                                                                

Predictions count: 3


                                                                                

Root-mean-square error = 0.7040036770265027


                                                                                

Mean absolute error = 0.647675355275472
-----------------------------------------
Parameters: {'maxIter': 5, 'regParam': 0.001, 'rank': 10}


                                                                                

Predictions count: 3


                                                                                

Root-mean-square error = 4.442838768943967


                                                                                

Mean absolute error = 4.424139300982158
-----------------------------------------
Parameters: {'maxIter': 5, 'regParam': 0.001, 'rank': 15}


                                                                                

Predictions count: 3


                                                                                

Root-mean-square error = 4.124614782057183


                                                                                

Mean absolute error = 4.090181152025859
-----------------------------------------
Parameters: {'maxIter': 5, 'regParam': 0.001, 'rank': 20}


                                                                                

Predictions count: 3


                                                                                

Root-mean-square error = 3.7141686935634044


                                                                                

Mean absolute error = 3.7124000787734985
-----------------------------------------
Parameters: {'maxIter': 5, 'regParam': 0.01, 'rank': 1}


                                                                                

Predictions count: 3


                                                                                

Root-mean-square error = 2.999589635741288


                                                                                

Mean absolute error = 2.3321011861165366
-----------------------------------------
Parameters: {'maxIter': 5, 'regParam': 0.01, 'rank': 5}


                                                                                

Predictions count: 3


                                                                                

Root-mean-square error = 1.0310413006586032


                                                                                

Mean absolute error = 0.8423579533894857
-----------------------------------------
Parameters: {'maxIter': 5, 'regParam': 0.01, 'rank': 10}


                                                                                

Predictions count: 3


                                                                                

Root-mean-square error = 4.389662417994656


                                                                                

Mean absolute error = 4.3667657772699995
-----------------------------------------
Parameters: {'maxIter': 5, 'regParam': 0.01, 'rank': 15}


                                                                                

Predictions count: 3


                                                                                

Root-mean-square error = 3.9825858731779937


                                                                                

Mean absolute error = 3.9085830052693686
-----------------------------------------
Parameters: {'maxIter': 5, 'regParam': 0.01, 'rank': 20}


                                                                                

Predictions count: 3


                                                                                

Root-mean-square error = 3.6567692713115014


                                                                                

Mean absolute error = 3.6559038956960044
-----------------------------------------
Parameters: {'maxIter': 5, 'regParam': 0.1, 'rank': 1}


                                                                                

Predictions count: 3


                                                                                

Root-mean-square error = 2.998836596938872


                                                                                

Mean absolute error = 2.3298346201578775
-----------------------------------------
Parameters: {'maxIter': 5, 'regParam': 0.1, 'rank': 5}


                                                                                

Predictions count: 3


                                                                                

Root-mean-square error = 2.5061240312220514


                                                                                

Mean absolute error = 2.486309051513672
-----------------------------------------
Parameters: {'maxIter': 5, 'regParam': 0.1, 'rank': 10}


                                                                                

Predictions count: 3


                                                                                

Root-mean-square error = 4.183111697187393


                                                                                

Mean absolute error = 4.138520240783691
-----------------------------------------
Parameters: {'maxIter': 5, 'regParam': 0.1, 'rank': 15}


                                                                                

Predictions count: 3


                                                                                

Root-mean-square error = 3.637357563474419


                                                                                

Mean absolute error = 3.3218613465627036
-----------------------------------------
Parameters: {'maxIter': 5, 'regParam': 0.1, 'rank': 20}


                                                                                

Predictions count: 3


                                                                                

Root-mean-square error = 3.4766801807606798


                                                                                

Mean absolute error = 3.476507822672526
-----------------------------------------
Parameters: {'maxIter': 10, 'regParam': 0.001, 'rank': 1}


                                                                                

Predictions count: 3


                                                                                

Root-mean-square error = 2.99984523680335


                                                                                

Mean absolute error = 2.3328688939412436
-----------------------------------------
Parameters: {'maxIter': 10, 'regParam': 0.001, 'rank': 5}


                                                                                

Predictions count: 3


                                                                                

Root-mean-square error = 0.6938335762143409


                                                                                

Mean absolute error = 0.6785397529602051
-----------------------------------------
Parameters: {'maxIter': 10, 'regParam': 0.001, 'rank': 10}


                                                                                

Predictions count: 3


                                                                                

Root-mean-square error = 4.436164034193167


                                                                                

Mean absolute error = 4.416965802510579
-----------------------------------------
Parameters: {'maxIter': 10, 'regParam': 0.001, 'rank': 15}


                                                                                

Predictions count: 3


                                                                                

Root-mean-square error = 4.104895412418584


                                                                                

Mean absolute error = 4.066122810045878
-----------------------------------------
Parameters: {'maxIter': 10, 'regParam': 0.001, 'rank': 20}


                                                                                

Predictions count: 3


                                                                                

Root-mean-square error = 3.706713357214028


                                                                                

Mean absolute error = 3.7050787607828775
-----------------------------------------
Parameters: {'maxIter': 10, 'regParam': 0.01, 'rank': 1}


                                                                                

Predictions count: 3


                                                                                

Root-mean-square error = 2.998530045163354


                                                                                

Mean absolute error = 2.3289098739624023
-----------------------------------------
Parameters: {'maxIter': 10, 'regParam': 0.01, 'rank': 5}


                                                                                

Predictions count: 3


                                                                                

Root-mean-square error = 1.4359664921506516


                                                                                

Mean absolute error = 1.338454246520996
-----------------------------------------
Parameters: {'maxIter': 10, 'regParam': 0.01, 'rank': 10}


                                                                                

Predictions count: 3


                                                                                

Root-mean-square error = 4.3458611538903265


                                                                                

Mean absolute error = 4.319109042485555
-----------------------------------------
Parameters: {'maxIter': 10, 'regParam': 0.01, 'rank': 15}


                                                                                

Predictions count: 3


                                                                                

Root-mean-square error = 3.8899204282443014


                                                                                

Mean absolute error = 3.777902046839396
-----------------------------------------
Parameters: {'maxIter': 10, 'regParam': 0.01, 'rank': 20}


                                                                                

Predictions count: 3


                                                                                

Root-mean-square error = 3.6129521556747095


                                                                                

Mean absolute error = 3.612561027208964
-----------------------------------------
Parameters: {'maxIter': 10, 'regParam': 0.1, 'rank': 1}


                                                                                

Predictions count: 3


                                                                                

Root-mean-square error = 2.988458761076085


                                                                                

Mean absolute error = 2.297830899556478
-----------------------------------------
Parameters: {'maxIter': 10, 'regParam': 0.1, 'rank': 5}


                                                                                

Predictions count: 3


                                                                                

Root-mean-square error = 2.711224306331523


                                                                                

Mean absolute error = 2.699795961380005
-----------------------------------------
Parameters: {'maxIter': 10, 'regParam': 0.1, 'rank': 10}


                                                                                

Predictions count: 3


                                                                                

Root-mean-square error = 4.121508585177505


                                                                                

Mean absolute error = 4.068540891011556
-----------------------------------------
Parameters: {'maxIter': 10, 'regParam': 0.1, 'rank': 15}


                                                                                

Predictions count: 3


                                                                                

Root-mean-square error = 3.582987345791004


                                                                                

Mean absolute error = 3.206581711769104
-----------------------------------------
Parameters: {'maxIter': 10, 'regParam': 0.1, 'rank': 20}


                                                                                

Predictions count: 3


                                                                                

Root-mean-square error = 3.4340819600750914


                                                                                

Mean absolute error = 3.433438460032145
-----------------------------------------
Parameters: {'maxIter': 15, 'regParam': 0.001, 'rank': 1}


                                                                                

Predictions count: 3


                                                                                

Root-mean-square error = 2.9997355542234994


                                                                                

Mean absolute error = 2.3325395584106445
-----------------------------------------
Parameters: {'maxIter': 15, 'regParam': 0.001, 'rank': 5}


                                                                                

Predictions count: 3


                                                                                

Root-mean-square error = 0.7063828566357221


                                                                                

Mean absolute error = 0.7059003512064616
-----------------------------------------
Parameters: {'maxIter': 15, 'regParam': 0.001, 'rank': 10}


                                                                                

Predictions count: 3


                                                                                

Root-mean-square error = 4.429731434441888


                                                                                

Mean absolute error = 4.410045027732849
-----------------------------------------
Parameters: {'maxIter': 15, 'regParam': 0.001, 'rank': 15}


                                                                                

Predictions count: 3


                                                                                

Root-mean-square error = 4.086602577432527


                                                                                

Mean absolute error = 4.043509999910991
-----------------------------------------
Parameters: {'maxIter': 15, 'regParam': 0.001, 'rank': 20}


                                                                                

Predictions count: 3


                                                                                

Root-mean-square error = 3.699599725831476


                                                                                

Mean absolute error = 3.6980881690979004
-----------------------------------------
Parameters: {'maxIter': 15, 'regParam': 0.01, 'rank': 1}


                                                                                

Predictions count: 3


                                                                                

Root-mean-square error = 2.997462281379828


                                                                                

Mean absolute error = 2.3256794611612954
-----------------------------------------
Parameters: {'maxIter': 15, 'regParam': 0.01, 'rank': 5}


                                                                                

Predictions count: 3


                                                                                

Root-mean-square error = 1.7038016215066183


                                                                                

Mean absolute error = 1.6375991503397624
-----------------------------------------
Parameters: {'maxIter': 15, 'regParam': 0.01, 'rank': 10}


                                                                                

Predictions count: 3


                                                                                

Root-mean-square error = 4.312230711141637


                                                                                

Mean absolute error = 4.28226097424825
-----------------------------------------
Parameters: {'maxIter': 15, 'regParam': 0.01, 'rank': 15}


                                                                                

Predictions count: 3


                                                                                

Root-mean-square error = 3.828307905470626


                                                                                

Mean absolute error = 3.683886965115865
-----------------------------------------
Parameters: {'maxIter': 15, 'regParam': 0.01, 'rank': 20}


                                                                                

Predictions count: 3


                                                                                

Root-mean-square error = 3.5813216992195147


                                                                                

Mean absolute error = 3.5811583201090493
-----------------------------------------
Parameters: {'maxIter': 15, 'regParam': 0.1, 'rank': 1}


                                                                                

Predictions count: 3


                                                                                

Root-mean-square error = 2.9773260849300396


                                                                                

Mean absolute error = 2.2617225646972656
-----------------------------------------
Parameters: {'maxIter': 15, 'regParam': 0.1, 'rank': 5}


                                                                                

Predictions count: 3


                                                                                

Root-mean-square error = 2.7875285291060297


                                                                                

Mean absolute error = 2.7797346115112305
-----------------------------------------
Parameters: {'maxIter': 15, 'regParam': 0.1, 'rank': 10}


                                                                                

Predictions count: 3


                                                                                

Root-mean-square error = 4.086077795436402


                                                                                

Mean absolute error = 4.027838468551636
-----------------------------------------
Parameters: {'maxIter': 15, 'regParam': 0.1, 'rank': 15}


                                                                                

Predictions count: 3


                                                                                

Root-mean-square error = 3.5566283861521875


                                                                                

Mean absolute error = 3.155261516571045
-----------------------------------------
Parameters: {'maxIter': 15, 'regParam': 0.1, 'rank': 20}


                                                                                

Predictions count: 3


                                                                                

Root-mean-square error = 3.4083183575771763




Mean absolute error = 3.4071412086486816
-----------------------------------------


                                                                                

In [113]:
print(f"Best parameters: {best_parameters}")
print(f"Best MAE: {mae}")
print(f"RMSE corresponding to the best MAE: {corresponding_rmse}")

Best parameters: {'maxIter': 5, 'regParam': 0.001, 'rank': 5}
Best MAE: 3.4071412086486816
RMSE corresponding to the best MAE: 0.7040036770265027
