<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc" style="margin-top: 1em;"><ul class="toc-item"><li><span><a href="#Introduction" data-toc-modified-id="Introduction-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Introduction</a></span></li><li><span><a href="#Setup" data-toc-modified-id="Setup-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Setup</a></span></li><li><span><a href="#Libraries" data-toc-modified-id="Libraries-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>Libraries</a></span></li><li><span><a href="#Data" data-toc-modified-id="Data-4"><span class="toc-item-num">4&nbsp;&nbsp;</span>Data</a></span></li><li><span><a href="#Split" data-toc-modified-id="Split-5"><span class="toc-item-num">5&nbsp;&nbsp;</span>Split</a></span></li></ul></div>

# Introduction

This notebook splits the dataset into a train, validation and test split:

# Setup

In [1]:
%%capture
%cd ..

# Libraries

In [2]:
from src import data
import pyspark.sql.functions as f
from lenskit import crossfold as xf

In [3]:
from pyspark.sql import SparkSession

# Data

In [4]:
df = data.get_data("/tmp/ml-20m/ratings.csv")
df.show(5)

+------+-------+------+-------------------+
|userId|movieId|rating|          timestamp|
+------+-------+------+-------------------+
|     1|      2|   3.5|2005-04-02 23:53:47|
|     1|     29|   3.5|2005-04-02 23:31:16|
|     1|     32|   3.5|2005-04-02 23:33:39|
|     1|     47|   3.5|2005-04-02 23:32:07|
|     1|     50|   3.5|2005-04-02 23:29:40|
+------+-------+------+-------------------+
only showing top 5 rows



In [5]:
df = df \
    .withColumnRenamed("userId", "user") \
    .withColumnRenamed("movieId", "item") \
    .cache()
df.show(5)

+----+----+------+-------------------+
|user|item|rating|          timestamp|
+----+----+------+-------------------+
|   1|   2|   3.5|2005-04-02 23:53:47|
|   1|  29|   3.5|2005-04-02 23:31:16|
|   1|  32|   3.5|2005-04-02 23:33:39|
|   1|  47|   3.5|2005-04-02 23:32:07|
|   1|  50|   3.5|2005-04-02 23:29:40|
+----+----+------+-------------------+
only showing top 5 rows



# Split

Check timestamp range:

In [6]:
df.selectExpr("min(timestamp)", "max(timestamp)").show()

+-------------------+-------------------+
|     min(timestamp)|     max(timestamp)|
+-------------------+-------------------+
|1995-01-09 11:46:44|2015-03-31 06:40:02|
+-------------------+-------------------+



Training on whole dataset may take too long so we filter by timestamp to reduce dataset size:

In [7]:
df = df.filter(f.expr("timestamp >= '2010-01-01'"))
df.show(5)

+----+-----+------+-------------------+
|user| item|rating|          timestamp|
+----+-----+------+-------------------+
|  11| 4226|   5.0|2011-01-12 01:35:59|
|  11| 5971|   5.0|2011-01-12 01:36:41|
|  11| 6291|   5.0|2011-01-12 01:35:13|
|  11| 7153|   5.0|2011-01-12 01:35:32|
|  11|30707|   5.0|2011-01-12 01:36:16|
+----+-----+------+-------------------+
only showing top 5 rows



In [8]:
%%time
train_df, test_df = next(xf.partition_users(df.toPandas(), 1, xf.SampleFrac(0.2)))

Wall time: 13min 24s


In [9]:
nrows_train, _ = train_df.shape
nrows_test, _ = test_df.shape

print(f"Number of rows in training set: {nrows_train:,d}")
print(f"Number of rows in test set: {nrows_test:,d}")

Number of rows in training set: 3,078,276
Number of rows in test set: 769,271


Save the results:

In [10]:
 spark = SparkSession.builder.getOrCreate()

In [11]:
%%time
# df.write.format(source).mode("overwrite").save(path)
spark.createDataFrame(train_df).write.mode("overwrite").parquet("/tmp/ml-20m/train_df.parquet")
spark.createDataFrame(test_df).write.mode("overwrite").parquet("/tmp/ml-20m/test_df.parquet")

Wall time: 21min 48s
