Создадим ещё пару базовых решений, качество которых хотелось бы в итоге превзойти.

In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import pyspark.sql.functions as f
import pyspark.sql.types as t
from pyspark.ml.feature import Word2Vec
from pyspark.ml.recommendation import ALS
from pyspark.sql import SparkSession
from sklearn.manifold import TSNE


%matplotlib inline


LAUNCHER_SAMPLE_PATH = 'launcher_sample.csv'

Используем тот же самый класс для разбиения датасета на обучение и тест, чтобы всё было по-честному.

In [2]:
class UserBasedShuffleSplit(object):
    """
    Cross-validation for recommenders splits test and train sets
    across items of the same users: i.e., if item I of user U is
    present in a test set, user U (with some other items) must
    also be present in the corresponding train set.
    """

    def __init__(self, array, test_user_size=0.1, test_item_size=1, n_splits=1):
        assert isinstance(test_user_size, (int, float)), 'test_user_size must be int or float'
        assert isinstance(test_item_size, int), 'test_user_size must be integer'

        self.array = array[['user', 'item']].copy()
        self.array['index'] = np.arange(len(self.array))
        self.test_user_size = test_user_size
        self.test_item_size = test_item_size
        self.n_splits = n_splits

    def get_splittable_users(self):
        """
        Returns all suitable users (those who have sufficient
        amount of items).
        """
        users, counts = np.unique(self.array['user'], return_counts=True)
        return users[counts > self.test_item_size]

    def get_subset_to_split(self, splittable_users):
        """
        Returns a subset of original `array` which contains users
        with sufficient amount of items to split into two sets.
        """
        if isinstance(self.test_user_size, float):
            test_user_size = int(self.test_user_size * len(splittable_users))
        else:
            test_user_size = self.test_user_size
        return self.array[np.in1d(
            self.array['user'],
            np.random.choice(splittable_users, test_user_size, replace=False)
        )]

    def __len__(self):
        return self.n_splits

    def __iter__(self):
        self.splittable_users = self.get_splittable_users()
        for _ in range(self.n_splits):
            splittable_subset = self.get_subset_to_split(self.splittable_users)
            test_idx = []
            for user, subset in splittable_subset.groupby('user'):
                test_idx.append(self.split_user(subset))

            test_idx = np.hstack(test_idx)
            train_idx = np.setdiff1d(np.arange(len(self.array)), test_idx)
            yield train_idx, test_idx

    def split_user(self, subset):
        return np.random.choice(subset['index'], self.test_item_size, replace=False)

И качество будем измерять так же.

In [3]:
def precision_score_at_k(df_test, df_predict, k):
    df_test_dict = {user: user_items for user, user_items in df_test.groupby('user')}
    scores = []
    for user, recommendations in df_predict.groupby('user'):
        user_items = df_test_dict.get(user, [])
        if len(user_items) == 0:
            continue
        intersection = np.intersect1d(user_items['item'], recommendations['item'].iloc[:k])
        n_matches = float(len(intersection))
        scores.append(n_matches / min(k, len(user_items)))
    if not scores:
        raise Exception("Users from test and train set don't intersect!")
    else:
        return np.mean(scores)

Загружаем данные и разбиваем на обучение и тест.

In [4]:
launcher = pd.read_csv(LAUNCHER_SAMPLE_PATH)
launcher.head()

Unnamed: 0,user,item
0,12156,5527
1,7982,15525
2,5614,13600
3,465,14937
4,465,9556


In [5]:
cv = UserBasedShuffleSplit(launcher, test_user_size=1000, test_item_size=10)
train_idx, test_idx = next(iter(cv))

df_train = launcher.iloc[train_idx]
df_test = launcher.iloc[test_idx]

In [6]:
df_train.head()

Unnamed: 0,user,item
0,12156,5527
1,7982,15525
2,5614,13600
3,465,14937
4,465,9556


In [7]:
df_test.head()

Unnamed: 0,user,item
393272,22,1145
385711,22,6964
392536,22,5055
380764,22,8882
380757,22,1778


Для создания базовых решений воспользуемся реализациями популярных алгоритмов ALS и Word2Vec из библиотеки Spark ML.

In [8]:
spark = SparkSession.builder.appName('als_and_w2v_baselines').getOrCreate()

## 1. ALS

Сформируем обучающую и тестовую выборки.

In [9]:
als_train = spark.createDataFrame(df_train)
als_train = als_train.withColumn('rating', f.lit(1))

als_train.cache()
als_train.orderBy('user').show()

+----+-----+------+
|user| item|rating|
+----+-----+------+
|   0| 3719|     1|
|   0| 1376|     1|
|   0| 6736|     1|
|   0| 1367|     1|
|   0| 1332|     1|
|   0|20577|     1|
|   0|21315|     1|
|   0| 6659|     1|
|   0| 1338|     1|
|   0|21496|     1|
|   0| 1334|     1|
|   0| 1454|     1|
|   0|20579|     1|
|   0| 4943|     1|
|   0| 1373|     1|
|   0|13535|     1|
|   0| 1379|     1|
|   0|20174|     1|
|   0|21258|     1|
|   0|11478|     1|
+----+-----+------+
only showing top 20 rows



In [10]:
als_test = spark.createDataFrame(df_test)

als_test.cache()
als_test.orderBy('user').show()

+----+-----+
|user| item|
+----+-----+
|  22| 8882|
|  22| 1227|
|  22| 7056|
|  22|   34|
|  22| 5055|
|  22| 1145|
|  22|10504|
|  22| 1778|
|  22|11058|
|  22| 6964|
|  27| 4048|
|  27|10841|
|  27|10113|
|  27|17436|
|  27| 4755|
|  27|18861|
|  27|10789|
|  27|12259|
|  27|12865|
|  27|18334|
+----+-----+
only showing top 20 rows



Обучим модель и сделаем предсказание.

In [11]:
als = ALS(rank=10, maxIter=5, implicitPrefs=True, seed=1707)

In [12]:
als_model = als.fit(als_train)

Py4JJavaError: An error occurred while calling o63.getParam.
: java.util.NoSuchElementException: Param coldStartStrategy does not exist.
	at org.apache.spark.ml.param.Params$$anonfun$getParam$2.apply(params.scala:601)
	at org.apache.spark.ml.param.Params$$anonfun$getParam$2.apply(params.scala:601)
	at scala.Option.getOrElse(Option.scala:121)
	at org.apache.spark.ml.param.Params$class.getParam(params.scala:600)
	at org.apache.spark.ml.PipelineStage.getParam(Pipeline.scala:42)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:280)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:214)
	at java.lang.Thread.run(Thread.java:745)


In [None]:
predictions = als_model.transform(train)

predictions.cache()
predictions.orderBy(f.col('user'), desc('prediction')).show()

## 2. Word2Vec

Сформируем обучающую и тестовую выборки.

In [13]:
w2v_train = spark.createDataFrame(df_train)\
    .select(f.col('user'), f.col('item').cast(t.StringType()))\
    .orderBy('user', 'item')\
    .groupBy('user')\
    .agg(f.collect_list('item').alias('items'))

w2v_train.cache()
w2v_train.orderBy('user').show()

+----+--------------------+
|user|               items|
+----+--------------------+
|   0|[11478, 1332, 133...|
|   1|[1042, 11120, 117...|
|   2|[10386, 10813, 11...|
|   3|[10741, 10948, 11...|
|   4|[10495, 10581, 10...|
|   5|[10181, 12330, 14...|
|   6|[10348, 10579, 10...|
|   7|[10730, 11012, 11...|
|   8|[10495, 10572, 10...|
|   9|[10411, 11882, 12...|
|  10|[10752, 10759, 10...|
|  11|[11015, 11131, 11...|
|  12|[10063, 10251, 10...|
|  13|[10034, 10948, 11...|
|  14|[10285, 12656, 13...|
|  15|[10338, 11657, 11...|
|  16|[12249, 12505, 12...|
|  17|[10948, 11132, 11...|
|  18|[10319, 10320, 10...|
|  19|[10948, 11209, 12...|
+----+--------------------+
only showing top 20 rows



In [14]:
w2v_train.printSchema()

root
 |-- user: long (nullable = true)
 |-- items: array (nullable = true)
 |    |-- element: string (containsNull = true)



Обучаем модель.

In [15]:
item2Vec = Word2Vec(vectorSize=10, minCount=0, inputCol='items', outputCol='result', windowSize=10)

In [16]:
item2Vec = item2Vec.fit(w2v_train)

Метод transform возвращает векторные представления наборов приложений покупателей (усреднение векторов всех приложений).

In [17]:
result = item2Vec.transform(w2v_train)
result.cache()

DataFrame[user: bigint, items: array<string>, result: vector]

In [18]:
result.show()

+----+--------------------+--------------------+
|user|               items|              result|
+----+--------------------+--------------------+
|  26|[10040, 10948, 11...|[-0.1730425206323...|
|  29|[10140, 12249, 13...|[0.06771651612451...|
| 474|[10179, 10961, 11...|[-0.0976660244844...|
| 964|[10229, 12734, 13...|[-0.2425984864433...|
|1677|[10923, 10924, 13...|[-0.1463075963950...|
|1697|[12234, 12939, 12...|[-0.2148559339344...|
|1806|[12826, 13977, 14...|[-0.2647646979793...|
|1950|[12193, 1227, 140...|[-0.1251080517585...|
|2040|[10326, 10348, 11...|[-0.1833049550100...|
|2214|[10780, 12940, 13...|[-0.1368040009623...|
|2250|[10720, 11658, 13...|[-0.1577911119569...|
|2453|[10010, 10016, 12...|[-0.0538457855158...|
|2509|[10020, 11658, 11...|[0.10259561520069...|
|2529|[10187, 10565, 11...|[-0.1790539913746...|
|2927|[10627, 10948, 11...|[-0.3726906167343...|
|3091|[10058, 10670, 10...|[-0.1290517181320...|
|3506|[11033, 11599, 11...|[-0.2458223892442...|
|3764|[10948, 11189,

Метод getVectors возвращает векторные представления приложений.

In [19]:
item2Vec.getVectors().show()

+-----+--------------------+
| word|              vector|
+-----+--------------------+
|10292|[0.04394483566284...|
|19125|[-0.0762670785188...|
| 5451|[0.03049250692129...|
| 4018|[0.09986700862646...|
|17319|[-0.6657822728157...|
|20778|[-0.0970354080200...|
|17079|[-0.5113746523857...|
| 9936|[0.06084300205111...|
|13172|[-0.0211183037608...|
|17840|[-0.1792943924665...|
|10304|[-0.3674074113368...|
|20323|[-0.1198502331972...|
|16997|[-0.8670586943626...|
|14779|[-0.2956146597862...|
|15822|[-0.8894290328025...|
| 4056|[-0.0083574801683...|
|15469|[-0.0352986119687...|
|12209|[-0.2829174101352...|
|21299|[-0.0479806289076...|
|  710|[-0.0756228566169...|
+-----+--------------------+
only showing top 20 rows

