In [1]:
%load_ext autoreload
%autoreload 2
%config Completer.use_jedi = False

In [2]:
! pip install rs-datasets

# The notebook contains an example of features preprocessing with PySpark for RePlay LightFM model wrapper and includes:
1. Data loading and reindexing
2. Features preprocessing with pyspark
3. Building LightFM model based on interaction matrix and features
4. Model evaluation

In [3]:
import warnings
from optuna.exceptions import ExperimentalWarning
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=ExperimentalWarning)

In [4]:
from pyspark.sql.types import IntegerType
from pyspark.sql.functions import array_contains, col, explode, split, substring

from replay.data_preparator import Indexer, DataPreparator
from replay.experiment import Experiment
from replay.metrics import HitRate, NDCG, MAP, Coverage
from replay.models import LightFMWrap
from replay.session_handler import State
from replay.splitters import DateSplitter
from replay.utils import get_log_info
from rs_datasets import MovieLens

In [5]:
spark = State().session
spark

22/07/05 13:33:34 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
22/07/05 13:33:34 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).
22/07/05 13:33:35 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


In [6]:
spark.sparkContext.setLogLevel('ERROR')

In [7]:
K=10
SEED=1234

# 1. Data loading

We will use MovieLens 10m dataset from rs_datasets package, which contains a list of recommendations datasets.

In [8]:
data = MovieLens("10m")
data.info()

ratings


Unnamed: 0,user_id,item_id,rating,timestamp
0,1,122,5.0,838985046
1,1,185,5.0,838983525
2,1,231,5.0,838983392



items


Unnamed: 0,item_id,title,genres
0,1,Toy Story (1995),Adventure|Animation|Children|Comedy|Fantasy
1,2,Jumanji (1995),Adventure|Children|Fantasy
2,3,Grumpier Old Men (1995),Comedy|Romance



tags


Unnamed: 0,user_id,item_id,tag,timestamp
0,15,4973,excellent!,1215184630
1,20,1747,politics,1188263867
2,20,1747,satire,1188263867





### 1.1 Convert interaction log to RePlay format

In [9]:
preparator = DataPreparator()

In [10]:
%%time
log = preparator.transform(columns_mapping={'user_id': 'user_id',
                                      'item_id': 'item_id',
                                      'relevance': 'rating',
                                      'timestamp': 'timestamp'
                                     }, 
                           data=data.ratings)
item_features = preparator.transform(columns_mapping={'item_id': 'item_id'}, 
                           data=data.items)

05-Jul-22 13:33:39, replay, INFO: Columns with ids of users or items are present in mapping. The dataframe will be treated as an interactions log.
05-Jul-22 13:33:47, replay, INFO: Column with ids of users or items is absent in mapping. The dataframe will be treated as a users'/items' features dataframe.


CPU times: user 348 ms, sys: 184 ms, total: 533 ms
Wall time: 9 s


In [11]:
log.show(2)

+-------+-------+---------+-------------------+
|user_id|item_id|relevance|          timestamp|
+-------+-------+---------+-------------------+
|      1|    122|      5.0|1996-08-02 11:24:06|
|      1|    185|      5.0|1996-08-02 10:58:45|
+-------+-------+---------+-------------------+
only showing top 2 rows



In [12]:
item_features.show(2)

+-------+----------------+--------------------+
|item_id|           title|              genres|
+-------+----------------+--------------------+
|      1|Toy Story (1995)|Adventure|Animati...|
|      2|  Jumanji (1995)|Adventure|Childre...|
+-------+----------------+--------------------+
only showing top 2 rows



<a id='indexing'></a>
### 1.2. Indexing

Convert given users' and items' identifiers (\_id) to integers starting at zero without gaps (\_idx) with Indexer class.

In [13]:
indexer = Indexer(user_col='user_id', item_col='item_id')

In [14]:
%%time
indexer.fit(users=log.select('user_id'),
           items=log.select('item_id').union(item_features.select('item_id')))

                                                                                

CPU times: user 179 ms, sys: 7.34 ms, total: 186 ms
Wall time: 3.36 s


In [15]:
%%time
log_replay = indexer.transform(df=log)
log_replay.show(2)

                                                                                

+--------+--------+---------+-------------------+
|user_idx|item_idx|relevance|          timestamp|
+--------+--------+---------+-------------------+
|   65232|    1057|      5.0|1996-08-02 11:24:06|
|   65232|      76|      5.0|1996-08-02 10:58:45|
+--------+--------+---------+-------------------+
only showing top 2 rows

CPU times: user 202 ms, sys: 11.4 ms, total: 213 ms
Wall time: 3.75 s


In [16]:
%%time
item_features_replay = indexer.transform(df=item_features)
item_features_replay.show(2)

+--------+----------------+--------------------+
|item_idx|           title|              genres|
+--------+----------------+--------------------+
|      11|Toy Story (1995)|Adventure|Animati...|
|     117|  Jumanji (1995)|Adventure|Childre...|
+--------+----------------+--------------------+
only showing top 2 rows

CPU times: user 35 ms, sys: 631 µs, total: 35.7 ms
Wall time: 851 ms


### 1.3. Data split

In [17]:
# train/test split 
train_spl = DateSplitter(
    test_start=0.2,
    drop_cold_items=True,
    drop_cold_users=True,

)
train, test = train_spl.split(log_replay)
print('train info:\n', get_log_info(train, user_col='user_idx', item_col='item_idx'))
print('test info:\n', get_log_info(test, user_col='user_idx', item_col='item_idx'))

                                                                                

train info:
 total lines: 8000043, total users: 59522, total items: 8989




test info:
 total lines: 249418, total users: 3196, total items: 8180


                                                                                

In [18]:
train.is_cached

True

# 2. Features preprocessing with pyspark

#### Year

In [19]:
year = item_features_replay.withColumn('year', substring(col('title'), -5, 4).astype(IntegerType())).select('item_idx', 'year')
year.show(2)

+--------+----+
|item_idx|year|
+--------+----+
|      11|1995|
|     117|1995|
+--------+----+
only showing top 2 rows



#### Genres

In [20]:
genres = (
    item_features_replay
    .select(
        "item_idx",
        split("genres", "\|").alias("genres")
    )
)

In [21]:
genres.show()

+--------+--------------------+
|item_idx|              genres|
+--------+--------------------+
|      11|[Adventure, Anima...|
|     117|[Adventure, Child...|
|     274|   [Comedy, Romance]|
|    1382|[Comedy, Drama, R...|
|     320|            [Comedy]|
|      89|[Action, Crime, T...|
|     252|   [Comedy, Romance]|
|    2179|[Adventure, Child...|
|    1018|            [Action]|
|      51|[Action, Adventur...|
|     139|[Comedy, Drama, R...|
|    1112|    [Comedy, Horror]|
|    2403|[Animation, Child...|
|     682|             [Drama]|
|    1348|[Action, Adventur...|
|     189|      [Crime, Drama]|
|     111|[Comedy, Drama, R...|
|     880|[Comedy, Drama, T...|
|     129|            [Comedy]|
|    1039|[Action, Comedy, ...|
+--------+--------------------+
only showing top 20 rows



In [22]:
genres_list = (
    genres.select(explode("genres").alias("genre"))
    .distinct().filter('genre <> "(no genres listed)"')
    .toPandas()["genre"].tolist()
)

In [23]:
genres_list

['Documentary',
 'IMAX',
 'Adventure',
 'Animation',
 'Comedy',
 'Thriller',
 'Sci-Fi',
 'Musical',
 'Horror',
 'Action',
 'Fantasy',
 'War',
 'Mystery',
 'Drama',
 'Film-Noir',
 'Crime',
 'Western',
 'Romance',
 'Children']

In [24]:
item_features = genres
for genre in genres_list:
    item_features = item_features.withColumn(
        genre,
        array_contains(col("genres"), genre).astype(IntegerType())
    )
item_features = item_features.drop("genres").cache()
item_features.count()

10681

In [25]:
item_features = item_features.join(year, on='item_idx', how='inner')
item_features.cache()
item_features.count()

10681

In [26]:
item_features.show(2)

+--------+-----------+----+---------+---------+------+--------+------+-------+------+------+-------+---+-------+-----+---------+-----+-------+-------+--------+----+
|item_idx|Documentary|IMAX|Adventure|Animation|Comedy|Thriller|Sci-Fi|Musical|Horror|Action|Fantasy|War|Mystery|Drama|Film-Noir|Crime|Western|Romance|Children|year|
+--------+-----------+----+---------+---------+------+--------+------+-------+------+------+-------+---+-------+-----+---------+-----+-------+-------+--------+----+
|      11|          0|   0|        1|        1|     1|       0|     0|      0|     0|     0|      1|  0|      0|    0|        0|    0|      0|      0|       1|1995|
|     117|          0|   0|        1|        0|     0|       0|     0|      0|     0|     0|      1|  0|      0|    0|        0|    0|      0|      0|       1|1995|
+--------+-----------+----+---------+---------+------+--------+------+-------+------+------+-------+---+-------+-----+---------+-----+-------+-------+--------+----+
only showi

# 3. Building LightFM model based on interaction matrix and item features

In [27]:
model_feat = LightFMWrap(random_state=SEED, loss='warp', no_components=16)

In [28]:
%%time
model_feat.fit(train, item_features=item_features)

                                                                                

CPU times: user 2h 59min 9s, sys: 5.87 s, total: 2h 59min 15s
Wall time: 4min 51s


In [36]:
%%time
recs = model_feat.predict(
    log=train,
    k=K,
    users=test.select('user_idx').distinct(),
    item_features=item_features,
    filter_seen_items=True,
)
recs.cache()
recs.count()

05-Jul-22 13:42:08, replay, INFO: This model can't predict cold users, they will be ignored

CPU times: user 200 ms, sys: 40 ms, total: 240 ms
Wall time: 31.8 s


                                                                                

31960

# 4. Model evaluation

In [37]:
metrics = Experiment(test, {NDCG(): K,
                            MAP() : K,
                            HitRate(): [1, K],
                           Coverage(train): K})
 

                                                                                

In [38]:
%%time
metrics.add_result("LightFM_item_features", recs)
metrics.results

                                                                                

CPU times: user 105 ms, sys: 36 ms, total: 141 ms
Wall time: 8.26 s


Unnamed: 0,Coverage@10,HitRate@1,HitRate@10,MAP@10,NDCG@10
LightFM_item_features,0.118144,0.31164,0.645807,0.171449,0.25576
