In [None]:
import os
from pyspark.sql.types import StructType, StructField
from pyspark.sql.types import DoubleType, IntegerType, StringType, LongType

In [None]:
driver = 'driver'
url = 'POSTGRES_URL'
table = 'table_name'
user = 'postgres_username'
password = 'postgres_password'
full_url = 'full_postgres_url'

In [None]:
reviews_schema = StructType([
  StructField("col_id", IntegerType()),
  StructField("reviewid", StringType()),
  StructField("userid", IntegerType()),
  StructField("business_id", StringType()),
  StructField("rating", DoubleType()),
  StructField("bid", IntegerType()),
  StructField("username", StringType())
])

In [None]:
remote_table = spark.read.format("jdbc")\
  .option("driver", driver)\
  .option("inferSchema", reviews_schema) \
  .option("url", url)\
  .option("dbtable", table)\
  .option("user", user)\
  .option("password", password)\
  .option("ssl", True) \
  .option("sslmode", "require") \
  .load()

In [None]:
display(remote_table.select("*"))

col_id,reviewid,userid,business_id,rating,bid,username
1,un4FKbZLLCgeqe_M2UYKHA,100000111,DFeEJf8h04q3KwRVWUsoMQ,1,6172,kaggle
2,40eEBbACmuOjKPG5GlCeBQ,100000140,-fiUXzkxRfbHY9TKWwuptw,3,840,kaggle
3,n-Zp8ByGBnqlnlKd_GD-vQ,100001009,RA5ubX3nRc_1lghtcQC4nw,5,6771,kaggle
4,Zc1E1kn5qVPgCt-GBxKRsA,100001011,O8sBSjxL8hQbA41lKtcoJg,5,4526,kaggle
5,p-x8nyO4crWMGmxubTNcBg,100001602,Cy8XYYDrZ5wd3Bq-toXMsg,5,1465,kaggle
6,0xQAItDYWX2CHEmqUY9e9Q,100001701,E7VRsyNQWPteH6_qI4j7bw,4,1111,kaggle
7,rtkyXejQbZSgCqUX3lEb5Q,100001701,XTtNTWH_Nqv27RC7OtS7dQ,5,4111,kaggle
8,rdgGUoCoKwFr_EHHn5UMyw,100001701,ufmokEGxGqEWIdvIVsNg_Q,5,2253,kaggle
9,Bhj9s2-q_k0_kulcAzZRmA,100001701,XXW_OFaYQkkGOGniujZFHg,5,11895,kaggle
10,93bU2vXw8qbYSU-rfvYLKw,100001714,M0pkmBUi_CI0qrzN7ee80Q,5,2293,kaggle


In [None]:
df = remote_table.filter(remote_table.col_id == 3)

In [None]:
display(df)

col_id,reviewid,userid,business_id,rating,bid,username
3,n-Zp8ByGBnqlnlKd_GD-vQ,100001009,RA5ubX3nRc_1lghtcQC4nw,5,6771,kaggle


In [None]:
df = remote_table.filter(remote_table.col_id >= 500000)

In [None]:
df = remote_table

In [None]:
display(df)

col_id,reviewid,userid,business_id,rating,bid,username
1,un4FKbZLLCgeqe_M2UYKHA,100000111,DFeEJf8h04q3KwRVWUsoMQ,1,6172,kaggle
2,40eEBbACmuOjKPG5GlCeBQ,100000140,-fiUXzkxRfbHY9TKWwuptw,3,840,kaggle
3,n-Zp8ByGBnqlnlKd_GD-vQ,100001009,RA5ubX3nRc_1lghtcQC4nw,5,6771,kaggle
4,Zc1E1kn5qVPgCt-GBxKRsA,100001011,O8sBSjxL8hQbA41lKtcoJg,5,4526,kaggle
5,p-x8nyO4crWMGmxubTNcBg,100001602,Cy8XYYDrZ5wd3Bq-toXMsg,5,1465,kaggle
6,0xQAItDYWX2CHEmqUY9e9Q,100001701,E7VRsyNQWPteH6_qI4j7bw,4,1111,kaggle
7,rtkyXejQbZSgCqUX3lEb5Q,100001701,XTtNTWH_Nqv27RC7OtS7dQ,5,4111,kaggle
8,rdgGUoCoKwFr_EHHn5UMyw,100001701,ufmokEGxGqEWIdvIVsNg_Q,5,2253,kaggle
9,Bhj9s2-q_k0_kulcAzZRmA,100001701,XXW_OFaYQkkGOGniujZFHg,5,11895,kaggle
10,93bU2vXw8qbYSU-rfvYLKw,100001714,M0pkmBUi_CI0qrzN7ee80Q,5,2293,kaggle


In [None]:
(training, validation, test) = df.randomSplit([0.6, 0.2, 0.2])

# caching data to cut down on cross-validation time later
training.cache()
validation.cache()
test.cache()

In [None]:
from pyspark.ml.recommendation import ALS
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.tuning import TrainValidationSplit, ParamGridBuilder

als_dt = ALS(maxIter=5, regParam=0.25, userCol="userid", itemCol="bid", ratingCol="rating", coldStartStrategy="drop", nonnegative = True, implicitPrefs = False) 

In [None]:
def tune_ALS(train_data, validation_data, maxIter, regParams, ranks, als_dt):
    min_error = float('inf')
    best_rank = -1
    best_regularization = 0
    best_model = None
    for rank in ranks:
        for reg in regParams:
            # get ALS model
            als = als_dt.setMaxIter(maxIter).setRank(rank).setRegParam(reg)
            # train ALS model
            model = als.fit(training)
            # evaluate the model by computing the RMSE on the validation data
            predictions = model.transform(validation)
            evaluator = RegressionEvaluator(metricName="rmse",
                                            labelCol="rating",
                                            predictionCol="prediction")
            rmse = evaluator.evaluate(predictions)
            print('{} latent factors and regularization = {}: '
                  'validation RMSE is {}'.format(rank, reg, rmse))
            if rmse < min_error:
                min_error = rmse
                best_rank = rank
                best_regularization = reg
                best_model = model
    print('\nThe best model has {} latent factors and '
          'regularization = {}'.format(best_rank, best_regularization))
    return best_model

In [None]:
als = ALS(maxIter=2, regParam=0.20, userCol="userid", itemCol="bid", ratingCol="rating", coldStartStrategy="drop", nonnegative = True, implicitPrefs = False).setRank(50)

In [None]:
model = als.fit(training)

In [None]:
ALS_recommendations = model.recommendForAllUsers(numItems = 10)

In [None]:
display(ALS_recommendations.filter(ALS_recommendations['userid'] == 999998859))

userid,recommendations
999998859,"List(List(63945, 4.911867), List(63947, 3.929494), List(63951, 3.929494), List(10334, 3.138642), List(9434, 3.075673), List(305, 3.0444708), List(10862, 3.039565), List(10368, 3.0060344), List(8139, 2.9573507), List(7610, 2.9526193))"


In [None]:
# Temporary table
ALS_recommendations.registerTempTable("ALS_recs_temp")
clean_recs = spark.sql("""SELECT userid,
                            bIds_and_ratings.bid AS bid,
                            bIds_and_ratings.rating AS prediction
                        FROM ALS_recs_temp LATERAL VIEW explode(recommendations) exploded_table AS bIds_and_ratings""")

In [None]:
clean_recs.join(remote_table, ["userid", "bid"], "left").filter(remote_table.rating.isNull()).show()

clean_recs_filtered = clean_recs.select("userid", "bid", "prediction")

# display(clean_recs_filtered)

In [None]:
new_res = (clean_recs_filtered.join(remote_table, ["userid", "bid"], "left").filter(remote_table.rating.isNull()))

In [None]:
new_res_fnl = new_res.select('userid', 'bid', 'prediction')

new_res_users = new_res_fnl.filter(new_res_fnl['userid'] > 999998859)

new_res_use = new_res_users.select('userid', 'bid', 'prediction')

display(new_res_use)

userid,bid,prediction
999998860,1663,2.602079
999998860,2668,2.9864202
999998860,6824,2.548663
999998860,7610,2.7829602
999998860,7617,2.934951
999998860,7652,2.5675263
999998860,10269,2.6922617
999998861,3570,2.819012
999998861,6482,2.7527943
999998861,6772,3.2657988


In [None]:
new = new_res_use.filter(new_res_use['prediction'] > 4)
print(new.count())

In [None]:
new_res_use.write.format("jdbc")\
  .option("driver", driver)\
  .option("url", url)\
  .option("dbtable", 'new_recs')\
  .option("user", user)\
  .option("password", password)\
  .option("ssl", True) \
  .option("sslmode", "require") \
  .option("truncate", True)\
  .mode("overwrite")\
  .save()