In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from math import floor
import psutil
import os

import pandas as pd

from pyspark.sql import SparkSession

from replay.data_preparator import DataPreparator, Indexer
from replay.experiment import Experiment
from replay.metrics import Coverage, HitRate, NDCG, MAP
from replay.model_handler import save, load, save_indexer, load_indexer
from replay.models import ALSWrap, ItemKNN, SLIM
from replay.session_handler import get_spark_session, State 
from replay.splitters import UserSplitter
from replay.utils import convert2spark, get_log_info
from replay.filters import filter_by_min_count, filter_out_low_ratings


libgomp: Invalid value for environment variable OMP_NUM_THREADS

libgomp: Invalid value for environment variable OMP_NUM_THREADS


In [3]:
K = 5
SEED = 42

In [4]:
spark_memory = floor(psutil.virtual_memory().total / 1024**3 * 0.7)
driver_memory = f"{spark_memory}g"

shuffle_partitions = os.cpu_count() * 3
user_home = os.environ["HOME"]

session = (
    SparkSession.builder.config("spark.driver.memory", driver_memory)
        .config(
            "spark.driver.extraJavaOptions",
            "-Dio.netty.tryReflectionSetAccessible=true",
        )
        .config("spark.sql.shuffle.partitions", str(shuffle_partitions))
        .config("spark.local.dir", os.path.join(user_home, "tmp"))
        .config("spark.driver.maxResultSize", "4g")
        .config("spark.driver.bindAddress", "127.0.0.1")
        .config("spark.driver.host", "localhost")
        .config("spark.sql.execution.arrow.pyspark.enabled", "true")
        .config("spark.kryoserializer.buffer.max", "256m")
        # .config("spark.worker.cleanup.enabled", "true")
        # .config("spark.worker.cleanup.interval", "5")
        # .config("spark.worker.cleanup.appDataTtl", "5")
        .master("local[*]")
        .enableHiveSupport()
        .getOrCreate()
)

23/08/22 17:33:34 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
23/08/22 17:33:34 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).


In [5]:
spark = State(session).session
spark.sparkContext.setLogLevel('ERROR')
spark

In [6]:
spark.sparkContext.getConf().getAll()

[('spark.app.startTime', '1692714814522'),
 ('spark.driver.memory', '1402g'),
 ('spark.sql.shuffle.partitions', '768'),
 ('spark.executor.id', 'driver'),
 ('spark.driver.host', 'localhost'),
 ('spark.driver.port', '37049'),
 ('spark.app.name', 'pyspark-shell'),
 ('spark.driver.bindAddress', '127.0.0.1'),
 ('spark.driver.extraJavaOptions',
  '-Dio.netty.tryReflectionSetAccessible=true'),
 ('spark.sql.warehouse.dir',
  'file:/home/jovyan/n.belousov/sber-recsys/notebooks/test_notebooks/spark-warehouse'),
 ('spark.sql.catalogImplementation', 'hive'),
 ('spark.rdd.compress', 'True'),
 ('spark.local.dir', '/home/jovyan/tmp'),
 ('spark.serializer.objectStreamReset', '100'),
 ('spark.kryoserializer.buffer.max', '256m'),
 ('spark.master', 'local[*]'),
 ('spark.submit.pyFiles', ''),
 ('spark.submit.deployMode', 'client'),
 ('spark.app.id', 'local-1692714815233'),
 ('spark.ui.showConsoleProgress', 'true'),
 ('spark.sql.execution.arrow.pyspark.enabled', 'true'),
 ('spark.driver.maxResultSize', '4g

In [7]:
def print_config_param(session, conf_name):
    # get current spark session configuration:
    conf = session.sparkContext.getConf().getAll()
    print(conf)
    # get num partitions
    print(f'{conf_name}: {dict(conf)[conf_name]}')

In [8]:
print_config_param(spark, 'spark.sql.shuffle.partitions')

[('spark.driver.memory', '1402g'), ('spark.sql.shuffle.partitions', '768'), ('spark.repl.local.jars', 'file:/home/jovyan/n.belousov/sber-recsys/notebooks/test_notebooks/jars/replay_2.12-0.1_spark_3.1.jar'), ('spark.jars', 'jars/replay_2.12-0.1_spark_3.1.jar'), ('spark.executor.id', 'driver'), ('spark.driver.host', 'localhost'), ('spark.app.startTime', '1692714639594'), ('spark.app.name', 'pyspark-shell'), ('spark.driver.bindAddress', '127.0.0.1'), ('spark.driver.extraJavaOptions', '-Dio.netty.tryReflectionSetAccessible=true'), ('spark.driver.port', '32877'), ('spark.sql.warehouse.dir', 'file:/home/jovyan/n.belousov/sber-recsys/notebooks/test_notebooks/spark-warehouse'), ('spark.sql.catalogImplementation', 'hive'), ('spark.rdd.compress', 'True'), ('spark.local.dir', '/home/jovyan/tmp'), ('spark.serializer.objectStreamReset', '100'), ('spark.master', 'local[*]'), ('spark.submit.pyFiles', ''), ('spark.kryoserializer.buffer.max', '256m'), ('spark.submit.deployMode', 'client'), ('spark.app.

## Data preprocessing

In [7]:
df = pd.read_parquet("../../data/amazon_cds/CDs_and_Vinyl.parquet")

In [8]:
df.head(10)

Unnamed: 0,userId,itemId,rating,timestamp
0,1393774,A171I27YBM4FL6,5.0,1461888000
1,1393774,A1H1DL4K669VQ9,5.0,1461888000
2,1393774,A23WIHT5886G36,5.0,1461024000
3,1393774,A3SZNOJP8OL26X,5.0,1459296000
4,1393774,A3V5XBBT7OZG5G,5.0,1456185600
5,1393774,A3SNL7UJY7GWBI,5.0,1455148800
6,1393774,A3478QRKQDOPQ2,5.0,1448668800
7,1393774,A3CP0CNKNFCYBZ,4.0,1437177600
8,1393774,A3OIIDZ137NJOU,5.0,1436572800
9,1393774,A3GVAG32NMMYT4,4.0,1432598400


## Data preporator

In [9]:
preparator = DataPreparator()

In [11]:
log = preparator.transform(columns_mapping={
    'user_id': 'userId',
    'item_id': 'itemId',
    'relevance': 'rating',
    'timestamp': 'timestamp'
}, data=df)

01-Aug-23 13:41:53, replay, INFO: Columns with ids of users or items are present in mapping. The dataframe will be treated as an interactions log.


In [12]:
log.show(2)

+----------+--------------+---------+-------------------+
|   user_id|       item_id|relevance|          timestamp|
+----------+--------------+---------+-------------------+
|0001393774|A171I27YBM4FL6|      5.0|2016-04-29 00:00:00|
|0001393774|A1H1DL4K669VQ9|      5.0|2016-04-29 00:00:00|
+----------+--------------+---------+-------------------+
only showing top 2 rows



In [13]:
get_log_info(log, user_col='user_id', item_col='item_id')

                                                                                

'total lines: 4543369, total users: 434060, total items: 1944316'

## Filtering

In [14]:
log = filter_out_low_ratings(log, .2)
get_log_info(log, user_col='user_id', item_col='item_id')

                                                                                

'total lines: 4543369, total users: 434060, total items: 1944316'

In [15]:
log = filter_by_min_count(log, num_entries=5, group_by='user_id')
get_log_info(log, user_col='user_id', item_col='item_id')

01-Aug-23 13:42:11, replay, INFO: current threshold removes 0.11964293457123998% of data
                                                                                

'total lines: 3999787, total users: 145522, total items: 1792030'

## Indexing

In [16]:
indexer = Indexer(user_col='user_id', item_col='item_id')

In [17]:
indexer.fit(users=log.select('user_id'), items=log.select('item_id'))

                                                                                

In [18]:
log_replay = indexer.transform(df=log)

                                                                                

In [19]:
log_replay.show(5)

[Stage 63:>                                                         (0 + 1) / 1]

+--------+--------+---------+-------------------+
|user_idx|item_idx|relevance|          timestamp|
+--------+--------+---------+-------------------+
|  108495|  687152|      5.0|2017-01-25 00:00:00|
|  108495|   82697|      5.0|2017-01-20 00:00:00|
|  108495|  781273|      3.0|2017-01-11 00:00:00|
|  108495|  890670|      1.0|2014-08-28 00:00:00|
|  108495|  106139|      3.0|2014-08-27 00:00:00|
+--------+--------+---------+-------------------+
only showing top 5 rows



                                                                                

## Split

In [20]:
splitter = UserSplitter(
    drop_cold_items=True,
    drop_cold_users=True,
    item_test_size=K,
    user_test_size=500,
    seed=SEED,
    shuffle=True
)
train, test = splitter.split(log_replay)

                                                                                

In [20]:
print(train.count(), test.count())



19997763 2500


                                                                                

In [21]:
test.is_cached

True

## Models training

In [22]:
item_knn = ItemKNN()

In [23]:
item_knn.fit(train)

                                                                                

In [24]:
recs = item_knn.predict(train, K, users=test.select('user_idx').distinct(), filter_seen_items=True)

                                                                                

In [26]:
recs.show(5)

+--------+--------+------------------+
|user_idx|item_idx|         relevance|
+--------+--------+------------------+
|   12853|      40|  13.4743145837735|
|   12853|      89| 12.78011963595636|
|   12853|     108| 8.871822092603605|
|   12853|      20| 8.112028529202195|
|   12853|      49|5.9643452591206785|
+--------+--------+------------------+
only showing top 5 rows

