In [1]:
import os

os.environ["PYSPARK_PYTHON"]="/home/jovyan/notebooks/env39/bin/python3.9"

In [2]:
import numpy as np
import pandas as pd
import nmslib
import scann

from pyspark.sql import functions as sf, types as st, SparkSession, Window
from pyspark.ml.feature import BucketedRandomProjectionLSH
from pyspark.ml.linalg import Vectors, VectorUDT
from pyspark.sql.functions import udf
from pyspark.ml.linalg import DenseVector
from pyspark.ml.feature import IndexToString, StringIndexer
from replay.session_handler import State
from replay.data_preparator import DataPreparator
from replay.splitters import DateSplitter
from replay.utils import get_log_info, get_top_k_recs
from replay.models import ALSWrap
from replay.model_handler import save, load
from replay.metrics import Coverage, HitRate, NDCG, MAP, Precision
from replay.experiment import Experiment
from rs_datasets import MillionSongDataset, MovieLens

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
spark_sess = (
    SparkSession
        .builder
        .master("local[6]")
        .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
        .config("spark.kryoserializer.buffer.max", "512m")
        .config("spark.driver.memory", "64g")
        .config("spark.executor.memory", "64g")
        .config("spark.sql.execution.arrow.pyspark.enabled", "true")
        .config("spark.local.dir", "/tmp")
        .getOrCreate()
)

spark = State(spark_sess).session
spark.sparkContext.setLogLevel("ERROR")

In [4]:
cores=6
K = 5
SEED=1234

In [5]:
%%time
# msd = MillionSongDataset()
data = MovieLens("1m")

CPU times: user 4.39 s, sys: 192 ms, total: 4.58 s
Wall time: 468 ms


In [6]:
preparator = DataPreparator()
log, _, _ = preparator(data.ratings, mapping={"relevance": "rating"})#, "user_id": "user_idx", "item_id": "item_idx"})
print(get_log_info(log))

total lines: 1000209, total users: 6040, total items: 3706


In [7]:
log = log.repartition(cores).cache()
log.write.mode('overwrite').format('noop').save()

In [8]:
only_positives_log = log.filter(sf.col('relevance') >= 3).withColumn('relevance', sf.lit(1))

In [9]:
# train/test split 
train_spl = DateSplitter(
    test_start=0.2,
    drop_cold_items=True,
    drop_cold_users=True,
)
train, test = train_spl.split(only_positives_log)
print('train info:\n', get_log_info(train))
print('test info:\n', get_log_info(test))

train info:
 total lines: 669181, total users: 5397, total items: 3569
test info:
 total lines: 86542, total users: 1139, total items: 3279


In [10]:
train = train.repartition(cores).cache()
test = test.repartition(cores).cache()
train.write.mode('overwrite').format('noop').save()
test.write.mode('overwrite').format('noop').save()

## 1. ALS model training

In [11]:
%%time
als = ALSWrap(rank=256, seed=SEED)
als.fit(log=train)

CPU times: user 26.5 ms, sys: 10.8 ms, total: 37.3 ms
Wall time: 2min 6s


In [12]:
save(als, '/tmp/als_ml1m')
als = load('/tmp/als_ml1m')

In [13]:
import time
def exec_time(start, end):
    diff_time = end - start
    return round(diff_time, 4)

time_dict = {}
index_time_dict = {}
index_time_dict['ALS'] = 0

In [14]:
%%time
start = time.time()
predict = als.predict(log=train, k=K, users=test.select('user_idx').distinct(), filter_seen_items=True)
predict = predict.cache()
predict.count()
end = time.time()
time_dict['ALS'] = exec_time(start,end)



CPU times: user 46 ms, sys: 12.9 ms, total: 58.9 ms
Wall time: 3min 22s


### 2. Evaluate prediction quality

In [15]:
prediction_quality = Experiment(test, {NDCG(): K,
                            MAP() : K,
                            HitRate(): [1, K],
                           })

In [16]:
%%time
prediction_quality.add_result("ALS", predict)
prediction_quality.results.sort_values('NDCG@5', ascending=False)

CPU times: user 78.7 ms, sys: 33.8 ms, total: 112 ms
Wall time: 17.8 s


Unnamed: 0,HitRate@1,HitRate@5,MAP@5,NDCG@5
ALS,0.279192,0.56014,0.179415,0.251115


### 3. Get features from ALS

In [17]:
def get_numpy_ids_vectors_from_als(id_vector_spark_df, id_name='item_id', vector_col_name='item_factors'):
    vectors = id_vector_spark_df.toPandas()
    ids = vectors[id_name].to_numpy()
    vectors = vectors[vector_col_name].to_numpy()
    return vectors, ids

In [18]:
%%time
user_vectors, _ = als.get_features(test.select('user_idx').distinct())
user_vectors_np, user_ids_np = get_numpy_ids_vectors_from_als(user_vectors, id_name='user_idx', vector_col_name='user_factors')

CPU times: user 6.85 ms, sys: 4.93 ms, total: 11.8 ms
Wall time: 299 ms


In [19]:
%%time
item_vectors, _ = als.get_features(train.select('item_idx').distinct())
item_vectors_np, item_ids_np = get_numpy_ids_vectors_from_als(item_vectors, id_name='item_idx', vector_col_name='item_factors')

CPU times: user 16.2 ms, sys: 5.1 ms, total: 21.3 ms
Wall time: 266 ms


### index with default parameters


##### Using hnsw index from nmslib as is

In [20]:
%%time
start = time.time()
index = nmslib.init(method='hnsw', space='negdotprod', data_type=nmslib.DataType.DENSE_VECTOR)
index.addDataPointBatch(data=np.stack(item_vectors_np))
index.createIndex()
end = time.time()
index_time_dict['HNSW'] = exec_time(start,end)

CPU times: user 1.36 s, sys: 108 ms, total: 1.47 s
Wall time: 162 ms


In [21]:
def get_neighbours(index, vectors, user_ids_list, item_ids_list, k):
    """
    - find nearest items based on user vector
    - convert to spark and process columns to get `user_id, item_id, relevance` columns
    - replace item numbers in index with item ids
    """
    neighbours = index.knnQueryBatch(np.stack(vectors), k=k)
    pd_res = pd.DataFrame(neighbours, columns=['item_idx', 'distance'])
    pd_res['user_idx'] = user_ids_list
    spark_res = spark.createDataFrame(pd_res)
    spark_res = spark_res.withColumn('zip_exp', sf.explode(sf.arrays_zip('item_idx', 'distance'))).select('user_idx', 'zip_exp')
    spark_res = spark_res.withColumn('item_idx', sf.col('zip_exp.item_idx'))
    spark_res = spark_res.withColumn('distance', sf.col('zip_exp.distance'))
    spark_res = (spark_res.withColumn('relevance',  sf.lit(-1.) * sf.col('distance'))
                 .select('user_idx', 'item_idx', 'relevance')
                )
    # ids_mapping = spark.createDataFrame(list(zip(range(len(list(item_ids_list))),
    #                                              list(item_ids_list))),
    #                                     schema="item_idx int, item_id string")
    ids_mapping = pd.DataFrame(list(zip(range(len(list(item_ids_list))),
                                                 list(item_ids_list))), columns = ['item_idx', 'item_id'])
    ids_mapping = spark.createDataFrame(ids_mapping)
    # return ids_mapping
    spark_res = spark_res.join(ids_mapping, on='item_idx').drop('item_idx').orderBy('user_idx')
    
    def withColumnRenamed(existingName, newName): DataFrame
    spark_res = spark_res.withColumnRenamed("item_id","item_idx")

    
    
    return spark_res

In [22]:
def filter_seen(log, pred, k, id_type='idx'):
    """
    filter items seen in log and leave top-k most relevant
    """
    
    user_id = 'user_' + id_type
    item_id = 'item_' + id_type
    num_of_seen = (
            log.groupBy(user_id)
            .agg(sf.count(item_id).alias("seen_count"))
        )


    max_seen = num_of_seen.select(sf.max("seen_count")).collect()[0][0]

    recs = pred.withColumn(
        "temp_rank",
        sf.row_number().over(
            Window.partitionBy(user_id).orderBy(
                sf.col("relevance").desc()
            )
        ),
    ).filter(sf.col("temp_rank") <= sf.lit(max_seen + k))
    

    
    recs = (
        recs.join(num_of_seen, on=user_id, how="left")
        .fillna(0)
        .filter(
            sf.col("temp_rank") <= sf.col("seen_count") + sf.lit(k)
        )
        .drop("temp_rank", "seen_count")
    )
    
    recs = recs.join(log, on=[user_id, item_id], how="anti")
    return get_top_k_recs(recs, k, id_type=id_type)

In [23]:
max_items_to_retrieve = train.groupBy('user_idx').agg(sf.count('item_idx').alias('num_items')).select(sf.max('num_items')).collect()[0][0]

In [24]:
%%time
start = time.time()
ann_res = get_neighbours(index, user_vectors_np, user_ids_np, item_ids_np, K + max_items_to_retrieve)
ann_res = filter_seen(train, ann_res, K)
ann_res = ann_res.cache()
ann_res.count()
end = time.time()
time_dict['HNSW'] = exec_time(start,end)

CPU times: user 124 ms, sys: 26.7 ms, total: 150 ms
Wall time: 3.61 s


In [25]:
%%time
prediction_quality.add_result("HNSW", ann_res)
prediction_quality.results.sort_values('NDCG@5', ascending=False)

CPU times: user 82.1 ms, sys: 25.8 ms, total: 108 ms
Wall time: 9.69 s


Unnamed: 0,HitRate@1,HitRate@5,MAP@5,NDCG@5
HNSW,0.280948,0.561018,0.180984,0.252473
ALS,0.279192,0.56014,0.179415,0.251115


### MLlib LSH

In [26]:
list_to_vector_udf = udf(lambda l: Vectors.dense(l), VectorUDT())

In [27]:
%%time

user_indexer = StringIndexer(
    inputCol="user_idx", outputCol="user_idx_indexed"
).fit(only_positives_log.select('user_idx').distinct())
item_indexer = StringIndexer(
    inputCol="item_idx", outputCol="item_idx_indexed"
).fit(only_positives_log.select('item_idx').distinct())

inv_user_indexer = IndexToString(
    inputCol="user_idx_indexed",
    outputCol="user_idx",
    labels=user_indexer.labels,
)
inv_item_indexer = IndexToString(
    inputCol="item_idx_indexed",
    outputCol="item_idx",
    labels=item_indexer.labels,
)

CPU times: user 24.2 ms, sys: 1.83 ms, total: 26 ms
Wall time: 716 ms


In [28]:
%%time

user_vectors_brp = (user_indexer.transform(user_vectors)
                    .withColumn('user_idx_indexed', sf.col('user_idx_indexed').cast("int"))
                    .drop('us`ber_idx')
                    .withColumnRenamed('user_idx_indexed', 'id')
                    .withColumn('features', list_to_vector_udf(sf.col('user_factors')))
                    .drop('user_factors')
                   )
item_vectors_brp = (item_indexer.transform(item_vectors)
                    .withColumn('item_idx_indexed', sf.col('item_idx_indexed').cast("int"))
                    .drop('item_idx')
                    .withColumnRenamed('item_idx_indexed', 'id')
                    .withColumn('features', list_to_vector_udf(sf.col('item_factors')))
                    .drop('item_factors')
                   )

CPU times: user 5.36 ms, sys: 2.44 ms, total: 7.8 ms
Wall time: 68.9 ms


In [29]:
# %%time
# user_vectors_brp.write.parquet(path='user_vectors_lsh.parquet', mode='overwrite')
# item_vectors_brp.write.parquet(path='item_vectors_lsh.parquet', mode='overwrite')

### Normalize vectors for better index performance and easier selection of bucketLength

In [30]:
@sf.udf(returnType=VectorUDT())
def norm_vector(first: DenseVector) -> DenseVector:
    return first * (1/float(first.dot(first))) ** 0.5

In [31]:
%%time
item_vectors_brp_n = item_vectors_brp.withColumn('features', norm_vector(sf.col('features')))
user_vectors_brp_n = user_vectors_brp.withColumn('features', norm_vector(sf.col('features')))


item_vectors_brp_n = item_vectors_brp_n.cache()
user_vectors_brp_n = user_vectors_brp_n.cache()
item_vectors_brp_n.count()
user_vectors_brp_n.count()

CPU times: user 23.5 ms, sys: 12 ms, total: 35.5 ms
Wall time: 4.36 s


1139

In [32]:
%%time
brp = BucketedRandomProjectionLSH(inputCol='features', seed=SEED, bucketLength=1, numHashTables=1)  #be ware with bucketLength
start = time.time()
brp_model = brp.fit(item_vectors_brp_n)
end = time.time()
index_time_dict['LSH'] = exec_time(start,end)

CPU times: user 3.77 ms, sys: 636 µs, total: 4.4 ms
Wall time: 76.2 ms


In [33]:
def convert_back(inv_user_indexer, inv_item_indexer, log):
    res = log
    if "user_idx_indexed" in log.columns:
        res = (
            inv_user_indexer.transform(res)
            .drop("user_idx_indexed")
            .withColumn("user_idx", sf.col("user_idx").cast(train.schema["user_idx"].dataType))
        )
    if "item_idx_indexed" in log.columns:
        res = (
            inv_item_indexer.transform(res)
            .drop("item_idx_indexed")
            .withColumn("item_idx", sf.col("item_idx").cast(train.schema["item_idx"].dataType))
        )
    return res

In [34]:
def get_pred_from_brp(fitted_brp):
    test_all = fitted_brp.approxSimilarityJoin(user_vectors_brp_n, 
                                          item_vectors_brp_n,
                                          threshold = 2.,
                                          distCol="EuclideanDistance").select(
    sf.col("datasetA.id").alias("user_idx_indexed"),
    sf.col("datasetB.id").alias("item_idx_indexed"),
    (sf.col("EuclideanDistance") * sf.lit(-1)).alias('relevance'))
    
    test_res = test_all.withColumn("temp_rank",
        sf.row_number().over(
            Window.partitionBy("user_idx_indexed").orderBy(
                sf.col("relevance").desc()
            )
        )
    ).filter(sf.col("temp_rank") <= sf.lit(max_items_to_retrieve + K))

    test_res = convert_back(inv_user_indexer, inv_item_indexer, test_res)
    test_res = test_res.cache()
    test_res_filter = get_top_k_recs(filter_seen(train, test_res, K), K)
    test_res = get_top_k_recs(test_res, K)
    test_res = test_res.cache()
    test_res.count()
    return test_res_filter, test_res

In [35]:
%%time
start = time.time()
brp_pred_filtered, brp_pred = get_pred_from_brp(brp_model)   #time depends on threshold
end=time.time()
time_dict['LSH'] = exec_time(start,end)

CPU times: user 322 ms, sys: 97.7 ms, total: 420 ms
Wall time: 18.4 s


In [36]:
prediction_quality.add_result("LSH", brp_pred_filtered)

## ScaNN

In [37]:
item_vectors_np_scann = np.array(list(item_vectors_np))
user_vectors_np_scann = np.array(list(user_vectors_np))

In [38]:
def compute_recall(neighbors, true_neighbors):
    total = 0
    for gt_row, row in zip(true_neighbors, neighbors):
        total += np.intersect1d(gt_row, row).shape[0]
    return total / true_neighbors.size

In [39]:
%%time
start = time.time()
searcher = scann.scann_ops_pybind.builder(item_vectors_np_scann, K + int(max_items_to_retrieve/20), "dot_product").score_brute_force().build()
end = time.time()
index_time_dict['SCANN'] = exec_time(start,end)

CPU times: user 5.5 ms, sys: 15.1 ms, total: 20.6 ms
Wall time: 9.76 ms


In [40]:
def get_neighbours_scann(searcher, vectors, user_ids_list, item_ids_list, k):
   
    neighbors, distances = searcher.search_batched(user_vectors_np_scann)
    pd_res = pd.DataFrame({'user_idx': np.repeat(user_ids_list, k),'item_idx': neighbors.flatten(), 'relevance': distances.flatten()})
    spark_res = spark.createDataFrame(pd_res)
    ids_mapping = pd.DataFrame(list(zip(range(len(list(item_ids_list))),
                                                 list(item_ids_list))), columns = ['item_idx', 'item_id'])    

    ids_mapping = spark.createDataFrame(ids_mapping)    
    spark_res = spark_res.join(ids_mapping, on='item_idx').drop('item_idx').orderBy('user_idx')
    def withColumnRenamed(existingName, newName): DataFrame
    spark_res = spark_res.withColumnRenamed("item_id","item_idx")
    return spark_res
    

In [41]:
%%time
start = time.time()
scann_res = get_neighbours_scann(searcher, user_vectors_np_scann, user_ids_np, item_ids_np, K + int(max_items_to_retrieve/20)) #K + max_items_to_retrieve
scann_res = filter_seen(train, scann_res, K)
scann_res = scann_res.cache()
scann_res.count()
end = time.time()
time_dict['SCANN'] = exec_time(start,end)

  Unsupported type in conversion from Arrow: uint32
Attempting non-optimization as 'spark.sql.execution.arrow.pyspark.fallback.enabled' is set to true.
  warn(msg)


CPU times: user 1.78 s, sys: 46.9 ms, total: 1.83 s
Wall time: 4.66 s


In [42]:
prediction_quality.add_result("SCANN", scann_res)
prediction_quality.results.sort_values('NDCG@5', ascending=False)

Unnamed: 0,HitRate@1,HitRate@5,MAP@5,NDCG@5
HNSW,0.280948,0.561018,0.180984,0.252473
ALS,0.279192,0.56014,0.179415,0.251115
SCANN,0.274802,0.537313,0.17106,0.239966
LSH,0.23266,0.47849,0.13429,0.196782


In [43]:
res = prediction_quality.results

In [44]:
res['index_time'] = index_time_dict.values()

In [45]:
res['inference_time'] = time_dict.values()

In [46]:
res

Unnamed: 0,HitRate@1,HitRate@5,MAP@5,NDCG@5,index_time,inference_time
ALS,0.279192,0.56014,0.179415,0.251115,0.0,202.448
HNSW,0.280948,0.561018,0.180984,0.252473,0.1619,3.6091
LSH,0.23266,0.47849,0.13429,0.196782,0.0702,18.3989
SCANN,0.274802,0.537313,0.17106,0.239966,0.0097,4.6555
