In [3]:
import os

from lightfm import LightFM
from lightfm.evaluation import precision_at_k
from pyspark.sql import functions as sf, SparkSession, Window
from replay.session_handler import State
from replay.data_preparator import DataPreparator, Indexer
from replay.splitters import DateSplitter, UserSplitter
from replay.utils import get_log_info, get_top_k_recs
from replay.metrics import Coverage, HitRate, NDCG, MAP, Precision
from replay.experiment import Experiment
from rs_datasets import MillionSongDataset, MovieLens
from scipy.sparse import csr_matrix

In [2]:
os.environ["PYSPARK_PYTHON"]="/opt/conda/envs/lightfm-gloo/bin/python3.8"

In [4]:
spark_sess = (
    SparkSession
        .builder
        .master("local[6]")
        .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
        .config("spark.kryoserializer.buffer.max", "512m")
        .config("spark.driver.memory", "64g")
        .config("spark.executor.memory", "64g")
        .config("spark.sql.shuffle.partitions", "18")
        .config("spark.default.parallelism", "18")
        .config("spark.sql.execution.arrow.pyspark.enabled", "true")
        .config("spark.local.dir", "/tmp")
        .getOrCreate()
)

spark = State(spark_sess).session
spark.sparkContext.setLogLevel("ERROR")

In [5]:
cores=6
K = 1000
SEED=22

In [6]:
dataset_size = "10m"
data = MovieLens(dataset_size)

In [7]:
preparator = DataPreparator()

log = preparator.transform(
            columns_mapping={
                "user_id": "user_id",
                "item_id": "item_id",
                "relevance": "rating", 
                "timestamp": "timestamp",
            },
            data=data.ratings
)

log = log.withColumnRenamed("user_id","user_idx").withColumnRenamed("item_id", "item_idx")

log = log.repartition(cores).cache()
log.write.mode('overwrite').format('noop').save()

14-Feb-23 11:59:53, replay, INFO: Columns with ids of users or items are present in mapping. The dataframe will be treated as an interactions log.


In [8]:
only_positives_log = log.filter(sf.col('relevance') >= 3)
train_spl = DateSplitter(
    test_start=0.2,
    drop_cold_items=True,
    drop_cold_users=True,
)

train, test = train_spl.split(only_positives_log)
print('train info:\n', get_log_info(train))
print('test info:\n', get_log_info(test))
train = train.withColumn('relevance', sf.lit(1))
test = test.withColumn('relevance', sf.lit(1))

train info:
 total lines: 6593699, total users: 59263, total items: 8888
test info:
 total lines: 196480, total users: 3171, total items: 7482


In [9]:
train = train.cache()
test = test.cache()
train.write.mode('overwrite').format('noop').save()
test.write.mode('overwrite').format('noop').save()

In [10]:
prediction_quality = Experiment(
    test, 
    {
        NDCG(): [5, 10, 25, 100, 500, 1000],  
        MAP(): [5, 10, 25, 100, 500, 1000],
        HitRate(): [5, 10, 25, 100, 500, 1000],
        Precision(): [5, 10, 25, 100, 500, 1000]
    },
)


In [11]:
train_repartitioned = train.repartition(cores, "user_idx")

In [12]:
model = LightFM(loss='warp', random_state=22, max_sampled=100, learning_rate=0.05)

In [14]:
from spark_gloo_lightfm import LightFMGlooWrap
lightfm_wrap = LightFMGlooWrap(
    model=model, 
    world_size=cores,
    use_spark=True,
)

In [15]:
%%time
wrapper = lightfm_wrap.fit_partial(train_repartitioned, verbose=True, epochs=30)

Spark executors launched
CPU times: user 70.6 ms, sys: 45 ms, total: 116 ms
Wall time: 2min 42s


In [22]:
pandas_log = train.toPandas().copy()
items_count = pandas_log['item_idx'].max() + 1

interactions_matrix_sparse = csr_matrix(
    (pandas_log.relevance, (pandas_log.user_idx, pandas_log.item_idx)),
    shape=(pandas_log['user_idx'].max() + 1, items_count),
    )

In [23]:
pandas_log_test = test.toPandas().copy()

interactions_matrix_sparse_test = csr_matrix(
    (pandas_log_test.relevance, (pandas_log_test.user_idx, pandas_log_test.item_idx)),
    shape=(pandas_log['user_idx'].max() + 1, items_count),
    )

In [24]:
train_matrix = interactions_matrix_sparse.tocoo()
test_matrix = interactions_matrix_sparse_test.tocoo()

In [26]:
print(precision_at_k(wrapper.model, test_matrix, train_interactions=train_matrix, k=5).mean())
print(precision_at_k(wrapper.model, test_matrix, train_interactions=train_matrix, k=10).mean())

0.24585305

In [28]:
test_pairs = test.select('user_idx').distinct().crossJoin(train.select('item_idx').distinct())

In [29]:
filtered_test = test_pairs.join(
    train, 
    [test_pairs.user_idx == train.user_idx, test_pairs.item_idx == train.item_idx], 
    "leftanti"
)

In [39]:
filtered_test_pd = filtered_test.toPandas()

In [40]:
filtered_test_pd["relevance"] = wrapper.model.predict(
        user_ids=filtered_test_pd["user_idx"].to_numpy(),
        item_ids=filtered_test_pd["item_idx"].to_numpy(),
        num_threads=40,
)

In [41]:
filtered_test_pd.head(3)

Unnamed: 0,user_idx,item_idx,relevance
0,3,36,-1.255437
1,3,47,-0.742236
2,3,74,-3.778671


In [35]:
prediction_quality.add_result('warp', filtered_test_pd)

In [37]:
prediction_quality.results

Unnamed: 0,HitRate@5,HitRate@10,HitRate@25,HitRate@100,HitRate@500,HitRate@1000,MAP@5,MAP@10,MAP@25,MAP@100,...,NDCG@25,NDCG@100,NDCG@500,NDCG@1000,Precision@5,Precision@10,Precision@25,Precision@100,Precision@500,Precision@1000
warp,0.535793,0.637023,0.752129,0.867865,0.950489,0.96878,0.192501,0.155537,0.115584,0.085876,...,0.216831,0.21959,0.326185,0.385131,0.245853,0.221823,0.183828,0.123929,0.062426,0.041921
