In [4]:
import pandas as pd
import pyspark
import unittest
from mmlspark.TrainTestSplit import *
from mmlspark.evaluate import *
from pyspark.ml.tuning import *
from pyspark.sql.functions import col
from pyspark.sql.types import *

In [6]:
def getDF():
    from pyspark.sql import SparkSession

    spark = SparkSession.builder \
        .master("local[*]") \
        .appName("EvaluationTest") \
        .getOrCreate()

    # Synthesize some testing data.

    df_rating = pd.DataFrame({
        'customerID': [1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4],
        'itemID': [3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 2, 3, 4, 5, 6, 7, 2, 3, 4, 5, 6],
        'rating': [5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5],
        'timeStamp': [d.strftime('%Y%m%d') for d in pd.date_range('2018-01-01', '2018-01-21')]
    })

    return spark.createDataFrame(df_rating)


In [10]:
dfs_rating = getDF()
pyspark.sql.DataFrame.min_rating_filter = TrainTestSplit.min_rating_filter

dfs_rating.min_rating_filter(min_rating=6, by_customer=True).count()

6

In [12]:
dfs_rating = getDF()

pyspark.sql.DataFrame.stratified_split = TrainTestSplit.stratified_split
dfs_train, dfs_test = dfs_rating.stratified_split(min_rating=3, by_customer=True, fixed_test_sample=False,
                                                  ratio=0.5)

dfs_train, dfs_test = dfs_rating.stratified_split(min_rating=3, by_customer=True, fixed_test_sample=True,
                                                  sample=2)
dfs_train.show()
dfs_test.show()

+----------+------+------+
|customerID|itemID|rating|
+----------+------+------+
|         1|     6|   5.0|
|         1|     4|   5.0|
|         1|     3|   5.0|
|         2|     4|   5.0|
|         2|     2|   5.0|
|         2|     1|   5.0|
|         3|     6|   5.0|
|         3|     3|   5.0|
|         3|     5|   5.0|
|         3|     2|   5.0|
|         4|     5|   5.0|
|         4|     3|   5.0|
|         4|     2|   5.0|
+----------+------+------+

+----------+------+------+
|customerID|itemID|rating|
+----------+------+------+
|         1|     7|   5.0|
|         1|     5|   5.0|
|         2|     5|   5.0|
|         2|     3|   5.0|
|         3|     7|   5.0|
|         3|     4|   5.0|
|         4|     6|   5.0|
|         4|     4|   5.0|
+----------+------+------+



In [13]:
dfs_rating = getDF()
pyspark.sql.DataFrame.chronological_split = TrainTestSplit.chronological_split

dfs_train, dfs_test = dfs_rating.chronological_split(min_rating=3, by_customer=True, fixed_test_sample=False,
                                                     ratio=0.3)
dfs_train.show()
dfs_test.show()

dfs_train, dfs_test = dfs_rating.chronological_split(min_rating=3, by_customer=True, fixed_test_sample=True,
                                                     sample=3)
dfs_train.show()
dfs_test.show()

+----------+------+---------+
|customerID|itemID|timeStamp|
+----------+------+---------+
|         1|     5| 20180103|
|         1|     4| 20180102|
|         1|     3| 20180101|
|         3|     5| 20180114|
|         3|     4| 20180113|
|         3|     3| 20180112|
|         3|     2| 20180111|
|         2|     3| 20180108|
|         2|     2| 20180107|
|         2|     1| 20180106|
|         4|     4| 20180119|
|         4|     3| 20180118|
|         4|     2| 20180117|
+----------+------+---------+

+----------+------+---------+
|customerID|itemID|timeStamp|
+----------+------+---------+
|         1|     7| 20180105|
|         1|     6| 20180104|
|         3|     7| 20180116|
|         3|     6| 20180115|
|         2|     5| 20180110|
|         2|     4| 20180109|
|         4|     6| 20180121|
|         4|     5| 20180120|
+----------+------+---------+

+----------+------+---------+
|customerID|itemID|timeStamp|
+----------+------+---------+
|         1|     4| 20180102|
|       

In [14]:
dfs_rating = getDF()
pyspark.sql.DataFrame.non_overlapping_split = TrainTestSplit.non_overlapping_split

dfs_train, dfs_test = dfs_rating.non_overlapping_split(min_rating=3, by_customer=True, fixed_test_sample=False,
                                                       ratio=0.5)
dfs_train.show()
dfs_test.show()

dfs_train, dfs_test = dfs_rating.non_overlapping_split(min_rating=3, by_customer=True, fixed_test_sample=True,
                                                       sample=3)
dfs_train.show()
dfs_test.show()

+----------+------+------+---------+
|customerID|itemID|rating|timeStamp|
+----------+------+------+---------+
|         1|     3|     5| 20180101|
|         1|     4|     5| 20180102|
|         1|     5|     5| 20180103|
|         1|     6|     5| 20180104|
|         1|     7|     5| 20180105|
|         3|     2|     5| 20180111|
|         3|     3|     5| 20180112|
|         3|     4|     5| 20180113|
|         3|     5|     5| 20180114|
|         3|     6|     5| 20180115|
|         3|     7|     5| 20180116|
|         4|     2|     5| 20180117|
|         4|     3|     5| 20180118|
|         4|     4|     5| 20180119|
|         4|     5|     5| 20180120|
|         4|     6|     5| 20180121|
+----------+------+------+---------+

+----------+------+------+---------+
|customerID|itemID|rating|timeStamp|
+----------+------+------+---------+
|         2|     1|     5| 20180106|
|         2|     2|     5| 20180107|
|         2|     3|     5| 20180108|
|         2|     4|     5| 20180109|


In [15]:
dfs_rating = getDF()
pyspark.sql.DataFrame.random_split = TrainTestSplit.random_split

dfs_train, dfs_test = dfs_rating.random_split(min_rating=3, by_customer=True, fixed_test_sample=False,
                                              ratio=0.5)
dfs_train.show()
dfs_test.show()

dfs_train, dfs_test = dfs_rating.random_split(min_rating=3, by_customer=True, fixed_test_sample=True, sample=3)
dfs_train.show()
dfs_test.show()

+----------+------+------+---------+
|customerID|itemID|rating|timeStamp|
+----------+------+------+---------+
|         1|     5|     5| 20180103|
|         1|     7|     5| 20180105|
|         3|     2|     5| 20180111|
|         3|     3|     5| 20180112|
|         3|     4|     5| 20180113|
|         3|     5|     5| 20180114|
|         3|     7|     5| 20180116|
|         2|     3|     5| 20180108|
|         2|     5|     5| 20180110|
|         4|     2|     5| 20180117|
|         4|     5|     5| 20180120|
+----------+------+------+---------+

+----------+------+------+---------+
|customerID|itemID|rating|timeStamp|
+----------+------+------+---------+
|         1|     3|     5| 20180101|
|         1|     4|     5| 20180102|
|         1|     6|     5| 20180104|
|         3|     6|     5| 20180115|
|         2|     1|     5| 20180106|
|         2|     2|     5| 20180107|
|         2|     4|     5| 20180109|
|         4|     3|     5| 20180118|
|         4|     4|     5| 20180119|
